Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
27d764eb
Commit
27d764eb
authored
Mar 10, 2023
by
ltqin
Browse files
Merge branch 'attn-bwd-develop' into attn-bwd-bf16-rtz
parents
022ce136
55057f09
Changes
17
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
4205 additions
and
106 deletions
+4205
-106
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
+2
-2
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward.cpp
...ale_softmax_gemm/batched_multihead_attention_backward.cpp
+168
-35
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_forward.cpp
...cale_softmax_gemm/batched_multihead_attention_forward.cpp
+144
-0
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train.cpp
..._scale_softmax_gemm/batched_multihead_attention_train.cpp
+305
-34
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward.cpp
...ale_softmax_gemm/grouped_multihead_attention_backward.cpp
+1032
-0
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_forward.cpp
...cale_softmax_gemm/grouped_multihead_attention_forward.cpp
+144
-0
example/32_batched_gemm_scale_softmax_gemm/run_batched_multihead_attention_forward.inc
..._softmax_gemm/run_batched_multihead_attention_forward.inc
+2
-2
example/32_batched_gemm_scale_softmax_gemm/run_grouped_multihead_attention_forward.inc
..._softmax_gemm/run_grouped_multihead_attention_forward.inc
+13
-11
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v1.hpp
..._batched_multihead_attention_backward_xdl_cshuffle_v1.hpp
+7
-8
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
..._batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
+4
-4
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_forward_xdl_cshuffle.hpp
...vice_batched_multihead_attention_forward_xdl_cshuffle.hpp
+1
-0
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_backward_xdl_cshuffle_v1.hpp
..._grouped_multihead_attention_backward_xdl_cshuffle_v1.hpp
+1190
-0
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_backward_xdl_cshuffle_v2.hpp
..._grouped_multihead_attention_backward_xdl_cshuffle_v2.hpp
+1181
-0
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_forward_xdl_cshuffle.hpp
...vice_grouped_multihead_attention_forward_xdl_cshuffle.hpp
+1
-0
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp
...batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp
+1
-1
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt2.hpp
...batched_multihead_attention_backward_xdl_cshuffle_pt2.hpp
+0
-0
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle.hpp
...wise_batched_multihead_attention_forward_xdl_cshuffle.hpp
+10
-9
No files found.
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
View file @
27d764eb
...
...
@@ -7,8 +7,8 @@ add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_pe
add_example_executable
(
example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
)
add_example_executable
(
example_grouped_multihead_attention_forward grouped_multihead_attention_forward.cpp
)
add_example_executable
(
example_batched_multihead_attention_forward batched_multihead_attention_forward.cpp
)
add_example_executable
(
example_
batch
ed_multihead_attention_backward
_pt1 batch
ed_multihead_attention_backward
_pt1
.cpp
)
add_example_executable
(
example_batched_multihead_attention_backward
_pt2
batched_multihead_attention_backward
_pt2
.cpp
)
add_example_executable
(
example_
group
ed_multihead_attention_backward
group
ed_multihead_attention_backward.cpp
)
add_example_executable
(
example_batched_multihead_attention_backward batched_multihead_attention_backward.cpp
)
add_example_executable
(
example_batched_multihead_attention_train batched_multihead_attention_train.cpp
)
add_custom_target
(
example_gemm_scale_softmax_gemm
)
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward
_pt2
.cpp
→
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward.cpp
View file @
27d764eb
...
...
@@ -24,8 +24,8 @@ Kernel outputs:
*/
#define PRINT_HOST 0
#define USING_MASK
1
#define
USING_K128 1
#define USING_MASK
0
#define
DIM 64 // DIM should be a multiple of 8.
#include <iostream>
#include <numeric>
...
...
@@ -36,7 +36,8 @@ Kernel outputs:
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v1.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
...
...
@@ -90,9 +91,81 @@ static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpeciali
static
constexpr
auto
TensorSpecV
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecY
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
#if USING_K128
// DIM should be a multiple of 8.
// If DIM <= 32 , ues prototype1 1st template.
// If 32 < DIM <= 64 , ues prototype1 2nd template.
// If 64 < DIM <= 128, ues prototype2 2nd template.
#if(DIM <= 32)
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
DataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
128
,
// MPerBlock
128
,
// NPerBlock
32
,
// KPerBlock
32
,
// Gemm1NPerBlock
32
,
// Gemm1KPerBlock
8
,
// AK1
8
,
// BK1
2
,
// B1K1
32
,
// MPerXDL
32
,
// NPerXDL
1
,
// MXdlPerWave
4
,
// NXdlPerWave
1
,
// Gemm1NXdlPerWave
1
,
// Gemm2NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
// BBlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
#elif(DIM <= 64)
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
<
NumDimG
,
NumDimM
,
NumDimN
,
...
...
@@ -121,7 +194,7 @@ using DeviceGemmInstance =
128
,
// MPerBlock
128
,
// NPerBlock
64
,
// KPerBlock
128
,
// Gemm1NPerBlock
64
,
// Gemm1NPerBlock
32
,
// Gemm1KPerBlock
8
,
// AK1
8
,
// BK1
...
...
@@ -130,7 +203,7 @@ using DeviceGemmInstance =
32
,
// NPerXDL
1
,
// MXdlPerWave
4
,
// NXdlPerWave
4
,
// Gemm1NXdlPerWave
2
,
// Gemm1NXdlPerWave
2
,
// Gemm2NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
...
...
@@ -154,14 +227,81 @@ using DeviceGemmInstance =
2
,
false
,
1
,
// CShuffleMXdlPerWavePerShuffle
4
,
// CShuffleNXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
#else
// using DeviceGemmInstance =
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
// NumDimG,
// NumDimM,
// NumDimN,
// NumDimK,
// NumDimO,
// DataType,
// GemmDataType,
// ZDataType,
// LSEDataType,
// Acc0BiasDataType,
// Acc1BiasDataType,
// AccDataType,
// ShuffleDataType,
// QKVElementOp,
// QKVElementOp,
// Scale,
// QKVElementOp,
// YElementOp,
// GemmSpec,
// TensorSpecQ,
// TensorSpecK,
// TensorSpecV,
// TensorSpecY,
// 1,
// 256,
// 128, // MPerBlock
// 128, // NPerBlock
// 64, // KPerBlock
// 64, // Gemm1NPerBlock
// 64, // Gemm1KPerBlock
// 8, // AK1
// 8, // BK1
// 2, // B1K1
// 32, // MPerXDL
// 32, // NPerXDL
// 1, // MXdlPerWave
// 4, // NXdlPerWave
// 2, // Gemm1NXdlPerWave
// 2, // Gemm2NXdlPerWave
// S<4, 64, 1>, // ABlockTransfer
// S<1, 0, 2>,
// S<1, 0, 2>,
// 2,
// 8,
// 8,
// true,
// S<4, 64, 1>, // BBlockTransfer
// S<1, 0, 2>,
// S<1, 0, 2>,
// 2,
// 8,
// 8,
// true,
// S<8, 32, 1>, // B1BlockTransfer
// S<0, 2, 1>,
// S<0, 2, 1>,
// 1,
// 2,
// 2,
// false,
// 1, // CShuffleMXdlPerWavePerShuffle
// 2, // CShuffleNXdlPerWavePerShuffle
// S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
// 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
// MaskingSpec>; // MaskingSpecialization
#elif(DIM <= 128)
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
_V2
<
NumDimG
,
NumDimM
,
NumDimN
,
...
...
@@ -190,8 +330,8 @@ using DeviceGemmInstance =
128
,
// MPerBlock
128
,
// NPerBlock
64
,
// KPerBlock
64
,
// Gemm1NPerBlock
64
,
// Gemm1KPerBlock
128
,
// Gemm1NPerBlock
32
,
// Gemm1KPerBlock
8
,
// AK1
8
,
// BK1
2
,
// B1K1
...
...
@@ -199,7 +339,7 @@ using DeviceGemmInstance =
32
,
// NPerXDL
1
,
// MXdlPerWave
4
,
// NXdlPerWave
2
,
// Gemm1NXdlPerWave
4
,
// Gemm1NXdlPerWave
2
,
// Gemm2NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
...
...
@@ -219,15 +359,16 @@ using DeviceGemmInstance =
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
2
,
false
,
1
,
// CShuffleMXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
4
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
#endif
// Ref Gemm0: S = alpha * Q * K^T
// fp16 in, fp32 out
using
ReferenceGemm0Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
DataType
,
...
...
@@ -339,25 +480,15 @@ int run(int argc, char* argv[])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck
::
index_t
M
=
512
;
ck
::
index_t
N
=
512
;
#if USING_K128
ck
::
index_t
K
=
128
;
ck
::
index_t
O
=
128
;
#else
ck
::
index_t
K
=
64
;
ck
::
index_t
O
=
64
;
#endif
ck
::
index_t
G0
=
3
;
ck
::
index_t
G1
=
2
;
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
ck
::
index_t
K
=
DIM
;
ck
::
index_t
O
=
DIM
;
ck
::
index_t
G0
=
54
;
ck
::
index_t
G1
=
16
;
bool
input_permute
=
false
;
bool
output_permute
=
false
;
float
p_drop
=
0.2
;
float
p_dropout
=
1
-
p_drop
;
uint16_t
p_dropout_in_16bits
=
uint16_t
(
std
::
floor
(
p_dropout
*
65535.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
const
unsigned
long
long
seed
=
1
;
const
unsigned
long
long
offset
=
0
;
...
...
@@ -384,12 +515,10 @@ int run(int argc, char* argv[])
G0
=
std
::
stoi
(
argv
[
8
]);
G1
=
std
::
stoi
(
argv
[
9
]);
alpha
=
std
::
stof
(
argv
[
10
]);
p_drop
=
std
::
stof
(
argv
[
10
]);
input_permute
=
std
::
stoi
(
argv
[
11
]);
output_permute
=
std
::
stoi
(
argv
[
12
]);
p_drop
=
std
::
stoi
(
argv
[
13
]);
}
else
{
...
...
@@ -402,6 +531,11 @@ int run(int argc, char* argv[])
exit
(
0
);
}
float
p_dropout
=
1
-
p_drop
;
uint16_t
p_dropout_in_16bits
=
uint16_t
(
std
::
floor
(
p_dropout
*
65535.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
std
::
cout
<<
"do_verification: "
<<
do_verification
<<
std
::
endl
;
std
::
cout
<<
"init_method: "
<<
init_method
<<
std
::
endl
;
std
::
cout
<<
"time_kernel: "
<<
time_kernel
<<
std
::
endl
;
...
...
@@ -536,7 +670,6 @@ int run(int argc, char* argv[])
// = 0
}
// calculate y & log-sum-exp beforehand
Tensor
<
DataType
>
q_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
DataType
>
k_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
ZDataType
>
z_g_m_n
({
BatchCount
,
M
,
N
});
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_forward.cpp
View file @
27d764eb
...
...
@@ -9,6 +9,8 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
Gemm1
*/
#define DIM 64 // DIM should be a multiple of 8.
#include <iostream>
#include <numeric>
#include <initializer_list>
...
...
@@ -73,6 +75,77 @@ static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecial
static
constexpr
auto
TensorSpecB1
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
#if(DIM <= 32)
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
256
,
128
,
// MPerBlock
128
,
// NPerBlock
32
,
// KPerBlock
32
,
// Gemm1NPerBlock
32
,
// Gemm1KPerBlock
8
,
// AK1
8
,
// BK1
2
,
// B1K1
32
,
// MPerXDL
32
,
// NPerXDL
1
,
// MXdlPerWave
4
,
// NXdlPerWave
1
,
// Gemm1NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
// BBlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
16
,
16
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
2
,
false
,
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
#elif(DIM <= 64)
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
<
NumDimG
,
...
...
@@ -142,6 +215,77 @@ using DeviceGemmInstance =
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
#elif(DIM <= 128)
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
256
,
128
,
// MPerBlock
128
,
// NPerBlock
32
,
// KPerBlock
128
,
// Gemm1NPerBlock
32
,
// Gemm1KPerBlock
8
,
// AK1
8
,
// BK1
2
,
// B1K1
32
,
// MPerXDL
32
,
// NPerXDL
1
,
// MXdlPerWave
4
,
// NXdlPerWave
4
,
// Gemm1NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
// BBlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
// CShuffleMXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
#endif
// Ref Gemm0: DataType in, AccDataType out
using
ReferenceGemm0Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
ADataType
,
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train.cpp
View file @
27d764eb
...
...
@@ -32,7 +32,7 @@ Kernel outputs:
#define PRINT_HOST 0
#define USING_MASK 0
#define
USING_HD32 0
#define
DIM 64 // DIM should be a multiple of 8.
#include <iostream>
#include <numeric>
...
...
@@ -43,8 +43,9 @@ Kernel outputs:
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v1.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_forward_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
...
...
@@ -99,6 +100,11 @@ static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpeciali
static
constexpr
auto
TensorSpecV
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecY
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
// DIM should be a multiple of 8.
// If DIM <= 32 , ues prototype1 1st template.
// If 32 < DIM <= 64 , ues prototype1 2nd template.
// If 64 < DIM <= 128, ues prototype2 2nd template.
#if(DIM <= 32)
using
DeviceGemmInstanceFWD
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
<
NumDimG
,
...
...
@@ -132,7 +138,7 @@ using DeviceGemmInstanceFWD =
128
,
// MPerBlock
128
,
// NPerBlock
32
,
// KPerBlock
64
,
// Gemm1NPerBlock
32
,
// Gemm1NPerBlock
32
,
// Gemm1KPerBlock
8
,
// AK1
8
,
// BK1
...
...
@@ -141,7 +147,7 @@ using DeviceGemmInstanceFWD =
32
,
// NPerXDL
1
,
// MXdlPerWave
4
,
// NXdlPerWave
2
,
// Gemm1NXdlPerWave
1
,
// Gemm1NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
...
...
@@ -160,23 +166,17 @@ using DeviceGemmInstanceFWD =
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
2
,
false
,
1
,
// CShuffleMXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
1
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
// Headdim/K/O should be a multiple of 8, and it's only supported up to 64 in prototype1.
// If Headdim/K/O <= 32, ues 1st template.
// If 32 < Headdim/K/O <= 64, ues 2nd template.
#if USING_HD32
// 1st template
using
DeviceGemmInstanceBWD
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_
PT
1
<
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_
V
1
<
NumDimG
,
NumDimM
,
NumDimN
,
...
...
@@ -242,10 +242,79 @@ using DeviceGemmInstanceBWD =
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
#else
// 2nd template
#elif(DIM <= 64)
using
DeviceGemmInstanceFWD
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
DataType
,
DataType
,
DataType
,
DataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
128
,
// MPerBlock
128
,
// NPerBlock
32
,
// KPerBlock
64
,
// Gemm1NPerBlock
32
,
// Gemm1KPerBlock
8
,
// AK1
8
,
// BK1
2
,
// B1K1
32
,
// MPerXDL
32
,
// NPerXDL
1
,
// MXdlPerWave
4
,
// NXdlPerWave
2
,
// Gemm1NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
// BBlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
16
,
16
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
// CShuffleMXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
using
DeviceGemmInstanceBWD
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_
PT
1
<
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_
V
1
<
NumDimG
,
NumDimM
,
NumDimN
,
...
...
@@ -311,6 +380,212 @@ using DeviceGemmInstanceBWD =
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
// using DeviceGemmInstanceBWD =
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
// NumDimG,
// NumDimM,
// NumDimN,
// NumDimK,
// NumDimO,
// DataType,
// GemmDataType,
// ZDataType,
// LSEDataType,
// Acc0BiasDataType,
// Acc1BiasDataType,
// AccDataType,
// ShuffleDataType,
// QKVElementOp,
// QKVElementOp,
// Scale,
// QKVElementOp,
// YElementOp,
// GemmSpec,
// TensorSpecQ,
// TensorSpecK,
// TensorSpecV,
// TensorSpecY,
// 1,
// 256,
// 128, // MPerBlock
// 128, // NPerBlock
// 64, // KPerBlock
// 64, // Gemm1NPerBlock
// 64, // Gemm1KPerBlock
// 8, // AK1
// 8, // BK1
// 2, // B1K1
// 32, // MPerXDL
// 32, // NPerXDL
// 1, // MXdlPerWave
// 4, // NXdlPerWave
// 2, // Gemm1NXdlPerWave
// 2, // Gemm2NXdlPerWave
// S<4, 64, 1>, // ABlockTransfer
// S<1, 0, 2>,
// S<1, 0, 2>,
// 2,
// 8,
// 8,
// true,
// S<4, 64, 1>, // BBlockTransfer
// S<1, 0, 2>,
// S<1, 0, 2>,
// 2,
// 8,
// 8,
// true,
// S<8, 32, 1>, // B1BlockTransfer
// S<0, 2, 1>,
// S<0, 2, 1>,
// 1,
// 2,
// 2,
// false,
// 1, // CShuffleMXdlPerWavePerShuffle
// 2, // CShuffleNXdlPerWavePerShuffle
// S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
// 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
// MaskingSpec>; // MaskingSpecialization
#elif(DIM <= 128)
using
DeviceGemmInstanceFWD
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
DataType
,
DataType
,
DataType
,
DataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
128
,
// MPerBlock
128
,
// NPerBlock
32
,
// KPerBlock
128
,
// Gemm1NPerBlock
32
,
// Gemm1KPerBlock
8
,
// AK1
8
,
// BK1
2
,
// B1K1
32
,
// MPerXDL
32
,
// NPerXDL
1
,
// MXdlPerWave
4
,
// NXdlPerWave
4
,
// Gemm1NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
// BBlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
// CShuffleMXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
using
DeviceGemmInstanceBWD
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
DataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
128
,
// MPerBlock
128
,
// NPerBlock
64
,
// KPerBlock
128
,
// Gemm1NPerBlock
32
,
// Gemm1KPerBlock
8
,
// AK1
8
,
// BK1
2
,
// B1K1
32
,
// MPerXDL
32
,
// NPerXDL
1
,
// MXdlPerWave
4
,
// NXdlPerWave
4
,
// Gemm1NXdlPerWave
2
,
// Gemm2NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
// BBlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
// CShuffleMXdlPerWavePerShuffle
4
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
#endif
// Ref Gemm0: S = alpha * Q * K^T
...
...
@@ -382,14 +657,12 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
// masking
#if USING_MASK
auto
N
=
s_g_m_n
.
GetLengths
()[
2
];
const
auto
mask
=
DeviceGemmInstanceFWD
::
C0MatrixMask
(
N
);
s_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
if
(
mask
.
IsMaskedElement
(
idx
[
1
],
idx
[
2
]))
self
(
idx
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
});
#endif
// P = Softmax(S)
auto
ref_softmax
=
ReferenceSoftmaxInstance
{};
...
...
@@ -424,22 +697,17 @@ int run(int argc, char* argv[])
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck
::
index_t
M
=
12
9
;
// 512
ck
::
index_t
N
=
12
9
;
// 512
ck
::
index_t
K
=
64
;
ck
::
index_t
O
=
64
;
ck
::
index_t
M
=
5
12
;
// 512
ck
::
index_t
N
=
5
12
;
// 512
ck
::
index_t
K
=
DIM
;
ck
::
index_t
O
=
DIM
;
ck
::
index_t
G0
=
4
;
// 54
ck
::
index_t
G1
=
6
;
// 16
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
bool
input_permute
=
false
;
bool
output_permute
=
false
;
bool
input_permute
=
true
;
bool
output_permute
=
true
;
float
p_drop
=
0.0
;
float
p_dropout
=
1
-
p_drop
;
uint16_t
p_dropout_in_16bits
=
uint16_t
(
std
::
floor
(
p_dropout
*
65535.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
float
p_drop
=
0.2
;
const
unsigned
long
long
seed
=
1
;
const
unsigned
long
long
offset
=
0
;
...
...
@@ -466,12 +734,10 @@ int run(int argc, char* argv[])
G0
=
std
::
stoi
(
argv
[
8
]);
G1
=
std
::
stoi
(
argv
[
9
]);
alpha
=
std
::
stof
(
argv
[
10
]);
p_drop
=
std
::
stof
(
argv
[
10
]);
input_permute
=
std
::
stoi
(
argv
[
11
]);
output_permute
=
std
::
stoi
(
argv
[
12
]);
p_drop
=
std
::
stoi
(
argv
[
13
]);
}
else
{
...
...
@@ -484,6 +750,11 @@ int run(int argc, char* argv[])
exit
(
0
);
}
float
p_dropout
=
1
-
p_drop
;
uint16_t
p_dropout_in_16bits
=
uint16_t
(
std
::
floor
(
p_dropout
*
65535.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
std
::
cout
<<
"do_verification: "
<<
do_verification
<<
std
::
endl
;
std
::
cout
<<
"init_method: "
<<
init_method
<<
std
::
endl
;
std
::
cout
<<
"time_kernel: "
<<
time_kernel
<<
std
::
endl
;
...
...
example/32_batched_gemm_scale_softmax_gemm/
batch
ed_multihead_attention_backward
_pt1
.cpp
→
example/32_batched_gemm_scale_softmax_gemm/
group
ed_multihead_attention_backward.cpp
View file @
27d764eb
This diff is collapsed.
Click to expand it.
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_forward.cpp
View file @
27d764eb
...
...
@@ -9,6 +9,8 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
Gemm1
*/
#define DIM 64 // DIM should be a multiple of 8.
#include <iostream>
#include <numeric>
#include <initializer_list>
...
...
@@ -73,6 +75,77 @@ static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecial
static
constexpr
auto
TensorSpecB1
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
#if(DIM <= 32)
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
256
,
128
,
// MPerBlock
128
,
// NPerBlock
32
,
// KPerBlock
32
,
// Gemm1NPerBlock
32
,
// Gemm1KPerBlock
8
,
// AK1
8
,
// BK1
2
,
// B1K1
32
,
// MPerXDL
32
,
// NPerXDL
1
,
// MXdlPerWave
4
,
// NXdlPerWave
1
,
// Gemm1NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
// BBlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
16
,
16
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
2
,
false
,
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
#elif(DIM <= 64)
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
<
NumDimG
,
...
...
@@ -142,6 +215,77 @@ using DeviceGemmInstance =
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
#elif(DIM <= 128)
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
256
,
128
,
// MPerBlock
128
,
// NPerBlock
32
,
// KPerBlock
128
,
// Gemm1NPerBlock
32
,
// Gemm1KPerBlock
8
,
// AK1
8
,
// BK1
2
,
// B1K1
32
,
// MPerXDL
32
,
// NPerXDL
1
,
// MXdlPerWave
4
,
// NXdlPerWave
4
,
// Gemm1NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
// BBlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
// CShuffleMXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
#endif
// Ref Gemm0: DataType in, AccDataType out
using
ReferenceGemm0Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
ADataType
,
...
...
example/32_batched_gemm_scale_softmax_gemm/run_batched_multihead_attention_forward.inc
View file @
27d764eb
...
...
@@ -11,8 +11,8 @@ int run(int argc, char* argv[])
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
ck
::
index_t
M
=
1000
;
// 120
ck
::
index_t
N
=
1000
;
// 1000
ck
::
index_t
K
=
64
;
ck
::
index_t
O
=
64
;
ck
::
index_t
K
=
DIM
;
ck
::
index_t
O
=
DIM
;
// Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape
// C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o])
...
...
example/32_batched_gemm_scale_softmax_gemm/run_grouped_multihead_attention_forward.inc
View file @
27d764eb
...
...
@@ -10,10 +10,7 @@ int run(int argc, char* argv[])
bool
input_permute
=
false
;
bool
output_permute
=
true
;
float
p_drop
=
0.1
;
float
p_dropout
=
1
-
p_drop
;
uint16_t
p_dropout_in_16bits
=
uint16_t
(
std
::
floor
(
p_dropout
*
65535.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
float
p_drop
=
0.2
;
const
unsigned
long
long
seed
=
1
;
const
unsigned
long
long
offset
=
0
;
...
...
@@ -27,14 +24,15 @@ int run(int argc, char* argv[])
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
else
if
(
argc
==
6
)
else
if
(
argc
==
7
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
input_permute
=
std
::
stoi
(
argv
[
4
]);
output_permute
=
std
::
stoi
(
argv
[
5
]);
p_drop
=
std
::
stoi
(
argv
[
4
]);
input_permute
=
std
::
stoi
(
argv
[
5
]);
output_permute
=
std
::
stoi
(
argv
[
6
]);
}
else
{
...
...
@@ -45,6 +43,10 @@ int run(int argc, char* argv[])
exit
(
0
);
}
float
p_dropout
=
1
-
p_drop
;
uint16_t
p_dropout_in_16bits
=
uint16_t
(
std
::
floor
(
p_dropout
*
65535.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
float
alpha
=
1
;
// scaling after 1st gemm
std
::
size_t
group_count
=
8
;
...
...
@@ -81,10 +83,10 @@ int run(int argc, char* argv[])
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
{
int
M
=
128
*
(
rand
()
%
8
+
1
);
int
N
=
128
*
(
rand
()
%
8
+
1
);
int
K
=
64
;
int
O
=
64
;
int
M
=
128
*
(
rand
()
%
8
)
+
(
rand
()
%
128
);
int
N
=
128
*
(
rand
()
%
8
)
+
(
rand
()
%
128
);
int
K
=
DIM
;
int
O
=
DIM
;
int
G0
=
rand
()
%
3
+
1
;
int
G1
=
rand
()
%
5
+
1
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_
pt
1.hpp
→
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_
v
1.hpp
View file @
27d764eb
...
...
@@ -50,10 +50,9 @@ template <typename GridwiseGemm,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
// __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
1
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
/*CK_MIN_BLOCK_PER_CU*/
1
)
#endif
kernel_batched_multihead_attention_backward_xdl_cshuffle_
pt
1
(
kernel_batched_multihead_attention_backward_xdl_cshuffle_
v
1
(
const
DataType
*
__restrict__
p_a_grid
,
const
DataType
*
__restrict__
p_b_grid
,
ZDataType
*
__restrict__
p_z_grid
,
...
...
@@ -233,7 +232,7 @@ template <index_t NumDimG,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpecialization
MaskingSpec
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
struct
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_
PT
1
struct
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_
V
1
:
public
BaseOperator
// TODO inherit atten bwd op once API stablizes
{
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
&&
NumDimO
>
0
,
...
...
@@ -255,7 +254,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
static constexpr index_t NumDimGemm1K = NumDimN;
#endif
using
DeviceOp
=
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_
PT
1
;
using
DeviceOp
=
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_
V
1
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -597,7 +596,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
};
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_
PT
1
<
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_
V
1
<
DataType
,
// TODO: distinguish A/B datatype
GemmDataType
,
GemmAccDataType
,
...
...
@@ -900,7 +899,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
float
ave_time
=
0
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
const
auto
kernel
=
kernel_batched_multihead_attention_backward_xdl_cshuffle_
pt
1
<
const
auto
kernel
=
kernel_batched_multihead_attention_backward_xdl_cshuffle_
v
1
<
GridwiseGemm
,
DataType
,
ZDataType
,
...
...
@@ -1231,7 +1230,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_
PT
1"
str
<<
"DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_
V
1"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle.hpp
→
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle
_v2
.hpp
View file @
27d764eb
...
...
@@ -15,7 +15,7 @@
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_
v
2.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_
pt
2.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
...
@@ -231,7 +231,7 @@ template <index_t NumDimG,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpecialization
MaskingSpec
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
struct
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
struct
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
_V2
:
public
BaseOperator
// TODO inherit atten bwd op once API stablizes
{
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
&&
NumDimO
>
0
,
...
...
@@ -253,7 +253,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
static constexpr index_t NumDimGemm1K = NumDimN;
#endif
using
DeviceOp
=
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
;
using
DeviceOp
=
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
_V2
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -1230,7 +1230,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle"
str
<<
"DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
_V2
"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_forward_xdl_cshuffle.hpp
View file @
27d764eb
...
...
@@ -413,6 +413,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
<
ADataType
,
// TODO: distinguish A/B datatype
ZDataType
,
GemmDataType
,
GemmAccDataType
,
CShuffleDataType
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_backward_xdl_cshuffle_v1.hpp
0 → 100644
View file @
27d764eb
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_backward_xdl_cshuffle_v2.hpp
0 → 100644
View file @
27d764eb
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_forward_xdl_cshuffle.hpp
View file @
27d764eb
...
...
@@ -424,6 +424,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
<
ADataType
,
// TODO: distinguish A/B datatype
ZDataType
,
GemmDataType
,
GemmAccDataType
,
CShuffleDataType
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp
View file @
27d764eb
...
...
@@ -85,7 +85,7 @@ template <typename DataType,
bool
PadN
,
bool
MaskOutUpperTriangle
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_
PT
1
struct
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_
V
1
{
static_assert
(
LoopSched
==
LoopScheduler
::
Default
,
"Non-default loop scheduler is currently not supported"
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_
v
2.hpp
→
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_
pt
2.hpp
View file @
27d764eb
File moved
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle.hpp
View file @
27d764eb
...
...
@@ -21,6 +21,7 @@
namespace
ck
{
template
<
typename
FloatAB
,
typename
ZDataType
,
typename
FloatGemm
,
typename
FloatGemmAcc
,
typename
FloatCShuffle
,
...
...
@@ -274,11 +275,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
const
auto
K
=
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
);
const
auto
Gemm1N
=
b1_grid_desc_bk0_n_bk1
.
GetLength
(
I1
);
if
(
Gemm1N
!=
K
)
{
std
::
cout
<<
"SizeK must be equal to SizeO (equal attention head size)"
<<
'\n'
;
return
false
;
}
//
if(Gemm1N != K)
//
{
//
std::cout << "SizeK must be equal to SizeO (equal attention head size)" << '\n';
//
return false;
//
}
if
(
!
(
M
==
c_grid_desc_m_n
.
GetLength
(
I0
)
&&
Gemm1N
==
c_grid_desc_m_n
.
GetLength
(
I1
)))
{
...
...
@@ -424,7 +425,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b1_grid
,
FloatC
*
__restrict__
p_c_grid
,
unsigned
short
*
__restrict__
p_z_grid
,
ZDataType
*
__restrict__
p_z_grid
,
FloatLSE
*
__restrict__
p_lse_grid
,
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
...
...
@@ -876,7 +877,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
auto
z_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
ushort
,
ushort
,
ZDataType
,
decltype
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
),
decltype
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
),
tensor_operation
::
element_wise
::
PassThrough
,
...
...
@@ -892,7 +893,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
n4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
>
,
9
,
// DstVectorDim
n4
,
// DstScalarPerVector
1
,
// DstScalarPerVector
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarStrideInVector
true
>
{
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment