Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
4c8b47c0
Unverified
Commit
4c8b47c0
authored
Aug 29, 2023
by
Dan Yao
Committed by
GitHub
Aug 29, 2023
Browse files
Merge pull request #870 from ROCmSoftwarePlatform/mha-train-bias-bwd-type2
Add bias to flash attention bwd
parents
226355e7
882b3328
Changes
15
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
3031 additions
and
635 deletions
+3031
-635
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v2.cpp
..._softmax_gemm/batched_multihead_attention_backward_v2.cpp
+19
-19
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train_v2.cpp
...ale_softmax_gemm/batched_multihead_attention_train_v2.cpp
+17
-17
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v2.cpp
..._softmax_gemm/grouped_multihead_attention_backward_v2.cpp
+17
-17
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_train_v2.cpp
...ale_softmax_gemm/grouped_multihead_attention_train_v2.cpp
+17
-17
example/52_flash_atten_bias/CMakeLists.txt
example/52_flash_atten_bias/CMakeLists.txt
+4
-1
example/52_flash_atten_bias/batched_multihead_attention_bias_backward_v2.cpp
...ten_bias/batched_multihead_attention_bias_backward_v2.cpp
+843
-0
example/52_flash_atten_bias/grouped_multihead_attention_bias_backward_v2.cpp
...ten_bias/grouped_multihead_attention_bias_backward_v2.cpp
+869
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
...e/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
+1
-0
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
...ice/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
+251
-174
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
...ice/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
+249
-174
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v1.hpp
...ice/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v1.hpp
+152
-70
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
...ice/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
+152
-70
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v1.hpp
...id/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v1.hpp
+214
-32
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
...id/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
+221
-44
include/ck/utility/static_buffer.hpp
include/ck/utility/static_buffer.hpp
+5
-0
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v2.cpp
View file @
4c8b47c0
...
...
@@ -25,7 +25,7 @@ Kernel outputs:
#define PRINT_HOST 0
#define USING_MASK 0
#define DIM
128
// DIM should be a multiple of 8.
#define DIM
64
// DIM should be a multiple of 8.
#include <iostream>
#include <numeric>
...
...
@@ -70,8 +70,8 @@ using AccDataType = F32;
using
ShuffleDataType
=
F32
;
using
LSEDataType
=
F32
;
using
ZDataType
=
U16
;
// INT32
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
using
Acc0BiasDataType
=
void
;
using
Acc1BiasDataType
=
void
;
static
constexpr
ck
::
index_t
NumDimG
=
2
;
static
constexpr
ck
::
index_t
NumDimM
=
1
;
...
...
@@ -104,20 +104,20 @@ static constexpr bool Deterministic = false;
#if(DIM <= 32)
// clang-format off
using
DeviceGemmInstance
=
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | |
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
OutputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
128
,
128
,
32
,
32
,
32
,
8
,
8
,
2
,
32
,
32
,
4
,
1
,
1
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpec
,
Deterministic
>
;
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds|
D0BlockTransfer|
CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|
SrcScalar|
MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
PerVector|
PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | |
|
| | | | | |
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
OutputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
128
,
128
,
32
,
32
,
32
,
8
,
8
,
2
,
32
,
32
,
4
,
1
,
1
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
4
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpec
,
Deterministic
>
;
// clang-format on
#elif(DIM <= 64)
// clang-format off
using
DeviceGemmInstance
=
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | |
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
OutputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
64
,
128
,
64
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
2
,
1
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpec
,
Deterministic
>
;
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds|
D0BlockTransfer|
CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|
SrcScalar|
MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
PerVector|
PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | |
|
| | | | | |
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
OutputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
64
,
128
,
64
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
2
,
1
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
4
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpec
,
Deterministic
>
;
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
...
...
@@ -129,20 +129,20 @@ using DeviceGemmInstance =
#elif(DIM <= 128)
// clang-format off
using
DeviceGemmInstance
=
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | |
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds|
D0BlockTransfer|
B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|
SrcScalar|
ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
PerVector|
Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | |
|
| | | | | | | | | | | | |
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 4, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 4, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 4, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::
DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 4, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 4, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 4, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 64, 128, 64, 128, 32, 8, 8, 2, 32, 32, 2, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 64, 128, 64, 128, 32, 8, 8, 2, 32, 32, 2, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 64, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
OutputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
64
,
128
,
128
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
2
,
1
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
4
,
S
<
1
,
32
,
1
,
8
>
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpec
,
Deterministic
>
;
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
OutputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
64
,
128
,
128
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
2
,
1
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
4
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
4
,
S
<
1
,
32
,
1
,
8
>
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpec
,
Deterministic
>
;
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 32, 128, 128, 128, 32, 8, 8, 2, 32, 32, 1, 1, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 32, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 1, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train_v2.cpp
View file @
4c8b47c0
...
...
@@ -79,8 +79,8 @@ using AccDataType = F32;
using
ShuffleDataType
=
F32
;
using
LSEDataType
=
F32
;
using
ZDataType
=
U16
;
// INT32
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
using
Acc0BiasDataType
=
void
;
using
Acc1BiasDataType
=
void
;
static
constexpr
ck
::
index_t
NumDimG
=
2
;
static
constexpr
ck
::
index_t
NumDimM
=
1
;
...
...
@@ -120,11 +120,11 @@ using DeviceGemmInstanceFWD =
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
InputDataType
,
InputDataType
,
InputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
void
,
void
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
128
,
128
,
32
,
32
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
1
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
4
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
2
,
false
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
8
,
4
,
MaskingSpec
,
Deterministic
>
;
using
DeviceGemmInstanceBWD
=
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | |
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
OutputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
128
,
128
,
32
,
32
,
32
,
8
,
8
,
2
,
32
,
32
,
4
,
1
,
1
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpec
,
Deterministic
>
;
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds|
D0BlockTransfer|
CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|
SrcScalar|
MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
PerVector|
PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | |
|
| | | | | |
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
OutputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
128
,
128
,
32
,
32
,
32
,
8
,
8
,
2
,
32
,
32
,
4
,
1
,
1
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
4
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpec
,
Deterministic
>
;
// clang-format on
#elif(DIM <= 64)
// clang-format off
...
...
@@ -136,11 +136,11 @@ using DeviceGemmInstanceFWD =
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
InputDataType
,
InputDataType
,
InputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
void
,
void
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
128
,
128
,
32
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
4
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
4
,
MaskingSpec
,
Deterministic
>
;
using
DeviceGemmInstanceBWD
=
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | |
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
OutputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
64
,
128
,
64
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
2
,
1
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpec
,
Deterministic
>
;
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds|
D0BlockTransfer|
CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|
SrcScalar|
MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
PerVector|
PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | |
|
| | | | | |
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
OutputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
64
,
128
,
64
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
2
,
1
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
4
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpec
,
Deterministic
>
;
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
...
...
@@ -159,10 +159,10 @@ using DeviceGemmInstanceFWD =
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
InputDataType
,
InputDataType
,
InputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
void
,
void
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
128
,
128
,
32
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
4
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
4
,
MaskingSpec
,
Deterministic
>
;
using
DeviceGemmInstanceBWD
=
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | |
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds|
D0BlockTransfer|
B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|
DstScalar|
ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
PerVector_K1|
Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | |
|
| | | | | | | | | | | | |
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 4, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 4, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 4, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
...
...
@@ -172,7 +172,7 @@ using DeviceGemmInstanceBWD =
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 64, 128, 64, 128, 32, 8, 8, 2, 32, 32, 2, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 64, 128, 64, 128, 32, 8, 8, 2, 32, 32, 2, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 64, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
OutputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
64
,
128
,
128
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
2
,
1
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
4
,
S
<
1
,
32
,
1
,
8
>
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpec
,
Deterministic
>
;
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
OutputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
64
,
128
,
128
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
2
,
1
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
4
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
4
,
S
<
1
,
32
,
1
,
8
>
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpec
,
Deterministic
>
;
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 32, 128, 128, 128, 32, 8, 8, 2, 32, 32, 1, 1, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 32, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 1, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v2.cpp
View file @
4c8b47c0
...
...
@@ -69,8 +69,8 @@ using AccDataType = F32;
using
ShuffleDataType
=
F32
;
using
LSEDataType
=
F32
;
using
ZDataType
=
U16
;
// INT32
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
using
Acc0BiasDataType
=
void
;
using
Acc1BiasDataType
=
void
;
static
constexpr
ck
::
index_t
NumDimG
=
2
;
static
constexpr
ck
::
index_t
NumDimM
=
1
;
...
...
@@ -103,20 +103,20 @@ static constexpr bool Deterministic = false;
#if(DIM <= 32)
// clang-format off
using
DeviceGemmInstance
=
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | |
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
OutputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
128
,
128
,
32
,
32
,
32
,
8
,
8
,
2
,
32
,
32
,
4
,
1
,
1
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpec
,
Deterministic
>
;
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds|
D0BlockTransfer|
CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|
SrcScalar|
MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
PerVector|
PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | |
|
| | | | | |
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
OutputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
128
,
128
,
32
,
32
,
32
,
8
,
8
,
2
,
32
,
32
,
4
,
1
,
1
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpec
,
Deterministic
>
;
// clang-format on
#elif(DIM <= 64)
// clang-format off
using
DeviceGemmInstance
=
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | |
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
OutputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
64
,
128
,
64
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
2
,
1
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpec
,
Deterministic
>
;
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds|
D0BlockTransfer|
CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|
SrcScalar|
MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
PerVector|
PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | |
|
| | | | | |
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
OutputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
64
,
128
,
64
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
2
,
1
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpec
,
Deterministic
>
;
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
...
...
@@ -128,10 +128,10 @@ using DeviceGemmInstance =
#elif(DIM <= 128)
// clang-format off
using
DeviceGemmInstance
=
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | |
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds|
D0BlockTransfer|
B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|
SrcScalar|
ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
PerVector|
Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | |
|
| | | | | | | | | | | | |
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 4, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 4, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 4, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
...
...
@@ -141,7 +141,7 @@ using DeviceGemmInstance =
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 64, 128, 64, 128, 32, 8, 8, 2, 32, 32, 2, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 64, 128, 64, 128, 32, 8, 8, 2, 32, 32, 2, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 64, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
OutputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
64
,
128
,
128
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
2
,
1
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
4
,
S
<
1
,
32
,
1
,
8
>
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpec
,
Deterministic
>
;
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
OutputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
64
,
128
,
128
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
2
,
1
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
4
,
S
<
1
,
32
,
1
,
8
>
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpec
,
Deterministic
>
;
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 32, 128, 128, 128, 32, 8, 8, 2, 32, 32, 1, 1, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 32, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 1, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_train_v2.cpp
View file @
4c8b47c0
...
...
@@ -78,8 +78,8 @@ using AccDataType = F32;
using
ShuffleDataType
=
F32
;
using
LSEDataType
=
F32
;
using
ZDataType
=
U16
;
// INT32
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
using
Acc0BiasDataType
=
void
;
using
Acc1BiasDataType
=
void
;
static
constexpr
ck
::
index_t
NumDimG
=
2
;
static
constexpr
ck
::
index_t
NumDimM
=
1
;
...
...
@@ -119,11 +119,11 @@ using DeviceGemmInstanceFWD =
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
InputDataType
,
InputDataType
,
InputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
void
,
void
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
128
,
128
,
32
,
32
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
1
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
2
,
false
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
8
,
1
,
MaskingSpec
,
Deterministic
>
;
using
DeviceGemmInstanceBWD
=
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | |
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
OutputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
128
,
128
,
32
,
32
,
32
,
8
,
8
,
2
,
32
,
32
,
4
,
1
,
1
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpec
,
Deterministic
>
;
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds|
D0BlockTransfer|
CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|
SrcScalar|
MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
PerVector|
PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | |
|
| | | | | |
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
OutputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
128
,
128
,
32
,
32
,
32
,
8
,
8
,
2
,
32
,
32
,
4
,
1
,
1
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpec
,
Deterministic
>
;
// clang-format on
#elif(DIM <= 64)
// clang-format off
...
...
@@ -135,11 +135,11 @@ using DeviceGemmInstanceFWD =
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
InputDataType
,
InputDataType
,
InputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
void
,
void
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
128
,
128
,
32
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
1
,
MaskingSpec
,
Deterministic
>
;
using
DeviceGemmInstanceBWD
=
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | |
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
OutputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
64
,
128
,
64
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
2
,
1
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpec
,
Deterministic
>
;
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds|
D0BlockTransfer|
CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|
SrcScalar|
MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
PerVector|
PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | |
|
| | | | | |
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
OutputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
64
,
128
,
64
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
2
,
1
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpec
,
Deterministic
>
;
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
...
...
@@ -158,10 +158,10 @@ using DeviceGemmInstanceFWD =
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
InputDataType
,
InputDataType
,
InputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
void
,
void
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
128
,
128
,
32
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
1
,
MaskingSpec
,
Deterministic
>
;
using
DeviceGemmInstanceBWD
=
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | |
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds|
D0BlockTransfer|
B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|
SrcScalar|
ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
PerVector|
Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | |
|
| | | | | | | | | | | | |
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 4, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 4, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 4, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
...
...
@@ -171,7 +171,7 @@ using DeviceGemmInstanceBWD =
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 64, 128, 64, 128, 32, 8, 8, 2, 32, 32, 2, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 64, 128, 64, 128, 32, 8, 8, 2, 32, 32, 2, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 64, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
OutputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
64
,
128
,
128
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
2
,
1
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
4
,
S
<
1
,
32
,
1
,
8
>
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpec
,
Deterministic
>
;
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
OutputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
64
,
128
,
128
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
2
,
1
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
4
,
S
<
1
,
32
,
1
,
8
>
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpec
,
Deterministic
>
;
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 32, 128, 128, 128, 32, 8, 8, 2, 32, 32, 1, 1, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 32, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 1, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
...
...
example/52_flash_atten_bias/CMakeLists.txt
View file @
4c8b47c0
add_example_executable
(
example_batched_multihead_attention_bias_forward_v2 batched_multihead_attention_bias_forward_v2.cpp
)
add_example_executable
(
example_grouped_multihead_attention_bias_forward_v2 grouped_multihead_attention_bias_forward_v2.cpp
)
add_example_executable
(
example_batched_multihead_attention_bias_backward_v2 batched_multihead_attention_bias_backward_v2.cpp
)
add_example_executable
(
example_grouped_multihead_attention_bias_backward_v2 grouped_multihead_attention_bias_backward_v2.cpp
)
\ No newline at end of file
example/52_flash_atten_bias/batched_multihead_attention_bias_backward_v2.cpp
0 → 100644
View file @
4c8b47c0
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
/*
Backprop for Gemm + Softmax + Gemm fused operation, where forward prop is defined as:
Y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
Computation graph:
K^T V
| |
| |
Q --- * ----- Softmax ----- * --> Y
S P
Kernel inputs:
Q, K, V, Y, dY, per-row softmax stats (LSE)
Kernel outputs:
dQ, dK, dV
*/
#define PRINT_HOST 0
#define USING_MASK 0
#define DIM 64 // DIM should be a multiple of 8.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <fstream>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_dropout.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
F16
=
ck
::
half_t
;
using
BF16
=
ck
::
bhalf_t
;
using
F32
=
float
;
using
U16
=
unsigned
short
;
using
INT32
=
int32_t
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
using
QKVElementOp
=
PassThrough
;
using
YElementOp
=
PassThrough
;
using
InputDataType
=
F16
;
using
OutputDataType
=
F16
;
using
GemmDataType
=
F16
;
using
AccDataType
=
F32
;
using
ShuffleDataType
=
F32
;
using
LSEDataType
=
F32
;
using
ZDataType
=
U16
;
// INT32
using
Acc0BiasDataType
=
F16
;
using
Acc1BiasDataType
=
void
;
static
constexpr
ck
::
index_t
NumDimG
=
2
;
static
constexpr
ck
::
index_t
NumDimM
=
1
;
static
constexpr
ck
::
index_t
NumDimN
=
1
;
static
constexpr
ck
::
index_t
NumDimK
=
1
;
static
constexpr
ck
::
index_t
NumDimO
=
1
;
// When OutputDataType == F32, CShuffleBlockTransferScalarPerVector_NPerBlock = 4
// When OutputDataType == F16/BF16, CShuffleBlockTransferScalarPerVector_NPerBlock = 8
static
constexpr
ck
::
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
=
8
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
#if USING_MASK
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskUpperTriangleFromBottomRight
;
#else
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskDisabled
;
#endif
static
constexpr
auto
TensorSpecQ
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecK
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecV
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecY
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
bool
Deterministic
=
false
;
// DIM should be a multiple of 8.
// If DIM <= 32 , ues prototype1.
// If 32 < DIM <= 64 , ues prototype1.
// If 64 < DIM <= 128, ues prototype2.
#if(DIM <= 32)
// clang-format off
using
DeviceGemmInstance
=
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | |
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
OutputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
128
,
128
,
32
,
32
,
32
,
8
,
8
,
2
,
32
,
32
,
4
,
1
,
1
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
4
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpec
,
Deterministic
>
;
// clang-format on
#elif(DIM <= 64)
// clang-format off
using
DeviceGemmInstance
=
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | |
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
OutputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
64
,
128
,
64
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
2
,
1
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
4
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpec
,
Deterministic
>
;
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | |
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, false, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, false, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, false, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// clang-format on
#elif(DIM <= 128)
// clang-format off
using
DeviceGemmInstance
=
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 4, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 4, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 4, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 4, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 4, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 64, 128, 64, 128, 32, 8, 8, 2, 32, 32, 2, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 64, 128, 64, 128, 32, 8, 8, 2, 32, 32, 2, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 64, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
OutputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
64
,
128
,
128
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
2
,
1
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
4
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
4
,
S
<
1
,
32
,
1
,
8
>
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpec
,
Deterministic
>
;
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 32, 128, 128, 128, 32, 8, 8, 2, 32, 32, 1, 1, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 32, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 1, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// clang-format on
#endif
// Ref Gemm0: S = alpha * Q * K^T
// fp16 in, fp32 out
using
ReferenceGemm0Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
InputDataType
,
InputDataType
,
AccDataType
,
AccDataType
,
PassThrough
,
PassThrough
,
Scale
>
;
// Ref Softmax: P = Softmax(S)
// fp32 in, fp16 out
using
ReferenceSoftmaxInstance
=
ck
::
tensor_operation
::
host
::
ReferenceSoftmax
<
AccDataType
,
InputDataType
,
AccDataType
>
;
// Ref Gemm1: Y = P * V
// fp16 in, fp16 out
using
ReferenceGemm1Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
InputDataType
,
InputDataType
,
InputDataType
,
AccDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
// Ref Gemm for backward pass
// fp16 in, fp16 out
using
ReferenceGemm0GradInstance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
InputDataType
,
InputDataType
,
InputDataType
,
AccDataType
,
PassThrough
,
PassThrough
,
Scale
>
;
using
ReferenceGemm1GradInstance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
InputDataType
,
InputDataType
,
OutputDataType
,
AccDataType
,
PassThrough
,
PassThrough
,
Scale
>
;
// Ref dropout
using
ReferenceDropoutInstance
=
ck
::
tensor_operation
::
host
::
ReferenceDropout
<
ZDataType
,
InputDataType
,
InputDataType
>
;
template
<
typename
TensorQ
,
typename
TensorK
,
typename
TensorD
,
typename
TensorV
,
typename
TensorS
,
typename
TensorP
,
typename
TensorZ
,
typename
TensorY
,
typename
TensorLSE
=
TensorP
>
void
run_attention_fwd_host
(
const
TensorQ
&
q_g_m_k
,
const
TensorK
&
k_g_n_k
,
const
TensorD
&
d_g_m_n
,
const
TensorV
&
v_g_n_o
,
const
float
alpha
,
TensorS
&
s_g_m_n
,
TensorP
&
p_g_m_n
,
TensorY
&
y_g_m_o
,
TensorLSE
&
lse_g_m
,
TensorP
&
p_drop_g_m_n
,
TensorZ
&
z_g_m_n
,
ZDataType
p_dropout_in_16bits
,
float
rp_dropout
)
{
// S = alpha * Q * K^T
auto
k_g_k_n
=
k_g_n_k
.
Transpose
({
0
,
2
,
1
});
auto
ref_gemm0
=
ReferenceGemm0Instance
{};
auto
ref_gemm0_invoker
=
ref_gemm0
.
MakeInvoker
();
auto
ref_gemm0_argument
=
ref_gemm0
.
MakeArgument
(
q_g_m_k
,
k_g_k_n
,
s_g_m_n
,
PassThrough
{},
PassThrough
{},
Scale
{
alpha
});
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
// bias
s_g_m_n
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
+=
ck
::
type_convert
<
AccDataType
>
(
d_g_m_n
(
idx
));
});
// masking
auto
M
=
s_g_m_n
.
GetLengths
()[
1
];
auto
N
=
s_g_m_n
.
GetLengths
()[
2
];
const
auto
mask
=
DeviceGemmInstance
::
C0MatrixMask
(
M
,
N
);
s_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
if
(
mask
.
IsMaskedElement
(
idx
[
1
],
idx
[
2
]))
self
(
idx
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
});
// P = Softmax(S)
auto
ref_softmax
=
ReferenceSoftmaxInstance
{};
auto
ref_softmax_invoker
=
ref_softmax
.
MakeInvoker
();
auto
ref_softmax_argument
=
ref_softmax
.
MakeArgument
(
s_g_m_n
,
p_g_m_n
,
1
,
0
,
{
2
},
&
lse_g_m
);
ref_softmax_invoker
.
Run
(
ref_softmax_argument
);
// P_dropped
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
z_g_m_n
,
p_g_m_n
,
p_drop_g_m_n
,
p_dropout_in_16bits
,
rp_dropout
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
// Y = P_dropout * V
auto
ref_gemm1
=
ReferenceGemm1Instance
{};
auto
ref_gemm1_invoker
=
ref_gemm1
.
MakeInvoker
();
auto
ref_gemm1_argument
=
ref_gemm1
.
MakeArgument
(
p_drop_g_m_n
,
v_g_n_o
,
y_g_m_o
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
ref_gemm1_invoker
.
Run
(
ref_gemm1_argument
);
}
int
run
(
int
argc
,
char
*
argv
[])
{
bool
do_verification
=
true
;
int
init_method
=
1
;
// method 1 will have slightly higher error; TODO: to investigate
bool
time_kernel
=
true
;
// Overall QKV matrices shape
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck
::
index_t
M
=
512
;
ck
::
index_t
N
=
512
;
ck
::
index_t
K
=
DIM
;
ck
::
index_t
O
=
DIM
;
ck
::
index_t
G0
=
4
;
ck
::
index_t
G1
=
6
;
bool
input_permute
=
false
;
bool
output_permute
=
false
;
float
p_drop
=
0.0
;
const
unsigned
long
long
seed
=
1
;
const
unsigned
long
long
offset
=
0
;
if
(
argc
==
1
)
{
// use default case
}
else
if
(
argc
==
4
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
else
if
(
argc
==
13
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
M
=
std
::
stoi
(
argv
[
4
]);
N
=
std
::
stoi
(
argv
[
5
]);
K
=
std
::
stoi
(
argv
[
6
]);
O
=
std
::
stoi
(
argv
[
7
]);
G0
=
std
::
stoi
(
argv
[
8
]);
G1
=
std
::
stoi
(
argv
[
9
]);
p_drop
=
std
::
stof
(
argv
[
10
]);
input_permute
=
std
::
stoi
(
argv
[
11
]);
output_permute
=
std
::
stoi
(
argv
[
12
]);
}
else
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4 to 11: M, N, K, O, G0, G1
\n
"
);
printf
(
"arg10: scale (alpha)
\n
"
);
printf
(
"arg11 to 12: input / output permute
\n
"
);
exit
(
0
);
}
float
p_dropout
=
1
-
p_drop
;
ZDataType
p_dropout_in_16bits
=
ZDataType
(
std
::
floor
(
p_dropout
*
65535.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
std
::
cout
<<
"do_verification: "
<<
do_verification
<<
std
::
endl
;
std
::
cout
<<
"init_method: "
<<
init_method
<<
std
::
endl
;
std
::
cout
<<
"time_kernel: "
<<
time_kernel
<<
std
::
endl
;
std
::
cout
<<
"M: "
<<
M
<<
std
::
endl
;
std
::
cout
<<
"N: "
<<
N
<<
std
::
endl
;
std
::
cout
<<
"K: "
<<
K
<<
std
::
endl
;
std
::
cout
<<
"O: "
<<
O
<<
std
::
endl
;
std
::
cout
<<
"G0: "
<<
G0
<<
std
::
endl
;
std
::
cout
<<
"G1: "
<<
G1
<<
std
::
endl
;
std
::
cout
<<
"alpha: "
<<
alpha
<<
std
::
endl
;
std
::
cout
<<
"input_permute: "
<<
input_permute
<<
std
::
endl
;
std
::
cout
<<
"output_permute: "
<<
output_permute
<<
std
::
endl
;
std
::
cout
<<
"p_drop: "
<<
p_drop
<<
std
::
endl
;
std
::
cout
<<
"seed: "
<<
seed
<<
std
::
endl
;
std
::
cout
<<
"offset: "
<<
offset
<<
std
::
endl
;
const
ck
::
index_t
BatchCount
=
G0
*
G1
;
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// Q layout [G0, M, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
K
,
M
*
K
,
K
,
1
};
// Q layout [G0, G1, M, K]
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_lengths
{
G0
,
G1
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// K layout [G0, N, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
K
,
N
*
K
,
K
,
1
};
// K layout [G0, G1, N, K]
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_lengths
{
G0
,
G1
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
O
,
O
,
1
,
G1
*
O
}
// V layout [G0, N, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
O
,
N
*
O
,
1
,
O
};
// V layout [G0, G1, N, O]
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_strides
=
output_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// Y layout [G0, M, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// Y layout [G0, G1, M, O]
std
::
vector
<
ck
::
index_t
>
d_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
d_gs_ms_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// D layout [G0, M, G1, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// D layout [G0, G1, M, N]
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// Z layout [G0, M, G1, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1, M, N]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward pass
// Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...)))
// ^^^^^^^^^^^^^^^^^^^^^
// LSE
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
{
G1
*
M
,
M
,
1
};
// LSE layout [G0, G1, M]
Tensor
<
InputDataType
>
q_gs_ms_ks
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
InputDataType
>
k_gs_ns_ks
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
Acc0BiasDataType
>
d_gs_ms_ns
(
d_gs_ms_ns_lengths
,
d_gs_ms_ns_strides
);
Tensor
<
ZDataType
>
z_gs_ms_ns
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
Tensor
<
InputDataType
>
v_gs_os_ns
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
Tensor
<
InputDataType
>
y_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
InputDataType
>
ygrad_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
LSEDataType
>
lse_gs_ms
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
std
::
cout
<<
"q_gs_ms_ks: "
<<
q_gs_ms_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"k_gs_ns_ks: "
<<
k_gs_ns_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d_gs_ms_ns: "
<<
d_gs_ms_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"z_gs_ms_ns: "
<<
z_gs_ms_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"v_gs_os_ns: "
<<
v_gs_os_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"y_gs_ms_os: "
<<
y_gs_ms_os
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"lse_gs_ms_os: "
<<
lse_gs_ms
.
mDesc
<<
std
::
endl
;
z_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
0
});
switch
(
init_method
)
{
case
0
:
break
;
case
1
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
InputDataType
>
{
-
2
,
2
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
InputDataType
>
{
-
2
,
2
});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
InputDataType
>
{
-
2
,
2
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_2
<
InputDataType
>
{
-
2
,
2
});
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
Acc0BiasDataType
>
{
-
2
,
2
});
// d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
break
;
case
2
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
InputDataType
>
{
0.0
,
1.0
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
InputDataType
>
{
0.0
,
1.0
});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_3
<
InputDataType
>
{
-
0.5
,
0.5
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_3
<
InputDataType
>
{
-
0.5
,
0.5
});
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_3
<
Acc0BiasDataType
>
{
-
0.5
,
0.5
});
break
;
case
3
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
InputDataType
>
{
-
5
,
5
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Acc0BiasDataType
>
{
1
});
break
;
case
4
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Acc0BiasDataType
>
{
1
});
break
;
case
5
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
// dy[g0, g1, m, o]
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Acc0BiasDataType
>
{
1
});
// dO dot O = [0; 1; 2; ...]
break
;
case
6
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
3
>
{});
// dy[g0, g1, m, o]
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Acc0BiasDataType
>
{
1
});
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
// dP = dO V = [0, 1, 2, ...; 0, 1, 2, ...; ...]
// dO dot O = [127.5; ...]
// dS = P * (dP - dO dot O)
//
break
;
default:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
// dy[g0,g1, m, o]
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Acc0BiasDataType
>
{
1
});
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
// dP = dO V = ones
// dS = P * (dP - (dO dot O))
// = 0.0039 * ones * (ones - 0.0039*256)
// = 0.0039 * ones * (ones - 1)
// = 0
}
// qkv gradients have the same descriptor as with qkv
DeviceMem
q_device_buf
(
sizeof
(
InputDataType
)
*
q_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
k_device_buf
(
sizeof
(
InputDataType
)
*
k_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d_device_buf
(
sizeof
(
Acc0BiasDataType
)
*
d_gs_ms_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
z_device_buf
(
sizeof
(
ZDataType
)
*
z_gs_ms_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
v_device_buf
(
sizeof
(
InputDataType
)
*
v_gs_os_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
y_device_buf
(
sizeof
(
InputDataType
)
*
y_gs_ms_os
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
lse_device_buf
(
sizeof
(
LSEDataType
)
*
lse_gs_ms
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
qgrad_device_buf
(
sizeof
(
OutputDataType
)
*
q_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
kgrad_device_buf
(
sizeof
(
OutputDataType
)
*
k_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
vgrad_device_buf
(
sizeof
(
OutputDataType
)
*
v_gs_os_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
ygrad_device_buf
(
sizeof
(
InputDataType
)
*
y_gs_ms_os
.
mDesc
.
GetElementSpaceSize
());
q_device_buf
.
ToDevice
(
q_gs_ms_ks
.
mData
.
data
());
k_device_buf
.
ToDevice
(
k_gs_ns_ks
.
mData
.
data
());
d_device_buf
.
ToDevice
(
d_gs_ms_ns
.
mData
.
data
());
v_device_buf
.
ToDevice
(
v_gs_os_ns
.
mData
.
data
());
ygrad_device_buf
.
ToDevice
(
ygrad_gs_ms_os
.
mData
.
data
());
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
// get z matrix
{
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
InputDataType
*>
(
q_device_buf
.
GetDeviceBuffer
()),
static_cast
<
InputDataType
*>
(
k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ZDataType
*>
(
z_device_buf
.
GetDeviceBuffer
()),
static_cast
<
InputDataType
*>
(
v_device_buf
.
GetDeviceBuffer
()),
static_cast
<
InputDataType
*>
(
y_device_buf
.
GetDeviceBuffer
()),
static_cast
<
LSEDataType
*>
(
lse_device_buf
.
GetDeviceBuffer
()),
static_cast
<
InputDataType
*>
(
ygrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
qgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
kgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
vgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Acc0BiasDataType
*>
(
d_device_buf
.
GetDeviceBuffer
()),
// p_acc0_bias;
nullptr
,
// p_acc1_bias;
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
,
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
,
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
,
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
,
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
,
lse_gs_ms_lengths
,
d_gs_ms_ns_lengths
,
// acc0_bias_gs_ms_ns_lengths
d_gs_ms_ns_strides
,
// acc0_bias_gs_ms_ns_strides
{},
// acc1_bias_gs_ms_os_lengths,
{},
// acc1_bias_gs_ms_os_strides,
QKVElementOp
{},
QKVElementOp
{},
Scale
{
alpha
},
QKVElementOp
{},
YElementOp
{},
p_drop
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
(
seed
,
offset
));
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
std
::
cout
<<
gemm
.
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
return
0
;
}
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
}
// not need output z matrix
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
InputDataType
*>
(
q_device_buf
.
GetDeviceBuffer
()),
static_cast
<
InputDataType
*>
(
k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ZDataType
*>
(
nullptr
),
// set to nullptr
static_cast
<
InputDataType
*>
(
v_device_buf
.
GetDeviceBuffer
()),
static_cast
<
InputDataType
*>
(
y_device_buf
.
GetDeviceBuffer
()),
static_cast
<
LSEDataType
*>
(
lse_device_buf
.
GetDeviceBuffer
()),
static_cast
<
InputDataType
*>
(
ygrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
qgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
kgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
vgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Acc0BiasDataType
*>
(
d_device_buf
.
GetDeviceBuffer
()),
// p_acc0_bias;
nullptr
,
// p_acc1_bias;
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
,
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
,
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
,
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
,
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
,
lse_gs_ms_lengths
,
d_gs_ms_ns_lengths
,
// acc0_bias_gs_ms_ns_lengths
d_gs_ms_ns_strides
,
// acc0_bias_gs_ms_ns_strides
{},
// acc1_bias_gs_ms_os_lengths,
{},
// acc1_bias_gs_ms_os_strides,
QKVElementOp
{},
QKVElementOp
{},
Scale
{
alpha
},
QKVElementOp
{},
YElementOp
{},
p_drop
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
(
seed
,
offset
));
qgrad_device_buf
.
SetZero
();
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
// 5 GEMM ops in total:
// S_MNK / dP_MNO Gemm (Gemm0 rcr)
// dQ_MKN Gemm (Gemm1 rrr)
// dV_NOM / dK_NKM Gemm (Gemm2 crr)
// 3x MNK + 2x MNO
std
::
size_t
flop
=
(
size_t
(
3
)
*
M
*
N
*
K
+
size_t
(
2
)
*
M
*
N
*
O
)
*
2
*
BatchCount
;
// Q/K/V/Y, dQ/dK/dV/dY, LSE
std
::
size_t
num_btype
=
(
sizeof
(
InputDataType
)
*
M
*
K
+
sizeof
(
InputDataType
)
*
K
*
N
+
sizeof
(
InputDataType
)
*
N
*
O
+
sizeof
(
InputDataType
)
*
M
*
O
*
size_t
(
2
)
+
sizeof
(
OutputDataType
)
*
M
*
K
+
sizeof
(
OutputDataType
)
*
K
*
N
+
sizeof
(
OutputDataType
)
*
N
*
O
+
sizeof
(
Acc0BiasDataType
)
*
M
*
N
)
*
BatchCount
+
sizeof
(
LSEDataType
)
*
M
*
BatchCount
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
// std::cout << "z_g_m_n ref:\n" << z_g_m_n;
bool
pass
=
true
;
if
(
do_verification
)
{
// copy z matirx data form device
Tensor
<
InputDataType
>
q_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
InputDataType
>
k_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
Acc0BiasDataType
>
d_g_m_n
({
G0
*
G1
,
M
,
N
});
Tensor
<
ZDataType
>
z_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
InputDataType
>
v_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
AccDataType
>
s_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
InputDataType
>
p_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
InputDataType
>
p_drop_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
InputDataType
>
y_g_m_o
({
BatchCount
,
M
,
O
});
Tensor
<
LSEDataType
>
lse_g_m
({
BatchCount
,
M
});
z_device_buf
.
FromDevice
(
z_gs_ms_ns
.
mData
.
data
());
z_gs_ms_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
z_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
q_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
q_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
k_gs_ns_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
k_g_n_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
v_gs_os_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
v_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
});
d_gs_ms_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
d_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
// run fwd again for y, cause z_g_m_n update
run_attention_fwd_host
(
q_g_m_k
,
k_g_n_k
,
d_g_m_n
,
v_g_n_o
,
alpha
,
s_g_m_n
,
p_g_m_n
,
y_g_m_o
,
lse_g_m
,
p_drop_g_m_n
,
z_g_m_n
,
p_dropout_in_16bits
,
rp_dropout
);
y_gs_ms_os
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
y_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
});
lse_gs_ms
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
lse_g_m
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
]);
});
y_device_buf
.
ToDevice
(
y_gs_ms_os
.
mData
.
data
());
lse_device_buf
.
ToDevice
(
lse_gs_ms
.
mData
.
data
());
// call kernel again
qgrad_device_buf
.
SetZero
();
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
Tensor
<
OutputDataType
>
qgrad_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
OutputDataType
>
kgrad_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
OutputDataType
>
vgrad_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
InputDataType
>
sgrad_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
InputDataType
>
pgrad_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
InputDataType
>
pgrad_drop_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
InputDataType
>
ygrad_g_m_o
({
BatchCount
,
M
,
O
});
Tensor
<
InputDataType
>
ygrad_dot_y_g_m
({
BatchCount
,
M
});
ygrad_gs_ms_os
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
ygrad_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
#if PRINT_HOST
{
std
::
cout
<<
"q_g_m_k ref:
\n
"
<<
q_g_m_k
;
std
::
cout
<<
"k_g_n_k ref:
\n
"
<<
k_g_n_k
;
std
::
cout
<<
"v_g_n_o ref:
\n
"
<<
v_g_n_o
;
std
::
cout
<<
"ygrad_g_m_o ref:
\n
"
<<
ygrad_g_m_o
;
}
#endif
// Gradients
auto
ref_gemm0_grad
=
ReferenceGemm0GradInstance
{};
auto
ref_gemm0_grad_invoker
=
ref_gemm0_grad
.
MakeInvoker
();
using
RefGemm0GradArg
=
ReferenceGemm0GradInstance
::
Argument
;
auto
ref_gemm1_grad
=
ReferenceGemm1GradInstance
{};
auto
ref_gemm1_grad_invoker
=
ref_gemm1_grad
.
MakeInvoker
();
using
RefGemm1GradArg
=
ReferenceGemm1GradInstance
::
Argument
;
// dP_dropout = dY * V^T
auto
v_g_o_n
=
v_g_n_o
.
Transpose
({
0
,
2
,
1
});
ref_gemm0_grad_invoker
.
Run
(
RefGemm0GradArg
{
ygrad_g_m_o
,
v_g_o_n
,
pgrad_drop_g_m_n
,
PassThrough
{},
PassThrough
{},
Scale
{
1.
f
}});
#if PRINT_HOST
{
std
::
cout
<<
"===== dP = dY * V^T
\n
"
;
std
::
cout
<<
"ygrad_g_m_o ref:
\n
"
<<
ygrad_g_m_o
;
std
::
cout
<<
"v_g_o_n ref:
\n
"
<<
v_g_o_n
;
std
::
cout
<<
"pgrad_drop_g_m_n ref:
\n
"
<<
pgrad_drop_g_m_n
;
}
#endif
// dP = dP_dropout x Z
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
z_g_m_n
,
pgrad_drop_g_m_n
,
pgrad_g_m_n
,
p_dropout_in_16bits
,
rp_dropout
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
sgrad_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx_gmn
)
{
float
ygrad_dot_y
=
0
;
for
(
int
o
=
0
;
o
<
O
;
o
++
)
{
auto
idx_gmo
=
idx_gmn
;
idx_gmo
[
2
]
=
o
;
ygrad_dot_y
+=
ck
::
type_convert
<
AccDataType
>
(
ygrad_g_m_o
(
idx_gmo
))
*
ck
::
type_convert
<
AccDataType
>
(
y_g_m_o
(
idx_gmo
));
}
self
(
idx_gmn
)
=
ck
::
type_convert
<
InputDataType
>
(
ck
::
type_convert
<
AccDataType
>
(
p_g_m_n
(
idx_gmn
))
*
(
ck
::
type_convert
<
AccDataType
>
(
pgrad_g_m_n
(
idx_gmn
))
-
ygrad_dot_y
));
});
#if PRINT_HOST
{
std
::
cout
<<
"===== dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
\n
"
;
std
::
cout
<<
"p_g_m_n ref:
\n
"
<<
p_g_m_n
;
std
::
cout
<<
"pgrad_g_m_n ref:
\n
"
<<
pgrad_g_m_n
;
std
::
cout
<<
"y_g_m_o ref:
\n
"
<<
y_g_m_o
;
std
::
cout
<<
"ygrad_g_m_o ref:
\n
"
<<
ygrad_g_m_o
;
std
::
cout
<<
"sgrad_g_m_n ref:
\n
"
<<
sgrad_g_m_n
;
}
#endif
// dV = P_drop^T * dY
auto
p_drop_g_n_m
=
p_drop_g_m_n
.
Transpose
({
0
,
2
,
1
});
ref_gemm1_grad_invoker
.
Run
(
RefGemm1GradArg
{
p_drop_g_n_m
,
ygrad_g_m_o
,
vgrad_g_n_o
,
PassThrough
{},
PassThrough
{},
Scale
{
1.0
f
}});
#if PRINT_HOST
{
std
::
cout
<<
"===== dV = P^T * dY
\n
"
;
std
::
cout
<<
"p_drop_g_n_m ref:
\n
"
<<
p_drop_g_n_m
;
std
::
cout
<<
"ygrad_g_m_o ref:
\n
"
<<
ygrad_g_m_o
;
std
::
cout
<<
"vgrad_g_n_o ref:
\n
"
<<
vgrad_g_n_o
;
}
#endif
// dQ = alpha * dS * K
ref_gemm1_grad_invoker
.
Run
(
RefGemm1GradArg
{
sgrad_g_m_n
,
k_g_n_k
,
qgrad_g_m_k
,
PassThrough
{},
PassThrough
{},
Scale
{
alpha
}});
#if PRINT_HOST
{
std
::
cout
<<
"===== dQ = alpha * dS * K
\n
"
;
std
::
cout
<<
"sgrad_g_m_n ref:
\n
"
<<
sgrad_g_m_n
;
std
::
cout
<<
"k_g_n_k ref:
\n
"
<<
k_g_n_k
;
std
::
cout
<<
"qgrad_g_m_k ref:
\n
"
<<
qgrad_g_m_k
;
}
#endif
// dK = alpha * dS^T * Q
auto
sgrad_g_n_m
=
sgrad_g_m_n
.
Transpose
({
0
,
2
,
1
});
ref_gemm1_grad_invoker
.
Run
(
RefGemm1GradArg
{
sgrad_g_n_m
,
q_g_m_k
,
kgrad_g_n_k
,
PassThrough
{},
PassThrough
{},
Scale
{
alpha
}});
#if PRINT_HOST
{
std
::
cout
<<
"===== dK = alpha * dS^T * Q
\n
"
;
std
::
cout
<<
"sgrad_g_n_m ref:
\n
"
<<
sgrad_g_n_m
;
std
::
cout
<<
"q_g_m_k ref:
\n
"
<<
q_g_m_k
;
std
::
cout
<<
"kgrad_g_n_k ref:
\n
"
<<
kgrad_g_n_k
;
}
#endif
Tensor
<
OutputDataType
>
qgrad_gs_ms_ks_host_result
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_host_result
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_host_result
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
Tensor
<
OutputDataType
>
qgrad_gs_ms_ks_device_result
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_device_result
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_device_result
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
qgrad_device_buf
.
FromDevice
(
qgrad_gs_ms_ks_device_result
.
mData
.
data
());
kgrad_device_buf
.
FromDevice
(
kgrad_gs_ns_ks_device_result
.
mData
.
data
());
vgrad_device_buf
.
FromDevice
(
vgrad_gs_os_ns_device_result
.
mData
.
data
());
// permute
qgrad_gs_ms_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
self
(
idx
)
=
qgrad_g_m_k
(
g
,
idx
[
2
],
idx
[
3
]);
});
kgrad_gs_ns_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
self
(
idx
)
=
kgrad_g_n_k
(
g
,
idx
[
2
],
idx
[
3
]);
});
vgrad_gs_os_ns_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
self
(
idx
)
=
vgrad_g_n_o
(
g
,
idx
[
3
],
idx
[
2
]);
});
std
::
cout
<<
"Checking qgrad:
\n
"
;
pass
&=
ck
::
utils
::
check_err
(
qgrad_gs_ms_ks_device_result
.
mData
,
qgrad_gs_ms_ks_host_result
.
mData
,
"error"
,
1e-2
,
1e-2
);
std
::
cout
<<
"Checking kgrad:
\n
"
;
pass
&=
ck
::
utils
::
check_err
(
kgrad_gs_ns_ks_device_result
.
mData
,
kgrad_gs_ns_ks_host_result
.
mData
,
"error"
,
1e-2
,
1e-2
);
std
::
cout
<<
"Checking vgrad:
\n
"
;
pass
&=
ck
::
utils
::
check_err
(
vgrad_gs_os_ns_device_result
.
mData
,
vgrad_gs_os_ns_host_result
.
mData
,
"error"
,
1e-2
,
1e-2
);
}
return
pass
?
((
void
)(
std
::
cout
<<
"pass
\n
"
),
0
)
:
((
void
)(
std
::
cout
<<
"fail
\n
"
),
1
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
return
run
(
argc
,
argv
);
}
example/52_flash_atten_bias/grouped_multihead_attention_bias_backward_v2.cpp
0 → 100644
View file @
4c8b47c0
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
/*
Backprop for Gemm + Softmax + Gemm fused operation, where forward prop is defined as:
Y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
Computation graph:
K^T V
| |
| |
Q --- * ----- Softmax ----- * --> Y
S P
Kernel inputs:
Q, K, V, Y, dY, per-row softmax stats (LSE)
Kernel outputs:
dQ, dK, dV
*/
#define USING_MASK 0
#define DIM 128 // DIM should be a multiple of 8.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <fstream>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v1.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_dropout.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
F16
=
ck
::
half_t
;
using
BF16
=
ck
::
bhalf_t
;
using
F32
=
float
;
using
U16
=
unsigned
short
;
using
INT32
=
int32_t
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
using
QKVElementOp
=
PassThrough
;
using
YElementOp
=
PassThrough
;
using
InputDataType
=
F16
;
using
OutputDataType
=
F16
;
using
GemmDataType
=
F16
;
using
AccDataType
=
F32
;
using
ShuffleDataType
=
F32
;
using
LSEDataType
=
F32
;
using
ZDataType
=
U16
;
// INT32
using
Acc0BiasDataType
=
F16
;
using
Acc1BiasDataType
=
void
;
static
constexpr
ck
::
index_t
NumDimG
=
2
;
static
constexpr
ck
::
index_t
NumDimM
=
1
;
static
constexpr
ck
::
index_t
NumDimN
=
1
;
static
constexpr
ck
::
index_t
NumDimK
=
1
;
static
constexpr
ck
::
index_t
NumDimO
=
1
;
// When OutputDataType == F32, CShuffleBlockTransferScalarPerVector_NPerBlock = 4
// When OutputDataType == F16/BF16, CShuffleBlockTransferScalarPerVector_NPerBlock = 8
static
constexpr
ck
::
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
=
8
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
#if USING_MASK
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskUpperTriangleFromTopLeft
;
#else
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskDisabled
;
#endif
static
constexpr
auto
TensorSpecQ
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecK
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecV
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecY
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
bool
Deterministic
=
false
;
// DIM should be a multiple of 8.
// If DIM <= 32 , ues prototype1.
// If 32 < DIM <= 64 , ues prototype1.
// If 64 < DIM <= 128, ues prototype2.
#if(DIM <= 32)
// clang-format off
using
DeviceGemmInstance
=
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | |
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
OutputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
128
,
128
,
32
,
32
,
32
,
8
,
8
,
2
,
32
,
32
,
4
,
1
,
1
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpec
,
Deterministic
>
;
// clang-format on
#elif(DIM <= 64)
// clang-format off
using
DeviceGemmInstance
=
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | |
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
OutputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
64
,
128
,
64
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
2
,
1
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpec
,
Deterministic
>
;
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | |
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, false, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, false, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, false, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// clang-format on
#elif(DIM <= 128)
// clang-format off
using
DeviceGemmInstance
=
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 4, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 4, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 4, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 4, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 4, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 64, 128, 64, 128, 32, 8, 8, 2, 32, 32, 2, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 64, 128, 64, 128, 32, 8, 8, 2, 32, 32, 2, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 64, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
OutputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
64
,
128
,
128
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
2
,
1
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
4
,
S
<
1
,
32
,
1
,
8
>
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpec
,
Deterministic
>
;
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 32, 128, 128, 128, 32, 8, 8, 2, 32, 32, 1, 1, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 32, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 1, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// clang-format on
#endif
// Ref Gemm0: S = alpha * Q * K^T
// fp16 in, fp32 out
using
ReferenceGemm0Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
InputDataType
,
InputDataType
,
AccDataType
,
AccDataType
,
PassThrough
,
PassThrough
,
Scale
>
;
// Ref Softmax: P = Softmax(S)
// fp32 in, fp16 out
using
ReferenceSoftmaxInstance
=
ck
::
tensor_operation
::
host
::
ReferenceSoftmax
<
AccDataType
,
InputDataType
,
AccDataType
>
;
// Ref Gemm1: Y = P * V
// fp16 in, fp16 out
using
ReferenceGemm1Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
InputDataType
,
InputDataType
,
InputDataType
,
AccDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
// Ref Gemm for backward pass
// fp16 in, fp16 out
using
ReferenceGemm0GradInstance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
InputDataType
,
InputDataType
,
InputDataType
,
AccDataType
,
PassThrough
,
PassThrough
,
Scale
>
;
using
ReferenceGemm1GradInstance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
InputDataType
,
InputDataType
,
OutputDataType
,
AccDataType
,
PassThrough
,
PassThrough
,
Scale
>
;
// Ref dropout
using
ReferenceDropoutInstance
=
ck
::
tensor_operation
::
host
::
ReferenceDropout
<
ZDataType
,
InputDataType
,
InputDataType
>
;
template
<
typename
TensorQ
,
typename
TensorK
,
typename
TensorD
,
typename
TensorV
,
typename
TensorS
,
typename
TensorP
,
typename
TensorZ
,
typename
TensorY
,
typename
TensorLSE
=
TensorP
>
void
run_attention_fwd_host
(
const
TensorQ
&
q_g_m_k
,
const
TensorK
&
k_g_n_k
,
const
TensorD
&
d_g_m_n
,
const
TensorV
&
v_g_n_o
,
const
float
alpha
,
TensorS
&
s_g_m_n
,
TensorP
&
p_g_m_n
,
TensorY
&
y_g_m_o
,
TensorLSE
&
lse_g_m
,
TensorP
&
p_drop_g_m_n
,
TensorZ
&
z_g_m_n
,
ZDataType
p_dropout_in_16bits
,
float
rp_dropout
)
{
// S = alpha * Q * K^T
auto
k_g_k_n
=
k_g_n_k
.
Transpose
({
0
,
2
,
1
});
auto
ref_gemm0
=
ReferenceGemm0Instance
{};
auto
ref_gemm0_invoker
=
ref_gemm0
.
MakeInvoker
();
auto
ref_gemm0_argument
=
ref_gemm0
.
MakeArgument
(
q_g_m_k
,
k_g_k_n
,
s_g_m_n
,
PassThrough
{},
PassThrough
{},
Scale
{
alpha
});
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
// bias
s_g_m_n
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
+=
ck
::
type_convert
<
AccDataType
>
(
d_g_m_n
(
idx
));
});
// masking
auto
M
=
s_g_m_n
.
GetLengths
()[
1
];
auto
N
=
s_g_m_n
.
GetLengths
()[
2
];
const
auto
mask
=
DeviceGemmInstance
::
C0MatrixMask
(
M
,
N
);
s_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
if
(
mask
.
IsMaskedElement
(
idx
[
1
],
idx
[
2
]))
self
(
idx
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
});
// P = Softmax(S)
auto
ref_softmax
=
ReferenceSoftmaxInstance
{};
auto
ref_softmax_invoker
=
ref_softmax
.
MakeInvoker
();
auto
ref_softmax_argument
=
ref_softmax
.
MakeArgument
(
s_g_m_n
,
p_g_m_n
,
1
,
0
,
{
2
},
&
lse_g_m
);
ref_softmax_invoker
.
Run
(
ref_softmax_argument
);
// P_dropped
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
z_g_m_n
,
p_g_m_n
,
p_drop_g_m_n
,
p_dropout_in_16bits
,
rp_dropout
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
// Y = P_dropout * V
auto
ref_gemm1
=
ReferenceGemm1Instance
{};
auto
ref_gemm1_invoker
=
ref_gemm1
.
MakeInvoker
();
auto
ref_gemm1_argument
=
ref_gemm1
.
MakeArgument
(
p_drop_g_m_n
,
v_g_n_o
,
y_g_m_o
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
ref_gemm1_invoker
.
Run
(
ref_gemm1_argument
);
}
int
run
(
int
argc
,
char
*
argv
[])
{
bool
do_verification
=
true
;
int
init_method
=
2
;
// method 1 will have slightly higher error; TODO: to investigate
bool
time_kernel
=
true
;
// Overall QKV matrices shape
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
float
alpha
=
1.
f
/
std
::
sqrt
(
DIM
);
float
p_drop
=
0.0
;
bool
input_permute
=
true
;
bool
output_permute
=
true
;
const
unsigned
long
long
seed
=
1
;
const
unsigned
long
long
offset
=
0
;
if
(
argc
==
1
)
{
// use default case
}
else
if
(
argc
==
4
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
else
if
(
argc
==
7
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
p_drop
=
std
::
stof
(
argv
[
4
]);
input_permute
=
std
::
stoi
(
argv
[
5
]);
output_permute
=
std
::
stoi
(
argv
[
6
]);
}
else
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4 to 11: M, N, K, O, G0, G1
\n
"
);
printf
(
"arg10: scale (alpha)
\n
"
);
printf
(
"arg11 to 12: input / output permute
\n
"
);
exit
(
0
);
}
float
p_dropout
=
1
-
p_drop
;
ZDataType
p_dropout_in_16bits
=
ZDataType
(
std
::
floor
(
p_dropout
*
65535.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
std
::
vector
<
DeviceGemmInstance
::
ProblemDesc
>
problem_descs
;
using
DeviceMemPtr
=
std
::
unique_ptr
<
DeviceMem
>
;
std
::
vector
<
const
void
*>
p_q
;
std
::
vector
<
const
void
*>
p_k
;
std
::
vector
<
const
void
*>
p_d0
;
std
::
vector
<
void
*>
p_z
;
// for result verification
std
::
vector
<
void
*>
p_z_nullptr
;
// for time test
std
::
vector
<
const
void
*>
p_v
;
std
::
vector
<
const
void
*>
p_y
;
std
::
vector
<
const
void
*>
p_lse
;
std
::
vector
<
void
*>
p_qgrad
;
std
::
vector
<
void
*>
p_kgrad
;
std
::
vector
<
void
*>
p_vgrad
;
std
::
vector
<
const
void
*>
p_ygrad
;
std
::
vector
<
Tensor
<
InputDataType
>>
q_g_m_ks
;
std
::
vector
<
Tensor
<
InputDataType
>>
k_g_n_ks
;
std
::
vector
<
Tensor
<
Acc0BiasDataType
>>
d0_g_m_ns
;
std
::
vector
<
Tensor
<
ZDataType
>>
z_g_m_ns
;
std
::
vector
<
Tensor
<
InputDataType
>>
v_g_n_os
;
std
::
vector
<
Tensor
<
AccDataType
>>
s_g_m_ns
;
std
::
vector
<
Tensor
<
InputDataType
>>
p_g_m_ns
;
std
::
vector
<
Tensor
<
InputDataType
>>
y_g_m_os
;
std
::
vector
<
Tensor
<
LSEDataType
>>
lse_g_ms
;
std
::
vector
<
Tensor
<
InputDataType
>>
p_drop_g_m_ns
;
std
::
vector
<
Tensor
<
InputDataType
>>
q_tensors
;
std
::
vector
<
Tensor
<
InputDataType
>>
k_tensors
;
std
::
vector
<
Tensor
<
Acc0BiasDataType
>>
d0_tensors
;
std
::
vector
<
Tensor
<
InputDataType
>>
v_tensors
;
std
::
vector
<
Tensor
<
InputDataType
>>
y_tensors
;
std
::
vector
<
Tensor
<
ZDataType
>>
z_tensors
;
std
::
vector
<
Tensor
<
LSEDataType
>>
lse_tensors
;
std
::
vector
<
Tensor
<
OutputDataType
>>
qgrad_tensors
;
std
::
vector
<
Tensor
<
OutputDataType
>>
kgrad_tensors
;
std
::
vector
<
Tensor
<
OutputDataType
>>
vgrad_tensors
;
std
::
vector
<
Tensor
<
InputDataType
>>
ygrad_tensors
;
std
::
vector
<
DeviceMemPtr
>
q_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
k_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
d0_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
z_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
v_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
y_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
lse_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
qgrad_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
ygrad_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
kgrad_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
vgrad_tensors_device
;
std
::
size_t
group_count
=
10
;
std
::
size_t
flop
=
0
,
num_byte
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
{
int
M
=
128
*
(
rand
()
%
8
)
+
(
rand
()
%
128
);
int
N
=
128
*
(
rand
()
%
8
)
+
(
rand
()
%
128
);
int
K
=
DIM
;
int
O
=
DIM
;
int
G0
=
rand
()
%
4
+
1
;
int
G1
=
rand
()
%
4
+
1
;
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// Q layout [G0, M, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
K
,
M
*
K
,
K
,
1
};
// Q layout [G0, G1, M, K]
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_lengths
{
G0
,
G1
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// K layout [G0, N, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
K
,
N
*
K
,
K
,
1
};
// K layout [G0, G1, N, K]
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_lengths
{
G0
,
G1
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
O
,
O
,
1
,
G1
*
O
}
// V layout [G0, N, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
O
,
N
*
O
,
1
,
O
};
// V layout [G0, G1, N, O]
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_strides
=
output_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// Y layout [G0, M, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// Y layout [G0, G1, M, O]
std
::
vector
<
ck
::
index_t
>
d0_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
d0_gs_ms_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// d0 layout [G0, M, G1, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// d0 layout [G0, G1, M, N]
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// Z layout [G0, M, G1, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1, M, N]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward
// pass Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...)))
// ^^^^^^^^^^^^^^^^^^^^^
// LSE
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
{
G1
*
M
,
M
,
1
};
// LSE layout [G0, G1, M]
problem_descs
.
push_back
({
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
,
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
,
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
,
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
,
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
,
lse_gs_ms_lengths
,
lse_gs_ms_strides
,
d0_gs_ms_ns_lengths
,
d0_gs_ms_ns_strides
,
{},
// acc1_bias_gs_ms_os_lengths,
{},
// acc1_bias_gs_ms_os_strides,
});
int
BatchCount
=
G0
*
G1
;
flop
+=
(
size_t
(
3
)
*
M
*
N
*
K
+
size_t
(
2
)
*
M
*
N
*
O
)
*
2
*
BatchCount
;
// Q/K/V/Y, dQ/dK/dV/dY, LSE
num_byte
+=
(
sizeof
(
InputDataType
)
*
M
*
K
+
sizeof
(
InputDataType
)
*
K
*
N
+
sizeof
(
InputDataType
)
*
N
*
O
+
sizeof
(
InputDataType
)
*
M
*
O
*
size_t
(
2
)
+
sizeof
(
OutputDataType
)
*
M
*
K
+
sizeof
(
OutputDataType
)
*
K
*
N
+
sizeof
(
OutputDataType
)
*
N
*
O
+
sizeof
(
Acc0BiasDataType
)
*
M
*
N
)
*
BatchCount
+
sizeof
(
LSEDataType
)
*
M
*
BatchCount
;
Tensor
<
InputDataType
>
q_gs_ms_ks
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
InputDataType
>
k_gs_ns_ks
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
Acc0BiasDataType
>
d0_gs_ms_ns
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
Tensor
<
ZDataType
>
z_gs_ms_ns
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
Tensor
<
InputDataType
>
v_gs_os_ns
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
Tensor
<
InputDataType
>
y_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
InputDataType
>
ygrad_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
LSEDataType
>
lse_gs_ms
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
if
(
i
<
4
)
{
std
::
cout
<<
"q_gs_ms_ks: "
<<
q_gs_ms_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"k_gs_ns_ks: "
<<
k_gs_ns_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d0_gs_ms_ns: "
<<
d0_gs_ms_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"z_gs_ms_ns: "
<<
z_gs_ms_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"v_gs_os_ns: "
<<
v_gs_os_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"y_gs_ms_os: "
<<
y_gs_ms_os
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"lse_gs_ms_os: "
<<
lse_gs_ms
.
mDesc
<<
std
::
endl
;
}
z_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
0
});
switch
(
init_method
)
{
case
0
:
break
;
case
1
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
InputDataType
>
{
-
2
,
2
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
InputDataType
>
{
-
2
,
2
});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
InputDataType
>
{
-
2
,
2
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_2
<
InputDataType
>
{
-
2
,
2
});
d0_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
Acc0BiasDataType
>
{
-
2
,
2
});
break
;
case
2
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
InputDataType
>
{
0.0
,
1.0
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
InputDataType
>
{
0.0
,
1.0
});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_3
<
InputDataType
>
{
-
0.5
,
0.5
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_3
<
InputDataType
>
{
-
0.5
,
0.5
});
d0_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_3
<
Acc0BiasDataType
>
{
-
0.5
,
0.5
});
break
;
case
3
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
InputDataType
>
{
-
5
,
5
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
d0_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Acc0BiasDataType
>
{
1
});
break
;
case
4
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
d0_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Acc0BiasDataType
>
{
1
});
break
;
case
5
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
// dy[g0, g1, m, o]
d0_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Acc0BiasDataType
>
{
1
});
// dO dot O = [0; 1; 2; ...]
break
;
case
6
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
3
>
{});
// dy[g0, g1, m, o]
d0_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Acc0BiasDataType
>
{
1
});
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
// dP = dO V = [0, 1, 2, ...; 0, 1, 2, ...; ...]
// dO dot O = [127.5; ...]
// dS = P * (dP - dO dot O)
//
break
;
default:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
// dy[g0, g1, m, o]
d0_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Acc0BiasDataType
>
{
1
});
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
// dP = dO V = ones
// dS = P * (dP - (dO dot O))
// = 0.0039 * ones * (ones - 0.0039*256)
// = 0.0039 * ones * (ones - 1)
// = 0
}
Tensor
<
InputDataType
>
q_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
InputDataType
>
k_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
Acc0BiasDataType
>
d0_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
ZDataType
>
z_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
InputDataType
>
v_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
AccDataType
>
s_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
InputDataType
>
p_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
InputDataType
>
y_g_m_o
({
BatchCount
,
M
,
O
});
Tensor
<
LSEDataType
>
lse_g_m
({
BatchCount
,
M
});
Tensor
<
InputDataType
>
p_drop_g_m_n
({
BatchCount
,
M
,
N
});
q_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
q_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
k_gs_ns_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
k_g_n_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
d0_gs_ms_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
d0_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
v_gs_os_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
v_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
});
q_g_m_ks
.
push_back
(
q_g_m_k
);
k_g_n_ks
.
push_back
(
k_g_n_k
);
d0_g_m_ns
.
push_back
(
d0_g_m_n
);
z_g_m_ns
.
push_back
(
z_g_m_n
);
v_g_n_os
.
push_back
(
v_g_n_o
);
s_g_m_ns
.
push_back
(
s_g_m_n
);
p_g_m_ns
.
push_back
(
p_g_m_n
);
y_g_m_os
.
push_back
(
y_g_m_o
);
lse_g_ms
.
push_back
(
lse_g_m
);
p_drop_g_m_ns
.
push_back
(
p_drop_g_m_n
);
q_tensors
.
push_back
(
q_gs_ms_ks
);
k_tensors
.
push_back
(
k_gs_ns_ks
);
d0_tensors
.
push_back
(
d0_gs_ms_ns
);
v_tensors
.
push_back
(
v_gs_os_ns
);
y_tensors
.
push_back
(
y_gs_ms_os
);
z_tensors
.
push_back
(
z_gs_ms_ns
);
lse_tensors
.
push_back
(
lse_gs_ms
);
ygrad_tensors
.
push_back
(
ygrad_gs_ms_os
);
q_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
InputDataType
)
*
q_gs_ms_ks
.
GetElementSpaceSize
()));
k_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
InputDataType
)
*
k_gs_ns_ks
.
GetElementSpaceSize
()));
d0_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
Acc0BiasDataType
)
*
d0_gs_ms_ns
.
GetElementSpaceSize
()));
z_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
ZDataType
)
*
z_gs_ms_ns
.
GetElementSpaceSize
()));
v_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
InputDataType
)
*
v_gs_os_ns
.
GetElementSpaceSize
()));
y_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
InputDataType
)
*
y_gs_ms_os
.
GetElementSpaceSize
()));
lse_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
LSEDataType
)
*
lse_gs_ms
.
GetElementSpaceSize
()));
qgrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
OutputDataType
)
*
q_gs_ms_ks
.
GetElementSpaceSize
()));
kgrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
OutputDataType
)
*
k_gs_ns_ks
.
GetElementSpaceSize
()));
vgrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
OutputDataType
)
*
v_gs_os_ns
.
GetElementSpaceSize
()));
ygrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
InputDataType
)
*
y_gs_ms_os
.
GetElementSpaceSize
()));
q_tensors_device
.
back
()
->
ToDevice
(
q_gs_ms_ks
.
data
());
k_tensors_device
.
back
()
->
ToDevice
(
k_gs_ns_ks
.
data
());
d0_tensors_device
.
back
()
->
ToDevice
(
d0_gs_ms_ns
.
data
());
z_tensors_device
.
back
()
->
ToDevice
(
z_gs_ms_ns
.
data
());
v_tensors_device
.
back
()
->
ToDevice
(
v_gs_os_ns
.
data
());
ygrad_tensors_device
.
back
()
->
ToDevice
(
ygrad_gs_ms_os
.
data
());
p_q
.
push_back
(
q_tensors_device
.
back
()
->
GetDeviceBuffer
());
p_k
.
push_back
(
k_tensors_device
.
back
()
->
GetDeviceBuffer
());
p_d0
.
push_back
(
d0_tensors_device
.
back
()
->
GetDeviceBuffer
());
p_z
.
push_back
(
z_tensors_device
.
back
()
->
GetDeviceBuffer
());
p_z_nullptr
.
push_back
(
nullptr
);
p_v
.
push_back
(
v_tensors_device
.
back
()
->
GetDeviceBuffer
());
p_y
.
push_back
(
y_tensors_device
.
back
()
->
GetDeviceBuffer
());
p_lse
.
push_back
(
lse_tensors_device
.
back
()
->
GetDeviceBuffer
());
p_kgrad
.
push_back
(
kgrad_tensors_device
.
back
()
->
GetDeviceBuffer
());
p_vgrad
.
push_back
(
vgrad_tensors_device
.
back
()
->
GetDeviceBuffer
());
p_ygrad
.
push_back
(
ygrad_tensors_device
.
back
()
->
GetDeviceBuffer
());
p_qgrad
.
push_back
(
qgrad_tensors_device
.
back
()
->
GetDeviceBuffer
());
}
auto
argument
=
gemm
.
MakeArgument
(
p_q
,
p_k
,
p_z_nullptr
,
p_v
,
p_y
,
p_lse
,
p_ygrad
,
p_qgrad
,
p_kgrad
,
p_vgrad
,
p_d0
,
{},
problem_descs
,
QKVElementOp
{},
QKVElementOp
{},
Scale
{
alpha
},
QKVElementOp
{},
YElementOp
{},
p_drop
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
(
seed
,
offset
));
DeviceMem
problem_desc_workspace
(
gemm
.
GetWorkSpaceSize
(
&
argument
));
gemm
.
SetWorkSpacePointer
(
&
argument
,
problem_desc_workspace
.
GetDeviceBuffer
());
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
std
::
cout
<<
gemm
.
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
return
0
;
}
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_byte
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
bool
pass
=
true
;
if
(
do_verification
)
{
// get z matrix
argument
=
gemm
.
MakeArgument
(
p_q
,
p_k
,
p_z
,
p_v
,
p_y
,
p_lse
,
p_ygrad
,
p_qgrad
,
p_kgrad
,
p_vgrad
,
p_d0
,
{},
problem_descs
,
QKVElementOp
{},
QKVElementOp
{},
Scale
{
alpha
},
QKVElementOp
{},
YElementOp
{},
p_drop
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
(
seed
,
offset
));
DeviceMem
problem_desc_workspace_verify
(
gemm
.
GetWorkSpaceSize
(
&
argument
));
gemm
.
SetWorkSpacePointer
(
&
argument
,
problem_desc_workspace_verify
.
GetDeviceBuffer
());
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
std
::
cout
<<
gemm
.
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
return
0
;
}
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
{
int
G1
=
v_tensors
[
i
].
GetLengths
()[
1
];
// copy z matirx data form device
z_tensors_device
[
i
]
->
FromDevice
(
z_tensors
[
i
].
mData
.
data
());
z_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
z_g_m_ns
[
i
](
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
run_attention_fwd_host
(
q_g_m_ks
[
i
],
k_g_n_ks
[
i
],
d0_g_m_ns
[
i
],
v_g_n_os
[
i
],
alpha
,
s_g_m_ns
[
i
],
p_g_m_ns
[
i
],
y_g_m_os
[
i
],
lse_g_ms
[
i
],
p_drop_g_m_ns
[
i
],
z_g_m_ns
[
i
],
p_dropout_in_16bits
,
rp_dropout
);
y_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
y_g_m_os
[
i
](
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
});
y_tensors_device
[
i
]
->
ToDevice
(
y_tensors
[
i
].
data
());
lse_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
lse_g_ms
[
i
](
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
]);
});
lse_tensors_device
[
i
]
->
ToDevice
(
lse_tensors
[
i
].
data
());
qgrad_tensors_device
[
i
]
->
SetZero
();
kgrad_tensors_device
[
i
]
->
SetZero
();
vgrad_tensors_device
[
i
]
->
SetZero
();
}
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
{
int
G0
=
v_tensors
[
i
].
GetLengths
()[
0
];
int
G1
=
v_tensors
[
i
].
GetLengths
()[
1
];
int
O
=
v_tensors
[
i
].
GetLengths
()[
2
];
int
N
=
v_tensors
[
i
].
GetLengths
()[
3
];
int
M
=
q_tensors
[
i
].
GetLengths
()[
2
];
int
K
=
q_tensors
[
i
].
GetLengths
()[
3
];
int
BatchCount
=
G0
*
G1
;
Tensor
<
OutputDataType
>
qgrad_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
OutputDataType
>
kgrad_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
OutputDataType
>
vgrad_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
InputDataType
>
sgrad_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
InputDataType
>
pgrad_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
InputDataType
>
pgrad_drop_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
InputDataType
>
ygrad_g_m_o
({
BatchCount
,
M
,
O
});
ygrad_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
ygrad_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
auto
ref_gemm0_grad
=
ReferenceGemm0GradInstance
{};
auto
ref_gemm0_grad_invoker
=
ref_gemm0_grad
.
MakeInvoker
();
using
RefGemm0GradArg
=
ReferenceGemm0GradInstance
::
Argument
;
auto
ref_gemm1_grad
=
ReferenceGemm1GradInstance
{};
auto
ref_gemm1_grad_invoker
=
ref_gemm1_grad
.
MakeInvoker
();
using
RefGemm1GradArg
=
ReferenceGemm1GradInstance
::
Argument
;
// dP = dY * V^T
auto
v_g_o_n
=
v_g_n_os
[
i
].
Transpose
({
0
,
2
,
1
});
ref_gemm0_grad_invoker
.
Run
(
RefGemm0GradArg
{
ygrad_g_m_o
,
v_g_o_n
,
pgrad_drop_g_m_n
,
PassThrough
{},
PassThrough
{},
Scale
{
1.
f
}});
// dP = dP_dropout x Z
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
z_g_m_ns
[
i
],
pgrad_drop_g_m_n
,
pgrad_g_m_n
,
p_dropout_in_16bits
,
rp_dropout
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
sgrad_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx_gmn
)
{
float
ygrad_dot_y
=
0
;
for
(
int
o
=
0
;
o
<
O
;
o
++
)
{
auto
idx_gmo
=
idx_gmn
;
idx_gmo
[
2
]
=
o
;
ygrad_dot_y
+=
ck
::
type_convert
<
AccDataType
>
(
ygrad_g_m_o
(
idx_gmo
))
*
ck
::
type_convert
<
AccDataType
>
(
y_g_m_os
[
i
](
idx_gmo
));
}
self
(
idx_gmn
)
=
ck
::
type_convert
<
InputDataType
>
(
ck
::
type_convert
<
AccDataType
>
(
p_g_m_ns
[
i
](
idx_gmn
))
*
(
ck
::
type_convert
<
AccDataType
>
(
pgrad_g_m_n
(
idx_gmn
))
-
ygrad_dot_y
));
});
// dV = P_drop^T * dY
auto
p_drop_g_n_m
=
p_drop_g_m_ns
[
i
].
Transpose
({
0
,
2
,
1
});
ref_gemm1_grad_invoker
.
Run
(
RefGemm1GradArg
{
p_drop_g_n_m
,
ygrad_g_m_o
,
vgrad_g_n_o
,
PassThrough
{},
PassThrough
{},
Scale
{
1.0
f
}});
// dQ = alpha * dS * K
ref_gemm1_grad_invoker
.
Run
(
RefGemm1GradArg
{
sgrad_g_m_n
,
k_g_n_ks
[
i
],
qgrad_g_m_k
,
PassThrough
{},
PassThrough
{},
Scale
{
alpha
}});
// dK = alpha * dS^T * Q
auto
sgrad_g_n_m
=
sgrad_g_m_n
.
Transpose
({
0
,
2
,
1
});
ref_gemm1_grad_invoker
.
Run
(
RefGemm1GradArg
{
sgrad_g_n_m
,
q_g_m_ks
[
i
],
kgrad_g_n_k
,
PassThrough
{},
PassThrough
{},
Scale
{
alpha
}});
Tensor
<
OutputDataType
>
qgrad_gs_ms_ks_host_result
(
q_tensors
[
i
].
GetLengths
(),
q_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_host_result
(
k_tensors
[
i
].
GetLengths
(),
k_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_host_result
(
v_tensors
[
i
].
GetLengths
(),
v_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
qgrad_gs_ms_ks_device_result
(
q_tensors
[
i
].
GetLengths
(),
q_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_device_result
(
k_tensors
[
i
].
GetLengths
(),
k_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_device_result
(
v_tensors
[
i
].
GetLengths
(),
v_tensors
[
i
].
GetStrides
());
qgrad_tensors_device
[
i
]
->
FromDevice
(
qgrad_gs_ms_ks_device_result
.
data
());
kgrad_tensors_device
[
i
]
->
FromDevice
(
kgrad_gs_ns_ks_device_result
.
data
());
vgrad_tensors_device
[
i
]
->
FromDevice
(
vgrad_gs_os_ns_device_result
.
data
());
// permute
qgrad_gs_ms_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
self
(
idx
)
=
qgrad_g_m_k
(
g
,
idx
[
2
],
idx
[
3
]);
});
kgrad_gs_ns_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
self
(
idx
)
=
kgrad_g_n_k
(
g
,
idx
[
2
],
idx
[
3
]);
});
vgrad_gs_os_ns_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
self
(
idx
)
=
vgrad_g_n_o
(
g
,
idx
[
3
],
idx
[
2
]);
});
std
::
cout
<<
"Checking qgrad:
\n
"
;
pass
&=
ck
::
utils
::
check_err
(
qgrad_gs_ms_ks_device_result
.
mData
,
qgrad_gs_ms_ks_host_result
.
mData
,
"error"
,
1e-2
,
1e-2
);
std
::
cout
<<
"Checking kgrad:
\n
"
;
pass
&=
ck
::
utils
::
check_err
(
kgrad_gs_ns_ks_device_result
.
mData
,
kgrad_gs_ns_ks_host_result
.
mData
,
"error"
,
1e-2
,
1e-2
);
std
::
cout
<<
"Checking vgrad:
\n
"
;
pass
&=
ck
::
utils
::
check_err
(
vgrad_gs_os_ns_device_result
.
mData
,
vgrad_gs_os_ns_host_result
.
mData
,
"error"
,
1e-2
,
1e-2
);
}
}
return
pass
?
((
void
)(
std
::
cout
<<
"pass
\n
"
),
0
)
:
((
void
)(
std
::
cout
<<
"fail
\n
"
),
1
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
return
run
(
argc
,
argv
);
}
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
View file @
4c8b47c0
...
...
@@ -769,6 +769,7 @@ struct BlockwiseGemmXdlops_v2
c_thread_buf_
;
__host__
__device__
constexpr
auto
&
GetCThreadBuffer
()
{
return
c_thread_buf_
;
}
__host__
__device__
constexpr
auto
&
GetCThreadDesc
()
{
return
c_thread_desc_
;
}
__device__
static
auto
GetWaveIdx
()
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
View file @
4c8b47c0
...
...
@@ -27,6 +27,7 @@ namespace device {
template
<
typename
GridwiseGemm
,
typename
InputDataType
,
typename
D0DataType
,
typename
OutputDataType
,
typename
ZDataType
,
typename
LSEDataType
,
...
...
@@ -37,6 +38,7 @@ template <typename GridwiseGemm,
typename
CElementwiseOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
D0GridDescriptor_M0_N0_M1_M2_N1_M3
,
typename
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
,
typename
B1GridDesc_BK0_N_BK1
,
typename
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
...
...
@@ -55,6 +57,7 @@ __global__ void
kernel_batched_multihead_attention_backward_qloop_xdl_cshuffle_v1
(
const
InputDataType
*
__restrict__
p_a_grid
,
const
InputDataType
*
__restrict__
p_b_grid
,
const
D0DataType
*
__restrict__
p_d0_grid
,
ZDataType
*
__restrict__
p_z_grid
,
const
InputDataType
*
__restrict__
p_b1_grid
,
const
InputDataType
*
__restrict__
p_c_grid
,
...
...
@@ -70,6 +73,7 @@ __global__ void
const
CElementwiseOperation
c_element_op
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
D0GridDescriptor_M0_N0_M1_M2_N1_M3
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
const
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
,
...
...
@@ -115,6 +119,13 @@ __global__ void
const
index_t
z_random_matrix_offset
=
g_idx
*
raw_m_padded
*
raw_n_padded
;
const
D0DataType
*
tmp_p_d0_grid
=
nullptr
;
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
const
long_index_t
d0_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetD0BasePtr
(
g_idx
)));
tmp_p_d0_grid
=
p_d0_grid
+
d0_batch_offset
;
}
if
constexpr
(
Deterministic
)
{
for
(
index_t
i
=
0
;
i
<
nblock
;
i
++
)
...
...
@@ -122,6 +133,7 @@ __global__ void
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
>(
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
tmp_p_d0_grid
,
z_matrix_ptr
,
p_b1_grid
+
b1_batch_offset
,
p_c_grid
+
c_batch_offset
,
...
...
@@ -138,6 +150,7 @@ __global__ void
c_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
...
...
@@ -157,6 +170,7 @@ __global__ void
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
>(
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
tmp_p_d0_grid
,
z_matrix_ptr
,
p_b1_grid
+
b1_batch_offset
,
p_c_grid
+
c_batch_offset
,
...
...
@@ -173,6 +187,7 @@ __global__ void
c_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
...
...
@@ -189,6 +204,7 @@ __global__ void
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_d0_grid
;
ignore
=
p_z_grid
;
ignore
=
p_b1_grid
;
ignore
=
p_c_grid
;
...
...
@@ -204,6 +220,7 @@ __global__ void
ignore
=
c_element_op
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
d0_grid_desc_m0_n0_m1_m2_n1_m3
;
ignore
=
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
;
ignore
=
b1_grid_desc_bk0_n_bk1
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
...
...
@@ -279,6 +296,7 @@ template <index_t NumDimG,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
bool
BBlockLdsExtraN
,
index_t
D0BlockTransferSrcScalarPerVector
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
...
...
@@ -292,11 +310,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
&&
NumDimO
>
0
,
"Number of dimension must be greater than 0"
);
static
constexpr
index_t
NumAcc0Bias
=
Acc0BiasDataType
::
Size
()
;
static
constexpr
index_t
NumAcc1Bias
=
Acc1BiasDataType
::
Size
()
;
using
D0DataType
=
Acc0BiasDataType
;
using
D1DataType
=
Acc1BiasDataType
;
// TODO: implement bias combination
static_assert
(
NumAcc0Bias
==
0
&&
NumAcc0Bias
==
0
,
"
Bias addition is unimplemented"
);
static_assert
(
std
::
is_void
<
D1DataType
>::
value
,
"Acc1
Bias addition is unimplemented"
);
using
DeviceOp
=
DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
;
...
...
@@ -335,31 +353,31 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
*/
// Q in Gemm A position
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
_vec
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
_vec
)
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
)
{
return
Transform
::
MakeAGridDescriptor_AK0_M_AK1
(
Transform
::
MakeAGridDescriptor_M_K
(
a_gs_ms_ks_lengths
_vec
,
a_gs_ms_ks_strides
_vec
),
Transform
::
MakeAGridDescriptor_M_K
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
),
Number
<
AK1
>
{});
}
// K in Gemm B0 position
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
_vec
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides
_vec
)
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides
)
{
return
Transform
::
MakeB0GridDescriptor_BK0_N_BK1
(
Transform
::
MakeB0GridDescriptor_N_K
(
b_gs_ns_ks_lengths
_vec
,
b_gs_ns_ks_strides
_vec
),
Transform
::
MakeB0GridDescriptor_N_K
(
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
),
Number
<
BK1
>
{});
}
// V in Gemm B1 position
static
auto
MakeB1GridDescriptor_BK0_N_BK1
(
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_lengths
_vec
,
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_strides
_vec
)
MakeB1GridDescriptor_BK0_N_BK1
(
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_lengths
,
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_strides
)
{
return
Transform
::
MakeB1GridDescriptor_BK0_N_BK1
(
Transform
::
MakeB1GridDescriptor_N_K
(
b1_gs_gemm1ns_gemm1ks_lengths
_vec
,
b1_gs_gemm1ns_gemm1ks_strides
_vec
),
Transform
::
MakeB1GridDescriptor_N_K
(
b1_gs_gemm1ns_gemm1ks_lengths
,
b1_gs_gemm1ns_gemm1ks_strides
),
Number
<
B1K1
>
{});
}
...
...
@@ -368,8 +386,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
//
// VGrad in Gemm C position
static
auto
MakeVGradGridDescriptor_N_O
(
const
std
::
vector
<
index_t
>&
v_gs_os_ns_lengths
_vec
,
const
std
::
vector
<
index_t
>&
v_gs_os_ns_strides
_vec
)
static
auto
MakeVGradGridDescriptor_N_O
(
const
std
::
vector
<
index_t
>&
v_gs_os_ns_lengths
,
const
std
::
vector
<
index_t
>&
v_gs_os_ns_strides
)
{
// v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major.
// Here directly rearrange lengths/strides before constructing tensor descriptor to reduce
...
...
@@ -395,17 +413,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
ids_old2new
.
insert
(
ids_old2new
.
end
(),
ns_ids
.
begin
(),
ns_ids
.
end
());
ids_old2new
.
insert
(
ids_old2new
.
end
(),
os_ids
.
begin
(),
os_ids
.
end
());
std
::
vector
<
index_t
>
v_gs_ns_os_lengths
_vec
(
num_dims
),
v_gs_ns_os_strides
_vec
(
num_dims
);
std
::
vector
<
index_t
>
v_gs_ns_os_lengths
(
num_dims
),
v_gs_ns_os_strides
(
num_dims
);
for
(
int
i
=
0
;
i
<
num_dims
;
i
++
)
{
index_t
id_new
=
ids_old2new
[
i
];
v_gs_ns_os_lengths
_vec
[
i
]
=
v_gs_os_ns_lengths
_vec
[
id_new
];
v_gs_ns_os_strides
_vec
[
i
]
=
v_gs_os_ns_strides
_vec
[
id_new
];
v_gs_ns_os_lengths
[
i
]
=
v_gs_os_ns_lengths
[
id_new
];
v_gs_ns_os_strides
[
i
]
=
v_gs_os_ns_strides
[
id_new
];
}
const
auto
vgrad_desc_nraw_oraw
=
MakeGridDescriptorPair
<
NumDimG
,
NumDimN
,
NumDimO
,
TensorSpecialization
::
Default
>
(
v_gs_ns_os_lengths
_vec
,
v_gs_ns_os_strides
_vec
)
v_gs_ns_os_lengths
,
v_gs_ns_os_strides
)
.
second
;
return
PadTensorDescriptor
(
vgrad_desc_nraw_oraw
,
...
...
@@ -434,17 +452,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
//
// YGrad in Gemm A position
static
auto
MakeYGradGridDescriptor_O0_M_O1
(
const
std
::
vector
<
index_t
>&
y_gs_ms_os_lengths
_vec
,
const
std
::
vector
<
index_t
>&
y_gs_ms_os_strides
_vec
)
static
auto
MakeYGradGridDescriptor_O0_M_O1
(
const
std
::
vector
<
index_t
>&
y_gs_ms_os_lengths
,
const
std
::
vector
<
index_t
>&
y_gs_ms_os_strides
)
{
return
Transform
::
MakeAGridDescriptor_AK0_M_AK1
(
Transform
::
MakeAGridDescriptor_M_K
(
y_gs_ms_os_lengths
_vec
,
y_gs_ms_os_strides
_vec
),
Transform
::
MakeAGridDescriptor_M_K
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
),
Number
<
Y_O1
>
{});
}
// V in Gemm B position
static
auto
MakeVGridDescriptor_O0_N_O1
(
const
std
::
vector
<
index_t
>&
v_gs_os_ns_lengths
_vec
,
const
std
::
vector
<
index_t
>&
v_gs_os_ns_strides
_vec
)
static
auto
MakeVGridDescriptor_O0_N_O1
(
const
std
::
vector
<
index_t
>&
v_gs_os_ns_lengths
,
const
std
::
vector
<
index_t
>&
v_gs_os_ns_strides
)
{
// v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major.
// Here directly rearrange lengths/strides before constructing tensor descriptor to reduce
...
...
@@ -470,17 +488,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
ids_old2new
.
insert
(
ids_old2new
.
end
(),
ns_ids
.
begin
(),
ns_ids
.
end
());
ids_old2new
.
insert
(
ids_old2new
.
end
(),
os_ids
.
begin
(),
os_ids
.
end
());
std
::
vector
<
index_t
>
v_gs_ns_os_lengths
_vec
(
num_dims
),
v_gs_ns_os_strides
_vec
(
num_dims
);
std
::
vector
<
index_t
>
v_gs_ns_os_lengths
(
num_dims
),
v_gs_ns_os_strides
(
num_dims
);
for
(
int
i
=
0
;
i
<
num_dims
;
i
++
)
{
index_t
id_new
=
ids_old2new
[
i
];
v_gs_ns_os_lengths
_vec
[
i
]
=
v_gs_os_ns_lengths
_vec
[
id_new
];
v_gs_ns_os_strides
_vec
[
i
]
=
v_gs_os_ns_strides
_vec
[
id_new
];
v_gs_ns_os_lengths
[
i
]
=
v_gs_os_ns_lengths
[
id_new
];
v_gs_ns_os_strides
[
i
]
=
v_gs_os_ns_strides
[
id_new
];
}
const
auto
v_grid_desc_nraw_oraw
=
MakeGridDescriptorPair
<
NumDimG
,
NumDimN
,
NumDimO
,
TensorSpecialization
::
Default
>
(
v_gs_ns_os_lengths
_vec
,
v_gs_ns_os_strides
_vec
)
v_gs_ns_os_lengths
,
v_gs_ns_os_strides
)
.
second
;
const
auto
v_grid_desc_n_o
=
PadTensorDescriptor
(
v_grid_desc_nraw_oraw
,
...
...
@@ -492,10 +510,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
}
// Z in Gemm0 C position
static
auto
MakeZGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
_vec
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
_vec
)
static
auto
MakeZGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
)
{
return
Transform
::
MakeCGridDescriptor_M_N
(
z_gs_ms_ns_lengths
_vec
,
z_gs_ms_ns_strides
_vec
);
return
Transform
::
MakeCGridDescriptor_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
}
//
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
...
...
@@ -506,10 +524,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
//
// QGrad in Gemm C position
static
auto
MakeQGradGridDescriptor_M_K
(
const
std
::
vector
<
index_t
>&
q_gs_ms_ks_lengths
_vec
,
const
std
::
vector
<
index_t
>&
q_gs_ms_ks_strides
_vec
)
static
auto
MakeQGradGridDescriptor_M_K
(
const
std
::
vector
<
index_t
>&
q_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
q_gs_ms_ks_strides
)
{
return
Transform
::
MakeCGridDescriptor_M_N
(
q_gs_ms_ks_lengths
_vec
,
q_gs_ms_ks_strides
_vec
);
return
Transform
::
MakeCGridDescriptor_M_N
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
}
//
...
...
@@ -517,10 +535,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
//
// KGrad in Gemm C position
static
auto
MakeKGradGridDescriptor_N_K
(
const
std
::
vector
<
index_t
>&
k_gs_ns_ks_lengths
_vec
,
const
std
::
vector
<
index_t
>&
k_gs_ns_ks_strides
_vec
)
static
auto
MakeKGradGridDescriptor_N_K
(
const
std
::
vector
<
index_t
>&
k_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
k_gs_ns_ks_strides
)
{
return
Transform
::
MakeCGridDescriptor_M_N
(
k_gs_ns_ks_lengths
_vec
,
k_gs_ns_ks_strides
_vec
);
return
Transform
::
MakeCGridDescriptor_M_N
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
}
static
auto
MakeLSEGridDescriptor_M
(
index_t
MRaw
)
...
...
@@ -547,9 +565,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
return
lse_grid_desc_mraw
;
}
}
// D in Gemm0 C position
static
auto
MakeDGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_strides
)
{
return
Transform
::
MakeCGridDescriptor_M_N
(
d_gs_ms_ns_lengths
,
d_gs_ms_ns_strides
);
}
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
({},
{}));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
({},
{}));
using
D0GridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
B1GridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
({},
{}));
using
YGridDesc_M_O
=
decltype
(
Transform
::
MakeCGridDescriptor_M_N
({},
{}));
using
LSEGridDesc_M
=
decltype
(
MakeLSEGridDescriptor_M
(
1
));
...
...
@@ -559,6 +584,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
using
CGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
ZGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
D0GridDesc_M_N
=
decltype
(
MakeDGridDescriptor_M_N
({},
{}));
using
KGridDesc_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_N_K
({},
{}));
using
YGradGridDesc_O0_M_O1
=
decltype
(
MakeYGradGridDescriptor_O0_M_O1
({},
{}));
using
ZGridDesc_M_N
=
decltype
(
MakeZGridDescriptor_M_N
({},
{}));
...
...
@@ -582,14 +608,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
struct
ComputeBasePtrOfStridedBatch
{
ComputeBasePtrOfStridedBatch
()
{}
ComputeBasePtrOfStridedBatch
(
const
AGridDesc_G_M_K
&
a_grid_desc_g_m_k
,
const
BGridDesc_G_N_K
&
b_grid_desc_g_n_k
,
const
D0GridDesc_G_M_N
&
d0_grid_desc_g_m_n
,
const
ZGridDesc_G_M_N
&
z_grid_desc_g_m_n
,
const
B1GridDesc_G_N_K
&
b1_grid_desc_g_n_k
,
const
CGridDesc_G_M_N
&
c_grid_desc_g_m_n
,
index_t
BatchStrideLSE
)
:
a_grid_desc_g_m_k_
(
a_grid_desc_g_m_k
),
b_grid_desc_g_n_k_
(
b_grid_desc_g_n_k
),
d0_grid_desc_g_m_n_
(
d0_grid_desc_g_m_n
),
z_grid_desc_g_m_n_
(
z_grid_desc_g_m_n
),
b1_grid_desc_g_n_k_
(
b1_grid_desc_g_n_k
),
c_grid_desc_g_m_n_
(
c_grid_desc_g_m_n
),
...
...
@@ -607,6 +636,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
return
b_grid_desc_g_n_k_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
__host__
__device__
constexpr
long_index_t
GetD0BasePtr
(
index_t
g_idx
)
const
{
return
d0_grid_desc_g_m_n_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
__host__
__device__
constexpr
long_index_t
GetZBasePtr
(
index_t
g_idx
)
const
{
return
z_grid_desc_g_m_n_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
...
...
@@ -630,6 +664,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
private:
AGridDesc_G_M_K
a_grid_desc_g_m_k_
;
BGridDesc_G_N_K
b_grid_desc_g_n_k_
;
D0GridDesc_G_M_N
d0_grid_desc_g_m_n_
;
ZGridDesc_G_M_N
z_grid_desc_g_m_n_
;
B1GridDesc_G_N_K
b1_grid_desc_g_n_k_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
...
...
@@ -640,6 +675,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
<
InputDataType
,
// TODO: distinguish A/B datatype
D0DataType
,
OutputDataType
,
ZDataType
,
GemmDataType
,
...
...
@@ -655,6 +691,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
AGridDesc_AK0_M_AK1
,
BGridDesc_BK0_N_BK1
,
KGridDesc_N_K
,
D0GridDesc_M_N
,
ZGridDesc_M_N
,
B1GridDesc_BK0_N_BK1
,
YGridDesc_M_O
,
...
...
@@ -691,6 +728,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
BBlockTransferDstScalarPerVector_BK1
,
true
,
BBlockLdsExtraN
,
D0BlockTransferSrcScalarPerVector
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
...
...
@@ -703,8 +741,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
const
InputDataType
*
p_a_grid
,
Argument
(
const
InputDataType
*
p_a_grid
,
const
InputDataType
*
p_b_grid
,
ZDataType
*
p_z_grid
,
const
InputDataType
*
p_b1_grid
,
...
...
@@ -714,8 +751,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
OutputDataType
*
p_qgrad_grid
,
OutputDataType
*
p_kgrad_grid
,
OutputDataType
*
p_vgrad_grid
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_bias
es
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_bias
es
,
const
D0DataType
*
p_acc0_bias
,
const
D1DataType
*
p_acc1_bias
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
...
...
@@ -727,12 +764,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_bias
es
_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_bias
es
_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
acc1_bias
es
_gs_ms_gemm1ns_lengths
,
// acc1_bias
es
_gs_ms_os_lengths
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
acc1_bias
es
_gs_ms_gemm1ns_strides
,
// acc1_bias
es
_gs_ms_os_strides
const
std
::
vector
<
ck
::
index_t
>
&
acc0_bias_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>
&
acc0_bias_gs_ms_ns_strides
,
const
std
::
vector
<
ck
::
index_t
>
&
acc1_bias_gs_ms_gemm1ns_lengths
,
// acc1_bias_gs_ms_os_lengths
const
std
::
vector
<
ck
::
index_t
>
&
acc1_bias_gs_ms_gemm1ns_strides
,
// acc1_bias_gs_ms_os_strides
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
AccElementwiseOperation
acc_element_op
,
...
...
@@ -742,6 +779,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_d0_grid_
{
p_acc0_bias
},
p_z_grid_
{
p_z_grid
},
p_b1_grid_
{
p_b1_grid
},
p_c_grid_
{
p_c_grid
},
...
...
@@ -796,22 +834,15 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
c_mz_gemm1nz_strides_
{
c_gs_ms_gemm1ns_strides
[
NumDimG
+
NumDimM
-
1
],
c_gs_ms_gemm1ns_strides
[
NumDimG
+
NumDimM
+
NumDimO
-
1
]},
batch_count_
{
c_grid_desc_g_m_n_
.
GetLength
(
I0
)},
compute_base_ptr_of_batch_
{
a_grid_desc_g_m_k_
,
b_grid_desc_g_n_k_
,
z_grid_desc_g_m_n_
,
b1_grid_desc_g_n_k_
,
c_grid_desc_g_m_n_
,
type_convert
<
index_t
>
(
lse_grid_desc_m_
.
GetElementSpaceSize
())},
p_drop_
{
p_drop
}
{
// TODO: implement bias addition
ignore
=
p_acc0_bias
es
;
ignore
=
p_acc1_bias
es
;
ignore
=
acc0_bias
es
_gs_ms_ns_lengths
;
ignore
=
acc0_bias
es
_gs_ms_ns_strides
;
ignore
=
acc1_bias
es
_gs_ms_gemm1ns_lengths
;
ignore
=
acc1_bias
es
_gs_ms_gemm1ns_strides
;
ignore
=
p_acc0_bias
;
ignore
=
p_acc1_bias
;
ignore
=
acc0_bias_gs_ms_ns_lengths
;
ignore
=
acc0_bias_gs_ms_ns_strides
;
ignore
=
acc1_bias_gs_ms_gemm1ns_lengths
;
ignore
=
acc1_bias_gs_ms_gemm1ns_strides
;
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
b_grid_desc_bk0_n_bk1_
,
...
...
@@ -822,6 +853,28 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
GridwiseGemm
::
MakeYGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
(
y_grid_desc_m_o_
);
}
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
const
auto
d0_grid_desc_m_n
=
MakeDGridDescriptor_M_N
(
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
);
d0_grid_desc_m0_n0_m1_m2_n1_m3_
=
GridwiseGemm
::
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3
(
d0_grid_desc_m_n
);
d0_grid_desc_g_m_n_
=
Transform
::
MakeCGridDescriptor_G_M_N
(
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
);
d0_n_length_stride_
.
push_back
(
acc0_bias_gs_ms_ns_lengths
[
NumDimG
+
NumDimM
]);
d0_n_length_stride_
.
push_back
(
acc0_bias_gs_ms_ns_strides
[
NumDimG
+
NumDimM
]);
}
compute_base_ptr_of_batch_
=
ComputeBasePtrOfStridedBatch
(
a_grid_desc_g_m_k_
,
b_grid_desc_g_n_k_
,
d0_grid_desc_g_m_n_
,
z_grid_desc_g_m_n_
,
b1_grid_desc_g_n_k_
,
c_grid_desc_g_m_n_
,
type_convert
<
index_t
>
(
lse_grid_desc_m_
.
GetElementSpaceSize
()));
seed_
=
std
::
get
<
0
>
(
seeds
);
offset_
=
std
::
get
<
1
>
(
seeds
);
...
...
@@ -862,6 +915,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// pointers
const
InputDataType
*
p_a_grid_
;
const
InputDataType
*
p_b_grid_
;
const
D0DataType
*
p_d0_grid_
;
ZDataType
*
p_z_grid_
;
const
InputDataType
*
p_b1_grid_
;
const
InputDataType
*
p_c_grid_
;
...
...
@@ -874,6 +928,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// tensor descriptor
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
typename
GridwiseGemm
::
D0GridDescriptor_M0_N0_M1_M2_N1_M3
d0_grid_desc_m0_n0_m1_m2_n1_m3_
;
ZGridDesc_M_N
z_grid_desc_m_n_
;
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1_
;
YGridDesc_M_O
y_grid_desc_m_o_
;
...
...
@@ -884,6 +939,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// batch offsets
AGridDesc_G_M_K
a_grid_desc_g_m_k_
;
BGridDesc_G_N_K
b_grid_desc_g_n_k_
;
D0GridDesc_G_M_N
d0_grid_desc_g_m_n_
;
B1GridDesc_G_N_K
b1_grid_desc_g_n_k_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
ZGridDesc_G_M_N
z_grid_desc_g_m_n_
;
...
...
@@ -922,6 +978,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
index_t
m_raw_padded_
;
index_t
n_raw_padded_
;
// raw data
std
::
vector
<
ck
::
index_t
>
d0_n_length_stride_
;
};
// Invoker
...
...
@@ -948,6 +1007,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
kernel_batched_multihead_attention_backward_qloop_xdl_cshuffle_v1
<
GridwiseGemm
,
InputDataType
,
D0DataType
,
OutputDataType
,
ZDataType
,
LSEDataType
,
...
...
@@ -958,6 +1018,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
CElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
D0GridDescriptor_M0_N0_M1_M2_N1_M3
,
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
,
DeviceOp
::
B1GridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
...
...
@@ -978,6 +1039,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_d0_grid_
,
arg
.
p_z_grid_
,
arg
.
p_b1_grid_
,
arg
.
p_c_grid_
,
...
...
@@ -993,6 +1055,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
arg
.
c_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
d0_grid_desc_m0_n0_m1_m2_n1_m3_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
y_grid_desc_mblock_mperblock_oblock_operblock_
,
...
...
@@ -1064,6 +1127,19 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
return
false
;
}
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
if
(
arg
.
d0_n_length_stride_
[
1
]
==
1
&&
arg
.
d0_n_length_stride_
[
0
]
%
D0BlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
if
(
arg
.
d0_n_length_stride_
[
1
]
!=
1
&&
D0BlockTransferSrcScalarPerVector
!=
1
)
{
return
false
;
}
}
// Note: we need raw lengths since threadwise copy can not handle vector load when part of
// vector is out of bounds
// Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O
...
...
@@ -1109,8 +1185,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
InputDataType
*
p_a
,
static
auto
MakeArgument
(
const
InputDataType
*
p_a
,
const
InputDataType
*
p_b
,
ZDataType
*
p_z
,
const
InputDataType
*
p_b1
,
...
...
@@ -1120,8 +1196,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
OutputDataType
*
p_qgrad_grid
,
OutputDataType
*
p_kgrad_grid
,
OutputDataType
*
p_vgrad_grid
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_bias
es
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_bias
es
,
const
D0DataType
*
p_acc0_bias
,
const
D1DataType
*
p_acc1_bias
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
...
...
@@ -1133,12 +1209,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_bias
es
_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_bias
es
_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
acc1_bias
es
_gs_ms_gemm1ns_lengths
,
// acc1_bias
es
_gs_ms_os_lengths
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
acc1_bias
es
_gs_ms_gemm1ns_strides
,
// acc1_bias
es
_gs_ms_os_strides
const
std
::
vector
<
ck
::
index_t
>
&
acc0_bias_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>
&
acc0_bias_gs_ms_ns_strides
,
const
std
::
vector
<
ck
::
index_t
>
&
acc1_bias_gs_ms_gemm1ns_lengths
,
// acc1_bias_gs_ms_os_lengths
const
std
::
vector
<
ck
::
index_t
>
&
acc1_bias_gs_ms_gemm1ns_strides
,
// acc1_bias_gs_ms_os_strides
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
AccElementwiseOperation
acc_element_op
,
...
...
@@ -1157,8 +1233,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
p_qgrad_grid
,
p_kgrad_grid
,
p_vgrad_grid
,
p_acc0_bias
es
,
p_acc1_bias
es
,
p_acc0_bias
,
p_acc1_bias
,
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
b_gs_ns_ks_lengths
,
...
...
@@ -1170,10 +1246,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
lse_gs_ms_lengths
,
acc0_bias
es
_gs_ms_ns_lengths
,
acc0_bias
es
_gs_ms_ns_strides
,
acc1_bias
es
_gs_ms_gemm1ns_lengths
,
// acc1_bias
es
_gs_ms_os_lengths
acc1_bias
es
_gs_ms_gemm1ns_strides
,
// acc1_bias
es
_gs_ms_os_strides
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
,
acc1_bias_gs_ms_gemm1ns_lengths
,
// acc1_bias_gs_ms_os_lengths
acc1_bias_gs_ms_gemm1ns_strides
,
// acc1_bias_gs_ms_os_strides
a_element_op
,
b_element_op
,
acc_element_op
,
...
...
@@ -1198,8 +1274,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
void
*
p_qgrad_grid
,
void
*
p_kgrad_grid
,
void
*
p_vgrad_grid
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_bias
es
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_bias
es
,
const
D0DataType
*
p_acc0_bias
,
const
D1DataType
*
p_acc1_bias
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
...
...
@@ -1211,12 +1287,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_bias
es
_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_bias
es
_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
acc1_bias
es
_gs_ms_gemm1ns_lengths
,
// acc1_bias
es
_gs_ms_os_lengths
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
acc1_bias
es
_gs_ms_gemm1ns_strides
,
// acc1_bias
es
_gs_ms_os_strides
const
std
::
vector
<
ck
::
index_t
>
&
acc0_bias_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>
&
acc0_bias_gs_ms_ns_strides
,
const
std
::
vector
<
ck
::
index_t
>
&
acc1_bias_gs_ms_gemm1ns_lengths
,
// acc1_bias_gs_ms_os_lengths
const
std
::
vector
<
ck
::
index_t
>
&
acc1_bias_gs_ms_gemm1ns_strides
,
// acc1_bias_gs_ms_os_strides
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
AccElementwiseOperation
acc_element_op
,
...
...
@@ -1225,7 +1301,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
float
p_drop
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
// override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
InputDataType
*>
(
p_a
),
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
InputDataType
*>
(
p_a
),
static_cast
<
const
InputDataType
*>
(
p_b
),
static_cast
<
ZDataType
*>
(
p_z
),
static_cast
<
const
InputDataType
*>
(
p_b1
),
...
...
@@ -1235,8 +1312,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static_cast
<
OutputDataType
*>
(
p_qgrad_grid
),
static_cast
<
OutputDataType
*>
(
p_kgrad_grid
),
static_cast
<
OutputDataType
*>
(
p_vgrad_grid
),
p_acc0_bias
es
,
// cast in struct Argument
p_acc1_bias
es
,
// cast in struct Argument
static_cast
<
const
D0DataType
*>
(
p_acc0_bias
)
,
// cast in struct Argument
static_cast
<
const
D1DataType
*>
(
p_acc1_bias
)
,
// cast in struct Argument
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
b_gs_ns_ks_lengths
,
...
...
@@ -1248,10 +1325,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
lse_gs_ms_lengths
,
acc0_bias
es
_gs_ms_ns_lengths
,
acc0_bias
es
_gs_ms_ns_strides
,
acc1_bias
es
_gs_ms_gemm1ns_lengths
,
acc1_bias
es
_gs_ms_gemm1ns_strides
,
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
,
acc1_bias_gs_ms_gemm1ns_lengths
,
acc1_bias_gs_ms_gemm1ns_strides
,
a_element_op
,
b_element_op
,
acc_element_op
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
View file @
4c8b47c0
...
...
@@ -27,6 +27,7 @@ namespace device {
template
<
typename
GridwiseGemm
,
typename
InputDataType
,
typename
D0DataType
,
typename
OutputDataType
,
typename
ZDataType
,
typename
LSEDataType
,
...
...
@@ -37,6 +38,7 @@ template <typename GridwiseGemm,
typename
CElementwiseOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
D0GridDescriptor_M0_N0_M1_M2_N1_M3
,
typename
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
,
typename
B1GridDesc_BK0_N_BK1
,
typename
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
...
...
@@ -55,6 +57,7 @@ __global__ void
kernel_batched_multihead_attention_backward_qloop_xdl_cshuffle_v2
(
const
InputDataType
*
__restrict__
p_a_grid
,
const
InputDataType
*
__restrict__
p_b_grid
,
const
D0DataType
*
__restrict__
p_d0_grid
,
ZDataType
*
__restrict__
p_z_grid
,
const
InputDataType
*
__restrict__
p_b1_grid
,
const
InputDataType
*
__restrict__
p_c_grid
,
...
...
@@ -70,6 +73,7 @@ __global__ void
const
CElementwiseOperation
c_element_op
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
D0GridDescriptor_M0_N0_M1_M2_N1_M3
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
const
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
,
...
...
@@ -115,6 +119,14 @@ __global__ void
const
index_t
z_random_matrix_offset
=
g_idx
*
raw_m_padded
*
raw_n_padded
;
const
D0DataType
*
tmp_p_d0_grid
=
nullptr
;
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
const
long_index_t
d0_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetD0BasePtr
(
g_idx
)));
tmp_p_d0_grid
=
p_d0_grid
+
d0_batch_offset
;
}
if
constexpr
(
Deterministic
)
{
for
(
index_t
i
=
0
;
i
<
nblock
;
i
++
)
...
...
@@ -122,6 +134,7 @@ __global__ void
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
>(
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
tmp_p_d0_grid
,
z_matrix_ptr
,
p_b1_grid
+
b1_batch_offset
,
p_c_grid
+
c_batch_offset
,
...
...
@@ -138,6 +151,7 @@ __global__ void
c_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
...
...
@@ -157,6 +171,7 @@ __global__ void
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
>(
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
tmp_p_d0_grid
,
z_matrix_ptr
,
p_b1_grid
+
b1_batch_offset
,
p_c_grid
+
c_batch_offset
,
...
...
@@ -173,6 +188,7 @@ __global__ void
c_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
...
...
@@ -189,6 +205,7 @@ __global__ void
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_d0_grid
;
ignore
=
p_z_grid
;
ignore
=
p_b1_grid
;
ignore
=
p_c_grid
;
...
...
@@ -204,6 +221,7 @@ __global__ void
ignore
=
c_element_op
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
d0_grid_desc_m0_n0_m1_m2_n1_m3
;
ignore
=
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
;
ignore
=
b1_grid_desc_bk0_n_bk1
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
...
...
@@ -279,6 +297,7 @@ template <index_t NumDimG,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
bool
BBlockLdsExtraN
,
index_t
D0BlockTransferSrcScalarPerVector
,
typename
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
B1BlockTransferThreadClusterArrangeOrder
,
typename
B1BlockTransferSrcAccessOrder
,
...
...
@@ -299,11 +318,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
&&
NumDimO
>
0
,
"Number of dimension must be greater than 0"
);
static
constexpr
index_t
NumAcc0Bias
=
Acc0BiasDataType
::
Size
()
;
static
constexpr
index_t
NumAcc1Bias
=
Acc1BiasDataType
::
Size
()
;
using
D0DataType
=
Acc0BiasDataType
;
using
D1DataType
=
Acc1BiasDataType
;
// TODO: implement bias combination
static_assert
(
NumAcc0Bias
==
0
&&
NumAcc0Bias
==
0
,
"
Bias addition is unimplemented"
);
static_assert
(
std
::
is_void
<
D1DataType
>::
value
,
"Acc1
Bias addition is unimplemented"
);
using
DeviceOp
=
DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
;
...
...
@@ -342,31 +361,31 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
*/
// Q in Gemm A position
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
_vec
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
_vec
)
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
)
{
return
Transform
::
MakeAGridDescriptor_AK0_M_AK1
(
Transform
::
MakeAGridDescriptor_M_K
(
a_gs_ms_ks_lengths
_vec
,
a_gs_ms_ks_strides
_vec
),
Transform
::
MakeAGridDescriptor_M_K
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
),
Number
<
AK1
>
{});
}
// K in Gemm B0 position
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
_vec
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides
_vec
)
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides
)
{
return
Transform
::
MakeB0GridDescriptor_BK0_N_BK1
(
Transform
::
MakeB0GridDescriptor_N_K
(
b_gs_ns_ks_lengths
_vec
,
b_gs_ns_ks_strides
_vec
),
Transform
::
MakeB0GridDescriptor_N_K
(
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
),
Number
<
BK1
>
{});
}
// V in Gemm B1 position
static
auto
MakeB1GridDescriptor_BK0_N_BK1
(
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_lengths
_vec
,
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_strides
_vec
)
MakeB1GridDescriptor_BK0_N_BK1
(
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_lengths
,
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_strides
)
{
return
Transform
::
MakeB1GridDescriptor_BK0_N_BK1
(
Transform
::
MakeB1GridDescriptor_N_K
(
b1_gs_gemm1ns_gemm1ks_lengths
_vec
,
b1_gs_gemm1ns_gemm1ks_strides
_vec
),
Transform
::
MakeB1GridDescriptor_N_K
(
b1_gs_gemm1ns_gemm1ks_lengths
,
b1_gs_gemm1ns_gemm1ks_strides
),
Number
<
B1K1
>
{});
}
...
...
@@ -375,8 +394,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
//
// VGrad in Gemm C position
static
auto
MakeVGradGridDescriptor_N_O
(
const
std
::
vector
<
index_t
>&
v_gs_os_ns_lengths
_vec
,
const
std
::
vector
<
index_t
>&
v_gs_os_ns_strides
_vec
)
static
auto
MakeVGradGridDescriptor_N_O
(
const
std
::
vector
<
index_t
>&
v_gs_os_ns_lengths
,
const
std
::
vector
<
index_t
>&
v_gs_os_ns_strides
)
{
// v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major.
// Here directly rearrange lengths/strides before constructing tensor descriptor to reduce
...
...
@@ -402,17 +421,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
ids_old2new
.
insert
(
ids_old2new
.
end
(),
ns_ids
.
begin
(),
ns_ids
.
end
());
ids_old2new
.
insert
(
ids_old2new
.
end
(),
os_ids
.
begin
(),
os_ids
.
end
());
std
::
vector
<
index_t
>
v_gs_ns_os_lengths
_vec
(
num_dims
),
v_gs_ns_os_strides
_vec
(
num_dims
);
std
::
vector
<
index_t
>
v_gs_ns_os_lengths
(
num_dims
),
v_gs_ns_os_strides
(
num_dims
);
for
(
int
i
=
0
;
i
<
num_dims
;
i
++
)
{
index_t
id_new
=
ids_old2new
[
i
];
v_gs_ns_os_lengths
_vec
[
i
]
=
v_gs_os_ns_lengths
_vec
[
id_new
];
v_gs_ns_os_strides
_vec
[
i
]
=
v_gs_os_ns_strides
_vec
[
id_new
];
v_gs_ns_os_lengths
[
i
]
=
v_gs_os_ns_lengths
[
id_new
];
v_gs_ns_os_strides
[
i
]
=
v_gs_os_ns_strides
[
id_new
];
}
const
auto
vgrad_desc_nraw_oraw
=
MakeGridDescriptorPair
<
NumDimG
,
NumDimN
,
NumDimO
,
TensorSpecialization
::
Default
>
(
v_gs_ns_os_lengths
_vec
,
v_gs_ns_os_strides
_vec
)
v_gs_ns_os_lengths
,
v_gs_ns_os_strides
)
.
second
;
return
PadTensorDescriptor
(
vgrad_desc_nraw_oraw
,
...
...
@@ -441,17 +460,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
//
// YGrad in Gemm A position
static
auto
MakeYGradGridDescriptor_O0_M_O1
(
const
std
::
vector
<
index_t
>&
y_gs_ms_os_lengths
_vec
,
const
std
::
vector
<
index_t
>&
y_gs_ms_os_strides
_vec
)
static
auto
MakeYGradGridDescriptor_O0_M_O1
(
const
std
::
vector
<
index_t
>&
y_gs_ms_os_lengths
,
const
std
::
vector
<
index_t
>&
y_gs_ms_os_strides
)
{
return
Transform
::
MakeAGridDescriptor_AK0_M_AK1
(
Transform
::
MakeAGridDescriptor_M_K
(
y_gs_ms_os_lengths
_vec
,
y_gs_ms_os_strides
_vec
),
Transform
::
MakeAGridDescriptor_M_K
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
),
Number
<
Y_O1
>
{});
}
// V in Gemm B position
static
auto
MakeVGridDescriptor_O0_N_O1
(
const
std
::
vector
<
index_t
>&
v_gs_os_ns_lengths
_vec
,
const
std
::
vector
<
index_t
>&
v_gs_os_ns_strides
_vec
)
static
auto
MakeVGridDescriptor_O0_N_O1
(
const
std
::
vector
<
index_t
>&
v_gs_os_ns_lengths
,
const
std
::
vector
<
index_t
>&
v_gs_os_ns_strides
)
{
// v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major.
// Here directly rearrange lengths/strides before constructing tensor descriptor to reduce
...
...
@@ -477,17 +496,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
ids_old2new
.
insert
(
ids_old2new
.
end
(),
ns_ids
.
begin
(),
ns_ids
.
end
());
ids_old2new
.
insert
(
ids_old2new
.
end
(),
os_ids
.
begin
(),
os_ids
.
end
());
std
::
vector
<
index_t
>
v_gs_ns_os_lengths
_vec
(
num_dims
),
v_gs_ns_os_strides
_vec
(
num_dims
);
std
::
vector
<
index_t
>
v_gs_ns_os_lengths
(
num_dims
),
v_gs_ns_os_strides
(
num_dims
);
for
(
int
i
=
0
;
i
<
num_dims
;
i
++
)
{
index_t
id_new
=
ids_old2new
[
i
];
v_gs_ns_os_lengths
_vec
[
i
]
=
v_gs_os_ns_lengths
_vec
[
id_new
];
v_gs_ns_os_strides
_vec
[
i
]
=
v_gs_os_ns_strides
_vec
[
id_new
];
v_gs_ns_os_lengths
[
i
]
=
v_gs_os_ns_lengths
[
id_new
];
v_gs_ns_os_strides
[
i
]
=
v_gs_os_ns_strides
[
id_new
];
}
const
auto
v_grid_desc_nraw_oraw
=
MakeGridDescriptorPair
<
NumDimG
,
NumDimN
,
NumDimO
,
TensorSpecialization
::
Default
>
(
v_gs_ns_os_lengths
_vec
,
v_gs_ns_os_strides
_vec
)
v_gs_ns_os_lengths
,
v_gs_ns_os_strides
)
.
second
;
const
auto
v_grid_desc_n_o
=
PadTensorDescriptor
(
v_grid_desc_nraw_oraw
,
...
...
@@ -498,11 +517,18 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
return
Transform
::
MakeB0GridDescriptor_BK0_N_BK1
(
v_grid_desc_n_o
,
Number
<
V_O1
>
{});
}
// D in Gemm0 C position
static
auto
MakeDGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_strides
)
{
return
Transform
::
MakeCGridDescriptor_M_N
(
d_gs_ms_ns_lengths
,
d_gs_ms_ns_strides
);
}
// Z in Gemm0 C position
static
auto
MakeZGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
_vec
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
_vec
)
static
auto
MakeZGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
)
{
return
Transform
::
MakeCGridDescriptor_M_N
(
z_gs_ms_ns_lengths
_vec
,
z_gs_ms_ns_strides
_vec
);
return
Transform
::
MakeCGridDescriptor_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
}
//
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
...
...
@@ -513,10 +539,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
//
// QGrad in Gemm C position
static
auto
MakeQGradGridDescriptor_M_K
(
const
std
::
vector
<
index_t
>&
q_gs_ms_ks_lengths
_vec
,
const
std
::
vector
<
index_t
>&
q_gs_ms_ks_strides
_vec
)
static
auto
MakeQGradGridDescriptor_M_K
(
const
std
::
vector
<
index_t
>&
q_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
q_gs_ms_ks_strides
)
{
return
Transform
::
MakeCGridDescriptor_M_N
(
q_gs_ms_ks_lengths
_vec
,
q_gs_ms_ks_strides
_vec
);
return
Transform
::
MakeCGridDescriptor_M_N
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
}
//
...
...
@@ -524,10 +550,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
//
// KGrad in Gemm C position
static
auto
MakeKGradGridDescriptor_N_K
(
const
std
::
vector
<
index_t
>&
k_gs_ns_ks_lengths
_vec
,
const
std
::
vector
<
index_t
>&
k_gs_ns_ks_strides
_vec
)
static
auto
MakeKGradGridDescriptor_N_K
(
const
std
::
vector
<
index_t
>&
k_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
k_gs_ns_ks_strides
)
{
return
Transform
::
MakeCGridDescriptor_M_N
(
k_gs_ns_ks_lengths
_vec
,
k_gs_ns_ks_strides
_vec
);
return
Transform
::
MakeCGridDescriptor_M_N
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
}
static
auto
MakeLSEGridDescriptor_M
(
index_t
MRaw
)
...
...
@@ -557,6 +583,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
({},
{}));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
({},
{}));
using
D0GridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
B1GridDesc_BK0_N_BK1
=
decltype
(
MakeB1GridDescriptor_BK0_N_BK1
({},
{}));
using
YGridDesc_M_O
=
decltype
(
Transform
::
MakeCGridDescriptor_M_N
({},
{}));
using
LSEGridDesc_M
=
decltype
(
MakeLSEGridDescriptor_M
(
1
));
...
...
@@ -566,6 +593,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
using
CGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
ZGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
D0GridDesc_M_N
=
decltype
(
MakeDGridDescriptor_M_N
({},
{}));
using
KGridDesc_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_N_K
({},
{}));
using
YGradGridDesc_M0_O_M1
=
decltype
(
MakeYGradGridDescriptor_M0_O_M1
(
YGridDesc_M_O
{}));
using
ZGridDesc_M_N
=
decltype
(
MakeZGridDescriptor_M_N
({},
{}));
...
...
@@ -589,14 +617,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
struct
ComputeBasePtrOfStridedBatch
{
ComputeBasePtrOfStridedBatch
()
{}
ComputeBasePtrOfStridedBatch
(
const
AGridDesc_G_M_K
&
a_grid_desc_g_m_k
,
const
BGridDesc_G_N_K
&
b_grid_desc_g_n_k
,
const
D0GridDesc_G_M_N
&
d0_grid_desc_g_m_n
,
const
ZGridDesc_G_M_N
&
z_grid_desc_g_m_n
,
const
B1GridDesc_G_N_K
&
b1_grid_desc_g_n_k
,
const
CGridDesc_G_M_N
&
c_grid_desc_g_m_n
,
index_t
BatchStrideLSE
)
:
a_grid_desc_g_m_k_
(
a_grid_desc_g_m_k
),
b_grid_desc_g_n_k_
(
b_grid_desc_g_n_k
),
d0_grid_desc_g_m_n_
(
d0_grid_desc_g_m_n
),
z_grid_desc_g_m_n_
(
z_grid_desc_g_m_n
),
b1_grid_desc_g_n_k_
(
b1_grid_desc_g_n_k
),
c_grid_desc_g_m_n_
(
c_grid_desc_g_m_n
),
...
...
@@ -614,6 +645,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
return
b_grid_desc_g_n_k_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
__host__
__device__
constexpr
long_index_t
GetD0BasePtr
(
index_t
g_idx
)
const
{
return
d0_grid_desc_g_m_n_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
__host__
__device__
constexpr
long_index_t
GetZBasePtr
(
index_t
g_idx
)
const
{
return
z_grid_desc_g_m_n_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
...
...
@@ -637,6 +672,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
private:
AGridDesc_G_M_K
a_grid_desc_g_m_k_
;
BGridDesc_G_N_K
b_grid_desc_g_n_k_
;
D0GridDesc_G_M_N
d0_grid_desc_g_m_n_
;
ZGridDesc_G_M_N
z_grid_desc_g_m_n_
;
B1GridDesc_G_N_K
b1_grid_desc_g_n_k_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
...
...
@@ -647,6 +683,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
<
InputDataType
,
// TODO: distinguish A/B datatype
D0DataType
,
OutputDataType
,
ZDataType
,
GemmDataType
,
...
...
@@ -662,6 +699,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
AGridDesc_AK0_M_AK1
,
BGridDesc_BK0_N_BK1
,
KGridDesc_N_K
,
D0GridDesc_M_N
,
ZGridDesc_M_N
,
B1GridDesc_BK0_N_BK1
,
YGridDesc_M_O
,
...
...
@@ -698,6 +736,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
BBlockTransferDstScalarPerVector_BK1
,
true
,
BBlockLdsExtraN
,
D0BlockTransferSrcScalarPerVector
,
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
B1BlockTransferThreadClusterArrangeOrder
,
B1BlockTransferSrcAccessOrder
,
...
...
@@ -718,8 +757,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
const
InputDataType
*
p_a_grid
,
Argument
(
const
InputDataType
*
p_a_grid
,
const
InputDataType
*
p_b_grid
,
ZDataType
*
p_z_grid
,
const
InputDataType
*
p_b1_grid
,
...
...
@@ -729,8 +767,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
OutputDataType
*
p_qgrad_grid
,
OutputDataType
*
p_kgrad_grid
,
OutputDataType
*
p_vgrad_grid
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_bias
es
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_bias
es
,
const
D0DataType
*
p_acc0_bias
,
const
D1DataType
*
p_acc1_bias
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
...
...
@@ -742,12 +780,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_bias
es
_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_bias
es
_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
acc1_bias
es
_gs_ms_gemm1ns_lengths
,
// acc1_bias
es
_gs_ms_os_lengths
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
acc1_bias
es
_gs_ms_gemm1ns_strides
,
// acc1_bias
es
_gs_ms_os_strides
const
std
::
vector
<
ck
::
index_t
>
&
acc0_bias_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>
&
acc0_bias_gs_ms_ns_strides
,
const
std
::
vector
<
ck
::
index_t
>
&
acc1_bias_gs_ms_gemm1ns_lengths
,
// acc1_bias_gs_ms_os_lengths
const
std
::
vector
<
ck
::
index_t
>
&
acc1_bias_gs_ms_gemm1ns_strides
,
// acc1_bias_gs_ms_os_strides
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
AccElementwiseOperation
acc_element_op
,
...
...
@@ -757,6 +795,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_d0_grid_
{
p_acc0_bias
},
p_z_grid_
{
p_z_grid
},
p_b1_grid_
{
p_b1_grid
},
p_c_grid_
{
p_c_grid
},
...
...
@@ -810,22 +849,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
c_mz_gemm1nz_strides_
{
c_gs_ms_gemm1ns_strides
[
NumDimG
+
NumDimM
-
1
],
c_gs_ms_gemm1ns_strides
[
NumDimG
+
NumDimM
+
NumDimO
-
1
]},
batch_count_
{
c_grid_desc_g_m_n_
.
GetLength
(
I0
)},
compute_base_ptr_of_batch_
{
a_grid_desc_g_m_k_
,
b_grid_desc_g_n_k_
,
z_grid_desc_g_m_n_
,
b1_grid_desc_g_n_k_
,
c_grid_desc_g_m_n_
,
type_convert
<
index_t
>
(
lse_grid_desc_m_
.
GetElementSpaceSize
())},
p_drop_
{
p_drop
}
{
// TODO: implement bias addition
ignore
=
p_acc0_biases
;
ignore
=
p_acc1_biases
;
ignore
=
acc0_biases_gs_ms_ns_lengths
;
ignore
=
acc0_biases_gs_ms_ns_strides
;
ignore
=
acc1_biases_gs_ms_gemm1ns_lengths
;
ignore
=
acc1_biases_gs_ms_gemm1ns_strides
;
ignore
=
p_acc1_bias
;
ignore
=
acc1_bias_gs_ms_gemm1ns_lengths
;
ignore
=
acc1_bias_gs_ms_gemm1ns_strides
;
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
b_grid_desc_bk0_n_bk1_
,
...
...
@@ -837,6 +866,29 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
y_grid_desc_m_o_
);
}
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
const
auto
d0_grid_desc_m_n
=
MakeDGridDescriptor_M_N
(
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
);
d0_grid_desc_m0_n0_m1_m2_n1_m3_
=
GridwiseGemm
::
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3
(
d0_grid_desc_m_n
);
d0_grid_desc_g_m_n_
=
Transform
::
MakeCGridDescriptor_G_M_N
(
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
);
d0_n_length_stride_
.
push_back
(
acc0_bias_gs_ms_ns_lengths
[
NumDimG
+
NumDimM
]);
d0_n_length_stride_
.
push_back
(
acc0_bias_gs_ms_ns_strides
[
NumDimG
+
NumDimM
]);
}
compute_base_ptr_of_batch_
=
ComputeBasePtrOfStridedBatch
(
a_grid_desc_g_m_k_
,
b_grid_desc_g_n_k_
,
d0_grid_desc_g_m_n_
,
z_grid_desc_g_m_n_
,
b1_grid_desc_g_n_k_
,
c_grid_desc_g_m_n_
,
type_convert
<
index_t
>
(
lse_grid_desc_m_
.
GetElementSpaceSize
()));
seed_
=
std
::
get
<
0
>
(
seeds
);
offset_
=
std
::
get
<
1
>
(
seeds
);
...
...
@@ -876,6 +928,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// pointers
const
InputDataType
*
p_a_grid_
;
const
InputDataType
*
p_b_grid_
;
const
D0DataType
*
p_d0_grid_
;
ZDataType
*
p_z_grid_
;
const
InputDataType
*
p_b1_grid_
;
const
InputDataType
*
p_c_grid_
;
...
...
@@ -888,6 +941,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// tensor descriptor
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
typename
GridwiseGemm
::
D0GridDescriptor_M0_N0_M1_M2_N1_M3
d0_grid_desc_m0_n0_m1_m2_n1_m3_
;
ZGridDesc_M_N
z_grid_desc_m_n_
;
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1_
;
YGridDesc_M_O
y_grid_desc_m_o_
;
...
...
@@ -898,6 +952,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// batch offsets
AGridDesc_G_M_K
a_grid_desc_g_m_k_
;
BGridDesc_G_N_K
b_grid_desc_g_n_k_
;
D0GridDesc_G_M_N
d0_grid_desc_g_m_n_
;
B1GridDesc_G_N_K
b1_grid_desc_g_n_k_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
ZGridDesc_G_M_N
z_grid_desc_g_m_n_
;
...
...
@@ -936,6 +991,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
index_t
m_raw_padded_
;
index_t
n_raw_padded_
;
// raw data
std
::
vector
<
ck
::
index_t
>
d0_n_length_stride_
;
};
// Invoker
...
...
@@ -966,6 +1024,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
kernel_batched_multihead_attention_backward_qloop_xdl_cshuffle_v2
<
GridwiseGemm
,
InputDataType
,
D0DataType
,
OutputDataType
,
ZDataType
,
LSEDataType
,
...
...
@@ -976,6 +1035,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
CElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
D0GridDescriptor_M0_N0_M1_M2_N1_M3
,
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
,
DeviceOp
::
B1GridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
...
...
@@ -996,6 +1056,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_d0_grid_
,
arg
.
p_z_grid_
,
arg
.
p_b1_grid_
,
arg
.
p_c_grid_
,
...
...
@@ -1011,6 +1072,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
arg
.
c_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
d0_grid_desc_m0_n0_m1_m2_n1_m3_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
y_grid_desc_mblock_mperblock_oblock_operblock_
,
...
...
@@ -1093,6 +1155,18 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
return
false
;
}
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
if
(
arg
.
d0_n_length_stride_
[
1
]
==
1
&&
arg
.
d0_n_length_stride_
[
0
]
%
D0BlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
if
(
arg
.
d0_n_length_stride_
[
1
]
!=
1
&&
D0BlockTransferSrcScalarPerVector
!=
1
)
{
return
false
;
}
}
// Note: we need raw lengths since threadwise copy can not handle vector load when part of
// vector is out of bounds
// Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O
...
...
@@ -1143,8 +1217,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
InputDataType
*
p_a
,
static
auto
MakeArgument
(
const
InputDataType
*
p_a
,
const
InputDataType
*
p_b
,
ZDataType
*
p_z
,
const
InputDataType
*
p_b1
,
...
...
@@ -1154,8 +1228,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
OutputDataType
*
p_qgrad_grid
,
OutputDataType
*
p_kgrad_grid
,
OutputDataType
*
p_vgrad_grid
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_bias
es
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_bias
es
,
const
D0DataType
*
p_acc0_bias
,
const
D1DataType
*
p_acc1_bias
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
...
...
@@ -1167,12 +1241,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_bias
es
_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_bias
es
_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
acc1_bias
es
_gs_ms_gemm1ns_lengths
,
// acc1_bias
es
_gs_ms_os_lengths
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
acc1_bias
es
_gs_ms_gemm1ns_strides
,
// acc1_bias
es
_gs_ms_os_strides
const
std
::
vector
<
ck
::
index_t
>
&
acc0_bias_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>
&
acc0_bias_gs_ms_ns_strides
,
const
std
::
vector
<
ck
::
index_t
>
&
acc1_bias_gs_ms_gemm1ns_lengths
,
// acc1_bias_gs_ms_os_lengths
const
std
::
vector
<
ck
::
index_t
>
&
acc1_bias_gs_ms_gemm1ns_strides
,
// acc1_bias_gs_ms_os_strides
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
AccElementwiseOperation
acc_element_op
,
...
...
@@ -1191,8 +1265,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_qgrad_grid
,
p_kgrad_grid
,
p_vgrad_grid
,
p_acc0_bias
es
,
p_acc1_bias
es
,
p_acc0_bias
,
p_acc1_bias
,
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
b_gs_ns_ks_lengths
,
...
...
@@ -1204,10 +1278,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
lse_gs_ms_lengths
,
acc0_bias
es
_gs_ms_ns_lengths
,
acc0_bias
es
_gs_ms_ns_strides
,
acc1_bias
es
_gs_ms_gemm1ns_lengths
,
// acc1_bias
es
_gs_ms_os_lengths
acc1_bias
es
_gs_ms_gemm1ns_strides
,
// acc1_bias
es
_gs_ms_os_strides
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
,
acc1_bias_gs_ms_gemm1ns_lengths
,
// acc1_bias_gs_ms_os_lengths
acc1_bias_gs_ms_gemm1ns_strides
,
// acc1_bias_gs_ms_os_strides
a_element_op
,
b_element_op
,
acc_element_op
,
...
...
@@ -1232,8 +1306,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
void
*
p_qgrad_grid
,
void
*
p_kgrad_grid
,
void
*
p_vgrad_grid
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_bias
es
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_bias
es
,
const
void
*
p_acc0_bias
,
const
void
*
p_acc1_bias
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
...
...
@@ -1245,12 +1319,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_bias
es
_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_bias
es
_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
acc1_bias
es
_gs_ms_gemm1ns_lengths
,
// acc1_bias
es
_gs_ms_os_lengths
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
acc1_bias
es
_gs_ms_gemm1ns_strides
,
// acc1_bias
es
_gs_ms_os_strides
const
std
::
vector
<
ck
::
index_t
>
&
acc0_bias_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>
&
acc0_bias_gs_ms_ns_strides
,
const
std
::
vector
<
ck
::
index_t
>
&
acc1_bias_gs_ms_gemm1ns_lengths
,
// acc1_bias_gs_ms_os_lengths
const
std
::
vector
<
ck
::
index_t
>
&
acc1_bias_gs_ms_gemm1ns_strides
,
// acc1_bias_gs_ms_os_strides
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
AccElementwiseOperation
acc_element_op
,
...
...
@@ -1259,7 +1333,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
float
p_drop
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
// override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
InputDataType
*>
(
p_a
),
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
InputDataType
*>
(
p_a
),
static_cast
<
const
InputDataType
*>
(
p_b
),
static_cast
<
ZDataType
*>
(
p_z
),
static_cast
<
const
InputDataType
*>
(
p_b1
),
...
...
@@ -1269,8 +1344,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static_cast
<
OutputDataType
*>
(
p_qgrad_grid
),
static_cast
<
OutputDataType
*>
(
p_kgrad_grid
),
static_cast
<
OutputDataType
*>
(
p_vgrad_grid
),
p_acc0_bias
es
,
// cast in struct Argument
p_acc1_bias
es
,
// cast in struct Argument
static_cast
<
const
D0DataType
*>
(
p_acc0_bias
)
,
// cast in struct Argument
static_cast
<
const
D1DataType
*>
(
p_acc1_bias
)
,
// cast in struct Argument
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
b_gs_ns_ks_lengths
,
...
...
@@ -1282,10 +1357,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
lse_gs_ms_lengths
,
acc0_bias
es
_gs_ms_ns_lengths
,
acc0_bias
es
_gs_ms_ns_strides
,
acc1_bias
es
_gs_ms_gemm1ns_lengths
,
acc1_bias
es
_gs_ms_gemm1ns_strides
,
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
,
acc1_bias_gs_ms_gemm1ns_lengths
,
acc1_bias_gs_ms_gemm1ns_strides
,
a_element_op
,
b_element_op
,
acc_element_op
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v1.hpp
View file @
4c8b47c0
...
...
@@ -27,6 +27,7 @@ namespace tensor_operation {
namespace
device
{
template
<
typename
GridwiseGemm
,
typename
D0DataType
,
typename
GroupKernelArg
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
...
...
@@ -101,6 +102,15 @@ __global__ void
(
arg_ptr
[
group_id
].
p_z_grid_
==
nullptr
?
nullptr
:
arg_ptr
[
group_id
].
p_z_grid_
+
z_batch_offset
);
const
D0DataType
*
tmp_p_d0_grid
=
nullptr
;
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
const
long_index_t
d0_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetD0BasePtr
(
g_idx
)));
tmp_p_d0_grid
=
arg_ptr
[
group_id
].
p_d0_grid_
+
d0_batch_offset
;
}
if
constexpr
(
Deterministic
)
{
for
(
index_t
i
=
0
;
i
<
num_blocks_per_batch
;
i
++
)
...
...
@@ -108,6 +118,7 @@ __global__ void
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
>(
arg_ptr
[
group_id
].
p_a_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_b_grid_
+
b_batch_offset
,
tmp_p_d0_grid
,
z_matrix_ptr
,
arg_ptr
[
group_id
].
p_b1_grid_
+
b1_batch_offset
,
arg_ptr
[
group_id
].
p_c_grid_
+
c_batch_offset
,
...
...
@@ -124,6 +135,7 @@ __global__ void
c_element_op
,
arg_ptr
[
group_id
].
a_grid_desc_ak0_m_ak1_
,
arg_ptr
[
group_id
].
b_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
d0_grid_desc_m0_n0_m1_m2_n1_m3_
,
arg_ptr
[
group_id
].
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
y_grid_desc_mblock_mperblock_oblock_operblock_
,
...
...
@@ -144,6 +156,7 @@ __global__ void
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
>(
arg_ptr
[
group_id
].
p_a_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_b_grid_
+
b_batch_offset
,
tmp_p_d0_grid
,
z_matrix_ptr
,
arg_ptr
[
group_id
].
p_b1_grid_
+
b1_batch_offset
,
arg_ptr
[
group_id
].
p_c_grid_
+
c_batch_offset
,
...
...
@@ -160,6 +173,7 @@ __global__ void
c_element_op
,
arg_ptr
[
group_id
].
a_grid_desc_ak0_m_ak1_
,
arg_ptr
[
group_id
].
b_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
d0_grid_desc_m0_n0_m1_m2_n1_m3_
,
arg_ptr
[
group_id
].
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
y_grid_desc_mblock_mperblock_oblock_operblock_
,
...
...
@@ -245,6 +259,7 @@ template <index_t NumDimG,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
bool
BBlockLdsExtraN
,
index_t
D0BlockTransferSrcScalarPerVector
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
...
...
@@ -258,11 +273,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
&&
NumDimO
>
0
,
"Number of dimension must be greater than 0"
);
static
constexpr
index_t
NumAcc0Bias
=
Acc0BiasDataType
::
Size
()
;
static
constexpr
index_t
NumAcc1Bias
=
Acc1BiasDataType
::
Size
()
;
using
D0DataType
=
Acc0BiasDataType
;
using
D1DataType
=
Acc1BiasDataType
;
// TODO: implement bias combination
static_assert
(
NumAcc0Bias
==
0
&&
NumAcc0Bias
==
0
,
"Bias addition is unimplemented"
);
static_assert
(
is_same
<
D1DataType
,
void
>::
value
,
"Bias
1
addition is unimplemented"
);
using
DeviceOp
=
DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
;
struct
ProblemDesc
...
...
@@ -285,11 +300,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
std
::
vector
<
index_t
>
lse_gs_ms_lengths
;
std
::
vector
<
index_t
>
lse_gs_ms_strides
;
std
::
vector
<
std
::
vector
<
index_t
>
>
acc0_bias
es
_gs_ms_ns_lengths
;
std
::
vector
<
std
::
vector
<
index_t
>
>
acc0_bias
es
_gs_ms_ns_strides
;
std
::
vector
<
index_t
>
acc0_bias_gs_ms_ns_lengths
;
std
::
vector
<
index_t
>
acc0_bias_gs_ms_ns_strides
;
std
::
vector
<
std
::
vector
<
index_t
>
>
acc1_bias
es
_gs_ms_os_lengths
;
std
::
vector
<
std
::
vector
<
index_t
>
>
acc1_bias
es
_gs_ms_os_strides
;
std
::
vector
<
index_t
>
acc1_bias_gs_ms_os_lengths
;
std
::
vector
<
index_t
>
acc1_bias_gs_ms_os_strides
;
};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -326,20 +341,20 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
*/
// Q in Gemm A position
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
_vec
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
_vec
)
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
)
{
return
Transform
::
MakeAGridDescriptor_AK0_M_AK1
(
Transform
::
MakeAGridDescriptor_M_K
(
a_gs_ms_ks_lengths
_vec
,
a_gs_ms_ks_strides
_vec
),
Transform
::
MakeAGridDescriptor_M_K
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
),
Number
<
AK1
>
{});
}
// K in Gemm B0 position
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
_vec
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides
_vec
)
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides
)
{
return
Transform
::
MakeB0GridDescriptor_BK0_N_BK1
(
Transform
::
MakeB0GridDescriptor_N_K
(
b_gs_ns_ks_lengths
_vec
,
b_gs_ns_ks_strides
_vec
),
Transform
::
MakeB0GridDescriptor_N_K
(
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
),
Number
<
BK1
>
{});
}
//
...
...
@@ -347,8 +362,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
//
// VGrad in Gemm C position
static
auto
MakeVGradGridDescriptor_N_O
(
const
std
::
vector
<
index_t
>&
v_gs_os_ns_lengths
_vec
,
const
std
::
vector
<
index_t
>&
v_gs_os_ns_strides
_vec
)
static
auto
MakeVGradGridDescriptor_N_O
(
const
std
::
vector
<
index_t
>&
v_gs_os_ns_lengths
,
const
std
::
vector
<
index_t
>&
v_gs_os_ns_strides
)
{
// v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major.
// Here directly rearrange lengths/strides before constructing tensor descriptor to reduce
...
...
@@ -374,17 +389,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
ids_old2new
.
insert
(
ids_old2new
.
end
(),
ns_ids
.
begin
(),
ns_ids
.
end
());
ids_old2new
.
insert
(
ids_old2new
.
end
(),
os_ids
.
begin
(),
os_ids
.
end
());
std
::
vector
<
index_t
>
v_gs_ns_os_lengths
_vec
(
num_dims
),
v_gs_ns_os_strides
_vec
(
num_dims
);
std
::
vector
<
index_t
>
v_gs_ns_os_lengths
(
num_dims
),
v_gs_ns_os_strides
(
num_dims
);
for
(
int
i
=
0
;
i
<
num_dims
;
i
++
)
{
index_t
id_new
=
ids_old2new
[
i
];
v_gs_ns_os_lengths
_vec
[
i
]
=
v_gs_os_ns_lengths
_vec
[
id_new
];
v_gs_ns_os_strides
_vec
[
i
]
=
v_gs_os_ns_strides
_vec
[
id_new
];
v_gs_ns_os_lengths
[
i
]
=
v_gs_os_ns_lengths
[
id_new
];
v_gs_ns_os_strides
[
i
]
=
v_gs_os_ns_strides
[
id_new
];
}
const
auto
vgrad_desc_nraw_oraw
=
MakeGridDescriptorPair
<
NumDimG
,
NumDimN
,
NumDimO
,
TensorSpecialization
::
Default
>
(
v_gs_ns_os_lengths
_vec
,
v_gs_ns_os_strides
_vec
)
v_gs_ns_os_lengths
,
v_gs_ns_os_strides
)
.
second
;
return
PadTensorDescriptor
(
vgrad_desc_nraw_oraw
,
...
...
@@ -395,17 +410,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
//
// dQ = alpha * dS * K
//
static
auto
MakeYGradGridDescriptor_O0_M_O1
(
const
std
::
vector
<
index_t
>&
y_gs_ms_os_lengths
_vec
,
const
std
::
vector
<
index_t
>&
y_gs_ms_os_strides
_vec
)
static
auto
MakeYGradGridDescriptor_O0_M_O1
(
const
std
::
vector
<
index_t
>&
y_gs_ms_os_lengths
,
const
std
::
vector
<
index_t
>&
y_gs_ms_os_strides
)
{
return
Transform
::
MakeAGridDescriptor_AK0_M_AK1
(
Transform
::
MakeAGridDescriptor_M_K
(
y_gs_ms_os_lengths
_vec
,
y_gs_ms_os_strides
_vec
),
Transform
::
MakeAGridDescriptor_M_K
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
),
Number
<
Y_O1
>
{});
}
// V in Gemm B position
static
auto
MakeVGridDescriptor_O0_N_O1
(
const
std
::
vector
<
index_t
>&
v_gs_os_ns_lengths
_vec
,
const
std
::
vector
<
index_t
>&
v_gs_os_ns_strides
_vec
)
static
auto
MakeVGridDescriptor_O0_N_O1
(
const
std
::
vector
<
index_t
>&
v_gs_os_ns_lengths
,
const
std
::
vector
<
index_t
>&
v_gs_os_ns_strides
)
{
// v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major.
// Here directly rearrange lengths/strides before constructing tensor descriptor to reduce
...
...
@@ -431,17 +446,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
ids_old2new
.
insert
(
ids_old2new
.
end
(),
ns_ids
.
begin
(),
ns_ids
.
end
());
ids_old2new
.
insert
(
ids_old2new
.
end
(),
os_ids
.
begin
(),
os_ids
.
end
());
std
::
vector
<
index_t
>
v_gs_ns_os_lengths
_vec
(
num_dims
),
v_gs_ns_os_strides
_vec
(
num_dims
);
std
::
vector
<
index_t
>
v_gs_ns_os_lengths
(
num_dims
),
v_gs_ns_os_strides
(
num_dims
);
for
(
int
i
=
0
;
i
<
num_dims
;
i
++
)
{
index_t
id_new
=
ids_old2new
[
i
];
v_gs_ns_os_lengths
_vec
[
i
]
=
v_gs_os_ns_lengths
_vec
[
id_new
];
v_gs_ns_os_strides
_vec
[
i
]
=
v_gs_os_ns_strides
_vec
[
id_new
];
v_gs_ns_os_lengths
[
i
]
=
v_gs_os_ns_lengths
[
id_new
];
v_gs_ns_os_strides
[
i
]
=
v_gs_os_ns_strides
[
id_new
];
}
const
auto
v_grid_desc_nraw_oraw
=
MakeGridDescriptorPair
<
NumDimG
,
NumDimN
,
NumDimO
,
TensorSpecialization
::
Default
>
(
v_gs_ns_os_lengths
_vec
,
v_gs_ns_os_strides
_vec
)
v_gs_ns_os_lengths
,
v_gs_ns_os_strides
)
.
second
;
const
auto
v_grid_desc_n_o
=
PadTensorDescriptor
(
v_grid_desc_nraw_oraw
,
...
...
@@ -452,10 +467,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
return
Transform
::
MakeB0GridDescriptor_BK0_N_BK1
(
v_grid_desc_n_o
,
Number
<
V_O1
>
{});
}
static
auto
MakeZGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
_vec
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
_vec
)
static
auto
MakeZGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
)
{
return
Transform
::
MakeCGridDescriptor_M_N
(
z_gs_ms_ns_lengths
_vec
,
z_gs_ms_ns_strides
_vec
);
return
Transform
::
MakeCGridDescriptor_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
}
static
auto
MakeLSEGridDescriptor_M
(
index_t
MRaw
)
...
...
@@ -482,6 +497,23 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
return
lse_grid_desc_mraw
;
}
}
// D in Gemm0 C position
static
auto
MakeD0GridDescriptor_M_N
(
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
)
{
return
Transform
::
MakeCGridDescriptor_M_N
(
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
);
}
static
auto
MakeD0GridDescriptor_G_M_N
(
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
)
{
return
Transform
::
MakeCGridDescriptor_G_M_N
(
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
);
}
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
({},
{}));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
({},
{}));
...
...
@@ -490,11 +522,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
using
LSEGridDesc_M
=
decltype
(
MakeLSEGridDescriptor_M
(
1
));
using
AGridDesc_G_M_K
=
decltype
(
Transform
::
MakeAGridDescriptor_G_M_K
({},
{}));
using
BGridDesc_G_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_G_N_K
({},
{}));
using
D0GridDesc_G_M_N
=
decltype
(
MakeD0GridDescriptor_G_M_N
({},
{}));
using
B1GridDesc_G_N_K
=
decltype
(
Transform
::
MakeB1GridDescriptor_G_N_K
({},
{}));
using
CGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
ZGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
KGridDesc_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_N_K
({},
{}));
using
D0GridDesc_M_N
=
decltype
(
MakeD0GridDescriptor_M_N
({},
{}));
using
YGradGridDesc_O0_M_O1
=
decltype
(
MakeYGradGridDescriptor_O0_M_O1
({},
{}));
using
ZGridDesc_M_N
=
decltype
(
MakeZGridDescriptor_M_N
({},
{}));
...
...
@@ -519,12 +553,14 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
{
ComputeBasePtrOfStridedBatch
(
const
AGridDesc_G_M_K
&
a_grid_desc_g_m_k
,
const
BGridDesc_G_N_K
&
b_grid_desc_g_n_k
,
const
D0GridDesc_G_M_N
&
d0_grid_desc_g_m_n
,
const
ZGridDesc_G_M_N
&
z_grid_desc_g_m_n
,
const
B1GridDesc_G_N_K
&
b1_grid_desc_g_n_k
,
const
CGridDesc_G_M_N
&
c_grid_desc_g_m_n
,
index_t
batch_stride_lse
)
:
a_grid_desc_g_m_k_
(
a_grid_desc_g_m_k
),
b_grid_desc_g_n_k_
(
b_grid_desc_g_n_k
),
d0_grid_desc_g_m_n_
(
d0_grid_desc_g_m_n
),
z_grid_desc_g_m_n_
(
z_grid_desc_g_m_n
),
b1_grid_desc_g_n_k_
(
b1_grid_desc_g_n_k
),
c_grid_desc_g_m_n_
(
c_grid_desc_g_m_n
),
...
...
@@ -542,6 +578,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
return
b_grid_desc_g_n_k_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
__host__
__device__
constexpr
long_index_t
GetD0BasePtr
(
index_t
g_idx
)
const
{
return
d0_grid_desc_g_m_n_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
__host__
__device__
constexpr
long_index_t
GetZBasePtr
(
index_t
g_idx
)
const
{
return
z_grid_desc_g_m_n_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
...
...
@@ -565,6 +606,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
private:
AGridDesc_G_M_K
a_grid_desc_g_m_k_
;
BGridDesc_G_N_K
b_grid_desc_g_n_k_
;
D0GridDesc_G_M_N
d0_grid_desc_g_m_n_
;
ZGridDesc_G_M_N
z_grid_desc_g_m_n_
;
B1GridDesc_G_N_K
b1_grid_desc_g_n_k_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
...
...
@@ -574,6 +616,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
<
InputDataType
,
// TODO: distinguish A/B datatype
D0DataType
,
OutputDataType
,
ZDataType
,
GemmDataType
,
...
...
@@ -589,6 +632,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
AGridDesc_AK0_M_AK1
,
BGridDesc_BK0_N_BK1
,
KGridDesc_N_K
,
D0GridDesc_M_N
,
ZGridDesc_M_N
,
B1GridDesc_BK0_N_BK1
,
YGridDesc_M_O
,
...
...
@@ -625,6 +669,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
BBlockTransferDstScalarPerVector_BK1
,
true
,
BBlockLdsExtraN
,
D0BlockTransferSrcScalarPerVector
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
...
...
@@ -641,6 +686,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// pointers
const
InputDataType
*
p_a_grid_
;
const
InputDataType
*
p_b_grid_
;
const
D0DataType
*
p_d0_grid_
;
ZDataType
*
p_z_grid_
;
const
InputDataType
*
p_b1_grid_
;
const
InputDataType
*
p_c_grid_
;
...
...
@@ -653,6 +699,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
typename
GridwiseGemm
::
D0GridDescriptor_M0_N0_M1_M2_N1_M3
d0_grid_desc_m0_n0_m1_m2_n1_m3_
;
ZGridDesc_M_N
z_grid_desc_m_n_
;
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1_
;
YGridDesc_M_O
y_grid_desc_m_o_
;
...
...
@@ -692,6 +739,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
index_t
batch_count_
;
// raw data
std
::
vector
<
ck
::
index_t
>
d0_n_length_stride_
;
};
// Argument
struct
Argument
:
public
BaseArgument
...
...
@@ -706,8 +756,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
std
::
vector
<
void
*>&
p_Qgrads
,
std
::
vector
<
void
*>&
p_Kgrads
,
std
::
vector
<
void
*>&
p_Vgrads
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>&
p_acc0_bias
es
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>&
p_acc1_bias
es
,
const
std
::
vector
<
const
void
*
>&
p_acc0_bias
_vec
,
const
std
::
vector
<
const
void
*
>&
p_acc1_bias
_vec
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
...
...
@@ -737,16 +787,14 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_Qgrads
.
size
())
&&
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_Kgrads
.
size
())
&&
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_Vgrads
.
size
())
&&
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_LSEs
.
size
())))
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_LSEs
.
size
())
&&
(
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_acc0_bias_vec
.
size
())
||
ck
::
type_convert
<
ck
::
index_t
>
(
p_acc0_bias_vec
.
size
()
==
0
))
&&
0
==
p_acc1_bias_vec
.
size
()))
{
throw
std
::
runtime_error
(
"wrong! group_count_ != p_As/b/b1/c.size"
);
}
if
(
!
(
p_acc0_biases
.
size
()
==
p_acc1_biases
.
size
()))
{
throw
std
::
runtime_error
(
"wrong! acc0_bias_vec.size != acc1_bias_vec.size"
);
}
grid_size_
=
0
;
index_t
z_random_matrix_offset
=
0
;
...
...
@@ -755,6 +803,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
{
const
auto
p_a_grid
=
static_cast
<
const
InputDataType
*>
(
p_As
[
i
]);
const
auto
p_b_grid
=
static_cast
<
const
InputDataType
*>
(
p_Bs
[
i
]);
const
auto
p_d0_grid
=
(
ck
::
type_convert
<
ck
::
index_t
>
(
p_acc0_bias_vec
.
size
())
==
group_count_
)
?
static_cast
<
const
D0DataType
*>
(
p_acc0_bias_vec
[
i
])
:
nullptr
;
auto
p_z_grid
=
static_cast
<
ZDataType
*>
(
p_Zs
[
i
]);
const
auto
p_b1_grid
=
static_cast
<
const
InputDataType
*>
(
p_B1s
[
i
]);
const
auto
p_c_grid
=
static_cast
<
const
InputDataType
*>
(
p_Cs
[
i
]);
...
...
@@ -770,6 +822,23 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
problem_desc
.
a_gs_ms_ks_lengths
,
problem_desc
.
a_gs_ms_ks_strides
);
const
auto
b_grid_desc_bk0_n_bk1
=
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
problem_desc
.
b_gs_ns_ks_lengths
,
problem_desc
.
b_gs_ns_ks_strides
);
std
::
vector
<
index_t
>
tmp_d0_gs_ms_ns_lengths
;
std
::
vector
<
index_t
>
tmp_d0_gs_ms_ns_strides
;
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
tmp_d0_gs_ms_ns_lengths
=
problem_desc
.
acc0_bias_gs_ms_ns_lengths
;
tmp_d0_gs_ms_ns_strides
=
problem_desc
.
acc0_bias_gs_ms_ns_strides
;
}
else
{
tmp_d0_gs_ms_ns_lengths
=
{
1
,
1
,
1
,
1
};
tmp_d0_gs_ms_ns_strides
=
{
0
,
0
,
0
,
0
};
}
const
D0GridDesc_M_N
d0_grid_desc_m_n
{
DeviceOp
::
MakeD0GridDescriptor_M_N
(
tmp_d0_gs_ms_ns_lengths
,
tmp_d0_gs_ms_ns_strides
)};
const
auto
d0_grid_desc_m0_n0_m1_m2_n1_m3
=
GridwiseGemm
::
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3
(
d0_grid_desc_m_n
);
const
auto
z_grid_desc_m_n
=
DeviceOp
::
MakeZGridDescriptor_M_N
(
problem_desc
.
z_gs_ms_ns_lengths
,
problem_desc
.
z_gs_ms_ns_strides
);
const
auto
b1_grid_desc_bk0_n_bk1
=
DeviceOp
::
MakeVGridDescriptor_O0_N_O1
(
...
...
@@ -789,6 +858,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
problem_desc
.
a_gs_ms_ks_lengths
,
problem_desc
.
a_gs_ms_ks_strides
);
const
auto
b_grid_desc_g_n_k
=
Transform
::
MakeB0GridDescriptor_G_N_K
(
problem_desc
.
b_gs_ns_ks_lengths
,
problem_desc
.
b_gs_ns_ks_strides
);
const
auto
d0_grid_desc_g_m_n
=
DeviceOp
::
MakeD0GridDescriptor_G_M_N
(
tmp_d0_gs_ms_ns_lengths
,
tmp_d0_gs_ms_ns_strides
);
const
auto
z_grid_desc_g_m_n
=
Transform
::
MakeCGridDescriptor_G_M_N
(
problem_desc
.
z_gs_ms_ns_lengths
,
problem_desc
.
z_gs_ms_ns_strides
);
const
auto
b1_grid_desc_g_n_k
=
Transform
::
MakeB1GridDescriptor_G_N_K
(
...
...
@@ -825,6 +896,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const
auto
compute_base_ptr_of_batch
=
ComputeBasePtrOfStridedBatch
(
a_grid_desc_g_m_k
,
b_grid_desc_g_n_k
,
d0_grid_desc_g_m_n
,
z_grid_desc_g_m_n
,
b1_grid_desc_g_n_k
,
c_grid_desc_g_m_n
,
...
...
@@ -836,18 +908,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
grid_size_
+=
grid_size_grp
;
// for each group, make sure acc0_biases_gs_ms_ns_lengths.size() == NumAcc0Bias and
// so on
if
(
!
(
problem_desc
.
acc0_biases_gs_ms_ns_lengths
.
size
()
==
NumAcc0Bias
&&
problem_desc
.
acc0_biases_gs_ms_ns_strides
.
size
()
==
NumAcc0Bias
&&
problem_desc
.
acc1_biases_gs_ms_os_lengths
.
size
()
==
NumAcc1Bias
&&
problem_desc
.
acc1_biases_gs_ms_os_strides
.
size
()
==
NumAcc1Bias
))
{
throw
std
::
runtime_error
(
"wrong! number of biases in function argument does not "
"match that in template argument"
);
}
const
auto
raw_m_padded
=
GridwiseGemm
::
GetPaddedSize
(
problem_desc
.
a_gs_ms_ks_lengths
[
NumDimG
+
NumDimM
-
1
]);
const
auto
raw_n_padded
=
GridwiseGemm
::
GetPaddedSize
(
...
...
@@ -855,6 +915,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
group_kernel_args_
.
push_back
({
p_a_grid
,
p_b_grid
,
p_d0_grid
,
p_z_grid
,
p_b1_grid
,
p_c_grid
,
...
...
@@ -865,6 +926,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
p_vgrad_grid
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
z_grid_desc_m_n
,
b1_grid_desc_bk0_n_bk1
,
y_grid_desc_m_o
,
...
...
@@ -886,6 +948,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
z_random_matrix_offset
=
z_random_matrix_offset
+
raw_m_padded
*
raw_n_padded
*
batch_count
;
// for check
std
::
vector
<
ck
::
index_t
>
d0_n_length_stride
;
d0_n_length_stride
.
push_back
(
tmp_d0_gs_ms_ns_lengths
[
NumDimG
+
NumDimM
]);
d0_n_length_stride
.
push_back
(
tmp_d0_gs_ms_ns_strides
[
NumDimG
+
NumDimM
]);
group_device_args_
.
push_back
(
{{
problem_desc
.
a_gs_ms_ks_lengths
[
NumDimG
+
NumDimM
-
1
],
problem_desc
.
b_gs_ns_ks_lengths
[
NumDimG
+
NumDimN
-
1
],
...
...
@@ -900,15 +967,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
{
problem_desc
.
c_gs_ms_gemm1ns_strides
[
NumDimG
+
NumDimM
-
1
],
problem_desc
.
c_gs_ms_gemm1ns_strides
[
NumDimG
+
NumDimM
+
NumDimO
-
1
]},
c_grid_desc_g_m_n
,
batch_count
});
batch_count
,
d0_n_length_stride
});
}
// TODO: implement bias addition
// ignore = p_acc0_bias
es
;
// ignore = p_acc1_bias
es
;
// ignore = acc0_bias
es
_gs_ms_ns_lengths;
// ignore = acc0_bias
es
_gs_ms_ns_strides;
// ignore = acc1_bias
es
_gs_ms_gemm1ns_lengths;
// ignore = acc1_bias
es
_gs_ms_gemm1ns_strides;
// ignore = p_acc0_bias
_vec
;
// ignore = p_acc1_bias
_vec
;
// ignore = acc0_bias_gs_ms_ns_lengths;
// ignore = acc0_bias_gs_ms_ns_strides;
// ignore = acc1_bias_gs_ms_gemm1ns_lengths;
// ignore = acc1_bias_gs_ms_gemm1ns_strides;
}
// element-wise op
...
...
@@ -964,6 +1032,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const
auto
kernel
=
kernel_grouped_multihead_attention_backward_qloop_xdl_cshuffle_v1
<
GridwiseGemm
,
D0DataType
,
GroupKernelArg
,
AElementwiseOperation
,
BElementwiseOperation
,
...
...
@@ -1061,6 +1130,19 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
return
false
;
}
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
if
(
device_arg
.
d0_n_length_stride_
[
1
]
==
1
&&
device_arg
.
d0_n_length_stride_
[
0
]
%
D0BlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
if
(
device_arg
.
d0_n_length_stride_
[
1
]
!=
1
&&
D0BlockTransferSrcScalarPerVector
!=
1
)
{
return
false
;
}
}
// Note: we need raw lengths since threadwise copy can not handle vector load when part
// of vector is out of bounds Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O
const
auto
MzRaw
=
device_arg
.
raw_lengths_mz_nz_kz_gemm1nz_
[
0
];
...
...
@@ -1128,8 +1210,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
std
::
vector
<
void
*>&
p_Qgrads
,
std
::
vector
<
void
*>&
p_Kgrads
,
std
::
vector
<
void
*>&
p_Vgrads
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>&
p_acc0_bias
es
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>&
p_acc1_bias
es
,
const
std
::
vector
<
const
void
*
>&
p_acc0_bias
_vec
,
const
std
::
vector
<
const
void
*
>&
p_acc1_bias
_vec
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
...
...
@@ -1149,8 +1231,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
p_Qgrads
,
p_Kgrads
,
p_Vgrads
,
p_acc0_bias
es
,
p_acc1_bias
es
,
p_acc0_bias
_vec
,
p_acc1_bias
_vec
,
problem_desc_vec
,
a_element_op
,
b_element_op
,
...
...
@@ -1176,8 +1258,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
std
::
vector
<
void
*>&
p_Qgrads
,
std
::
vector
<
void
*>&
p_Kgrads
,
std
::
vector
<
void
*>&
p_Vgrads
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>&
p_acc0_bias
es
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>&
p_acc1_bias
es
,
const
std
::
vector
<
const
void
*
>&
p_acc0_bias
_vec
,
const
std
::
vector
<
const
void
*
>&
p_acc1_bias
_vec
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
...
...
@@ -1197,8 +1279,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
p_Qgrads
,
p_Kgrads
,
p_Vgrads
,
p_acc0_bias
es
,
// cast in struct Argument
p_acc1_bias
es
,
// cast in struct Argument
p_acc0_bias
_vec
,
// cast in struct Argument
p_acc1_bias
_vec
,
// cast in struct Argument
problem_desc_vec
,
a_element_op
,
b_element_op
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
View file @
4c8b47c0
...
...
@@ -27,6 +27,7 @@ namespace tensor_operation {
namespace
device
{
template
<
typename
GridwiseGemm
,
typename
D0DataType
,
typename
GroupKernelArg
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
...
...
@@ -100,6 +101,15 @@ __global__ void
auto
z_matrix_ptr
=
(
arg_ptr
[
group_id
].
p_z_grid_
==
nullptr
?
nullptr
:
arg_ptr
[
group_id
].
p_z_grid_
+
z_batch_offset
);
const
D0DataType
*
tmp_p_d0_grid
=
nullptr
;
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
const
long_index_t
d0_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetD0BasePtr
(
g_idx
)));
tmp_p_d0_grid
=
arg_ptr
[
group_id
].
p_d0_grid_
+
d0_batch_offset
;
}
if
constexpr
(
Deterministic
)
{
...
...
@@ -108,6 +118,7 @@ __global__ void
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
>(
arg_ptr
[
group_id
].
p_a_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_b_grid_
+
b_batch_offset
,
tmp_p_d0_grid
,
z_matrix_ptr
,
arg_ptr
[
group_id
].
p_b1_grid_
+
b1_batch_offset
,
arg_ptr
[
group_id
].
p_c_grid_
+
c_batch_offset
,
...
...
@@ -124,6 +135,7 @@ __global__ void
c_element_op
,
arg_ptr
[
group_id
].
a_grid_desc_ak0_m_ak1_
,
arg_ptr
[
group_id
].
b_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
d0_grid_desc_m0_n0_m1_m2_n1_m3_
,
arg_ptr
[
group_id
].
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
y_grid_desc_mblock_mperblock_oblock_operblock_
,
...
...
@@ -144,6 +156,7 @@ __global__ void
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
>(
arg_ptr
[
group_id
].
p_a_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_b_grid_
+
b_batch_offset
,
tmp_p_d0_grid
,
z_matrix_ptr
,
arg_ptr
[
group_id
].
p_b1_grid_
+
b1_batch_offset
,
arg_ptr
[
group_id
].
p_c_grid_
+
c_batch_offset
,
...
...
@@ -160,6 +173,7 @@ __global__ void
c_element_op
,
arg_ptr
[
group_id
].
a_grid_desc_ak0_m_ak1_
,
arg_ptr
[
group_id
].
b_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
d0_grid_desc_m0_n0_m1_m2_n1_m3_
,
arg_ptr
[
group_id
].
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
y_grid_desc_mblock_mperblock_oblock_operblock_
,
...
...
@@ -245,6 +259,7 @@ template <index_t NumDimG,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
bool
BBlockLdsExtraN
,
index_t
D0BlockTransferSrcScalarPerVector
,
typename
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
B1BlockTransferThreadClusterArrangeOrder
,
typename
B1BlockTransferSrcAccessOrder
,
...
...
@@ -265,11 +280,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
&&
NumDimO
>
0
,
"Number of dimension must be greater than 0"
);
static
constexpr
index_t
NumAcc0Bias
=
Acc0BiasDataType
::
Size
()
;
static
constexpr
index_t
NumAcc1Bias
=
Acc1BiasDataType
::
Size
()
;
using
D0DataType
=
Acc0BiasDataType
;
using
D1DataType
=
Acc1BiasDataType
;
// TODO: implement bias combination
static_assert
(
NumAcc0Bias
==
0
&&
NumAcc0Bias
==
0
,
"Bias addition is unimplemented"
);
static_assert
(
is_same
<
D1DataType
,
void
>::
value
,
"Bias
1
addition is unimplemented"
);
using
DeviceOp
=
DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
;
struct
ProblemDesc
...
...
@@ -292,11 +307,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
std
::
vector
<
index_t
>
lse_gs_ms_lengths
;
std
::
vector
<
index_t
>
lse_gs_ms_strides
;
std
::
vector
<
std
::
vector
<
index_t
>
>
acc0_bias
es
_gs_ms_ns_lengths
;
std
::
vector
<
std
::
vector
<
index_t
>
>
acc0_bias
es
_gs_ms_ns_strides
;
std
::
vector
<
index_t
>
acc0_bias_gs_ms_ns_lengths
;
std
::
vector
<
index_t
>
acc0_bias_gs_ms_ns_strides
;
std
::
vector
<
std
::
vector
<
index_t
>
>
acc1_bias
es
_gs_ms_os_lengths
;
std
::
vector
<
std
::
vector
<
index_t
>
>
acc1_bias
es
_gs_ms_os_strides
;
std
::
vector
<
index_t
>
acc1_bias_gs_ms_os_lengths
;
std
::
vector
<
index_t
>
acc1_bias_gs_ms_os_strides
;
};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -333,31 +348,31 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
*/
// Q in Gemm A position
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
_vec
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
_vec
)
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
)
{
return
Transform
::
MakeAGridDescriptor_AK0_M_AK1
(
Transform
::
MakeAGridDescriptor_M_K
(
a_gs_ms_ks_lengths
_vec
,
a_gs_ms_ks_strides
_vec
),
Transform
::
MakeAGridDescriptor_M_K
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
),
Number
<
AK1
>
{});
}
// K in Gemm B0 position
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
_vec
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides
_vec
)
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides
)
{
return
Transform
::
MakeB0GridDescriptor_BK0_N_BK1
(
Transform
::
MakeB0GridDescriptor_N_K
(
b_gs_ns_ks_lengths
_vec
,
b_gs_ns_ks_strides
_vec
),
Transform
::
MakeB0GridDescriptor_N_K
(
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
),
Number
<
BK1
>
{});
}
// V in Gemm B1 position
static
auto
MakeB1GridDescriptor_BK0_N_BK1
(
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_lengths
_vec
,
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_strides
_vec
)
MakeB1GridDescriptor_BK0_N_BK1
(
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_lengths
,
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_strides
)
{
return
Transform
::
MakeB1GridDescriptor_BK0_N_BK1
(
Transform
::
MakeB1GridDescriptor_N_K
(
b1_gs_gemm1ns_gemm1ks_lengths
_vec
,
b1_gs_gemm1ns_gemm1ks_strides
_vec
),
Transform
::
MakeB1GridDescriptor_N_K
(
b1_gs_gemm1ns_gemm1ks_lengths
,
b1_gs_gemm1ns_gemm1ks_strides
),
Number
<
B1K1
>
{});
}
...
...
@@ -366,8 +381,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
//
// VGrad in Gemm C position
static
auto
MakeVGradGridDescriptor_N_O
(
const
std
::
vector
<
index_t
>&
v_gs_os_ns_lengths
_vec
,
const
std
::
vector
<
index_t
>&
v_gs_os_ns_strides
_vec
)
static
auto
MakeVGradGridDescriptor_N_O
(
const
std
::
vector
<
index_t
>&
v_gs_os_ns_lengths
,
const
std
::
vector
<
index_t
>&
v_gs_os_ns_strides
)
{
// v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major.
// Here directly rearrange lengths/strides before constructing tensor descriptor to reduce
...
...
@@ -393,17 +408,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
ids_old2new
.
insert
(
ids_old2new
.
end
(),
ns_ids
.
begin
(),
ns_ids
.
end
());
ids_old2new
.
insert
(
ids_old2new
.
end
(),
os_ids
.
begin
(),
os_ids
.
end
());
std
::
vector
<
index_t
>
v_gs_ns_os_lengths
_vec
(
num_dims
),
v_gs_ns_os_strides
_vec
(
num_dims
);
std
::
vector
<
index_t
>
v_gs_ns_os_lengths
(
num_dims
),
v_gs_ns_os_strides
(
num_dims
);
for
(
int
i
=
0
;
i
<
num_dims
;
i
++
)
{
index_t
id_new
=
ids_old2new
[
i
];
v_gs_ns_os_lengths
_vec
[
i
]
=
v_gs_os_ns_lengths
_vec
[
id_new
];
v_gs_ns_os_strides
_vec
[
i
]
=
v_gs_os_ns_strides
_vec
[
id_new
];
v_gs_ns_os_lengths
[
i
]
=
v_gs_os_ns_lengths
[
id_new
];
v_gs_ns_os_strides
[
i
]
=
v_gs_os_ns_strides
[
id_new
];
}
const
auto
vgrad_desc_nraw_oraw
=
MakeGridDescriptorPair
<
NumDimG
,
NumDimN
,
NumDimO
,
TensorSpecialization
::
Default
>
(
v_gs_ns_os_lengths
_vec
,
v_gs_ns_os_strides
_vec
)
v_gs_ns_os_lengths
,
v_gs_ns_os_strides
)
.
second
;
return
PadTensorDescriptor
(
vgrad_desc_nraw_oraw
,
...
...
@@ -435,10 +450,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
//
// QGrad in Gemm C position
static
auto
MakeQGradGridDescriptor_M_K
(
const
std
::
vector
<
index_t
>&
q_gs_ms_ks_lengths
_vec
,
const
std
::
vector
<
index_t
>&
q_gs_ms_ks_strides
_vec
)
static
auto
MakeQGradGridDescriptor_M_K
(
const
std
::
vector
<
index_t
>&
q_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
q_gs_ms_ks_strides
)
{
return
Transform
::
MakeCGridDescriptor_M_N
(
q_gs_ms_ks_lengths
_vec
,
q_gs_ms_ks_strides
_vec
);
return
Transform
::
MakeCGridDescriptor_M_N
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
}
//
...
...
@@ -446,16 +461,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
//
// KGrad in Gemm C position
static
auto
MakeKGradGridDescriptor_N_K
(
const
std
::
vector
<
index_t
>&
k_gs_ns_ks_lengths
_vec
,
const
std
::
vector
<
index_t
>&
k_gs_ns_ks_strides
_vec
)
static
auto
MakeKGradGridDescriptor_N_K
(
const
std
::
vector
<
index_t
>&
k_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
k_gs_ns_ks_strides
)
{
return
Transform
::
MakeCGridDescriptor_M_N
(
k_gs_ns_ks_lengths
_vec
,
k_gs_ns_ks_strides
_vec
);
return
Transform
::
MakeCGridDescriptor_M_N
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
}
static
auto
MakeZGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
_vec
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
_vec
)
static
auto
MakeZGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
)
{
return
Transform
::
MakeCGridDescriptor_M_N
(
z_gs_ms_ns_lengths
_vec
,
z_gs_ms_ns_strides
_vec
);
return
Transform
::
MakeCGridDescriptor_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
}
static
auto
MakeLSEGridDescriptor_M
(
index_t
MRaw
)
...
...
@@ -482,6 +497,23 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
return
lse_grid_desc_mraw
;
}
}
// D in Gemm0 C position
static
auto
MakeD0GridDescriptor_M_N
(
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
)
{
return
Transform
::
MakeCGridDescriptor_M_N
(
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
);
}
static
auto
MakeD0GridDescriptor_G_M_N
(
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
)
{
return
Transform
::
MakeCGridDescriptor_G_M_N
(
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
);
}
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
({},
{}));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
({},
{}));
...
...
@@ -490,11 +522,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
using
LSEGridDesc_M
=
decltype
(
MakeLSEGridDescriptor_M
(
1
));
using
AGridDesc_G_M_K
=
decltype
(
Transform
::
MakeAGridDescriptor_G_M_K
({},
{}));
using
BGridDesc_G_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_G_N_K
({},
{}));
using
D0GridDesc_G_M_N
=
decltype
(
MakeD0GridDescriptor_G_M_N
({},
{}));
using
B1GridDesc_G_N_K
=
decltype
(
Transform
::
MakeB1GridDescriptor_G_N_K
({},
{}));
using
CGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
ZGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
KGridDesc_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_N_K
({},
{}));
using
D0GridDesc_M_N
=
decltype
(
MakeD0GridDescriptor_M_N
({},
{}));
using
YGradGridDesc_M0_O_M1
=
decltype
(
MakeYGradGridDescriptor_M0_O_M1
(
YGridDesc_M_O
{}));
using
ZGridDesc_M_N
=
decltype
(
MakeZGridDescriptor_M_N
({},
{}));
...
...
@@ -519,12 +553,14 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
{
ComputeBasePtrOfStridedBatch
(
const
AGridDesc_G_M_K
&
a_grid_desc_g_m_k
,
const
BGridDesc_G_N_K
&
b_grid_desc_g_n_k
,
const
D0GridDesc_G_M_N
&
d0_grid_desc_g_m_n
,
const
ZGridDesc_G_M_N
&
z_grid_desc_g_m_n
,
const
B1GridDesc_G_N_K
&
b1_grid_desc_g_n_k
,
const
CGridDesc_G_M_N
&
c_grid_desc_g_m_n
,
index_t
BatchStrideLSE
)
:
a_grid_desc_g_m_k_
(
a_grid_desc_g_m_k
),
b_grid_desc_g_n_k_
(
b_grid_desc_g_n_k
),
d0_grid_desc_g_m_n_
(
d0_grid_desc_g_m_n
),
z_grid_desc_g_m_n_
(
z_grid_desc_g_m_n
),
b1_grid_desc_g_n_k_
(
b1_grid_desc_g_n_k
),
c_grid_desc_g_m_n_
(
c_grid_desc_g_m_n
),
...
...
@@ -542,6 +578,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
return
b_grid_desc_g_n_k_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
__host__
__device__
constexpr
long_index_t
GetD0BasePtr
(
index_t
g_idx
)
const
{
return
d0_grid_desc_g_m_n_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
__host__
__device__
constexpr
long_index_t
GetZBasePtr
(
index_t
g_idx
)
const
{
return
z_grid_desc_g_m_n_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
...
...
@@ -565,6 +606,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
private:
AGridDesc_G_M_K
a_grid_desc_g_m_k_
;
BGridDesc_G_N_K
b_grid_desc_g_n_k_
;
D0GridDesc_G_M_N
d0_grid_desc_g_m_n_
;
ZGridDesc_G_M_N
z_grid_desc_g_m_n_
;
B1GridDesc_G_N_K
b1_grid_desc_g_n_k_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
...
...
@@ -574,6 +616,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
<
InputDataType
,
// TODO: distinguish A/B datatype
D0DataType
,
OutputDataType
,
ZDataType
,
GemmDataType
,
...
...
@@ -589,6 +632,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
AGridDesc_AK0_M_AK1
,
BGridDesc_BK0_N_BK1
,
KGridDesc_N_K
,
D0GridDesc_M_N
,
ZGridDesc_M_N
,
B1GridDesc_BK0_N_BK1
,
YGridDesc_M_O
,
...
...
@@ -625,6 +669,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
BBlockTransferDstScalarPerVector_BK1
,
true
,
BBlockLdsExtraN
,
D0BlockTransferSrcScalarPerVector
,
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
B1BlockTransferThreadClusterArrangeOrder
,
B1BlockTransferSrcAccessOrder
,
...
...
@@ -649,6 +694,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// pointers
const
InputDataType
*
p_a_grid_
;
const
InputDataType
*
p_b_grid_
;
const
D0DataType
*
p_d0_grid_
;
ZDataType
*
p_z_grid_
;
const
InputDataType
*
p_b1_grid_
;
const
InputDataType
*
p_c_grid_
;
...
...
@@ -661,6 +707,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
typename
GridwiseGemm
::
D0GridDescriptor_M0_N0_M1_M2_N1_M3
d0_grid_desc_m0_n0_m1_m2_n1_m3_
;
ZGridDesc_M_N
z_grid_desc_m_n_
;
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1_
;
YGridDesc_M_O
y_grid_desc_m_o_
;
...
...
@@ -700,6 +747,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
index_t
batch_count_
;
// raw data
std
::
vector
<
ck
::
index_t
>
d0_n_length_stride_
;
};
// Argument
struct
Argument
:
public
BaseArgument
...
...
@@ -714,8 +764,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
std
::
vector
<
void
*>&
p_Qgrads
,
std
::
vector
<
void
*>&
p_Kgrads
,
std
::
vector
<
void
*>&
p_Vgrads
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>&
p_acc0_bias
es
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>&
p_acc1_bias
es
,
const
std
::
vector
<
const
void
*
>&
p_acc0_bias
_vec
,
const
std
::
vector
<
const
void
*
>&
p_acc1_bias
_vec
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
...
...
@@ -745,16 +795,14 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_Qgrads
.
size
())
&&
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_Kgrads
.
size
())
&&
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_Vgrads
.
size
())
&&
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_LSEs
.
size
())))
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_LSEs
.
size
())
&&
(
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_acc0_bias_vec
.
size
())
||
ck
::
type_convert
<
ck
::
index_t
>
(
p_acc0_bias_vec
.
size
()
==
0
))
&&
0
==
p_acc1_bias_vec
.
size
()))
{
throw
std
::
runtime_error
(
"wrong! group_count_ != p_As/b/b1/c.size"
);
}
if
(
!
(
p_acc0_biases
.
size
()
==
p_acc1_biases
.
size
()))
{
throw
std
::
runtime_error
(
"wrong! acc0_bias_vec.size != acc1_bias_vec.size"
);
}
grid_size_
=
0
;
index_t
z_random_matrix_offset
=
0
;
...
...
@@ -763,6 +811,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
{
const
auto
p_a_grid
=
static_cast
<
const
InputDataType
*>
(
p_As
[
i
]);
const
auto
p_b_grid
=
static_cast
<
const
InputDataType
*>
(
p_Bs
[
i
]);
const
auto
p_d0_grid
=
(
ck
::
type_convert
<
ck
::
index_t
>
(
p_acc0_bias_vec
.
size
())
==
group_count_
)
?
static_cast
<
const
D0DataType
*>
(
p_acc0_bias_vec
[
i
])
:
nullptr
;
auto
p_z_grid
=
static_cast
<
ZDataType
*>
(
p_Zs
[
i
]);
const
auto
p_b1_grid
=
static_cast
<
const
InputDataType
*>
(
p_B1s
[
i
]);
const
auto
p_c_grid
=
static_cast
<
const
InputDataType
*>
(
p_Cs
[
i
]);
...
...
@@ -778,6 +830,23 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
problem_desc
.
a_gs_ms_ks_lengths
,
problem_desc
.
a_gs_ms_ks_strides
);
const
auto
b_grid_desc_bk0_n_bk1
=
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
problem_desc
.
b_gs_ns_ks_lengths
,
problem_desc
.
b_gs_ns_ks_strides
);
std
::
vector
<
index_t
>
tmp_d0_gs_ms_ns_lengths
;
std
::
vector
<
index_t
>
tmp_d0_gs_ms_ns_strides
;
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
tmp_d0_gs_ms_ns_lengths
=
problem_desc
.
acc0_bias_gs_ms_ns_lengths
;
tmp_d0_gs_ms_ns_strides
=
problem_desc
.
acc0_bias_gs_ms_ns_strides
;
}
else
{
tmp_d0_gs_ms_ns_lengths
=
{
1
,
1
,
1
,
1
};
tmp_d0_gs_ms_ns_strides
=
{
0
,
0
,
0
,
0
};
}
const
D0GridDesc_M_N
d0_grid_desc_m_n
{
DeviceOp
::
MakeD0GridDescriptor_M_N
(
tmp_d0_gs_ms_ns_lengths
,
tmp_d0_gs_ms_ns_strides
)};
const
auto
d0_grid_desc_m0_n0_m1_m2_n1_m3
=
GridwiseGemm
::
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3
(
d0_grid_desc_m_n
);
const
auto
z_grid_desc_m_n
=
DeviceOp
::
MakeZGridDescriptor_M_N
(
problem_desc
.
z_gs_ms_ns_lengths
,
problem_desc
.
z_gs_ms_ns_strides
);
const
auto
b1_grid_desc_bk0_n_bk1
=
DeviceOp
::
MakeB1GridDescriptor_BK0_N_BK1
(
...
...
@@ -797,6 +866,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
problem_desc
.
a_gs_ms_ks_lengths
,
problem_desc
.
a_gs_ms_ks_strides
);
const
auto
b_grid_desc_g_n_k
=
Transform
::
MakeB0GridDescriptor_G_N_K
(
problem_desc
.
b_gs_ns_ks_lengths
,
problem_desc
.
b_gs_ns_ks_strides
);
const
auto
d0_grid_desc_g_m_n
=
DeviceOp
::
MakeD0GridDescriptor_G_M_N
(
tmp_d0_gs_ms_ns_lengths
,
tmp_d0_gs_ms_ns_strides
);
const
auto
z_grid_desc_g_m_n
=
Transform
::
MakeCGridDescriptor_G_M_N
(
problem_desc
.
z_gs_ms_ns_lengths
,
problem_desc
.
z_gs_ms_ns_strides
);
const
auto
b1_grid_desc_g_n_k
=
Transform
::
MakeB1GridDescriptor_G_N_K
(
...
...
@@ -833,6 +904,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const
auto
compute_base_ptr_of_batch
=
ComputeBasePtrOfStridedBatch
(
a_grid_desc_g_m_k
,
b_grid_desc_g_n_k
,
d0_grid_desc_g_m_n
,
z_grid_desc_g_m_n
,
b1_grid_desc_g_n_k
,
c_grid_desc_g_m_n
,
...
...
@@ -844,18 +916,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
grid_size_
+=
grid_size_grp
;
// for each group, make sure acc0_biases_gs_ms_ns_lengths.size() == NumAcc0Bias and
// so on
if
(
!
(
problem_desc
.
acc0_biases_gs_ms_ns_lengths
.
size
()
==
NumAcc0Bias
&&
problem_desc
.
acc0_biases_gs_ms_ns_strides
.
size
()
==
NumAcc0Bias
&&
problem_desc
.
acc1_biases_gs_ms_os_lengths
.
size
()
==
NumAcc1Bias
&&
problem_desc
.
acc1_biases_gs_ms_os_strides
.
size
()
==
NumAcc1Bias
))
{
throw
std
::
runtime_error
(
"wrong! number of biases in function argument does not "
"match that in template argument"
);
}
const
auto
raw_m_padded
=
GridwiseGemm
::
GetPaddedSize
(
problem_desc
.
a_gs_ms_ks_lengths
[
NumDimG
+
NumDimM
-
1
]);
const
auto
raw_n_padded
=
GridwiseGemm
::
GetPaddedSize
(
...
...
@@ -863,6 +923,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
group_kernel_args_
.
push_back
({
p_a_grid
,
p_b_grid
,
p_d0_grid
,
p_z_grid
,
p_b1_grid
,
p_c_grid
,
...
...
@@ -873,6 +934,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_vgrad_grid
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
z_grid_desc_m_n
,
b1_grid_desc_bk0_n_bk1
,
y_grid_desc_m_o
,
...
...
@@ -894,6 +956,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
z_random_matrix_offset
=
z_random_matrix_offset
+
raw_m_padded
*
raw_n_padded
*
batch_count
;
// for check
std
::
vector
<
ck
::
index_t
>
d0_n_length_stride
;
d0_n_length_stride
.
push_back
(
tmp_d0_gs_ms_ns_lengths
[
NumDimG
+
NumDimM
]);
d0_n_length_stride
.
push_back
(
tmp_d0_gs_ms_ns_strides
[
NumDimG
+
NumDimM
]);
group_device_args_
.
push_back
(
{{
problem_desc
.
a_gs_ms_ks_lengths
[
NumDimG
+
NumDimM
-
1
],
problem_desc
.
b_gs_ns_ks_lengths
[
NumDimG
+
NumDimN
-
1
],
...
...
@@ -908,15 +975,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
{
problem_desc
.
c_gs_ms_gemm1ns_strides
[
NumDimG
+
NumDimM
-
1
],
problem_desc
.
c_gs_ms_gemm1ns_strides
[
NumDimG
+
NumDimM
+
NumDimO
-
1
]},
c_grid_desc_g_m_n
,
batch_count
});
batch_count
,
d0_n_length_stride
});
}
// TODO: implement bias addition
// ignore = p_acc0_bias
es
;
// ignore = p_acc1_bias
es
;
// ignore = acc0_bias
es
_gs_ms_ns_lengths;
// ignore = acc0_bias
es
_gs_ms_ns_strides;
// ignore = acc1_bias
es
_gs_ms_gemm1ns_lengths;
// ignore = acc1_bias
es
_gs_ms_gemm1ns_strides;
// ignore = p_acc0_bias
_vec
;
// ignore = p_acc1_bias
_vec
;
// ignore = acc0_bias_gs_ms_ns_lengths;
// ignore = acc0_bias_gs_ms_ns_strides;
// ignore = acc1_bias_gs_ms_gemm1ns_lengths;
// ignore = acc1_bias_gs_ms_gemm1ns_strides;
}
// element-wise op
...
...
@@ -971,6 +1039,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const
auto
kernel
=
kernel_grouped_multihead_attention_backward_qloop_xdl_cshuffle_v2
<
GridwiseGemm
,
D0DataType
,
GroupKernelArg
,
AElementwiseOperation
,
BElementwiseOperation
,
...
...
@@ -1067,6 +1136,19 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
return
false
;
}
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
if
(
device_arg
.
d0_n_length_stride_
[
1
]
==
1
&&
device_arg
.
d0_n_length_stride_
[
0
]
%
D0BlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
if
(
device_arg
.
d0_n_length_stride_
[
1
]
!=
1
&&
D0BlockTransferSrcScalarPerVector
!=
1
)
{
return
false
;
}
}
// Note: we need raw lengths since threadwise copy can not handle vector load when part
// of vector is out of bounds Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O
const
auto
MzRaw
=
device_arg
.
raw_lengths_mz_nz_kz_gemm1nz_
[
0
];
...
...
@@ -1140,8 +1222,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
std
::
vector
<
void
*>&
p_Qgrads
,
std
::
vector
<
void
*>&
p_Kgrads
,
std
::
vector
<
void
*>&
p_Vgrads
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>&
p_acc0_bias
es
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>&
p_acc1_bias
es
,
const
std
::
vector
<
const
void
*
>&
p_acc0_bias
_vec
,
const
std
::
vector
<
const
void
*
>&
p_acc1_bias
_vec
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
...
...
@@ -1161,8 +1243,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_Qgrads
,
p_Kgrads
,
p_Vgrads
,
p_acc0_bias
es
,
p_acc1_bias
es
,
p_acc0_bias
_vec
,
p_acc1_bias
_vec
,
problem_desc_vec
,
a_element_op
,
b_element_op
,
...
...
@@ -1188,8 +1270,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
std
::
vector
<
void
*>&
p_Qgrads
,
std
::
vector
<
void
*>&
p_Kgrads
,
std
::
vector
<
void
*>&
p_Vgrads
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>&
p_acc0_bias
es
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>&
p_acc1_bias
es
,
const
std
::
vector
<
const
void
*
>&
p_acc0_bias
_vec
,
const
std
::
vector
<
const
void
*
>&
p_acc1_bias
_vec
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
...
...
@@ -1209,8 +1291,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_Qgrads
,
p_Kgrads
,
p_Vgrads
,
p_acc0_bias
es
,
// cast in struct Argument
p_acc1_bias
es
,
// cast in struct Argument
p_acc0_bias
_vec
,
// cast in struct Argument
p_acc1_bias
_vec
,
// cast in struct Argument
problem_desc_vec
,
a_element_op
,
b_element_op
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v1.hpp
View file @
4c8b47c0
...
...
@@ -21,6 +21,7 @@
namespace
ck
{
template
<
typename
InputDataType
,
typename
D0DataType
,
typename
OutputDataType
,
typename
ZDataType
,
typename
GemmDataType
,
...
...
@@ -36,6 +37,7 @@ template <typename InputDataType,
typename
QGridDesc_K0_M_K1
,
typename
KGridDesc_K0_N_K1
,
typename
KGridDesc_N_K
,
typename
D0GridDesc_M_N
,
typename
ZGridDesc_M_N
,
typename
VGridDesc_O0_N_O1
,
typename
YGridDesc_M_O
,
...
...
@@ -72,6 +74,7 @@ template <typename InputDataType,
index_t
BBlockTransferDstScalarPerVector_BK1
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
// ignored
index_t
BBlockLdsExtraN
,
index_t
D0BlockTransferSrcScalarPerVector
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
...
...
@@ -1219,13 +1222,128 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
return
math
::
max
(
p_slash_sgrad_bytes_end
,
softmax_bytes_end
,
c_block_bytes_end
);
}
// D0
static
constexpr
auto
D0M2
=
Number
<
4
>
{};
static
constexpr
auto
D0M1
=
Number
<
MPerXdl
>
{}
/
D0M2
;
static
constexpr
auto
D0M0
=
Number
<
MPerBlock
>
{}
/
Number
<
MPerXdl
>
{};
__host__
__device__
static
constexpr
auto
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3
(
const
D0GridDesc_M_N
&
d0_grid_desc_m_n
)
{
const
auto
M
=
d0_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
d0_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
MBlock
=
M
/
MPerBlock
;
const
auto
NBlock
=
N
/
NPerBlock
;
const
auto
d0_grid_desc_m0_n0_m1_m2_n1_m3
=
transform_tensor_descriptor
(
d0_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
D0M0
,
D0M1
,
D0M2
)),
make_unmerge_transform
(
make_tuple
(
NBlock
,
Number
<
NPerBlock
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
3
,
5
>
{},
Sequence
<
1
,
4
>
{}));
return
d0_grid_desc_m0_n0_m1_m2_n1_m3
;
}
using
D0GridDescriptor_M0_N0_M1_M2_N1_M3
=
remove_cvref_t
<
decltype
(
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3
(
D0GridDesc_M_N
{}))
>
;
struct
D0Loader
{
template
<
typename
DataType
>
struct
TypeTransform
{
using
Type
=
DataType
;
};
template
<
>
struct
TypeTransform
<
void
>
{
using
Type
=
ck
::
half_t
;
};
static
constexpr
index_t
NThreadClusterLengths
=
MPerXdl
;
static_assert
(
MPerXdl
<=
KPerBlock
);
static_assert
(
D0BlockTransferSrcScalarPerVector
*
NThreadClusterLengths
<=
NPerBlock
,
"D0BlockTransferSrcScalarPerVector * NThreadClusterLengths <= NPerBlock"
);
__host__
__device__
static
constexpr
auto
GetD0BlockWriteDescriptor_M0_N0_M1_M2_N1_M3
()
{
// B1 matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
D0M1
,
Number
<
NPerBlock
>
{},
D0M2
));
}
__host__
__device__
static
constexpr
auto
GetD0BlockReadDescriptor_N0_N1_M0_M1_M2
()
{
constexpr
auto
d0_raw_m0_n_m1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
D0M1
,
Number
<
NPerBlock
>
{},
D0M2
));
constexpr
auto
d0_n0_n1_m0_m1_m2
=
transform_tensor_descriptor
(
d0_raw_m0_n_m1
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
D0M1
/
I2
,
I2
)),
make_unmerge_transform
(
make_tuple
(
Number
<
NPerBlock
/
NPerXdl
>
{},
Number
<
NPerXdl
>
{})),
make_pass_through_transform
(
D0M2
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
2
,
3
>
{},
Sequence
<
0
,
1
>
{},
Sequence
<
4
>
{}));
return
d0_n0_n1_m0_m1_m2
;
}
static
constexpr
auto
d0_block_write_desc_m0_n0_m1_m2_n1_m3
=
GetD0BlockWriteDescriptor_M0_N0_M1_M2_N1_M3
();
static
constexpr
auto
d0_block_read_desc_n0_n1_m0_m1_m2
=
GetD0BlockReadDescriptor_N0_N1_M0_M1_M2
();
static
constexpr
auto
d0_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I4
,
I1
,
D0M2
));
using
D0BlockwiseCopy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
I1
,
I1
,
I1
,
D0M1
,
NPerBlock
,
D0M2
>
,
// BlockSliceLengths
Sequence
<
1
,
1
,
1
,
BlockSize
/
NThreadClusterLengths
,
NThreadClusterLengths
,
1
>
,
// ThreadClusterLengths
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// ThreadClusterArrangeOrder
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
D0GridDescriptor_M0_N0_M1_M2_N1_M3
,
// SrcDesc
decltype
(
d0_block_write_desc_m0_n0_m1_m2_n1_m3
),
// DstDesc
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// SrcDimAccessOrder
Sequence
<
0
,
1
,
2
,
4
,
3
,
5
>
,
// DstDimAccessOrder
4
,
// SrcVectorDim
5
,
// DstVectorDim
D0BlockTransferSrcScalarPerVector
,
// SrcScalarPerVector
4
,
// DstScalarPerVector
1
,
1
,
true
,
true
,
// DstResetCoord
1
>
;
using
D0ThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
decltype
(
d0_block_read_desc_n0_n1_m0_m1_m2
),
// SrcDesc
decltype
(
d0_thread_desc_
),
// DstDesc
Sequence
<
1
,
1
,
4
,
1
,
4
>
,
// SliceLengths
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
// DimAccessOrder
4
,
// SrcVectorDim
2
,
// SrcScalarPerVector
2
>
;
};
template
<
bool
HasMainKBlockLoop
,
bool
IsDropout
,
typename
Block2CTileMap
,
typename
C0MatrixMask
,
typename
YGradGridDesc_O0_M_O1
>
__device__
static
void
Run
(
const
InputDataType
*
__restrict__
p_q_grid
,
__device__
static
void
Run
(
const
InputDataType
*
__restrict__
p_q_grid
,
const
InputDataType
*
__restrict__
p_k_grid
,
const
D0DataType
*
__restrict__
p_d0_grid
,
ZDataType
*
__restrict__
p_z_grid
,
const
InputDataType
*
__restrict__
p_v_grid
,
const
InputDataType
*
__restrict__
p_y_grid
,
...
...
@@ -1242,6 +1360,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const
CElementwiseOperation
&
c_element_op
,
const
QGridDesc_K0_M_K1
&
q_grid_desc_k0_m_k1
,
const
KGridDesc_K0_N_K1
&
k_grid_desc_k0_n_k1
,
const
D0GridDescriptor_M0_N0_M1_M2_N1_M3
&
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
const
VGridDesc_O0_N_O1
&
v_grid_desc_o0_n_o1
,
...
...
@@ -1792,6 +1911,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// gemm0 M loop
index_t
gemm0_m_block_outer_index
=
num_gemm0_m_block_outer_loop
-
1
;
// D0
auto
d0_block_copy_global_to_lds
=
typename
D0Loader
::
D0BlockwiseCopy
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
gemm0_m_block_outer_index
,
block_work_idx_n
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{},
D0Loader
::
d0_block_write_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
auto
d0_thread_copy_lds_to_vgpr
=
typename
D0Loader
::
D0ThreadCopy
(
make_tuple
(
wave_id
[
I1
],
wave_m_n_id
[
I1
],
0
,
wave_m_n_id
[
I0
],
0
));
do
{
auto
m_block_data_idx_on_grid
=
...
...
@@ -1939,6 +2070,57 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
}
block_sync_lds
();
// wait for lds read in gemm0 blockwise gemm
// add bias
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
static
constexpr
auto
&
c_thread_desc
=
s_blockwise_gemm
.
GetCThreadDesc
();
const
auto
d0_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_d0_grid
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
auto
d0_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
D0DataType
*>
(
p_shared
)
+
SharedMemTrait
::
k_block_space_offset
,
D0Loader
::
d0_block_write_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
auto
d0_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
D0DataType
>
(
D0Loader
::
d0_thread_desc_
.
GetElementSpaceSize
());
static_for
<
0
,
D0M0
,
1
>
{}([
&
](
auto
mr
)
{
// load data to lds
d0_block_copy_global_to_lds
.
RunRead
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
d0_grid_buf
);
d0_block_copy_global_to_lds
.
MoveSrcSliceWindow
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
1
,
0
,
0
,
0
));
d0_block_copy_global_to_lds
.
RunWrite
(
D0Loader
::
d0_block_write_desc_m0_n0_m1_m2_n1_m3
,
d0_block_buf
);
block_sync_lds
();
// read data form lds
d0_thread_copy_lds_to_vgpr
.
Run
(
D0Loader
::
d0_block_read_desc_n0_n1_m0_m1_m2
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
d0_block_buf
,
D0Loader
::
d0_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
d0_thread_buf
);
// bias add
static_for
<
0
,
d0_thread_buf
.
Size
(),
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
c_offset
=
c_thread_desc
.
CalculateOffset
(
make_tuple
(
mr
,
I0
,
i
));
s_slash_p_thread_buf
(
Number
<
c_offset
>
{})
+=
ck
::
type_convert
<
FloatGemmAcc
>
(
d0_thread_buf
[
i
]);
});
});
// load k
gemm_tile_k_blockwise_copy
.
RunWrite
(
GemmBlockwiseCopy
::
k_block_desc_k0_n_k1
,
k_block_buf
);
d0_block_copy_global_to_lds
.
MoveSrcSliceWindow
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
-
1
,
0
,
-
D0M0
.
value
,
0
,
0
,
0
));
}
// P_i: = softmax(scalar * S_i:)
// scaling is already performed in the preceding statements with s_element_op
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
View file @
4c8b47c0
...
...
@@ -21,6 +21,7 @@
namespace
ck
{
template
<
typename
InputDataType
,
typename
D0DataType
,
typename
OutputDataType
,
typename
ZDataType
,
typename
GemmDataType
,
...
...
@@ -36,6 +37,7 @@ template <typename InputDataType,
typename
QGridDesc_K0_M_K1
,
typename
KGridDesc_K0_N_K1
,
typename
KGridDesc_N_K
,
typename
D0GridDesc_M_N
,
typename
ZGridDesc_M_N
,
typename
VGridDesc_N0_O_N1
,
typename
YGridDesc_M_O
,
...
...
@@ -72,6 +74,7 @@ template <typename InputDataType,
index_t
BBlockTransferDstScalarPerVector_BK1
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
// ignored
index_t
BBlockLdsExtraN
,
index_t
D0BlockTransferSrcScalarPerVector
,
typename
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
B1BlockTransferThreadClusterArrangeOrder
,
typename
B1BlockTransferSrcAccessOrder
,
...
...
@@ -1150,13 +1153,128 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
c_block_bytes_end
);
}
// D0
static
constexpr
auto
D0M2
=
Number
<
4
>
{};
static
constexpr
auto
D0M1
=
Number
<
MPerXdl
>
{}
/
D0M2
;
static
constexpr
auto
D0M0
=
Number
<
MPerBlock
>
{}
/
Number
<
MPerXdl
>
{};
__host__
__device__
static
constexpr
auto
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3
(
const
D0GridDesc_M_N
&
d0_grid_desc_m_n
)
{
const
auto
M
=
d0_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
d0_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
MBlock
=
M
/
MPerBlock
;
const
auto
NBlock
=
N
/
NPerBlock
;
const
auto
d0_grid_desc_m0_n0_m1_m2_n1_m3
=
transform_tensor_descriptor
(
d0_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
D0M0
,
D0M1
,
D0M2
)),
make_unmerge_transform
(
make_tuple
(
NBlock
,
Number
<
NPerBlock
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
3
,
5
>
{},
Sequence
<
1
,
4
>
{}));
return
d0_grid_desc_m0_n0_m1_m2_n1_m3
;
}
using
D0GridDescriptor_M0_N0_M1_M2_N1_M3
=
remove_cvref_t
<
decltype
(
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3
(
D0GridDesc_M_N
{}))
>
;
struct
D0Loader
{
template
<
typename
DataType
>
struct
TypeTransform
{
using
Type
=
DataType
;
};
template
<
>
struct
TypeTransform
<
void
>
{
using
Type
=
ck
::
half_t
;
};
static
constexpr
index_t
NThreadClusterLengths
=
32
;
static_assert
(
NPerXdl
==
32
);
static_assert
(
D0BlockTransferSrcScalarPerVector
*
NThreadClusterLengths
<=
NPerBlock
,
"D0BlockTransferSrcScalarPerVector * NThreadClusterLengths <= NPerBlock"
);
__host__
__device__
static
constexpr
auto
GetD0BlockWriteDescriptor_M0_N0_M1_M2_N1_M3
()
{
// B1 matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
D0M1
,
Number
<
NPerBlock
>
{},
D0M2
));
}
__host__
__device__
static
constexpr
auto
GetD0BlockReadDescriptor_N0_N1_M0_M1_M2
()
{
constexpr
auto
d0_raw_m0_n_m1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
D0M1
,
Number
<
NPerBlock
>
{},
D0M2
));
constexpr
auto
d0_n0_n1_m0_m1_m2
=
transform_tensor_descriptor
(
d0_raw_m0_n_m1
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
D0M1
/
I2
,
I2
)),
make_unmerge_transform
(
make_tuple
(
Number
<
NPerBlock
/
NPerXdl
>
{},
Number
<
NPerXdl
>
{})),
make_pass_through_transform
(
D0M2
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
2
,
3
>
{},
Sequence
<
0
,
1
>
{},
Sequence
<
4
>
{}));
return
d0_n0_n1_m0_m1_m2
;
}
static
constexpr
auto
d0_block_write_desc_m0_n0_m1_m2_n1_m3
=
GetD0BlockWriteDescriptor_M0_N0_M1_M2_N1_M3
();
static
constexpr
auto
d0_block_read_desc_n0_n1_m0_m1_m2
=
GetD0BlockReadDescriptor_N0_N1_M0_M1_M2
();
static
constexpr
auto
d0_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I4
,
I1
,
D0M2
));
using
D0BlockwiseCopy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
I1
,
I1
,
I1
,
D0M1
,
NPerBlock
,
D0M2
>
,
// BlockSliceLengths
Sequence
<
1
,
1
,
1
,
BlockSize
/
NThreadClusterLengths
,
NThreadClusterLengths
,
1
>
,
// ThreadClusterLengths
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// ThreadClusterArrangeOrder
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
D0GridDescriptor_M0_N0_M1_M2_N1_M3
,
// SrcDesc
decltype
(
d0_block_write_desc_m0_n0_m1_m2_n1_m3
),
// DstDesc
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// SrcDimAccessOrder
Sequence
<
0
,
1
,
2
,
4
,
3
,
5
>
,
// DstDimAccessOrder
4
,
// SrcVectorDim
5
,
// DstVectorDim
D0BlockTransferSrcScalarPerVector
,
// SrcScalarPerVector
4
,
// DstScalarPerVector
1
,
1
,
true
,
true
,
// DstResetCoord
1
>
;
using
D0ThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
typename
TypeTransform
<
D0DataType
>::
Type
,
// SrcData
typename
TypeTransform
<
D0DataType
>::
Type
,
// DstData
decltype
(
d0_block_read_desc_n0_n1_m0_m1_m2
),
// SrcDesc
decltype
(
d0_thread_desc_
),
// DstDesc
Sequence
<
1
,
1
,
4
,
1
,
4
>
,
// SliceLengths
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
// DimAccessOrder
4
,
// SrcVectorDim
2
,
// SrcScalarPerVector
2
>
;
};
template
<
bool
HasMainKBlockLoop
,
bool
IsDropout
,
typename
Block2CTileMap
,
typename
C0MatrixMask
,
typename
YGradGridDesc_M0_O_M1
>
__device__
static
void
Run
(
const
InputDataType
*
__restrict__
p_q_grid
,
__device__
static
void
Run
(
const
InputDataType
*
__restrict__
p_q_grid
,
const
InputDataType
*
__restrict__
p_k_grid
,
const
D0DataType
*
__restrict__
p_d0_grid
,
ZDataType
*
__restrict__
p_z_grid
,
const
InputDataType
*
__restrict__
p_v_grid
,
const
InputDataType
*
__restrict__
p_y_grid
,
...
...
@@ -1173,6 +1291,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const
CElementwiseOperation
&
c_element_op
,
const
QGridDesc_K0_M_K1
&
q_grid_desc_k0_m_k1
,
const
KGridDesc_K0_N_K1
&
k_grid_desc_k0_n_k1
,
const
D0GridDescriptor_M0_N0_M1_M2_N1_M3
&
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
const
VGridDesc_N0_O_N1
&
v_grid_desc_n0_o_n1
,
...
...
@@ -1621,7 +1740,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
0
,
//
wave_m_n_id
[
I1
]),
// NPerXdl
tensor_operation
::
element_wise
::
PassThrough
{}};
//
// set up Y dot dY
//
...
...
@@ -1714,6 +1832,17 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// gemm0 M loop
index_t
gemm0_m_block_outer_index
=
num_gemm0_m_block_outer_loop
-
1
;
// D0
auto
d0_block_copy_global_to_lds
=
typename
D0Loader
::
D0BlockwiseCopy
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
gemm0_m_block_outer_index
,
block_work_idx_n
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{},
D0Loader
::
d0_block_write_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
auto
d0_thread_copy_lds_to_vgpr
=
typename
D0Loader
::
D0ThreadCopy
(
make_tuple
(
wave_id
[
I1
],
wave_m_n_id
[
I1
],
0
,
wave_m_n_id
[
I0
],
0
));
do
{
auto
m_block_data_idx_on_grid
=
...
...
@@ -1855,6 +1984,55 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
block_sync_lds
();
// wait for lds read in gemm0 blockwise gemm
// add bias
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
static
constexpr
auto
&
c_thread_desc
=
s_blockwise_gemm
.
GetCThreadDesc
();
const
auto
d0_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_d0_grid
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
auto
d0_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
D0DataType
*>
(
p_shared
)
+
SharedMemTrait
::
a_block_space_offset
,
D0Loader
::
d0_block_write_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
auto
d0_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
D0DataType
>
(
D0Loader
::
d0_thread_desc_
.
GetElementSpaceSize
());
ignore
=
d0_thread_buf
;
static_for
<
0
,
D0M0
,
1
>
{}([
&
](
auto
mr
)
{
// load data to lds
d0_block_copy_global_to_lds
.
RunRead
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
d0_grid_buf
);
d0_block_copy_global_to_lds
.
MoveSrcSliceWindow
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
1
,
0
,
0
,
0
));
d0_block_copy_global_to_lds
.
RunWrite
(
D0Loader
::
d0_block_write_desc_m0_n0_m1_m2_n1_m3
,
d0_block_buf
);
block_sync_lds
();
// read data form lds
d0_thread_copy_lds_to_vgpr
.
Run
(
D0Loader
::
d0_block_read_desc_n0_n1_m0_m1_m2
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
d0_block_buf
,
D0Loader
::
d0_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
d0_thread_buf
);
// bias add
static_for
<
0
,
d0_thread_buf
.
Size
(),
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
c_offset
=
c_thread_desc
.
CalculateOffset
(
make_tuple
(
mr
,
I0
,
i
));
s_slash_p_thread_buf
(
Number
<
c_offset
>
{})
+=
ck
::
type_convert
<
FloatGemmAcc
>
(
d0_thread_buf
[
i
]);
});
});
d0_block_copy_global_to_lds
.
MoveSrcSliceWindow
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
-
1
,
0
,
-
D0M0
.
value
,
0
,
0
,
0
));
}
// P_i: = softmax(scalar * S_i:)
// scaling is already performed in the preceding statements with s_element_op
blockwise_softmax
.
RunWithPreCalcStats
(
s_slash_p_thread_buf
,
lse_thread_buf
);
...
...
@@ -1929,11 +2107,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// dV = P_drop^T * dY
{
// TODO: explore using dynamic buffer for a1 thread buffer
// For a1_blockwise_copy, the goal is to satisfy pipeline requirements
RunRead(),
// RunWrite(), and MoveSliceWindow(). But it is impossible to
implement given that
// the A1 source buffer is static buffer holding the output
of first GEMM and
// requires constexpr offset by design. Therefore, we pass
tensor coordinate offset
// explicitly in Run() below.
// For a1_blockwise_copy, the goal is to satisfy pipeline requirements
//
RunRead(),
RunWrite(), and MoveSliceWindow(). But it is impossible to
//
implement given that
the A1 source buffer is static buffer holding the output
//
of first GEMM and
requires constexpr offset by design. Therefore, we pass
//
tensor coordinate offset
explicitly in Run() below.
// preload data into LDS
vgrad_gemm_tile_ygrad_blockwise_copy
.
RunRead
(
ygrad_grid_desc_m0_o_m1
,
...
...
@@ -2089,11 +2267,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// dK = scalar * dS^T * Q
{
// TODO: explore using dynamic buffer for a1 thread buffer
// For a1_blockwise_copy, the goal is to satisfy pipeline requirements
RunRead(),
// RunWrite(), and MoveSliceWindow(). But it is impossible to
implement given that
// the A1 source buffer is static buffer holding the output
of first GEMM and
// requires constexpr offset by design. Therefore, we pass
tensor coordinate offset
// explicitly in Run() below.
// For a1_blockwise_copy, the goal is to satisfy pipeline requirements
//
RunRead(),
RunWrite(), and MoveSliceWindow(). But it is impossible to
//
implement given that
the A1 source buffer is static buffer holding the output
//
of first GEMM and
requires constexpr offset by design. Therefore, we pass
//
tensor coordinate offset
explicitly in Run() below.
// preload data into LDS
kgrad_gemm_tile_q_blockwise_copy
.
RunRead
(
q_grid_desc_m0_k_m1
,
q_grid_buf
);
...
...
@@ -2179,7 +2357,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
make_multi_index
(
-
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
}
while
(
0
<
gemm0_m_block_outer_index
--
);
// end j loop
// shuffle dK&dV and write
...
...
include/ck/utility/static_buffer.hpp
View file @
4c8b47c0
...
...
@@ -111,6 +111,11 @@ struct StaticBufferTupleOfVector
return
base
::
operator
()(
i_v
).
template
AsType
<
S
>()(
i_s
);
}
template
<
index_t
I
>
__host__
__device__
constexpr
S
&
operator
()(
Number
<
I
>
i_v
,
Number
<
I
>
i_s
)
{
return
base
::
operator
()(
i_v
).
template
AsType
<
S
>()(
i_s
);
}
// Get X
// i is offset of S, not X. i should be aligned to X
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment