Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
72faed1c
Commit
72faed1c
authored
Oct 23, 2023
by
letaoqin
Browse files
bias with shuffle
parent
b23b3d71
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
211 additions
and
99 deletions
+211
-99
example/52_flash_atten_bias/run_batched_multihead_attention_bias_infer.inc
...atten_bias/run_batched_multihead_attention_bias_infer.inc
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_infer_xdl_cshuffle.hpp
...gpu/device/impl/device_batched_mha_infer_xdl_cshuffle.hpp
+9
-12
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_infer_xdl_cshuffle.hpp
...tion/gpu/grid/gridwise_batched_mha_infer_xdl_cshuffle.hpp
+201
-86
No files found.
example/52_flash_atten_bias/run_batched_multihead_attention_bias_infer.inc
View file @
72faed1c
...
@@ -5,7 +5,7 @@ int run(int argc, char* argv[])
...
@@ -5,7 +5,7 @@ int run(int argc, char* argv[])
{
{
bool
do_verification
=
true
;
bool
do_verification
=
true
;
int
init_method
=
1
;
int
init_method
=
1
;
bool
time_kernel
=
fals
e
;
bool
time_kernel
=
tru
e
;
// GEMM shape for A/B0/B1/C
// GEMM shape for A/B0/B1/C
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_infer_xdl_cshuffle.hpp
View file @
72faed1c
...
@@ -35,7 +35,7 @@ template <typename GridwiseGemm,
...
@@ -35,7 +35,7 @@ template <typename GridwiseGemm,
typename
BGridDesc_BK0_N_BK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
B1GridDesc_BK0_N_BK1
,
typename
B1GridDesc_BK0_N_BK1
,
typename
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
D0GridDescriptor_M0_N0_
M1_N1_M2
_N2_M
3
_N3
_N4_N5
,
typename
D0GridDescriptor_M0_N0_
N1
_N2_M
1
_N3
,
typename
Block2CTileMap
,
typename
Block2CTileMap
,
typename
ComputeBasePtrOfStridedBatch
,
typename
ComputeBasePtrOfStridedBatch
,
typename
C0MatrixMask
,
typename
C0MatrixMask
,
...
@@ -60,8 +60,7 @@ __global__ void
...
@@ -60,8 +60,7 @@ __global__ void
const
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
,
const
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
,
const
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
const
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c1_grid_desc_mblock_mperblock_nblock_nperblock
,
c1_grid_desc_mblock_mperblock_nblock_nperblock
,
const
D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
const
D0GridDescriptor_M0_N0_N1_N2_M1_N3
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
const
Block2CTileMap
block_2_ctile_map
,
const
Block2CTileMap
block_2_ctile_map
,
const
index_t
batch_count
,
const
index_t
batch_count
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
,
...
@@ -107,7 +106,7 @@ __global__ void
...
@@ -107,7 +106,7 @@ __global__ void
c1de_element_op
,
c1de_element_op
,
a_grid_desc_ak0_m_ak1
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b_grid_desc_bk0_n_bk1
,
d0_griddesc_m0_n0_m1_
n1_
m2_n
2
_m3
_n3_n4_n5
,
d0_grid
_
desc_m0_n0_m1_m2_n
1
_m3
,
b1_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
c1_grid_desc_mblock_mperblock_nblock_nperblock
,
c1_grid_desc_mblock_mperblock_nblock_nperblock
,
block_2_ctile_map
,
block_2_ctile_map
,
...
@@ -127,7 +126,7 @@ __global__ void
...
@@ -127,7 +126,7 @@ __global__ void
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
b1_grid_desc_bk0_n_bk1
;
ignore
=
b1_grid_desc_bk0_n_bk1
;
ignore
=
c1_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
c1_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
d0_griddesc_m0_n0_m1_
n1_
m2_n
2
_m3
_n3_n4_n5
;
ignore
=
d0_grid
_
desc_m0_n0_m1_m2_n
1
_m3
;
ignore
=
block_2_ctile_map
;
ignore
=
block_2_ctile_map
;
ignore
=
batch_count
;
ignore
=
batch_count
;
ignore
=
compute_base_ptr_of_batch
;
ignore
=
compute_base_ptr_of_batch
;
...
@@ -533,9 +532,8 @@ struct DeviceBatchedMultiheadAttentionInfer_Xdl_CShuffle
...
@@ -533,9 +532,8 @@ struct DeviceBatchedMultiheadAttentionInfer_Xdl_CShuffle
{
{
D0GridDesc_M_N
d0_grid_desc_m_n_
=
MakeD0GridDescriptor_M_N
(
D0GridDesc_M_N
d0_grid_desc_m_n_
=
MakeD0GridDescriptor_M_N
(
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
);
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
);
d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
=
d0_grid_desc_m0_n0_m1_m2_n1_m3_
=
GridwiseGemm
::
MakeD0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
GridwiseGemm
::
MakeD0GridDescriptor_M0_N0_N1_N2_M1_N3
(
d0_grid_desc_m_n_
);
d0_grid_desc_m_n_
);
d0_grid_desc_g_m_n_
=
MakeD0GridDescriptor_G_M_N
(
acc0_bias_gs_ms_ns_lengths
,
d0_grid_desc_g_m_n_
=
MakeD0GridDescriptor_G_M_N
(
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
);
acc0_bias_gs_ms_ns_strides
);
...
@@ -588,8 +586,7 @@ struct DeviceBatchedMultiheadAttentionInfer_Xdl_CShuffle
...
@@ -588,8 +586,7 @@ struct DeviceBatchedMultiheadAttentionInfer_Xdl_CShuffle
typename
GridwiseGemm
::
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
typename
GridwiseGemm
::
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c1_grid_desc_mblock_mperblock_nblock_nperblock_
;
c1_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemm
::
D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
typename
GridwiseGemm
::
D0GridDescriptor_M0_N0_N1_N2_M1_N3
d0_grid_desc_m0_n0_m1_m2_n1_m3_
;
d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
;
// block-to-c-tile map
// block-to-c-tile map
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
...
@@ -655,7 +652,7 @@ struct DeviceBatchedMultiheadAttentionInfer_Xdl_CShuffle
...
@@ -655,7 +652,7 @@ struct DeviceBatchedMultiheadAttentionInfer_Xdl_CShuffle
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
B1GridDesc_BK0_N_BK1
,
DeviceOp
::
B1GridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
D0GridDescriptor_M0_N0_
M1_N1_M2
_N2_M
3
_N3
_N4_N5
,
typename
GridwiseGemm
::
D0GridDescriptor_M0_N0_
N1
_N2_M
1
_N3
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
ComputeBasePtrOfStridedBatch
,
ComputeBasePtrOfStridedBatch
,
C0MatrixMask
,
C0MatrixMask
,
...
@@ -680,7 +677,7 @@ struct DeviceBatchedMultiheadAttentionInfer_Xdl_CShuffle
...
@@ -680,7 +677,7 @@ struct DeviceBatchedMultiheadAttentionInfer_Xdl_CShuffle
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
c1_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
c1_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
d0_grid_desc_m0_n0_m1_
n1_
m2_n
2
_m3_
n3_n4_n5_
,
arg
.
d0_grid_desc_m0_n0_m1_m2_n
1
_m3_
,
arg
.
block_2_ctile_map_
,
arg
.
block_2_ctile_map_
,
arg
.
batch_count_
,
arg
.
batch_count_
,
arg
.
compute_base_ptr_of_batch_
,
arg
.
compute_base_ptr_of_batch_
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_infer_xdl_cshuffle.hpp
View file @
72faed1c
...
@@ -362,6 +362,124 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
...
@@ -362,6 +362,124 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
using
DefaultBlock2CTileMap
=
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
C1GridDesc_M_N
{}))
>
;
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
C1GridDesc_M_N
{}))
>
;
static
constexpr
auto
D0N2
=
AK1
;
static
constexpr
auto
D0N1
=
AK0
;
static
constexpr
auto
D0N0
=
Number
<
NPerBlock
/
KPerBlock
>
{};
static_assert
(
NPerBlock
%
KPerBlock
==
0
);
__host__
__device__
static
constexpr
auto
MakeD0GridDescriptor_M0_N0_N1_N2_M1_N3
(
const
D0GridDesc_M_N
&
d0_grid_desc_m_n
)
{
const
auto
M
=
d0_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
d0_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
MBlock
=
M
/
MPerBlock
;
const
auto
NBlock
=
N
/
NPerBlock
;
const
auto
d0_grid_desc_m0_n0_n1_n2_m1_n3
=
transform_tensor_descriptor
(
d0_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MPerBlock
>
{})),
make_unmerge_transform
(
make_tuple
(
NBlock
,
D0N0
,
D0N1
,
D0N2
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
4
>
{},
Sequence
<
1
,
2
,
3
,
5
>
{}));
return
d0_grid_desc_m0_n0_n1_n2_m1_n3
;
}
using
D0GridDescriptor_M0_N0_N1_N2_M1_N3
=
remove_cvref_t
<
decltype
(
MakeD0GridDescriptor_M0_N0_N1_N2_M1_N3
(
D0GridDesc_M_N
{}))
>
;
struct
D0Operator
{
template
<
typename
DataType
>
struct
TypeTransform
{
using
Type
=
DataType
;
static
constexpr
index_t
Size0
=
sizeof
(
DataType
);
static
constexpr
index_t
Size
=
sizeof
(
DataType
);
};
template
<
>
struct
TypeTransform
<
void
>
{
using
Type
=
ck
::
half_t
;
static
constexpr
index_t
Size0
=
0
;
static
constexpr
index_t
Size
=
sizeof
(
ck
::
half_t
);
};
__host__
__device__
static
constexpr
auto
GetD0BlockGlobalDescriptor_M0_N0_N1_N2_M1_N3
()
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
D0N1
,
Number
<
MPerBlock
>
{},
D0N2
));
}
__host__
__device__
static
constexpr
auto
GetD0BlockVgprDescriptor_M0_M1_N0_N1_N2
()
{
constexpr
auto
d0_raw_n0_m_n1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
D0N1
,
Number
<
MPerBlock
>
{},
D0N2
));
constexpr
auto
d0_raw_m_n
=
transform_tensor_descriptor
(
d0_raw_n0_m_n1
,
make_tuple
(
make_pass_through_transform
(
Number
<
MPerBlock
>
{}),
make_merge_transform
(
make_tuple
(
D0N1
,
D0N2
))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
constexpr
auto
d0_m0_m1_n0_n1_n2
=
transform_tensor_descriptor
(
d0_raw_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
MPerBlock
/
MPerXdl
>
{},
Number
<
MPerXdl
>
{})),
make_unmerge_transform
(
make_tuple
((
D0N1
*
D0N2
)
/
(
I2
*
I4
),
I2
,
I4
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
,
4
>
{}));
return
d0_m0_m1_n0_n1_n2
;
}
static
constexpr
auto
d0_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I4
,
I1
,
I4
));
static
constexpr
auto
d0_block_dst_desc_m0_n0_n1_n2_m1_n3
=
GetD0BlockGlobalDescriptor_M0_N0_N1_N2_M1_N3
();
static
constexpr
auto
d0_block_src_desc_m0_m1_n0_n1_n2
=
GetD0BlockVgprDescriptor_M0_M1_N0_N1_N2
();
using
D0BlockwiseCopyGlobalToLds
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
I1
,
I1
,
I1
,
D0N1
,
MPerBlock
,
D0N2
>
,
Sequence
<
I1
,
I1
,
I1
,
4
,
64
,
1
>
,
Sequence
<
0
,
1
,
2
,
4
,
3
,
5
>
,
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
D0GridDescriptor_M0_N0_N1_N2_M1_N3
,
decltype
(
d0_block_dst_desc_m0_n0_n1_n2_m1_n3
),
Sequence
<
0
,
1
,
2
,
4
,
3
,
5
>
,
Sequence
<
0
,
1
,
2
,
4
,
3
,
5
>
,
5
,
5
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
1
,
1
,
true
,
// SrcResetCoord
true
,
// DstResetCoord
NumGemmKPrefetchStage
>
;
using
D0ThreadwiseCopyLdsToVgpr
=
ThreadwiseTensorSliceTransfer_v4
<
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
decltype
(
d0_block_src_desc_m0_m1_n0_n1_n2
),
// SrcDesc
decltype
(
d0_thread_desc_
),
// DstDesc
Sequence
<
1
,
1
,
4
,
1
,
4
>
,
// SliceLengths
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
// DimAccessOrder
4
,
// SrcVectorDim
4
,
// SrcScalarPerVector
2
>
;
};
struct
SharedMemTrait
struct
SharedMemTrait
{
{
// LDS allocation for A and B: be careful of alignment
// LDS allocation for A and B: be careful of alignment
...
@@ -399,27 +517,28 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
...
@@ -399,27 +517,28 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
};
};
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
,
typename
C0MatrixMask
>
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
,
typename
C0MatrixMask
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
__device__
static
void
const
FloatAB
*
__restrict__
p_
b
_grid
,
Run
(
const
FloatAB
*
__restrict__
p_
a
_grid
,
const
D0DataType
*
__restrict__
p_
d0
_grid
,
const
FloatAB
*
__restrict__
p_
b
_grid
,
const
FloatAB
*
__restrict__
p_
b1
_grid
,
const
D0DataType
*
__restrict__
p_
d0
_grid
,
Float
C
*
__restrict__
p_
c
_grid
,
const
Float
AB
*
__restrict__
p_
b1
_grid
,
void
*
__restrict__
p_
share
d
,
FloatC
*
__restrict__
p_
c_gri
d
,
const
AElementwiseOperation
&
a_element_op
,
void
*
__restrict__
p_shared
,
const
B
ElementwiseOperation
&
b
_element_op
,
const
A
ElementwiseOperation
&
a
_element_op
,
const
Acc
ElementwiseOperation
&
acc
_element_op
,
const
B
ElementwiseOperation
&
b
_element_op
,
const
B1
ElementwiseOperation
&
b1
_element_op
,
const
Acc
ElementwiseOperation
&
acc
_element_op
,
const
C
ElementwiseOperation
&
c
_element_op
,
const
B1
ElementwiseOperation
&
b1
_element_op
,
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
CElementwiseOperation
&
c_element_op
,
const
B
GridDesc_
B
K0_
N_B
K1
&
b
_grid_desc_
b
k0_
n_b
k1
,
const
A
GridDesc_
A
K0_
M_A
K1
&
a
_grid_desc_
a
k0_
m_a
k1
,
const
D0
GridDesc
riptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
&
const
B
GridDesc
_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
d0_griddesc_m0_n0_
m1_n1_m2
_n2_m
3
_n3
_n4_n5
,
const
D0GridDescriptor_M0_N0_N1_N2_M1_N3
&
d0_grid
_
desc_m0_n0_
n1
_n2_m
1
_n3
,
const
B1GridDesc_BK0_N_BK1
&
b1_grid_desc_bk0_n_bk1
,
const
B1GridDesc_BK0_N_BK1
&
b1_grid_desc_bk0_n_bk1
,
const
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
const
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
c1_grid_desc_mblock_mperblock_nblock_nperblock
,
c1_grid_desc_mblock_mperblock_nblock_nperblock
,
const
Block2CTileMap
&
block_2_ctile_map
,
const
Block2CTileMap
&
block_2_ctile_map
,
const
C0MatrixMask
&
c0_matrix_mask
)
const
C0MatrixMask
&
c0_matrix_mask
)
{
{
ignore
=
d0_grid_desc_m0_n0_n1_n2_m1_n3
;
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
...
@@ -679,49 +798,6 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
...
@@ -679,49 +798,6 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
const
auto
wave_id
=
GetGemm0WaveIdx
();
const
auto
wave_id
=
GetGemm0WaveIdx
();
const
auto
wave_m_n_id
=
GetGemm0WaveMNIdx
(
wave_id
[
I2
]);
// I2: 0~63
const
auto
wave_m_n_id
=
GetGemm0WaveMNIdx
(
wave_id
[
I2
]);
// I2: 0~63
// bias (d0 matrix)
constexpr
auto
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
// MBlockId
I1
,
// NBlockId
m0
,
// MRepeat
n0
,
// NRepeat
m1
,
// MWaveId
n1
,
// NWaveId
m2
,
// MPerXdl
n2
,
// NGroupNum
n3
,
// NInputNum
n4
));
// RegisterNum
auto
d0_threadwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
D0DataType
,
D0DataType
,
decltype
(
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
),
decltype
(
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
),
Sequence
<
I1
,
// MBlockId
I1
,
// NBlockID
m0
,
// MRepeat
n0
,
// NRepeat
m1
,
// MWaveId
n1
,
// NWaveId
m2
,
// MPerXdl
n2
,
// NGroupNum
n3
,
// NInputNum
n4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
>
,
9
,
D0BlockTransferSrcScalarPerVector
,
1
,
false
>
(
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
block_work_idx
[
I0
],
// MBlockId
0
,
// NBlockId
0
,
// mrepeat
0
,
// nrepeat
wave_id
[
I0
],
// MWaveId
wave_id
[
I1
],
// NWaveId
wave_m_n_id
[
I1
],
// MPerXdl
0
,
// group
wave_m_n_id
[
I0
],
// NInputIndex
0
));
// register number
// selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size
// selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size
// selected_mfma.k_per_blk <= Gemm1KPack
// selected_mfma.k_per_blk <= Gemm1KPack
...
@@ -823,6 +899,18 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
...
@@ -823,6 +899,18 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
running_max_new
=
NumericLimits
<
FloatGemmAcc
>::
Lowest
();
running_max_new
=
NumericLimits
<
FloatGemmAcc
>::
Lowest
();
// gemm1 K loop
// gemm1 K loop
ignore
=
wave_m_n_id
;
auto
d0_block_copy_global_to_lds
=
typename
D0Operator
::
D0BlockwiseCopyGlobalToLds
(
d0_grid_desc_m0_n0_n1_n2_m1_n3
,
make_multi_index
(
block_work_idx
[
I0
],
0
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{},
D0Operator
::
d0_block_dst_desc_m0_n0_n1_n2_m1_n3
,
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
auto
d0_thread_copy_lds_to_vgpr
=
typename
D0Operator
::
D0ThreadwiseCopyLdsToVgpr
(
make_tuple
(
wave_id
[
I0
],
wave_m_n_id
[
I1
],
0
,
wave_m_n_id
[
I0
],
0
));
index_t
gemm1_k_block_outer_index
=
0
;
index_t
gemm1_k_block_outer_index
=
0
;
do
do
{
{
...
@@ -916,30 +1004,57 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
...
@@ -916,30 +1004,57 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
// add bias
// add bias
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
{
const
auto
d0_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
ignore
=
d0_thread_copy_lds_to_vgpr
;
p_d0_grid
,
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetElementSpaceSize
());
if
(
p_d0_grid
!=
nullptr
)
// get register
{
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
static
constexpr
auto
&
c_thread_desc
=
blockwise_gemm
.
GetCThreadDesc
();
D0DataType
,
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetElementSpaceSize
(),
const
auto
d0_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
true
>
p_d0_grid
,
d0_grid_desc_m0_n0_n1_n2_m1_n3
.
GetElementSpaceSize
());
d0_thread_buf
;
auto
d0_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
// load data from global
static_cast
<
D0DataType
*>
(
p_shared
)
+
SharedMemTrait
::
a_block_space_offset
,
d0_threadwise_copy
.
Run
(
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
D0Operator
::
d0_block_dst_desc_m0_n0_n1_n2_m1_n3
.
GetElementSpaceSize
());
d0_grid_buf
,
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
auto
d0_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
D0DataType
>
(
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
D0Operator
::
d0_thread_desc_
.
GetElementSpaceSize
());
d0_thread_buf
);
ignore
=
c_thread_desc
;
ignore
=
d0_grid_buf
;
// acc add bias
ignore
=
d0_block_buf
;
static_for
<
0
,
m0
*
n0
*
n2
*
n4
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
d0_thread_buf
;
acc_thread_buf
(
i
)
+=
ck
::
type_convert
<
FloatGemmAcc
>
(
d0_thread_buf
[
i
]);
static_for
<
0
,
D0N0
,
1
>
{}([
&
](
auto
nr
)
{
});
// load data to lds
d0_block_copy_global_to_lds
.
RunRead
(
d0_grid_desc_m0_n0_n1_n2_m1_n3
,
d0_grid_buf
);
d0_block_copy_global_to_lds
.
MoveSrcSliceWindow
(
d0_grid_desc_m0_n0_n1_n2_m1_n3
,
make_multi_index
(
0
,
0
,
1
,
0
,
0
,
0
));
d0_block_copy_global_to_lds
.
RunWrite
(
D0Operator
::
d0_block_dst_desc_m0_n0_n1_n2_m1_n3
,
d0_block_buf
);
block_sync_lds
();
// read data form lds
d0_thread_copy_lds_to_vgpr
.
Run
(
D0Operator
::
d0_block_src_desc_m0_m1_n0_n1_n2
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
d0_block_buf
,
D0Operator
::
d0_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
d0_thread_buf
);
// bias add
static_for
<
0
,
d0_thread_buf
.
Size
(),
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
c_offset
=
c_thread_desc
.
CalculateOffset
(
make_tuple
(
I0
,
nr
,
i
));
ignore
=
c_offset
;
acc_thread_buf
(
Number
<
c_offset
>
{})
+=
ck
::
type_convert
<
FloatGemmAcc
>
(
d0_thread_buf
[
i
]);
});
});
d0_threadwise_copy
.
MoveSrcSliceWindow
(
d0_block_copy_global_to_lds
.
MoveSrcSliceWindow
(
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
d0_grid_desc_m0_n0_n1_n2_m1_n3
,
make_multi_index
(
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
make_multi_index
(
0
,
1
,
-
D0N0
.
value
,
0
,
0
,
0
));
}
}
}
// softmax
// softmax
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment