Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
ac3ef99c
Unverified
Commit
ac3ef99c
authored
Nov 08, 2023
by
Dan Yao
Committed by
GitHub
Nov 08, 2023
Browse files
Merge pull request #1010 from ROCmSoftwarePlatform/mha-train-develop-bias-shfl
Add bias with shuffle for flash attention fwd
parents
acea1753
e87ddb0e
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
218 additions
and
119 deletions
+218
-119
example/52_flash_atten_bias/batched_gemm_multihead_attention_bias_infer.cpp
...tten_bias/batched_gemm_multihead_attention_bias_infer.cpp
+4
-4
example/52_flash_atten_bias/grouped_mutihead_attention_bias_infer.cpp
...lash_atten_bias/grouped_mutihead_attention_bias_infer.cpp
+1
-1
example/52_flash_atten_bias/run_batched_multihead_attention_bias_infer.inc
...atten_bias/run_batched_multihead_attention_bias_infer.inc
+1
-1
example/52_flash_atten_bias/run_batched_multihead_attention_infer.inc
...lash_atten_bias/run_batched_multihead_attention_infer.inc
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_infer_xdl_cshuffle.hpp
...gpu/device/impl/device_batched_mha_infer_xdl_cshuffle.hpp
+10
-14
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_infer_xdl_cshuffle.hpp
...gpu/device/impl/device_grouped_mha_infer_xdl_cshuffle.hpp
+5
-7
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_infer_xdl_cshuffle.hpp
...tion/gpu/grid/gridwise_batched_mha_infer_xdl_cshuffle.hpp
+196
-91
No files found.
example/52_flash_atten_bias/batched_gemm_multihead_attention_bias_infer.cpp
View file @
ac3ef99c
...
...
@@ -101,7 +101,7 @@ using DeviceGemmInstance =
32
,
// Gemm1KPerBlock
8
,
// AK1
8
,
// BK1
2
,
// B1K1
4
,
// B1K1
32
,
// MPerXDL
32
,
// NPerXDL
1
,
// MXdlPerWave
...
...
@@ -121,13 +121,13 @@ using DeviceGemmInstance =
8
,
8
,
true
,
4
,
S
<
16
,
16
,
1
>
,
// B1BlockTransfer
8
,
S
<
8
,
32
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
DIM
/
32
,
4
,
2
,
false
,
1
,
// CShuffleMXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
...
...
example/52_flash_atten_bias/grouped_mutihead_attention_bias_infer.cpp
View file @
ac3ef99c
...
...
@@ -120,7 +120,7 @@ using DeviceGemmInstance =
8
,
8
,
true
,
1
,
8
,
S
<
16
,
16
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
...
...
example/52_flash_atten_bias/run_batched_multihead_attention_bias_infer.inc
View file @
ac3ef99c
...
...
@@ -5,7 +5,7 @@ int run(int argc, char* argv[])
{
bool
do_verification
=
true
;
int
init_method
=
1
;
bool
time_kernel
=
fals
e
;
bool
time_kernel
=
tru
e
;
// GEMM shape for A/B0/B1/C
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
...
...
example/52_flash_atten_bias/run_batched_multihead_attention_infer.inc
View file @
ac3ef99c
...
...
@@ -5,7 +5,7 @@ int run(int argc, char* argv[])
{
bool
do_verification
=
true
;
int
init_method
=
1
;
bool
time_kernel
=
fals
e
;
bool
time_kernel
=
tru
e
;
// GEMM shape for A/B0/B1/C
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_infer_xdl_cshuffle.hpp
View file @
ac3ef99c
...
...
@@ -35,7 +35,7 @@ template <typename GridwiseGemm,
typename
BGridDesc_BK0_N_BK1
,
typename
B1GridDesc_BK0_N_BK1
,
typename
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
D0GridDescriptor_M0_N0_
M1_N1_M2
_N2_M
3
_N3
_N4_N5
,
typename
D0GridDescriptor_M0_N0_
N1
_N2_M
1
_N3
,
typename
Block2CTileMap
,
typename
ComputeBasePtrOfStridedBatch
,
typename
C0MatrixMask
,
...
...
@@ -60,8 +60,7 @@ __global__ void
const
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
,
const
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c1_grid_desc_mblock_mperblock_nblock_nperblock
,
const
D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
const
D0GridDescriptor_M0_N0_N1_N2_M1_N3
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
const
Block2CTileMap
block_2_ctile_map
,
const
index_t
batch_count
,
const
index_t
h_ratio
,
...
...
@@ -109,7 +108,7 @@ __global__ void
c1de_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
d0_griddesc_m0_n0_m1_
n1_
m2_n
2
_m3
_n3_n4_n5
,
d0_grid
_
desc_m0_n0_m1_m2_n
1
_m3
,
b1_grid_desc_bk0_n_bk1
,
c1_grid_desc_mblock_mperblock_nblock_nperblock
,
block_2_ctile_map
,
...
...
@@ -129,7 +128,7 @@ __global__ void
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
b1_grid_desc_bk0_n_bk1
;
ignore
=
c1_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
d0_griddesc_m0_n0_m1_
n1_
m2_n
2
_m3
_n3_n4_n5
;
ignore
=
d0_grid
_
desc_m0_n0_m1_m2_n
1
_m3
;
ignore
=
block_2_ctile_map
;
ignore
=
batch_count
;
ignore
=
h_ratio
;
...
...
@@ -206,8 +205,7 @@ template <index_t NumDimG,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpecialization
MaskingSpec
,
int
D0sTransferSrcScalarPerVector
=
4
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
struct
DeviceBatchedMultiheadAttentionInfer_Xdl_CShuffle
:
public
DeviceBatchedMultiheadAttentionInfer
<
NumDimG
,
NumDimM
,
...
...
@@ -537,9 +535,8 @@ struct DeviceBatchedMultiheadAttentionInfer_Xdl_CShuffle
{
D0GridDesc_M_N
d0_grid_desc_m_n_
=
MakeD0GridDescriptor_M_N
(
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
);
d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
=
GridwiseGemm
::
MakeD0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
d0_grid_desc_m_n_
);
d0_grid_desc_m0_n0_m1_m2_n1_m3_
=
GridwiseGemm
::
MakeD0GridDescriptor_M0_N0_N1_N2_M1_N3
(
d0_grid_desc_m_n_
);
d0_grid_desc_g_m_n_
=
MakeD0GridDescriptor_G_M_N
(
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
);
...
...
@@ -592,8 +589,7 @@ struct DeviceBatchedMultiheadAttentionInfer_Xdl_CShuffle
typename
GridwiseGemm
::
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c1_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemm
::
D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
;
typename
GridwiseGemm
::
D0GridDescriptor_M0_N0_N1_N2_M1_N3
d0_grid_desc_m0_n0_m1_m2_n1_m3_
;
// block-to-c-tile map
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
...
...
@@ -660,7 +656,7 @@ struct DeviceBatchedMultiheadAttentionInfer_Xdl_CShuffle
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
B1GridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
D0GridDescriptor_M0_N0_
M1_N1_M2
_N2_M
3
_N3
_N4_N5
,
typename
GridwiseGemm
::
D0GridDescriptor_M0_N0_
N1
_N2_M
1
_N3
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
ComputeBasePtrOfStridedBatch
,
C0MatrixMask
,
...
...
@@ -685,7 +681,7 @@ struct DeviceBatchedMultiheadAttentionInfer_Xdl_CShuffle
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
c1_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
d0_grid_desc_m0_n0_m1_
n1_
m2_n
2
_m3_
n3_n4_n5_
,
arg
.
d0_grid_desc_m0_n0_m1_m2_n
1
_m3_
,
arg
.
block_2_ctile_map_
,
arg
.
batch_count_
,
arg
.
h_ratio_
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_infer_xdl_cshuffle.hpp
View file @
ac3ef99c
...
...
@@ -112,7 +112,7 @@ __global__ void
c_element_op
,
arg_ptr
[
group_id
].
a_grid_desc_ak0_m_ak1_
,
arg_ptr
[
group_id
].
b_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
d0_grid_desc_m0_n0_m1_
n1_
m2_n
2
_m3_
n3_n4_n5_
,
arg_ptr
[
group_id
].
d0_grid_desc_m0_n0_m1_m2_n
1
_m3_
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg_ptr
[
group_id
].
block_2_ctile_map_
,
...
...
@@ -465,8 +465,7 @@ struct DeviceGroupedMultiheadAttentionInfer_Xdl_CShuffle
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
typename
GridwiseGemm
::
D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
;
typename
GridwiseGemm
::
D0GridDescriptor_M0_N0_N1_N2_M1_N3
d0_grid_desc_m0_n0_m1_m2_n1_m3_
;
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1_
;
typename
GridwiseGemm
::
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...
...
@@ -576,9 +575,8 @@ struct DeviceGroupedMultiheadAttentionInfer_Xdl_CShuffle
const
D0GridDesc_M_N
d0_grid_desc_m_n
{
DeviceOp
::
MakeD0GridDescriptor_M_N
(
tmp_d0_gs_ms_ns_lengths
,
tmp_d0_gs_ms_ns_strides
)};
const
auto
d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
=
GridwiseGemm
::
MakeD0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
d0_grid_desc_m_n
);
const
auto
d0_grid_desc_m0_n0_m1_m2_n1_m3_
=
GridwiseGemm
::
MakeD0GridDescriptor_M0_N0_N1_N2_M1_N3
(
d0_grid_desc_m_n
);
const
auto
b1_grid_desc_bk0_n_bk1
=
MakeB1GridDescriptor_BK0_N_BK1
(
problem_desc
.
b1_gs_os_ns_lengths
,
problem_desc
.
b1_gs_os_ns_strides
);
...
...
@@ -628,7 +626,7 @@ struct DeviceGroupedMultiheadAttentionInfer_Xdl_CShuffle
p_c_grid
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
d0_grid_desc_m0_n0_m1_
n1_
m2_n
2
_m3_
n3_n4_n5
,
d0_grid_desc_m0_n0_m1_m2_n
1
_m3_
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
block_2_ctile_map
.
CalculateGridSize
(
c_grid_desc_m_n
),
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_infer_xdl_cshuffle.hpp
View file @
ac3ef99c
...
...
@@ -92,11 +92,6 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
static_assert
(
BK1Value
%
BBlockTransferDstScalarPerVector_BK1
==
0
);
static_assert
(
B1K1Value
%
B1BlockTransferDstScalarPerVector_BK1
==
0
);
static_assert
(
D0BlockTransferSrcScalarPerVector
==
1
||
D0BlockTransferSrcScalarPerVector
==
2
||
D0BlockTransferSrcScalarPerVector
==
4
,
"D0BlockTransferSrcScalarPerVector must be 1 or 2 or 4"
);
static_assert
(
LoopSched
==
LoopScheduler
::
Default
,
"Non-default loop scheduler is currently not supported"
);
...
...
@@ -366,6 +361,125 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
C1GridDesc_M_N
{}))
>
;
static
constexpr
auto
D0N2
=
AK1
;
static
constexpr
auto
D0N1
=
AK0
;
static
constexpr
auto
D0N0
=
Number
<
NPerBlock
/
KPerBlock
>
{};
static_assert
(
NPerBlock
%
KPerBlock
==
0
);
__host__
__device__
static
constexpr
auto
MakeD0GridDescriptor_M0_N0_N1_N2_M1_N3
(
const
D0GridDesc_M_N
&
d0_grid_desc_m_n
)
{
const
auto
M
=
d0_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
d0_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
MBlock
=
M
/
MPerBlock
;
const
auto
NBlock
=
N
/
NPerBlock
;
const
auto
d0_grid_desc_m0_n0_n1_n2_m1_n3
=
transform_tensor_descriptor
(
d0_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MPerBlock
>
{})),
make_unmerge_transform
(
make_tuple
(
NBlock
,
D0N0
,
D0N1
,
D0N2
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
4
>
{},
Sequence
<
1
,
2
,
3
,
5
>
{}));
return
d0_grid_desc_m0_n0_n1_n2_m1_n3
;
}
using
D0GridDescriptor_M0_N0_N1_N2_M1_N3
=
remove_cvref_t
<
decltype
(
MakeD0GridDescriptor_M0_N0_N1_N2_M1_N3
(
D0GridDesc_M_N
{}))
>
;
struct
D0Operator
{
static_assert
(
ABlockTransferThreadClusterLengths_AK0_M_AK1
::
Size
()
==
3
);
static_assert
(
ABlockTransferDstScalarPerVector_AK1
%
D0BlockTransferSrcScalarPerVector
==
0
);
template
<
typename
DataType
>
struct
TypeTransform
{
using
Type
=
DataType
;
};
template
<
>
struct
TypeTransform
<
void
>
{
using
Type
=
ck
::
half_t
;
};
__host__
__device__
static
constexpr
auto
GetD0BlockGlobalDescriptor_M0_N0_N1_N2_M1_N3
()
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
D0N1
,
Number
<
MPerBlock
>
{},
D0N2
));
}
__host__
__device__
static
constexpr
auto
GetD0BlockVgprDescriptor_M0_M1_N0_N1_N2
()
{
constexpr
auto
d0_raw_n0_m_n1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
D0N1
,
Number
<
MPerBlock
>
{},
D0N2
));
constexpr
auto
d0_raw_m_n
=
transform_tensor_descriptor
(
d0_raw_n0_m_n1
,
make_tuple
(
make_pass_through_transform
(
Number
<
MPerBlock
>
{}),
make_merge_transform
(
make_tuple
(
D0N1
,
D0N2
))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
constexpr
auto
d0_m0_m1_n0_n1_n2
=
transform_tensor_descriptor
(
d0_raw_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
MPerBlock
/
MPerXdl
>
{},
Number
<
MPerXdl
>
{})),
make_unmerge_transform
(
make_tuple
((
D0N1
*
D0N2
)
/
(
I2
*
I4
),
I2
,
I4
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
,
4
>
{}));
return
d0_m0_m1_n0_n1_n2
;
}
static
constexpr
auto
d0_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I4
,
I1
,
I4
));
static
constexpr
auto
d0_block_dst_desc_m0_n0_n1_n2_m1_n3
=
GetD0BlockGlobalDescriptor_M0_N0_N1_N2_M1_N3
();
static
constexpr
auto
d0_block_src_desc_m0_m1_n0_n1_n2
=
GetD0BlockVgprDescriptor_M0_M1_N0_N1_N2
();
using
D0BlockwiseCopyGlobalToLds
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
I1
,
I1
,
I1
,
D0N1
,
MPerBlock
,
D0N2
>
,
typename
sequence_merge
<
Sequence
<
1
,
1
,
1
>
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
>::
type
,
Sequence
<
0
,
1
,
2
,
4
,
3
,
5
>
,
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
D0GridDescriptor_M0_N0_N1_N2_M1_N3
,
decltype
(
d0_block_dst_desc_m0_n0_n1_n2_m1_n3
),
Sequence
<
0
,
1
,
2
,
4
,
3
,
5
>
,
Sequence
<
0
,
1
,
2
,
4
,
3
,
5
>
,
5
,
5
,
D0BlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
1
,
1
,
true
,
// SrcResetCoord
true
,
// DstResetCoord
NumGemmKPrefetchStage
>
;
using
D0ThreadwiseCopyLdsToVgpr
=
ThreadwiseTensorSliceTransfer_v4
<
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
decltype
(
d0_block_src_desc_m0_m1_n0_n1_n2
),
// SrcDesc
decltype
(
d0_thread_desc_
),
// DstDesc
Sequence
<
1
,
1
,
4
,
1
,
4
>
,
// SliceLengths
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
// DimAccessOrder
4
,
// SrcVectorDim
4
,
// SrcScalarPerVector
2
>
;
};
struct
SharedMemTrait
{
// LDS allocation for A and B: be careful of alignment
...
...
@@ -403,26 +517,26 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
};
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
,
typename
C0MatrixMask
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_
b
_grid
,
const
D0DataType
*
__restrict__
p_
d0
_grid
,
const
FloatAB
*
__restrict__
p_
b1
_grid
,
Float
C
*
__restrict__
p_
c
_grid
,
void
*
__restrict__
p_
share
d
,
const
AElementwiseOperation
&
a_element_op
,
const
B
ElementwiseOperation
&
b
_element_op
,
const
Acc
ElementwiseOperation
&
acc
_element_op
,
const
B1
ElementwiseOperation
&
b1
_element_op
,
const
C
ElementwiseOperation
&
c
_element_op
,
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
B
GridDesc_
B
K0_
N_B
K1
&
b
_grid_desc_
b
k0_
n_b
k1
,
const
D0
GridDesc
riptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
&
d0_griddesc_m0_n0_
m1_n1_m2
_n2_m
3
_n3
_n4_n5
,
const
B1GridDesc_BK0_N_BK1
&
b1_grid_desc_bk0_n_bk1
,
const
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
c1_grid_desc_mblock_mperblock_nblock_nperblock
,
const
Block2CTileMap
&
block_2_ctile_map
,
const
C0MatrixMask
&
c0_matrix_mask
)
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_
a
_grid
,
const
FloatAB
*
__restrict__
p_
b
_grid
,
const
D0DataType
*
__restrict__
p_
d0
_grid
,
const
Float
AB
*
__restrict__
p_
b1
_grid
,
FloatC
*
__restrict__
p_
c_gri
d
,
void
*
__restrict__
p_shared
,
const
A
ElementwiseOperation
&
a
_element_op
,
const
B
ElementwiseOperation
&
b
_element_op
,
const
Acc
ElementwiseOperation
&
acc
_element_op
,
const
B1
ElementwiseOperation
&
b1
_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
A
GridDesc_
A
K0_
M_A
K1
&
a
_grid_desc_
a
k0_
m_a
k1
,
const
B
GridDesc
_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
D0GridDescriptor_M0_N0_N1_N2_M1_N3
&
d0_grid
_
desc_m0_n0_
n1
_n2_m
1
_n3
,
const
B1GridDesc_BK0_N_BK1
&
b1_grid_desc_bk0_n_bk1
,
const
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
c1_grid_desc_mblock_mperblock_nblock_nperblock
,
const
Block2CTileMap
&
block_2_ctile_map
,
const
C0MatrixMask
&
c0_matrix_mask
)
{
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
...
...
@@ -683,49 +797,6 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
const
auto
wave_id
=
GetGemm0WaveIdx
();
const
auto
wave_m_n_id
=
GetGemm0WaveMNIdx
(
wave_id
[
I2
]);
// I2: 0~63
// bias (d0 matrix)
constexpr
auto
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
// MBlockId
I1
,
// NBlockId
m0
,
// MRepeat
n0
,
// NRepeat
m1
,
// MWaveId
n1
,
// NWaveId
m2
,
// MPerXdl
n2
,
// NGroupNum
n3
,
// NInputNum
n4
));
// RegisterNum
auto
d0_threadwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
D0DataType
,
D0DataType
,
decltype
(
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
),
decltype
(
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
),
Sequence
<
I1
,
// MBlockId
I1
,
// NBlockID
m0
,
// MRepeat
n0
,
// NRepeat
m1
,
// MWaveId
n1
,
// NWaveId
m2
,
// MPerXdl
n2
,
// NGroupNum
n3
,
// NInputNum
n4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
>
,
9
,
D0BlockTransferSrcScalarPerVector
,
1
,
false
>
(
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
block_work_idx
[
I0
],
// MBlockId
0
,
// NBlockId
0
,
// mrepeat
0
,
// nrepeat
wave_id
[
I0
],
// MWaveId
wave_id
[
I1
],
// NWaveId
wave_m_n_id
[
I1
],
// MPerXdl
0
,
// group
wave_m_n_id
[
I0
],
// NInputIndex
0
));
// register number
// selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size
// selected_mfma.k_per_blk <= Gemm1KPack
...
...
@@ -827,6 +898,17 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
running_max_new
=
NumericLimits
<
FloatGemmAcc
>::
Lowest
();
// gemm1 K loop
auto
d0_block_copy_global_to_lds
=
typename
D0Operator
::
D0BlockwiseCopyGlobalToLds
(
d0_grid_desc_m0_n0_n1_n2_m1_n3
,
make_multi_index
(
block_work_idx
[
I0
],
0
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{},
D0Operator
::
d0_block_dst_desc_m0_n0_n1_n2_m1_n3
,
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
auto
d0_thread_copy_lds_to_vgpr
=
typename
D0Operator
::
D0ThreadwiseCopyLdsToVgpr
(
make_tuple
(
wave_id
[
I0
],
wave_m_n_id
[
I1
],
0
,
wave_m_n_id
[
I0
],
0
));
index_t
gemm1_k_block_outer_index
=
0
;
do
{
...
...
@@ -920,30 +1002,53 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
// add bias
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
const
auto
d0_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_d0_grid
,
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetElementSpaceSize
());
// get register
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
D0DataType
,
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetElementSpaceSize
(),
true
>
d0_thread_buf
;
// load data from global
d0_threadwise_copy
.
Run
(
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
d0_grid_buf
,
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
d0_thread_buf
);
// acc add bias
static_for
<
0
,
m0
*
n0
*
n2
*
n4
,
1
>
{}([
&
](
auto
i
)
{
acc_thread_buf
(
i
)
+=
ck
::
type_convert
<
FloatGemmAcc
>
(
d0_thread_buf
[
i
]);
});
if
(
p_d0_grid
!=
nullptr
)
{
static
constexpr
auto
&
c_thread_desc
=
blockwise_gemm
.
GetCThreadDesc
();
const
auto
d0_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_d0_grid
,
d0_grid_desc_m0_n0_n1_n2_m1_n3
.
GetElementSpaceSize
());
auto
d0_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
D0DataType
*>
(
p_shared
)
+
SharedMemTrait
::
a_block_space_offset
,
D0Operator
::
d0_block_dst_desc_m0_n0_n1_n2_m1_n3
.
GetElementSpaceSize
());
auto
d0_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
D0DataType
>
(
D0Operator
::
d0_thread_desc_
.
GetElementSpaceSize
());
static_for
<
0
,
D0N0
,
1
>
{}([
&
](
auto
nr
)
{
// load data to lds
d0_block_copy_global_to_lds
.
RunRead
(
d0_grid_desc_m0_n0_n1_n2_m1_n3
,
d0_grid_buf
);
d0_threadwise_copy
.
MoveSrcSliceWindow
(
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
d0_block_copy_global_to_lds
.
MoveSrcSliceWindow
(
d0_grid_desc_m0_n0_n1_n2_m1_n3
,
make_multi_index
(
0
,
0
,
1
,
0
,
0
,
0
));
d0_block_copy_global_to_lds
.
RunWrite
(
D0Operator
::
d0_block_dst_desc_m0_n0_n1_n2_m1_n3
,
d0_block_buf
);
block_sync_lds
();
// read data form lds
d0_thread_copy_lds_to_vgpr
.
Run
(
D0Operator
::
d0_block_src_desc_m0_m1_n0_n1_n2
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
d0_block_buf
,
D0Operator
::
d0_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
d0_thread_buf
);
// bias add
static_for
<
0
,
d0_thread_buf
.
Size
(),
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
c_offset
=
c_thread_desc
.
CalculateOffset
(
make_tuple
(
I0
,
nr
,
i
));
acc_thread_buf
(
Number
<
c_offset
>
{})
+=
ck
::
type_convert
<
FloatGemmAcc
>
(
d0_thread_buf
[
i
]);
});
});
d0_block_copy_global_to_lds
.
MoveSrcSliceWindow
(
d0_grid_desc_m0_n0_n1_n2_m1_n3
,
make_multi_index
(
0
,
1
,
-
D0N0
.
value
,
0
,
0
,
0
));
}
}
// softmax
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment