Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
e1980d10
"examples/vscode:/vscode.git/clone" did not exist on "229fd8cbca989b675ed9ad30676b323eebc24fbc"
Commit
e1980d10
authored
Nov 10, 2023
by
Qianfeng Zhang
Browse files
Update in D0 shuffled loading to support bigger KPerBlock size
parent
ac3ef99c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
47 additions
and
40 deletions
+47
-40
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_infer_xdl_cshuffle.hpp
...tion/gpu/grid/gridwise_batched_mha_infer_xdl_cshuffle.hpp
+47
-40
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_infer_xdl_cshuffle.hpp
View file @
e1980d10
...
...
@@ -361,10 +361,14 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
C1GridDesc_M_N
{}))
>
;
static
constexpr
auto
D0N2
=
AK1
;
static
constexpr
auto
D0N1
=
AK0
;
static
constexpr
auto
D0N0
=
Number
<
NPerBlock
/
KPerBlock
>
{};
static_assert
(
NPerBlock
%
KPerBlock
==
0
);
static
constexpr
auto
D0N2
=
AK1
;
static
constexpr
auto
D0N1
=
Number
<
32
/
AK1
.
value
>
{};
static
constexpr
auto
D0N0
=
Number
<
NPerBlock
/
32
>
{};
static
constexpr
auto
D0N0_PerShuffle
=
Number
<
KPerBlock
/
32
>
{};
static
constexpr
auto
D0_NumShuffle
=
NPerBlock
/
KPerBlock
;
static_assert
(
NPerBlock
%
KPerBlock
==
0
&&
KPerBlock
%
32
==
0
,
"KPerBlock should be multiple of 32 and divisor of NPerBlock"
);
__host__
__device__
static
constexpr
auto
MakeD0GridDescriptor_M0_N0_N1_N2_M1_N3
(
const
D0GridDesc_M_N
&
d0_grid_desc_m_n
)
...
...
@@ -408,47 +412,48 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
__host__
__device__
static
constexpr
auto
GetD0BlockGlobalDescriptor_M0_N0_N1_N2_M1_N3
()
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
D0N1
,
Number
<
MPerBlock
>
{},
D0N2
));
make_tuple
(
I1
,
I1
,
D0N0_PerShuffle
,
D0N1
,
Number
<
MPerBlock
>
{},
D0N2
));
}
__host__
__device__
static
constexpr
auto
GetD0BlockVgprDescriptor_M0_M1_N0_N1_N2
()
__host__
__device__
static
constexpr
auto
GetD0BlockVgprDescriptor_M0_M1_N0_N1_N2
_N3
()
{
constexpr
auto
d0_raw_n0_m_n
1
=
make_
naive_tensor_descriptor_packed
(
make_tuple
(
D0N1
,
Number
<
MPerBlock
>
{},
D0N2
));
constexpr
auto
d0_raw_n0_
n1_
m_n
2
=
make_naive_tensor_descriptor_packed
(
make_
tuple
(
D0N0_PerShuffle
,
D0N1
,
Number
<
MPerBlock
>
{},
D0N2
));
constexpr
auto
d0_raw_m_n
=
transform_tensor_descriptor
(
d0_raw_n0_m_n
1
,
d0_raw_n0_
n1_
m_n
2
,
make_tuple
(
make_pass_through_transform
(
Number
<
MPerBlock
>
{}),
make_merge_transform
(
make_tuple
(
D0N1
,
D0N2
))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_merge_transform
(
make_tuple
(
D0N0_PerShuffle
,
D0N1
,
D0N2
))),
make_tuple
(
Sequence
<
2
>
{},
Sequence
<
0
,
1
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
constexpr
auto
d0_m0_m1_n0_n1_n2
=
transform_tensor_descriptor
(
constexpr
auto
d0_m0_m1_n0_n1_n2
_n3
=
transform_tensor_descriptor
(
d0_raw_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
MPerBlock
/
MPerXdl
>
{},
Number
<
MPerXdl
>
{})),
make_unmerge_transform
(
make_tuple
((
D0N1
*
D0N2
)
/
(
I2
*
I4
),
I2
,
I4
))),
make_unmerge_transform
(
make_tuple
(
D0N0_PerShuffle
,
(
D0N1
*
D0N2
)
/
(
I2
*
I4
),
I2
,
I4
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
,
4
>
{}));
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
,
4
,
5
>
{}));
return
d0_m0_m1_n0_n1_n2
;
return
d0_m0_m1_n0_n1_n2
_n3
;
}
static
constexpr
auto
d0_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I4
,
I1
,
I4
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
D0N0_PerShuffle
,
I4
,
I1
,
I4
));
static
constexpr
auto
d0_block_dst_desc_m0_n0_n1_n2_m1_n3
=
GetD0BlockGlobalDescriptor_M0_N0_N1_N2_M1_N3
();
static
constexpr
auto
d0_block_src_desc_m0_m1_n0_n1_n2
=
GetD0BlockVgprDescriptor_M0_M1_N0_N1_N2
();
static
constexpr
auto
d0_block_src_desc_m0_m1_n0_n1_n2
_n3
=
GetD0BlockVgprDescriptor_M0_M1_N0_N1_N2
_N3
();
using
D0BlockwiseCopyGlobalToLds
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
I1
,
I1
,
I1
,
D0N1
,
MPerBlock
,
D0N2
>
,
Sequence
<
I1
,
I1
,
D0N0_PerShuffle
,
D0N1
,
MPerBlock
,
D0N2
>
,
typename
sequence_merge
<
Sequence
<
1
,
1
,
1
>
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
>::
type
,
Sequence
<
0
,
1
,
2
,
4
,
3
,
5
>
,
...
...
@@ -468,16 +473,16 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
true
,
// DstResetCoord
NumGemmKPrefetchStage
>
;
using
D0ThreadwiseCopyLdsToVgpr
=
ThreadwiseTensorSliceTransfer_v4
<
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
decltype
(
d0_block_src_desc_m0_m1_n0_n1_n2
),
// SrcDesc
decltype
(
d0_thread_desc_
),
// DstDesc
Sequence
<
1
,
1
,
4
,
1
,
4
>
,
// SliceLengths
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
// DimAccessOrder
4
,
// SrcVectorDim
4
,
// SrcScalarPerVector
2
>
;
using
D0ThreadwiseCopyLdsToVgpr
=
ThreadwiseTensorSliceTransfer_v4
<
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
decltype
(
d0_block_src_desc_m0_m1_n0_n1_n2
_n3
),
// SrcDesc
decltype
(
d0_thread_desc_
),
// DstDesc
Sequence
<
1
,
1
,
D0N0_PerShuffle
.
value
,
4
,
1
,
4
>
,
// SliceLengths
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
// DimAccessOrder
5
,
// SrcVectorDim
4
,
// SrcScalarPerVector
2
>
;
};
struct
SharedMemTrait
...
...
@@ -907,7 +912,7 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
tensor_operation
::
element_wise
::
PassThrough
{});
auto
d0_thread_copy_lds_to_vgpr
=
typename
D0Operator
::
D0ThreadwiseCopyLdsToVgpr
(
make_tuple
(
wave_id
[
I0
],
wave_m_n_id
[
I1
],
0
,
wave_m_n_id
[
I0
],
0
));
make_tuple
(
wave_id
[
I0
],
wave_m_n_id
[
I1
],
0
,
0
,
wave_m_n_id
[
I0
],
0
));
index_t
gemm1_k_block_outer_index
=
0
;
do
...
...
@@ -1016,29 +1021,31 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
auto
d0_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
D0DataType
>
(
D0Operator
::
d0_thread_desc_
.
GetElementSpaceSize
());
static_for
<
0
,
D0
N0
,
1
>
{}([
&
](
auto
nr
)
{
static_for
<
0
,
D0
_NumShuffle
,
1
>
{}([
&
](
auto
nr
)
{
// load data to lds
d0_block_copy_global_to_lds
.
RunRead
(
d0_grid_desc_m0_n0_n1_n2_m1_n3
,
d0_grid_buf
);
d0_block_copy_global_to_lds
.
MoveSrcSliceWindow
(
d0_grid_desc_m0_n0_n1_n2_m1_n3
,
make_multi_index
(
0
,
0
,
1
,
0
,
0
,
0
));
d0_grid_desc_m0_n0_n1_n2_m1_n3
,
make_multi_index
(
0
,
0
,
D0N0_PerShuffle
,
0
,
0
,
0
));
d0_block_copy_global_to_lds
.
RunWrite
(
D0Operator
::
d0_block_dst_desc_m0_n0_n1_n2_m1_n3
,
d0_block_buf
);
block_sync_lds
();
// read data form lds
d0_thread_copy_lds_to_vgpr
.
Run
(
D0Operator
::
d0_block_src_desc_m0_m1_n0_n1_n2
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
d0_block_buf
,
D0Operator
::
d0_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
d0_thread_buf
);
d0_thread_copy_lds_to_vgpr
.
Run
(
D0Operator
::
d0_block_src_desc_m0_m1_n0_n1_n2_n3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
d0_block_buf
,
D0Operator
::
d0_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
d0_thread_buf
);
// bias add
static_for
<
0
,
d0_thread_buf
.
Size
(),
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
c_offset
=
c_thread_desc
.
CalculateOffset
(
make_tuple
(
I0
,
nr
,
i
));
constexpr
index_t
c_offset
=
c_thread_desc
.
CalculateOffset
(
make_tuple
(
I0
,
nr
*
D0N0_PerShuffle
,
i
));
acc_thread_buf
(
Number
<
c_offset
>
{})
+=
ck
::
type_convert
<
FloatGemmAcc
>
(
d0_thread_buf
[
i
]);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment