Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
c9c85206
Commit
c9c85206
authored
Feb 23, 2023
by
danyao12
Browse files
Add template for prototype1
parent
d7244818
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
86 additions
and
5 deletions
+86
-5
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_pt1_fp16.cpp
...ax_gemm/batched_multihead_attention_backward_pt1_fp16.cpp
+77
-0
include/ck/ck.hpp
include/ck/ck.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp
...batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp
+2
-0
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp
...batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp
+6
-4
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_pt1_fp16.cpp
View file @
c9c85206
...
@@ -25,6 +25,7 @@ Kernel outputs:
...
@@ -25,6 +25,7 @@ Kernel outputs:
#define PRINT_HOST 0
#define PRINT_HOST 0
#define USING_MASK 0
#define USING_MASK 0
#define USING_HD32 0
#include <iostream>
#include <iostream>
#include <numeric>
#include <numeric>
...
@@ -86,6 +87,80 @@ static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpeciali
...
@@ -86,6 +87,80 @@ static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpeciali
static
constexpr
auto
TensorSpecV
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecV
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecY
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecY
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
// Headdim/K/O should be a multiple of 8, and it's only supported up to 64 in prototype1.
// If Headdim/K/O <= 32, ues 1st template.
// If 32 < Headdim/K/O <= 64, ues 2nd template.
#if USING_HD32
// 1st template
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
DataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
128
,
// MPerBlock
128
,
// NPerBlock
32
,
// KPerBlock
32
,
// Gemm1NPerBlock
32
,
// Gemm1KPerBlock
8
,
// AK1
8
,
// BK1
2
,
// B1K1
32
,
// MPerXDL
32
,
// NPerXDL
1
,
// MXdlPerWave
4
,
// NXdlPerWave
1
,
// Gemm1NXdlPerWave
1
,
// Gemm2NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
// BBlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
#else
//2nd template
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
<
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
<
NumDimG
,
NumDimG
,
...
@@ -125,6 +200,7 @@ using DeviceGemmInstance =
...
@@ -125,6 +200,7 @@ using DeviceGemmInstance =
1
,
// MXdlPerWave
1
,
// MXdlPerWave
4
,
// NXdlPerWave
4
,
// NXdlPerWave
2
,
// Gemm1NXdlPerWave
2
,
// Gemm1NXdlPerWave
2
,
// Gemm2NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
...
@@ -151,6 +227,7 @@ using DeviceGemmInstance =
...
@@ -151,6 +227,7 @@ using DeviceGemmInstance =
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
MaskingSpec
>
;
// MaskingSpecialization
#endif
// Ref Gemm0: S = alpha * Q * K^T
// Ref Gemm0: S = alpha * Q * K^T
// fp16 in, fp32 out
// fp16 in, fp32 out
...
...
include/ck/ck.hpp
View file @
c9c85206
...
@@ -118,7 +118,7 @@
...
@@ -118,7 +118,7 @@
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0
#endif
#endif
#define CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK 1
#define CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK 1
#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK
1
#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK
0
#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK 1
#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK 1
// experimental feature: in-regsiter sub-dword transpose
// experimental feature: in-regsiter sub-dword transpose
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp
View file @
c9c85206
...
@@ -204,6 +204,7 @@ template <index_t NumDimG,
...
@@ -204,6 +204,7 @@ template <index_t NumDimG,
index_t
MXdlPerWave
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
index_t
NXdlPerWave
,
index_t
Gemm1NXdlPerWave
,
index_t
Gemm1NXdlPerWave
,
index_t
Gemm2NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
typename
ABlockTransferSrcAccessOrder
,
...
@@ -627,6 +628,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
...
@@ -627,6 +628,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
MXdlPerWave
,
MXdlPerWave
,
NXdlPerWave
,
NXdlPerWave
,
Gemm1NXdlPerWave
,
Gemm1NXdlPerWave
,
Gemm2NXdlPerWave
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcAccessOrder
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp
View file @
c9c85206
...
@@ -51,6 +51,7 @@ template <typename DataType,
...
@@ -51,6 +51,7 @@ template <typename DataType,
index_t
MXdlPerWave
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
index_t
NXdlPerWave
,
index_t
Gemm1NXdlPerWave
,
index_t
Gemm1NXdlPerWave
,
index_t
Gemm2NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
typename
ABlockTransferSrcAccessOrder
,
...
@@ -726,9 +727,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
...
@@ -726,9 +727,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
static_assert
(
Sum_M
%
MPerXdl
==
0
,
""
);
static_assert
(
Sum_M
%
MPerXdl
==
0
,
""
);
static
constexpr
index_t
GemmNWave
=
2
;
static
constexpr
index_t
GemmNWave
=
Free0_N
/
Gemm2NXdlPerWave
/
MPerXdl
;
static
constexpr
index_t
GemmOWave
=
BlockSize
/
get_warp_size
()
/
GemmNWave
;
static
constexpr
index_t
GemmOWave
=
BlockSize
/
get_warp_size
()
/
GemmNWave
;
static
constexpr
index_t
GemmNRepeat
=
Free0_N
/
GemmNWave
/
MPerXdl
;
static
constexpr
index_t
GemmNRepeat
=
Gemm2NXdlPerWave
;
static
constexpr
index_t
GemmORepeat
=
Free1_O
/
GemmOWave
/
NPerXdl
;
static
constexpr
index_t
GemmORepeat
=
Free1_O
/
GemmOWave
/
NPerXdl
;
static
constexpr
index_t
GemmMLoop
=
Free1_M
/
Sum_M
;
static
constexpr
index_t
GemmMLoop
=
Free1_M
/
Sum_M
;
static
constexpr
index_t
GemmMPack
=
static
constexpr
index_t
GemmMPack
=
...
@@ -1563,8 +1564,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
...
@@ -1563,8 +1564,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
n3
,
// NInputNum
n3
,
// NInputNum
n4
>
,
n4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
>
,
9
,
// DstVectorDim
9
,
// DstVectorDim
1
,
// DstScalarPerVector
1
,
// DstScalarPerVector
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarStrideInVector
1
,
// DstScalarStrideInVector
true
>
{
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
true
>
{
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
...
@@ -1937,6 +1938,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
...
@@ -1937,6 +1938,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
else
else
{
{
s_element_op
(
s_slash_p_thread_buf
(
i
),
s_slash_p_thread_buf
[
i
]);
s_element_op
(
s_slash_p_thread_buf
(
i
),
s_slash_p_thread_buf
[
i
]);
}
}
});
});
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment