Commit 7aa37568 authored by danyao12's avatar danyao12
Browse files

qloop dropout optimize

parent 4274096b
...@@ -121,6 +121,7 @@ using DeviceGemmInstance = ...@@ -121,6 +121,7 @@ using DeviceGemmInstance =
1, // MXdlPerWave 1, // MXdlPerWave
4, // NXdlPerWave 4, // NXdlPerWave
1, // Gemm1NXdlPerWave 1, // Gemm1NXdlPerWave
1, // DropoutStep
S<4, 64, 1>, // ABlockTransfer S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
...@@ -194,6 +195,7 @@ using DeviceGemmInstance = ...@@ -194,6 +195,7 @@ using DeviceGemmInstance =
1, // MXdlPerWave 1, // MXdlPerWave
4, // NXdlPerWave 4, // NXdlPerWave
2, // Gemm1NXdlPerWave 2, // Gemm1NXdlPerWave
1, // DropoutStep
S<4, 64, 1>, // ABlockTransfer S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
...@@ -257,7 +259,7 @@ using DeviceGemmInstance = ...@@ -257,7 +259,7 @@ using DeviceGemmInstance =
128, // MPerBlock 128, // MPerBlock
128, // NPerBlock 128, // NPerBlock
32, // KPerBlock 32, // KPerBlock
128, // Gemm1NPerBlock 64, // Gemm1NPerBlock
32, // Gemm1KPerBlock 32, // Gemm1KPerBlock
8, // AK1 8, // AK1
8, // BK1 8, // BK1
...@@ -266,7 +268,8 @@ using DeviceGemmInstance = ...@@ -266,7 +268,8 @@ using DeviceGemmInstance =
32, // NPerXDL 32, // NPerXDL
1, // MXdlPerWave 1, // MXdlPerWave
4, // NXdlPerWave 4, // NXdlPerWave
4, // Gemm1NXdlPerWave 2, // Gemm1NXdlPerWave
1, // DropoutStep
S<4, 64, 1>, // ABlockTransfer S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
...@@ -282,7 +285,7 @@ using DeviceGemmInstance = ...@@ -282,7 +285,7 @@ using DeviceGemmInstance =
8, 8,
true, true,
4, 4,
S<8, 32, 1>, // B1BlockTransfer S<16, 16, 1>, // B1BlockTransfer
S<0, 2, 1>, S<0, 2, 1>,
S<0, 2, 1>, S<0, 2, 1>,
1, 1,
......
...@@ -113,11 +113,11 @@ static constexpr bool Deterministic = false; ...@@ -113,11 +113,11 @@ static constexpr bool Deterministic = false;
#if(DIM <= 32) #if(DIM <= 32)
// clang-format off // clang-format off
using DeviceGemmInstanceFWD = using DeviceGemmInstanceFWD =
// #################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| ADataType| BDataType| B1DataType| CDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector| D1BlockTransfer| MaskingSpec| Deterministic| // #################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| ADataType| BDataType| B1DataType| CDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1|Dropout| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector| D1BlockTransfer| MaskingSpec| Deterministic|
// #################################################################################| | | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| _NPerBlock| SrcScalar| | | // #################################################################################| | | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| Step| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| _NPerBlock| SrcScalar| | |
// #################################################################################| | | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | PerVector| | | // #################################################################################| | | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | PerVector| | |
// #################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // #################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, InputDataType, InputDataType, InputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 32, 32, 8, 8, 2, 32, 32, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 4, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, false, 1, 1, S<1, 64, 1, 4>, 8, 4, MaskingSpec, Deterministic>; ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, InputDataType, InputDataType, InputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 32, 32, 8, 8, 2, 32, 32, 1, 4, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 4, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, false, 1, 1, S<1, 64, 1, 4>, 8, 4, MaskingSpec, Deterministic>;
using DeviceGemmInstanceBWD = using DeviceGemmInstanceBWD =
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic| // ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
...@@ -129,11 +129,11 @@ using DeviceGemmInstanceBWD = ...@@ -129,11 +129,11 @@ using DeviceGemmInstanceBWD =
#elif(DIM <= 64) #elif(DIM <= 64)
// clang-format off // clang-format off
using DeviceGemmInstanceFWD = using DeviceGemmInstanceFWD =
// #################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| ADataType| BDataType| B1DataType| CDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector| D1BlockTransfer| MaskingSpec| Deterministic| // #################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| ADataType| BDataType| B1DataType| CDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1|Dropout| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector| D1BlockTransfer| MaskingSpec| Deterministic|
// #################################################################################| | | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| _NPerBlock| SrcScalar| | | // #################################################################################| | | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| Step| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| _NPerBlock| SrcScalar| | |
// #################################################################################| | | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | PerVector| | | // #################################################################################| | | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | PerVector| | |
// #################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // #################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, InputDataType, InputDataType, InputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 4, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, 4, MaskingSpec, Deterministic>; ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, InputDataType, InputDataType, InputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 4, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, 4, MaskingSpec, Deterministic>;
using DeviceGemmInstanceBWD = using DeviceGemmInstanceBWD =
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic| // ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
...@@ -152,11 +152,11 @@ using DeviceGemmInstanceBWD = ...@@ -152,11 +152,11 @@ using DeviceGemmInstanceBWD =
#elif(DIM <= 128) #elif(DIM <= 128)
// clang-format off // clang-format off
using DeviceGemmInstanceFWD = using DeviceGemmInstanceFWD =
// #################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| ADataType| BDataType| B1DataType| CDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector| D1BlockTransfer| MaskingSpec| Deterministic| // #################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| ADataType| BDataType| B1DataType| CDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1|Dropout| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector| D1BlockTransfer| MaskingSpec| Deterministic|
// #################################################################################| | | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| _NPerBlock| SrcScalar| | | // #################################################################################| | | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| Step| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| _NPerBlock| SrcScalar| | |
// #################################################################################| | | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | PerVector| | | // #################################################################################| | | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | PerVector| | |
// #################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // #################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, InputDataType, InputDataType, InputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, 4, MaskingSpec, Deterministic>; ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, InputDataType, InputDataType, InputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 4, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, 4, MaskingSpec, Deterministic>;
using DeviceGemmInstanceBWD = using DeviceGemmInstanceBWD =
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic| // ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
......
...@@ -121,6 +121,7 @@ using DeviceGemmInstance = ...@@ -121,6 +121,7 @@ using DeviceGemmInstance =
1, // MXdlPerWave 1, // MXdlPerWave
4, // NXdlPerWave 4, // NXdlPerWave
1, // Gemm1NXdlPerWave 1, // Gemm1NXdlPerWave
1, // DropoutStep
S<4, 64, 1>, // ABlockTransfer S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
...@@ -194,6 +195,7 @@ using DeviceGemmInstance = ...@@ -194,6 +195,7 @@ using DeviceGemmInstance =
1, // MXdlPerWave 1, // MXdlPerWave
4, // NXdlPerWave 4, // NXdlPerWave
2, // Gemm1NXdlPerWave 2, // Gemm1NXdlPerWave
1, // DropoutStep
S<4, 64, 1>, // ABlockTransfer S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
...@@ -257,7 +259,7 @@ using DeviceGemmInstance = ...@@ -257,7 +259,7 @@ using DeviceGemmInstance =
128, // MPerBlock 128, // MPerBlock
128, // NPerBlock 128, // NPerBlock
32, // KPerBlock 32, // KPerBlock
128, // Gemm1NPerBlock 64, // Gemm1NPerBlock
32, // Gemm1KPerBlock 32, // Gemm1KPerBlock
8, // AK1 8, // AK1
8, // BK1 8, // BK1
...@@ -266,7 +268,8 @@ using DeviceGemmInstance = ...@@ -266,7 +268,8 @@ using DeviceGemmInstance =
32, // NPerXDL 32, // NPerXDL
1, // MXdlPerWave 1, // MXdlPerWave
4, // NXdlPerWave 4, // NXdlPerWave
4, // Gemm1NXdlPerWave 2, // Gemm1NXdlPerWave
1, // DropoutStep
S<4, 64, 1>, // ABlockTransfer S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
...@@ -282,7 +285,7 @@ using DeviceGemmInstance = ...@@ -282,7 +285,7 @@ using DeviceGemmInstance =
8, 8,
true, true,
1, 1,
S<8, 32, 1>, // B1BlockTransfer S<16, 16, 1>, // B1BlockTransfer
S<0, 2, 1>, S<0, 2, 1>,
S<0, 2, 1>, S<0, 2, 1>,
1, 1,
......
...@@ -112,11 +112,11 @@ static constexpr bool Deterministic = false; ...@@ -112,11 +112,11 @@ static constexpr bool Deterministic = false;
#if(DIM <= 32) #if(DIM <= 32)
// clang-format off // clang-format off
using DeviceGemmInstanceFWD = using DeviceGemmInstanceFWD =
// #################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| ADataType| BDataType| B1DataType| CDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector| D1BlockTransfer| MaskingSpec| Deterministic| // #################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| ADataType| BDataType| B1DataType| CDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1|Dropout| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector| D1BlockTransfer| MaskingSpec| Deterministic|
// #################################################################################| | | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| _NPerBlock| SrcScalar| | | // #################################################################################| | | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| Step| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| _NPerBlock| SrcScalar| | |
// #################################################################################| | | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | PerVector| | | // #################################################################################| | | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | PerVector| | |
// #################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // #################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, InputDataType, InputDataType, InputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 32, 32, 8, 8, 2, 32, 32, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, false, 1, 1, S<1, 64, 1, 4>, 8, 1, MaskingSpec, Deterministic>; ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, InputDataType, InputDataType, InputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 32, 32, 8, 8, 2, 32, 32, 1, 4, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, false, 1, 1, S<1, 64, 1, 4>, 8, 1, MaskingSpec, Deterministic>;
using DeviceGemmInstanceBWD = using DeviceGemmInstanceBWD =
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic| // ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
...@@ -128,11 +128,11 @@ using DeviceGemmInstanceBWD = ...@@ -128,11 +128,11 @@ using DeviceGemmInstanceBWD =
#elif(DIM <= 64) #elif(DIM <= 64)
// clang-format off // clang-format off
using DeviceGemmInstanceFWD = using DeviceGemmInstanceFWD =
// #################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| ADataType| BDataType| B1DataType| CDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector| D1BlockTransfer| MaskingSpec| Deterministic| // #################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| ADataType| BDataType| B1DataType| CDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1|Dropout| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector| D1BlockTransfer| MaskingSpec| Deterministic|
// #################################################################################| | | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| _NPerBlock| SrcScalar| | | // #################################################################################| | | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| Step| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| _NPerBlock| SrcScalar| | |
// #################################################################################| | | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | PerVector| | | // #################################################################################| | | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | PerVector| | |
// #################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // #################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, InputDataType, InputDataType, InputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, 1, MaskingSpec, Deterministic>; ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, InputDataType, InputDataType, InputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, 1, MaskingSpec, Deterministic>;
using DeviceGemmInstanceBWD = using DeviceGemmInstanceBWD =
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic| // ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
...@@ -151,11 +151,11 @@ using DeviceGemmInstanceBWD = ...@@ -151,11 +151,11 @@ using DeviceGemmInstanceBWD =
#elif(DIM <= 128) #elif(DIM <= 128)
// clang-format off // clang-format off
using DeviceGemmInstanceFWD = using DeviceGemmInstanceFWD =
// #################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| ADataType| BDataType| B1DataType| CDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector| D1BlockTransfer| MaskingSpec| Deterministic| // #################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| ADataType| BDataType| B1DataType| CDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1|Dropout| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector| D1BlockTransfer| MaskingSpec| Deterministic|
// #################################################################################| | | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| _NPerBlock| SrcScalar| | | // #################################################################################| | | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| Step| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| _NPerBlock| SrcScalar| | |
// #################################################################################| | | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | PerVector| | | // #################################################################################| | | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | PerVector| | |
// #################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // #################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, InputDataType, InputDataType, InputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, 1, MaskingSpec, Deterministic>; ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, InputDataType, InputDataType, InputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, 1, MaskingSpec, Deterministic>;
using DeviceGemmInstanceBWD = using DeviceGemmInstanceBWD =
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic| // ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
......
...@@ -138,12 +138,12 @@ struct BlockwiseDropout ...@@ -138,12 +138,12 @@ struct BlockwiseDropout
constexpr int tmp_size = MRepeat * KRepeat; constexpr int tmp_size = MRepeat * KRepeat;
int philox_calls = tmp_size / 4; int philox_calls = tmp_size / 8;
ushort tmp[tmp_size]; ushort tmp[tmp_size];
for(int i = 0; i < philox_calls; i++) for(int i = 0; i < philox_calls; i++)
{ {
ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * Offset{} * MRaw); ph.get_random_8x16((tmp + i * 8), element_global_1d_id + i * Offset{} * MRaw);
} }
block_sync_lds(); block_sync_lds();
...@@ -179,12 +179,12 @@ struct BlockwiseDropout ...@@ -179,12 +179,12 @@ struct BlockwiseDropout
constexpr int tmp_size = MRepeat * KRepeat; constexpr int tmp_size = MRepeat * KRepeat;
int philox_calls = tmp_size / 4; int philox_calls = tmp_size / 8;
ushort tmp[tmp_size]; ushort tmp[tmp_size];
for(int i = 0; i < philox_calls; i++) for(int i = 0; i < philox_calls; i++)
{ {
ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * Offset{} * MRaw); ph.get_random_8x16((tmp + i * 8), element_global_1d_id + i * Offset{} * MRaw);
} }
block_sync_lds(); block_sync_lds();
...@@ -218,21 +218,19 @@ struct BlockwiseDropout ...@@ -218,21 +218,19 @@ struct BlockwiseDropout
} }
// get raw z matrix with random number for shuffle // get raw z matrix with random number for shuffle
template <typename ZThreadBuffer, template <typename ZThreadBuffer, typename Step, typename Offset>
typename Step,
typename Offset> // N3*N4=8
__host__ __device__ void GenerateZMatrixAttnFwd(ck::philox& ph, __host__ __device__ void GenerateZMatrixAttnFwd(ck::philox& ph,
index_t element_global_1d_id, index_t element_global_1d_id,
ZThreadBuffer& z_thread_buf) ZThreadBuffer& z_thread_buf)
{ {
constexpr int tmp_size = MRepeat * KRepeat / Step{}.value; constexpr int tmp_size = MRepeat * KRepeat / Step{}.value;
int philox_calls = tmp_size / 4; int philox_calls = tmp_size / 8;
ushort tmp[tmp_size]; ushort tmp[tmp_size];
for(int i = 0; i < philox_calls; i++) for(int i = 0; i < philox_calls; i++)
{ {
ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * Offset{}); ph.get_random_8x16((tmp + i * 8), element_global_1d_id + i * Offset{});
} }
static_for<0, tmp_size, 1>{}([&](auto i) { z_thread_buf(i) = tmp[i.value]; }); static_for<0, tmp_size, 1>{}([&](auto i) { z_thread_buf(i) = tmp[i.value]; });
......
...@@ -40,7 +40,7 @@ template <typename GridwiseGemm, ...@@ -40,7 +40,7 @@ template <typename GridwiseGemm,
typename D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5, typename D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
typename B1GridDesc_BK0_N_BK1, typename B1GridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5, typename ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6,
typename LSEGridDescriptor_M, typename LSEGridDescriptor_M,
typename Block2CTileMap, typename Block2CTileMap,
typename ComputeBasePtrOfStridedBatch, typename ComputeBasePtrOfStridedBatch,
...@@ -73,8 +73,8 @@ __global__ void ...@@ -73,8 +73,8 @@ __global__ void
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6,
const LSEGridDescriptor_M lse_grid_desc_m, const LSEGridDescriptor_M lse_grid_desc_m,
const Block2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
const index_t batch_count, const index_t batch_count,
...@@ -141,7 +141,7 @@ __global__ void ...@@ -141,7 +141,7 @@ __global__ void
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6,
lse_grid_desc_m, lse_grid_desc_m,
block_2_ctile_map, block_2_ctile_map,
c0_matrix_mask, c0_matrix_mask,
...@@ -174,7 +174,7 @@ __global__ void ...@@ -174,7 +174,7 @@ __global__ void
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6,
lse_grid_desc_m, lse_grid_desc_m,
block_2_ctile_map, block_2_ctile_map,
c0_matrix_mask, c0_matrix_mask,
...@@ -203,7 +203,7 @@ __global__ void ...@@ -203,7 +203,7 @@ __global__ void
ignore = d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5; ignore = d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5;
ignore = b1_grid_desc_bk0_n_bk1; ignore = b1_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5; ignore = z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6;
ignore = lse_grid_desc_m; ignore = lse_grid_desc_m;
ignore = block_2_ctile_map; ignore = block_2_ctile_map;
ignore = batch_count; ignore = batch_count;
...@@ -263,6 +263,7 @@ template <index_t NumDimG, ...@@ -263,6 +263,7 @@ template <index_t NumDimG,
index_t MXdlPerWave, index_t MXdlPerWave,
index_t NXdlPerWave, index_t NXdlPerWave,
index_t Gemm1NXdlPerWave, index_t Gemm1NXdlPerWave,
index_t DropoutStep,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1, typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder, typename ABlockTransferSrcAccessOrder,
...@@ -564,6 +565,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -564,6 +565,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
MXdlPerWave, MXdlPerWave,
NXdlPerWave, NXdlPerWave,
Gemm1NXdlPerWave, Gemm1NXdlPerWave,
DropoutStep,
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
...@@ -735,8 +737,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -735,8 +737,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
seed_ = std::get<0>(seeds); seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds); offset_ = std::get<1>(seeds);
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_ = z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(z_grid_desc_m_n_); GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6(
z_grid_desc_m_n_);
m_raw_padded_ = GridwiseGemm::GetPaddedSize(raw_lengths_mz_nz_kz_gemm1nz_[0]); m_raw_padded_ = GridwiseGemm::GetPaddedSize(raw_lengths_mz_nz_kz_gemm1nz_[0]);
n_raw_padded_ = GridwiseGemm::GetPaddedSize(raw_lengths_mz_nz_kz_gemm1nz_[1]); n_raw_padded_ = GridwiseGemm::GetPaddedSize(raw_lengths_mz_nz_kz_gemm1nz_[1]);
...@@ -791,8 +794,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -791,8 +794,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_; c_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_; z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6_;
// block-to-c-tile map // block-to-c-tile map
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
...@@ -876,7 +879,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -876,7 +879,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
typename GridwiseGemm::D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5, typename GridwiseGemm::D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
DeviceOp::B1GridDesc_BK0_N_BK1, DeviceOp::B1GridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5, typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6,
DeviceOp::LSEGridDesc_M, DeviceOp::LSEGridDesc_M,
typename GridwiseGemm::DefaultBlock2CTileMap, typename GridwiseGemm::DefaultBlock2CTileMap,
ComputeBasePtrOfStridedBatch, ComputeBasePtrOfStridedBatch,
...@@ -909,7 +912,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -909,7 +912,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
arg.d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_, arg.d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg.b1_grid_desc_bk0_n_bk1_, arg.b1_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_, arg.z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6_,
arg.lse_grid_desc_m_, arg.lse_grid_desc_m_,
arg.block_2_ctile_map_, arg.block_2_ctile_map_,
arg.batch_count_, arg.batch_count_,
......
...@@ -135,7 +135,7 @@ __global__ void ...@@ -135,7 +135,7 @@ __global__ void
arg_ptr[group_id].d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_, arg_ptr[group_id].d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_, arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg_ptr[group_id].z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_, arg_ptr[group_id].z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6_,
arg_ptr[group_id].lse_grid_desc_m_, arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].block_2_ctile_map_, arg_ptr[group_id].block_2_ctile_map_,
arg_ptr[group_id].c0_matrix_mask_, arg_ptr[group_id].c0_matrix_mask_,
...@@ -173,7 +173,7 @@ __global__ void ...@@ -173,7 +173,7 @@ __global__ void
arg_ptr[group_id].d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_, arg_ptr[group_id].d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_, arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg_ptr[group_id].z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_, arg_ptr[group_id].z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6_,
arg_ptr[group_id].lse_grid_desc_m_, arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].block_2_ctile_map_, arg_ptr[group_id].block_2_ctile_map_,
arg_ptr[group_id].c0_matrix_mask_, arg_ptr[group_id].c0_matrix_mask_,
...@@ -244,6 +244,7 @@ template <index_t NumDimG, ...@@ -244,6 +244,7 @@ template <index_t NumDimG,
index_t MXdlPerWave, index_t MXdlPerWave,
index_t NXdlPerWave, index_t NXdlPerWave,
index_t Gemm1NXdlPerWave, index_t Gemm1NXdlPerWave,
index_t DropoutStep,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1, typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder, typename ABlockTransferSrcAccessOrder,
...@@ -566,6 +567,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -566,6 +567,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
MXdlPerWave, MXdlPerWave,
NXdlPerWave, NXdlPerWave,
Gemm1NXdlPerWave, Gemm1NXdlPerWave,
DropoutStep,
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
...@@ -622,8 +624,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -622,8 +624,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_; B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_; c_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_; z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6_;
ZGridDesc_M_N z_grid_desc_m_n_; ZGridDesc_M_N z_grid_desc_m_n_;
LSEGridDesc_M lse_grid_desc_m_; LSEGridDesc_M lse_grid_desc_m_;
...@@ -768,12 +770,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -768,12 +770,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n); c_grid_desc_m_n);
// typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 const auto z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6 =
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5; GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6(
const auto z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(
z_grid_desc_m_n); z_grid_desc_m_n);
const index_t BlockStart = grid_size_; const index_t BlockStart = grid_size_;
...@@ -829,7 +827,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -829,7 +827,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6,
z_grid_desc_m_n, z_grid_desc_m_n,
lse_grid_desc_m, lse_grid_desc_m,
block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n), block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n),
......
...@@ -1533,8 +1533,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1 ...@@ -1533,8 +1533,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
unsigned short, unsigned short,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize(), z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize(),
true> true>
z_tenor_buffer; z_tensor_buffer;
z_tenor_buffer.Clear(); z_tensor_buffer.Clear();
// z matrix global desc // z matrix global desc
/*const auto M = q_grid_desc_k0_m_k1.GetLength(I1); /*const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1); const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
...@@ -1966,16 +1966,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1 ...@@ -1966,16 +1966,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
// P_dropped // P_dropped
static_for<0, n0, 1>{}([&](auto i) { static_for<0, n0, 1>{}([&](auto i) {
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf),
decltype(z_tenor_buffer), decltype(z_tensor_buffer),
true, true,
decltype(n0), decltype(n0),
decltype(i)>( decltype(i)>(
s_slash_p_thread_buf, ph, z_tenor_buffer); s_slash_p_thread_buf, ph, z_tensor_buffer);
z_thread_copy_vgpr_to_global.Run( z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer, z_tensor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_buf); z_grid_buf);
z_thread_copy_vgpr_to_global.MoveDstSliceWindow( z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
......
...@@ -1473,8 +1473,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2 ...@@ -1473,8 +1473,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2
unsigned short, unsigned short,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize(), z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize(),
true> true>
z_tenor_buffer; z_tensor_buffer;
z_tenor_buffer.Clear(); z_tensor_buffer.Clear();
// z matrix global desc // z matrix global desc
/*const auto M = q_grid_desc_k0_m_k1.GetLength(I1); /*const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1); const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
...@@ -1865,16 +1865,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2 ...@@ -1865,16 +1865,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2
// P_dropped // P_dropped
static_for<0, n0, 1>{}([&](auto i) { static_for<0, n0, 1>{}([&](auto i) {
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf),
decltype(z_tenor_buffer), decltype(z_tensor_buffer),
true, true,
decltype(n0), decltype(n0),
decltype(i)>( decltype(i)>(
s_slash_p_thread_buf, ph, z_tenor_buffer); s_slash_p_thread_buf, ph, z_tensor_buffer);
z_thread_copy_vgpr_to_global.Run( z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer, z_tensor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_buf); z_grid_buf);
z_thread_copy_vgpr_to_global.MoveDstSliceWindow( z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
......
...@@ -110,6 +110,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -110,6 +110,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
static constexpr auto Gemm0MWaves = MPerBlock / (MPerXdl * MXdlPerWave); static constexpr auto Gemm0MWaves = MPerBlock / (MPerXdl * MXdlPerWave);
static constexpr auto Gemm0NWaves = NPerBlock / (NPerXdl * NXdlPerWave); static constexpr auto Gemm0NWaves = NPerBlock / (NPerXdl * NXdlPerWave);
static constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
static constexpr auto DropoutNThread = mfma.num_input_blks; // 2
// get_random_8x16() generates 8 random numbers each time
static constexpr auto DropoutTile = Number<DropoutNThread * 8>{}; // 16
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
// C desc for source in blockwise copy // C desc for source in blockwise copy
...@@ -119,10 +124,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -119,10 +124,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
const auto M = z_grid_desc_m_n.GetLength(I0); const auto M = z_grid_desc_m_n.GetLength(I0);
const auto N = z_grid_desc_m_n.GetLength(I1); const auto N = z_grid_desc_m_n.GetLength(I1);
constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma; constexpr auto M3 = mfma.num_groups_per_blk;
constexpr auto M3 = mfma.num_groups_per_blk; constexpr auto M4 = mfma.num_input_blks;
constexpr auto M4 = mfma.num_input_blks; constexpr auto M5 = mfma.group_size;
constexpr auto M5 = mfma.group_size;
return transform_tensor_descriptor( return transform_tensor_descriptor(
z_grid_desc_m_n, z_grid_desc_m_n,
...@@ -136,9 +140,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -136,9 +140,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
__host__ __device__ static constexpr auto GetPaddedSize(const index_t size) __host__ __device__ static constexpr auto GetPaddedSize(const index_t size)
{ {
constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma; return math::integer_divide_ceil(size, DropoutTile) * DropoutTile;
constexpr auto group_size = mfma.group_size;
return math::integer_divide_ceil(size, group_size) * group_size;
} }
__device__ static auto GetGemm0WaveIdx() __device__ static auto GetGemm0WaveIdx()
...@@ -542,9 +544,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -542,9 +544,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
BBlockDesc_BK0_N_BK1{}); BBlockDesc_BK0_N_BK1{});
} }
static constexpr index_t KPack = static constexpr index_t KPack = math::max(math::lcm(AK1, BK1), mfma.k_per_blk);
math::max(math::lcm(AK1, BK1),
MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
// Blockwise gemm with transposed XDL output // Blockwise gemm with transposed XDL output
using BlockwiseGemm = BlockwiseGemmXdlops_v2< using BlockwiseGemm = BlockwiseGemmXdlops_v2<
...@@ -646,8 +646,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -646,8 +646,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will // with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7]. // cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size // therefore we may just as well assign Gemm1KPack = group_size
static constexpr index_t GemmKPack = static constexpr index_t GemmKPack = mfma.group_size;
MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma.group_size;
static constexpr index_t GemmMWave = Gemm0NWaves; // 4 // 4 static constexpr index_t GemmMWave = Gemm0NWaves; // 4 // 4
static constexpr index_t GemmNWave = Gemm0MWaves; // 1 // 1 static constexpr index_t GemmNWave = Gemm0MWaves; // 1 // 1
...@@ -770,9 +769,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -770,9 +769,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
static constexpr index_t GemmNRepeat = Gemm2NXdlPerWave; // 1 // 1 static constexpr index_t GemmNRepeat = Gemm2NXdlPerWave; // 1 // 1
static constexpr index_t GemmMRepeat = Gemm2_M / GemmMWave / MPerXdl; // 1 // 1 static constexpr index_t GemmMRepeat = Gemm2_M / GemmMWave / MPerXdl; // 1 // 1
static constexpr index_t GemmKLoop = Gemm2_K / Sum_K; // 2 // 2 static constexpr index_t GemmKLoop = Gemm2_K / Sum_K; // 2 // 2
static constexpr index_t GemmKPack = static constexpr index_t GemmKPack = math::max(A_K1, mfma.k_per_blk);
math::max(A_K1, MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); static constexpr index_t B_K3 = GemmKPack; // 8
static constexpr index_t B_K3 = GemmKPack; // 8
static constexpr index_t B_K2 = static constexpr index_t B_K2 =
XdlopsGemm<GemmDataType, MPerXdl, NPerXdl, GemmKPack, false>{}.K0PerXdlops; // 2 XdlopsGemm<GemmDataType, MPerXdl, NPerXdl, GemmKPack, false>{}.K0PerXdlops; // 2
static constexpr index_t B_K1 = Sum_K / B_K2 / B_K3; // 4 static constexpr index_t B_K1 = Sum_K / B_K2 / B_K3; // 4
...@@ -1570,8 +1568,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1570,8 +1568,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
ushort, ushort,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize(), z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize(),
true> true>
z_tenor_buffer; z_tensor_buffer;
z_tenor_buffer.Clear(); z_tensor_buffer.Clear();
auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize()); p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize());
...@@ -1759,7 +1757,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1759,7 +1757,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
// scaling is already performed in the preceding statements with s_element_op // scaling is already performed in the preceding statements with s_element_op
blockwise_softmax.RunWithPreCalcStats(s_slash_p_thread_buf, lse_thread_buf); blockwise_softmax.RunWithPreCalcStats(s_slash_p_thread_buf, lse_thread_buf);
constexpr auto position_offset = M3 * M4;
// save z to global // save z to global
if constexpr(IsDropout) if constexpr(IsDropout)
{ {
...@@ -1774,23 +1771,27 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1774,23 +1771,27 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
auto m_global = m_local + m_block_data_idx_on_grid; auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid; auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded + auto global_tile_id = z_random_matrix_offset +
n_global; // unique element global 1d id (m_global / DropoutTile) * DropoutTile * raw_n_padded +
(n_global / DropoutTile) * DropoutTile;
auto global_elem_id = auto global_elem_id = global_tile_id + (wave_m_n_id[I0] * M4) +
(global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4; (n_global % DropoutTile) * raw_n_padded;
blockwise_dropout blockwise_dropout
.template ApplyDropoutAttnBwdSaveZ<decltype(s_slash_p_thread_buf), .template ApplyDropoutAttnBwdSaveZ<decltype(s_slash_p_thread_buf),
decltype(z_tenor_buffer), decltype(z_tensor_buffer),
decltype(position_offset), decltype(DropoutTile),
true>( true>(s_slash_p_thread_buf,
s_slash_p_thread_buf, ph, global_elem_id, z_tenor_buffer, raw_n_padded); ph,
global_elem_id,
z_tensor_buffer,
raw_n_padded);
z_thread_copy_vgpr_to_global.Run( z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer, z_tensor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
z_grid_buf); z_grid_buf);
} }
...@@ -1806,15 +1807,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1806,15 +1807,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
auto m_global = m_local + m_block_data_idx_on_grid; auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid; auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded + auto global_tile_id = z_random_matrix_offset +
n_global; // unique element global 1d id (m_global / DropoutTile) * DropoutTile * raw_n_padded +
(n_global / DropoutTile) * DropoutTile;
auto global_elem_id = auto global_elem_id = global_tile_id + (wave_m_n_id[I0] * M4) +
(global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4; (n_global % DropoutTile) * raw_n_padded;
// P_dropped // P_dropped
blockwise_dropout.template ApplyDropoutAttnBwd<decltype(s_slash_p_thread_buf), blockwise_dropout.template ApplyDropoutAttnBwd<decltype(s_slash_p_thread_buf),
decltype(position_offset), decltype(DropoutTile),
true>( true>(
s_slash_p_thread_buf, ph, global_elem_id, raw_n_padded); s_slash_p_thread_buf, ph, global_elem_id, raw_n_padded);
} }
......
...@@ -121,6 +121,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -121,6 +121,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static constexpr auto B1K0 = Number<Gemm1KPerBlock / B1K1Value>{}; static constexpr auto B1K0 = Number<Gemm1KPerBlock / B1K1Value>{};
static constexpr auto B1K1 = Number<B1K1Value>{}; static constexpr auto B1K1 = Number<B1K1Value>{};
static constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
static constexpr auto DropoutNThread = mfma.num_input_blks; // 2
// get_random_8x16() generates 8 random numbers each time
static constexpr auto DropoutTile = Number<DropoutNThread * 8>{}; // 16
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype( using GridwiseGemmPipe = remove_cvref_t<decltype(
...@@ -133,10 +138,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -133,10 +138,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const auto M = z_grid_desc_m_n.GetLength(I0); const auto M = z_grid_desc_m_n.GetLength(I0);
const auto N = z_grid_desc_m_n.GetLength(I1); const auto N = z_grid_desc_m_n.GetLength(I1);
constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma; constexpr auto M3 = mfma.num_groups_per_blk;
constexpr auto M3 = mfma.num_groups_per_blk; constexpr auto M4 = mfma.num_input_blks;
constexpr auto M4 = mfma.num_input_blks; constexpr auto M5 = mfma.group_size;
constexpr auto M5 = mfma.group_size;
return transform_tensor_descriptor( return transform_tensor_descriptor(
z_grid_desc_m_n, z_grid_desc_m_n,
...@@ -150,9 +154,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -150,9 +154,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
__host__ __device__ static constexpr auto GetPaddedSize(const index_t size) __host__ __device__ static constexpr auto GetPaddedSize(const index_t size)
{ {
constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma; return math::integer_divide_ceil(size, DropoutTile) * DropoutTile;
constexpr auto group_size = mfma.group_size;
return math::integer_divide_ceil(size, group_size) * group_size;
} }
__device__ static auto GetGemm0WaveIdx() __device__ static auto GetGemm0WaveIdx()
...@@ -522,9 +524,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -522,9 +524,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
true, // DstResetCoord true, // DstResetCoord
NumGemmKPrefetchStage>; NumGemmKPrefetchStage>;
static constexpr index_t KPack = static constexpr index_t KPack = math::max(math::lcm(AK1, BK1), mfma.k_per_blk);
math::max(math::lcm(AK1, BK1),
MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
// Blockwise gemm with transposed XDL output // Blockwise gemm with transposed XDL output
using BlockwiseGemm = BlockwiseGemmXdlops_v2< using BlockwiseGemm = BlockwiseGemmXdlops_v2<
...@@ -657,8 +657,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -657,8 +657,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will // with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7]. // cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size // therefore we may just as well assign Gemm1KPack = group_size
static constexpr index_t GemmKPack = static constexpr index_t GemmKPack = mfma.group_size;
MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma.group_size;
using BlockwiseGemm = BlockwiseGemmXdlops_v2< using BlockwiseGemm = BlockwiseGemmXdlops_v2<
BlockSize, BlockSize,
...@@ -709,9 +708,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -709,9 +708,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static constexpr index_t GemmMWave = BlockSize / get_warp_size() / GemmNWave; static constexpr index_t GemmMWave = BlockSize / get_warp_size() / GemmNWave;
static constexpr index_t GemmNRepeat = Gemm2NXdlPerWave; static constexpr index_t GemmNRepeat = Gemm2NXdlPerWave;
static constexpr index_t GemmMRepeat = Gemm2_M / GemmMWave / MPerXdl; static constexpr index_t GemmMRepeat = Gemm2_M / GemmMWave / MPerXdl;
static constexpr index_t GemmKPack = static constexpr index_t GemmKPack = math::max(math::lcm(A_K1, B_K1), mfma.k_per_blk);
math::max(math::lcm(A_K1, B_K1),
MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
using BBlockSliceLengths = Sequence<B_K0, Gemm2_N, B_K1>; using BBlockSliceLengths = Sequence<B_K0, Gemm2_N, B_K1>;
using BThreadClusterLengths = using BThreadClusterLengths =
...@@ -1554,8 +1551,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1554,8 +1551,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
ushort, ushort,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize(), z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize(),
true> true>
z_tenor_buffer; z_tensor_buffer;
z_tenor_buffer.Clear(); z_tensor_buffer.Clear();
auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize()); p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize());
...@@ -1722,7 +1719,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1722,7 +1719,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// scaling is already performed in the preceding statements with s_element_op // scaling is already performed in the preceding statements with s_element_op
blockwise_softmax.RunWithPreCalcStats(s_slash_p_thread_buf, lse_thread_buf); blockwise_softmax.RunWithPreCalcStats(s_slash_p_thread_buf, lse_thread_buf);
constexpr auto position_offset = M3 * M4;
// save z to global // save z to global
if constexpr(IsDropout) if constexpr(IsDropout)
{ {
...@@ -1737,23 +1733,27 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1737,23 +1733,27 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
auto m_global = m_local + m_block_data_idx_on_grid; auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid; auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded + auto global_tile_id = z_random_matrix_offset +
n_global; // unique element global 1d id (m_global / DropoutTile) * DropoutTile * raw_n_padded +
(n_global / DropoutTile) * DropoutTile;
auto global_elem_id = auto global_elem_id = global_tile_id + (wave_m_n_id[I0] * M4) +
(global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4; (n_global % DropoutTile) * raw_n_padded;
blockwise_dropout blockwise_dropout
.template ApplyDropoutAttnBwdSaveZ<decltype(s_slash_p_thread_buf), .template ApplyDropoutAttnBwdSaveZ<decltype(s_slash_p_thread_buf),
decltype(z_tenor_buffer), decltype(z_tensor_buffer),
decltype(position_offset), decltype(DropoutTile),
true>( true>(s_slash_p_thread_buf,
s_slash_p_thread_buf, ph, global_elem_id, z_tenor_buffer, raw_n_padded); ph,
global_elem_id,
z_tensor_buffer,
raw_n_padded);
z_thread_copy_vgpr_to_global.Run( z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer, z_tensor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
z_grid_buf); z_grid_buf);
} }
...@@ -1769,14 +1769,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1769,14 +1769,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
auto m_global = m_local + m_block_data_idx_on_grid; auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid; auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded + auto global_tile_id = z_random_matrix_offset +
n_global; // unique element global 1d id (m_global / DropoutTile) * DropoutTile * raw_n_padded +
(n_global / DropoutTile) * DropoutTile;
auto global_elem_id = global_tile_id + (wave_m_n_id[I0] * M4) +
(n_global % DropoutTile) * raw_n_padded;
auto global_elem_id =
(global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4;
// P_dropped // P_dropped
blockwise_dropout.template ApplyDropoutAttnBwd<decltype(s_slash_p_thread_buf), blockwise_dropout.template ApplyDropoutAttnBwd<decltype(s_slash_p_thread_buf),
decltype(position_offset), decltype(DropoutTile),
true>( true>(
s_slash_p_thread_buf, ph, global_elem_id, raw_n_padded); s_slash_p_thread_buf, ph, global_elem_id, raw_n_padded);
} }
......
...@@ -109,6 +109,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -109,6 +109,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static constexpr auto Gemm0MWaves = MPerBlock / (MPerXdl * MXdlPerWave); static constexpr auto Gemm0MWaves = MPerBlock / (MPerXdl * MXdlPerWave);
static constexpr auto Gemm0NWaves = NPerBlock / (NPerXdl * NXdlPerWave); static constexpr auto Gemm0NWaves = NPerBlock / (NPerXdl * NXdlPerWave);
static constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
static constexpr auto DropoutNThread = mfma.num_input_blks; // 2
// get_random_8x16() generates 8 random numbers each time
static constexpr auto DropoutTile = Number<DropoutNThread * 8>{}; // 16
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
// C desc for source in blockwise copy // C desc for source in blockwise copy
...@@ -118,10 +123,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -118,10 +123,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const auto M = z_grid_desc_m_n.GetLength(I0); const auto M = z_grid_desc_m_n.GetLength(I0);
const auto N = z_grid_desc_m_n.GetLength(I1); const auto N = z_grid_desc_m_n.GetLength(I1);
constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma; constexpr auto M3 = mfma.num_groups_per_blk;
constexpr auto M3 = mfma.num_groups_per_blk; constexpr auto M4 = mfma.num_input_blks;
constexpr auto M4 = mfma.num_input_blks; constexpr auto M5 = mfma.group_size;
constexpr auto M5 = mfma.group_size;
return transform_tensor_descriptor( return transform_tensor_descriptor(
z_grid_desc_m_n, z_grid_desc_m_n,
...@@ -135,9 +139,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -135,9 +139,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
__host__ __device__ static constexpr auto GetPaddedSize(const index_t size) __host__ __device__ static constexpr auto GetPaddedSize(const index_t size)
{ {
constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma; return math::integer_divide_ceil(size, DropoutTile) * DropoutTile;
constexpr auto group_size = mfma.group_size;
return math::integer_divide_ceil(size, group_size) * group_size;
} }
__device__ static auto GetGemm0WaveIdx() __device__ static auto GetGemm0WaveIdx()
...@@ -563,9 +565,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -563,9 +565,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
BBlockDesc_BK0_N_BK1{}); BBlockDesc_BK0_N_BK1{});
} }
static constexpr index_t KPack = static constexpr index_t KPack = math::max(math::lcm(AK1, BK1), mfma.k_per_blk);
math::max(math::lcm(AK1, BK1),
MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
// Blockwise gemm with transposed XDL output // Blockwise gemm with transposed XDL output
using BlockwiseGemm = BlockwiseGemmXdlops_v2< using BlockwiseGemm = BlockwiseGemmXdlops_v2<
...@@ -667,8 +667,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -667,8 +667,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will // with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7]. // cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size // therefore we may just as well assign Gemm1KPack = group_size
static constexpr index_t GemmKPack = static constexpr index_t GemmKPack = mfma.group_size;
MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma.group_size;
static constexpr index_t GemmMWave = Gemm0NWaves; // 4 // 4 static constexpr index_t GemmMWave = Gemm0NWaves; // 4 // 4
static constexpr index_t GemmNWave = Gemm0MWaves; // 1 // 1 static constexpr index_t GemmNWave = Gemm0MWaves; // 1 // 1
...@@ -791,9 +790,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -791,9 +790,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static constexpr index_t GemmNRepeat = Gemm2NXdlPerWave; // 1 // 1 static constexpr index_t GemmNRepeat = Gemm2NXdlPerWave; // 1 // 1
static constexpr index_t GemmMRepeat = Gemm2_M / GemmMWave / MPerXdl; // 1 // 1 static constexpr index_t GemmMRepeat = Gemm2_M / GemmMWave / MPerXdl; // 1 // 1
static constexpr index_t GemmKLoop = Gemm2_K / Sum_K; // 2 // 2 static constexpr index_t GemmKLoop = Gemm2_K / Sum_K; // 2 // 2
static constexpr index_t GemmKPack = static constexpr index_t GemmKPack = math::max(A_K1, mfma.k_per_blk);
math::max(A_K1, MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); static constexpr index_t B_K3 = GemmKPack; // 8
static constexpr index_t B_K3 = GemmKPack; // 8
static constexpr index_t B_K2 = static constexpr index_t B_K2 =
XdlopsGemm<GemmDataType, MPerXdl, NPerXdl, GemmKPack, false>{}.K0PerXdlops; // 2 XdlopsGemm<GemmDataType, MPerXdl, NPerXdl, GemmKPack, false>{}.K0PerXdlops; // 2
static constexpr index_t B_K1 = Sum_K / B_K2 / B_K3; // 4 static constexpr index_t B_K1 = Sum_K / B_K2 / B_K3; // 4
...@@ -1621,8 +1619,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1621,8 +1619,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
ushort, ushort,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize(), z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize(),
true> true>
z_tenor_buffer; z_tensor_buffer;
z_tenor_buffer.Clear(); z_tensor_buffer.Clear();
auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize()); p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize());
...@@ -1946,7 +1944,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1946,7 +1944,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// scaling is already performed in the preceding statements with s_element_op // scaling is already performed in the preceding statements with s_element_op
blockwise_softmax.RunWithPreCalcStats(s_slash_p_thread_buf, lse_thread_buf); blockwise_softmax.RunWithPreCalcStats(s_slash_p_thread_buf, lse_thread_buf);
constexpr auto position_offset = M3 * M4;
// save z to global // save z to global
if constexpr(IsDropout) if constexpr(IsDropout)
{ {
...@@ -1961,23 +1958,27 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1961,23 +1958,27 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
auto m_global = m_local + m_block_data_idx_on_grid; auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid; auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded + auto global_tile_id = z_random_matrix_offset +
n_global; // unique element global 1d id (m_global / DropoutTile) * DropoutTile * raw_n_padded +
(n_global / DropoutTile) * DropoutTile;
auto global_elem_id = auto global_elem_id = global_tile_id + (wave_m_n_id[I0] * M4) +
(global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4; (n_global % DropoutTile) * raw_n_padded;
blockwise_dropout blockwise_dropout
.template ApplyDropoutAttnBwdSaveZ<decltype(s_slash_p_thread_buf), .template ApplyDropoutAttnBwdSaveZ<decltype(s_slash_p_thread_buf),
decltype(z_tenor_buffer), decltype(z_tensor_buffer),
decltype(position_offset), decltype(DropoutTile),
true>( true>(s_slash_p_thread_buf,
s_slash_p_thread_buf, ph, global_elem_id, z_tenor_buffer, raw_n_padded); ph,
global_elem_id,
z_tensor_buffer,
raw_n_padded);
z_thread_copy_vgpr_to_global.Run( z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer, z_tensor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
z_grid_buf); z_grid_buf);
} }
...@@ -1993,15 +1994,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1993,15 +1994,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
auto m_global = m_local + m_block_data_idx_on_grid; auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid; auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded + auto global_tile_id = z_random_matrix_offset +
n_global; // unique element global 1d id (m_global / DropoutTile) * DropoutTile * raw_n_padded +
(n_global / DropoutTile) * DropoutTile;
auto global_elem_id = auto global_elem_id = global_tile_id + (wave_m_n_id[I0] * M4) +
(global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4; (n_global % DropoutTile) * raw_n_padded;
// P_dropped // P_dropped
blockwise_dropout.template ApplyDropoutAttnBwd<decltype(s_slash_p_thread_buf), blockwise_dropout.template ApplyDropoutAttnBwd<decltype(s_slash_p_thread_buf),
decltype(position_offset), decltype(DropoutTile),
true>( true>(
s_slash_p_thread_buf, ph, global_elem_id, raw_n_padded); s_slash_p_thread_buf, ph, global_elem_id, raw_n_padded);
} }
......
...@@ -120,6 +120,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -120,6 +120,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static constexpr auto B1K0 = Number<Gemm1KPerBlock / B1K1Value>{}; static constexpr auto B1K0 = Number<Gemm1KPerBlock / B1K1Value>{};
static constexpr auto B1K1 = Number<B1K1Value>{}; static constexpr auto B1K1 = Number<B1K1Value>{};
static constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
static constexpr auto DropoutNThread = mfma.num_input_blks; // 2
// get_random_8x16() generates 8 random numbers each time
static constexpr auto DropoutTile = Number<DropoutNThread * 8>{}; // 16
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype( using GridwiseGemmPipe = remove_cvref_t<decltype(
...@@ -132,10 +137,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -132,10 +137,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const auto M = z_grid_desc_m_n.GetLength(I0); const auto M = z_grid_desc_m_n.GetLength(I0);
const auto N = z_grid_desc_m_n.GetLength(I1); const auto N = z_grid_desc_m_n.GetLength(I1);
constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma; constexpr auto M3 = mfma.num_groups_per_blk;
constexpr auto M3 = mfma.num_groups_per_blk; constexpr auto M4 = mfma.num_input_blks;
constexpr auto M4 = mfma.num_input_blks; constexpr auto M5 = mfma.group_size;
constexpr auto M5 = mfma.group_size;
return transform_tensor_descriptor( return transform_tensor_descriptor(
z_grid_desc_m_n, z_grid_desc_m_n,
...@@ -149,9 +153,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -149,9 +153,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
__host__ __device__ static constexpr auto GetPaddedSize(const index_t size) __host__ __device__ static constexpr auto GetPaddedSize(const index_t size)
{ {
constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma; return math::integer_divide_ceil(size, DropoutTile) * DropoutTile;
constexpr auto group_size = mfma.group_size;
return math::integer_divide_ceil(size, group_size) * group_size;
} }
__device__ static auto GetGemm0WaveIdx() __device__ static auto GetGemm0WaveIdx()
...@@ -543,9 +545,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -543,9 +545,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
true, // DstResetCoord true, // DstResetCoord
NumGemmKPrefetchStage>; NumGemmKPrefetchStage>;
static constexpr index_t KPack = static constexpr index_t KPack = math::max(math::lcm(AK1, BK1), mfma.k_per_blk);
math::max(math::lcm(AK1, BK1),
MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
// Blockwise gemm with transposed XDL output // Blockwise gemm with transposed XDL output
using BlockwiseGemm = BlockwiseGemmXdlops_v2< using BlockwiseGemm = BlockwiseGemmXdlops_v2<
...@@ -678,8 +678,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -678,8 +678,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will // with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7]. // cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size // therefore we may just as well assign Gemm1KPack = group_size
static constexpr index_t GemmKPack = static constexpr index_t GemmKPack = mfma.group_size;
MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma.group_size;
using BlockwiseGemm = BlockwiseGemmXdlops_v2< using BlockwiseGemm = BlockwiseGemmXdlops_v2<
BlockSize, BlockSize,
...@@ -730,9 +729,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -730,9 +729,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static constexpr index_t GemmMWave = BlockSize / get_warp_size() / GemmNWave; static constexpr index_t GemmMWave = BlockSize / get_warp_size() / GemmNWave;
static constexpr index_t GemmNRepeat = Gemm2NXdlPerWave; static constexpr index_t GemmNRepeat = Gemm2NXdlPerWave;
static constexpr index_t GemmMRepeat = Gemm2_M / GemmMWave / MPerXdl; static constexpr index_t GemmMRepeat = Gemm2_M / GemmMWave / MPerXdl;
static constexpr index_t GemmKPack = static constexpr index_t GemmKPack = math::max(math::lcm(A_K1, B_K1), mfma.k_per_blk);
math::max(math::lcm(A_K1, B_K1),
MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
using BBlockSliceLengths = Sequence<B_K0, Gemm2_N, B_K1>; using BBlockSliceLengths = Sequence<B_K0, Gemm2_N, B_K1>;
using BThreadClusterLengths = using BThreadClusterLengths =
...@@ -1582,8 +1579,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1582,8 +1579,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
ushort, ushort,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize(), z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize(),
true> true>
z_tenor_buffer; z_tensor_buffer;
z_tenor_buffer.Clear(); z_tensor_buffer.Clear();
auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize()); p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize());
...@@ -1862,7 +1859,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1862,7 +1859,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// scaling is already performed in the preceding statements with s_element_op // scaling is already performed in the preceding statements with s_element_op
blockwise_softmax.RunWithPreCalcStats(s_slash_p_thread_buf, lse_thread_buf); blockwise_softmax.RunWithPreCalcStats(s_slash_p_thread_buf, lse_thread_buf);
constexpr auto position_offset = M3 * M4;
// save z to global // save z to global
if constexpr(IsDropout) if constexpr(IsDropout)
{ {
...@@ -1877,23 +1873,27 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1877,23 +1873,27 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
auto m_global = m_local + m_block_data_idx_on_grid; auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid; auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded + auto global_tile_id = z_random_matrix_offset +
n_global; // unique element global 1d id (m_global / DropoutTile) * DropoutTile * raw_n_padded +
(n_global / DropoutTile) * DropoutTile;
auto global_elem_id = auto global_elem_id = global_tile_id + (wave_m_n_id[I0] * M4) +
(global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4; (n_global % DropoutTile) * raw_n_padded;
blockwise_dropout blockwise_dropout
.template ApplyDropoutAttnBwdSaveZ<decltype(s_slash_p_thread_buf), .template ApplyDropoutAttnBwdSaveZ<decltype(s_slash_p_thread_buf),
decltype(z_tenor_buffer), decltype(z_tensor_buffer),
decltype(position_offset), decltype(DropoutTile),
true>( true>(s_slash_p_thread_buf,
s_slash_p_thread_buf, ph, global_elem_id, z_tenor_buffer, raw_n_padded); ph,
global_elem_id,
z_tensor_buffer,
raw_n_padded);
z_thread_copy_vgpr_to_global.Run( z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer, z_tensor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
z_grid_buf); z_grid_buf);
} }
...@@ -1909,14 +1909,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1909,14 +1909,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
auto m_global = m_local + m_block_data_idx_on_grid; auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid; auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded + auto global_tile_id = z_random_matrix_offset +
n_global; // unique element global 1d id (m_global / DropoutTile) * DropoutTile * raw_n_padded +
(n_global / DropoutTile) * DropoutTile;
auto global_elem_id = global_tile_id + (wave_m_n_id[I0] * M4) +
(n_global % DropoutTile) * raw_n_padded;
auto global_elem_id =
(global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4;
// P_dropped // P_dropped
blockwise_dropout.template ApplyDropoutAttnBwd<decltype(s_slash_p_thread_buf), blockwise_dropout.template ApplyDropoutAttnBwd<decltype(s_slash_p_thread_buf),
decltype(position_offset), decltype(DropoutTile),
true>( true>(
s_slash_p_thread_buf, ph, global_elem_id, raw_n_padded); s_slash_p_thread_buf, ph, global_elem_id, raw_n_padded);
} }
......
...@@ -873,8 +873,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V1 ...@@ -873,8 +873,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
unsigned short, unsigned short,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize(), z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize(),
true> true>
z_tenor_buffer; z_tensor_buffer;
z_tenor_buffer.Clear(); z_tensor_buffer.Clear();
// z matrix global desc // z matrix global desc
auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
...@@ -1022,16 +1022,16 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V1 ...@@ -1022,16 +1022,16 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
{ {
static_for<0, n0, 1>{}([&](auto i) { static_for<0, n0, 1>{}([&](auto i) {
blockwise_dropout.template ApplyDropout<decltype(acc_thread_buf), blockwise_dropout.template ApplyDropout<decltype(acc_thread_buf),
decltype(z_tenor_buffer), decltype(z_tensor_buffer),
false, false,
decltype(n0), decltype(n0),
decltype(i)>( decltype(i)>(
acc_thread_buf, ph, z_tenor_buffer); acc_thread_buf, ph, z_tensor_buffer);
z_thread_copy_vgpr_to_global.Run( z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer, z_tensor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_buf); z_grid_buf);
z_thread_copy_vgpr_to_global.MoveDstSliceWindow( z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
......
...@@ -60,6 +60,7 @@ template <typename FloatAB, ...@@ -60,6 +60,7 @@ template <typename FloatAB,
index_t MXdlPerWave, index_t MXdlPerWave,
index_t NXdlPerWave, index_t NXdlPerWave,
index_t Gemm1NXdlPerWave, index_t Gemm1NXdlPerWave,
index_t DropoutStepValue,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1, typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder, typename ABlockTransferSrcAccessOrder,
...@@ -113,6 +114,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -113,6 +114,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
static constexpr auto I5 = Number<5>{}; static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{}; static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{}; static constexpr auto I7 = Number<7>{};
static constexpr auto I8 = Number<8>{};
static constexpr auto I9 = Number<9>{};
static constexpr auto WaveSize = 64; static constexpr auto WaveSize = 64;
...@@ -130,54 +133,76 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -130,54 +133,76 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
static constexpr auto B1K0 = Number<Gemm1KPerBlock / B1K1Value>{}; static constexpr auto B1K0 = Number<Gemm1KPerBlock / B1K1Value>{};
static constexpr auto B1K1 = Number<B1K1Value>{}; static constexpr auto B1K1 = Number<B1K1Value>{};
static constexpr auto mfma = MfmaSelector<FloatGemm, MPerXdl, NPerXdl>::selected_mfma;
static constexpr auto DropoutNThread = mfma.num_input_blks; // 2
// get_random_8x16() generates 8 random numbers each time
static constexpr auto DropoutTile = Number<DropoutNThread * 8>{}; // 16
static constexpr auto DropoutMThread = DropoutTile; // 16
static constexpr auto DropoutTilePerXdl = NPerXdl / DropoutTile; // 2
static constexpr auto DropoutStep = Number<DropoutStepValue>{}; // 1 2 4
static constexpr auto DropoutNRepeat =
Number<math::integer_divide_ceil(DropoutStep, DropoutTilePerXdl)>{}; // 1 1 2
static constexpr auto DropoutGroupPerTile =
Number<mfma.num_groups_per_blk / DropoutTilePerXdl>{}; // 2
static constexpr auto DropoutStepPerXdl =
Number<math::min(DropoutStep, DropoutTilePerXdl)>{}; // 1 2 2
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype( using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>; GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
// C desc for source in gridwise copy // C desc for source in gridwise copy
__host__ __device__ static constexpr auto MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5( __host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6(
const ZGridDesc_M_N& z_grid_desc_m_n) ////=> for z use const ZGridDesc_M_N& z_grid_desc_m_n) ////=> for z use
{ {
const auto M = z_grid_desc_m_n.GetLength(I0); const auto M = z_grid_desc_m_n.GetLength(I0);
const auto N = z_grid_desc_m_n.GetLength(I1); const auto N = z_grid_desc_m_n.GetLength(I1);
constexpr auto mfma = MfmaSelector<FloatGemm, MPerXdl, NPerXdl>::selected_mfma; const auto M0 = M / MPerBlock;
constexpr auto N3 = mfma.num_groups_per_blk; const auto N0 = N / (DropoutNRepeat * NPerXdl);
constexpr auto N4 = mfma.num_input_blks; constexpr auto M1 = MXdlPerWave;
constexpr auto N5 = mfma.group_size; constexpr auto N1 = DropoutNRepeat;
constexpr auto M2 = Gemm0MWaves;
constexpr auto N2 = Gemm0NWaves;
constexpr auto M3 = DropoutTilePerXdl;
constexpr auto N3 = DropoutStepPerXdl;
constexpr auto M4 = DropoutTile;
constexpr auto N4 = DropoutGroupPerTile;
constexpr auto N5 = mfma.num_input_blks;
constexpr auto N6 = mfma.group_size;
return transform_tensor_descriptor( return transform_tensor_descriptor(
z_grid_desc_m_n, z_grid_desc_m_n,
make_tuple(make_unmerge_transform( make_tuple(make_unmerge_transform(make_tuple(M0, M1, M2, M3, M4)),
make_tuple(M / MPerBlock, MXdlPerWave, Gemm0MWaves, MPerXdl)), make_unmerge_transform(make_tuple(N0, N1, N2, N3, N4, N5, N6))),
make_unmerge_transform(
make_tuple(N / NPerBlock, NXdlPerWave, Gemm0NWaves, N3, N4, N5))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7, 8, 9>{})); make_tuple(Sequence<0, 2, 4, 6, 8>{}, Sequence<1, 3, 5, 7, 9, 10, 11>{}));
} }
__host__ __device__ static constexpr auto GetZShuffleBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4() __host__ __device__ static constexpr auto
GetZShuffleBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5()
{ {
constexpr auto mfma = MfmaSelector<FloatGemm, MPerXdl, NPerXdl>::selected_mfma; constexpr auto M0 = MXdlPerWave;
constexpr auto M0 = MXdlPerWave; constexpr auto N0 = DropoutNRepeat;
constexpr auto M1 = Gemm0MWaves; constexpr auto M1 = Gemm0MWaves;
constexpr auto N1 = Gemm0NWaves; constexpr auto N1 = Gemm0NWaves;
constexpr auto M2 = MPerXdl; constexpr auto M2 = DropoutTilePerXdl;
constexpr auto N2 = mfma.num_groups_per_blk; constexpr auto N2 = DropoutStepPerXdl;
constexpr auto N3 = mfma.num_input_blks; constexpr auto M3 = DropoutTile;
constexpr auto N4 = mfma.group_size; constexpr auto N3 = DropoutGroupPerTile;
constexpr auto N4 = mfma.num_input_blks;
constexpr auto z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = constexpr auto N5 = mfma.group_size;
make_naive_tensor_descriptor_packed(make_tuple(M0, I1, M1, N1, M2, N2, N3, N4));
constexpr auto z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
return z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4; make_naive_tensor_descriptor_packed(make_tuple(M0, N0, M1, N1, M2, N2, M3, N3, N4, N5));
return z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5;
} }
__host__ __device__ static constexpr auto GetPaddedSize(const index_t size) __host__ __device__ static constexpr auto GetPaddedSize(const index_t size)
{ {
constexpr auto mfma = MfmaSelector<FloatGemm, MPerXdl, NPerXdl>::selected_mfma; return math::integer_divide_ceil(size, DropoutTile) * DropoutTile;
constexpr auto group_size = mfma.group_size;
return math::integer_divide_ceil(size, group_size) * group_size;
} }
__device__ static auto GetGemm0WaveIdx() __device__ static auto GetGemm0WaveIdx()
...@@ -434,10 +459,9 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -434,10 +459,9 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
const auto M = d0_grid_desc_m_n.GetLength(I0); const auto M = d0_grid_desc_m_n.GetLength(I0);
const auto N = d0_grid_desc_m_n.GetLength(I1); const auto N = d0_grid_desc_m_n.GetLength(I1);
constexpr auto mfma = MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma; constexpr auto N3 = mfma.num_groups_per_blk;
constexpr auto N3 = mfma.num_groups_per_blk; constexpr auto N4 = mfma.num_input_blks;
constexpr auto N4 = mfma.num_input_blks; constexpr auto N5 = mfma.group_size;
constexpr auto N5 = mfma.group_size;
return transform_tensor_descriptor( return transform_tensor_descriptor(
d0_grid_desc_m_n, d0_grid_desc_m_n,
make_tuple(make_unmerge_transform( make_tuple(make_unmerge_transform(
...@@ -468,8 +492,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -468,8 +492,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
using DefaultBlock2CTileMap = using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>; remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>;
using ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 = remove_cvref_t<decltype( using ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6 = remove_cvref_t<decltype(
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(ZGridDesc_M_N{}))>; MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6(ZGridDesc_M_N{}))>;
struct SharedMemTrait struct SharedMemTrait
{ {
...@@ -507,10 +531,10 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -507,10 +531,10 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
// LDS allocation for Z shuffle in LDS // LDS allocation for Z shuffle in LDS
static constexpr auto z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = static constexpr auto z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
GetZShuffleBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); GetZShuffleBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5();
static constexpr auto z_shuffle_block_space_size = static constexpr auto z_shuffle_block_space_size =
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetElementSpaceSize(); z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize();
}; };
template <bool HasMainKBlockLoop, template <bool HasMainKBlockLoop,
...@@ -538,8 +562,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -538,8 +562,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5& const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6,
const LSEGridDesc_M& lse_grid_desc_m, const LSEGridDesc_M& lse_grid_desc_m,
const Block2CTileMap& block_2_ctile_map, const Block2CTileMap& block_2_ctile_map,
const C0MatrixMask& c0_matrix_mask, const C0MatrixMask& c0_matrix_mask,
...@@ -661,9 +685,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -661,9 +685,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
// acc1[m][o] += acc[m][n] * B1[n][o] // acc1[m][o] += acc[m][n] * B1[n][o]
// sanity check // sanity check
constexpr index_t KPack = constexpr index_t KPack = math::max(math::lcm(AK1, BK1), mfma.k_per_blk);
math::max(math::lcm(AK1, BK1),
MfmaSelector<FloatGemm, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_v2< auto blockwise_gemm = BlockwiseGemmXdlops_v2<
BlockSize, BlockSize,
...@@ -823,8 +845,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -823,8 +845,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will // with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7]. // cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size // therefore we may just as well assign Gemm1KPack = group_size
constexpr index_t Gemm1KPack = constexpr index_t Gemm1KPack = mfma.group_size;
MfmaSelector<FloatGemm, MPerXdl, NPerXdl>::selected_mfma.group_size;
auto gemm1_blockwise_gemm = BlockwiseGemmXdlops_v2< auto gemm1_blockwise_gemm = BlockwiseGemmXdlops_v2<
BlockSize, BlockSize,
...@@ -1008,67 +1029,75 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -1008,67 +1029,75 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
}, },
Number<NumD0Tensor>{}); Number<NumD0Tensor>{});
constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 = // for blockwise copy
make_naive_tensor_descriptor_packed(make_tuple(m0, // MRepeat
DropoutNRepeat, // NRepeat
m1, // MWaveId
n1, // NWaveId
I1,
DropoutStepPerXdl,
m2,
DropoutGroupPerTile,
n3,
n4)); // RegisterNum
constexpr auto z_shuffle_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3 = // for blockwise copy
make_naive_tensor_descriptor_packed(make_tuple(m0, // MRepeat
DropoutNRepeat, // NRepeat
m1, // MWaveId
n1, // NWaveId
I1,
DropoutStepPerXdl,
DropoutGroupPerTile,
n3,
n4, // RegisterNum
m2));
// z is random number matrix for dropout verify // z is random number matrix for dropout verify
// //
// z vgpr copy to global // z vgpr copy to global
// //
// z matrix threadwise desc // z matrix threadwise desc
constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = // for blockwise copy constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6 =
make_naive_tensor_descriptor_packed(make_tuple(m0, // MRepeat make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId
I1, // NRepeat I1, // NBlockId
m1, // MWaveId m0, // MRepeat
n1, // NWaveId DropoutNRepeat, // NRepeat
m2, // MPerXdl m1, // MWaveId
n2, // NGroupNum n1, // NWaveId
n3, // NInputNum I1,
n4)); // RegisterNum DropoutStepPerXdl,
m2,
constexpr auto z_shuffle_thread_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4 = // for blockwise copy DropoutGroupPerTile,
make_naive_tensor_descriptor_packed(make_tuple(m0, // MRepeat n3,
I1, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
n2, // NGroupNum
n3, // NInputNum
n4, // RegisterNum
I1)); // I1
constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId
I1, // NBlockId
m0, // MRepeat
I1, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
n2, // NGroupNum
n3, // NInputNum
n4)); // RegisterNum n4)); // RegisterNum
constexpr auto z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = constexpr auto z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
GetZShuffleBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); GetZShuffleBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5();
constexpr auto ZM0 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0); constexpr auto ZM0 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetLength(I0); // 1
constexpr auto ZN0 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1); constexpr auto ZN0 =
constexpr auto ZM1 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2); z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetLength(I1); // 1 1 2
constexpr auto ZN1 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3); constexpr auto ZM1 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetLength(I2); // 4
constexpr auto ZM2 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4); constexpr auto ZN1 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetLength(I3); // 1
constexpr auto ZN2 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5); constexpr auto ZM2 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetLength(I4); // 2
constexpr auto ZN3 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6); constexpr auto ZN2 =
constexpr auto ZN4 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7); z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetLength(I5); // 1 2 2
constexpr auto ZM3 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetLength(I6); // 16
constexpr auto z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4 = constexpr auto ZN3 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetLength(I7); // 2
constexpr auto ZN4 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetLength(I8); // 2
constexpr auto ZN5 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetLength(I9); // 4
constexpr auto z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3 =
transform_tensor_descriptor( transform_tensor_descriptor(
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(make_pass_through_transform(ZM0), make_tuple(make_pass_through_transform(ZM0),
make_pass_through_transform(ZN0), make_pass_through_transform(ZN0),
make_pass_through_transform(ZM1), make_pass_through_transform(ZM1),
make_pass_through_transform(ZN1), make_pass_through_transform(ZN1),
make_unmerge_transform(make_tuple(ZM2 / ZN4, ZN4)), make_pass_through_transform(ZM2),
make_pass_through_transform(ZN2), make_pass_through_transform(ZN2),
make_pass_through_transform(ZN3), make_unmerge_transform(make_tuple(ZM3 / ZN4 / ZN5, ZN4, ZN5)),
make_pass_through_transform(ZN4)), make_merge_transform_v3_division_mod(make_tuple(ZN3, ZN4, ZN5))),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
Sequence<2>{}, Sequence<2>{},
...@@ -1076,112 +1105,130 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -1076,112 +1105,130 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
Sequence<4>{}, Sequence<4>{},
Sequence<5>{}, Sequence<5>{},
Sequence<6>{}, Sequence<6>{},
Sequence<7>{}), Sequence<7, 8, 9>{}),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
Sequence<2>{}, Sequence<2>{},
Sequence<3>{}, Sequence<3>{},
Sequence<4, 7>{}, Sequence<4>{},
Sequence<5>{}, Sequence<5>{},
Sequence<6>{}, Sequence<6, 7, 8>{},
Sequence<8>{})); Sequence<9>{}));
StaticBuffer<AddressSpaceEnum::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr,
ushort, ushort,
z_shuffle_thread_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4.GetElementSpaceSize(), z_shuffle_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize(),
true> true>
z_tensor_buffer; z_tensor_buffer;
z_tensor_buffer.Clear(); z_tensor_buffer.Clear();
auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize()); p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6.GetElementSpaceSize());
auto z_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto z_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ushort*>(p_shared), static_cast<ushort*>(p_shared),
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetElementSpaceSize()); z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize());
auto z_tmp_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< auto z_tmp_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
ushort, ushort,
ushort, ushort,
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4), decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
decltype(z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4), decltype(z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
Sequence<m0, // MRepeat Sequence<m0, // MRepeat
I1, // NRepeat DropoutNRepeat, // NRepeat
m1, // MWaveId m1, // MWaveId
n1, // NWaveId n1, // NWaveId
m2, // MPerXdl I1,
n2, // NGroupNum DropoutStepPerXdl,
n3, // NInputNum m2,
DropoutGroupPerTile,
n3,
n4>, // RegisterNum n4>, // RegisterNum
Sequence<0, 1, 2, 3, 4, 5, 6, 7>, Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
7, // DstVectorDim 9, // DstVectorDim
1, // DstScalarPerVector 1, // DstScalarPerVector
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector 1, // DstScalarStrideInVector
true>{z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, true>{z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, // MRepeat make_multi_index(0, // MRepeat
0, // NRepeat 0, // NRepeat
wave_id[I0], // MWaveId wave_id[I0], // MWaveId
wave_id[I1], // NWaveId wave_id[I1], // NWaveId
wave_m_n_id[I1], // MPerXdl wave_m_n_id[I1] / DropoutMThread,
0, // NGroupIndex 0,
wave_m_n_id[I0], // NInputIndex wave_m_n_id[I1] % DropoutMThread,
0,
wave_m_n_id[I0],
0), 0),
tensor_operation::element_wise::PassThrough{}}; tensor_operation::element_wise::PassThrough{}};
auto z_shuffle_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2< auto z_shuffle_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
ushort, ushort,
ushort, ushort,
decltype(z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4), decltype(z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3),
decltype(z_shuffle_thread_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4), decltype(z_shuffle_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3),
Sequence<m0, I1, m1, n1, m2, n2, n3, n4, I1>, Sequence<m0,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8>, DropoutNRepeat,
8, m1,
n1,
I1,
DropoutStepPerXdl,
DropoutGroupPerTile,
n3,
n4,
m2>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
9,
1, 1,
1, 1,
true>{z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4, true>{z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_multi_index(0, // MRepeat make_multi_index(0, // MRepeat
0, // NRepeat 0, // NRepeat
wave_id[I0], // MWaveId wave_id[I0], // MWaveId
wave_id[I1], // NWaveId wave_id[I1], // NWaveId
wave_m_n_id[I1] / ZN4, wave_m_n_id[I1] / DropoutMThread,
0,
0, 0,
wave_m_n_id[I0], wave_m_n_id[I0],
0, 0,
wave_m_n_id[I1] % ZN4)}; wave_m_n_id[I1] % DropoutMThread)};
auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
ushort, ushort,
ZDataType, ZDataType,
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5), decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6),
decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5), decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6),
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
Sequence<I1, // MBlockId Sequence<I1, // MBlockId
I1, // NBlockID I1, // NBlockID
m0, // MRepeat m0, // MRepeat
I1, // NRepeat DropoutNRepeat, // NRepeat
m1, // MWaveId m1, // MWaveId
n1, // NWaveId n1, // NWaveId
m2, // MPerXdl I1,
n2, // NGroupNum DropoutStepPerXdl,
n3, // NInputNum m2,
DropoutGroupPerTile,
n3,
n4>, n4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11>,
9, // DstVectorDim 11, // DstVectorDim
1, // DstScalarPerVector 1, // DstScalarPerVector
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector 1, // DstScalarStrideInVector
true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6,
make_multi_index(block_work_idx_m, // MBlockId make_multi_index(block_work_idx_m, // MBlockId
0, // NBlockId 0, // NBlockId
0, // mrepeat 0, // mrepeat
0, // nrepeat 0, // nrepeat
wave_id[I0], // MWaveId wave_id[I0], // MWaveId
wave_id[I1], // NWaveId wave_id[I1], // NWaveId
wave_m_n_id[I1], // MPerXdl wave_m_n_id[I1] / DropoutMThread,
0, // group 0,
wave_m_n_id[I0], // NInputIndex wave_m_n_id[I1] % DropoutMThread,
0,
wave_m_n_id[I0],
0), 0),
tensor_operation::element_wise::PassThrough{}}; tensor_operation::element_wise::PassThrough{}};
...@@ -1308,8 +1355,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -1308,8 +1355,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
blockwise_softmax.Run(acc_thread_buf, workspace_buf); blockwise_softmax.Run(acc_thread_buf, workspace_buf);
constexpr auto position_offset = N3 * N4; constexpr auto iterator_offset = Number<8 * DropoutStep>{};
constexpr auto iterator_offset = n2 * n3 * n4; constexpr auto iterator_step = Number<n0 * n1 * n2 * n3 * n4 / 8 / DropoutStep>{};
if constexpr(IsDropout) // dropout if constexpr(IsDropout) // dropout
{ {
...@@ -1326,49 +1373,44 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -1326,49 +1373,44 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
n_global; // unique element global 1d id n_global; // unique element global 1d id
blockwise_dropout.template GenerateZMatrixAttnFwd<decltype(z_tensor_buffer), blockwise_dropout.template GenerateZMatrixAttnFwd<decltype(z_tensor_buffer),
decltype(n0), decltype(iterator_step),
decltype(position_offset)>( decltype(DropoutTile)>(
ph, global_elem_id, z_tensor_buffer); ph, global_elem_id, z_tensor_buffer);
z_tmp_thread_copy_vgpr_to_lds.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4, z_tmp_thread_copy_vgpr_to_lds.Run(
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_tensor_buffer, make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, z_tensor_buffer,
z_block_buf); z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_block_buf);
z_shuffle_thread_copy_lds_to_vgpr.Run( z_shuffle_thread_copy_lds_to_vgpr.Run(
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4, z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
z_block_buf, z_block_buf,
z_shuffle_thread_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4, z_shuffle_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tensor_buffer); z_tensor_buffer);
blockwise_dropout.template ApplyDropoutWithZ<decltype(acc_thread_buf), blockwise_dropout.template ApplyDropoutWithZ<decltype(acc_thread_buf),
decltype(z_tensor_buffer), decltype(z_tensor_buffer),
decltype(n0), decltype(iterator_step),
decltype(i)>(acc_thread_buf, decltype(i)>(acc_thread_buf,
z_tensor_buffer); z_tensor_buffer);
// save z to global // save z to global
if(p_z_grid) if(p_z_grid && (gemm1_n_block_data_idx_on_grid == 0))
{ {
z_thread_copy_vgpr_to_global.Run( z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tensor_buffer, z_tensor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6,
z_grid_buf); z_grid_buf);
z_thread_copy_vgpr_to_global.MoveDstSliceWindow( z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6,
make_multi_index(0, 0, 0, 1, 0, 0, 0, 0, 0, 0)); make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0));
} }
}); });
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 0, 0, -(n0.value), 0, 0, 0, 0, 0, 0));
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
} }
// TODO: may convert to log domain // TODO: may convert to log domain
...@@ -1489,7 +1531,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -1489,7 +1531,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
static_for<0, MXdlPerWave, 1>{}( static_for<0, MXdlPerWave, 1>{}(
[&](auto I) { lse_thread_buf(I) = running_max(I) + math::log(running_sum(I)); }); [&](auto I) { lse_thread_buf(I) = running_max(I) + math::log(running_sum(I)); });
if(get_lane_local_1d_id() < AccM2) if((get_lane_local_1d_id() < AccM2) && (gemm1_n_block_data_idx_on_grid == 0))
{ {
static_for<0, MXdlPerWave, 1>{}([&](auto I) { static_for<0, MXdlPerWave, 1>{}([&](auto I) {
// copy from VGPR to Global // copy from VGPR to Global
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment