Commit 35b2971e authored by danyao12's avatar danyao12
Browse files

fix bugs and optimize bwd qloop 2 kernels

parent 52478ac3
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
/* /*
Backprop for Gemm + Softmax + Gemm fused operation, where forward prop is defined as: Backprop for Gemm + Softmax + Gemm fused operation, where forward prop is defined as:
...@@ -105,48 +105,48 @@ static constexpr bool Deterministic = false; ...@@ -105,48 +105,48 @@ static constexpr bool Deterministic = false;
#if(DIM <= 32) #if(DIM <= 32)
// clang-format off // clang-format off
using DeviceGemmInstance = using DeviceGemmInstance =
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| DDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2|YDotYGrad| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic| // ##############################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| DDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2|YDotYGrad| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| KPer| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | | // ##############################################################################################| | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| KPer| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Block| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | | // ##############################################################################################| | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Block| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | // ##############################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | |
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 32, 32, 8, 8, 2, 32, 32, 4, 1, 1, 1, 32, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 64, 1, 4>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>; ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 32, 32, 8, 8, 2, 32, 32, 4, 1, 1, 1, 32, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 64, 1, 4>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// clang-format on // clang-format on
#elif(DIM <= 64) #elif(DIM <= 64)
// clang-format off // clang-format off
using DeviceGemmInstance = using DeviceGemmInstance =
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| DDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2|YDotYGrad| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic| // ##############################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| DDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2|YDotYGrad| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| KPer| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | | // ##############################################################################################| | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| KPer| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Block| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | | // ##############################################################################################| | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Block| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | // ##############################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | |
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 64, 128, 64, 64, 32, 8, 8, 2, 32, 32, 2, 1, 2, 1, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>; ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 64, 128, 64, 64, 32, 8, 8, 2, 32, 32, 2, 1, 2, 1, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| DDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2|YDotYGrad| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic| // ##############################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| DDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2|YDotYGrad| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| KPer| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | | // ##############################################################################################| | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| KPer| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Block| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | | // ##############################################################################################| | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Block| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | // ##############################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | |
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 1, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, false, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>; // ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 1, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, false, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 1, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, false, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>; // ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 1, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, false, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 2, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, false, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>; // ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 2, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, false, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// clang-format on // clang-format on
#elif(DIM <= 128) #elif(DIM <= 128)
// clang-format off // clang-format off
using DeviceGemmInstance = using DeviceGemmInstance =
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| DDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2|YDotYGrad| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic| // ##############################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| DDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2|YDotYGrad| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| KPer| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | | // ##############################################################################################| | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| KPer| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Block| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | | // ##############################################################################################| | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Block| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | | // ##############################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 4, 1, 4, 2, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>; // ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 4, 1, 4, 2, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 4, 1, 4, 1, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>; // ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 4, 1, 4, 1, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 4, 1, 4, 2, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>; // ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 4, 1, 4, 2, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 4, 1, 4, 4, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>; // ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 4, 1, 4, 4, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 4, 1, 4, 1, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>; // ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 4, 1, 4, 1, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 64, 128, 64, 128, 32, 8, 8, 2, 32, 32, 2, 1, 4, 1, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>; // ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 64, 128, 64, 128, 32, 8, 8, 2, 32, 32, 2, 1, 4, 1, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 64, 128, 64, 128, 32, 8, 8, 2, 32, 32, 2, 1, 4, 2, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>; // ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 64, 128, 64, 128, 32, 8, 8, 2, 32, 32, 2, 1, 4, 2, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 64, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 1, 4, 1, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>; // ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 64, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 1, 4, 1, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 64, 128, 128, 128, 32, 8, 8, 2, 32, 32, 2, 1, 4, 1, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>; ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 64, 128, 128, 128, 32, 8, 8, 2, 32, 32, 2, 1, 4, 1, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 32, 128, 128, 128, 32, 8, 8, 2, 32, 32, 1, 1, 4, 1, 64, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>; // ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 32, 128, 128, 128, 32, 8, 8, 2, 32, 32, 1, 1, 4, 1, 64, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 32, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 1, 4, 1, 64, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>; // ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 32, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 1, 4, 1, 64, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// clang-format on // clang-format on
#endif #endif
...@@ -382,11 +382,6 @@ int run(int argc, char* argv[]) ...@@ -382,11 +382,6 @@ int run(int argc, char* argv[])
std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1, M}; std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1, M};
std::vector<ck::index_t> lse_gs_ms_strides{G1 * M, M, 1}; // LSE layout [G0, G1, M] std::vector<ck::index_t> lse_gs_ms_strides{G1 * M, M, 1}; // LSE layout [G0, G1, M]
// y dot ygrad sum
std::vector<ck::index_t> d_gs_ms_lengths{G0, G1, M};
std::vector<ck::index_t> d_gs_ms_strides{
G1 * M, M, 1}; // O dot Ograd row sum layout [G0, G1, M]
Tensor<InputDataType> q_gs_ms_ks(q_gs_ms_ks_lengths, q_gs_ms_ks_strides); Tensor<InputDataType> q_gs_ms_ks(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<InputDataType> k_gs_ns_ks(k_gs_ns_ks_lengths, k_gs_ns_ks_strides); Tensor<InputDataType> k_gs_ns_ks(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
Tensor<ZDataType> z_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides); Tensor<ZDataType> z_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
...@@ -394,7 +389,7 @@ int run(int argc, char* argv[]) ...@@ -394,7 +389,7 @@ int run(int argc, char* argv[])
Tensor<InputDataType> y_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides); Tensor<InputDataType> y_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<InputDataType> ygrad_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides); Tensor<InputDataType> ygrad_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<LSEDataType> lse_gs_ms(lse_gs_ms_lengths, lse_gs_ms_strides); Tensor<LSEDataType> lse_gs_ms(lse_gs_ms_lengths, lse_gs_ms_strides);
Tensor<DDataType> d_gs_ms(d_gs_ms_lengths, d_gs_ms_strides); Tensor<DDataType> d_gs_ms(lse_gs_ms_lengths, lse_gs_ms_strides);
std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl; std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl;
std::cout << "k_gs_ns_ks: " << k_gs_ns_ks.mDesc << std::endl; std::cout << "k_gs_ns_ks: " << k_gs_ns_ks.mDesc << std::endl;
...@@ -532,7 +527,6 @@ int run(int argc, char* argv[]) ...@@ -532,7 +527,6 @@ int run(int argc, char* argv[])
y_gs_ms_os_lengths, y_gs_ms_os_lengths,
y_gs_ms_os_strides, y_gs_ms_os_strides,
lse_gs_ms_lengths, lse_gs_ms_lengths,
d_gs_ms_lengths,
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths}, {}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides}, {}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths}, {}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
...@@ -579,7 +573,6 @@ int run(int argc, char* argv[]) ...@@ -579,7 +573,6 @@ int run(int argc, char* argv[])
y_gs_ms_os_lengths, y_gs_ms_os_lengths,
y_gs_ms_os_strides, y_gs_ms_os_strides,
lse_gs_ms_lengths, lse_gs_ms_lengths,
d_gs_ms_lengths,
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths}, {}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides}, {}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths}, {}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
/* /*
Backprop for Gemm + Softmax + Gemm fused operation, where forward prop is defined as: Backprop for Gemm + Softmax + Gemm fused operation, where forward prop is defined as:
...@@ -24,7 +24,7 @@ Kernel outputs: ...@@ -24,7 +24,7 @@ Kernel outputs:
*/ */
#define USING_MASK 0 #define USING_MASK 0
#define DIM 32 // DIM should be a multiple of 8. #define DIM 128 // DIM should be a multiple of 8.
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
...@@ -104,48 +104,48 @@ static constexpr bool Deterministic = false; ...@@ -104,48 +104,48 @@ static constexpr bool Deterministic = false;
#if(DIM <= 32) #if(DIM <= 32)
// clang-format off // clang-format off
using DeviceGemmInstance = using DeviceGemmInstance =
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| DDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2|YDotYGrad| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic| // ##############################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| DDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2|YDotYGrad| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| KPer| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | | // ##############################################################################################| | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| KPer| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Block| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | | // ##############################################################################################| | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Block| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | // ##############################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | |
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 32, 32, 8, 8, 2, 32, 32, 4, 1, 1, 1, 32, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 64, 1, 4>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>; ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 32, 32, 8, 8, 2, 32, 32, 4, 1, 1, 1, 32, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 64, 1, 4>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// clang-format on // clang-format on
#elif(DIM <= 64) #elif(DIM <= 64)
// clang-format off // clang-format off
using DeviceGemmInstance = using DeviceGemmInstance =
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| DDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2|YDotYGrad| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic| // ##############################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| DDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2|YDotYGrad| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| KPer| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | | // ##############################################################################################| | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| KPer| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Block| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | | // ##############################################################################################| | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Block| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | // ##############################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | |
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 64, 128, 64, 64, 32, 8, 8, 2, 32, 32, 2, 1, 2, 1, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>; ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 64, 128, 64, 64, 32, 8, 8, 2, 32, 32, 2, 1, 2, 1, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| DDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2|YDotYGrad| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic| // ##############################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| DDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2|YDotYGrad| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| KPer| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | | // ##############################################################################################| | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| KPer| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Block| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | | // ##############################################################################################| | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Block| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | | // ##############################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 1, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, false, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>; // ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 1, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, false, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 1, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, false, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>; // ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 1, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, false, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 2, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, false, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>; // ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 2, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, false, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// clang-format on // clang-format on
#elif(DIM <= 128) #elif(DIM <= 128)
// clang-format off // clang-format off
using DeviceGemmInstance = using DeviceGemmInstance =
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| DDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2|YDotYGrad| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic| // ##############################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| DDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2|YDotYGrad| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| KPer| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | | // ##############################################################################################| | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| KPer| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Block| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | | // ##############################################################################################| | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Block| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | | // ##############################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 4, 1, 4, 2, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>; // ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 4, 1, 4, 2, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 4, 1, 4, 1, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>; // ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 4, 1, 4, 1, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 4, 1, 4, 2, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>; // ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 4, 1, 4, 2, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 4, 1, 4, 4, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>; // ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 4, 1, 4, 4, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 4, 1, 4, 1, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>; // ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 4, 1, 4, 1, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 64, 128, 64, 128, 32, 8, 8, 2, 32, 32, 2, 1, 4, 1, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>; // ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 64, 128, 64, 128, 32, 8, 8, 2, 32, 32, 2, 1, 4, 1, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 64, 128, 64, 128, 32, 8, 8, 2, 32, 32, 2, 1, 4, 2, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>; // ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 64, 128, 64, 128, 32, 8, 8, 2, 32, 32, 2, 1, 4, 2, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 64, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 1, 4, 1, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>; // ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 64, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 1, 4, 1, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 64, 128, 128, 128, 32, 8, 8, 2, 32, 32, 2, 1, 4, 1, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>; ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 64, 128, 128, 128, 32, 8, 8, 2, 32, 32, 2, 1, 4, 1, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 32, 128, 128, 128, 32, 8, 8, 2, 32, 32, 1, 1, 4, 1, 64, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>; // ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 32, 128, 128, 128, 32, 8, 8, 2, 32, 32, 1, 1, 4, 1, 64, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 32, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 1, 4, 1, 64, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>; // ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 32, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 1, 4, 1, 64, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// clang-format on // clang-format on
#endif #endif
...@@ -260,7 +260,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k, ...@@ -260,7 +260,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
int run(int argc, char* argv[]) int run(int argc, char* argv[])
{ {
bool do_verification = true; bool do_verification = true;
int init_method = 1; // method 1 will have slightly higher error; TODO: to investigate int init_method = 2; // method 1 will have slightly higher error; TODO: to investigate
bool time_kernel = true; bool time_kernel = true;
// Overall QKV matrices shape // Overall QKV matrices shape
...@@ -410,10 +410,6 @@ int run(int argc, char* argv[]) ...@@ -410,10 +410,6 @@ int run(int argc, char* argv[])
// LSE // LSE
std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1, M}; std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1, M};
std::vector<ck::index_t> lse_gs_ms_strides{G1 * M, M, 1}; // LSE layout [G0, G1, M] std::vector<ck::index_t> lse_gs_ms_strides{G1 * M, M, 1}; // LSE layout [G0, G1, M]
// D = row_sum(y dot ygrad)
std::vector<ck::index_t> d_gs_ms_lengths{G0, G1, M};
std::vector<ck::index_t> d_gs_ms_strides{G1 * M, M, 1}; // D layout [G0, G1, M]
problem_descs.push_back({ problem_descs.push_back({
q_gs_ms_ks_lengths, q_gs_ms_ks_lengths,
q_gs_ms_ks_strides, q_gs_ms_ks_strides,
...@@ -427,8 +423,6 @@ int run(int argc, char* argv[]) ...@@ -427,8 +423,6 @@ int run(int argc, char* argv[])
y_gs_ms_os_strides, y_gs_ms_os_strides,
lse_gs_ms_lengths, lse_gs_ms_lengths,
lse_gs_ms_strides, lse_gs_ms_strides,
d_gs_ms_lengths,
d_gs_ms_strides,
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths}, {}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides}, {}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths}, {}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
...@@ -452,7 +446,7 @@ int run(int argc, char* argv[]) ...@@ -452,7 +446,7 @@ int run(int argc, char* argv[])
Tensor<InputDataType> y_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides); Tensor<InputDataType> y_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<InputDataType> ygrad_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides); Tensor<InputDataType> ygrad_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<LSEDataType> lse_gs_ms(lse_gs_ms_lengths, lse_gs_ms_strides); Tensor<LSEDataType> lse_gs_ms(lse_gs_ms_lengths, lse_gs_ms_strides);
Tensor<DDataType> d_gs_ms(d_gs_ms_lengths, d_gs_ms_strides); Tensor<DDataType> d_gs_ms(lse_gs_ms_lengths, lse_gs_ms_strides);
if(i < 4) if(i < 4)
{ {
std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl; std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -44,13 +44,14 @@ __global__ void ...@@ -44,13 +44,14 @@ __global__ void
const InputDataType* __restrict__ p_ygrad_grid, const InputDataType* __restrict__ p_ygrad_grid,
DDataType* __restrict__ p_d_grid, DDataType* __restrict__ p_d_grid,
const YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock, y_grid_desc_mblock_mperblock_nblock_nperblock,
const DGridDescriptor_M d_grid_desc_m, const DGridDescriptor_M d_grid_desc_m,
const Block2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
const index_t batch_count, const index_t batch_count,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch) const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
...@@ -62,11 +63,10 @@ __global__ void ...@@ -62,11 +63,10 @@ __global__ void
const long_index_t d_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t d_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetLSEBasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetLSEBasePtr(g_idx)));
// GridwiseGemm::test();
GridwiseGemm::Run(p_y_grid + c_batch_offset, GridwiseGemm::Run(p_y_grid + c_batch_offset,
p_ygrad_grid + c_batch_offset, p_ygrad_grid + c_batch_offset,
p_d_grid + d_batch_offset, p_d_grid + d_batch_offset,
c_grid_desc_mblock_mperblock_nblock_nperblock, y_grid_desc_mblock_mperblock_nblock_nperblock,
d_grid_desc_m, d_grid_desc_m,
block_2_ctile_map); block_2_ctile_map);
...@@ -74,7 +74,7 @@ __global__ void ...@@ -74,7 +74,7 @@ __global__ void
ignore = p_y_grid; ignore = p_y_grid;
ignore = p_ygrad_grid; ignore = p_ygrad_grid;
ingore = p_d_grid; ingore = p_d_grid;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; ignore = y_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = d_grid_desc_m; ignore = d_grid_desc_m;
ignore = block_2_ctile_map; ignore = block_2_ctile_map;
ignore = batch_count; ignore = batch_count;
...@@ -98,18 +98,18 @@ template <typename GridwiseGemm, ...@@ -98,18 +98,18 @@ template <typename GridwiseGemm,
typename ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3, typename ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3,
typename B1GridDesc_BK0_N_BK1, typename B1GridDesc_BK0_N_BK1,
typename LSEGridDescriptor_M, typename LSEGridDescriptor_M,
typename DGridDescriptor_M,
typename YGradGridDesc_O0_M_O1, typename YGradGridDesc_O0_M_O1,
typename Block2CTileMap, typename Block2CTileMap,
typename ComputeBasePtrOfStridedBatch, typename ComputeBasePtrOfStridedBatch,
typename C0MatrixMask, typename C0MatrixMask,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
bool IsDropout,
bool Deterministic> bool Deterministic>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
#endif #endif
kernel_batched_multihead_attention_backward_xdl_cshuffle_light_v1( kernel_batched_multihead_attention_backward_qloop_xdl_cshuffle_light_v1(
const InputDataType* __restrict__ p_a_grid, const InputDataType* __restrict__ p_a_grid,
const InputDataType* __restrict__ p_b_grid, const InputDataType* __restrict__ p_b_grid,
ZDataType* __restrict__ p_z_grid, ZDataType* __restrict__ p_z_grid,
...@@ -128,10 +128,9 @@ __global__ void ...@@ -128,10 +128,9 @@ __global__ void
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3 const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
const LSEGridDescriptor_M lse_grid_desc_m, const LSEGridDescriptor_M lse_grid_desc_m,
const DGridDescriptor_M d_grid_desc_m,
const YGradGridDesc_O0_M_O1 ygrad_grid_desc_o0_m_o1, const YGradGridDesc_O0_M_O1 ygrad_grid_desc_o0_m_o1,
const Block2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
const index_t batch_count, const index_t batch_count,
...@@ -144,7 +143,8 @@ __global__ void ...@@ -144,7 +143,8 @@ __global__ void
const index_t raw_m_padded, const index_t raw_m_padded,
const index_t raw_n_padded) const index_t raw_n_padded)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
...@@ -174,70 +174,70 @@ __global__ void ...@@ -174,70 +174,70 @@ __global__ void
{ {
for(index_t i = 0; i < nblock; i++) for(index_t i = 0; i < nblock; i++)
{ {
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset, GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
p_b_grid + b_batch_offset, p_a_grid + a_batch_offset,
z_matrix_ptr, p_b_grid + b_batch_offset,
p_b1_grid + b1_batch_offset, z_matrix_ptr,
p_lse_grid + lse_batch_offset, p_b1_grid + b1_batch_offset,
p_d_grid + lse_batch_offset, p_lse_grid + lse_batch_offset,
p_ygrad_grid + c_batch_offset, p_d_grid + lse_batch_offset,
p_qgrad_grid + a_batch_offset, p_ygrad_grid + c_batch_offset,
p_kgrad_grid + b_batch_offset, p_qgrad_grid + a_batch_offset,
p_vgrad_grid + b1_batch_offset, p_kgrad_grid + b_batch_offset,
p_shared, p_vgrad_grid + b1_batch_offset,
a_element_op, p_shared,
b_element_op, a_element_op,
acc_element_op, b_element_op,
b1_element_op, acc_element_op,
c_element_op, b1_element_op,
a_grid_desc_ak0_m_ak1, c_element_op,
b_grid_desc_bk0_n_bk1, a_grid_desc_ak0_m_ak1,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, b_grid_desc_bk0_n_bk1,
b1_grid_desc_bk0_n_bk1, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
lse_grid_desc_m, b1_grid_desc_bk0_n_bk1,
d_grid_desc_m, lse_grid_desc_m,
ygrad_grid_desc_o0_m_o1, ygrad_grid_desc_o0_m_o1,
block_2_ctile_map, block_2_ctile_map,
c0_matrix_mask, c0_matrix_mask,
p_drop, p_drop,
ph, ph,
z_random_matrix_offset, z_random_matrix_offset,
raw_n_padded, raw_n_padded,
i); i);
} }
} }
else else
{ {
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset, GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
p_b_grid + b_batch_offset, p_a_grid + a_batch_offset,
z_matrix_ptr, p_b_grid + b_batch_offset,
p_b1_grid + b1_batch_offset, z_matrix_ptr,
p_lse_grid + lse_batch_offset, p_b1_grid + b1_batch_offset,
p_d_grid + lse_batch_offset, p_lse_grid + lse_batch_offset,
p_ygrad_grid + c_batch_offset, p_d_grid + lse_batch_offset,
p_qgrad_grid + a_batch_offset, p_ygrad_grid + c_batch_offset,
p_kgrad_grid + b_batch_offset, p_qgrad_grid + a_batch_offset,
p_vgrad_grid + b1_batch_offset, p_kgrad_grid + b_batch_offset,
p_shared, p_vgrad_grid + b1_batch_offset,
a_element_op, p_shared,
b_element_op, a_element_op,
acc_element_op, b_element_op,
b1_element_op, acc_element_op,
c_element_op, b1_element_op,
a_grid_desc_ak0_m_ak1, c_element_op,
b_grid_desc_bk0_n_bk1, a_grid_desc_ak0_m_ak1,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, b_grid_desc_bk0_n_bk1,
b1_grid_desc_bk0_n_bk1, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
lse_grid_desc_m, b1_grid_desc_bk0_n_bk1,
d_grid_desc_m, lse_grid_desc_m,
ygrad_grid_desc_o0_m_o1, ygrad_grid_desc_o0_m_o1,
block_2_ctile_map, block_2_ctile_map,
c0_matrix_mask, c0_matrix_mask,
p_drop, p_drop,
ph, ph,
z_random_matrix_offset, z_random_matrix_offset,
raw_n_padded, raw_n_padded,
0); 0);
} }
#else #else
ignore = p_a_grid; ignore = p_a_grid;
...@@ -328,7 +328,7 @@ template <index_t NumDimG, ...@@ -328,7 +328,7 @@ template <index_t NumDimG,
MaskingSpecialization MaskingSpec, MaskingSpecialization MaskingSpec,
bool Deterministic, bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
: public BaseOperator // TODO inherit atten bwd op once API stablizes : public BaseOperator // TODO inherit atten bwd op once API stablizes
{ {
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
...@@ -336,12 +336,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -336,12 +336,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size(); static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size(); static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
static constexpr index_t DMPerBlock = BlockSize;
static constexpr index_t DMPerBlock = BlockSize;
// TODO: implement bias combination // TODO: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented"); static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1; using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -613,10 +613,22 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -613,10 +613,22 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
const auto M = math::integer_divide_ceil(MRaw, DMPerBlock) * DMPerBlock; const auto M = math::integer_divide_ceil(MRaw, DMPerBlock) * DMPerBlock;
const auto MPad = M - MRaw; const auto MPad = M - MRaw;
return transform_tensor_descriptor(d_grid_desc_mraw, if constexpr(GemmSpec == GemmSpecialization::MPadding ||
make_tuple(make_right_pad_transform(MRaw, MPad)), GemmSpec == GemmSpecialization::MNPadding ||
make_tuple(Sequence<0>{}), GemmSpec == GemmSpecialization::MKPadding ||
make_tuple(Sequence<0>{})); GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M
return transform_tensor_descriptor(d_grid_desc_mraw,
make_tuple(make_right_pad_transform(MRaw, MPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
}
else
{
// not pad M
return d_grid_desc_mraw;
}
} }
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {})); using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
...@@ -707,7 +719,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -707,7 +719,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
}; };
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1< using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1<
InputDataType, // TODO: distinguish A/B datatype InputDataType, // TODO: distinguish A/B datatype
OutputDataType, OutputDataType,
ZDataType, ZDataType,
...@@ -729,7 +741,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -729,7 +741,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
B1GridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1,
YGridDesc_M_O, YGridDesc_M_O,
LSEGridDesc_M, LSEGridDesc_M,
LSEGridDesc_M,
NumGemmKPrefetchStage, NumGemmKPrefetchStage,
BlockSize, BlockSize,
MPerBlock, MPerBlock,
...@@ -772,14 +783,14 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -772,14 +783,14 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
Deterministic>; Deterministic>;
// GridwiseYDotYGrad // GridwiseYDotYGrad
using GridwiseYDotYGrad = using GridwiseYDotYGrad = GridwiseBatchedMultiheadAttentionBackward_YDotYGrad<InputDataType,
GridwiseBatchedMultiheadAttentionBackward_YDotYGrad<InputDataType, // TODO: distinguish A/B DDataType,
DDataType, // datatype DYGridDesc_M_O,
YGridDesc_M_O, DGridDesc_M,
DGridDesc_M, BlockSize,
BlockSize, DMPerBlock,
DMPerBlock, DKPerBlock,
DKPerBlock>; Gemm1NPerBlock>;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
...@@ -808,7 +819,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -808,7 +819,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths, const std::vector<index_t>& lse_gs_ms_lengths,
const std::vector<index_t>& d_gs_ms_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths, const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides, const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias> const std::array<std::vector<ck::index_t>, NumAcc1Bias>
...@@ -845,7 +855,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -845,7 +855,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
d_y_grid_desc_m_o_{DTransform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths, d_y_grid_desc_m_o_{DTransform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)}, c_gs_ms_gemm1ns_strides)},
lse_grid_desc_m_{DeviceOp::MakeLSEGridDescriptor_M(lse_gs_ms_lengths[NumDimG])}, lse_grid_desc_m_{DeviceOp::MakeLSEGridDescriptor_M(lse_gs_ms_lengths[NumDimG])},
d_grid_desc_m_{DeviceOp::MakeDGridDescriptor_M(d_gs_ms_lengths[NumDimG])}, d_grid_desc_m_{DeviceOp::MakeDGridDescriptor_M(lse_gs_ms_lengths[NumDimG])},
k_grid_desc_n_k_{ k_grid_desc_n_k_{
Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)}, Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)},
ygrad_grid_desc_o0_m_o1_{DeviceOp::MakeYGradGridDescriptor_O0_M_O1( ygrad_grid_desc_o0_m_o1_{DeviceOp::MakeYGradGridDescriptor_O0_M_O1(
...@@ -1024,19 +1034,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -1024,19 +1034,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
float ave_time = 0; float ave_time = 0;
{ {
const index_t grid_size = const index_t grid_size =
(Deterministic arg.d_block_2_ctile_map_.CalculateGridSize(arg.d_y_grid_desc_m_o_) *
? 1
: arg.d_block_2_ctile_map_.CalculateGridSize(arg.d_y_grid_desc_m_o_)) *
arg.batch_count_; arg.batch_count_;
std::cout << "grid_size: " << grid_size
<< "grid_size / arg.batch_count_: " << grid_size / arg.batch_count_
<< " arg.batch_count_: " << arg.batch_count_ << std::endl;
std::cout << "MPerBlock: " << MPerBlock << " Gemm1NPerBlock: " << Gemm1NPerBlock
<< std::endl;
std::cout << "arg.d_y_grid_desc_m_o_: {" << arg.d_y_grid_desc_m_o_.GetLength(I0)
<< "," << arg.d_y_grid_desc_m_o_.GetLength(I1) << "}" << std::endl;
std::cout << "arg.d_grid_desc_m_: {" << arg.d_grid_desc_m_.GetLength(I0) << "}"
<< std::endl;
auto launch_kernel = [&]() { auto launch_kernel = [&]() {
const auto kernel = kernel_batched_multihead_attention_backward_ydotygrad_v1< const auto kernel = kernel_batched_multihead_attention_backward_ydotygrad_v1<
...@@ -1073,9 +1072,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -1073,9 +1072,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
: arg.block_2_ctile_map_.CalculateGridSize(arg.k_grid_desc_n_k_)) * : arg.block_2_ctile_map_.CalculateGridSize(arg.k_grid_desc_n_k_)) *
arg.batch_count_; arg.batch_count_;
auto launch_kernel = [&](auto has_main_k_block_loop_) { auto launch_kernel = [&](auto has_main_k_block_loop_, auto is_dropout_) {
const auto kernel = const auto kernel =
kernel_batched_multihead_attention_backward_xdl_cshuffle_light_v1< kernel_batched_multihead_attention_backward_qloop_xdl_cshuffle_light_v1<
GridwiseGemm, GridwiseGemm,
InputDataType, InputDataType,
OutputDataType, OutputDataType,
...@@ -1092,12 +1091,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -1092,12 +1091,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3, typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3,
DeviceOp::B1GridDesc_BK0_N_BK1, DeviceOp::B1GridDesc_BK0_N_BK1,
DeviceOp::LSEGridDesc_M, DeviceOp::LSEGridDesc_M,
DeviceOp::LSEGridDesc_M,
DeviceOp::YGradGridDesc_O0_M_O1, DeviceOp::YGradGridDesc_O0_M_O1,
typename GridwiseGemm::DefaultBlock2CTileMap, typename GridwiseGemm::DefaultBlock2CTileMap,
ComputeBasePtrOfStridedBatch, ComputeBasePtrOfStridedBatch,
C0MatrixMask, C0MatrixMask,
has_main_k_block_loop_, has_main_k_block_loop_,
is_dropout_,
Deterministic>; Deterministic>;
return launch_and_time_kernel( return launch_and_time_kernel(
...@@ -1126,7 +1125,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -1126,7 +1125,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
arg.c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_, arg.c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg.b1_grid_desc_bk0_n_bk1_, arg.b1_grid_desc_bk0_n_bk1_,
arg.lse_grid_desc_m_, arg.lse_grid_desc_m_,
arg.lse_grid_desc_m_,
arg.ygrad_grid_desc_o0_m_o1_, arg.ygrad_grid_desc_o0_m_o1_,
arg.block_2_ctile_map_, arg.block_2_ctile_map_,
arg.batch_count_, arg.batch_count_,
...@@ -1139,9 +1137,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -1139,9 +1137,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
arg.m_raw_padded_, arg.m_raw_padded_,
arg.n_raw_padded_); arg.n_raw_padded_);
}; };
if(arg.p_drop_ > 0.0)
ave_time += launch_kernel(integral_constant<bool, false>{}); {
ave_time += launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
ave_time += launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
return ave_time; return ave_time;
} }
...@@ -1161,17 +1166,19 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -1161,17 +1166,19 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
#if DEBUG_LOG
arg.Print();
#endif
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
{ {
return false; return false;
} }
// TODO: Check if tensor specialization & strides mismatch // TODO: Check if tensor specialization & strides mismatch
if(!GridwiseYDotYGrad::CheckValidity(arg.d_y_grid_desc_m_o_, arg.d_block_2_ctile_map_))
{
return false;
}
// Check if C permute dimension matches GEMM + GEMM shape // Check if C permute dimension matches GEMM + GEMM shape
const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded
const index_t c_m = arg.y_grid_desc_m_o_.GetLength(I0); const index_t c_m = arg.y_grid_desc_m_o_.GetLength(I0);
...@@ -1255,7 +1262,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -1255,7 +1262,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths, const std::vector<index_t>& lse_gs_ms_lengths,
const std::vector<index_t>& d_gs_ms_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths, const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides, const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias> const std::array<std::vector<ck::index_t>, NumAcc1Bias>
...@@ -1294,7 +1300,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -1294,7 +1300,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
lse_gs_ms_lengths, lse_gs_ms_lengths,
d_gs_ms_lengths,
acc0_biases_gs_ms_ns_lengths, acc0_biases_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides, acc0_biases_gs_ms_ns_strides,
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
...@@ -1337,7 +1342,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -1337,7 +1342,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths, const std::vector<index_t>& lse_gs_ms_lengths,
const std::vector<index_t>& d_gs_ms_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths, const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides, const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias> const std::array<std::vector<ck::index_t>, NumAcc1Bias>
...@@ -1376,7 +1380,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -1376,7 +1380,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
lse_gs_ms_lengths, lse_gs_ms_lengths,
d_gs_ms_lengths,
acc0_biases_gs_ms_ns_lengths, acc0_biases_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides, acc0_biases_gs_ms_ns_strides,
acc1_biases_gs_ms_gemm1ns_lengths, acc1_biases_gs_ms_gemm1ns_lengths,
...@@ -1402,7 +1405,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -1402,7 +1405,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1" str << "DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -43,13 +43,14 @@ __global__ void ...@@ -43,13 +43,14 @@ __global__ void
const InputDataType* __restrict__ p_ygrad_grid, const InputDataType* __restrict__ p_ygrad_grid,
DDataType* __restrict__ p_d_grid, DDataType* __restrict__ p_d_grid,
const YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock, y_grid_desc_mblock_mperblock_nblock_nperblock,
const DGridDescriptor_M d_grid_desc_m, const DGridDescriptor_M d_grid_desc_m,
const Block2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
const index_t batch_count, const index_t batch_count,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch) const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
...@@ -61,11 +62,10 @@ __global__ void ...@@ -61,11 +62,10 @@ __global__ void
const long_index_t d_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t d_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetLSEBasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetLSEBasePtr(g_idx)));
// GridwiseGemm::test();
GridwiseGemm::Run(p_y_grid + c_batch_offset, GridwiseGemm::Run(p_y_grid + c_batch_offset,
p_ygrad_grid + c_batch_offset, p_ygrad_grid + c_batch_offset,
p_d_grid + d_batch_offset, p_d_grid + d_batch_offset,
c_grid_desc_mblock_mperblock_nblock_nperblock, y_grid_desc_mblock_mperblock_nblock_nperblock,
d_grid_desc_m, d_grid_desc_m,
block_2_ctile_map); block_2_ctile_map);
...@@ -73,7 +73,7 @@ __global__ void ...@@ -73,7 +73,7 @@ __global__ void
ignore = p_y_grid; ignore = p_y_grid;
ignore = p_ygrad_grid; ignore = p_ygrad_grid;
ingore = p_d_grid; ingore = p_d_grid;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; ignore = y_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = d_grid_desc_m; ignore = d_grid_desc_m;
ignore = block_2_ctile_map; ignore = block_2_ctile_map;
ignore = batch_count; ignore = batch_count;
...@@ -97,18 +97,18 @@ template <typename GridwiseGemm, ...@@ -97,18 +97,18 @@ template <typename GridwiseGemm,
typename ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3, typename ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3,
typename B1GridDesc_BK0_N_BK1, typename B1GridDesc_BK0_N_BK1,
typename LSEGridDescriptor_M, typename LSEGridDescriptor_M,
typename DGridDescriptor_M,
typename YGradGridDesc_M0_O_M1, typename YGradGridDesc_M0_O_M1,
typename Block2CTileMap, typename Block2CTileMap,
typename ComputeBasePtrOfStridedBatch, typename ComputeBasePtrOfStridedBatch,
typename C0MatrixMask, typename C0MatrixMask,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
bool IsDropout,
bool Deterministic> bool Deterministic>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
#endif #endif
kernel_batched_multihead_attention_backward_xdl_cshuffle_light_v2( kernel_batched_multihead_attention_backward_qloop_xdl_cshuffle_light_v2(
const InputDataType* __restrict__ p_a_grid, const InputDataType* __restrict__ p_a_grid,
const InputDataType* __restrict__ p_b_grid, const InputDataType* __restrict__ p_b_grid,
ZDataType* __restrict__ p_z_grid, ZDataType* __restrict__ p_z_grid,
...@@ -130,7 +130,6 @@ __global__ void ...@@ -130,7 +130,6 @@ __global__ void
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
const LSEGridDescriptor_M lse_grid_desc_m, const LSEGridDescriptor_M lse_grid_desc_m,
const DGridDescriptor_M d_grid_desc_m,
const YGradGridDesc_M0_O_M1 ygrad_grid_desc_m0_o_m1, const YGradGridDesc_M0_O_M1 ygrad_grid_desc_m0_o_m1,
const Block2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
const index_t batch_count, const index_t batch_count,
...@@ -143,7 +142,8 @@ __global__ void ...@@ -143,7 +142,8 @@ __global__ void
const index_t raw_m_padded, const index_t raw_m_padded,
const index_t raw_n_padded) const index_t raw_n_padded)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
...@@ -173,70 +173,70 @@ __global__ void ...@@ -173,70 +173,70 @@ __global__ void
{ {
for(index_t i = 0; i < nblock; i++) for(index_t i = 0; i < nblock; i++)
{ {
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset, GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
p_b_grid + b_batch_offset, p_a_grid + a_batch_offset,
z_matrix_ptr, p_b_grid + b_batch_offset,
p_b1_grid + b1_batch_offset, z_matrix_ptr,
p_lse_grid + lse_batch_offset, p_b1_grid + b1_batch_offset,
p_d_grid + lse_batch_offset, p_lse_grid + lse_batch_offset,
p_ygrad_grid + c_batch_offset, p_d_grid + lse_batch_offset,
p_qgrad_grid + a_batch_offset, p_ygrad_grid + c_batch_offset,
p_kgrad_grid + b_batch_offset, p_qgrad_grid + a_batch_offset,
p_vgrad_grid + b1_batch_offset, p_kgrad_grid + b_batch_offset,
p_shared, p_vgrad_grid + b1_batch_offset,
a_element_op, p_shared,
b_element_op, a_element_op,
acc_element_op, b_element_op,
b1_element_op, acc_element_op,
c_element_op, b1_element_op,
a_grid_desc_ak0_m_ak1, c_element_op,
b_grid_desc_bk0_n_bk1, a_grid_desc_ak0_m_ak1,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, b_grid_desc_bk0_n_bk1,
b1_grid_desc_bk0_n_bk1, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
lse_grid_desc_m, b1_grid_desc_bk0_n_bk1,
d_grid_desc_m, lse_grid_desc_m,
ygrad_grid_desc_m0_o_m1, ygrad_grid_desc_m0_o_m1,
block_2_ctile_map, block_2_ctile_map,
c0_matrix_mask, c0_matrix_mask,
p_drop, p_drop,
ph, ph,
z_random_matrix_offset, z_random_matrix_offset,
raw_n_padded, raw_n_padded,
i); i);
} }
} }
else else
{ {
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset, GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
p_b_grid + b_batch_offset, p_a_grid + a_batch_offset,
z_matrix_ptr, p_b_grid + b_batch_offset,
p_b1_grid + b1_batch_offset, z_matrix_ptr,
p_lse_grid + lse_batch_offset, p_b1_grid + b1_batch_offset,
p_d_grid + lse_batch_offset, p_lse_grid + lse_batch_offset,
p_ygrad_grid + c_batch_offset, p_d_grid + lse_batch_offset,
p_qgrad_grid + a_batch_offset, p_ygrad_grid + c_batch_offset,
p_kgrad_grid + b_batch_offset, p_qgrad_grid + a_batch_offset,
p_vgrad_grid + b1_batch_offset, p_kgrad_grid + b_batch_offset,
p_shared, p_vgrad_grid + b1_batch_offset,
a_element_op, p_shared,
b_element_op, a_element_op,
acc_element_op, b_element_op,
b1_element_op, acc_element_op,
c_element_op, b1_element_op,
a_grid_desc_ak0_m_ak1, c_element_op,
b_grid_desc_bk0_n_bk1, a_grid_desc_ak0_m_ak1,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, b_grid_desc_bk0_n_bk1,
b1_grid_desc_bk0_n_bk1, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
lse_grid_desc_m, b1_grid_desc_bk0_n_bk1,
d_grid_desc_m, lse_grid_desc_m,
ygrad_grid_desc_m0_o_m1, ygrad_grid_desc_m0_o_m1,
block_2_ctile_map, block_2_ctile_map,
c0_matrix_mask, c0_matrix_mask,
p_drop, p_drop,
ph, ph,
z_random_matrix_offset, z_random_matrix_offset,
raw_n_padded, raw_n_padded,
0); 0);
} }
#else #else
ignore = p_a_grid; ignore = p_a_grid;
...@@ -334,7 +334,7 @@ template <index_t NumDimG, ...@@ -334,7 +334,7 @@ template <index_t NumDimG,
MaskingSpecialization MaskingSpec, MaskingSpecialization MaskingSpec,
bool Deterministic, bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
: public BaseOperator // TODO inherit atten bwd op once API stablizes : public BaseOperator // TODO inherit atten bwd op once API stablizes
{ {
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
...@@ -347,7 +347,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -347,7 +347,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
// TODO: implement bias combination // TODO: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented"); static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2; using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -616,13 +616,25 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -616,13 +616,25 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
{ {
const auto d_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw)); const auto d_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw));
const auto M = math::integer_divide_ceil(MRaw, BlockSize) * BlockSize; const auto M = math::integer_divide_ceil(MRaw, DMPerBlock) * DMPerBlock;
const auto MPad = M - MRaw; const auto MPad = M - MRaw;
return transform_tensor_descriptor(d_grid_desc_mraw, if constexpr(GemmSpec == GemmSpecialization::MPadding ||
make_tuple(make_right_pad_transform(MRaw, MPad)), GemmSpec == GemmSpecialization::MNPadding ||
make_tuple(Sequence<0>{}), GemmSpec == GemmSpecialization::MKPadding ||
make_tuple(Sequence<0>{})); GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M
return transform_tensor_descriptor(d_grid_desc_mraw,
make_tuple(make_right_pad_transform(MRaw, MPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
}
else
{
// not pad M
return d_grid_desc_mraw;
}
} }
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {})); using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
...@@ -713,7 +725,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -713,7 +725,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
}; };
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2< using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2<
InputDataType, // TODO: distinguish A/B datatype InputDataType, // TODO: distinguish A/B datatype
OutputDataType, OutputDataType,
ZDataType, ZDataType,
...@@ -735,7 +747,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -735,7 +747,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
B1GridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1,
YGridDesc_M_O, YGridDesc_M_O,
LSEGridDesc_M, LSEGridDesc_M,
LSEGridDesc_M,
NumGemmKPrefetchStage, NumGemmKPrefetchStage,
BlockSize, BlockSize,
MPerBlock, MPerBlock,
...@@ -786,14 +797,14 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -786,14 +797,14 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
Deterministic>; Deterministic>;
// GridwiseYDotYGrad // GridwiseYDotYGrad
using GridwiseYDotYGrad = using GridwiseYDotYGrad = GridwiseBatchedMultiheadAttentionBackward_YDotYGrad<InputDataType,
GridwiseBatchedMultiheadAttentionBackward_YDotYGrad<InputDataType, // TODO: distinguish A/B DDataType,
DDataType, // datatype DYGridDesc_M_O,
YGridDesc_M_O, DGridDesc_M,
DGridDesc_M, BlockSize,
BlockSize, DMPerBlock,
DMPerBlock, DKPerBlock,
DKPerBlock>; Gemm1NPerBlock>;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
...@@ -822,7 +833,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -822,7 +833,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths, const std::vector<index_t>& lse_gs_ms_lengths,
const std::vector<index_t>& d_gs_ms_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths, const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides, const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias> const std::array<std::vector<ck::index_t>, NumAcc1Bias>
...@@ -859,7 +869,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -859,7 +869,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
d_y_grid_desc_m_o_{DTransform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths, d_y_grid_desc_m_o_{DTransform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)}, c_gs_ms_gemm1ns_strides)},
lse_grid_desc_m_{DeviceOp::MakeLSEGridDescriptor_M(lse_gs_ms_lengths[NumDimG])}, lse_grid_desc_m_{DeviceOp::MakeLSEGridDescriptor_M(lse_gs_ms_lengths[NumDimG])},
d_grid_desc_m_{DeviceOp::MakeDGridDescriptor_M(d_gs_ms_lengths[NumDimG])}, d_grid_desc_m_{DeviceOp::MakeDGridDescriptor_M(lse_gs_ms_lengths[NumDimG])},
k_grid_desc_n_k_{ k_grid_desc_n_k_{
Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)}, Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)},
ygrad_grid_desc_m0_o_m1_{DeviceOp::MakeYGradGridDescriptor_M0_O_M1(y_grid_desc_m_o_)}, ygrad_grid_desc_m0_o_m1_{DeviceOp::MakeYGradGridDescriptor_M0_O_M1(y_grid_desc_m_o_)},
...@@ -1037,19 +1047,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -1037,19 +1047,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
float ave_time = 0; float ave_time = 0;
{ {
const index_t grid_size = const index_t grid_size =
(Deterministic arg.d_block_2_ctile_map_.CalculateGridSize(arg.d_y_grid_desc_m_o_) *
? 1
: arg.d_block_2_ctile_map_.CalculateGridSize(arg.d_y_grid_desc_m_o_)) *
arg.batch_count_; arg.batch_count_;
std::cout << "grid_size: " << grid_size
<< "grid_size / arg.batch_count_: " << grid_size / arg.batch_count_
<< " arg.batch_count_: " << arg.batch_count_ << std::endl;
std::cout << "MPerBlock: " << MPerBlock << " Gemm1NPerBlock: " << Gemm1NPerBlock
<< std::endl;
std::cout << "arg.d_y_grid_desc_m_o_: {" << arg.d_y_grid_desc_m_o_.GetLength(I0)
<< "," << arg.d_y_grid_desc_m_o_.GetLength(I1) << "}" << std::endl;
std::cout << "arg.d_grid_desc_m_: {" << arg.d_grid_desc_m_.GetLength(I0) << "}"
<< std::endl;
auto launch_kernel = [&]() { auto launch_kernel = [&]() {
const auto kernel = kernel_batched_multihead_attention_backward_ydotygrad_v2< const auto kernel = kernel_batched_multihead_attention_backward_ydotygrad_v2<
...@@ -1090,9 +1089,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -1090,9 +1089,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
const auto K = const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
auto launch_kernel = [&](auto has_main_k_block_loop_) { auto launch_kernel = [&](auto has_main_k_block_loop_, auto is_dropout_) {
const auto kernel = const auto kernel =
kernel_batched_multihead_attention_backward_xdl_cshuffle_light_v2< kernel_batched_multihead_attention_backward_qloop_xdl_cshuffle_light_v2<
GridwiseGemm, GridwiseGemm,
InputDataType, InputDataType,
OutputDataType, OutputDataType,
...@@ -1109,12 +1108,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -1109,12 +1108,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3, typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3,
DeviceOp::B1GridDesc_BK0_N_BK1, DeviceOp::B1GridDesc_BK0_N_BK1,
DeviceOp::LSEGridDesc_M, DeviceOp::LSEGridDesc_M,
DeviceOp::LSEGridDesc_M,
DeviceOp::YGradGridDesc_M0_O_M1, DeviceOp::YGradGridDesc_M0_O_M1,
typename GridwiseGemm::DefaultBlock2CTileMap, typename GridwiseGemm::DefaultBlock2CTileMap,
ComputeBasePtrOfStridedBatch, ComputeBasePtrOfStridedBatch,
C0MatrixMask, C0MatrixMask,
has_main_k_block_loop_, has_main_k_block_loop_,
is_dropout_,
Deterministic>; Deterministic>;
return launch_and_time_kernel( return launch_and_time_kernel(
...@@ -1143,7 +1142,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -1143,7 +1142,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
arg.c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_, arg.c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg.b1_grid_desc_bk0_n_bk1_, arg.b1_grid_desc_bk0_n_bk1_,
arg.lse_grid_desc_m_, arg.lse_grid_desc_m_,
arg.lse_grid_desc_m_,
arg.ygrad_grid_desc_m0_o_m1_, arg.ygrad_grid_desc_m0_o_m1_,
arg.block_2_ctile_map_, arg.block_2_ctile_map_,
arg.batch_count_, arg.batch_count_,
...@@ -1162,13 +1160,22 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -1162,13 +1160,22 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
ave_time += launch_kernel(integral_constant<bool, true>{}); if(arg.p_drop_ > 0.0)
ave_time += launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, true>{});
else
ave_time += launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, false>{});
} }
else else
{ {
ave_time += launch_kernel(integral_constant<bool, false>{}); if(arg.p_drop_ > 0.0)
ave_time += launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, true>{});
else
ave_time += launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{});
} }
return ave_time; return ave_time;
} }
...@@ -1188,17 +1195,19 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -1188,17 +1195,19 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
#if DEBUG_LOG
arg.Print();
#endif
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
{ {
return false; return false;
} }
// TODO: Check if tensor specialization & strides mismatch // TODO: Check if tensor specialization & strides mismatch
if(!GridwiseYDotYGrad::CheckValidity(arg.d_y_grid_desc_m_o_, arg.d_block_2_ctile_map_))
{
return false;
}
// Check if C permute dimension matches GEMM + GEMM shape // Check if C permute dimension matches GEMM + GEMM shape
const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded
const index_t c_m = arg.y_grid_desc_m_o_.GetLength(I0); const index_t c_m = arg.y_grid_desc_m_o_.GetLength(I0);
...@@ -1286,7 +1295,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -1286,7 +1295,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths, const std::vector<index_t>& lse_gs_ms_lengths,
const std::vector<index_t>& d_gs_ms_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths, const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides, const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias> const std::array<std::vector<ck::index_t>, NumAcc1Bias>
...@@ -1325,7 +1333,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -1325,7 +1333,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
lse_gs_ms_lengths, lse_gs_ms_lengths,
d_gs_ms_lengths,
acc0_biases_gs_ms_ns_lengths, acc0_biases_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides, acc0_biases_gs_ms_ns_strides,
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
...@@ -1368,7 +1375,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -1368,7 +1375,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths, const std::vector<index_t>& lse_gs_ms_lengths,
const std::vector<index_t>& d_gs_ms_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths, const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides, const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias> const std::array<std::vector<ck::index_t>, NumAcc1Bias>
...@@ -1407,7 +1413,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -1407,7 +1413,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
lse_gs_ms_lengths, lse_gs_ms_lengths,
d_gs_ms_lengths,
acc0_biases_gs_ms_ns_lengths, acc0_biases_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides, acc0_biases_gs_ms_ns_strides,
acc1_biases_gs_ms_gemm1ns_lengths, acc1_biases_gs_ms_gemm1ns_lengths,
...@@ -1433,7 +1438,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -1433,7 +1438,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2" str << "DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
......
...@@ -73,7 +73,7 @@ __global__ void ...@@ -73,7 +73,7 @@ __global__ void
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3 const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
...@@ -140,7 +140,7 @@ __global__ void ...@@ -140,7 +140,7 @@ __global__ void
c_element_op, c_element_op,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
lse_grid_desc_m, lse_grid_desc_m,
...@@ -175,7 +175,7 @@ __global__ void ...@@ -175,7 +175,7 @@ __global__ void
c_element_op, c_element_op,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
lse_grid_desc_m, lse_grid_desc_m,
......
...@@ -1040,7 +1040,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1040,7 +1040,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
ave_time = launch_kernel(integral_constant<bool, false>{}, ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{}); integral_constant<bool, false>{});
} }
return ave_time; return ave_time;
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -36,7 +36,8 @@ __global__ void ...@@ -36,7 +36,8 @@ __global__ void
kernel_grouped_multihead_attention_backward_ydotygrad_v1( kernel_grouped_multihead_attention_backward_ydotygrad_v1(
const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args, const index_t group_count) const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args, const index_t group_count)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
const index_t block_id = get_block_1d_id(); const index_t block_id = get_block_1d_id();
const auto arg_ptr = reinterpret_cast<const GroupKernelArg*>( const auto arg_ptr = reinterpret_cast<const GroupKernelArg*>(
cast_pointer_to_generic_address_space(group_kernel_args)); cast_pointer_to_generic_address_space(group_kernel_args));
...@@ -89,12 +90,13 @@ template <typename GridwiseGemm, ...@@ -89,12 +90,13 @@ template <typename GridwiseGemm,
typename B1ElementwiseOperation, typename B1ElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
bool IsDropout,
bool Deterministic> bool Deterministic>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
#endif #endif
kernel_grouped_multihead_attention_backward_xdl_cshuffle_v1( kernel_grouped_multihead_attention_backward_qloop_xdl_cshuffle_light_v1(
const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args, const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args,
const index_t group_count, const index_t group_count,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
...@@ -106,7 +108,8 @@ __global__ void ...@@ -106,7 +108,8 @@ __global__ void
const unsigned long long seed, const unsigned long long seed,
const unsigned long long offset) const unsigned long long offset)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id(); const index_t block_id = get_block_1d_id();
const auto arg_ptr = reinterpret_cast<const GroupKernelArg*>( const auto arg_ptr = reinterpret_cast<const GroupKernelArg*>(
...@@ -158,7 +161,7 @@ __global__ void ...@@ -158,7 +161,7 @@ __global__ void
{ {
for(index_t i = 0; i < num_blocks_per_batch; i++) for(index_t i = 0; i < num_blocks_per_batch; i++)
{ {
GridwiseGemm::template Run<HasMainKBlockLoop>( GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset, arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset, arg_ptr[group_id].p_b_grid_ + b_batch_offset,
z_matrix_ptr, z_matrix_ptr,
...@@ -180,7 +183,6 @@ __global__ void ...@@ -180,7 +183,6 @@ __global__ void
arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_, arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].lse_grid_desc_m_, arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].ygrad_grid_desc_o0_m_o1_, arg_ptr[group_id].ygrad_grid_desc_o0_m_o1_,
arg_ptr[group_id].block_2_ctile_map_, arg_ptr[group_id].block_2_ctile_map_,
arg_ptr[group_id].c0_matrix_mask_, arg_ptr[group_id].c0_matrix_mask_,
...@@ -194,7 +196,7 @@ __global__ void ...@@ -194,7 +196,7 @@ __global__ void
} }
else else
{ {
GridwiseGemm::template Run<HasMainKBlockLoop>( GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset, arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset, arg_ptr[group_id].p_b_grid_ + b_batch_offset,
z_matrix_ptr, z_matrix_ptr,
...@@ -216,7 +218,6 @@ __global__ void ...@@ -216,7 +218,6 @@ __global__ void
arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_, arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].lse_grid_desc_m_, arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].ygrad_grid_desc_o0_m_o1_, arg_ptr[group_id].ygrad_grid_desc_o0_m_o1_,
arg_ptr[group_id].block_2_ctile_map_, arg_ptr[group_id].block_2_ctile_map_,
arg_ptr[group_id].c0_matrix_mask_, arg_ptr[group_id].c0_matrix_mask_,
...@@ -307,7 +308,7 @@ template <index_t NumDimG, ...@@ -307,7 +308,7 @@ template <index_t NumDimG,
MaskingSpecialization MaskingSpec, MaskingSpecialization MaskingSpec,
bool Deterministic, bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
: public BaseOperator // TODO inherit atten bwd op once API stablizes : public BaseOperator // TODO inherit atten bwd op once API stablizes
{ {
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
...@@ -320,7 +321,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -320,7 +321,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
// TODO: implement bias combination // TODO: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented"); static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
using DeviceOp = DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1; using DeviceOp = DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1;
struct ProblemDesc struct ProblemDesc
{ {
std::vector<index_t> a_gs_ms_ks_lengths; std::vector<index_t> a_gs_ms_ks_lengths;
...@@ -341,9 +342,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -341,9 +342,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
std::vector<index_t> lse_gs_ms_lengths; std::vector<index_t> lse_gs_ms_lengths;
std::vector<index_t> lse_gs_ms_strides; std::vector<index_t> lse_gs_ms_strides;
std::vector<index_t> d_gs_ms_lengths;
std::vector<index_t> d_gs_ms_strides;
std::vector<std::vector<index_t>> acc0_biases_gs_ms_ns_lengths; std::vector<std::vector<index_t>> acc0_biases_gs_ms_ns_lengths;
std::vector<std::vector<index_t>> acc0_biases_gs_ms_ns_strides; std::vector<std::vector<index_t>> acc0_biases_gs_ms_ns_strides;
...@@ -564,10 +562,22 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -564,10 +562,22 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
const auto M = math::integer_divide_ceil(MRaw, DMPerBlock) * DMPerBlock; const auto M = math::integer_divide_ceil(MRaw, DMPerBlock) * DMPerBlock;
const auto MPad = M - MRaw; const auto MPad = M - MRaw;
return transform_tensor_descriptor(d_grid_desc_mraw, if constexpr(GemmSpec == GemmSpecialization::MPadding ||
make_tuple(make_right_pad_transform(MRaw, MPad)), GemmSpec == GemmSpecialization::MNPadding ||
make_tuple(Sequence<0>{}), GemmSpec == GemmSpecialization::MKPadding ||
make_tuple(Sequence<0>{})); GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M
return transform_tensor_descriptor(d_grid_desc_mraw,
make_tuple(make_right_pad_transform(MRaw, MPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
}
else
{
// not pad M
return d_grid_desc_mraw;
}
} }
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {})); using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
...@@ -658,7 +668,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -658,7 +668,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
}; };
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1< using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1<
InputDataType, // TODO: distinguish A/B datatype InputDataType, // TODO: distinguish A/B datatype
OutputDataType, OutputDataType,
ZDataType, ZDataType,
...@@ -680,7 +690,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -680,7 +690,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
B1GridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1,
YGridDesc_M_O, YGridDesc_M_O,
LSEGridDesc_M, LSEGridDesc_M,
LSEGridDesc_M,
NumGemmKPrefetchStage, NumGemmKPrefetchStage,
BlockSize, BlockSize,
MPerBlock, MPerBlock,
...@@ -725,14 +734,14 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -725,14 +734,14 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>; using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>;
// GridwiseYDotYGrad // GridwiseYDotYGrad
using GridwiseYDotYGrad = using GridwiseYDotYGrad = GridwiseBatchedMultiheadAttentionBackward_YDotYGrad<InputDataType,
GridwiseBatchedMultiheadAttentionBackward_YDotYGrad<InputDataType, // TODO: distinguish A/B DDataType,
DDataType, // datatype DYGridDesc_M_O,
YGridDesc_M_O, DGridDesc_M,
DGridDesc_M, BlockSize,
BlockSize, DMPerBlock,
DMPerBlock, DKPerBlock,
DKPerBlock>; Gemm1NPerBlock>;
using DBlock2CTileMap = using DBlock2CTileMap =
OffsettedBlockToCTileMap<typename GridwiseYDotYGrad::DefaultBlock2CTileMap>; OffsettedBlockToCTileMap<typename GridwiseYDotYGrad::DefaultBlock2CTileMap>;
...@@ -776,7 +785,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -776,7 +785,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
// D parameter // D parameter
DDataType* p_d_grid_; DDataType* p_d_grid_;
DYGridDesc_M_O d_y_grid_desc_m_o_;
DGridDesc_M d_grid_desc_m_; DGridDesc_M d_grid_desc_m_;
DBlock2CTileMap d_block_2_ctile_map_; DBlock2CTileMap d_block_2_ctile_map_;
typename GridwiseYDotYGrad::YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseYDotYGrad::YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...@@ -956,7 +964,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -956,7 +964,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
// D parameters // D parameters
const auto p_d_grid = static_cast<DDataType*>(p_Ds[i]); const auto p_d_grid = static_cast<DDataType*>(p_Ds[i]);
const auto d_grid_desc_m = const auto d_grid_desc_m =
DeviceOp::MakeDGridDescriptor_M(problem_desc.d_gs_ms_lengths[NumDimG]); DeviceOp::MakeDGridDescriptor_M(problem_desc.lse_gs_ms_lengths[NumDimG]);
const auto d_y_grid_desc_m_o = DTransform::MakeCGridDescriptor_M_N( const auto d_y_grid_desc_m_o = DTransform::MakeCGridDescriptor_M_N(
problem_desc.c_gs_ms_gemm1ns_lengths, problem_desc.c_gs_ms_gemm1ns_strides); problem_desc.c_gs_ms_gemm1ns_lengths, problem_desc.c_gs_ms_gemm1ns_strides);
...@@ -1001,7 +1009,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -1001,7 +1009,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
raw_m_padded, raw_m_padded,
raw_n_padded, raw_n_padded,
p_d_grid, p_d_grid,
d_y_grid_desc_m_o,
d_grid_desc_m, d_grid_desc_m,
d_block_2_ctile_map, d_block_2_ctile_map,
d_y_grid_desc_mblock_mperblock_nblock_nperblock, d_y_grid_desc_mblock_mperblock_nblock_nperblock,
...@@ -1105,17 +1112,19 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -1105,17 +1112,19 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
ave_time = launch_kernel(); ave_time = launch_kernel();
} }
auto launch_kernel = [&](auto has_main_k_block_loop_) { auto launch_kernel = [&](auto has_main_k_block_loop_, auto is_dropout_) {
const auto kernel = kernel_grouped_multihead_attention_backward_xdl_cshuffle_v1< const auto kernel =
GridwiseGemm, kernel_grouped_multihead_attention_backward_qloop_xdl_cshuffle_light_v1<
GroupKernelArg, GridwiseGemm,
AElementwiseOperation, GroupKernelArg,
BElementwiseOperation, AElementwiseOperation,
AccElementwiseOperation, BElementwiseOperation,
B1ElementwiseOperation, AccElementwiseOperation,
CElementwiseOperation, B1ElementwiseOperation,
has_main_k_block_loop_, CElementwiseOperation,
Deterministic>; has_main_k_block_loop_,
is_dropout_,
Deterministic>;
return launch_and_time_kernel( return launch_and_time_kernel(
stream_config, stream_config,
...@@ -1139,11 +1148,21 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -1139,11 +1148,21 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
// to concern Gemm0's loop // to concern Gemm0's loop
if(all_has_main_k_block_loop) if(all_has_main_k_block_loop)
{ {
ave_time += launch_kernel(integral_constant<bool, true>{}); if(arg.p_dropout_ > 0.0)
ave_time += launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, true>{});
else
ave_time += launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, false>{});
} }
else if(!some_has_main_k_block_loop) else if(!some_has_main_k_block_loop)
{ {
ave_time += launch_kernel(integral_constant<bool, false>{}); if(arg.p_dropout_ > 0.0)
ave_time += launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, true>{});
else
ave_time += launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{});
} }
else else
{ {
...@@ -1169,22 +1188,18 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -1169,22 +1188,18 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
{ {
return false; return false;
} }
for(index_t i = 0; i < arg.group_count_; i++) for(index_t i = 0; i < arg.group_count_; i++)
{ {
// TODO: Check if tensor specialization & strides mismatch // TODO: Check if tensor specialization & strides mismatch
const auto& kernel_arg = arg.group_kernel_args_[i]; const auto& kernel_arg = arg.group_kernel_args_[i];
const auto& device_arg = arg.group_device_args_[i]; const auto& device_arg = arg.group_device_args_[i];
if(!GridwiseYDotYGrad::CheckValidity(kernel_arg.d_y_grid_desc_m_o_,
kernel_arg.d_block_2_ctile_map_))
{
return false;
}
// Check if C permute dimension matches GEMM + GEMM shape // Check if C permute dimension matches GEMM + GEMM shape
const index_t c_g = device_arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded const index_t c_g = device_arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded
const index_t c_m = kernel_arg.y_grid_desc_m_o_.GetLength(I0); const index_t c_m = kernel_arg.y_grid_desc_m_o_.GetLength(I0);
...@@ -1352,7 +1367,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -1352,7 +1367,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1" str << "DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -36,7 +36,8 @@ __global__ void ...@@ -36,7 +36,8 @@ __global__ void
kernel_grouped_multihead_attention_backward_ydotygrad_v2( kernel_grouped_multihead_attention_backward_ydotygrad_v2(
const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args, const index_t group_count) const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args, const index_t group_count)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
const index_t block_id = get_block_1d_id(); const index_t block_id = get_block_1d_id();
const auto arg_ptr = reinterpret_cast<const GroupKernelArg*>( const auto arg_ptr = reinterpret_cast<const GroupKernelArg*>(
cast_pointer_to_generic_address_space(group_kernel_args)); cast_pointer_to_generic_address_space(group_kernel_args));
...@@ -89,12 +90,13 @@ template <typename GridwiseGemm, ...@@ -89,12 +90,13 @@ template <typename GridwiseGemm,
typename B1ElementwiseOperation, typename B1ElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
bool IsDropout,
bool Deterministic> bool Deterministic>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
#endif #endif
kernel_grouped_multihead_attention_backward_xdl_cshuffle_v2( kernel_grouped_multihead_attention_backward_qloop_xdl_cshuffle_light_v2(
const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args, const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args,
const index_t group_count, const index_t group_count,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
...@@ -106,7 +108,8 @@ __global__ void ...@@ -106,7 +108,8 @@ __global__ void
const unsigned long long seed, const unsigned long long seed,
const unsigned long long offset) const unsigned long long offset)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id(); const index_t block_id = get_block_1d_id();
const auto arg_ptr = reinterpret_cast<const GroupKernelArg*>( const auto arg_ptr = reinterpret_cast<const GroupKernelArg*>(
...@@ -158,7 +161,7 @@ __global__ void ...@@ -158,7 +161,7 @@ __global__ void
{ {
for(index_t i = 0; i < num_blocks_per_batch; i++) for(index_t i = 0; i < num_blocks_per_batch; i++)
{ {
GridwiseGemm::template Run<HasMainKBlockLoop>( GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset, arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset, arg_ptr[group_id].p_b_grid_ + b_batch_offset,
z_matrix_ptr, z_matrix_ptr,
...@@ -180,7 +183,6 @@ __global__ void ...@@ -180,7 +183,6 @@ __global__ void
arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_, arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].lse_grid_desc_m_, arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].ygrad_grid_desc_m0_o_m1_, arg_ptr[group_id].ygrad_grid_desc_m0_o_m1_,
arg_ptr[group_id].block_2_ctile_map_, arg_ptr[group_id].block_2_ctile_map_,
arg_ptr[group_id].c0_matrix_mask_, arg_ptr[group_id].c0_matrix_mask_,
...@@ -194,7 +196,7 @@ __global__ void ...@@ -194,7 +196,7 @@ __global__ void
} }
else else
{ {
GridwiseGemm::template Run<HasMainKBlockLoop>( GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset, arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset, arg_ptr[group_id].p_b_grid_ + b_batch_offset,
z_matrix_ptr, z_matrix_ptr,
...@@ -216,7 +218,6 @@ __global__ void ...@@ -216,7 +218,6 @@ __global__ void
arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_, arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].lse_grid_desc_m_, arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].ygrad_grid_desc_m0_o_m1_, arg_ptr[group_id].ygrad_grid_desc_m0_o_m1_,
arg_ptr[group_id].block_2_ctile_map_, arg_ptr[group_id].block_2_ctile_map_,
arg_ptr[group_id].c0_matrix_mask_, arg_ptr[group_id].c0_matrix_mask_,
...@@ -314,7 +315,7 @@ template <index_t NumDimG, ...@@ -314,7 +315,7 @@ template <index_t NumDimG,
MaskingSpecialization MaskingSpec, MaskingSpecialization MaskingSpec,
bool Deterministic, bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
: public BaseOperator // TODO inherit atten bwd op once API stablizes : public BaseOperator // TODO inherit atten bwd op once API stablizes
{ {
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
...@@ -327,7 +328,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -327,7 +328,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
// TODO: implement bias combination // TODO: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented"); static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
using DeviceOp = DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2; using DeviceOp = DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2;
struct ProblemDesc struct ProblemDesc
{ {
std::vector<index_t> a_gs_ms_ks_lengths; std::vector<index_t> a_gs_ms_ks_lengths;
...@@ -348,9 +349,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -348,9 +349,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
std::vector<index_t> lse_gs_ms_lengths; std::vector<index_t> lse_gs_ms_lengths;
std::vector<index_t> lse_gs_ms_strides; std::vector<index_t> lse_gs_ms_strides;
std::vector<index_t> d_gs_ms_lengths;
std::vector<index_t> d_gs_ms_strides;
std::vector<std::vector<index_t>> acc0_biases_gs_ms_ns_lengths; std::vector<std::vector<index_t>> acc0_biases_gs_ms_ns_lengths;
std::vector<std::vector<index_t>> acc0_biases_gs_ms_ns_strides; std::vector<std::vector<index_t>> acc0_biases_gs_ms_ns_strides;
...@@ -564,10 +562,22 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -564,10 +562,22 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
const auto M = math::integer_divide_ceil(MRaw, DMPerBlock) * DMPerBlock; const auto M = math::integer_divide_ceil(MRaw, DMPerBlock) * DMPerBlock;
const auto MPad = M - MRaw; const auto MPad = M - MRaw;
return transform_tensor_descriptor(d_grid_desc_mraw, if constexpr(GemmSpec == GemmSpecialization::MPadding ||
make_tuple(make_right_pad_transform(MRaw, MPad)), GemmSpec == GemmSpecialization::MNPadding ||
make_tuple(Sequence<0>{}), GemmSpec == GemmSpecialization::MKPadding ||
make_tuple(Sequence<0>{})); GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M
return transform_tensor_descriptor(d_grid_desc_mraw,
make_tuple(make_right_pad_transform(MRaw, MPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
}
else
{
// not pad M
return d_grid_desc_mraw;
}
} }
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {})); using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
...@@ -658,7 +668,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -658,7 +668,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
}; };
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2< using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2<
InputDataType, // TODO: distinguish A/B datatype InputDataType, // TODO: distinguish A/B datatype
OutputDataType, OutputDataType,
ZDataType, ZDataType,
...@@ -680,7 +690,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -680,7 +690,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
B1GridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1,
YGridDesc_M_O, YGridDesc_M_O,
LSEGridDesc_M, LSEGridDesc_M,
LSEGridDesc_M,
NumGemmKPrefetchStage, NumGemmKPrefetchStage,
BlockSize, BlockSize,
MPerBlock, MPerBlock,
...@@ -733,14 +742,14 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -733,14 +742,14 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>; using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>;
// GridwiseYDotYGrad // GridwiseYDotYGrad
using GridwiseYDotYGrad = using GridwiseYDotYGrad = GridwiseBatchedMultiheadAttentionBackward_YDotYGrad<InputDataType,
GridwiseBatchedMultiheadAttentionBackward_YDotYGrad<InputDataType, // TODO: distinguish A/B DDataType,
DDataType, // datatype DYGridDesc_M_O,
YGridDesc_M_O, DGridDesc_M,
DGridDesc_M, BlockSize,
BlockSize, DMPerBlock,
DMPerBlock, DKPerBlock,
DKPerBlock>; Gemm1NPerBlock>;
using DBlock2CTileMap = using DBlock2CTileMap =
OffsettedBlockToCTileMap<typename GridwiseYDotYGrad::DefaultBlock2CTileMap>; OffsettedBlockToCTileMap<typename GridwiseYDotYGrad::DefaultBlock2CTileMap>;
...@@ -784,7 +793,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -784,7 +793,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
// D parameter // D parameter
DDataType* p_d_grid_; DDataType* p_d_grid_;
DYGridDesc_M_O d_y_grid_desc_m_o_;
DGridDesc_M d_grid_desc_m_; DGridDesc_M d_grid_desc_m_;
DBlock2CTileMap d_block_2_ctile_map_; DBlock2CTileMap d_block_2_ctile_map_;
typename GridwiseYDotYGrad::YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseYDotYGrad::YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...@@ -870,6 +878,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -870,6 +878,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
index_t z_random_matrix_offset = 0; index_t z_random_matrix_offset = 0;
d_grid_size_ = 0; d_grid_size_ = 0;
for(index_t i = 0; i < group_count_; i++) for(index_t i = 0; i < group_count_; i++)
{ {
const auto p_a_grid = static_cast<const InputDataType*>(p_As[i]); const auto p_a_grid = static_cast<const InputDataType*>(p_As[i]);
...@@ -920,6 +929,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -920,6 +929,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
const index_t BlockStart = grid_size_; const index_t BlockStart = grid_size_;
const auto block_2_ctile_map = Block2CTileMap(k_grid_desc_n_k, BlockStart); const auto block_2_ctile_map = Block2CTileMap(k_grid_desc_n_k, BlockStart);
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3 =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(
z_grid_desc_m_n);
const index_t batch_count = c_grid_desc_g_m_n.GetLength(I0); const index_t batch_count = c_grid_desc_g_m_n.GetLength(I0);
const index_t grid_size_grp = const index_t grid_size_grp =
(Deterministic ? 1 : block_2_ctile_map.CalculateGridSize(k_grid_desc_n_k)) * (Deterministic ? 1 : block_2_ctile_map.CalculateGridSize(k_grid_desc_n_k)) *
...@@ -959,7 +972,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -959,7 +972,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
// D parameters // D parameters
const auto p_d_grid = static_cast<DDataType*>(p_Ds[i]); const auto p_d_grid = static_cast<DDataType*>(p_Ds[i]);
const auto d_grid_desc_m = const auto d_grid_desc_m =
DeviceOp::MakeDGridDescriptor_M(problem_desc.d_gs_ms_lengths[NumDimG]); DeviceOp::MakeDGridDescriptor_M(problem_desc.lse_gs_ms_lengths[NumDimG]);
const auto d_y_grid_desc_m_o = DTransform::MakeCGridDescriptor_M_N( const auto d_y_grid_desc_m_o = DTransform::MakeCGridDescriptor_M_N(
problem_desc.c_gs_ms_gemm1ns_lengths, problem_desc.c_gs_ms_gemm1ns_strides); problem_desc.c_gs_ms_gemm1ns_lengths, problem_desc.c_gs_ms_gemm1ns_strides);
...@@ -1004,7 +1017,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -1004,7 +1017,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
raw_m_padded, raw_m_padded,
raw_n_padded, raw_n_padded,
p_d_grid, p_d_grid,
d_y_grid_desc_m_o,
d_grid_desc_m, d_grid_desc_m,
d_block_2_ctile_map, d_block_2_ctile_map,
d_y_grid_desc_mblock_mperblock_nblock_nperblock, d_y_grid_desc_mblock_mperblock_nblock_nperblock,
...@@ -1107,17 +1119,19 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -1107,17 +1119,19 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
ave_time = launch_kernel(); ave_time = launch_kernel();
} }
auto launch_kernel = [&](auto has_main_k_block_loop_) { auto launch_kernel = [&](auto has_main_k_block_loop_, auto is_dropout_) {
const auto kernel = kernel_grouped_multihead_attention_backward_xdl_cshuffle_v2< const auto kernel =
GridwiseGemm, kernel_grouped_multihead_attention_backward_qloop_xdl_cshuffle_light_v2<
GroupKernelArg, GridwiseGemm,
AElementwiseOperation, GroupKernelArg,
BElementwiseOperation, AElementwiseOperation,
AccElementwiseOperation, BElementwiseOperation,
B1ElementwiseOperation, AccElementwiseOperation,
CElementwiseOperation, B1ElementwiseOperation,
has_main_k_block_loop_, CElementwiseOperation,
Deterministic>; has_main_k_block_loop_,
is_dropout_,
Deterministic>;
return launch_and_time_kernel( return launch_and_time_kernel(
stream_config, stream_config,
...@@ -1141,11 +1155,21 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -1141,11 +1155,21 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
// to concern Gemm0's loop // to concern Gemm0's loop
if(all_has_main_k_block_loop) if(all_has_main_k_block_loop)
{ {
ave_time += launch_kernel(integral_constant<bool, true>{}); if(arg.p_dropout_ > 0.0)
ave_time += launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, true>{});
else
ave_time += launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, false>{});
} }
else if(!some_has_main_k_block_loop) else if(!some_has_main_k_block_loop)
{ {
ave_time += launch_kernel(integral_constant<bool, false>{}); if(arg.p_dropout_ > 0.0)
ave_time += launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, true>{});
else
ave_time += launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{});
} }
else else
{ {
...@@ -1171,7 +1195,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -1171,7 +1195,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
{ {
return false; return false;
} }
...@@ -1181,11 +1207,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -1181,11 +1207,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
// TODO: Check if tensor specialization & strides mismatch // TODO: Check if tensor specialization & strides mismatch
const auto& kernel_arg = arg.group_kernel_args_[i]; const auto& kernel_arg = arg.group_kernel_args_[i];
const auto& device_arg = arg.group_device_args_[i]; const auto& device_arg = arg.group_device_args_[i];
if(!GridwiseYDotYGrad::CheckValidity(kernel_arg.d_y_grid_desc_m_o_,
kernel_arg.d_block_2_ctile_map_))
{
return false;
}
// Check if C permute dimension matches GEMM + GEMM shape // Check if C permute dimension matches GEMM + GEMM shape
const index_t c_g = device_arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded const index_t c_g = device_arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded
const index_t c_m = kernel_arg.y_grid_desc_m_o_.GetLength(I0); const index_t c_m = kernel_arg.y_grid_desc_m_o_.GetLength(I0);
...@@ -1358,7 +1379,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -1358,7 +1379,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2" str << "DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -41,7 +41,6 @@ template <typename InputDataType, ...@@ -41,7 +41,6 @@ template <typename InputDataType,
typename VGridDesc_O0_N_O1, typename VGridDesc_O0_N_O1,
typename YGridDesc_M_O, typename YGridDesc_M_O,
typename LSEGridDesc_M, typename LSEGridDesc_M,
typename DGridDesc_M,
index_t NumGemmKPrefetchStage, index_t NumGemmKPrefetchStage,
index_t BlockSize, index_t BlockSize,
index_t MPerBlock, index_t MPerBlock,
...@@ -83,7 +82,7 @@ template <typename InputDataType, ...@@ -83,7 +82,7 @@ template <typename InputDataType,
bool MaskOutUpperTriangle, bool MaskOutUpperTriangle,
bool Deterministic, bool Deterministic,
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
{ {
static_assert(LoopSched == LoopScheduler::Default, static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported"); "Non-default loop scheduler is currently not supported");
...@@ -1155,6 +1154,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -1155,6 +1154,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
} }
template <bool HasMainKBlockLoop, template <bool HasMainKBlockLoop,
bool IsDropout,
typename Block2CTileMap, typename Block2CTileMap,
typename C0MatrixMask, typename C0MatrixMask,
typename YGradGridDesc_O0_M_O1> typename YGradGridDesc_O0_M_O1>
...@@ -1180,7 +1180,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -1180,7 +1180,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
const VGridDesc_O0_N_O1& v_grid_desc_o0_n_o1, const VGridDesc_O0_N_O1& v_grid_desc_o0_n_o1,
const LSEGridDesc_M& lse_grid_desc_m, const LSEGridDesc_M& lse_grid_desc_m,
const DGridDesc_M& d_grid_desc_m,
const YGradGridDesc_O0_M_O1& ygrad_grid_desc_o0_m_o1, const YGradGridDesc_O0_M_O1& ygrad_grid_desc_o0_m_o1,
const Block2CTileMap& block_2_ctile_map, const Block2CTileMap& block_2_ctile_map,
const C0MatrixMask& c0_matrix_mask, const C0MatrixMask& c0_matrix_mask,
...@@ -1206,7 +1205,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -1206,7 +1205,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
const auto lse_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto lse_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_lse_grid, lse_grid_desc_m.GetElementSpaceSize()); p_lse_grid, lse_grid_desc_m.GetElementSpaceSize());
const auto d_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto d_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d_grid, d_grid_desc_m.GetElementSpaceSize()); p_d_grid, lse_grid_desc_m.GetElementSpaceSize()); // reuse lse grid descriptor
const auto ygrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto ygrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ygrad_grid, ygrad_grid_desc_o0_m_o1.GetElementSpaceSize()); p_ygrad_grid, ygrad_grid_desc_o0_m_o1.GetElementSpaceSize());
auto vgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto vgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
...@@ -1532,6 +1531,25 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -1532,6 +1531,25 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
acc0_thread_origin[I5], acc0_thread_origin[I5],
acc0_thread_origin[I6])}; acc0_thread_origin[I6])};
auto d_thread_copy_global_to_vgpr =
ThreadwiseTensorSliceTransfer_v2<FloatD,
FloatGemmAcc,
decltype(lse_grid_desc_mb_m0_m1_m2_m3_m4),
decltype(lse_thread_desc_mb_m0_m1_m2_m3_m4),
Sequence<1, m0, m1, m2, m3, m4>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
1,
1,
true /* ResetCoordAfterRun */>{
lse_grid_desc_mb_m0_m1_m2_m3_m4,
make_multi_index(num_gemm0_m_block_outer_loop - 1, // mblock
acc0_thread_origin[I0], // mrepeat
acc0_thread_origin[I2], // mwave
acc0_thread_origin[I4], // mperxdl
acc0_thread_origin[I5],
acc0_thread_origin[I6])};
// //
// z vgpr copy to global // z vgpr copy to global
// //
...@@ -1651,11 +1669,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -1651,11 +1669,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
// load d and lse // load d and lse
// //
lse_thread_copy_global_to_vgpr.Run(lse_grid_desc_mb_m0_m1_m2_m3_m4, d_thread_copy_global_to_vgpr.Run(lse_grid_desc_mb_m0_m1_m2_m3_m4,
d_grid_buf, d_grid_buf,
lse_thread_desc_mb_m0_m1_m2_m3_m4, lse_thread_desc_mb_m0_m1_m2_m3_m4,
make_tuple(I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0),
y_dot_ygrad_thread_buf); y_dot_ygrad_thread_buf);
lse_thread_copy_global_to_vgpr.Run(lse_grid_desc_mb_m0_m1_m2_m3_m4, lse_thread_copy_global_to_vgpr.Run(lse_grid_desc_mb_m0_m1_m2_m3_m4,
lse_grid_buf, lse_grid_buf,
...@@ -1743,56 +1761,64 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -1743,56 +1761,64 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
constexpr auto position_offset = M3 * M4; constexpr auto position_offset = M3 * M4;
// save z to global // save z to global
if(p_z_grid) if constexpr(IsDropout)
{ {
if(p_z_grid)
{
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin; auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
auto m_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0]; auto m_local =
auto n_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1]; block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto m_global = m_local + m_block_data_idx_on_grid; auto n_local =
auto n_global = n_local + n_block_data_idx_on_grid; block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded + auto n_global = n_local + n_block_data_idx_on_grid;
n_global; // unique element global 1d id
auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded +
auto global_elem_id = n_global; // unique element global 1d id
(global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4;
auto global_elem_id =
blockwise_dropout.template ApplyDropoutAttnBwdSaveZ<decltype(s_slash_p_thread_buf), (global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4;
decltype(z_tenor_buffer),
decltype(position_offset), blockwise_dropout
true>( .template ApplyDropoutAttnBwdSaveZ<decltype(s_slash_p_thread_buf),
s_slash_p_thread_buf, ph, global_elem_id, z_tenor_buffer, raw_n_padded); decltype(z_tenor_buffer),
decltype(position_offset),
z_thread_copy_vgpr_to_global.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, true>(
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), s_slash_p_thread_buf, ph, global_elem_id, z_tenor_buffer, raw_n_padded);
z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, z_thread_copy_vgpr_to_global.Run(
z_grid_buf); z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
} make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
else z_tenor_buffer,
{ z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
ignore = z_grid_buf; z_grid_buf);
}
else
{
ignore = z_grid_buf;
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin; auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
auto m_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0]; auto m_local =
auto n_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1]; block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto m_global = m_local + m_block_data_idx_on_grid; auto n_local =
auto n_global = n_local + n_block_data_idx_on_grid; block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded + auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded +
n_global; // unique element global 1d id n_global; // unique element global 1d id
auto global_elem_id = auto global_elem_id =
(global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4; (global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4;
// P_dropped // P_dropped
blockwise_dropout.template ApplyDropoutAttnBwd<decltype(s_slash_p_thread_buf), blockwise_dropout.template ApplyDropoutAttnBwd<decltype(s_slash_p_thread_buf),
decltype(position_offset), decltype(position_offset),
true>( true>(
s_slash_p_thread_buf, ph, global_elem_id, raw_n_padded); s_slash_p_thread_buf, ph, global_elem_id, raw_n_padded);
}
} }
block_sync_lds(); // wait for gemm1 LDS read block_sync_lds(); // wait for gemm1 LDS read
// dS = P * (dP - Y_dot_dY) // dS = P * (dP - Y_dot_dY)
...@@ -1965,6 +1991,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1 ...@@ -1965,6 +1991,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
make_multi_index(-1, 0, 0, 0, 0, 0, 0, 0, 0, 0)); make_multi_index(-1, 0, 0, 0, 0, 0, 0, 0, 0, 0));
lse_thread_copy_global_to_vgpr.MoveSrcSliceWindow(lse_grid_desc_mb_m0_m1_m2_m3_m4, lse_thread_copy_global_to_vgpr.MoveSrcSliceWindow(lse_grid_desc_mb_m0_m1_m2_m3_m4,
make_multi_index(-1, 0, 0, 0, 0, 0)); make_multi_index(-1, 0, 0, 0, 0, 0));
d_thread_copy_global_to_vgpr.MoveSrcSliceWindow(lse_grid_desc_mb_m0_m1_m2_m3_m4,
make_multi_index(-1, 0, 0, 0, 0, 0));
} while(0 < gemm0_m_block_outer_index--); // end j loop } while(0 < gemm0_m_block_outer_index--); // end j loop
// shuffle dK&dV and write // shuffle dK&dV and write
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -41,7 +41,6 @@ template <typename InputDataType, ...@@ -41,7 +41,6 @@ template <typename InputDataType,
typename VGridDesc_N0_O_N1, typename VGridDesc_N0_O_N1,
typename YGridDesc_M_O, typename YGridDesc_M_O,
typename LSEGridDesc_M, typename LSEGridDesc_M,
typename DGridDesc_M,
index_t NumGemmKPrefetchStage, index_t NumGemmKPrefetchStage,
index_t BlockSize, index_t BlockSize,
index_t MPerBlock, index_t MPerBlock,
...@@ -91,7 +90,7 @@ template <typename InputDataType, ...@@ -91,7 +90,7 @@ template <typename InputDataType,
bool MaskOutUpperTriangle, bool MaskOutUpperTriangle,
bool Deterministic, bool Deterministic,
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
{ {
static_assert(LoopSched == LoopScheduler::Default, static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported"); "Non-default loop scheduler is currently not supported");
...@@ -1110,6 +1109,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -1110,6 +1109,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
} }
template <bool HasMainKBlockLoop, template <bool HasMainKBlockLoop,
bool IsDropout,
typename Block2CTileMap, typename Block2CTileMap,
typename C0MatrixMask, typename C0MatrixMask,
typename YGradGridDesc_M0_O_M1> typename YGradGridDesc_M0_O_M1>
...@@ -1135,7 +1135,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -1135,7 +1135,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
const VGridDesc_N0_O_N1& v_grid_desc_n0_o_n1, const VGridDesc_N0_O_N1& v_grid_desc_n0_o_n1,
const LSEGridDesc_M& lse_grid_desc_m, const LSEGridDesc_M& lse_grid_desc_m,
const DGridDesc_M& d_grid_desc_m,
const YGradGridDesc_M0_O_M1& ygrad_grid_desc_m0_o_m1, const YGradGridDesc_M0_O_M1& ygrad_grid_desc_m0_o_m1,
const Block2CTileMap& block_2_ctile_map, const Block2CTileMap& block_2_ctile_map,
const C0MatrixMask& c0_matrix_mask, const C0MatrixMask& c0_matrix_mask,
...@@ -1161,7 +1160,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -1161,7 +1160,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
const auto lse_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto lse_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_lse_grid, lse_grid_desc_m.GetElementSpaceSize()); p_lse_grid, lse_grid_desc_m.GetElementSpaceSize());
const auto d_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto d_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d_grid, d_grid_desc_m.GetElementSpaceSize()); p_d_grid, lse_grid_desc_m.GetElementSpaceSize()); // reuse lse grid descriptor
const auto ygrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto ygrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ygrad_grid, ygrad_grid_desc_m0_o_m1.GetElementSpaceSize()); p_ygrad_grid, ygrad_grid_desc_m0_o_m1.GetElementSpaceSize());
auto vgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto vgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
...@@ -1516,6 +1515,25 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -1516,6 +1515,25 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
acc0_thread_origin[I5], acc0_thread_origin[I5],
acc0_thread_origin[I6])}; acc0_thread_origin[I6])};
auto d_thread_copy_global_to_vgpr =
ThreadwiseTensorSliceTransfer_v2<FloatD,
FloatGemmAcc,
decltype(lse_grid_desc_mb_m0_m1_m2_m3_m4),
decltype(lse_thread_desc_mb_m0_m1_m2_m3_m4),
Sequence<1, m0, m1, m2, m3, m4>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
1,
1,
true /* ResetCoordAfterRun */>{
lse_grid_desc_mb_m0_m1_m2_m3_m4,
make_multi_index(num_gemm0_m_block_outer_loop - 1, // mblock
acc0_thread_origin[I0], // mrepeat
acc0_thread_origin[I2], // mwave
acc0_thread_origin[I4], // mperxdl
acc0_thread_origin[I5],
acc0_thread_origin[I6])};
// //
// z vgpr copy to global // z vgpr copy to global
// //
...@@ -1612,11 +1630,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -1612,11 +1630,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
// load d and lse // load d and lse
// //
lse_thread_copy_global_to_vgpr.Run(lse_grid_desc_mb_m0_m1_m2_m3_m4, d_thread_copy_global_to_vgpr.Run(lse_grid_desc_mb_m0_m1_m2_m3_m4,
d_grid_buf, d_grid_buf,
lse_thread_desc_mb_m0_m1_m2_m3_m4, lse_thread_desc_mb_m0_m1_m2_m3_m4,
make_tuple(I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0),
y_dot_ygrad_thread_buf); y_dot_ygrad_thread_buf);
lse_thread_copy_global_to_vgpr.Run(lse_grid_desc_mb_m0_m1_m2_m3_m4, lse_thread_copy_global_to_vgpr.Run(lse_grid_desc_mb_m0_m1_m2_m3_m4,
lse_grid_buf, lse_grid_buf,
...@@ -1706,55 +1724,63 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -1706,55 +1724,63 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
constexpr auto position_offset = M3 * M4; constexpr auto position_offset = M3 * M4;
// save z to global // save z to global
if(p_z_grid) if constexpr(IsDropout)
{ {
if(p_z_grid)
{
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin; auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
auto m_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0]; auto m_local =
auto n_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1]; block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto m_global = m_local + m_block_data_idx_on_grid; auto n_local =
auto n_global = n_local + n_block_data_idx_on_grid; block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded + auto n_global = n_local + n_block_data_idx_on_grid;
n_global; // unique element global 1d id
auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded +
auto global_elem_id = n_global; // unique element global 1d id
(global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4;
auto global_elem_id =
blockwise_dropout.template ApplyDropoutAttnBwdSaveZ<decltype(s_slash_p_thread_buf), (global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4;
decltype(z_tenor_buffer),
decltype(position_offset), blockwise_dropout
true>( .template ApplyDropoutAttnBwdSaveZ<decltype(s_slash_p_thread_buf),
s_slash_p_thread_buf, ph, global_elem_id, z_tenor_buffer, raw_n_padded); decltype(z_tenor_buffer),
decltype(position_offset),
z_thread_copy_vgpr_to_global.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, true>(
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), s_slash_p_thread_buf, ph, global_elem_id, z_tenor_buffer, raw_n_padded);
z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, z_thread_copy_vgpr_to_global.Run(
z_grid_buf); z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
} make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
else z_tenor_buffer,
{ z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
ignore = z_grid_buf; z_grid_buf);
}
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin; else
auto m_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0]; {
auto n_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1]; ignore = z_grid_buf;
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded +
n_global; // unique element global 1d id
auto global_elem_id =
(global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4;
// P_dropped
blockwise_dropout.template ApplyDropoutAttnBwd<decltype(s_slash_p_thread_buf),
decltype(position_offset),
true>(
s_slash_p_thread_buf, ph, global_elem_id, raw_n_padded);
}
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
auto m_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded +
n_global; // unique element global 1d id
auto global_elem_id =
(global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4;
// P_dropped
blockwise_dropout.template ApplyDropoutAttnBwd<decltype(s_slash_p_thread_buf),
decltype(position_offset),
true>(
s_slash_p_thread_buf, ph, global_elem_id, raw_n_padded);
}
}
block_sync_lds(); // wait for gemm1 LDS read block_sync_lds(); // wait for gemm1 LDS read
// gemm dV // gemm dV
...@@ -2005,6 +2031,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2 ...@@ -2005,6 +2031,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
qgrad_grid_desc_m0_o0_m1_o1_m2_o2_o3_o4, Gemm2::c_block_slice_copy_step); // step M qgrad_grid_desc_m0_o0_m1_o1_m2_o2_o3_o4, Gemm2::c_block_slice_copy_step); // step M
lse_thread_copy_global_to_vgpr.MoveSrcSliceWindow(lse_grid_desc_mb_m0_m1_m2_m3_m4, lse_thread_copy_global_to_vgpr.MoveSrcSliceWindow(lse_grid_desc_mb_m0_m1_m2_m3_m4,
make_multi_index(-1, 0, 0, 0, 0, 0)); make_multi_index(-1, 0, 0, 0, 0, 0));
d_thread_copy_global_to_vgpr.MoveSrcSliceWindow(lse_grid_desc_mb_m0_m1_m2_m3_m4,
make_multi_index(-1, 0, 0, 0, 0, 0));
z_thread_copy_vgpr_to_global.MoveDstSliceWindow( z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_multi_index(-1, 0, 0, 0, 0, 0, 0, 0, 0, 0)); make_multi_index(-1, 0, 0, 0, 0, 0, 0, 0, 0, 0));
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -26,7 +26,8 @@ template <typename InputDataType, ...@@ -26,7 +26,8 @@ template <typename InputDataType,
typename DGridDesc_M, typename DGridDesc_M,
index_t BlockSize, index_t BlockSize,
index_t MPerBlock, index_t MPerBlock,
index_t NPerBlock> index_t NPerBlock,
index_t NPadded>
struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -103,7 +104,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad ...@@ -103,7 +104,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
MakeDefaultBlock2CTileMap(const YGridDesc_M_N& y_grid_desc_m_n) MakeDefaultBlock2CTileMap(const YGridDesc_M_N& y_grid_desc_m_n)
{ {
// should rewrite BlockToCTileMap_M00_N0_M01Adapt // should rewrite BlockToCTileMap_M00_N0_M01Adapt
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, 1024, YGridDesc_M_N>(y_grid_desc_m_n); return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPadded, YGridDesc_M_N>(y_grid_desc_m_n);
} }
using YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment