Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
13129772
Commit
13129772
authored
Oct 26, 2023
by
letaoqin
Browse files
update to grouped gemm
parent
5cc0fd88
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
7 additions
and
9 deletions
+7
-9
example/52_flash_atten_bias/grouped_mutihead_attention_bias_infer.cpp
...lash_atten_bias/grouped_mutihead_attention_bias_infer.cpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_infer_xdl_cshuffle.hpp
...gpu/device/impl/device_batched_mha_infer_xdl_cshuffle.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_infer_xdl_cshuffle.hpp
...gpu/device/impl/device_grouped_mha_infer_xdl_cshuffle.hpp
+5
-7
No files found.
example/52_flash_atten_bias/grouped_mutihead_attention_bias_infer.cpp
View file @
13129772
...
...
@@ -120,7 +120,7 @@ using DeviceGemmInstance =
8
,
8
,
true
,
1
,
8
,
S
<
16
,
16
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_infer_xdl_cshuffle.hpp
View file @
13129772
...
...
@@ -202,7 +202,7 @@ template <index_t NumDimG,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpecialization
MaskingSpec
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
struct
DeviceBatchedMultiheadAttentionInfer_Xdl_CShuffle
:
public
DeviceBatchedMultiheadAttentionInfer
<
NumDimG
,
NumDimM
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_infer_xdl_cshuffle.hpp
View file @
13129772
...
...
@@ -109,7 +109,7 @@ __global__ void
c_element_op
,
arg_ptr
[
group_id
].
a_grid_desc_ak0_m_ak1_
,
arg_ptr
[
group_id
].
b_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
d0_grid_desc_m0_n0_m1_
n1_
m2_n
2
_m3_
n3_n4_n5_
,
arg_ptr
[
group_id
].
d0_grid_desc_m0_n0_m1_m2_n
1
_m3_
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg_ptr
[
group_id
].
block_2_ctile_map_
,
...
...
@@ -461,8 +461,7 @@ struct DeviceGroupedMultiheadAttentionInfer_Xdl_CShuffle
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
typename
GridwiseGemm
::
D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
;
typename
GridwiseGemm
::
D0GridDescriptor_M0_N0_N1_N2_M1_N3
d0_grid_desc_m0_n0_m1_m2_n1_m3_
;
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1_
;
typename
GridwiseGemm
::
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...
...
@@ -567,9 +566,8 @@ struct DeviceGroupedMultiheadAttentionInfer_Xdl_CShuffle
const
D0GridDesc_M_N
d0_grid_desc_m_n
{
DeviceOp
::
MakeD0GridDescriptor_M_N
(
tmp_d0_gs_ms_ns_lengths
,
tmp_d0_gs_ms_ns_strides
)};
const
auto
d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
=
GridwiseGemm
::
MakeD0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
d0_grid_desc_m_n
);
const
auto
d0_grid_desc_m0_n0_m1_m2_n1_m3_
=
GridwiseGemm
::
MakeD0GridDescriptor_M0_N0_N1_N2_M1_N3
(
d0_grid_desc_m_n
);
const
auto
b1_grid_desc_bk0_n_bk1
=
MakeB1GridDescriptor_BK0_N_BK1
(
problem_desc
.
b1_gs_os_ns_lengths
,
problem_desc
.
b1_gs_os_ns_strides
);
...
...
@@ -619,7 +617,7 @@ struct DeviceGroupedMultiheadAttentionInfer_Xdl_CShuffle
p_c_grid
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
d0_grid_desc_m0_n0_m1_
n1_
m2_n
2
_m3_
n3_n4_n5
,
d0_grid_desc_m0_n0_m1_m2_n
1
_m3_
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
block_2_ctile_map
.
CalculateGridSize
(
c_grid_desc_m_n
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment