Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
5cc0fd88
"examples/git@developer.sourcefind.cn:OpenDAS/apex.git" did not exist on "dcda3b56b2ed0a01988d5827f8c7d705060c79ba"
Commit
5cc0fd88
authored
Oct 25, 2023
by
letaoqin
Browse files
D0 data loading use D0BlockTransferSrcScalarPerVector
parent
e938bd61
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
9 additions
and
17 deletions
+9
-17
example/52_flash_atten_bias/batched_gemm_multihead_attention_bias_infer.cpp
...tten_bias/batched_gemm_multihead_attention_bias_infer.cpp
+4
-4
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_infer_xdl_cshuffle.hpp
...gpu/device/impl/device_batched_mha_infer_xdl_cshuffle.hpp
+0
-1
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_infer_xdl_cshuffle.hpp
...tion/gpu/grid/gridwise_batched_mha_infer_xdl_cshuffle.hpp
+5
-12
No files found.
example/52_flash_atten_bias/batched_gemm_multihead_attention_bias_infer.cpp
View file @
5cc0fd88
...
@@ -101,7 +101,7 @@ using DeviceGemmInstance =
...
@@ -101,7 +101,7 @@ using DeviceGemmInstance =
32
,
// Gemm1KPerBlock
32
,
// Gemm1KPerBlock
8
,
// AK1
8
,
// AK1
8
,
// BK1
8
,
// BK1
2
,
// B1K1
4
,
// B1K1
32
,
// MPerXDL
32
,
// MPerXDL
32
,
// NPerXDL
32
,
// NPerXDL
1
,
// MXdlPerWave
1
,
// MXdlPerWave
...
@@ -121,13 +121,13 @@ using DeviceGemmInstance =
...
@@ -121,13 +121,13 @@ using DeviceGemmInstance =
8
,
8
,
8
,
8
,
true
,
true
,
4
,
8
,
S
<
16
,
16
,
1
>
,
// B1BlockTransfer
S
<
8
,
32
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
DIM
/
32
,
4
,
4
,
2
,
false
,
false
,
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleMXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_infer_xdl_cshuffle.hpp
View file @
5cc0fd88
...
@@ -202,7 +202,6 @@ template <index_t NumDimG,
...
@@ -202,7 +202,6 @@ template <index_t NumDimG,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpecialization
MaskingSpec
,
MaskingSpecialization
MaskingSpec
,
int
D0sTransferSrcScalarPerVector
=
4
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
struct
DeviceBatchedMultiheadAttentionInfer_Xdl_CShuffle
struct
DeviceBatchedMultiheadAttentionInfer_Xdl_CShuffle
:
public
DeviceBatchedMultiheadAttentionInfer
<
NumDimG
,
:
public
DeviceBatchedMultiheadAttentionInfer
<
NumDimG
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_infer_xdl_cshuffle.hpp
View file @
5cc0fd88
...
@@ -88,11 +88,6 @@ template <typename FloatAB,
...
@@ -88,11 +88,6 @@ template <typename FloatAB,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
struct
GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
{
{
static_assert
(
D0BlockTransferSrcScalarPerVector
==
1
||
D0BlockTransferSrcScalarPerVector
==
2
||
D0BlockTransferSrcScalarPerVector
==
4
,
"D0BlockTransferSrcScalarPerVector must be 1 or 2 or 4"
);
static_assert
(
LoopSched
==
LoopScheduler
::
Default
,
static_assert
(
LoopSched
==
LoopScheduler
::
Default
,
"Non-default loop scheduler is currently not supported"
);
"Non-default loop scheduler is currently not supported"
);
...
@@ -392,20 +387,18 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
...
@@ -392,20 +387,18 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
struct
D0Operator
struct
D0Operator
{
{
static_assert
(
ABlockTransferThreadClusterLengths_AK0_M_AK1
::
Size
()
==
3
);
static_assert
(
ABlockTransferThreadClusterLengths_AK0_M_AK1
::
Size
()
==
3
);
static_assert
(
ABlockTransferDstScalarPerVector_AK1
%
D0BlockTransferSrcScalarPerVector
==
0
);
template
<
typename
DataType
>
template
<
typename
DataType
>
struct
TypeTransform
struct
TypeTransform
{
{
using
Type
=
DataType
;
using
Type
=
DataType
;
static
constexpr
index_t
Size0
=
sizeof
(
DataType
);
static
constexpr
index_t
Size
=
sizeof
(
DataType
);
};
};
template
<
>
template
<
>
struct
TypeTransform
<
void
>
struct
TypeTransform
<
void
>
{
{
using
Type
=
ck
::
half_t
;
using
Type
=
ck
::
half_t
;
static
constexpr
index_t
Size0
=
0
;
static
constexpr
index_t
Size
=
sizeof
(
ck
::
half_t
);
};
};
__host__
__device__
static
constexpr
auto
GetD0BlockGlobalDescriptor_M0_N0_N1_N2_M1_N3
()
__host__
__device__
static
constexpr
auto
GetD0BlockGlobalDescriptor_M0_N0_N1_N2_M1_N3
()
...
@@ -463,7 +456,7 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
...
@@ -463,7 +456,7 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
Sequence
<
0
,
1
,
2
,
4
,
3
,
5
>
,
Sequence
<
0
,
1
,
2
,
4
,
3
,
5
>
,
5
,
5
,
5
,
5
,
A
BlockTransferSrcScalarPerVector
,
D0
BlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
ABlockTransferDstScalarPerVector_AK1
,
1
,
1
,
1
,
1
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment