Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
896408a5
Commit
896408a5
authored
Aug 14, 2023
by
letaoqin
Browse files
fix group gemm
parent
95d76f67
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
38 additions
and
20 deletions
+38
-20
example/32_batched_gemm_scale_softmax_gemm/run_grouped_multihead_attention_forward.inc
..._softmax_gemm/run_grouped_multihead_attention_forward.inc
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2.hpp
...pu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2.hpp
+36
-18
No files found.
example/32_batched_gemm_scale_softmax_gemm/run_grouped_multihead_attention_forward.inc
View file @
896408a5
...
@@ -287,8 +287,8 @@ int run(int argc, char* argv[])
...
@@ -287,8 +287,8 @@ int run(int argc, char* argv[])
p_c
,
p_c
,
p_z
,
p_z
,
p_lse
,
p_lse
,
nullptr
,
// p_acc0_biases
{}
,
// p_acc0_biases
nullptr
,
// p_acc1_biases
{}
,
// p_acc1_biases
problem_descs
,
problem_descs
,
a_element_op
,
a_element_op
,
b0_element_op
,
b0_element_op
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2.hpp
View file @
896408a5
...
@@ -24,6 +24,7 @@ namespace tensor_operation {
...
@@ -24,6 +24,7 @@ namespace tensor_operation {
namespace
device
{
namespace
device
{
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
typename
D0DataType
,
typename
GemmAccDataType
,
typename
GemmAccDataType
,
typename
GroupKernelArg
,
typename
GroupKernelArg
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
...
@@ -99,8 +100,17 @@ __global__ void
...
@@ -99,8 +100,17 @@ __global__ void
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetZBasePtr
(
g_idx
)));
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetZBasePtr
(
g_idx
)));
const
long_index_t
lse_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
const
long_index_t
lse_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetLSEBasePtr
(
g_idx
)));
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetLSEBasePtr
(
g_idx
)));
const
long_index_t
d0_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetD0BasePtr
(
g_idx
)));
const
D0DataType
*
tmp_p_d0_grid
=
nullptr
;
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
const
long_index_t
d0_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetD0BasePtr
(
g_idx
)));
tmp_p_d0_grid
=
arg_ptr
[
group_id
].
p_d0_grid_
+
d0_batch_offset
;
}
if
constexpr
(
Deterministic
)
if
constexpr
(
Deterministic
)
{
{
...
@@ -109,9 +119,7 @@ __global__ void
...
@@ -109,9 +119,7 @@ __global__ void
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
,
IsLseStoring
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
,
IsLseStoring
>(
arg_ptr
[
group_id
].
p_a_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_a_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_b_grid_
+
b_batch_offset
,
arg_ptr
[
group_id
].
p_b_grid_
+
b_batch_offset
,
arg_ptr
[
group_id
].
p_d0_grid_
==
nullptr
tmp_p_d0_grid
,
?
nullptr
:
arg_ptr
[
group_id
].
p_d0_grid_
+
d0_batch_offset
,
arg_ptr
[
group_id
].
p_b1_grid_
+
b1_batch_offset
,
arg_ptr
[
group_id
].
p_b1_grid_
+
b1_batch_offset
,
arg_ptr
[
group_id
].
p_c_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_c_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_z_grid_
==
nullptr
arg_ptr
[
group_id
].
p_z_grid_
==
nullptr
...
@@ -150,9 +158,7 @@ __global__ void
...
@@ -150,9 +158,7 @@ __global__ void
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
,
IsLseStoring
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
,
IsLseStoring
>(
arg_ptr
[
group_id
].
p_a_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_a_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_b_grid_
+
b_batch_offset
,
arg_ptr
[
group_id
].
p_b_grid_
+
b_batch_offset
,
arg_ptr
[
group_id
].
p_d0_grid_
==
nullptr
tmp_p_d0_grid
,
?
nullptr
:
arg_ptr
[
group_id
].
p_d0_grid_
+
d0_batch_offset
,
arg_ptr
[
group_id
].
p_b1_grid_
+
b1_batch_offset
,
arg_ptr
[
group_id
].
p_b1_grid_
+
b1_batch_offset
,
arg_ptr
[
group_id
].
p_c_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_c_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_z_grid_
==
nullptr
?
nullptr
arg_ptr
[
group_id
].
p_z_grid_
==
nullptr
?
nullptr
...
@@ -716,9 +722,23 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -716,9 +722,23 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
problem_desc
.
a_gs_ms_ks_lengths
,
problem_desc
.
a_gs_ms_ks_strides
);
problem_desc
.
a_gs_ms_ks_lengths
,
problem_desc
.
a_gs_ms_ks_strides
);
const
auto
b_grid_desc_bk0_n_bk1
=
MakeBGridDescriptor_BK0_N_BK1
(
const
auto
b_grid_desc_bk0_n_bk1
=
MakeBGridDescriptor_BK0_N_BK1
(
problem_desc
.
b0_gs_ns_ks_lengths
,
problem_desc
.
b0_gs_ns_ks_strides
);
problem_desc
.
b0_gs_ns_ks_lengths
,
problem_desc
.
b0_gs_ns_ks_strides
);
const
D0GridDesc_M_N
d0_grid_desc_m_n
{
DeviceOp
::
MakeD0GridDescriptor_M_N
(
problem_desc
.
acc0_biases_gs_ms_ns_lengths
,
std
::
vector
<
index_t
>
tmp_d0_gs_ms_ns_lengths
;
problem_desc
.
acc0_biases_gs_ms_ns_strides
)};
std
::
vector
<
index_t
>
tmp_d0_gs_ms_ns_strides
;
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
tmp_d0_gs_ms_ns_lengths
=
problem_desc
.
acc0_biases_gs_ms_ns_lengths
;
tmp_d0_gs_ms_ns_strides
=
problem_desc
.
acc0_biases_gs_ms_ns_strides
;
}
else
{
tmp_d0_gs_ms_ns_lengths
=
{
1
,
1
,
1
,
1
};
tmp_d0_gs_ms_ns_strides
=
{
0
,
0
,
0
,
0
};
}
const
D0GridDesc_M_N
d0_grid_desc_m_n
{
DeviceOp
::
MakeD0GridDescriptor_M_N
(
tmp_d0_gs_ms_ns_lengths
,
tmp_d0_gs_ms_ns_strides
)};
const
auto
d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
=
const
auto
d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
=
GridwiseGemm
::
MakeD0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
GridwiseGemm
::
MakeD0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
d0_grid_desc_m_n
);
d0_grid_desc_m_n
);
...
@@ -735,9 +755,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -735,9 +755,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
problem_desc
.
a_gs_ms_ks_lengths
,
problem_desc
.
a_gs_ms_ks_strides
);
problem_desc
.
a_gs_ms_ks_lengths
,
problem_desc
.
a_gs_ms_ks_strides
);
const
auto
b_grid_desc_g_n_k
=
Transform
::
MakeB0GridDescriptor_G_N_K
(
const
auto
b_grid_desc_g_n_k
=
Transform
::
MakeB0GridDescriptor_G_N_K
(
problem_desc
.
b0_gs_ns_ks_lengths
,
problem_desc
.
b0_gs_ns_ks_strides
);
problem_desc
.
b0_gs_ns_ks_lengths
,
problem_desc
.
b0_gs_ns_ks_strides
);
const
auto
d0_grid_desc_g_m_n
=
const
auto
d0_grid_desc_g_m_n
=
DeviceOp
::
MakeD0GridDescriptor_G_M_N
(
DeviceOp
::
MakeD0GridDescriptor_G_M_N
(
problem_desc
.
acc0_biases_gs_ms_ns_lengths
,
tmp_d0_gs_ms_ns_lengths
,
tmp_d0_gs_ms_ns_strides
);
problem_desc
.
acc0_biases_gs_ms_ns_strides
);
const
auto
b1_grid_desc_g_n_k
=
Transform
::
MakeB1GridDescriptor_G_N_K
(
const
auto
b1_grid_desc_g_n_k
=
Transform
::
MakeB1GridDescriptor_G_N_K
(
problem_desc
.
b1_gs_os_ns_lengths
,
problem_desc
.
b1_gs_os_ns_strides
);
problem_desc
.
b1_gs_os_ns_lengths
,
problem_desc
.
b1_gs_os_ns_strides
);
const
auto
c_grid_desc_g_m_n
=
Transform
::
MakeCGridDescriptor_G_M_N
(
const
auto
c_grid_desc_g_m_n
=
Transform
::
MakeCGridDescriptor_G_M_N
(
...
@@ -812,10 +831,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -812,10 +831,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
// for check
// for check
std
::
vector
<
ck
::
index_t
>
d0_n_length_stride
;
std
::
vector
<
ck
::
index_t
>
d0_n_length_stride
;
d0_n_length_stride
.
push_back
(
d0_n_length_stride
.
push_back
(
tmp_d0_gs_ms_ns_lengths
[
NumDimG
+
NumDimM
]);
problem_desc
.
acc0_biases_gs_ms_ns_lengths
[
NumDimG
+
NumDimM
]);
d0_n_length_stride
.
push_back
(
tmp_d0_gs_ms_ns_strides
[
NumDimG
+
NumDimM
]);
d0_n_length_stride
.
push_back
(
problem_desc
.
acc0_biases_gs_ms_ns_strides
[
NumDimG
+
NumDimM
]);
group_device_args_
.
push_back
(
group_device_args_
.
push_back
(
{{
problem_desc
.
a_gs_ms_ks_lengths
[
NumDimG
+
NumDimM
-
1
],
{{
problem_desc
.
a_gs_ms_ks_lengths
[
NumDimG
+
NumDimM
-
1
],
...
@@ -900,6 +917,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -900,6 +917,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
[
&
](
auto
has_main_k_block_loop_
,
auto
is_dropout_
,
auto
is_lse_storing_
)
{
[
&
](
auto
has_main_k_block_loop_
,
auto
is_dropout_
,
auto
is_lse_storing_
)
{
const
auto
kernel
=
const
auto
kernel
=
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v2
<
GridwiseGemm
,
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v2
<
GridwiseGemm
,
D0DataType
,
GemmAccDataType
,
GemmAccDataType
,
GroupKernelArg
,
GroupKernelArg
,
AElementwiseOperation
,
AElementwiseOperation
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment