Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
d173a2cb
Commit
d173a2cb
authored
Aug 08, 2023
by
letaoqin
Browse files
batched gemm add do vector load
parent
9679ba63
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
34 additions
and
1 deletion
+34
-1
example/52_flash_atten_bias/batched_multihead_attention_bias_forward_v2.cpp
...tten_bias/batched_multihead_attention_bias_forward_v2.cpp
+6
-0
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2r2.hpp
.../device/impl/device_batched_mha_fwd_xdl_cshuffle_v2r2.hpp
+28
-1
No files found.
example/52_flash_atten_bias/batched_multihead_attention_bias_forward_v2.cpp
View file @
d173a2cb
...
...
@@ -136,6 +136,7 @@ using DeviceGemmInstance =
8
,
8
,
true
,
4
,
S
<
16
,
16
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
...
...
@@ -147,6 +148,7 @@ using DeviceGemmInstance =
1
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
4
,
MaskingSpec
,
// MaskingSpecialization
Deterministic
>
;
#elif(DIM <= 64)
...
...
@@ -207,6 +209,7 @@ using DeviceGemmInstance =
8
,
8
,
true
,
4
,
S
<
16
,
16
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
...
...
@@ -218,6 +221,7 @@ using DeviceGemmInstance =
2
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
4
,
MaskingSpec
,
// MaskingSpecialization
Deterministic
>
;
#elif(DIM <= 128)
...
...
@@ -278,6 +282,7 @@ using DeviceGemmInstance =
8
,
8
,
true
,
4
,
S
<
8
,
32
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
...
...
@@ -289,6 +294,7 @@ using DeviceGemmInstance =
2
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
4
,
MaskingSpec
,
// MaskingSpecialization
Deterministic
>
;
#endif
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2r2.hpp
View file @
d173a2cb
...
...
@@ -274,6 +274,7 @@ template <index_t NumDimG,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
bool
BBlockLdsExtraN
,
index_t
Acc0BiasTransferSrcScalarPerVector
,
typename
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
B1BlockTransferThreadClusterArrangeOrder
,
typename
B1BlockTransferSrcAccessOrder
,
...
...
@@ -285,6 +286,7 @@ template <index_t NumDimG,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
Acc1BiasTransferSrcScalarPerVector
,
MaskingSpecialization
MaskingSpec
,
bool
Deterministic
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
...
...
@@ -347,6 +349,14 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
BSpec
,
B1Spec
,
CSpec
>
;
using
RawTransform
=
TransformBatchedContractionContractionToBatchedGemmGemm
<
Sequence
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
>
,
Sequence
<
MPerBlock
,
NPerBlock
,
KPerBlock
,
Gemm1NPerBlock
>
,
GemmSpecialization
::
Default
,
ASpec
,
BSpec
,
B1Spec
,
CSpec
>
;
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides_vec
)
...
...
@@ -552,6 +562,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
BBlockTransferDstScalarPerVector_BK1
,
true
,
BBlockLdsExtraN
,
Acc0BiasTransferSrcScalarPerVector
,
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
B1BlockTransferThreadClusterArrangeOrder
,
B1BlockTransferSrcAccessOrder
,
...
...
@@ -564,6 +575,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
CShuffleNXdlPerWavePerShuffle
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
Acc1BiasTransferSrcScalarPerVector
,
LoopSched
,
Transform
::
matrix_padder
.
PadN
,
MaskingSpec
!=
MaskingSpecialization
::
MaskDisabled
,
...
...
@@ -670,7 +682,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
c_grid_desc_g_m_n_
,
d_grid_desc_g_m_n_
,
z_grid_desc_g_m_n_
,
type_convert
<
index_t
>
(
lse_grid_desc_m_
.
GetElementSpaceSize
())}
type_convert
<
index_t
>
(
lse_grid_desc_m_
.
GetElementSpaceSize
())},
raw_d0_n_
(
0
)
{
// TODO ANT: implement bias addition
ignore
=
p_acc1_biases
;
...
...
@@ -709,6 +722,12 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
{
is_lse_storing_
=
false
;
}
if
constexpr
(
NumD0Tensor
)
{
const
auto
d0_grid_desc_m_n
=
RawTransform
::
MakeCGridDescriptor_M_N
(
acc0_biases_gs_ms_ns_lengths
[
0
],
acc0_biases_gs_ms_ns_strides
[
0
]);
raw_d0_n_
=
d0_grid_desc_m_n
.
GetLength
(
I1
);
}
}
void
Print
()
const
...
...
@@ -794,6 +813,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
index_t
m_raw_padded_
;
index_t
n_raw_padded_
;
// raw data
int
raw_d0_n_
;
};
// Invoker
...
...
@@ -1000,6 +1022,11 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
return
false
;
}
if
(
arg
.
raw_d0_n_
%
Acc0BiasTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
// Note: we need raw lengths since threadwise copy can not handle vector load when part of
// vector is out of bounds
// Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment