Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
81639679
Commit
81639679
authored
Nov 11, 2023
by
Qianfeng Zhang
Browse files
Let D0 shuffled laoding not depend on ABlockTransfer Spec
parent
9a423017
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
16 deletions
+20
-16
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_infer_xdl_cshuffle.hpp
...tion/gpu/grid/gridwise_batched_mha_infer_xdl_cshuffle.hpp
+20
-16
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_infer_xdl_cshuffle.hpp
View file @
81639679
...
@@ -361,14 +361,21 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
...
@@ -361,14 +361,21 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
using
DefaultBlock2CTileMap
=
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
C1GridDesc_M_N
{}))
>
;
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
C1GridDesc_M_N
{}))
>
;
static
constexpr
auto
D0N2
=
AK1
;
using
D0StoreType
=
static
constexpr
auto
D0N1
=
Number
<
32
/
AK1
.
value
>
{};
typename
conditional
<
is_same
<
D0DataType
,
void
>::
value
,
half_t
,
D0DataType
>::
type
;
static
constexpr
auto
D0N0
=
Number
<
NPerBlock
/
32
>
{};
static
constexpr
auto
D0N0_PerShuffle
=
Number
<
KPerBlock
/
32
>
{};
static
constexpr
auto
D0ShuffleBlock_N
=
static
constexpr
auto
D0_NumShuffle
=
NPerBlock
/
KPerBlock
;
ck
::
math
::
min
(
static_cast
<
index_t
>
(
32768
/
sizeof
(
D0StoreType
))
/
MPerBlock
,
NPerBlock
);
static
constexpr
auto
D0N2
=
Number
<
4
*
sizeof
(
float
)
/
sizeof
(
D0StoreType
)
>
{};
static_assert
(
NPerBlock
%
KPerBlock
==
0
&&
KPerBlock
%
32
==
0
,
static
constexpr
auto
D0N1
=
Number
<
32
/
D0N2
.
value
>
{};
"KPerBlock should be multiple of 32 and divisor of NPerBlock"
);
static
constexpr
auto
D0N0
=
Number
<
NPerBlock
/
32
>
{};
// ToDo: strange issue when D0N0_PerShuffle == 4 is used (too many vgpr consumption ?)
static
constexpr
auto
D0N0_PerShuffle
=
Number
<
ck
::
math
::
min
(
D0ShuffleBlock_N
/
32
,
2
)
>
{};
static
constexpr
auto
D0_NumShuffle
=
D0N0
.
value
/
D0N0_PerShuffle
.
value
;
static
constexpr
auto
I16
=
Number
<
16
>
{};
static_assert
(
NPerBlock
%
D0ShuffleBlock_N
==
0
&&
D0ShuffleBlock_N
%
32
==
0
,
"Calculated D0ShuffleBlock_N should be multiple of 32 and divisor of NPerBlock"
);
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeD0GridDescriptor_M0_N0_N1_N2_M1_N3
(
const
D0GridDesc_M_N
&
d0_grid_desc_m_n
)
MakeD0GridDescriptor_M0_N0_N1_N2_M1_N3
(
const
D0GridDesc_M_N
&
d0_grid_desc_m_n
)
...
@@ -394,9 +401,7 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
...
@@ -394,9 +401,7 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
struct
D0Operator
struct
D0Operator
{
{
static_assert
(
ABlockTransferThreadClusterLengths_AK0_M_AK1
::
Size
()
==
3
);
static_assert
(
D0N2
%
D0BlockTransferSrcScalarPerVector
==
0
);
static_assert
(
ABlockTransferDstScalarPerVector_AK1
%
D0BlockTransferSrcScalarPerVector
==
0
);
template
<
typename
DataType
>
template
<
typename
DataType
>
struct
TypeTransform
struct
TypeTransform
...
@@ -454,8 +459,7 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
...
@@ -454,8 +459,7 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
I1
,
I1
,
D0N0_PerShuffle
,
D0N1
,
MPerBlock
,
D0N2
>
,
Sequence
<
I1
,
I1
,
D0N0_PerShuffle
,
D0N1
,
MPerBlock
,
D0N2
>
,
typename
sequence_merge
<
Sequence
<
1
,
1
,
1
>
,
typename
sequence_merge
<
Sequence
<
1
,
1
,
1
>
,
Sequence
<
4
,
BlockSize
/
4
,
1
>>::
type
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
>::
type
,
Sequence
<
0
,
1
,
2
,
4
,
3
,
5
>
,
Sequence
<
0
,
1
,
2
,
4
,
3
,
5
>
,
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
...
@@ -466,7 +470,7 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
...
@@ -466,7 +470,7 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
5
,
5
,
5
,
5
,
D0BlockTransferSrcScalarPerVector
,
D0BlockTransferSrcScalarPerVector
,
A
BlockTransferDstScalarPerVector
_AK1
,
D0N2
,
// D0
BlockTransferDstScalarPerVector
1
,
1
,
1
,
1
,
true
,
// SrcResetCoord
true
,
// SrcResetCoord
...
@@ -482,7 +486,7 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
...
@@ -482,7 +486,7 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
// DimAccessOrder
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
// DimAccessOrder
5
,
// SrcVectorDim
5
,
// SrcVectorDim
4
,
// SrcScalarPerVector
4
,
// SrcScalarPerVector
2
>
;
2
>
;
// SrcScalarStrideInVector (not used)
};
};
struct
SharedMemTrait
struct
SharedMemTrait
...
@@ -1045,7 +1049,7 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
...
@@ -1045,7 +1049,7 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
// bias add
// bias add
static_for
<
0
,
d0_thread_buf
.
Size
(),
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
d0_thread_buf
.
Size
(),
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
c_offset
=
c_thread_desc
.
CalculateOffset
(
constexpr
index_t
c_offset
=
c_thread_desc
.
CalculateOffset
(
make_tuple
(
I0
,
nr
*
D0N0_PerShuffle
,
i
));
make_tuple
(
I0
,
nr
*
D0N0_PerShuffle
+
i
/
I16
,
i
%
I16
));
acc_thread_buf
(
Number
<
c_offset
>
{})
+=
acc_thread_buf
(
Number
<
c_offset
>
{})
+=
ck
::
type_convert
<
FloatGemmAcc
>
(
d0_thread_buf
[
i
]);
ck
::
type_convert
<
FloatGemmAcc
>
(
d0_thread_buf
[
i
]);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment