Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
53a74710
Commit
53a74710
authored
Aug 10, 2023
by
letaoqin
Browse files
add bias to batched gemm
parent
b8f08e67
Changes
4
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
267 additions
and
39 deletions
+267
-39
example/52_flash_atten_bias/batched_multihead_attention_bias_forward_v2.cpp
...tten_bias/batched_multihead_attention_bias_forward_v2.cpp
+4
-4
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2.hpp
...pu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2.hpp
+121
-24
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2r2.hpp
.../device/impl/device_batched_mha_fwd_xdl_cshuffle_v2r2.hpp
+0
-8
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2.hpp
...ion/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2.hpp
+142
-3
No files found.
example/52_flash_atten_bias/batched_multihead_attention_bias_forward_v2.cpp
View file @
53a74710
...
@@ -19,7 +19,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
...
@@ -19,7 +19,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2
r2
.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/check_err.hpp"
...
@@ -80,7 +80,7 @@ static constexpr bool Deterministic = false;
...
@@ -80,7 +80,7 @@ static constexpr bool Deterministic = false;
#if(DIM <= 32)
#if(DIM <= 32)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
R2
<
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
@@ -153,7 +153,7 @@ using DeviceGemmInstance =
...
@@ -153,7 +153,7 @@ using DeviceGemmInstance =
Deterministic
>
;
Deterministic
>
;
#elif(DIM <= 64)
#elif(DIM <= 64)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
R2
<
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
@@ -226,7 +226,7 @@ using DeviceGemmInstance =
...
@@ -226,7 +226,7 @@ using DeviceGemmInstance =
Deterministic
>
;
Deterministic
>
;
#elif(DIM <= 128)
#elif(DIM <= 128)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
R2
<
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2.hpp
View file @
53a74710
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2r2.hpp
View file @
53a74710
...
@@ -346,14 +346,6 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
...
@@ -346,14 +346,6 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
BSpec
,
BSpec
,
B1Spec
,
B1Spec
,
CSpec
>
;
CSpec
>
;
using
RawTransform
=
TransformBatchedContractionContractionToBatchedGemmGemm
<
Sequence
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
>
,
Sequence
<
MPerBlock
,
NPerBlock
,
KPerBlock
,
Gemm1NPerBlock
>
,
GemmSpecialization
::
Default
,
ASpec
,
BSpec
,
B1Spec
,
CSpec
>
;
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths_vec
,
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides_vec
)
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides_vec
)
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2.hpp
View file @
53a74710
...
@@ -25,6 +25,7 @@ namespace ck {
...
@@ -25,6 +25,7 @@ namespace ck {
*
*
*/
*/
template
<
typename
FloatAB
,
template
<
typename
FloatAB
,
typename
D0sDataType
,
typename
ZDataType
,
typename
ZDataType
,
typename
FloatGemm
,
typename
FloatGemm
,
typename
FloatGemmAcc
,
typename
FloatGemmAcc
,
...
@@ -39,6 +40,7 @@ template <typename FloatAB,
...
@@ -39,6 +40,7 @@ template <typename FloatAB,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
D0sGridDesc_M_N
,
typename
B1GridDesc_BK0_N_BK1
,
typename
B1GridDesc_BK0_N_BK1
,
typename
CGridDesc_M_N
,
typename
CGridDesc_M_N
,
typename
ZGridDesc_M_N
,
typename
ZGridDesc_M_N
,
...
@@ -74,6 +76,7 @@ template <typename FloatAB,
...
@@ -74,6 +76,7 @@ template <typename FloatAB,
index_t
BBlockTransferDstScalarPerVector_BK1
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
// ignored
bool
BThreadTransferSrcResetCoordinateAfterRun
,
// ignored
index_t
BBlockLdsExtraN
,
index_t
BBlockLdsExtraN
,
index_t
D0BlockTransferSrcScalarPerVector
,
typename
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
B1BlockTransferThreadClusterArrangeOrder
,
typename
B1BlockTransferThreadClusterArrangeOrder
,
typename
B1BlockTransferSrcAccessOrder
,
typename
B1BlockTransferSrcAccessOrder
,
...
@@ -86,6 +89,7 @@ template <typename FloatAB,
...
@@ -86,6 +89,7 @@ template <typename FloatAB,
index_t
CShuffleNXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
D1BlockTransferSrcScalarPerVector
,
LoopScheduler
LoopSched
,
LoopScheduler
LoopSched
,
bool
PadN
,
bool
PadN
,
bool
MaskOutUpperTriangle
,
bool
MaskOutUpperTriangle
,
...
@@ -93,6 +97,11 @@ template <typename FloatAB,
...
@@ -93,6 +97,11 @@ template <typename FloatAB,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
struct
GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
{
{
static_assert
(
D0BlockTransferSrcScalarPerVector
==
1
||
D0BlockTransferSrcScalarPerVector
==
2
||
D0BlockTransferSrcScalarPerVector
==
4
,
"D0BlockTransferSrcScalarPerVector must be 1 or 2 or 4"
);
static
constexpr
index_t
NumD0Tensor
=
D0sDataType
::
Size
();
static_assert
(
LoopSched
==
LoopScheduler
::
Default
,
static_assert
(
LoopSched
==
LoopScheduler
::
Default
,
"Non-default loop scheduler is currently not supported"
);
"Non-default loop scheduler is currently not supported"
);
...
@@ -407,6 +416,52 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -407,6 +416,52 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
c_grid_desc_m_n
);
c_grid_desc_m_n
);
}
}
static
constexpr
auto
MakeD0sGridPointer
()
{
return
generate_tuple
(
[
&
](
auto
i
)
{
using
D0DataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
D0sDataType
>>
;
return
static_cast
<
const
D0DataType
*>
(
nullptr
);
},
Number
<
NumD0Tensor
>
{});
}
// D0 desc for source in blockwise copy
template
<
typename
D0GridDesc_M_N
>
__host__
__device__
static
constexpr
auto
MakeGemm0D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
const
D0GridDesc_M_N
&
d0_grid_desc_m_n
)
{
const
auto
M
=
d0_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
d0_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
mfma
=
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
N3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
N4
=
mfma
.
num_input_blks
;
constexpr
auto
N5
=
mfma
.
group_size
;
return
transform_tensor_descriptor
(
d0_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M
/
MPerBlock
,
MXdlPerWave
,
Gemm0MWaves
,
MPerXdl
)),
make_unmerge_transform
(
make_tuple
(
N
/
NPerBlock
,
NXdlPerWave
,
Gemm0NWaves
,
N3
,
N4
,
N5
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
>
{},
Sequence
<
1
,
3
,
5
,
7
,
8
,
9
>
{}));
}
// D0s desc for source in blockwise copy
__host__
__device__
static
constexpr
auto
MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
const
D0sGridDesc_M_N
&
ds_grid_desc_m_n
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
return
MakeGemm0D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
ds_grid_desc_m_n
[
i
]);
},
Number
<
NumD0Tensor
>
{});
}
using
D0sGridPointer
=
decltype
(
MakeD0sGridPointer
());
using
D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
=
remove_cvref_t
<
decltype
(
MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
D0sGridDesc_M_N
{}))
>
;
using
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
using
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}))
>
;
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}))
>
;
...
@@ -465,6 +520,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -465,6 +520,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
typename
C0MatrixMask
>
typename
C0MatrixMask
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
D0sGridPointer
p_d0s_grid
,
const
FloatAB
*
__restrict__
p_b1_grid
,
const
FloatAB
*
__restrict__
p_b1_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
ZDataType
*
__restrict__
p_z_grid
,
ZDataType
*
__restrict__
p_z_grid
,
...
@@ -477,6 +533,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -477,6 +533,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
const
CElementwiseOperation
&
c_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
&
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
const
B1GridDesc_BK0_N_BK1
&
b1_grid_desc_bk0_n_bk1
,
const
B1GridDesc_BK0_N_BK1
&
b1_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
...
@@ -891,6 +949,65 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -891,6 +949,65 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
// gemm1 K loop
// gemm1 K loop
index_t
gemm1_k_block_outer_index
=
0
;
index_t
gemm1_k_block_outer_index
=
0
;
const
auto
wave_id
=
GetGemm0WaveIdx
();
const
auto
wave_m_n_id
=
GetGemm0WaveMNIdx
(
wave_id
[
I2
]);
// I2: 0~63
// bias (d matrix)
constexpr
auto
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
// MBlockId
I1
,
// NBlockId
m0
,
// MRepeat
n0
,
// NRepeat
m1
,
// MWaveId
n1
,
// NWaveId
m2
,
// MPerXdl
n2
,
// NGroupNum
n3
,
// NInputNum
n4
));
// RegisterNum
auto
d0s_threadwise_copy
=
generate_tuple
(
[
&
](
auto
i
)
{
using
D0DataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
D0sDataType
>>
;
return
ThreadwiseTensorSliceTransfer_v2
<
D0DataType
,
D0DataType
,
decltype
(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
[
i
]),
decltype
(
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
),
Sequence
<
I1
,
// MBlockId
I1
,
// NBlockID
m0
,
// MRepeat
n0
,
// NRepeat
m1
,
// MWaveId
n1
,
// NWaveId
m2
,
// MPerXdl
n2
,
// NGroupNum
n3
,
// NInputNum
n4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
>
,
9
,
D0BlockTransferSrcScalarPerVector
,
1
,
false
>
(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
[
i
],
make_multi_index
(
block_work_idx_m
,
// MBlockId
0
,
// NBlockId
0
,
// mrepeat
0
,
// nrepeat
wave_id
[
I0
],
// MWaveId
wave_id
[
I1
],
// NWaveId
wave_m_n_id
[
I1
],
// MPerXdl
0
,
// group
wave_m_n_id
[
I0
],
// NInputIndex
0
));
// register number
},
Number
<
NumD0Tensor
>
{});
const
auto
d0s_grid_buf
=
generate_tuple
(
[
&
](
auto
i
)
{
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_d0s_grid
[
i
],
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
[
i
].
GetElementSpaceSize
());
},
Number
<
NumD0Tensor
>
{});
// z is random number matrix for dropout verify
// z is random number matrix for dropout verify
//
//
// z vgpr copy to global
// z vgpr copy to global
...
@@ -983,9 +1100,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -983,9 +1100,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
static_cast
<
ushort
*>
(
p_shared
),
static_cast
<
ushort
*>
(
p_shared
),
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetElementSpaceSize
());
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetElementSpaceSize
());
const
auto
wave_id
=
GetGemm0WaveIdx
();
const
auto
wave_m_n_id
=
GetGemm0WaveMNIdx
(
wave_id
[
I2
]);
// I2: 0~63
auto
z_tmp_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
auto
z_tmp_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
ushort
,
ushort
,
ushort
,
ushort
,
...
@@ -1163,6 +1277,31 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -1163,6 +1277,31 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
block_sync_lds
();
// wait for lds read in gemm0 blockwise gemm
block_sync_lds
();
// wait for lds read in gemm0 blockwise gemm
// add bias
static_for
<
0
,
NumD0Tensor
,
1
>
{}([
&
](
auto
i
)
{
// get register
using
D0DataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
D0sDataType
>>
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
D0DataType
,
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetElementSpaceSize
(),
true
>
d0_thread_buf
;
// load data from global
d0s_threadwise_copy
(
i
).
Run
(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
[
i
],
d0s_grid_buf
[
i
],
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
d0_thread_buf
);
// acc add bias
static_for
<
0
,
m0
*
n0
*
n2
*
n4
,
1
>
{}(
[
&
](
auto
j
)
{
acc_thread_buf
(
j
)
+=
d0_thread_buf
[
j
];
});
d0s_threadwise_copy
(
i
).
MoveSrcSliceWindow
(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
[
i
],
make_multi_index
(
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
});
// softmax
// softmax
SoftmaxBuf
&
max
=
blockwise_softmax
.
max_value_buf
;
SoftmaxBuf
&
max
=
blockwise_softmax
.
max_value_buf
;
SoftmaxBuf
&
sum
=
blockwise_softmax
.
sum_value_buf
;
SoftmaxBuf
&
sum
=
blockwise_softmax
.
sum_value_buf
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment