Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
5d6bfabb
"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "b75a491a78a205f4b78bf2fb7356074c97b93d86"
Commit
5d6bfabb
authored
Aug 08, 2023
by
letaoqin
Browse files
add d vector load template parameters
parent
8e7b98eb
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
16 additions
and
4 deletions
+16
-4
example/52_flash_atten_bias/grouped_multihead_attention_bias_forward_v2.cpp
...tten_bias/grouped_multihead_attention_bias_forward_v2.cpp
+9
-3
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2r2.hpp
.../device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2r2.hpp
+4
-0
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2r2.hpp
...n/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2r2.hpp
+3
-1
No files found.
example/52_flash_atten_bias/grouped_multihead_attention_bias_forward_v2.cpp
View file @
5d6bfabb
...
@@ -136,6 +136,7 @@ using DeviceGemmInstance =
...
@@ -136,6 +136,7 @@ using DeviceGemmInstance =
8
,
8
,
8
,
8
,
true
,
true
,
1
,
S
<
16
,
16
,
1
>
,
// B1BlockTransfer
S
<
16
,
16
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
...
@@ -147,7 +148,8 @@ using DeviceGemmInstance =
...
@@ -147,7 +148,8 @@ using DeviceGemmInstance =
1
,
// CShuffleNXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
,
// MaskingSpecialization
1
,
MaskingSpec
,
// MaskingSpecialization
Deterministic
>
;
Deterministic
>
;
#elif(DIM <= 64)
#elif(DIM <= 64)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
...
@@ -207,6 +209,7 @@ using DeviceGemmInstance =
...
@@ -207,6 +209,7 @@ using DeviceGemmInstance =
8
,
8
,
8
,
8
,
true
,
true
,
1
,
S
<
16
,
16
,
1
>
,
// B1BlockTransfer
S
<
16
,
16
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
...
@@ -218,7 +221,8 @@ using DeviceGemmInstance =
...
@@ -218,7 +221,8 @@ using DeviceGemmInstance =
2
,
// CShuffleNXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
,
// MaskingSpecialization
1
,
MaskingSpec
,
// MaskingSpecialization
Deterministic
>
;
Deterministic
>
;
#elif(DIM <= 128)
#elif(DIM <= 128)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
...
@@ -278,6 +282,7 @@ using DeviceGemmInstance =
...
@@ -278,6 +282,7 @@ using DeviceGemmInstance =
8
,
8
,
8
,
8
,
true
,
true
,
1
,
S
<
8
,
32
,
1
>
,
// B1BlockTransfer
S
<
8
,
32
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
...
@@ -289,7 +294,8 @@ using DeviceGemmInstance =
...
@@ -289,7 +294,8 @@ using DeviceGemmInstance =
2
,
// CShuffleNXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
,
// MaskingSpecialization
1
,
MaskingSpec
,
// MaskingSpecialization
Deterministic
>
;
Deterministic
>
;
#endif
#endif
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2r2.hpp
View file @
5d6bfabb
...
@@ -256,6 +256,7 @@ template <index_t NumDimG,
...
@@ -256,6 +256,7 @@ template <index_t NumDimG,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
bool
BBlockLdsExtraN
,
bool
BBlockLdsExtraN
,
index_t
Acc0BiasTransferSrcScalarPerVector
,
typename
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
B1BlockTransferThreadClusterArrangeOrder
,
typename
B1BlockTransferThreadClusterArrangeOrder
,
typename
B1BlockTransferSrcAccessOrder
,
typename
B1BlockTransferSrcAccessOrder
,
...
@@ -267,6 +268,7 @@ template <index_t NumDimG,
...
@@ -267,6 +268,7 @@ template <index_t NumDimG,
index_t
CShuffleNXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
Acc1BiasTransferSrcScalarPerVector
,
MaskingSpecialization
MaskingSpec
,
MaskingSpecialization
MaskingSpec
,
bool
Deterministic
,
bool
Deterministic
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
...
@@ -561,6 +563,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
...
@@ -561,6 +563,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
BBlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
BBlockTransferDstScalarPerVector_BK1
,
true
,
true
,
Acc0BiasTransferSrcScalarPerVector
,
BBlockLdsExtraN
,
BBlockLdsExtraN
,
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
B1BlockTransferThreadClusterArrangeOrder
,
B1BlockTransferThreadClusterArrangeOrder
,
...
@@ -574,6 +577,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
...
@@ -574,6 +577,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
CShuffleNXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
Acc1BiasTransferSrcScalarPerVector
,
LoopSched
,
LoopSched
,
Transform
::
matrix_padder
.
PadN
,
Transform
::
matrix_padder
.
PadN
,
MaskingSpec
!=
MaskingSpecialization
::
MaskDisabled
,
MaskingSpec
!=
MaskingSpecialization
::
MaskDisabled
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2r2.hpp
View file @
5d6bfabb
...
@@ -74,6 +74,7 @@ template <typename FloatAB,
...
@@ -74,6 +74,7 @@ template <typename FloatAB,
index_t
BBlockTransferDstScalarPerVector_BK1
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
// ignored
bool
BThreadTransferSrcResetCoordinateAfterRun
,
// ignored
index_t
BBlockLdsExtraN
,
index_t
BBlockLdsExtraN
,
index_t
D0BlockTransferSrcScalarPerVector
,
typename
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
B1BlockTransferThreadClusterArrangeOrder
,
typename
B1BlockTransferThreadClusterArrangeOrder
,
typename
B1BlockTransferSrcAccessOrder
,
typename
B1BlockTransferSrcAccessOrder
,
...
@@ -86,6 +87,7 @@ template <typename FloatAB,
...
@@ -86,6 +87,7 @@ template <typename FloatAB,
index_t
CShuffleNXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
D1BlockTransferSrcScalarPerVector
,
LoopScheduler
LoopSched
,
LoopScheduler
LoopSched
,
bool
PadN
,
bool
PadN
,
bool
MaskOutUpperTriangle
,
bool
MaskOutUpperTriangle
,
...
@@ -930,7 +932,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
...
@@ -930,7 +932,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
n4
>
,
n4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
>
,
9
,
9
,
4
,
D0BlockTransferSrcScalarPerVector
,
1
,
1
,
false
>
(
d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
false
>
(
d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
block_work_idx
[
I0
],
// MBlockId
make_multi_index
(
block_work_idx
[
I0
],
// MBlockId
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment