Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
7aa37568
Commit
7aa37568
authored
Aug 11, 2023
by
danyao12
Browse files
qloop dropout optimize
parent
4274096b
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
435 additions
and
380 deletions
+435
-380
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_forward_v2.cpp
...e_softmax_gemm/batched_multihead_attention_forward_v2.cpp
+6
-3
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train_v2.cpp
...ale_softmax_gemm/batched_multihead_attention_train_v2.cpp
+15
-15
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_forward_v2.cpp
...e_softmax_gemm/grouped_multihead_attention_forward_v2.cpp
+6
-3
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_train_v2.cpp
...ale_softmax_gemm/grouped_multihead_attention_train_v2.cpp
+15
-15
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
+7
-9
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2.hpp
...pu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2.hpp
+15
-12
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2.hpp
...pu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2.hpp
+9
-11
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_kloop_v1.hpp
...u/grid/gridwise_batched_mha_bwd_xdl_cshuffle_kloop_v1.hpp
+5
-5
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_kloop_v2.hpp
...u/grid/gridwise_batched_mha_bwd_xdl_cshuffle_kloop_v2.hpp
+5
-5
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v1.hpp
...dwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v1.hpp
+34
-32
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v2.hpp
...dwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v2.hpp
+34
-32
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v1.hpp
...id/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v1.hpp
+34
-32
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
...id/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
+34
-32
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v1.hpp
...ion/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v1.hpp
+5
-5
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2.hpp
...ion/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2.hpp
+211
-169
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_forward_v2.cpp
View file @
7aa37568
...
...
@@ -121,6 +121,7 @@ using DeviceGemmInstance =
1
,
// MXdlPerWave
4
,
// NXdlPerWave
1
,
// Gemm1NXdlPerWave
1
,
// DropoutStep
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
...
...
@@ -194,6 +195,7 @@ using DeviceGemmInstance =
1
,
// MXdlPerWave
4
,
// NXdlPerWave
2
,
// Gemm1NXdlPerWave
1
,
// DropoutStep
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
...
...
@@ -257,7 +259,7 @@ using DeviceGemmInstance =
128
,
// MPerBlock
128
,
// NPerBlock
32
,
// KPerBlock
128
,
// Gemm1NPerBlock
64
,
// Gemm1NPerBlock
32
,
// Gemm1KPerBlock
8
,
// AK1
8
,
// BK1
...
...
@@ -266,7 +268,8 @@ using DeviceGemmInstance =
32
,
// NPerXDL
1
,
// MXdlPerWave
4
,
// NXdlPerWave
4
,
// Gemm1NXdlPerWave
2
,
// Gemm1NXdlPerWave
1
,
// DropoutStep
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
...
...
@@ -282,7 +285,7 @@ using DeviceGemmInstance =
8
,
true
,
4
,
S
<
8
,
32
,
1
>
,
// B1BlockTransfer
S
<
16
,
16
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train_v2.cpp
View file @
7aa37568
...
...
@@ -113,11 +113,11 @@ static constexpr bool Deterministic = false;
#if(DIM <= 32)
// clang-format off
using
DeviceGemmInstanceFWD
=
// #################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| ADataType| BDataType| B1DataType| CDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector| D1BlockTransfer| MaskingSpec| Deterministic|
// #################################################################################| | | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| _NPerBlock| SrcScalar| | |
// #################################################################################| | | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | PerVector| | |
// #################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
InputDataType
,
InputDataType
,
InputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
128
,
128
,
32
,
32
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
4
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
2
,
false
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
8
,
4
,
MaskingSpec
,
Deterministic
>
;
// #################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| ADataType| BDataType| B1DataType| CDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1|
Dropout|
ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector| D1BlockTransfer| MaskingSpec| Deterministic|
// #################################################################################| | | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl|
Step|
ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| _NPerBlock| SrcScalar| | |
// #################################################################################| | | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per|
|
Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | PerVector| | |
// #################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave|
|
| | | | | | | | | | | | | | | | | | | | | | | | | | | | |
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
InputDataType
,
InputDataType
,
InputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
128
,
128
,
32
,
32
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
1
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
4
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
2
,
false
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
8
,
4
,
MaskingSpec
,
Deterministic
>
;
using
DeviceGemmInstanceBWD
=
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
...
...
@@ -129,11 +129,11 @@ using DeviceGemmInstanceBWD =
#elif(DIM <= 64)
// clang-format off
using
DeviceGemmInstanceFWD
=
// #################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| ADataType| BDataType| B1DataType| CDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector| D1BlockTransfer| MaskingSpec| Deterministic|
// #################################################################################| | | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| _NPerBlock| SrcScalar| | |
// #################################################################################| | | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | PerVector| | |
// #################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
InputDataType
,
InputDataType
,
InputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
128
,
128
,
32
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
4
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
4
,
MaskingSpec
,
Deterministic
>
;
// #################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| ADataType| BDataType| B1DataType| CDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1|
Dropout|
ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector| D1BlockTransfer| MaskingSpec| Deterministic|
// #################################################################################| | | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl|
Step|
ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| _NPerBlock| SrcScalar| | |
// #################################################################################| | | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per|
|
Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | PerVector| | |
// #################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave|
|
| | | | | | | | | | | | | | | | | | | | | | | | | | | | |
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
InputDataType
,
InputDataType
,
InputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
128
,
128
,
32
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
4
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
4
,
MaskingSpec
,
Deterministic
>
;
using
DeviceGemmInstanceBWD
=
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
...
...
@@ -152,11 +152,11 @@ using DeviceGemmInstanceBWD =
#elif(DIM <= 128)
// clang-format off
using
DeviceGemmInstanceFWD
=
// #################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| ADataType| BDataType| B1DataType| CDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector| D1BlockTransfer| MaskingSpec| Deterministic|
// #################################################################################| | | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| _NPerBlock| SrcScalar| | |
// #################################################################################| | | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | PerVector| | |
// #################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
InputDataType
,
InputDataType
,
InputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
128
,
128
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
4
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
4
,
MaskingSpec
,
Deterministic
>
;
// #################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| ADataType| BDataType| B1DataType| CDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1|
Dropout|
ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector| D1BlockTransfer| MaskingSpec| Deterministic|
// #################################################################################| | | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl|
Step|
ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| _NPerBlock| SrcScalar| | |
// #################################################################################| | | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per|
|
Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | PerVector| | |
// #################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave|
|
| | | | | | | | | | | | | | | | | | | | | | | | | | | | |
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
InputDataType
,
InputDataType
,
InputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
128
,
128
,
32
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
4
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
4
,
MaskingSpec
,
Deterministic
>
;
using
DeviceGemmInstanceBWD
=
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_forward_v2.cpp
View file @
7aa37568
...
...
@@ -121,6 +121,7 @@ using DeviceGemmInstance =
1
,
// MXdlPerWave
4
,
// NXdlPerWave
1
,
// Gemm1NXdlPerWave
1
,
// DropoutStep
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
...
...
@@ -194,6 +195,7 @@ using DeviceGemmInstance =
1
,
// MXdlPerWave
4
,
// NXdlPerWave
2
,
// Gemm1NXdlPerWave
1
,
// DropoutStep
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
...
...
@@ -257,7 +259,7 @@ using DeviceGemmInstance =
128
,
// MPerBlock
128
,
// NPerBlock
32
,
// KPerBlock
128
,
// Gemm1NPerBlock
64
,
// Gemm1NPerBlock
32
,
// Gemm1KPerBlock
8
,
// AK1
8
,
// BK1
...
...
@@ -266,7 +268,8 @@ using DeviceGemmInstance =
32
,
// NPerXDL
1
,
// MXdlPerWave
4
,
// NXdlPerWave
4
,
// Gemm1NXdlPerWave
2
,
// Gemm1NXdlPerWave
1
,
// DropoutStep
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
...
...
@@ -282,7 +285,7 @@ using DeviceGemmInstance =
8
,
true
,
1
,
S
<
8
,
32
,
1
>
,
// B1BlockTransfer
S
<
16
,
16
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_train_v2.cpp
View file @
7aa37568
...
...
@@ -112,11 +112,11 @@ static constexpr bool Deterministic = false;
#if(DIM <= 32)
// clang-format off
using
DeviceGemmInstanceFWD
=
// #################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| ADataType| BDataType| B1DataType| CDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector| D1BlockTransfer| MaskingSpec| Deterministic|
// #################################################################################| | | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| _NPerBlock| SrcScalar| | |
// #################################################################################| | | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | PerVector| | |
// #################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
InputDataType
,
InputDataType
,
InputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
128
,
128
,
32
,
32
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
2
,
false
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
8
,
1
,
MaskingSpec
,
Deterministic
>
;
// #################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| ADataType| BDataType| B1DataType| CDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1|
Dropout|
ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector| D1BlockTransfer| MaskingSpec| Deterministic|
// #################################################################################| | | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl|
Step|
ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| _NPerBlock| SrcScalar| | |
// #################################################################################| | | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per|
|
Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | PerVector| | |
// #################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave|
|
| | | | | | | | | | | | | | | | | | | | | | | | | | | | |
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
InputDataType
,
InputDataType
,
InputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
128
,
128
,
32
,
32
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
1
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
2
,
false
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
8
,
1
,
MaskingSpec
,
Deterministic
>
;
using
DeviceGemmInstanceBWD
=
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
...
...
@@ -128,11 +128,11 @@ using DeviceGemmInstanceBWD =
#elif(DIM <= 64)
// clang-format off
using
DeviceGemmInstanceFWD
=
// #################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| ADataType| BDataType| B1DataType| CDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector| D1BlockTransfer| MaskingSpec| Deterministic|
// #################################################################################| | | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| _NPerBlock| SrcScalar| | |
// #################################################################################| | | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | PerVector| | |
// #################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
InputDataType
,
InputDataType
,
InputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
128
,
128
,
32
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
1
,
MaskingSpec
,
Deterministic
>
;
// #################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| ADataType| BDataType| B1DataType| CDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1|
Dropout|
ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector| D1BlockTransfer| MaskingSpec| Deterministic|
// #################################################################################| | | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl|
Step|
ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| _NPerBlock| SrcScalar| | |
// #################################################################################| | | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per|
|
Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | PerVector| | |
// #################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave|
|
| | | | | | | | | | | | | | | | | | | | | | | | | | | | |
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
InputDataType
,
InputDataType
,
InputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
128
,
128
,
32
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
1
,
MaskingSpec
,
Deterministic
>
;
using
DeviceGemmInstanceBWD
=
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
...
...
@@ -151,11 +151,11 @@ using DeviceGemmInstanceBWD =
#elif(DIM <= 128)
// clang-format off
using
DeviceGemmInstanceFWD
=
// #################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| ADataType| BDataType| B1DataType| CDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector| D1BlockTransfer| MaskingSpec| Deterministic|
// #################################################################################| | | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| _NPerBlock| SrcScalar| | |
// #################################################################################| | | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | PerVector| | |
// #################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
InputDataType
,
InputDataType
,
InputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
128
,
128
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
1
,
MaskingSpec
,
Deterministic
>
;
// #################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| ADataType| BDataType| B1DataType| CDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1|
Dropout|
ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector| D1BlockTransfer| MaskingSpec| Deterministic|
// #################################################################################| | | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl|
Step|
ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| _NPerBlock| SrcScalar| | |
// #################################################################################| | | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per|
|
Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | PerVector| | |
// #################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave|
|
| | | | | | | | | | | | | | | | | | | | | | | | | | | | |
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
InputDataType
,
InputDataType
,
InputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
128
,
128
,
32
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
1
,
MaskingSpec
,
Deterministic
>
;
using
DeviceGemmInstanceBWD
=
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
...
...
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
View file @
7aa37568
...
...
@@ -138,12 +138,12 @@ struct BlockwiseDropout
constexpr
int
tmp_size
=
MRepeat
*
KRepeat
;
int
philox_calls
=
tmp_size
/
4
;
int
philox_calls
=
tmp_size
/
8
;
ushort
tmp
[
tmp_size
];
for
(
int
i
=
0
;
i
<
philox_calls
;
i
++
)
{
ph
.
get_random_
4
x16
((
tmp
+
i
*
4
),
element_global_1d_id
+
i
*
Offset
{}
*
MRaw
);
ph
.
get_random_
8
x16
((
tmp
+
i
*
8
),
element_global_1d_id
+
i
*
Offset
{}
*
MRaw
);
}
block_sync_lds
();
...
...
@@ -179,12 +179,12 @@ struct BlockwiseDropout
constexpr
int
tmp_size
=
MRepeat
*
KRepeat
;
int
philox_calls
=
tmp_size
/
4
;
int
philox_calls
=
tmp_size
/
8
;
ushort
tmp
[
tmp_size
];
for
(
int
i
=
0
;
i
<
philox_calls
;
i
++
)
{
ph
.
get_random_
4
x16
((
tmp
+
i
*
4
),
element_global_1d_id
+
i
*
Offset
{}
*
MRaw
);
ph
.
get_random_
8
x16
((
tmp
+
i
*
8
),
element_global_1d_id
+
i
*
Offset
{}
*
MRaw
);
}
block_sync_lds
();
...
...
@@ -218,21 +218,19 @@ struct BlockwiseDropout
}
// get raw z matrix with random number for shuffle
template
<
typename
ZThreadBuffer
,
typename
Step
,
typename
Offset
>
// N3*N4=8
template
<
typename
ZThreadBuffer
,
typename
Step
,
typename
Offset
>
__host__
__device__
void
GenerateZMatrixAttnFwd
(
ck
::
philox
&
ph
,
index_t
element_global_1d_id
,
ZThreadBuffer
&
z_thread_buf
)
{
constexpr
int
tmp_size
=
MRepeat
*
KRepeat
/
Step
{}.
value
;
int
philox_calls
=
tmp_size
/
4
;
int
philox_calls
=
tmp_size
/
8
;
ushort
tmp
[
tmp_size
];
for
(
int
i
=
0
;
i
<
philox_calls
;
i
++
)
{
ph
.
get_random_
4
x16
((
tmp
+
i
*
4
),
element_global_1d_id
+
i
*
Offset
{});
ph
.
get_random_
8
x16
((
tmp
+
i
*
8
),
element_global_1d_id
+
i
*
Offset
{});
}
static_for
<
0
,
tmp_size
,
1
>
{}([
&
](
auto
i
)
{
z_thread_buf
(
i
)
=
tmp
[
i
.
value
];
});
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2.hpp
View file @
7aa37568
...
...
@@ -40,7 +40,7 @@ template <typename GridwiseGemm,
typename
D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
,
typename
B1GridDesc_BK0_N_BK1
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
,
typename
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_
M4_
N4_N5
_N6
,
typename
LSEGridDescriptor_M
,
typename
Block2CTileMap
,
typename
ComputeBasePtrOfStridedBatch
,
...
...
@@ -73,8 +73,8 @@ __global__ void
const
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_
M4_
N4_N5
_N6
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_
m4_
n4_n5
_n6
,
const
LSEGridDescriptor_M
lse_grid_desc_m
,
const
Block2CTileMap
block_2_ctile_map
,
const
index_t
batch_count
,
...
...
@@ -141,7 +141,7 @@ __global__ void
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_
m4_
n4_n5
_n6
,
lse_grid_desc_m
,
block_2_ctile_map
,
c0_matrix_mask
,
...
...
@@ -174,7 +174,7 @@ __global__ void
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_
m4_
n4_n5
_n6
,
lse_grid_desc_m
,
block_2_ctile_map
,
c0_matrix_mask
,
...
...
@@ -203,7 +203,7 @@ __global__ void
ignore
=
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
;
ignore
=
b1_grid_desc_bk0_n_bk1
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
;
ignore
=
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_
m4_
n4_n5
_n6
;
ignore
=
lse_grid_desc_m
;
ignore
=
block_2_ctile_map
;
ignore
=
batch_count
;
...
...
@@ -263,6 +263,7 @@ template <index_t NumDimG,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
index_t
Gemm1NXdlPerWave
,
index_t
DropoutStep
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
...
...
@@ -564,6 +565,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
MXdlPerWave
,
NXdlPerWave
,
Gemm1NXdlPerWave
,
DropoutStep
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
...
...
@@ -735,8 +737,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
seed_
=
std
::
get
<
0
>
(
seeds
);
offset_
=
std
::
get
<
1
>
(
seeds
);
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
z_grid_desc_m_n_
);
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6_
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6
(
z_grid_desc_m_n_
);
m_raw_padded_
=
GridwiseGemm
::
GetPaddedSize
(
raw_lengths_mz_nz_kz_gemm1nz_
[
0
]);
n_raw_padded_
=
GridwiseGemm
::
GetPaddedSize
(
raw_lengths_mz_nz_kz_gemm1nz_
[
1
]);
...
...
@@ -791,8 +794,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
;
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_
M4_
N4_N5
_N6
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_
m4_
n4_n5_
n6_
;
// block-to-c-tile map
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
...
...
@@ -876,7 +879,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
typename
GridwiseGemm
::
D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
,
DeviceOp
::
B1GridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
,
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_
M4_
N4_N5
_N6
,
DeviceOp
::
LSEGridDesc_M
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
ComputeBasePtrOfStridedBatch
,
...
...
@@ -909,7 +912,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
arg
.
d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg
.
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_
m4_
n4_n5_
n6_
,
arg
.
lse_grid_desc_m_
,
arg
.
block_2_ctile_map_
,
arg
.
batch_count_
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2.hpp
View file @
7aa37568
...
...
@@ -135,7 +135,7 @@ __global__ void
arg_ptr
[
group_id
].
d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg_ptr
[
group_id
].
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg_ptr
[
group_id
].
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_
m4_
n4_n5_
n6_
,
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
arg_ptr
[
group_id
].
block_2_ctile_map_
,
arg_ptr
[
group_id
].
c0_matrix_mask_
,
...
...
@@ -173,7 +173,7 @@ __global__ void
arg_ptr
[
group_id
].
d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg_ptr
[
group_id
].
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg_ptr
[
group_id
].
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_
m4_
n4_n5_
n6_
,
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
arg_ptr
[
group_id
].
block_2_ctile_map_
,
arg_ptr
[
group_id
].
c0_matrix_mask_
,
...
...
@@ -244,6 +244,7 @@ template <index_t NumDimG,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
index_t
Gemm1NXdlPerWave
,
index_t
DropoutStep
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
...
...
@@ -566,6 +567,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
MXdlPerWave
,
NXdlPerWave
,
Gemm1NXdlPerWave
,
DropoutStep
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
...
...
@@ -622,8 +624,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1_
;
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
;
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_
M4_
N4_N5
_N6
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_
m4_
n4_n5_
n6_
;
ZGridDesc_M_N
z_grid_desc_m_n_
;
LSEGridDesc_M
lse_grid_desc_m_
;
...
...
@@ -768,12 +770,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
);
// typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5;
const
auto
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
const
auto
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6
(
z_grid_desc_m_n
);
const
index_t
BlockStart
=
grid_size_
;
...
...
@@ -829,7 +827,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_
m4_
n4_n5
_n6
,
z_grid_desc_m_n
,
lse_grid_desc_m
,
block_2_ctile_map
.
CalculateGridSize
(
c_grid_desc_m_n
),
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_kloop_v1.hpp
View file @
7aa37568
...
...
@@ -1533,8 +1533,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
unsigned
short
,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetElementSpaceSize
(),
true
>
z_tenor_buffer
;
z_tenor_buffer
.
Clear
();
z_ten
s
or_buffer
;
z_ten
s
or_buffer
.
Clear
();
// z matrix global desc
/*const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
...
...
@@ -1966,16 +1966,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
// P_dropped
static_for
<
0
,
n0
,
1
>
{}([
&
](
auto
i
)
{
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
s_slash_p_thread_buf
),
decltype
(
z_tenor_buffer
),
decltype
(
z_ten
s
or_buffer
),
true
,
decltype
(
n0
),
decltype
(
i
)>(
s_slash_p_thread_buf
,
ph
,
z_tenor_buffer
);
s_slash_p_thread_buf
,
ph
,
z_ten
s
or_buffer
);
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tenor_buffer
,
z_ten
s
or_buffer
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_buf
);
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_kloop_v2.hpp
View file @
7aa37568
...
...
@@ -1473,8 +1473,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2
unsigned
short
,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetElementSpaceSize
(),
true
>
z_tenor_buffer
;
z_tenor_buffer
.
Clear
();
z_ten
s
or_buffer
;
z_ten
s
or_buffer
.
Clear
();
// z matrix global desc
/*const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
...
...
@@ -1865,16 +1865,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2
// P_dropped
static_for
<
0
,
n0
,
1
>
{}([
&
](
auto
i
)
{
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
s_slash_p_thread_buf
),
decltype
(
z_tenor_buffer
),
decltype
(
z_ten
s
or_buffer
),
true
,
decltype
(
n0
),
decltype
(
i
)>(
s_slash_p_thread_buf
,
ph
,
z_tenor_buffer
);
s_slash_p_thread_buf
,
ph
,
z_ten
s
or_buffer
);
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tenor_buffer
,
z_ten
s
or_buffer
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_buf
);
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v1.hpp
View file @
7aa37568
...
...
@@ -110,6 +110,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
static
constexpr
auto
Gemm0MWaves
=
MPerBlock
/
(
MPerXdl
*
MXdlPerWave
);
static
constexpr
auto
Gemm0NWaves
=
NPerBlock
/
(
NPerXdl
*
NXdlPerWave
);
static
constexpr
auto
mfma
=
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
static
constexpr
auto
DropoutNThread
=
mfma
.
num_input_blks
;
// 2
// get_random_8x16() generates 8 random numbers each time
static
constexpr
auto
DropoutTile
=
Number
<
DropoutNThread
*
8
>
{};
// 16
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
// C desc for source in blockwise copy
...
...
@@ -119,10 +124,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
const
auto
M
=
z_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
z_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
mfma
=
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
M3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
M4
=
mfma
.
num_input_blks
;
constexpr
auto
M5
=
mfma
.
group_size
;
constexpr
auto
M3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
M4
=
mfma
.
num_input_blks
;
constexpr
auto
M5
=
mfma
.
group_size
;
return
transform_tensor_descriptor
(
z_grid_desc_m_n
,
...
...
@@ -136,9 +140,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
__host__
__device__
static
constexpr
auto
GetPaddedSize
(
const
index_t
size
)
{
constexpr
auto
mfma
=
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
group_size
=
mfma
.
group_size
;
return
math
::
integer_divide_ceil
(
size
,
group_size
)
*
group_size
;
return
math
::
integer_divide_ceil
(
size
,
DropoutTile
)
*
DropoutTile
;
}
__device__
static
auto
GetGemm0WaveIdx
()
...
...
@@ -542,9 +544,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
BBlockDesc_BK0_N_BK1
{});
}
static
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
static
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
mfma
.
k_per_blk
);
// Blockwise gemm with transposed XDL output
using
BlockwiseGemm
=
BlockwiseGemmXdlops_v2
<
...
...
@@ -646,8 +646,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
static
constexpr
index_t
GemmKPack
=
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
group_size
;
static
constexpr
index_t
GemmKPack
=
mfma
.
group_size
;
static
constexpr
index_t
GemmMWave
=
Gemm0NWaves
;
// 4 // 4
static
constexpr
index_t
GemmNWave
=
Gemm0MWaves
;
// 1 // 1
...
...
@@ -770,9 +769,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
static
constexpr
index_t
GemmNRepeat
=
Gemm2NXdlPerWave
;
// 1 // 1
static
constexpr
index_t
GemmMRepeat
=
Gemm2_M
/
GemmMWave
/
MPerXdl
;
// 1 // 1
static
constexpr
index_t
GemmKLoop
=
Gemm2_K
/
Sum_K
;
// 2 // 2
static
constexpr
index_t
GemmKPack
=
math
::
max
(
A_K1
,
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
static
constexpr
index_t
B_K3
=
GemmKPack
;
// 8
static
constexpr
index_t
GemmKPack
=
math
::
max
(
A_K1
,
mfma
.
k_per_blk
);
static
constexpr
index_t
B_K3
=
GemmKPack
;
// 8
static
constexpr
index_t
B_K2
=
XdlopsGemm
<
GemmDataType
,
MPerXdl
,
NPerXdl
,
GemmKPack
,
false
>
{}.
K0PerXdlops
;
// 2
static
constexpr
index_t
B_K1
=
Sum_K
/
B_K2
/
B_K3
;
// 4
...
...
@@ -1570,8 +1568,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
ushort
,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
(),
true
>
z_tenor_buffer
;
z_tenor_buffer
.
Clear
();
z_ten
s
or_buffer
;
z_ten
s
or_buffer
.
Clear
();
auto
z_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_z_grid
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
());
...
...
@@ -1759,7 +1757,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
// scaling is already performed in the preceding statements with s_element_op
blockwise_softmax
.
RunWithPreCalcStats
(
s_slash_p_thread_buf
,
lse_thread_buf
);
constexpr
auto
position_offset
=
M3
*
M4
;
// save z to global
if
constexpr
(
IsDropout
)
{
...
...
@@ -1774,23 +1771,27 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
auto
m_global
=
m_local
+
m_block_data_idx_on_grid
;
auto
n_global
=
n_local
+
n_block_data_idx_on_grid
;
auto
global_elem_id_raw
=
z_random_matrix_offset
+
m_global
*
raw_n_padded
+
n_global
;
// unique element global 1d id
auto
global_tile_id
=
z_random_matrix_offset
+
(
m_global
/
DropoutTile
)
*
DropoutTile
*
raw_n_padded
+
(
n_global
/
DropoutTile
)
*
DropoutTile
;
auto
global_elem_id
=
(
global_elem_id_raw
%
M4
)
*
raw_n_padded
+
(
global_elem_id_raw
/
M4
)
*
M4
;
auto
global_elem_id
=
global_tile_id
+
(
wave_m_n_id
[
I0
]
*
M4
)
+
(
n_global
%
DropoutTile
)
*
raw_n_padded
;
blockwise_dropout
.
template
ApplyDropoutAttnBwdSaveZ
<
decltype
(
s_slash_p_thread_buf
),
decltype
(
z_tenor_buffer
),
decltype
(
position_offset
),
true
>(
s_slash_p_thread_buf
,
ph
,
global_elem_id
,
z_tenor_buffer
,
raw_n_padded
);
decltype
(
z_tensor_buffer
),
decltype
(
DropoutTile
),
true
>(
s_slash_p_thread_buf
,
ph
,
global_elem_id
,
z_tensor_buffer
,
raw_n_padded
);
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tenor_buffer
,
z_ten
s
or_buffer
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
z_grid_buf
);
}
...
...
@@ -1806,15 +1807,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
auto
m_global
=
m_local
+
m_block_data_idx_on_grid
;
auto
n_global
=
n_local
+
n_block_data_idx_on_grid
;
auto
global_elem_id_raw
=
z_random_matrix_offset
+
m_global
*
raw_n_padded
+
n_global
;
// unique element global 1d id
auto
global_tile_id
=
z_random_matrix_offset
+
(
m_global
/
DropoutTile
)
*
DropoutTile
*
raw_n_padded
+
(
n_global
/
DropoutTile
)
*
DropoutTile
;
auto
global_elem_id
=
(
global_elem_id_raw
%
M4
)
*
raw_n_padded
+
(
global_elem_id_raw
/
M4
)
*
M4
;
auto
global_elem_id
=
global_tile_id
+
(
wave_m_n_id
[
I0
]
*
M4
)
+
(
n_global
%
DropoutTile
)
*
raw_n_padded
;
// P_dropped
blockwise_dropout
.
template
ApplyDropoutAttnBwd
<
decltype
(
s_slash_p_thread_buf
),
decltype
(
position_offset
),
decltype
(
DropoutTile
),
true
>(
s_slash_p_thread_buf
,
ph
,
global_elem_id
,
raw_n_padded
);
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v2.hpp
View file @
7aa37568
...
...
@@ -121,6 +121,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static
constexpr
auto
B1K0
=
Number
<
Gemm1KPerBlock
/
B1K1Value
>
{};
static
constexpr
auto
B1K1
=
Number
<
B1K1Value
>
{};
static
constexpr
auto
mfma
=
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
static
constexpr
auto
DropoutNThread
=
mfma
.
num_input_blks
;
// 2
// get_random_8x16() generates 8 random numbers each time
static
constexpr
auto
DropoutTile
=
Number
<
DropoutNThread
*
8
>
{};
// 16
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
...
...
@@ -133,10 +138,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const
auto
M
=
z_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
z_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
mfma
=
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
M3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
M4
=
mfma
.
num_input_blks
;
constexpr
auto
M5
=
mfma
.
group_size
;
constexpr
auto
M3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
M4
=
mfma
.
num_input_blks
;
constexpr
auto
M5
=
mfma
.
group_size
;
return
transform_tensor_descriptor
(
z_grid_desc_m_n
,
...
...
@@ -150,9 +154,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
__host__
__device__
static
constexpr
auto
GetPaddedSize
(
const
index_t
size
)
{
constexpr
auto
mfma
=
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
group_size
=
mfma
.
group_size
;
return
math
::
integer_divide_ceil
(
size
,
group_size
)
*
group_size
;
return
math
::
integer_divide_ceil
(
size
,
DropoutTile
)
*
DropoutTile
;
}
__device__
static
auto
GetGemm0WaveIdx
()
...
...
@@ -522,9 +524,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
true
,
// DstResetCoord
NumGemmKPrefetchStage
>
;
static
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
static
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
mfma
.
k_per_blk
);
// Blockwise gemm with transposed XDL output
using
BlockwiseGemm
=
BlockwiseGemmXdlops_v2
<
...
...
@@ -657,8 +657,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
static
constexpr
index_t
GemmKPack
=
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
group_size
;
static
constexpr
index_t
GemmKPack
=
mfma
.
group_size
;
using
BlockwiseGemm
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
...
...
@@ -709,9 +708,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static
constexpr
index_t
GemmMWave
=
BlockSize
/
get_warp_size
()
/
GemmNWave
;
static
constexpr
index_t
GemmNRepeat
=
Gemm2NXdlPerWave
;
static
constexpr
index_t
GemmMRepeat
=
Gemm2_M
/
GemmMWave
/
MPerXdl
;
static
constexpr
index_t
GemmKPack
=
math
::
max
(
math
::
lcm
(
A_K1
,
B_K1
),
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
static
constexpr
index_t
GemmKPack
=
math
::
max
(
math
::
lcm
(
A_K1
,
B_K1
),
mfma
.
k_per_blk
);
using
BBlockSliceLengths
=
Sequence
<
B_K0
,
Gemm2_N
,
B_K1
>
;
using
BThreadClusterLengths
=
...
...
@@ -1554,8 +1551,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
ushort
,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
(),
true
>
z_tenor_buffer
;
z_tenor_buffer
.
Clear
();
z_ten
s
or_buffer
;
z_ten
s
or_buffer
.
Clear
();
auto
z_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_z_grid
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
());
...
...
@@ -1722,7 +1719,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// scaling is already performed in the preceding statements with s_element_op
blockwise_softmax
.
RunWithPreCalcStats
(
s_slash_p_thread_buf
,
lse_thread_buf
);
constexpr
auto
position_offset
=
M3
*
M4
;
// save z to global
if
constexpr
(
IsDropout
)
{
...
...
@@ -1737,23 +1733,27 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
auto
m_global
=
m_local
+
m_block_data_idx_on_grid
;
auto
n_global
=
n_local
+
n_block_data_idx_on_grid
;
auto
global_elem_id_raw
=
z_random_matrix_offset
+
m_global
*
raw_n_padded
+
n_global
;
// unique element global 1d id
auto
global_tile_id
=
z_random_matrix_offset
+
(
m_global
/
DropoutTile
)
*
DropoutTile
*
raw_n_padded
+
(
n_global
/
DropoutTile
)
*
DropoutTile
;
auto
global_elem_id
=
(
global_elem_id_raw
%
M4
)
*
raw_n_padded
+
(
global_elem_id_raw
/
M4
)
*
M4
;
auto
global_elem_id
=
global_tile_id
+
(
wave_m_n_id
[
I0
]
*
M4
)
+
(
n_global
%
DropoutTile
)
*
raw_n_padded
;
blockwise_dropout
.
template
ApplyDropoutAttnBwdSaveZ
<
decltype
(
s_slash_p_thread_buf
),
decltype
(
z_tenor_buffer
),
decltype
(
position_offset
),
true
>(
s_slash_p_thread_buf
,
ph
,
global_elem_id
,
z_tenor_buffer
,
raw_n_padded
);
decltype
(
z_tensor_buffer
),
decltype
(
DropoutTile
),
true
>(
s_slash_p_thread_buf
,
ph
,
global_elem_id
,
z_tensor_buffer
,
raw_n_padded
);
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tenor_buffer
,
z_ten
s
or_buffer
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
z_grid_buf
);
}
...
...
@@ -1769,14 +1769,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
auto
m_global
=
m_local
+
m_block_data_idx_on_grid
;
auto
n_global
=
n_local
+
n_block_data_idx_on_grid
;
auto
global_elem_id_raw
=
z_random_matrix_offset
+
m_global
*
raw_n_padded
+
n_global
;
// unique element global 1d id
auto
global_tile_id
=
z_random_matrix_offset
+
(
m_global
/
DropoutTile
)
*
DropoutTile
*
raw_n_padded
+
(
n_global
/
DropoutTile
)
*
DropoutTile
;
auto
global_elem_id
=
global_tile_id
+
(
wave_m_n_id
[
I0
]
*
M4
)
+
(
n_global
%
DropoutTile
)
*
raw_n_padded
;
auto
global_elem_id
=
(
global_elem_id_raw
%
M4
)
*
raw_n_padded
+
(
global_elem_id_raw
/
M4
)
*
M4
;
// P_dropped
blockwise_dropout
.
template
ApplyDropoutAttnBwd
<
decltype
(
s_slash_p_thread_buf
),
decltype
(
position_offset
),
decltype
(
DropoutTile
),
true
>(
s_slash_p_thread_buf
,
ph
,
global_elem_id
,
raw_n_padded
);
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v1.hpp
View file @
7aa37568
...
...
@@ -109,6 +109,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static
constexpr
auto
Gemm0MWaves
=
MPerBlock
/
(
MPerXdl
*
MXdlPerWave
);
static
constexpr
auto
Gemm0NWaves
=
NPerBlock
/
(
NPerXdl
*
NXdlPerWave
);
static
constexpr
auto
mfma
=
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
static
constexpr
auto
DropoutNThread
=
mfma
.
num_input_blks
;
// 2
// get_random_8x16() generates 8 random numbers each time
static
constexpr
auto
DropoutTile
=
Number
<
DropoutNThread
*
8
>
{};
// 16
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
// C desc for source in blockwise copy
...
...
@@ -118,10 +123,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const
auto
M
=
z_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
z_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
mfma
=
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
M3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
M4
=
mfma
.
num_input_blks
;
constexpr
auto
M5
=
mfma
.
group_size
;
constexpr
auto
M3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
M4
=
mfma
.
num_input_blks
;
constexpr
auto
M5
=
mfma
.
group_size
;
return
transform_tensor_descriptor
(
z_grid_desc_m_n
,
...
...
@@ -135,9 +139,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
__host__
__device__
static
constexpr
auto
GetPaddedSize
(
const
index_t
size
)
{
constexpr
auto
mfma
=
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
group_size
=
mfma
.
group_size
;
return
math
::
integer_divide_ceil
(
size
,
group_size
)
*
group_size
;
return
math
::
integer_divide_ceil
(
size
,
DropoutTile
)
*
DropoutTile
;
}
__device__
static
auto
GetGemm0WaveIdx
()
...
...
@@ -563,9 +565,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
BBlockDesc_BK0_N_BK1
{});
}
static
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
static
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
mfma
.
k_per_blk
);
// Blockwise gemm with transposed XDL output
using
BlockwiseGemm
=
BlockwiseGemmXdlops_v2
<
...
...
@@ -667,8 +667,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
static
constexpr
index_t
GemmKPack
=
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
group_size
;
static
constexpr
index_t
GemmKPack
=
mfma
.
group_size
;
static
constexpr
index_t
GemmMWave
=
Gemm0NWaves
;
// 4 // 4
static
constexpr
index_t
GemmNWave
=
Gemm0MWaves
;
// 1 // 1
...
...
@@ -791,9 +790,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static
constexpr
index_t
GemmNRepeat
=
Gemm2NXdlPerWave
;
// 1 // 1
static
constexpr
index_t
GemmMRepeat
=
Gemm2_M
/
GemmMWave
/
MPerXdl
;
// 1 // 1
static
constexpr
index_t
GemmKLoop
=
Gemm2_K
/
Sum_K
;
// 2 // 2
static
constexpr
index_t
GemmKPack
=
math
::
max
(
A_K1
,
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
static
constexpr
index_t
B_K3
=
GemmKPack
;
// 8
static
constexpr
index_t
GemmKPack
=
math
::
max
(
A_K1
,
mfma
.
k_per_blk
);
static
constexpr
index_t
B_K3
=
GemmKPack
;
// 8
static
constexpr
index_t
B_K2
=
XdlopsGemm
<
GemmDataType
,
MPerXdl
,
NPerXdl
,
GemmKPack
,
false
>
{}.
K0PerXdlops
;
// 2
static
constexpr
index_t
B_K1
=
Sum_K
/
B_K2
/
B_K3
;
// 4
...
...
@@ -1621,8 +1619,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
ushort
,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
(),
true
>
z_tenor_buffer
;
z_tenor_buffer
.
Clear
();
z_ten
s
or_buffer
;
z_ten
s
or_buffer
.
Clear
();
auto
z_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_z_grid
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
());
...
...
@@ -1946,7 +1944,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// scaling is already performed in the preceding statements with s_element_op
blockwise_softmax
.
RunWithPreCalcStats
(
s_slash_p_thread_buf
,
lse_thread_buf
);
constexpr
auto
position_offset
=
M3
*
M4
;
// save z to global
if
constexpr
(
IsDropout
)
{
...
...
@@ -1961,23 +1958,27 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
auto
m_global
=
m_local
+
m_block_data_idx_on_grid
;
auto
n_global
=
n_local
+
n_block_data_idx_on_grid
;
auto
global_elem_id_raw
=
z_random_matrix_offset
+
m_global
*
raw_n_padded
+
n_global
;
// unique element global 1d id
auto
global_tile_id
=
z_random_matrix_offset
+
(
m_global
/
DropoutTile
)
*
DropoutTile
*
raw_n_padded
+
(
n_global
/
DropoutTile
)
*
DropoutTile
;
auto
global_elem_id
=
(
global_elem_id_raw
%
M4
)
*
raw_n_padded
+
(
global_elem_id_raw
/
M4
)
*
M4
;
auto
global_elem_id
=
global_tile_id
+
(
wave_m_n_id
[
I0
]
*
M4
)
+
(
n_global
%
DropoutTile
)
*
raw_n_padded
;
blockwise_dropout
.
template
ApplyDropoutAttnBwdSaveZ
<
decltype
(
s_slash_p_thread_buf
),
decltype
(
z_tenor_buffer
),
decltype
(
position_offset
),
true
>(
s_slash_p_thread_buf
,
ph
,
global_elem_id
,
z_tenor_buffer
,
raw_n_padded
);
decltype
(
z_tensor_buffer
),
decltype
(
DropoutTile
),
true
>(
s_slash_p_thread_buf
,
ph
,
global_elem_id
,
z_tensor_buffer
,
raw_n_padded
);
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tenor_buffer
,
z_ten
s
or_buffer
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
z_grid_buf
);
}
...
...
@@ -1993,15 +1994,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
auto
m_global
=
m_local
+
m_block_data_idx_on_grid
;
auto
n_global
=
n_local
+
n_block_data_idx_on_grid
;
auto
global_elem_id_raw
=
z_random_matrix_offset
+
m_global
*
raw_n_padded
+
n_global
;
// unique element global 1d id
auto
global_tile_id
=
z_random_matrix_offset
+
(
m_global
/
DropoutTile
)
*
DropoutTile
*
raw_n_padded
+
(
n_global
/
DropoutTile
)
*
DropoutTile
;
auto
global_elem_id
=
(
global_elem_id_raw
%
M4
)
*
raw_n_padded
+
(
global_elem_id_raw
/
M4
)
*
M4
;
auto
global_elem_id
=
global_tile_id
+
(
wave_m_n_id
[
I0
]
*
M4
)
+
(
n_global
%
DropoutTile
)
*
raw_n_padded
;
// P_dropped
blockwise_dropout
.
template
ApplyDropoutAttnBwd
<
decltype
(
s_slash_p_thread_buf
),
decltype
(
position_offset
),
decltype
(
DropoutTile
),
true
>(
s_slash_p_thread_buf
,
ph
,
global_elem_id
,
raw_n_padded
);
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
View file @
7aa37568
...
...
@@ -120,6 +120,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static
constexpr
auto
B1K0
=
Number
<
Gemm1KPerBlock
/
B1K1Value
>
{};
static
constexpr
auto
B1K1
=
Number
<
B1K1Value
>
{};
static
constexpr
auto
mfma
=
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
static
constexpr
auto
DropoutNThread
=
mfma
.
num_input_blks
;
// 2
// get_random_8x16() generates 8 random numbers each time
static
constexpr
auto
DropoutTile
=
Number
<
DropoutNThread
*
8
>
{};
// 16
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
...
...
@@ -132,10 +137,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const
auto
M
=
z_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
z_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
mfma
=
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
M3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
M4
=
mfma
.
num_input_blks
;
constexpr
auto
M5
=
mfma
.
group_size
;
constexpr
auto
M3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
M4
=
mfma
.
num_input_blks
;
constexpr
auto
M5
=
mfma
.
group_size
;
return
transform_tensor_descriptor
(
z_grid_desc_m_n
,
...
...
@@ -149,9 +153,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
__host__
__device__
static
constexpr
auto
GetPaddedSize
(
const
index_t
size
)
{
constexpr
auto
mfma
=
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
group_size
=
mfma
.
group_size
;
return
math
::
integer_divide_ceil
(
size
,
group_size
)
*
group_size
;
return
math
::
integer_divide_ceil
(
size
,
DropoutTile
)
*
DropoutTile
;
}
__device__
static
auto
GetGemm0WaveIdx
()
...
...
@@ -543,9 +545,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
true
,
// DstResetCoord
NumGemmKPrefetchStage
>
;
static
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
static
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
mfma
.
k_per_blk
);
// Blockwise gemm with transposed XDL output
using
BlockwiseGemm
=
BlockwiseGemmXdlops_v2
<
...
...
@@ -678,8 +678,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
static
constexpr
index_t
GemmKPack
=
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
group_size
;
static
constexpr
index_t
GemmKPack
=
mfma
.
group_size
;
using
BlockwiseGemm
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
...
...
@@ -730,9 +729,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static
constexpr
index_t
GemmMWave
=
BlockSize
/
get_warp_size
()
/
GemmNWave
;
static
constexpr
index_t
GemmNRepeat
=
Gemm2NXdlPerWave
;
static
constexpr
index_t
GemmMRepeat
=
Gemm2_M
/
GemmMWave
/
MPerXdl
;
static
constexpr
index_t
GemmKPack
=
math
::
max
(
math
::
lcm
(
A_K1
,
B_K1
),
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
static
constexpr
index_t
GemmKPack
=
math
::
max
(
math
::
lcm
(
A_K1
,
B_K1
),
mfma
.
k_per_blk
);
using
BBlockSliceLengths
=
Sequence
<
B_K0
,
Gemm2_N
,
B_K1
>
;
using
BThreadClusterLengths
=
...
...
@@ -1582,8 +1579,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
ushort
,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
(),
true
>
z_tenor_buffer
;
z_tenor_buffer
.
Clear
();
z_ten
s
or_buffer
;
z_ten
s
or_buffer
.
Clear
();
auto
z_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_z_grid
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
());
...
...
@@ -1862,7 +1859,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// scaling is already performed in the preceding statements with s_element_op
blockwise_softmax
.
RunWithPreCalcStats
(
s_slash_p_thread_buf
,
lse_thread_buf
);
constexpr
auto
position_offset
=
M3
*
M4
;
// save z to global
if
constexpr
(
IsDropout
)
{
...
...
@@ -1877,23 +1873,27 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
auto
m_global
=
m_local
+
m_block_data_idx_on_grid
;
auto
n_global
=
n_local
+
n_block_data_idx_on_grid
;
auto
global_elem_id_raw
=
z_random_matrix_offset
+
m_global
*
raw_n_padded
+
n_global
;
// unique element global 1d id
auto
global_tile_id
=
z_random_matrix_offset
+
(
m_global
/
DropoutTile
)
*
DropoutTile
*
raw_n_padded
+
(
n_global
/
DropoutTile
)
*
DropoutTile
;
auto
global_elem_id
=
(
global_elem_id_raw
%
M4
)
*
raw_n_padded
+
(
global_elem_id_raw
/
M4
)
*
M4
;
auto
global_elem_id
=
global_tile_id
+
(
wave_m_n_id
[
I0
]
*
M4
)
+
(
n_global
%
DropoutTile
)
*
raw_n_padded
;
blockwise_dropout
.
template
ApplyDropoutAttnBwdSaveZ
<
decltype
(
s_slash_p_thread_buf
),
decltype
(
z_tenor_buffer
),
decltype
(
position_offset
),
true
>(
s_slash_p_thread_buf
,
ph
,
global_elem_id
,
z_tenor_buffer
,
raw_n_padded
);
decltype
(
z_tensor_buffer
),
decltype
(
DropoutTile
),
true
>(
s_slash_p_thread_buf
,
ph
,
global_elem_id
,
z_tensor_buffer
,
raw_n_padded
);
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tenor_buffer
,
z_ten
s
or_buffer
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
z_grid_buf
);
}
...
...
@@ -1909,14 +1909,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
auto
m_global
=
m_local
+
m_block_data_idx_on_grid
;
auto
n_global
=
n_local
+
n_block_data_idx_on_grid
;
auto
global_elem_id_raw
=
z_random_matrix_offset
+
m_global
*
raw_n_padded
+
n_global
;
// unique element global 1d id
auto
global_tile_id
=
z_random_matrix_offset
+
(
m_global
/
DropoutTile
)
*
DropoutTile
*
raw_n_padded
+
(
n_global
/
DropoutTile
)
*
DropoutTile
;
auto
global_elem_id
=
global_tile_id
+
(
wave_m_n_id
[
I0
]
*
M4
)
+
(
n_global
%
DropoutTile
)
*
raw_n_padded
;
auto
global_elem_id
=
(
global_elem_id_raw
%
M4
)
*
raw_n_padded
+
(
global_elem_id_raw
/
M4
)
*
M4
;
// P_dropped
blockwise_dropout
.
template
ApplyDropoutAttnBwd
<
decltype
(
s_slash_p_thread_buf
),
decltype
(
position_offset
),
decltype
(
DropoutTile
),
true
>(
s_slash_p_thread_buf
,
ph
,
global_elem_id
,
raw_n_padded
);
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v1.hpp
View file @
7aa37568
...
...
@@ -873,8 +873,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
unsigned
short
,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetElementSpaceSize
(),
true
>
z_tenor_buffer
;
z_tenor_buffer
.
Clear
();
z_ten
s
or_buffer
;
z_ten
s
or_buffer
.
Clear
();
// z matrix global desc
auto
z_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
...
...
@@ -1022,16 +1022,16 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
{
static_for
<
0
,
n0
,
1
>
{}([
&
](
auto
i
)
{
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
acc_thread_buf
),
decltype
(
z_tenor_buffer
),
decltype
(
z_ten
s
or_buffer
),
false
,
decltype
(
n0
),
decltype
(
i
)>(
acc_thread_buf
,
ph
,
z_tenor_buffer
);
acc_thread_buf
,
ph
,
z_ten
s
or_buffer
);
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tenor_buffer
,
z_ten
s
or_buffer
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_buf
);
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2.hpp
View file @
7aa37568
...
...
@@ -60,6 +60,7 @@ template <typename FloatAB,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
index_t
Gemm1NXdlPerWave
,
index_t
DropoutStepValue
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
...
...
@@ -113,6 +114,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
static
constexpr
auto
I8
=
Number
<
8
>
{};
static
constexpr
auto
I9
=
Number
<
9
>
{};
static
constexpr
auto
WaveSize
=
64
;
...
...
@@ -130,54 +133,76 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
static
constexpr
auto
B1K0
=
Number
<
Gemm1KPerBlock
/
B1K1Value
>
{};
static
constexpr
auto
B1K1
=
Number
<
B1K1Value
>
{};
static
constexpr
auto
mfma
=
MfmaSelector
<
FloatGemm
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
static
constexpr
auto
DropoutNThread
=
mfma
.
num_input_blks
;
// 2
// get_random_8x16() generates 8 random numbers each time
static
constexpr
auto
DropoutTile
=
Number
<
DropoutNThread
*
8
>
{};
// 16
static
constexpr
auto
DropoutMThread
=
DropoutTile
;
// 16
static
constexpr
auto
DropoutTilePerXdl
=
NPerXdl
/
DropoutTile
;
// 2
static
constexpr
auto
DropoutStep
=
Number
<
DropoutStepValue
>
{};
// 1 2 4
static
constexpr
auto
DropoutNRepeat
=
Number
<
math
::
integer_divide_ceil
(
DropoutStep
,
DropoutTilePerXdl
)
>
{};
// 1 1 2
static
constexpr
auto
DropoutGroupPerTile
=
Number
<
mfma
.
num_groups_per_blk
/
DropoutTilePerXdl
>
{};
// 2
static
constexpr
auto
DropoutStepPerXdl
=
Number
<
math
::
min
(
DropoutStep
,
DropoutTilePerXdl
)
>
{};
// 1 2 2
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
>
())
>
;
// C desc for source in gridwise copy
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6
(
const
ZGridDesc_M_N
&
z_grid_desc_m_n
)
////=> for z use
{
const
auto
M
=
z_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
z_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
mfma
=
MfmaSelector
<
FloatGemm
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
N3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
N4
=
mfma
.
num_input_blks
;
constexpr
auto
N5
=
mfma
.
group_size
;
const
auto
M0
=
M
/
MPerBlock
;
const
auto
N0
=
N
/
(
DropoutNRepeat
*
NPerXdl
);
constexpr
auto
M1
=
MXdlPerWave
;
constexpr
auto
N1
=
DropoutNRepeat
;
constexpr
auto
M2
=
Gemm0MWaves
;
constexpr
auto
N2
=
Gemm0NWaves
;
constexpr
auto
M3
=
DropoutTilePerXdl
;
constexpr
auto
N3
=
DropoutStepPerXdl
;
constexpr
auto
M4
=
DropoutTile
;
constexpr
auto
N4
=
DropoutGroupPerTile
;
constexpr
auto
N5
=
mfma
.
num_input_blks
;
constexpr
auto
N6
=
mfma
.
group_size
;
return
transform_tensor_descriptor
(
z_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M
/
MPerBlock
,
MXdlPerWave
,
Gemm0MWaves
,
MPerXdl
)),
make_unmerge_transform
(
make_tuple
(
N
/
NPerBlock
,
NXdlPerWave
,
Gemm0NWaves
,
N3
,
N4
,
N5
))),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
)),
make_unmerge_transform
(
make_tuple
(
N0
,
N1
,
N2
,
N3
,
N4
,
N5
,
N6
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
>
{},
Sequence
<
1
,
3
,
5
,
7
,
8
,
9
>
{}));
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
,
8
>
{},
Sequence
<
1
,
3
,
5
,
7
,
9
,
10
,
11
>
{}));
}
__host__
__device__
static
constexpr
auto
GetZShuffleBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
()
__host__
__device__
static
constexpr
auto
GetZShuffleBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
()
{
constexpr
auto
mfma
=
MfmaSelector
<
FloatGemm
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
M0
=
MXdlPerWave
;
constexpr
auto
M1
=
Gemm0MWaves
;
constexpr
auto
N1
=
Gemm0NWaves
;
constexpr
auto
M2
=
MPerXdl
;
constexpr
auto
N2
=
mfma
.
num_groups_per_blk
;
constexpr
auto
N3
=
mfma
.
num_input_blks
;
constexpr
auto
N4
=
mfma
.
group_size
;
constexpr
auto
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
M0
,
I1
,
M1
,
N1
,
M2
,
N2
,
N3
,
N4
));
return
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
;
constexpr
auto
M0
=
MXdlPerWave
;
constexpr
auto
N0
=
DropoutNRepeat
;
constexpr
auto
M1
=
Gemm0MWaves
;
constexpr
auto
N1
=
Gemm0NWaves
;
constexpr
auto
M2
=
DropoutTilePerXdl
;
constexpr
auto
N2
=
DropoutStepPerXdl
;
constexpr
auto
M3
=
DropoutTile
;
constexpr
auto
N3
=
DropoutGroupPerTile
;
constexpr
auto
N4
=
mfma
.
num_input_blks
;
constexpr
auto
N5
=
mfma
.
group_size
;
constexpr
auto
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
M0
,
N0
,
M1
,
N1
,
M2
,
N2
,
M3
,
N3
,
N4
,
N5
));
return
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
;
}
__host__
__device__
static
constexpr
auto
GetPaddedSize
(
const
index_t
size
)
{
constexpr
auto
mfma
=
MfmaSelector
<
FloatGemm
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
group_size
=
mfma
.
group_size
;
return
math
::
integer_divide_ceil
(
size
,
group_size
)
*
group_size
;
return
math
::
integer_divide_ceil
(
size
,
DropoutTile
)
*
DropoutTile
;
}
__device__
static
auto
GetGemm0WaveIdx
()
...
...
@@ -434,10 +459,9 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
const
auto
M
=
d0_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
d0_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
mfma
=
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
N3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
N4
=
mfma
.
num_input_blks
;
constexpr
auto
N5
=
mfma
.
group_size
;
constexpr
auto
N3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
N4
=
mfma
.
num_input_blks
;
constexpr
auto
N5
=
mfma
.
group_size
;
return
transform_tensor_descriptor
(
d0_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
...
...
@@ -468,8 +492,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}))
>
;
using
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
ZGridDesc_M_N
{}))
>
;
using
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_
M4_
N4_N5
_N6
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_
M4_
N4_N5
_N6
(
ZGridDesc_M_N
{}))
>
;
struct
SharedMemTrait
{
...
...
@@ -507,10 +531,10 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
// LDS allocation for Z shuffle in LDS
static
constexpr
auto
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
GetZShuffleBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
static
constexpr
auto
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_
m3_
n3_n4
_n5
=
GetZShuffleBlockDescriptor_M0_N0_M1_N1_M2_N2_
M3_
N3_N4
_N5
();
static
constexpr
auto
z_shuffle_block_space_size
=
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetElementSpaceSize
();
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_
m3_
n3_n4
_n5
.
GetElementSpaceSize
();
};
template
<
bool
HasMainKBlockLoop
,
...
...
@@ -538,8 +562,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
const
B1GridDesc_BK0_N_BK1
&
b1_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_
M4_
N4_N5
_N6
&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_
m4_
n4_n5
_n6
,
const
LSEGridDesc_M
&
lse_grid_desc_m
,
const
Block2CTileMap
&
block_2_ctile_map
,
const
C0MatrixMask
&
c0_matrix_mask
,
...
...
@@ -661,9 +685,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
// acc1[m][o] += acc[m][n] * B1[n][o]
// sanity check
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
FloatGemm
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
mfma
.
k_per_blk
);
auto
blockwise_gemm
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
...
...
@@ -823,8 +845,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
constexpr
index_t
Gemm1KPack
=
MfmaSelector
<
FloatGemm
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
group_size
;
constexpr
index_t
Gemm1KPack
=
mfma
.
group_size
;
auto
gemm1_blockwise_gemm
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
...
...
@@ -1008,67 +1029,75 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
},
Number
<
NumD0Tensor
>
{});
constexpr
auto
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
=
// for blockwise copy
make_naive_tensor_descriptor_packed
(
make_tuple
(
m0
,
// MRepeat
DropoutNRepeat
,
// NRepeat
m1
,
// MWaveId
n1
,
// NWaveId
I1
,
DropoutStepPerXdl
,
m2
,
DropoutGroupPerTile
,
n3
,
n4
));
// RegisterNum
constexpr
auto
z_shuffle_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
=
// for blockwise copy
make_naive_tensor_descriptor_packed
(
make_tuple
(
m0
,
// MRepeat
DropoutNRepeat
,
// NRepeat
m1
,
// MWaveId
n1
,
// NWaveId
I1
,
DropoutStepPerXdl
,
DropoutGroupPerTile
,
n3
,
n4
,
// RegisterNum
m2
));
// z is random number matrix for dropout verify
//
// z vgpr copy to global
//
// z matrix threadwise desc
constexpr
auto
z_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
// for blockwise copy
make_naive_tensor_descriptor_packed
(
make_tuple
(
m0
,
// MRepeat
I1
,
// NRepeat
m1
,
// MWaveId
n1
,
// NWaveId
m2
,
// MPerXdl
n2
,
// NGroupNum
n3
,
// NInputNum
n4
));
// RegisterNum
constexpr
auto
z_shuffle_thread_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4
=
// for blockwise copy
make_naive_tensor_descriptor_packed
(
make_tuple
(
m0
,
// MRepeat
I1
,
// NRepeat
m1
,
// MWaveId
n1
,
// NWaveId
m2
,
// MPerXdl
n2
,
// NGroupNum
n3
,
// NInputNum
n4
,
// RegisterNum
I1
));
// I1
constexpr
auto
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
// MBlockId
I1
,
// NBlockId
m0
,
// MRepeat
I1
,
// NRepeat
m1
,
// MWaveId
n1
,
// NWaveId
m2
,
// MPerXdl
n2
,
// NGroupNum
n3
,
// NInputNum
constexpr
auto
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
// MBlockId
I1
,
// NBlockId
m0
,
// MRepeat
DropoutNRepeat
,
// NRepeat
m1
,
// MWaveId
n1
,
// NWaveId
I1
,
DropoutStepPerXdl
,
m2
,
DropoutGroupPerTile
,
n3
,
n4
));
// RegisterNum
constexpr
auto
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
GetZShuffleBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
constexpr
auto
ZM0
=
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I0
);
constexpr
auto
ZN0
=
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I1
);
constexpr
auto
ZM1
=
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I2
);
constexpr
auto
ZN1
=
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I3
);
constexpr
auto
ZM2
=
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I4
);
constexpr
auto
ZN2
=
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I5
);
constexpr
auto
ZN3
=
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I6
);
constexpr
auto
ZN4
=
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I7
);
constexpr
auto
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4
=
constexpr
auto
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
=
GetZShuffleBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
();
constexpr
auto
ZM0
=
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetLength
(
I0
);
// 1
constexpr
auto
ZN0
=
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetLength
(
I1
);
// 1 1 2
constexpr
auto
ZM1
=
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetLength
(
I2
);
// 4
constexpr
auto
ZN1
=
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetLength
(
I3
);
// 1
constexpr
auto
ZM2
=
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetLength
(
I4
);
// 2
constexpr
auto
ZN2
=
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetLength
(
I5
);
// 1 2 2
constexpr
auto
ZM3
=
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetLength
(
I6
);
// 16
constexpr
auto
ZN3
=
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetLength
(
I7
);
// 2
constexpr
auto
ZN4
=
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetLength
(
I8
);
// 2
constexpr
auto
ZN5
=
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetLength
(
I9
);
// 4
constexpr
auto
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
=
transform_tensor_descriptor
(
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_
m3_
n3_n4
_n5
,
make_tuple
(
make_pass_through_transform
(
ZM0
),
make_pass_through_transform
(
ZN0
),
make_pass_through_transform
(
ZM1
),
make_pass_through_transform
(
ZN1
),
make_
unmerge_transform
(
make_tuple
(
ZM2
/
ZN4
,
ZN4
)
),
make_
pass_through_transform
(
ZM2
),
make_pass_through_transform
(
ZN2
),
make_
pass_through_transform
(
ZN3
),
make_
pass_through_transform
(
ZN4
)),
make_
unmerge_transform
(
make_tuple
(
ZM3
/
ZN4
/
ZN5
,
ZN4
,
ZN5
)
),
make_
merge_transform_v3_division_mod
(
make_tuple
(
ZN3
,
ZN4
,
ZN5
)
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
...
...
@@ -1076,112 +1105,130 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
Sequence
<
4
>
{},
Sequence
<
5
>
{},
Sequence
<
6
>
{},
Sequence
<
7
>
{}),
Sequence
<
7
,
8
,
9
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
,
7
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{},
Sequence
<
6
>
{},
Sequence
<
8
>
{}));
Sequence
<
6
,
7
,
8
>
{},
Sequence
<
9
>
{}));
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ushort
,
z_shuffle_thread_desc_m0_n0_m1_n1_m2_n2_
n
3_m
3
_n
4
.
GetElementSpaceSize
(),
z_shuffle_thread_desc_m0_n0_m1_n1_m2_n2_
m
3_m
4_m5
_n
3
.
GetElementSpaceSize
(),
true
>
z_tensor_buffer
;
z_tensor_buffer
.
Clear
();
auto
z_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_z_grid
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetElementSpaceSize
());
p_z_grid
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_
m4_
n4_n5
_n6
.
GetElementSpaceSize
());
auto
z_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
ushort
*>
(
p_shared
),
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetElementSpaceSize
());
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_
m3_
n3_n4
_n5
.
GetElementSpaceSize
());
auto
z_tmp_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
ushort
,
ushort
,
decltype
(
z_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
decltype
(
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
decltype
(
z_thread_desc_m0_n0_m1_n1_m2_n2_
m3_
n3_n4
_n5
),
decltype
(
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_
m3_
n3_n4
_n5
),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
m0
,
// MRepeat
I1
,
// NRepeat
m1
,
// MWaveId
n1
,
// NWaveId
m2
,
// MPerXdl
n2
,
// NGroupNum
n3
,
// NInputNum
Sequence
<
m0
,
// MRepeat
DropoutNRepeat
,
// NRepeat
m1
,
// MWaveId
n1
,
// NWaveId
I1
,
DropoutStepPerXdl
,
m2
,
DropoutGroupPerTile
,
n3
,
n4
>
,
// RegisterNum
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
7
,
// DstVectorDim
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
>
,
9
,
// DstVectorDim
1
,
// DstScalarPerVector
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarStrideInVector
true
>
{
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
make_multi_index
(
0
,
// MRepeat
0
,
// NRepeat
wave_id
[
I0
],
// MWaveId
wave_id
[
I1
],
// NWaveId
wave_m_n_id
[
I1
],
// MPerXdl
0
,
// NGroupIndex
wave_m_n_id
[
I0
],
// NInputIndex
true
>
{
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
0
,
// MRepeat
0
,
// NRepeat
wave_id
[
I0
],
// MWaveId
wave_id
[
I1
],
// NWaveId
wave_m_n_id
[
I1
]
/
DropoutMThread
,
0
,
wave_m_n_id
[
I1
]
%
DropoutMThread
,
0
,
wave_m_n_id
[
I0
],
0
),
tensor_operation
::
element_wise
::
PassThrough
{}};
auto
z_shuffle_thread_copy_lds_to_vgpr
=
ThreadwiseTensorSliceTransfer_v2
<
ushort
,
ushort
,
decltype
(
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4
),
decltype
(
z_shuffle_thread_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4
),
Sequence
<
m0
,
I1
,
m1
,
n1
,
m2
,
n2
,
n3
,
n4
,
I1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
>
,
8
,
decltype
(
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
decltype
(
z_shuffle_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
Sequence
<
m0
,
DropoutNRepeat
,
m1
,
n1
,
I1
,
DropoutStepPerXdl
,
DropoutGroupPerTile
,
n3
,
n4
,
m2
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
>
,
9
,
1
,
1
,
true
>
{
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_
n
3_m
3
_n
4
,
true
>
{
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_
m
3_m
4_m5
_n
3
,
make_multi_index
(
0
,
// MRepeat
0
,
// NRepeat
wave_id
[
I0
],
// MWaveId
wave_id
[
I1
],
// NWaveId
wave_m_n_id
[
I1
]
/
ZN4
,
wave_m_n_id
[
I1
]
/
DropoutMThread
,
0
,
0
,
wave_m_n_id
[
I0
],
0
,
wave_m_n_id
[
I1
]
%
ZN4
)};
wave_m_n_id
[
I1
]
%
DropoutMThread
)};
auto
z_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
ushort
,
ZDataType
,
decltype
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
),
decltype
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
),
decltype
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_
m4_
n4_n5
_n6
),
decltype
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_
m4_
n4_n5
_n6
),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
I1
,
// MBlockId
I1
,
// NBlockID
m0
,
// MRepeat
I1
,
// NRepeat
m1
,
// MWaveId
n1
,
// NWaveId
m2
,
// MPerXdl
n2
,
// NGroupNum
n3
,
// NInputNum
Sequence
<
I1
,
// MBlockId
I1
,
// NBlockID
m0
,
// MRepeat
DropoutNRepeat
,
// NRepeat
m1
,
// MWaveId
n1
,
// NWaveId
I1
,
DropoutStepPerXdl
,
m2
,
DropoutGroupPerTile
,
n3
,
n4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
>
,
9
,
// DstVectorDim
1
,
// DstScalarPerVector
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
>
,
11
,
// DstVectorDim
1
,
// DstScalarPerVector
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarStrideInVector
true
>
{
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
true
>
{
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_
m4_
n4_n5
_n6
,
make_multi_index
(
block_work_idx_m
,
// MBlockId
0
,
// NBlockId
0
,
// mrepeat
0
,
// nrepeat
wave_id
[
I0
],
// MWaveId
wave_id
[
I1
],
// NWaveId
wave_m_n_id
[
I1
],
// MPerXdl
0
,
// group
wave_m_n_id
[
I0
],
// NInputIndex
wave_m_n_id
[
I1
]
/
DropoutMThread
,
0
,
wave_m_n_id
[
I1
]
%
DropoutMThread
,
0
,
wave_m_n_id
[
I0
],
0
),
tensor_operation
::
element_wise
::
PassThrough
{}};
...
...
@@ -1308,8 +1355,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
blockwise_softmax
.
Run
(
acc_thread_buf
,
workspace_buf
);
constexpr
auto
position
_offset
=
N
3
*
N4
;
constexpr
auto
iterator_
offset
=
n2
*
n3
*
n4
;
constexpr
auto
iterator
_offset
=
N
umber
<
8
*
DropoutStep
>
{}
;
constexpr
auto
iterator_
step
=
Number
<
n0
*
n1
*
n2
*
n3
*
n4
/
8
/
DropoutStep
>
{}
;
if
constexpr
(
IsDropout
)
// dropout
{
...
...
@@ -1326,49 +1373,44 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
n_global
;
// unique element global 1d id
blockwise_dropout
.
template
GenerateZMatrixAttnFwd
<
decltype
(
z_tensor_buffer
),
decltype
(
n0
),
decltype
(
position_offset
)>(
decltype
(
iterator_step
),
decltype
(
DropoutTile
)>(
ph
,
global_elem_id
,
z_tensor_buffer
);
z_tmp_thread_copy_vgpr_to_lds
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tensor_buffer
,
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
z_block_buf
);
z_tmp_thread_copy_vgpr_to_lds
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tensor_buffer
,
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_block_buf
);
z_shuffle_thread_copy_lds_to_vgpr
.
Run
(
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_
n
3_m
3
_n
4
,
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_
m
3_m
4_m5
_n
3
,
z_block_buf
,
z_shuffle_thread_desc_m0_n0_m1_n1_m2_n2_
n
3_m
3
_n
4
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_shuffle_thread_desc_m0_n0_m1_n1_m2_n2_
m
3_m
4_m5
_n
3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tensor_buffer
);
blockwise_dropout
.
template
ApplyDropoutWithZ
<
decltype
(
acc_thread_buf
),
decltype
(
z_tensor_buffer
),
decltype
(
n0
),
decltype
(
iterator_step
),
decltype
(
i
)>(
acc_thread_buf
,
z_tensor_buffer
);
// save z to global
if
(
p_z_grid
)
if
(
p_z_grid
&&
(
gemm1_n_block_data_idx_on_grid
==
0
)
)
{
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_
m4_
n4_n5
_n6
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tensor_buffer
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_
m4_
n4_n5
_n6
,
z_grid_buf
);
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
));
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_
m4_
n4_n5
_n6
,
make_multi_index
(
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
}
});
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
0
,
0
,
0
,
-
(
n0
.
value
),
0
,
0
,
0
,
0
,
0
,
0
));
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
}
// TODO: may convert to log domain
...
...
@@ -1489,7 +1531,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
static_for
<
0
,
MXdlPerWave
,
1
>
{}(
[
&
](
auto
I
)
{
lse_thread_buf
(
I
)
=
running_max
(
I
)
+
math
::
log
(
running_sum
(
I
));
});
if
(
get_lane_local_1d_id
()
<
AccM2
)
if
(
(
get_lane_local_1d_id
()
<
AccM2
)
&&
(
gemm1_n_block_data_idx_on_grid
==
0
))
{
static_for
<
0
,
MXdlPerWave
,
1
>
{}([
&
](
auto
I
)
{
// copy from VGPR to Global
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment