Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
1a306e0d
Commit
1a306e0d
authored
May 18, 2023
by
danyao12
Browse files
add OutputDataType&Deterministic for pt1q2
parent
35b59ac6
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
367 additions
and
274 deletions
+367
-274
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v4.cpp
..._softmax_gemm/batched_multihead_attention_backward_v4.cpp
+160
-131
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v4.hpp
..._batched_multihead_attention_backward_xdl_cshuffle_v4.hpp
+163
-110
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt4.hpp
...batched_multihead_attention_backward_xdl_cshuffle_pt4.hpp
+44
-33
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v4.cpp
View file @
1a306e0d
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v4.hpp
View file @
1a306e0d
...
...
@@ -28,7 +28,8 @@ namespace tensor_operation {
namespace
device
{
template
<
typename
GridwiseGemm
,
typename
DataType
,
typename
InputDataType
,
typename
OutputDataType
,
typename
ZDataType
,
typename
LSEDataType
,
typename
AElementwiseOperation
,
...
...
@@ -46,22 +47,23 @@ template <typename GridwiseGemm,
typename
Block2CTileMap
,
typename
ComputeBasePtrOfStridedBatch
,
typename
C0MatrixMask
,
bool
HasMainKBlockLoop
>
bool
HasMainKBlockLoop
,
bool
Deterministic
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
/*CK_MIN_BLOCK_PER_CU*/
1
)
#endif
kernel_batched_multihead_attention_backward_xdl_cshuffle_v1
(
const
DataType
*
__restrict__
p_a_grid
,
const
DataType
*
__restrict__
p_b_grid
,
const
Input
DataType
*
__restrict__
p_a_grid
,
const
Input
DataType
*
__restrict__
p_b_grid
,
ZDataType
*
__restrict__
p_z_grid
,
const
DataType
*
__restrict__
p_b1_grid
,
const
DataType
*
__restrict__
p_c_grid
,
const
Input
DataType
*
__restrict__
p_b1_grid
,
const
Input
DataType
*
__restrict__
p_c_grid
,
const
LSEDataType
*
__restrict__
p_lse_grid
,
const
DataType
*
__restrict__
p_ygrad_grid
,
DataType
*
__restrict__
p_qgrad_grid
,
DataType
*
__restrict__
p_kgrad_grid
,
DataType
*
__restrict__
p_vgrad_grid
,
const
Input
DataType
*
__restrict__
p_ygrad_grid
,
Output
DataType
*
__restrict__
p_qgrad_grid
,
Output
DataType
*
__restrict__
p_kgrad_grid
,
Output
DataType
*
__restrict__
p_vgrad_grid
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
AccElementwiseOperation
acc_element_op
,
...
...
@@ -78,6 +80,7 @@ __global__ void
const
YGradGridDesc_O0_M_O1
ygrad_grid_desc_o0_m_o1
,
const
Block2CTileMap
block_2_ctile_map
,
const
index_t
batch_count
,
const
index_t
nblock
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
,
const
C0MatrixMask
c0_matrix_mask
,
const
float
p_drop
,
...
...
@@ -109,33 +112,72 @@ __global__ void
ck
::
philox
ph
(
seed
,
global_thread_id
,
offset
);
ZDataType
*
z_matrix_ptr
=
(
p_z_grid
==
nullptr
?
nullptr
:
p_z_grid
+
z_batch_offset
);
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
z_matrix_ptr
,
p_b1_grid
+
b1_batch_offset
,
p_c_grid
+
c_batch_offset
,
p_lse_grid
+
lse_batch_offset
,
p_ygrad_grid
+
c_batch_offset
,
p_qgrad_grid
+
a_batch_offset
,
p_kgrad_grid
+
b_batch_offset
,
p_vgrad_grid
+
b1_batch_offset
,
p_shared
,
a_element_op
,
b_element_op
,
acc_element_op
,
b1_element_op
,
c_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
lse_grid_desc_m
,
ygrad_grid_desc_o0_m_o1
,
block_2_ctile_map
,
c0_matrix_mask
,
p_drop
,
ph
);
if
constexpr
(
Deterministic
)
{
for
(
index_t
i
=
0
;
i
<
nblock
;
i
++
)
{
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
z_matrix_ptr
,
p_b1_grid
+
b1_batch_offset
,
p_c_grid
+
c_batch_offset
,
p_lse_grid
+
lse_batch_offset
,
p_ygrad_grid
+
c_batch_offset
,
p_qgrad_grid
+
a_batch_offset
,
p_kgrad_grid
+
b_batch_offset
,
p_vgrad_grid
+
b1_batch_offset
,
p_shared
,
a_element_op
,
b_element_op
,
acc_element_op
,
b1_element_op
,
c_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
lse_grid_desc_m
,
ygrad_grid_desc_o0_m_o1
,
block_2_ctile_map
,
c0_matrix_mask
,
p_drop
,
ph
,
i
);
}
}
else
{
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
z_matrix_ptr
,
p_b1_grid
+
b1_batch_offset
,
p_c_grid
+
c_batch_offset
,
p_lse_grid
+
lse_batch_offset
,
p_ygrad_grid
+
c_batch_offset
,
p_qgrad_grid
+
a_batch_offset
,
p_kgrad_grid
+
b_batch_offset
,
p_vgrad_grid
+
b1_batch_offset
,
p_shared
,
a_element_op
,
b_element_op
,
acc_element_op
,
b1_element_op
,
c_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
lse_grid_desc_m
,
ygrad_grid_desc_o0_m_o1
,
block_2_ctile_map
,
c0_matrix_mask
,
p_drop
,
ph
,
0
);
}
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
...
...
@@ -168,7 +210,8 @@ template <index_t NumDimG,
index_t
NumDimN
,
index_t
NumDimK
,
index_t
NumDimO
,
// NumDimGemm1N
typename
DataType
,
typename
InputDataType
,
typename
OutputDataType
,
typename
GemmDataType
,
typename
ZDataType
,
typename
LSEDataType
,
...
...
@@ -221,6 +264,7 @@ template <index_t NumDimG,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpecialization
MaskingSpec
,
bool
Deterministic
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
struct
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
:
public
BaseOperator
// TODO inherit atten bwd op once API stablizes
...
...
@@ -587,7 +631,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
<
DataType
,
// TODO: distinguish A/B datatype
InputDataType
,
// TODO: distinguish A/B datatype
OutputDataType
,
ZDataType
,
GemmDataType
,
GemmAccDataType
,
CShuffleDataType
,
...
...
@@ -643,22 +689,23 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopSched
,
Transform
::
matrix_padder
.
PadN
,
MaskingSpec
==
MaskingSpecialization
::
MaskOutUpperTriangle
>
;
MaskingSpec
==
MaskingSpecialization
::
MaskOutUpperTriangle
,
Deterministic
>
;
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
const
DataType
*
p_a_grid
,
const
DataType
*
p_b_grid
,
const
Input
DataType
*
p_a_grid
,
const
Input
DataType
*
p_b_grid
,
ZDataType
*
p_z_grid
,
const
DataType
*
p_b1_grid
,
const
DataType
*
p_c_grid
,
// for dS
const
Input
DataType
*
p_b1_grid
,
const
Input
DataType
*
p_c_grid
,
// for dS
const
LSEDataType
*
p_lse_grid
,
const
DataType
*
p_ygrad_grid
,
DataType
*
p_qgrad_grid
,
DataType
*
p_kgrad_grid
,
DataType
*
p_vgrad_grid
,
const
Input
DataType
*
p_ygrad_grid
,
Output
DataType
*
p_qgrad_grid
,
Output
DataType
*
p_kgrad_grid
,
Output
DataType
*
p_vgrad_grid
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_biases
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
...
...
@@ -802,16 +849,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
}
// pointers
const
DataType
*
p_a_grid_
;
const
DataType
*
p_b_grid_
;
const
Input
DataType
*
p_a_grid_
;
const
Input
DataType
*
p_b_grid_
;
ZDataType
*
p_z_grid_
;
const
DataType
*
p_b1_grid_
;
const
DataType
*
p_c_grid_
;
const
Input
DataType
*
p_b1_grid_
;
const
Input
DataType
*
p_c_grid_
;
const
LSEDataType
*
p_lse_grid_
;
const
DataType
*
p_ygrad_grid_
;
DataType
*
p_qgrad_grid_
;
DataType
*
p_kgrad_grid_
;
DataType
*
p_vgrad_grid_
;
const
Input
DataType
*
p_ygrad_grid_
;
Output
DataType
*
p_qgrad_grid_
;
Output
DataType
*
p_kgrad_grid_
;
Output
DataType
*
p_vgrad_grid_
;
// tensor descriptor
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
...
...
@@ -876,14 +923,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
}
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
k_grid_desc_n_k_
)
*
arg
.
batch_count_
;
(
Deterministic
?
1
:
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
k_grid_desc_n_k_
))
*
arg
.
batch_count_
;
float
ave_time
=
0
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
const
auto
kernel
=
kernel_batched_multihead_attention_backward_xdl_cshuffle_v1
<
GridwiseGemm
,
DataType
,
InputDataType
,
OutputDataType
,
ZDataType
,
LSEDataType
,
AElementwiseOperation
,
...
...
@@ -901,42 +951,45 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
ComputeBasePtrOfStridedBatch
,
C0MatrixMask
,
has_main_k_block_loop_
>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_z_grid_
,
arg
.
p_b1_grid_
,
arg
.
p_c_grid_
,
arg
.
p_lse_grid_
,
arg
.
p_ygrad_grid_
,
arg
.
p_qgrad_grid_
,
arg
.
p_kgrad_grid_
,
arg
.
p_vgrad_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
acc_element_op_
,
arg
.
b1_element_op_
,
arg
.
c_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
y_grid_desc_mblock_mperblock_oblock_operblock_
,
arg
.
lse_grid_desc_m_
,
arg
.
ygrad_grid_desc_o0_m_o1_
,
arg
.
block_2_ctile_map_
,
arg
.
batch_count_
,
arg
.
compute_base_ptr_of_batch_
,
arg
.
c0_matrix_mask_
,
arg
.
p_drop_
,
arg
.
seed_
,
arg
.
offset_
);
has_main_k_block_loop_
,
Deterministic
>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_z_grid_
,
arg
.
p_b1_grid_
,
arg
.
p_c_grid_
,
arg
.
p_lse_grid_
,
arg
.
p_ygrad_grid_
,
arg
.
p_qgrad_grid_
,
arg
.
p_kgrad_grid_
,
arg
.
p_vgrad_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
acc_element_op_
,
arg
.
b1_element_op_
,
arg
.
c_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
y_grid_desc_mblock_mperblock_oblock_operblock_
,
arg
.
lse_grid_desc_m_
,
arg
.
ygrad_grid_desc_o0_m_o1_
,
arg
.
block_2_ctile_map_
,
arg
.
batch_count_
,
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
k_grid_desc_n_k_
),
arg
.
compute_base_ptr_of_batch_
,
arg
.
c0_matrix_mask_
,
arg
.
p_drop_
,
arg
.
seed_
,
arg
.
offset_
);
};
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
...
...
@@ -1041,16 +1094,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
}
static
auto
MakeArgument
(
const
DataType
*
p_a
,
const
DataType
*
p_b
,
const
Input
DataType
*
p_a
,
const
Input
DataType
*
p_b
,
ZDataType
*
p_z
,
const
DataType
*
p_b1
,
const
DataType
*
p_c
,
const
Input
DataType
*
p_b1
,
const
Input
DataType
*
p_c
,
const
LSEDataType
*
p_lse
,
const
DataType
*
p_ygrad_grid
,
DataType
*
p_qgrad_grid
,
DataType
*
p_kgrad_grid
,
DataType
*
p_vgrad_grid
,
const
Input
DataType
*
p_ygrad_grid
,
Output
DataType
*
p_qgrad_grid
,
Output
DataType
*
p_kgrad_grid
,
Output
DataType
*
p_vgrad_grid
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_biases
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
...
...
@@ -1156,16 +1209,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
float
p_drop
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
// override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
DataType
*>
(
p_a
),
static_cast
<
const
DataType
*>
(
p_b
),
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
Input
DataType
*>
(
p_a
),
static_cast
<
const
Input
DataType
*>
(
p_b
),
static_cast
<
ZDataType
*>
(
p_z
),
static_cast
<
const
DataType
*>
(
p_b1
),
static_cast
<
const
DataType
*>
(
p_c
),
static_cast
<
const
Input
DataType
*>
(
p_b1
),
static_cast
<
const
Input
DataType
*>
(
p_c
),
static_cast
<
const
LSEDataType
*>
(
p_lse
),
static_cast
<
const
DataType
*>
(
p_ygrad_grid
),
static_cast
<
DataType
*>
(
p_qgrad_grid
),
static_cast
<
DataType
*>
(
p_kgrad_grid
),
static_cast
<
DataType
*>
(
p_vgrad_grid
),
static_cast
<
const
Input
DataType
*>
(
p_ygrad_grid
),
static_cast
<
Output
DataType
*>
(
p_qgrad_grid
),
static_cast
<
Output
DataType
*>
(
p_kgrad_grid
),
static_cast
<
Output
DataType
*>
(
p_vgrad_grid
),
p_acc0_biases
,
// cast in struct Argument
p_acc1_biases
,
// cast in struct Argument
a_gs_ms_ks_lengths
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt4.hpp
View file @
1a306e0d
...
...
@@ -20,7 +20,9 @@
namespace
ck
{
template
<
typename
DataType
,
template
<
typename
InputDataType
,
typename
OutputDataType
,
typename
ZDataType
,
typename
GemmDataType
,
typename
FloatGemmAcc
,
typename
FloatCShuffle
,
...
...
@@ -77,6 +79,7 @@ template <typename DataType,
LoopScheduler
LoopSched
,
bool
PadN
,
bool
MaskOutUpperTriangle
,
bool
Deterministic
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
{
...
...
@@ -424,7 +427,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
Sequence
<
AK0
,
MPerBlock
,
AK1
>
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
DataType
,
Input
DataType
,
GemmDataType
,
GridDesc_K0_M_K1
,
decltype
(
q_block_desc_k0_m_k1
),
...
...
@@ -449,7 +452,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
Sequence
<
BK0
,
NPerBlock
,
BK1
>
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
DataType
,
Input
DataType
,
GemmDataType
,
GridDesc_K0_N_K1
,
decltype
(
k_block_desc_k0_n_k1
),
...
...
@@ -474,7 +477,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
Sequence
<
BK0
,
NPerBlock
,
BK1
>
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
DataType
,
Input
DataType
,
GemmDataType
,
GridDesc_K0_N_K1
,
decltype
(
v_block_desc_k0_n_k1
),
...
...
@@ -499,7 +502,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
Sequence
<
AK0
,
MPerBlock
,
AK1
>
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
DataType
,
Input
DataType
,
GemmDataType
,
GridDesc_K0_M_K1
,
decltype
(
ygrad_block_desc_k0_m_k1
),
...
...
@@ -1080,7 +1083,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
typename
ElementwiseOp
=
tensor_operation
::
element_wise
::
PassThrough
>
using
CBlockwiseCopy
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatGemmAcc
,
DataType
,
Output
DataType
,
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
CGridDesc_M0_N0_M1_N1_M2_N2_N3_N4
,
ElementwiseOp
,
// CElementwiseOperation
...
...
@@ -1096,7 +1099,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
template
<
index_t
BlockSize_
,
index_t
BlockSliceLength_M_
,
index_t
BlockSliceLength_O_
>
struct
YDotYGrad_M_O_
{
static
constexpr
index_t
SrcScalarPerVector
=
16
/
sizeof
(
DataType
);
static
constexpr
index_t
SrcScalarPerVector
=
16
/
sizeof
(
Input
DataType
);
static
constexpr
auto
ThreadClusterLength_O
=
Number
<
BlockSliceLength_O_
/
SrcScalarPerVector
>
{};
static
constexpr
auto
ThreadClusterLength_M
=
Number
<
BlockSize_
/
ThreadClusterLength_O
>
{};
...
...
@@ -1213,16 +1216,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
typename
Block2CTileMap
,
typename
C0MatrixMask
,
typename
YGradGridDesc_O0_M_O1
>
__device__
static
void
Run
(
const
DataType
*
__restrict__
p_q_grid
,
const
DataType
*
__restrict__
p_k_grid
,
unsigned
short
*
__restrict__
p_z_grid
,
const
DataType
*
__restrict__
p_v_grid
,
const
DataType
*
__restrict__
p_y_grid
,
__device__
static
void
Run
(
const
Input
DataType
*
__restrict__
p_q_grid
,
const
Input
DataType
*
__restrict__
p_k_grid
,
ZDataType
*
__restrict__
p_z_grid
,
const
Input
DataType
*
__restrict__
p_v_grid
,
const
Input
DataType
*
__restrict__
p_y_grid
,
const
FloatLSE
*
__restrict__
p_lse_grid
,
const
DataType
*
__restrict__
p_ygrad_grid
,
DataType
*
__restrict__
p_qgrad_grid
,
DataType
*
__restrict__
p_kgrad_grid
,
DataType
*
__restrict__
p_vgrad_grid
,
const
Input
DataType
*
__restrict__
p_ygrad_grid
,
Output
DataType
*
__restrict__
p_qgrad_grid
,
Output
DataType
*
__restrict__
p_kgrad_grid
,
Output
DataType
*
__restrict__
p_vgrad_grid
,
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
...
...
@@ -1241,7 +1244,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
const
Block2CTileMap
&
block_2_ctile_map
,
const
C0MatrixMask
&
c0_matrix_mask
,
const
float
p_drop
,
ck
::
philox
&
ph
)
ck
::
philox
&
ph
,
const
index_t
block_idx_n
)
{
const
FloatGemmAcc
p_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
-
p_drop
);
const
FloatGemmAcc
rp_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
/
p_dropout
);
...
...
@@ -1273,9 +1277,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
const
auto
block_work_idx
=
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
const
index_t
block_work_idx_n
=
Deterministic
?
block_idx_n
:
block_work_idx
[
I0
];
// HACK: this force n_block_data_idx_on_grid into SGPR
const
index_t
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
NPerBlock
);
__builtin_amdgcn_readfirstlane
(
block_work_idx
_n
*
NPerBlock
);
// 6 GEMM operations are categorized into 3 buckets. SizeK == SizeO == head_dim
// S_MNK / dP_MNO Gemm (Gemm0 rcr)
...
...
@@ -1605,7 +1611,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
auto
z_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
ushort
,
ushort
,
ZDataType
,
decltype
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
),
decltype
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
),
tensor_operation
::
element_wise
::
PassThrough
,
...
...
@@ -1625,15 +1631,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarStrideInVector
true
>
{
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
block_work_idx
[
I0
]
,
// MBlockId
0
,
// NBlockId
0
,
// mrepeat
0
,
// nrepeat
wave_id
[
I0
],
// MWaveId
wave_id
[
I1
],
// NWaveId
wave_m_n_id
[
I1
],
// MPerXdl
0
,
// group
wave_m_n_id
[
I0
],
// NInputIndex
make_multi_index
(
block_work_idx
_n
,
// MBlockId
0
,
// NBlockId
0
,
// mrepeat
0
,
// nrepeat
wave_id
[
I0
],
// MWaveId
wave_id
[
I1
],
// NWaveId
wave_m_n_id
[
I1
],
// MPerXdl
0
,
// group
wave_m_n_id
[
I0
],
// NInputIndex
0
),
tensor_operation
::
element_wise
::
PassThrough
{}};
...
...
@@ -1678,7 +1684,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
// performs for y
auto
y_threadwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
DataType
,
Input
DataType
,
FloatGemmAcc
,
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
decltype
(
y_thread_desc_m0_m1_o0_o1
),
...
...
@@ -1743,6 +1749,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
const
index_t
num_gemm0_m_block_outer_loop
=
q_grid_desc_k0_m_k1
.
GetLength
(
I1
)
/
MPerBlock
;
constexpr
index_t
num_gemm1_k_block_inner_loop
=
MPerBlock
/
Gemm1KPerBlock
;
if
constexpr
(
Deterministic
)
{
block_sync_lds
();
}
// Initialize dK&dV
kgrad_thread_buf
.
Clear
();
vgrad_thread_buf
.
Clear
();
...
...
@@ -2286,7 +2297,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
FloatCShuffle
,
// typename SrcData,
DataType
,
// typename DstData,
Output
DataType
,
// typename DstData,
decltype
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
vgrad_grid_desc_nblock_nperblock_oblock_operblock
),
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
...
...
@@ -2297,7 +2308,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
{
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
0
,
0
,
0
,
0
),
vgrad_grid_desc_nblock_nperblock_oblock_operblock
,
make_multi_index
(
block_work_idx
[
I0
]
,
0
,
block_work_idx
[
I1
],
0
),
make_multi_index
(
block_work_idx
_n
,
0
,
block_work_idx
[
I1
],
0
),
c_element_op
};
// shuffle: threadwise copy C from VGPR to LDS
...
...
@@ -2344,7 +2355,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
FloatCShuffle
,
// typename SrcData,
DataType
,
// typename DstData,
Output
DataType
,
// typename DstData,
decltype
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
kgrad_grid_desc_nblock_nperblock_oblock_operblock
),
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
...
...
@@ -2355,7 +2366,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
{
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
0
,
0
,
0
,
0
),
kgrad_grid_desc_nblock_nperblock_oblock_operblock
,
make_multi_index
(
block_work_idx
[
I0
]
,
0
,
block_work_idx
[
I1
],
0
),
make_multi_index
(
block_work_idx
_n
,
0
,
block_work_idx
[
I1
],
0
),
c_element_op
};
// space filling curve for threadwise C in VGPR
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment