Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
4033f5df
Unverified
Commit
4033f5df
authored
Oct 26, 2023
by
Dan Yao
Committed by
GitHub
Oct 26, 2023
Browse files
Merge pull request #949 from ROCmSoftwarePlatform/mha-train-mqagqa
Grouped Query Attention/Multi Query Attention
parents
18be6bc9
957d5dee
Changes
26
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
111 additions
and
32 deletions
+111
-32
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
...ice/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
+68
-10
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2.hpp
...pu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2.hpp
+19
-6
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v1.hpp
...dwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v1.hpp
+6
-4
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v2.hpp
...dwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v2.hpp
+6
-4
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v1.hpp
...id/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v1.hpp
+6
-4
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
...id/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
+6
-4
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
View file @
4033f5df
...
...
@@ -44,6 +44,7 @@ __global__ void
kernel_grouped_multihead_attention_backward_qloop_xdl_cshuffle_v2
(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
group_kernel_args
,
const
index_t
group_count
,
const
index_t
h_ratio
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
AccElementwiseOperation
acc_element_op
,
...
...
@@ -82,19 +83,26 @@ __global__ void
const
index_t
num_blocks_per_batch
=
arg_ptr
[
group_id
].
num_blocks_per_batch_
;
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
(
block_id
-
arg_ptr
[
group_id
].
block_start_
)
/
(
Deterministic
?
1
:
num_blocks_per_batch
));
const
index_t
gkv_idx
=
__builtin_amdgcn_readfirstlane
(
g_idx
/
h_ratio
);
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetABasePtr
(
g_idx
)));
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetBBasePtr
(
g_idx
)));
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetBBasePtr
(
g
kv
_idx
)));
const
long_index_t
z_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetZBasePtr
(
g_idx
)));
const
long_index_t
b1_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetB1BasePtr
(
g_idx
)));
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetB1BasePtr
(
g
kv
_idx
)));
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetCBasePtr
(
g_idx
)));
const
long_index_t
lse_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetLSEBasePtr
(
g_idx
)));
const
long_index_t
bgrad_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetBGradBasePtr
(
g_idx
)));
const
long_index_t
b1grad_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetB1GradBasePtr
(
g_idx
)));
const
index_t
global_thread_id
=
get_thread_global_1d_id
();
ck
::
philox
ph
(
seed
,
global_thread_id
,
offset
);
...
...
@@ -128,9 +136,9 @@ __global__ void
arg_ptr
[
group_id
].
p_lse_grid_
+
lse_batch_offset
,
arg_ptr
[
group_id
].
p_ygrad_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_qgrad_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_kgrad_grid_
+
b_batch_offset
,
arg_ptr
[
group_id
].
p_kgrad_grid_
+
b
grad
_batch_offset
,
tmp_p_d0grad_grid
,
arg_ptr
[
group_id
].
p_vgrad_grid_
+
b1_batch_offset
,
arg_ptr
[
group_id
].
p_vgrad_grid_
+
b1
grad
_batch_offset
,
p_shared
,
a_element_op
,
b_element_op
,
...
...
@@ -139,9 +147,11 @@ __global__ void
c_element_op
,
arg_ptr
[
group_id
].
a_grid_desc_ak0_m_ak1_
,
arg_ptr
[
group_id
].
b_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
bgrad_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
d0_grid_desc_m0_n0_m1_m2_n1_m3_
,
arg_ptr
[
group_id
].
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
b1grad_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
y_grid_desc_mblock_mperblock_oblock_operblock_
,
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
arg_ptr
[
group_id
].
ygrad_grid_desc_m0_o_m1_
,
...
...
@@ -167,9 +177,9 @@ __global__ void
arg_ptr
[
group_id
].
p_lse_grid_
+
lse_batch_offset
,
arg_ptr
[
group_id
].
p_ygrad_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_qgrad_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_kgrad_grid_
+
b_batch_offset
,
arg_ptr
[
group_id
].
p_kgrad_grid_
+
b
grad
_batch_offset
,
tmp_p_d0grad_grid
,
arg_ptr
[
group_id
].
p_vgrad_grid_
+
b1_batch_offset
,
arg_ptr
[
group_id
].
p_vgrad_grid_
+
b1
grad
_batch_offset
,
p_shared
,
a_element_op
,
b_element_op
,
...
...
@@ -178,9 +188,11 @@ __global__ void
c_element_op
,
arg_ptr
[
group_id
].
a_grid_desc_ak0_m_ak1_
,
arg_ptr
[
group_id
].
b_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
bgrad_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
d0_grid_desc_m0_n0_m1_m2_n1_m3_
,
arg_ptr
[
group_id
].
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
b1grad_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
y_grid_desc_mblock_mperblock_oblock_operblock_
,
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
arg_ptr
[
group_id
].
ygrad_grid_desc_m0_o_m1_
,
...
...
@@ -196,6 +208,7 @@ __global__ void
#else
ignore
=
group_kernel_args
;
ignore
=
group_count
;
ignore
=
h_ratio
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
acc_element_op
;
...
...
@@ -313,6 +326,12 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
std
::
vector
<
index_t
>
lse_gs_ms_lengths
;
std
::
vector
<
index_t
>
lse_gs_ms_strides
;
std
::
vector
<
index_t
>
bgrad_gs_ns_ks_lengths
;
std
::
vector
<
index_t
>
bgrad_gs_ns_ks_strides
;
std
::
vector
<
index_t
>
b1grad_gs_gemm1ns_gemm1ks_lengths
;
std
::
vector
<
index_t
>
b1grad_gs_gemm1ns_gemm1ks_strides
;
std
::
vector
<
index_t
>
acc0_bias_gs_ms_ns_lengths
;
std
::
vector
<
index_t
>
acc0_bias_gs_ms_ns_strides
;
...
...
@@ -557,7 +576,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static
auto
MakeD0GridDescriptor_M_N
(
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
)
{
return
Transform
::
MakeC0GridDescriptor_M_N
(
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
);
}
...
...
@@ -566,7 +584,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
MakeD0GridDescriptor_G_M_N
(
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
)
{
return
Transform
::
MakeC0GridDescriptor_G_M_N
(
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
);
}
...
...
@@ -613,6 +630,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const
ZGridDesc_G_M_N
&
z_grid_desc_g_m_n
,
const
B1GridDesc_G_N_K
&
b1_grid_desc_g_n_k
,
const
CGridDesc_G_M_N
&
c_grid_desc_g_m_n
,
const
BGridDesc_G_N_K
&
bgrad_grid_desc_g_n_k
,
const
B1GridDesc_G_N_K
&
b1grad_grid_desc_g_n_k
,
index_t
BatchStrideLSE
)
:
a_grid_desc_g_m_k_
(
a_grid_desc_g_m_k
),
b_grid_desc_g_n_k_
(
b_grid_desc_g_n_k
),
...
...
@@ -620,6 +639,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
z_grid_desc_g_m_n_
(
z_grid_desc_g_m_n
),
b1_grid_desc_g_n_k_
(
b1_grid_desc_g_n_k
),
c_grid_desc_g_m_n_
(
c_grid_desc_g_m_n
),
bgrad_grid_desc_g_n_k_
(
bgrad_grid_desc_g_n_k
),
b1grad_grid_desc_g_n_k_
(
b1grad_grid_desc_g_n_k
),
BatchStrideLSE_
(
BatchStrideLSE
)
{
}
...
...
@@ -659,6 +680,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideLSE_
);
}
__host__
__device__
constexpr
long_index_t
GetBGradBasePtr
(
index_t
g_idx
)
const
{
return
bgrad_grid_desc_g_n_k_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
__host__
__device__
constexpr
long_index_t
GetB1GradBasePtr
(
index_t
g_idx
)
const
{
return
b1grad_grid_desc_g_n_k_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
private:
AGridDesc_G_M_K
a_grid_desc_g_m_k_
;
BGridDesc_G_N_K
b_grid_desc_g_n_k_
;
...
...
@@ -666,6 +697,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
ZGridDesc_G_M_N
z_grid_desc_g_m_n_
;
B1GridDesc_G_N_K
b1_grid_desc_g_n_k_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
BGridDesc_G_N_K
bgrad_grid_desc_g_n_k_
;
B1GridDesc_G_N_K
b1grad_grid_desc_g_n_k_
;
index_t
BatchStrideLSE_
;
};
...
...
@@ -765,9 +798,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
BGridDesc_BK0_N_BK1
bgrad_grid_desc_bk0_n_bk1_
;
typename
GridwiseGemm
::
D0GridDescriptor_M0_N0_M1_M2_N1_M3
d0_grid_desc_m0_n0_m1_m2_n1_m3_
;
ZGridDesc_M_N
z_grid_desc_m_n_
;
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1_
;
B1GridDesc_BK0_N_BK1
b1grad_grid_desc_bk0_n_bk1_
;
YGridDesc_M_O
y_grid_desc_m_o_
;
typename
GridwiseGemm
::
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
...
...
@@ -802,6 +837,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
std
::
vector
<
index_t
>
c_mz_gemm1nz_strides_
;
// for gridwise gemm check
BGridDesc_G_N_K
b_grid_desc_g_n_k_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
index_t
batch_count_
;
...
...
@@ -870,6 +906,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
index_t
z_random_matrix_offset
=
0
;
h_ratio_
=
problem_desc_vec
[
0
].
a_gs_ms_ks_lengths
[
NumDimG
-
1
]
/
problem_desc_vec
[
0
].
b_gs_ns_ks_lengths
[
NumDimG
-
1
];
for
(
index_t
i
=
0
;
i
<
group_count_
;
i
++
)
{
const
auto
p_a_grid
=
static_cast
<
const
InputDataType
*>
(
p_As
[
i
]);
...
...
@@ -897,6 +936,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
problem_desc
.
a_gs_ms_ks_lengths
,
problem_desc
.
a_gs_ms_ks_strides
);
const
auto
b_grid_desc_bk0_n_bk1
=
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
problem_desc
.
b_gs_ns_ks_lengths
,
problem_desc
.
b_gs_ns_ks_strides
);
const
auto
bgrad_grid_desc_bk0_n_bk1
=
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
problem_desc
.
bgrad_gs_ns_ks_lengths
,
problem_desc
.
bgrad_gs_ns_ks_strides
);
std
::
vector
<
index_t
>
tmp_d0_gs_ms_ns_lengths
;
std
::
vector
<
index_t
>
tmp_d0_gs_ms_ns_strides
;
...
...
@@ -919,6 +960,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const
auto
b1_grid_desc_bk0_n_bk1
=
DeviceOp
::
MakeVGridDescriptor_O0_N_O1
(
problem_desc
.
b1_gs_gemm1ns_gemm1ks_lengths
,
problem_desc
.
b1_gs_gemm1ns_gemm1ks_strides
);
const
auto
b1grad_grid_desc_bk0_n_bk1
=
DeviceOp
::
MakeVGridDescriptor_O0_N_O1
(
problem_desc
.
b1grad_gs_gemm1ns_gemm1ks_lengths
,
problem_desc
.
b1grad_gs_gemm1ns_gemm1ks_strides
);
const
auto
y_grid_desc_m_o
=
Transform
::
MakeCGridDescriptor_M_N
(
problem_desc
.
c_gs_ms_gemm1ns_lengths
,
problem_desc
.
c_gs_ms_gemm1ns_strides
);
...
...
@@ -942,6 +986,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
problem_desc
.
b1_gs_gemm1ns_gemm1ks_strides
);
const
auto
c_grid_desc_g_m_n
=
Transform
::
MakeCGridDescriptor_G_M_N
(
problem_desc
.
c_gs_ms_gemm1ns_lengths
,
problem_desc
.
c_gs_ms_gemm1ns_strides
);
const
auto
bgrad_grid_desc_g_n_k
=
Transform
::
MakeB0GridDescriptor_G_N_K
(
problem_desc
.
bgrad_gs_ns_ks_lengths
,
problem_desc
.
bgrad_gs_ns_ks_strides
);
const
auto
b1grad_grid_desc_g_n_k
=
Transform
::
MakeB1GridDescriptor_G_N_K
(
problem_desc
.
b1grad_gs_gemm1ns_gemm1ks_lengths
,
problem_desc
.
b1grad_gs_gemm1ns_gemm1ks_strides
);
typename
GridwiseGemm
::
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock
;
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
...
...
@@ -975,6 +1024,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
z_grid_desc_g_m_n
,
b1_grid_desc_g_n_k
,
c_grid_desc_g_m_n
,
bgrad_grid_desc_g_n_k
,
b1grad_grid_desc_g_n_k
,
type_convert
<
index_t
>
(
problem_desc
.
lse_gs_ms_strides
[
NumDimG
-
1
]));
// C0 mask
...
...
@@ -1002,9 +1053,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_vgrad_grid
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
bgrad_grid_desc_bk0_n_bk1
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
z_grid_desc_m_n
,
b1_grid_desc_bk0_n_bk1
,
b1grad_grid_desc_bk0_n_bk1
,
y_grid_desc_m_o
,
y_grid_desc_mblock_mperblock_oblock_operblock
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
...
...
@@ -1042,6 +1095,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
problem_desc
.
b1_gs_gemm1ns_gemm1ks_strides
[
NumDimG
+
NumDimO
+
NumDimN
-
1
]},
{
problem_desc
.
c_gs_ms_gemm1ns_strides
[
NumDimG
+
NumDimM
-
1
],
problem_desc
.
c_gs_ms_gemm1ns_strides
[
NumDimG
+
NumDimM
+
NumDimO
-
1
]},
b_grid_desc_g_n_k
,
c_grid_desc_g_m_n
,
batch_count
,
d0_n_length_stride
});
...
...
@@ -1068,6 +1122,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
index_t
grid_size_
;
index_t
group_count_
;
index_t
h_ratio_
;
std
::
vector
<
GroupKernelArg
>
group_kernel_args_
;
std
::
vector
<
GroupDeviceArg
>
group_device_args_
;
...
...
@@ -1126,6 +1181,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
0
,
cast_pointer_to_constant_address_space
(
arg
.
p_workspace_
),
arg
.
group_count_
,
arg
.
h_ratio_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
acc_element_op_
,
...
...
@@ -1194,13 +1250,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const
auto
&
device_arg
=
arg
.
group_device_args_
[
i
];
// Check if C permute dimension matches GEMM + GEMM shape
const
index_t
c_g
=
device_arg
.
c_grid_desc_g_m_n_
.
GetLength
(
I0
);
// unpadded
const
index_t
b_g
=
device_arg
.
b_grid_desc_g_n_k_
.
GetLength
(
I0
);
const
index_t
c_m
=
kernel_arg
.
y_grid_desc_m_o_
.
GetLength
(
I0
);
const
index_t
c_gemm1n
=
kernel_arg
.
y_grid_desc_m_o_
.
GetLength
(
I1
);
const
index_t
a_m
=
kernel_arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I1
);
const
index_t
b1_gemm1n
=
kernel_arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I0
)
*
kernel_arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I2
);
if
(
!
(
c_g
==
device_arg
.
batch_count_
&&
c_m
==
a_m
&&
c_gemm1n
==
b1_gemm1n
))
if
(
!
(
c_g
==
device_arg
.
batch_count_
&&
c_m
==
a_m
&&
c_gemm1n
==
b1_gemm1n
&&
c_g
%
b_g
==
0
&&
c_g
/
b_g
==
arg
.
h_ratio_
))
{
return
false
;
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2.hpp
View file @
4033f5df
...
...
@@ -43,6 +43,7 @@ __global__ void
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v2
(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
group_kernel_args
,
const
index_t
group_count
,
const
index_t
h_ratio
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
AccElementwiseOperation
acc_element_op
,
...
...
@@ -87,13 +88,14 @@ __global__ void
const
index_t
num_blocks_per_batch
=
arg_ptr
[
group_id
].
num_blocks_per_batch_
;
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
(
block_id
-
arg_ptr
[
group_id
].
block_start_
)
/
num_blocks_per_batch
);
const
index_t
gkv_idx
=
__builtin_amdgcn_readfirstlane
(
g_idx
/
h_ratio
);
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetABasePtr
(
g_idx
)));
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetBBasePtr
(
g_idx
)));
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetBBasePtr
(
g
kv
_idx
)));
const
long_index_t
b1_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetB1BasePtr
(
g_idx
)));
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetB1BasePtr
(
g
kv
_idx
)));
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetCBasePtr
(
g_idx
)));
const
long_index_t
z_batch_offset
=
__builtin_amdgcn_readfirstlane
(
...
...
@@ -147,6 +149,7 @@ __global__ void
#else
ignore
=
group_kernel_args
;
ignore
=
group_count
;
ignore
=
h_ratio
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
acc_element_op
;
...
...
@@ -367,7 +370,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
MakeD0GridDescriptor_M_N
(
const
std
::
vector
<
ck
::
index_t
>&
acc0_biases_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_biases_gs_ms_ns_strides
)
{
return
Transform
::
MakeC0GridDescriptor_M_N
(
acc0_biases_gs_ms_ns_lengths
,
acc0_biases_gs_ms_ns_strides
);
}
...
...
@@ -376,7 +378,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
MakeD0GridDescriptor_G_M_N
(
const
std
::
vector
<
ck
::
index_t
>&
acc0_biases_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_biases_gs_ms_ns_strides
)
{
return
Transform
::
MakeC0GridDescriptor_G_M_N
(
acc0_biases_gs_ms_ns_lengths
,
acc0_biases_gs_ms_ns_strides
);
}
...
...
@@ -606,6 +607,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
// for gridwise gemm check
CGridDesc_M_N
c_grid_desc_m_n_
;
BGridDesc_G_N_K
b_grid_desc_g_n_k_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
// raw data
std
::
vector
<
ck
::
index_t
>
d0_n_length_stride_
;
...
...
@@ -654,6 +657,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
index_t
z_random_matrix_offset
=
0
;
h_ratio_
=
problem_desc_vec
[
0
].
a_gs_ms_ks_lengths
[
NumDimG
-
1
]
/
problem_desc_vec
[
0
].
b0_gs_ns_ks_lengths
[
NumDimG
-
1
];
for
(
std
::
size_t
i
=
0
;
i
<
group_count_
;
i
++
)
{
const
auto
p_a_grid
=
static_cast
<
const
ADataType
*>
(
p_a_vec
[
i
]);
...
...
@@ -805,6 +811,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
{
problem_desc
.
c_gs_ms_os_strides
[
NumDimG
+
NumDimM
-
1
],
problem_desc
.
c_gs_ms_os_strides
[
NumDimG
+
NumDimM
+
NumDimO
-
1
]},
c_grid_desc_m_n
,
b_grid_desc_g_n_k
,
c_grid_desc_g_m_n
,
d0_n_length_stride
});
}
...
...
@@ -830,6 +838,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
B1ElementwiseOperation
b1_element_op_
;
CElementwiseOperation
c_element_op_
;
index_t
h_ratio_
;
float
p_dropout_
;
uint8_t
p_dropout_in_uint8_t_
;
unsigned
long
long
seed_
;
...
...
@@ -918,6 +927,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
0
,
cast_pointer_to_constant_address_space
(
arg
.
p_workspace_
),
arg
.
group_count_
,
arg
.
h_ratio_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
acc_element_op_
,
...
...
@@ -1040,11 +1050,14 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
const
auto
&
device_arg
=
arg
.
group_device_args_
[
i
];
// Check if C permute dimension matches GEMM + GEMM shape
const
index_t
c_g
=
device_arg
.
c_grid_desc_g_m_n_
.
GetLength
(
I0
);
// unpadded
const
index_t
b_g
=
device_arg
.
b_grid_desc_g_n_k_
.
GetLength
(
I0
);
const
index_t
c_m
=
device_arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
);
const
index_t
c_gemm1n
=
device_arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
);
const
index_t
a_m
=
kernel_arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I1
);
const
index_t
b1_gemm1n
=
kernel_arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I1
);
if
(
!
(
c_m
==
a_m
&&
c_gemm1n
==
b1_gemm1n
))
if
(
!
(
c_m
==
a_m
&&
c_gemm1n
==
b1_gemm1n
&&
c_g
%
b_g
==
0
&&
c_g
/
b_g
==
arg
.
h_ratio_
))
{
return
false
;
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v1.hpp
View file @
4033f5df
...
...
@@ -1443,10 +1443,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
const
CElementwiseOperation
&
c_element_op
,
const
QGridDesc_K0_M_K1
&
q_grid_desc_k0_m_k1
,
const
KGridDesc_K0_N_K1
&
k_grid_desc_k0_n_k1
,
const
KGridDesc_K0_N_K1
&
kgrad_grid_desc_k0_n_k1
,
const
D0GridDescriptor_M0_N0_M1_M2_N1_M3
&
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
const
VGridDesc_O0_N_O1
&
v_grid_desc_o0_n_o1
,
const
VGridDesc_O0_N_O1
&
vgrad_grid_desc_o0_n_o1
,
const
LSEGridDesc_M
&
lse_grid_desc_m
,
const
YGradGridDesc_O0_M_O1
&
ygrad_grid_desc_o0_m_o1
,
const
Block2CTileMap
&
block_2_ctile_map
,
...
...
@@ -1477,11 +1479,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
const
auto
ygrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_ygrad_grid
,
ygrad_grid_desc_o0_m_o1
.
GetElementSpaceSize
());
auto
vgrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_vgrad_grid
,
v_grid_desc_o0_n_o1
.
GetElementSpaceSize
());
p_vgrad_grid
,
v
grad
_grid_desc_o0_n_o1
.
GetElementSpaceSize
());
auto
qgrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_qgrad_grid
,
q_grid_desc_k0_m_k1
.
GetElementSpaceSize
());
auto
kgrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_kgrad_grid
,
k_grid_desc_k0_n_k1
.
GetElementSpaceSize
());
p_kgrad_grid
,
k
grad
_grid_desc_k0_n_k1
.
GetElementSpaceSize
());
// divide block work by [N, K]
const
auto
block_work_idx
=
...
...
@@ -1631,7 +1633,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
// dV: transform input and output tensor descriptors
auto
vgrad_grid_desc_nblock_nperblock_oblock_operblock
=
MakeVGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock
(
v_grid_desc_o0_n_o1
);
MakeVGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock
(
v
grad
_grid_desc_o0_n_o1
);
// dK: A matrix blockwise copy
auto
kgrad_gemm_tile_sgrad_blockwise_copy
=
...
...
@@ -1660,7 +1662,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
// dK: transform input and output tensor descriptors
auto
kgrad_grid_desc_nblock_nperblock_oblock_operblock
=
MakeKGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock
(
k_grid_desc_k0_n_k1
);
MakeKGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock
(
k
grad
_grid_desc_k0_n_k1
);
//
// set up dQ Gemm (type 3 crr)
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v2.hpp
View file @
4033f5df
...
...
@@ -1535,10 +1535,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const
CElementwiseOperation
&
c_element_op
,
const
QGridDesc_K0_M_K1
&
q_grid_desc_k0_m_k1
,
const
KGridDesc_K0_N_K1
&
k_grid_desc_k0_n_k1
,
const
KGridDesc_K0_N_K1
&
kgrad_grid_desc_k0_n_k1
,
const
D0GridDescriptor_M0_N0_M1_M2_N1_M3
&
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
const
VGridDesc_O0_N_O1
&
v_grid_desc_o0_n_o1
,
const
VGridDesc_O0_N_O1
&
vgrad_grid_desc_o0_n_o1
,
const
LSEGridDesc_M
&
lse_grid_desc_m
,
const
YGradGridDesc_M0_O_M1
&
ygrad_grid_desc_m0_o_m1
,
const
Block2CTileMap
&
block_2_ctile_map
,
...
...
@@ -1569,11 +1571,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const
auto
ygrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_ygrad_grid
,
ygrad_grid_desc_m0_o_m1
.
GetElementSpaceSize
());
auto
vgrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_vgrad_grid
,
v_grid_desc_o0_n_o1
.
GetElementSpaceSize
());
p_vgrad_grid
,
v
grad
_grid_desc_o0_n_o1
.
GetElementSpaceSize
());
auto
qgrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_qgrad_grid
,
q_grid_desc_k0_m_k1
.
GetElementSpaceSize
());
auto
kgrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_kgrad_grid
,
k_grid_desc_k0_n_k1
.
GetElementSpaceSize
());
p_kgrad_grid
,
k
grad
_grid_desc_k0_n_k1
.
GetElementSpaceSize
());
// divide block work by [N, K]
const
auto
block_work_idx
=
...
...
@@ -1746,7 +1748,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// dV: transform input and output tensor descriptors
auto
vgrad_grid_desc_nblock_nperblock_oblock_operblock
=
MakeVGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock
(
v_grid_desc_o0_n_o1
);
MakeVGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock
(
v
grad
_grid_desc_o0_n_o1
);
// dK: transform input and output tensor descriptors
const
auto
q_grid_desc_m0_k_m1
=
...
...
@@ -1779,7 +1781,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// dK: transform input and output tensor descriptors
auto
kgrad_grid_desc_nblock_nperblock_oblock_operblock
=
MakeKGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock
(
k_grid_desc_k0_n_k1
);
MakeKGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock
(
k
grad
_grid_desc_k0_n_k1
);
//
// set up dQ Gemm (type 3 crr)
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v1.hpp
View file @
4033f5df
...
...
@@ -1524,10 +1524,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const
CElementwiseOperation
&
c_element_op
,
const
QGridDesc_K0_M_K1
&
q_grid_desc_k0_m_k1
,
const
KGridDesc_K0_N_K1
&
k_grid_desc_k0_n_k1
,
const
KGridDesc_K0_N_K1
&
kgrad_grid_desc_k0_n_k1
,
const
D0GridDescriptor_M0_N0_M1_M2_N1_M3
&
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
const
VGridDesc_O0_N_O1
&
v_grid_desc_o0_n_o1
,
const
VGridDesc_O0_N_O1
&
vgrad_grid_desc_o0_n_o1
,
const
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
&
y_grid_desc_mblock_mperblock_oblock_operblock
,
const
LSEGridDesc_M
&
lse_grid_desc_m
,
...
...
@@ -1560,11 +1562,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const
auto
ygrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_ygrad_grid
,
ygrad_grid_desc_o0_m_o1
.
GetElementSpaceSize
());
auto
vgrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_vgrad_grid
,
v_grid_desc_o0_n_o1
.
GetElementSpaceSize
());
p_vgrad_grid
,
v
grad
_grid_desc_o0_n_o1
.
GetElementSpaceSize
());
auto
qgrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_qgrad_grid
,
q_grid_desc_k0_m_k1
.
GetElementSpaceSize
());
auto
kgrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_kgrad_grid
,
k_grid_desc_k0_n_k1
.
GetElementSpaceSize
());
p_kgrad_grid
,
k
grad
_grid_desc_k0_n_k1
.
GetElementSpaceSize
());
// divide block work by [N, K]
const
auto
block_work_idx
=
...
...
@@ -1714,7 +1716,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// dV: transform input and output tensor descriptors
auto
vgrad_grid_desc_nblock_nperblock_oblock_operblock
=
MakeVGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock
(
v_grid_desc_o0_n_o1
);
MakeVGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock
(
v
grad
_grid_desc_o0_n_o1
);
// dK: A matrix blockwise copy
auto
kgrad_gemm_tile_sgrad_blockwise_copy
=
...
...
@@ -1743,7 +1745,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// dK: transform input and output tensor descriptors
auto
kgrad_grid_desc_nblock_nperblock_oblock_operblock
=
MakeKGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock
(
k_grid_desc_k0_n_k1
);
MakeKGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock
(
k
grad
_grid_desc_k0_n_k1
);
//
// set up dQ Gemm (type 3 crr)
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
View file @
4033f5df
...
...
@@ -1600,10 +1600,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const
CElementwiseOperation
&
c_element_op
,
const
QGridDesc_K0_M_K1
&
q_grid_desc_k0_m_k1
,
const
KGridDesc_K0_N_K1
&
k_grid_desc_k0_n_k1
,
const
KGridDesc_K0_N_K1
&
kgrad_grid_desc_k0_n_k1
,
const
D0GridDescriptor_M0_N0_M1_M2_N1_M3
&
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
const
VGridDesc_O0_N_O1
&
v_grid_desc_o0_n_o1
,
const
VGridDesc_O0_N_O1
&
vgrad_grid_desc_o0_n_o1
,
const
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
&
y_grid_desc_mblock_mperblock_oblock_operblock
,
const
LSEGridDesc_M
&
lse_grid_desc_m
,
...
...
@@ -1636,11 +1638,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const
auto
ygrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_ygrad_grid
,
ygrad_grid_desc_m0_o_m1
.
GetElementSpaceSize
());
auto
vgrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_vgrad_grid
,
v_grid_desc_o0_n_o1
.
GetElementSpaceSize
());
p_vgrad_grid
,
v
grad
_grid_desc_o0_n_o1
.
GetElementSpaceSize
());
auto
qgrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_qgrad_grid
,
q_grid_desc_k0_m_k1
.
GetElementSpaceSize
());
auto
kgrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_kgrad_grid
,
k_grid_desc_k0_n_k1
.
GetElementSpaceSize
());
p_kgrad_grid
,
k
grad
_grid_desc_k0_n_k1
.
GetElementSpaceSize
());
// divide block work by [N, K]
const
auto
block_work_idx
=
...
...
@@ -1813,7 +1815,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// dV: transform input and output tensor descriptors
auto
vgrad_grid_desc_nblock_nperblock_oblock_operblock
=
MakeVGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock
(
v_grid_desc_o0_n_o1
);
MakeVGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock
(
v
grad
_grid_desc_o0_n_o1
);
// dK: transform input and output tensor descriptors
const
auto
q_grid_desc_m0_k_m1
=
...
...
@@ -1846,7 +1848,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// dK: transform input and output tensor descriptors
auto
kgrad_grid_desc_nblock_nperblock_oblock_operblock
=
MakeKGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock
(
k_grid_desc_k0_n_k1
);
MakeKGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock
(
k
grad
_grid_desc_k0_n_k1
);
//
// set up dQ Gemm (type 3 crr)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment