Commit 70d700b3 authored by danyao12's avatar danyao12
Browse files

optimized bwd split kernels w/ bias

parent 9e11dea6
...@@ -83,7 +83,7 @@ static constexpr ck::index_t NumDimO = 1; ...@@ -83,7 +83,7 @@ static constexpr ck::index_t NumDimO = 1;
// When OutputDataType == F16/BF16, CShuffleBlockTransferScalarPerVector_NPerBlock = 8 // When OutputDataType == F16/BF16, CShuffleBlockTransferScalarPerVector_NPerBlock = 8
static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8; static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
#if USING_MASK #if USING_MASK
static constexpr auto MaskingSpec = static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskUpperTriangleFromTopLeft; ck::tensor_operation::device::MaskingSpecialization::MaskUpperTriangleFromTopLeft;
...@@ -119,7 +119,7 @@ using DeviceGemmInstance = ...@@ -119,7 +119,7 @@ using DeviceGemmInstance =
// ##############################################################################################| | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Block| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | | // ##############################################################################################| | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Block| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
// ##############################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | // ##############################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | |
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 64, 128, 64, 64, 32, 64, 8, 8, 2, 32, 32, 2, 1, 2, 1, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 4, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>; // ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 64, 128, 64, 64, 32, 64, 8, 8, 2, 32, 32, 2, 1, 2, 1, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 4, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 64, 32, 32, 8, 8, 2, 32, 32, 2, 1, 2, 1, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 4, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>; ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 64, 32, 32, 8, 8, 2, 32, 32, 4, 1, 2, 1, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 4, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ##############################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| DDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| Gemm2| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2|YDotYGrad| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic| // ##############################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| DDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| Gemm2| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2|YDotYGrad| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ##############################################################################################| | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| KPer| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | | // ##############################################################################################| | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| KPer| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ##############################################################################################| | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Block| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | | // ##############################################################################################| | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Block| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
......
...@@ -118,7 +118,7 @@ using DeviceGemmInstance = ...@@ -118,7 +118,7 @@ using DeviceGemmInstance =
// ##############################################################################################| | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Block| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | | // ##############################################################################################| | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Block| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
// ##############################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | // ##############################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | |
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 64, 128, 64, 64, 32, 64, 8, 8, 2, 32, 32, 2, 1, 2, 1, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 4, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>; // ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 64, 128, 64, 64, 32, 64, 8, 8, 2, 32, 32, 2, 1, 2, 1, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 4, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 64, 32, 32, 8, 8, 2, 32, 32, 2, 1, 2, 1, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 4, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>; ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 64, 32, 32, 8, 8, 2, 32, 32, 4, 1, 2, 1, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 4, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ##############################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| DDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| Gemm2| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2|YDotYGrad| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic| // ##############################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| DDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| Gemm2| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2|YDotYGrad| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ##############################################################################################| | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| KPer| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | | // ##############################################################################################| | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| KPer| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ##############################################################################################| | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Block| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | | // ##############################################################################################| | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Block| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
......
...@@ -82,6 +82,7 @@ __global__ void ...@@ -82,6 +82,7 @@ __global__ void
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename InputDataType, typename InputDataType,
typename D0DataType,
typename OutputDataType, typename OutputDataType,
typename ZDataType, typename ZDataType,
typename LSEDataType, typename LSEDataType,
...@@ -93,6 +94,7 @@ template <typename GridwiseGemm, ...@@ -93,6 +94,7 @@ template <typename GridwiseGemm,
typename CElementwiseOperation, typename CElementwiseOperation,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename D0GridDescriptor_M0_N0_M1_M2_N1_M3,
typename ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3, typename ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3,
typename B1GridDesc_BK0_N_BK1, typename B1GridDesc_BK0_N_BK1,
typename LSEGridDescriptor_M, typename LSEGridDescriptor_M,
...@@ -110,6 +112,7 @@ __global__ void ...@@ -110,6 +112,7 @@ __global__ void
kernel_batched_multihead_attention_backward_qloop_xdl_cshuffle_light_v1( kernel_batched_multihead_attention_backward_qloop_xdl_cshuffle_light_v1(
const InputDataType* __restrict__ p_a_grid, const InputDataType* __restrict__ p_a_grid,
const InputDataType* __restrict__ p_b_grid, const InputDataType* __restrict__ p_b_grid,
const D0DataType* __restrict__ p_d0_grid,
ZDataType* __restrict__ p_z_grid, ZDataType* __restrict__ p_z_grid,
const InputDataType* __restrict__ p_b1_grid, const InputDataType* __restrict__ p_b1_grid,
const LSEDataType* __restrict__ p_lse_grid, const LSEDataType* __restrict__ p_lse_grid,
...@@ -125,6 +128,7 @@ __global__ void ...@@ -125,6 +128,7 @@ __global__ void
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const D0GridDescriptor_M0_N0_M1_M2_N1_M3 d0_grid_desc_m0_n0_m1_m2_n1_m3,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3 const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
...@@ -168,6 +172,13 @@ __global__ void ...@@ -168,6 +172,13 @@ __global__ void
const index_t z_random_matrix_offset = g_idx * raw_m_padded * raw_n_padded; const index_t z_random_matrix_offset = g_idx * raw_m_padded * raw_n_padded;
const D0DataType* tmp_p_d0_grid = nullptr;
if constexpr(!is_same<D0DataType, void>::value)
{
const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetD0BasePtr(g_idx)));
tmp_p_d0_grid = p_d0_grid + d0_batch_offset;
}
if constexpr(Deterministic) if constexpr(Deterministic)
{ {
for(index_t i = 0; i < nblock; i++) for(index_t i = 0; i < nblock; i++)
...@@ -175,6 +186,7 @@ __global__ void ...@@ -175,6 +186,7 @@ __global__ void
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>( GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
p_a_grid + a_batch_offset, p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset, p_b_grid + b_batch_offset,
tmp_p_d0_grid,
z_matrix_ptr, z_matrix_ptr,
p_b1_grid + b1_batch_offset, p_b1_grid + b1_batch_offset,
p_lse_grid + lse_batch_offset, p_lse_grid + lse_batch_offset,
...@@ -191,6 +203,7 @@ __global__ void ...@@ -191,6 +203,7 @@ __global__ void
c_element_op, c_element_op,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
d0_grid_desc_m0_n0_m1_m2_n1_m3,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
lse_grid_desc_m, lse_grid_desc_m,
...@@ -209,6 +222,7 @@ __global__ void ...@@ -209,6 +222,7 @@ __global__ void
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>( GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
p_a_grid + a_batch_offset, p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset, p_b_grid + b_batch_offset,
tmp_p_d0_grid,
z_matrix_ptr, z_matrix_ptr,
p_b1_grid + b1_batch_offset, p_b1_grid + b1_batch_offset,
p_lse_grid + lse_batch_offset, p_lse_grid + lse_batch_offset,
...@@ -225,6 +239,7 @@ __global__ void ...@@ -225,6 +239,7 @@ __global__ void
c_element_op, c_element_op,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
d0_grid_desc_m0_n0_m1_m2_n1_m3,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
lse_grid_desc_m, lse_grid_desc_m,
...@@ -240,6 +255,7 @@ __global__ void ...@@ -240,6 +255,7 @@ __global__ void
#else #else
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b_grid; ignore = p_b_grid;
ignore = p_d0_grid;
ignore = p_z_grid; ignore = p_z_grid;
ignore = p_b1_grid; ignore = p_b1_grid;
ignore = p_lse_grid; ignore = p_lse_grid;
...@@ -255,6 +271,7 @@ __global__ void ...@@ -255,6 +271,7 @@ __global__ void
ignore = c_element_op; ignore = c_element_op;
ignore = a_grid_desc_ak0_m_ak1; ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1; ignore = b_grid_desc_bk0_n_bk1;
ignore = d0_grid_desc_m0_n0_m1_m2_n1_m3;
ignore = c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3; ignore = c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3;
ignore = b1_grid_desc_bk0_n_bk1; ignore = b1_grid_desc_bk0_n_bk1;
ignore = lse_grid_desc_m; ignore = lse_grid_desc_m;
...@@ -307,6 +324,7 @@ template <index_t NumDimG, ...@@ -307,6 +324,7 @@ template <index_t NumDimG,
index_t KPerBlock, // Gemm0KPerBlock index_t KPerBlock, // Gemm0KPerBlock
index_t Gemm1NPerBlock, index_t Gemm1NPerBlock,
index_t Gemm1KPerBlock, index_t Gemm1KPerBlock,
index_t Gemm2KPerBlock,
index_t AK1, index_t AK1,
index_t BK1, index_t BK1,
index_t B1K1, index_t B1K1,
...@@ -331,6 +349,7 @@ template <index_t NumDimG, ...@@ -331,6 +349,7 @@ template <index_t NumDimG,
index_t BBlockTransferSrcScalarPerVector, index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1, index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN, bool BBlockLdsExtraN,
index_t D0BlockTransferSrcScalarPerVector,
index_t CShuffleMXdlPerWavePerShuffle, index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
...@@ -344,12 +363,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -344,12 +363,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
"Number of dimension must be greater than 0"); "Number of dimension must be greater than 0");
static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size(); using D0DataType = Acc0BiasDataType;
static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size(); using D1DataType = Acc1BiasDataType;
static constexpr index_t DMPerBlock = BlockSize; static constexpr index_t DMPerBlock = BlockSize;
// TODO: implement bias combination // TODO: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented"); static_assert(std::is_void<D1DataType>::value, "Acc1 Bias addition is unimplemented");
using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1; using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1;
...@@ -357,9 +376,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -357,9 +376,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr index_t V_O1 = 8; static constexpr index_t V_O1 = BK1;
static constexpr index_t Y_O1 = 8; static constexpr index_t Y_O1 = AK1;
static constexpr index_t Y_M1 = 2; static constexpr index_t Y_M1 = B1K1;
static constexpr auto padder = GemmGemmPadder<GemmSpec, static constexpr auto padder = GemmGemmPadder<GemmSpec,
Number<MPerBlock>, Number<MPerBlock>,
...@@ -397,31 +416,31 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -397,31 +416,31 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
*/ */
// Q in Gemm A position // Q in Gemm A position
static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_gs_ms_ks_lengths_vec, static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides_vec) const std::vector<index_t>& a_gs_ms_ks_strides)
{ {
return Transform::MakeAGridDescriptor_AK0_M_AK1( return Transform::MakeAGridDescriptor_AK0_M_AK1(
Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec), Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides),
Number<AK1>{}); Number<AK1>{});
} }
// K in Gemm B0 position // K in Gemm B0 position
static auto MakeBGridDescriptor_BK0_N_BK1(const std::vector<index_t>& b_gs_ns_ks_lengths_vec, static auto MakeBGridDescriptor_BK0_N_BK1(const std::vector<index_t>& b_gs_ns_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_strides_vec) const std::vector<index_t>& b_gs_ns_ks_strides)
{ {
return Transform::MakeB0GridDescriptor_BK0_N_BK1( return Transform::MakeB0GridDescriptor_BK0_N_BK1(
Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths_vec, b_gs_ns_ks_strides_vec), Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides),
Number<BK1>{}); Number<BK1>{});
} }
// V in Gemm B1 position // V in Gemm B1 position
static auto static auto
MakeB1GridDescriptor_BK0_N_BK1(const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths_vec, MakeB1GridDescriptor_BK0_N_BK1(const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths,
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides_vec) const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides)
{ {
return Transform::MakeB1GridDescriptor_BK0_N_BK1( return Transform::MakeB1GridDescriptor_BK0_N_BK1(
Transform::MakeB1GridDescriptor_N_K(b1_gs_gemm1ns_gemm1ks_lengths_vec, Transform::MakeB1GridDescriptor_N_K(b1_gs_gemm1ns_gemm1ks_lengths,
b1_gs_gemm1ns_gemm1ks_strides_vec), b1_gs_gemm1ns_gemm1ks_strides),
Number<B1K1>{}); Number<B1K1>{});
} }
...@@ -430,8 +449,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -430,8 +449,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
// //
// VGrad in Gemm C position // VGrad in Gemm C position
static auto MakeVGradGridDescriptor_N_O(const std::vector<index_t>& v_gs_os_ns_lengths_vec, static auto MakeVGradGridDescriptor_N_O(const std::vector<index_t>& v_gs_os_ns_lengths,
const std::vector<index_t>& v_gs_os_ns_strides_vec) const std::vector<index_t>& v_gs_os_ns_strides)
{ {
// v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major. // v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major.
// Here directly rearrange lengths/strides before constructing tensor descriptor to reduce // Here directly rearrange lengths/strides before constructing tensor descriptor to reduce
...@@ -457,17 +476,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -457,17 +476,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
ids_old2new.insert(ids_old2new.end(), ns_ids.begin(), ns_ids.end()); ids_old2new.insert(ids_old2new.end(), ns_ids.begin(), ns_ids.end());
ids_old2new.insert(ids_old2new.end(), os_ids.begin(), os_ids.end()); ids_old2new.insert(ids_old2new.end(), os_ids.begin(), os_ids.end());
std::vector<index_t> v_gs_ns_os_lengths_vec(num_dims), v_gs_ns_os_strides_vec(num_dims); std::vector<index_t> v_gs_ns_os_lengths(num_dims), v_gs_ns_os_strides(num_dims);
for(int i = 0; i < num_dims; i++) for(int i = 0; i < num_dims; i++)
{ {
index_t id_new = ids_old2new[i]; index_t id_new = ids_old2new[i];
v_gs_ns_os_lengths_vec[i] = v_gs_os_ns_lengths_vec[id_new]; v_gs_ns_os_lengths[i] = v_gs_os_ns_lengths[id_new];
v_gs_ns_os_strides_vec[i] = v_gs_os_ns_strides_vec[id_new]; v_gs_ns_os_strides[i] = v_gs_os_ns_strides[id_new];
} }
const auto vgrad_desc_nraw_oraw = const auto vgrad_desc_nraw_oraw =
MakeGridDescriptorPair<NumDimG, NumDimN, NumDimO, TensorSpecialization::Default>( MakeGridDescriptorPair<NumDimG, NumDimN, NumDimO, TensorSpecialization::Default>(
v_gs_ns_os_lengths_vec, v_gs_ns_os_strides_vec) v_gs_ns_os_lengths, v_gs_ns_os_strides)
.second; .second;
return PadTensorDescriptor(vgrad_desc_nraw_oraw, return PadTensorDescriptor(vgrad_desc_nraw_oraw,
...@@ -496,17 +515,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -496,17 +515,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
// //
// YGrad in Gemm A position // YGrad in Gemm A position
static auto MakeYGradGridDescriptor_O0_M_O1(const std::vector<index_t>& y_gs_ms_os_lengths_vec, static auto MakeYGradGridDescriptor_O0_M_O1(const std::vector<index_t>& y_gs_ms_os_lengths,
const std::vector<index_t>& y_gs_ms_os_strides_vec) const std::vector<index_t>& y_gs_ms_os_strides)
{ {
return Transform::MakeAGridDescriptor_AK0_M_AK1( return Transform::MakeAGridDescriptor_AK0_M_AK1(
Transform::MakeAGridDescriptor_M_K(y_gs_ms_os_lengths_vec, y_gs_ms_os_strides_vec), Transform::MakeAGridDescriptor_M_K(y_gs_ms_os_lengths, y_gs_ms_os_strides),
Number<Y_O1>{}); Number<Y_O1>{});
} }
// V in Gemm B position // V in Gemm B position
static auto MakeVGridDescriptor_O0_N_O1(const std::vector<index_t>& v_gs_os_ns_lengths_vec, static auto MakeVGridDescriptor_O0_N_O1(const std::vector<index_t>& v_gs_os_ns_lengths,
const std::vector<index_t>& v_gs_os_ns_strides_vec) const std::vector<index_t>& v_gs_os_ns_strides)
{ {
// v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major. // v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major.
// Here directly rearrange lengths/strides before constructing tensor descriptor to reduce // Here directly rearrange lengths/strides before constructing tensor descriptor to reduce
...@@ -532,17 +551,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -532,17 +551,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
ids_old2new.insert(ids_old2new.end(), ns_ids.begin(), ns_ids.end()); ids_old2new.insert(ids_old2new.end(), ns_ids.begin(), ns_ids.end());
ids_old2new.insert(ids_old2new.end(), os_ids.begin(), os_ids.end()); ids_old2new.insert(ids_old2new.end(), os_ids.begin(), os_ids.end());
std::vector<index_t> v_gs_ns_os_lengths_vec(num_dims), v_gs_ns_os_strides_vec(num_dims); std::vector<index_t> v_gs_ns_os_lengths(num_dims), v_gs_ns_os_strides(num_dims);
for(int i = 0; i < num_dims; i++) for(int i = 0; i < num_dims; i++)
{ {
index_t id_new = ids_old2new[i]; index_t id_new = ids_old2new[i];
v_gs_ns_os_lengths_vec[i] = v_gs_os_ns_lengths_vec[id_new]; v_gs_ns_os_lengths[i] = v_gs_os_ns_lengths[id_new];
v_gs_ns_os_strides_vec[i] = v_gs_os_ns_strides_vec[id_new]; v_gs_ns_os_strides[i] = v_gs_os_ns_strides[id_new];
} }
const auto v_grid_desc_nraw_oraw = const auto v_grid_desc_nraw_oraw =
MakeGridDescriptorPair<NumDimG, NumDimN, NumDimO, TensorSpecialization::Default>( MakeGridDescriptorPair<NumDimG, NumDimN, NumDimO, TensorSpecialization::Default>(
v_gs_ns_os_lengths_vec, v_gs_ns_os_strides_vec) v_gs_ns_os_lengths, v_gs_ns_os_strides)
.second; .second;
const auto v_grid_desc_n_o = PadTensorDescriptor(v_grid_desc_nraw_oraw, const auto v_grid_desc_n_o = PadTensorDescriptor(v_grid_desc_nraw_oraw,
...@@ -554,10 +573,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -554,10 +573,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
} }
// Z in Gemm0 C position // Z in Gemm0 C position
static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths_vec, static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides_vec) const std::vector<index_t>& z_gs_ms_ns_strides)
{ {
return Transform::MakeCGridDescriptor_M_N(z_gs_ms_ns_lengths_vec, z_gs_ms_ns_strides_vec); return Transform::MakeCGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
} }
// //
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i) // dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
...@@ -568,10 +587,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -568,10 +587,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
// //
// QGrad in Gemm C position // QGrad in Gemm C position
static auto MakeQGradGridDescriptor_M_K(const std::vector<index_t>& q_gs_ms_ks_lengths_vec, static auto MakeQGradGridDescriptor_M_K(const std::vector<index_t>& q_gs_ms_ks_lengths,
const std::vector<index_t>& q_gs_ms_ks_strides_vec) const std::vector<index_t>& q_gs_ms_ks_strides)
{ {
return Transform::MakeCGridDescriptor_M_N(q_gs_ms_ks_lengths_vec, q_gs_ms_ks_strides_vec); return Transform::MakeCGridDescriptor_M_N(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
} }
// //
...@@ -579,10 +598,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -579,10 +598,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
// //
// KGrad in Gemm C position // KGrad in Gemm C position
static auto MakeKGradGridDescriptor_N_K(const std::vector<index_t>& k_gs_ns_ks_lengths_vec, static auto MakeKGradGridDescriptor_N_K(const std::vector<index_t>& k_gs_ns_ks_lengths,
const std::vector<index_t>& k_gs_ns_ks_strides_vec) const std::vector<index_t>& k_gs_ns_ks_strides)
{ {
return Transform::MakeCGridDescriptor_M_N(k_gs_ns_ks_lengths_vec, k_gs_ns_ks_strides_vec); return Transform::MakeCGridDescriptor_M_N(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
} }
static auto MakeLSEGridDescriptor_M(index_t MRaw) static auto MakeLSEGridDescriptor_M(index_t MRaw)
...@@ -609,6 +628,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -609,6 +628,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
return lse_grid_desc_mraw; return lse_grid_desc_mraw;
} }
} }
// D0 in Gemm0 C position
static auto MakeD0GridDescriptor_M_N(const std::vector<index_t>& d_gs_ms_ns_lengths,
const std::vector<index_t>& d_gs_ms_ns_strides)
{
return Transform::MakeCGridDescriptor_M_N(d_gs_ms_ns_lengths, d_gs_ms_ns_strides);
}
static auto MakeDGridDescriptor_M(index_t MRaw) static auto MakeDGridDescriptor_M(index_t MRaw)
{ {
...@@ -637,6 +662,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -637,6 +662,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {})); using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {})); using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using D0GridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using B1GridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {})); using B1GridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using YGridDesc_M_O = decltype(Transform::MakeCGridDescriptor_M_N({}, {})); using YGridDesc_M_O = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1)); using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1));
...@@ -648,6 +674,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -648,6 +674,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
using DYGridDesc_M_O = decltype(DTransform::MakeCGridDescriptor_M_N({}, {})); using DYGridDesc_M_O = decltype(DTransform::MakeCGridDescriptor_M_N({}, {}));
using DGridDesc_M = decltype(MakeDGridDescriptor_M(1)); using DGridDesc_M = decltype(MakeDGridDescriptor_M(1));
using D0GridDesc_M_N = decltype(MakeD0GridDescriptor_M_N({}, {}));
using KGridDesc_N_K = decltype(Transform::MakeB0GridDescriptor_N_K({}, {})); using KGridDesc_N_K = decltype(Transform::MakeB0GridDescriptor_N_K({}, {}));
using YGradGridDesc_O0_M_O1 = decltype(MakeYGradGridDescriptor_O0_M_O1({}, {})); using YGradGridDesc_O0_M_O1 = decltype(MakeYGradGridDescriptor_O0_M_O1({}, {}));
using ZGridDesc_M_N = decltype(MakeZGridDescriptor_M_N({}, {})); using ZGridDesc_M_N = decltype(MakeZGridDescriptor_M_N({}, {}));
...@@ -671,14 +698,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -671,14 +698,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
struct ComputeBasePtrOfStridedBatch struct ComputeBasePtrOfStridedBatch
{ {
ComputeBasePtrOfStridedBatch() {}
ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k, ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k,
const BGridDesc_G_N_K& b_grid_desc_g_n_k, const BGridDesc_G_N_K& b_grid_desc_g_n_k,
const D0GridDesc_G_M_N& d0_grid_desc_g_m_n,
const ZGridDesc_G_M_N& z_grid_desc_g_m_n, const ZGridDesc_G_M_N& z_grid_desc_g_m_n,
const B1GridDesc_G_N_K& b1_grid_desc_g_n_k, const B1GridDesc_G_N_K& b1_grid_desc_g_n_k,
const CGridDesc_G_M_N& c_grid_desc_g_m_n, const CGridDesc_G_M_N& c_grid_desc_g_m_n,
index_t BatchStrideLSE) index_t BatchStrideLSE)
: a_grid_desc_g_m_k_(a_grid_desc_g_m_k), : a_grid_desc_g_m_k_(a_grid_desc_g_m_k),
b_grid_desc_g_n_k_(b_grid_desc_g_n_k), b_grid_desc_g_n_k_(b_grid_desc_g_n_k),
d0_grid_desc_g_m_n_(d0_grid_desc_g_m_n),
z_grid_desc_g_m_n_(z_grid_desc_g_m_n), z_grid_desc_g_m_n_(z_grid_desc_g_m_n),
b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k), b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k),
c_grid_desc_g_m_n_(c_grid_desc_g_m_n), c_grid_desc_g_m_n_(c_grid_desc_g_m_n),
...@@ -696,6 +726,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -696,6 +726,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
return b_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0)); return b_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
} }
__host__ __device__ constexpr long_index_t GetD0BasePtr(index_t g_idx) const
{
return d0_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
__host__ __device__ constexpr long_index_t GetZBasePtr(index_t g_idx) const __host__ __device__ constexpr long_index_t GetZBasePtr(index_t g_idx) const
{ {
return z_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0)); return z_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
...@@ -719,6 +754,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -719,6 +754,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
private: private:
AGridDesc_G_M_K a_grid_desc_g_m_k_; AGridDesc_G_M_K a_grid_desc_g_m_k_;
BGridDesc_G_N_K b_grid_desc_g_n_k_; BGridDesc_G_N_K b_grid_desc_g_n_k_;
D0GridDesc_G_M_N d0_grid_desc_g_m_n_;
ZGridDesc_G_M_N z_grid_desc_g_m_n_; ZGridDesc_G_M_N z_grid_desc_g_m_n_;
B1GridDesc_G_N_K b1_grid_desc_g_n_k_; B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_; CGridDesc_G_M_N c_grid_desc_g_m_n_;
...@@ -729,6 +765,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -729,6 +765,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1< using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1<
InputDataType, // TODO: distinguish A/B datatype InputDataType, // TODO: distinguish A/B datatype
D0DataType,
OutputDataType, OutputDataType,
ZDataType, ZDataType,
GemmDataType, GemmDataType,
...@@ -745,6 +782,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -745,6 +782,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
AGridDesc_AK0_M_AK1, AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1, BGridDesc_BK0_N_BK1,
KGridDesc_N_K, KGridDesc_N_K,
D0GridDesc_M_N,
ZGridDesc_M_N, ZGridDesc_M_N,
B1GridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1,
YGridDesc_M_O, YGridDesc_M_O,
...@@ -756,6 +794,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -756,6 +794,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
KPerBlock, KPerBlock,
Gemm1NPerBlock, Gemm1NPerBlock,
Gemm1KPerBlock, Gemm1KPerBlock,
Gemm2KPerBlock,
AK1, AK1,
BK1, BK1,
B1K1, B1K1,
...@@ -781,6 +820,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -781,6 +820,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
BBlockTransferDstScalarPerVector_BK1, BBlockTransferDstScalarPerVector_BK1,
true, true,
BBlockLdsExtraN, BBlockLdsExtraN,
D0BlockTransferSrcScalarPerVector,
CShuffleMXdlPerWavePerShuffle, CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
...@@ -802,46 +842,46 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -802,46 +842,46 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument( Argument(const InputDataType* p_a_grid,
const InputDataType* p_a_grid, const InputDataType* p_b_grid,
const InputDataType* p_b_grid, ZDataType* p_z_grid,
ZDataType* p_z_grid, const InputDataType* p_b1_grid,
const InputDataType* p_b1_grid, const InputDataType* p_c_grid, // for dS
const InputDataType* p_c_grid, // for dS const LSEDataType* p_lse_grid,
const LSEDataType* p_lse_grid, DDataType* p_d_grid,
DDataType* p_d_grid, const InputDataType* p_ygrad_grid,
const InputDataType* p_ygrad_grid, OutputDataType* p_qgrad_grid,
OutputDataType* p_qgrad_grid, OutputDataType* p_kgrad_grid,
OutputDataType* p_kgrad_grid, OutputDataType* p_vgrad_grid,
OutputDataType* p_vgrad_grid, const D0DataType* p_acc0_bias,
const std::array<void*, NumAcc0Bias> p_acc0_biases, const D1DataType* p_acc1_bias,
const std::array<void*, NumAcc1Bias> p_acc1_biases, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& b_gs_ns_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_strides, const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_lengths, const std::vector<index_t>& z_gs_ms_ns_strides,
const std::vector<index_t>& z_gs_ms_ns_strides, const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& lse_gs_ms_lengths,
const std::vector<index_t>& lse_gs_ms_lengths, const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths, const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides, const std::vector<ck::index_t>&
const std::array<std::vector<ck::index_t>, NumAcc1Bias> acc1_bias_gs_ms_gemm1ns_lengths, // acc1_bias_gs_ms_os_lengths
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths const std::vector<ck::index_t>&
const std::array<std::vector<ck::index_t>, NumAcc1Bias> acc1_bias_gs_ms_gemm1ns_strides, // acc1_bias_gs_ms_os_strides
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides AElementwiseOperation a_element_op,
AElementwiseOperation a_element_op, BElementwiseOperation b_element_op,
BElementwiseOperation b_element_op, AccElementwiseOperation acc_element_op,
AccElementwiseOperation acc_element_op, B1ElementwiseOperation b1_element_op,
B1ElementwiseOperation b1_element_op, CElementwiseOperation c_element_op,
CElementwiseOperation c_element_op, float p_drop,
float p_drop, std::tuple<unsigned long long, unsigned long long> seeds)
std::tuple<unsigned long long, unsigned long long> seeds)
: p_a_grid_{p_a_grid}, : p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
p_d0_grid_{p_acc0_bias},
p_z_grid_{p_z_grid}, p_z_grid_{p_z_grid},
p_b1_grid_{p_b1_grid}, p_b1_grid_{p_b1_grid},
p_c_grid_{p_c_grid}, p_c_grid_{p_c_grid},
...@@ -902,22 +942,38 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -902,22 +942,38 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
c_mz_gemm1nz_strides_{c_gs_ms_gemm1ns_strides[NumDimG + NumDimM - 1], c_mz_gemm1nz_strides_{c_gs_ms_gemm1ns_strides[NumDimG + NumDimM - 1],
c_gs_ms_gemm1ns_strides[NumDimG + NumDimM + NumDimO - 1]}, c_gs_ms_gemm1ns_strides[NumDimG + NumDimM + NumDimO - 1]},
batch_count_{c_grid_desc_g_m_n_.GetLength(I0)}, batch_count_{c_grid_desc_g_m_n_.GetLength(I0)},
compute_base_ptr_of_batch_{
a_grid_desc_g_m_k_,
b_grid_desc_g_n_k_,
z_grid_desc_g_m_n_,
b1_grid_desc_g_n_k_,
c_grid_desc_g_m_n_,
type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize())},
p_drop_{p_drop} p_drop_{p_drop}
{ {
// TODO: implement bias addition // TODO: implement bias addition
ignore = p_acc0_biases; ignore = p_acc0_bias;
ignore = p_acc1_biases; ignore = p_acc1_bias;
ignore = acc0_biases_gs_ms_ns_lengths; ignore = acc0_bias_gs_ms_ns_lengths;
ignore = acc0_biases_gs_ms_ns_strides; ignore = acc0_bias_gs_ms_ns_strides;
ignore = acc1_biases_gs_ms_gemm1ns_lengths; ignore = acc1_bias_gs_ms_gemm1ns_lengths;
ignore = acc1_biases_gs_ms_gemm1ns_strides; ignore = acc1_bias_gs_ms_gemm1ns_strides;
if constexpr(!is_same<D0DataType, void>::value)
{
const auto d0_grid_desc_m_n = MakeD0GridDescriptor_M_N(acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides);
d0_grid_desc_m0_n0_m1_m2_n1_m3_ =
GridwiseGemm::MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(d0_grid_desc_m_n);
d0_grid_desc_g_m_n_ = Transform::MakeCGridDescriptor_G_M_N(
acc0_bias_gs_ms_ns_lengths, acc0_bias_gs_ms_ns_strides);
d0_n_length_stride_.push_back(acc0_bias_gs_ms_ns_lengths[NumDimG + NumDimM]);
d0_n_length_stride_.push_back(acc0_bias_gs_ms_ns_strides[NumDimG + NumDimM]);
}
compute_base_ptr_of_batch_ = ComputeBasePtrOfStridedBatch(
a_grid_desc_g_m_k_,
b_grid_desc_g_n_k_,
d0_grid_desc_g_m_n_,
z_grid_desc_g_m_n_,
b1_grid_desc_g_n_k_,
c_grid_desc_g_m_n_,
type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize()));
seed_ = std::get<0>(seeds); seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds); offset_ = std::get<1>(seeds);
...@@ -961,6 +1017,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -961,6 +1017,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
// pointers // pointers
const InputDataType* p_a_grid_; const InputDataType* p_a_grid_;
const InputDataType* p_b_grid_; const InputDataType* p_b_grid_;
const D0DataType* p_d0_grid_;
ZDataType* p_z_grid_; ZDataType* p_z_grid_;
const InputDataType* p_b1_grid_; const InputDataType* p_b1_grid_;
const InputDataType* p_c_grid_; const InputDataType* p_c_grid_;
...@@ -974,6 +1031,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -974,6 +1031,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
// tensor descriptor // tensor descriptor
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::D0GridDescriptor_M0_N0_M1_M2_N1_M3 d0_grid_desc_m0_n0_m1_m2_n1_m3_;
ZGridDesc_M_N z_grid_desc_m_n_; ZGridDesc_M_N z_grid_desc_m_n_;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_; B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
YGridDesc_M_O y_grid_desc_m_o_; YGridDesc_M_O y_grid_desc_m_o_;
...@@ -986,6 +1044,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -986,6 +1044,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
// batch offsets // batch offsets
AGridDesc_G_M_K a_grid_desc_g_m_k_; AGridDesc_G_M_K a_grid_desc_g_m_k_;
BGridDesc_G_N_K b_grid_desc_g_n_k_; BGridDesc_G_N_K b_grid_desc_g_n_k_;
D0GridDesc_G_M_N d0_grid_desc_g_m_n_;
B1GridDesc_G_N_K b1_grid_desc_g_n_k_; B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_; CGridDesc_G_M_N c_grid_desc_g_m_n_;
ZGridDesc_G_M_N z_grid_desc_g_m_n_; ZGridDesc_G_M_N z_grid_desc_g_m_n_;
...@@ -1025,6 +1084,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1025,6 +1084,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
index_t m_raw_padded_; index_t m_raw_padded_;
index_t n_raw_padded_; index_t n_raw_padded_;
// raw data
std::vector<ck::index_t> d0_n_length_stride_;
}; };
// Invoker // Invoker
...@@ -1085,6 +1147,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1085,6 +1147,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
kernel_batched_multihead_attention_backward_qloop_xdl_cshuffle_light_v1< kernel_batched_multihead_attention_backward_qloop_xdl_cshuffle_light_v1<
GridwiseGemm, GridwiseGemm,
InputDataType, InputDataType,
D0DataType,
OutputDataType, OutputDataType,
ZDataType, ZDataType,
LSEDataType, LSEDataType,
...@@ -1096,6 +1159,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1096,6 +1159,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
CElementwiseOperation, CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::D0GridDescriptor_M0_N0_M1_M2_N1_M3,
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3, typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3,
DeviceOp::B1GridDesc_BK0_N_BK1, DeviceOp::B1GridDesc_BK0_N_BK1,
DeviceOp::LSEGridDesc_M, DeviceOp::LSEGridDesc_M,
...@@ -1115,6 +1179,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1115,6 +1179,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
0, 0,
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_d0_grid_,
arg.p_z_grid_, arg.p_z_grid_,
arg.p_b1_grid_, arg.p_b1_grid_,
arg.p_lse_grid_, arg.p_lse_grid_,
...@@ -1130,6 +1195,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1130,6 +1195,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
arg.c_element_op_, arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.d0_grid_desc_m0_n0_m1_m2_n1_m3_,
arg.c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_, arg.c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg.b1_grid_desc_bk0_n_bk1_, arg.b1_grid_desc_bk0_n_bk1_,
arg.lse_grid_desc_m_, arg.lse_grid_desc_m_,
...@@ -1200,6 +1266,19 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1200,6 +1266,19 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
return false; return false;
} }
if constexpr(!is_same<D0DataType, void>::value)
{
if(arg.d0_n_length_stride_[1] == 1 &&
arg.d0_n_length_stride_[0] % D0BlockTransferSrcScalarPerVector != 0)
{
return false;
}
if(arg.d0_n_length_stride_[1] != 1 && D0BlockTransferSrcScalarPerVector != 1)
{
return false;
}
}
// Note: we need raw lengths since threadwise copy can not handle vector load when part of // Note: we need raw lengths since threadwise copy can not handle vector load when part of
// vector is out of bounds // vector is out of bounds
// Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O // Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O
...@@ -1245,44 +1324,44 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1245,44 +1324,44 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg)); return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
} }
static auto MakeArgument( static auto
const InputDataType* p_a, MakeArgument(const InputDataType* p_a,
const InputDataType* p_b, const InputDataType* p_b,
ZDataType* p_z, ZDataType* p_z,
const InputDataType* p_b1, const InputDataType* p_b1,
const InputDataType* p_c, const InputDataType* p_c,
const LSEDataType* p_lse, const LSEDataType* p_lse,
DDataType* p_d_grid, DDataType* p_d_grid,
const InputDataType* p_ygrad_grid, const InputDataType* p_ygrad_grid,
OutputDataType* p_qgrad_grid, OutputDataType* p_qgrad_grid,
OutputDataType* p_kgrad_grid, OutputDataType* p_kgrad_grid,
OutputDataType* p_vgrad_grid, OutputDataType* p_vgrad_grid,
const std::array<void*, NumAcc0Bias> p_acc0_biases, const D0DataType* p_acc0_bias,
const std::array<void*, NumAcc1Bias> p_acc1_biases, const D1DataType* p_acc1_bias,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_strides, const std::vector<index_t>& b_gs_ns_ks_strides,
const std::vector<index_t>& z_gs_ms_ns_lengths, const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides, const std::vector<index_t>& z_gs_ms_ns_strides,
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths, const std::vector<index_t>& lse_gs_ms_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths, const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides, const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias> const std::vector<ck::index_t>&
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths acc1_bias_gs_ms_gemm1ns_lengths, // acc1_bias_gs_ms_os_lengths
const std::array<std::vector<ck::index_t>, NumAcc1Bias> const std::vector<ck::index_t>&
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides acc1_bias_gs_ms_gemm1ns_strides, // acc1_bias_gs_ms_os_strides
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op, AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
float p_drop, float p_drop,
std::tuple<unsigned long long, unsigned long long> seeds) std::tuple<unsigned long long, unsigned long long> seeds)
{ {
return Argument{p_a, return Argument{p_a,
p_b, p_b,
...@@ -1295,8 +1374,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1295,8 +1374,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
p_qgrad_grid, p_qgrad_grid,
p_kgrad_grid, p_kgrad_grid,
p_vgrad_grid, p_vgrad_grid,
p_acc0_biases, p_acc0_bias,
p_acc1_biases, p_acc1_bias,
a_gs_ms_ks_lengths, a_gs_ms_ks_lengths,
a_gs_ms_ks_strides, a_gs_ms_ks_strides,
b_gs_ns_ks_lengths, b_gs_ns_ks_lengths,
...@@ -1308,10 +1387,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1308,10 +1387,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
lse_gs_ms_lengths, lse_gs_ms_lengths,
acc0_biases_gs_ms_ns_lengths, acc0_bias_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides, acc0_bias_gs_ms_ns_strides,
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths acc1_bias_gs_ms_gemm1ns_lengths, // acc1_bias_gs_ms_os_lengths
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides acc1_bias_gs_ms_gemm1ns_strides, // acc1_bias_gs_ms_os_strides
a_element_op, a_element_op,
b_element_op, b_element_op,
acc_element_op, acc_element_op,
...@@ -1337,8 +1416,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1337,8 +1416,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
void* p_qgrad_grid, void* p_qgrad_grid,
void* p_kgrad_grid, void* p_kgrad_grid,
void* p_vgrad_grid, void* p_vgrad_grid,
const std::array<void*, NumAcc0Bias> p_acc0_biases, const D0DataType* p_acc0_bias,
const std::array<void*, NumAcc1Bias> p_acc1_biases, const D1DataType* p_acc1_bias,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_lengths,
...@@ -1350,12 +1429,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1350,12 +1429,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths, const std::vector<index_t>& lse_gs_ms_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths, const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides, const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias> const std::vector<ck::index_t>&
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths acc1_bias_gs_ms_gemm1ns_lengths, // acc1_bias_gs_ms_os_lengths
const std::array<std::vector<ck::index_t>, NumAcc1Bias> const std::vector<ck::index_t>&
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides acc1_bias_gs_ms_gemm1ns_strides, // acc1_bias_gs_ms_os_strides
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op, AccElementwiseOperation acc_element_op,
...@@ -1364,41 +1443,42 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1364,41 +1443,42 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
float p_drop, float p_drop,
std::tuple<unsigned long long, unsigned long long> seeds) // override std::tuple<unsigned long long, unsigned long long> seeds) // override
{ {
return std::make_unique<Argument>(static_cast<const InputDataType*>(p_a), return std::make_unique<Argument>(
static_cast<const InputDataType*>(p_b), static_cast<const InputDataType*>(p_a),
static_cast<ZDataType*>(p_z), static_cast<const InputDataType*>(p_b),
static_cast<const InputDataType*>(p_b1), static_cast<ZDataType*>(p_z),
static_cast<const InputDataType*>(p_c), static_cast<const InputDataType*>(p_b1),
static_cast<const LSEDataType*>(p_lse), static_cast<const InputDataType*>(p_c),
static_cast<DDataType*>(p_d_grid), static_cast<const LSEDataType*>(p_lse),
static_cast<const InputDataType*>(p_ygrad_grid), static_cast<DDataType*>(p_d_grid),
static_cast<OutputDataType*>(p_qgrad_grid), static_cast<const InputDataType*>(p_ygrad_grid),
static_cast<OutputDataType*>(p_kgrad_grid), static_cast<OutputDataType*>(p_qgrad_grid),
static_cast<OutputDataType*>(p_vgrad_grid), static_cast<OutputDataType*>(p_kgrad_grid),
p_acc0_biases, // cast in struct Argument static_cast<OutputDataType*>(p_vgrad_grid),
p_acc1_biases, // cast in struct Argument static_cast<const D0DataType*>(p_acc0_bias), // cast in struct Argument
a_gs_ms_ks_lengths, static_cast<const D1DataType*>(p_acc1_bias), // cast in struct Argument
a_gs_ms_ks_strides, a_gs_ms_ks_lengths,
b_gs_ns_ks_lengths, a_gs_ms_ks_strides,
b_gs_ns_ks_strides, b_gs_ns_ks_lengths,
z_gs_ms_ns_lengths, b_gs_ns_ks_strides,
z_gs_ms_ns_strides, z_gs_ms_ns_lengths,
b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths z_gs_ms_ns_strides,
b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
lse_gs_ms_lengths, c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
acc0_biases_gs_ms_ns_lengths, lse_gs_ms_lengths,
acc0_biases_gs_ms_ns_strides, acc0_bias_gs_ms_ns_lengths,
acc1_biases_gs_ms_gemm1ns_lengths, acc0_bias_gs_ms_ns_strides,
acc1_biases_gs_ms_gemm1ns_strides, acc1_bias_gs_ms_gemm1ns_lengths,
a_element_op, acc1_bias_gs_ms_gemm1ns_strides,
b_element_op, a_element_op,
acc_element_op, b_element_op,
b1_element_op, acc_element_op,
c_element_op, b1_element_op,
p_drop, c_element_op,
seeds); p_drop,
seeds);
} }
// polymorphic // polymorphic
...@@ -1424,6 +1504,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1424,6 +1504,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
<< MPerBlock << ", " << MPerBlock << ", "
<< Gemm1NPerBlock << ", " << Gemm1NPerBlock << ", "
<< Gemm1KPerBlock << ", " << Gemm1KPerBlock << ", "
<< Gemm2KPerBlock << ", "
<< B1K1 << ", " << B1K1 << ", "
<< getGemmSpecializationString(GemmSpec) << ", " << getGemmSpecializationString(GemmSpec) << ", "
<< "ASpec" << getTensorSpecializationString(ASpec) << ", " << "ASpec" << getTensorSpecializationString(ASpec) << ", "
......
...@@ -82,6 +82,7 @@ __global__ void ...@@ -82,6 +82,7 @@ __global__ void
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename InputDataType, typename InputDataType,
typename D0DataType,
typename OutputDataType, typename OutputDataType,
typename ZDataType, typename ZDataType,
typename LSEDataType, typename LSEDataType,
...@@ -93,6 +94,7 @@ template <typename GridwiseGemm, ...@@ -93,6 +94,7 @@ template <typename GridwiseGemm,
typename CElementwiseOperation, typename CElementwiseOperation,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename D0GridDescriptor_M0_N0_M1_M2_N1_M3,
typename ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3, typename ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3,
typename B1GridDesc_BK0_N_BK1, typename B1GridDesc_BK0_N_BK1,
typename LSEGridDescriptor_M, typename LSEGridDescriptor_M,
...@@ -110,6 +112,7 @@ __global__ void ...@@ -110,6 +112,7 @@ __global__ void
kernel_batched_multihead_attention_backward_qloop_xdl_cshuffle_light_v2( kernel_batched_multihead_attention_backward_qloop_xdl_cshuffle_light_v2(
const InputDataType* __restrict__ p_a_grid, const InputDataType* __restrict__ p_a_grid,
const InputDataType* __restrict__ p_b_grid, const InputDataType* __restrict__ p_b_grid,
const D0DataType* __restrict__ p_d0_grid,
ZDataType* __restrict__ p_z_grid, ZDataType* __restrict__ p_z_grid,
const InputDataType* __restrict__ p_b1_grid, const InputDataType* __restrict__ p_b1_grid,
const LSEDataType* __restrict__ p_lse_grid, const LSEDataType* __restrict__ p_lse_grid,
...@@ -125,6 +128,7 @@ __global__ void ...@@ -125,6 +128,7 @@ __global__ void
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const D0GridDescriptor_M0_N0_M1_M2_N1_M3 d0_grid_desc_m0_n0_m1_m2_n1_m3,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3 const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
...@@ -168,6 +172,14 @@ __global__ void ...@@ -168,6 +172,14 @@ __global__ void
const index_t z_random_matrix_offset = g_idx * raw_m_padded * raw_n_padded; const index_t z_random_matrix_offset = g_idx * raw_m_padded * raw_n_padded;
const D0DataType* tmp_p_d0_grid = nullptr;
if constexpr(!is_same<D0DataType, void>::value)
{
const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetD0BasePtr(g_idx)));
tmp_p_d0_grid = p_d0_grid + d0_batch_offset;
}
if constexpr(Deterministic) if constexpr(Deterministic)
{ {
for(index_t i = 0; i < nblock; i++) for(index_t i = 0; i < nblock; i++)
...@@ -175,6 +187,7 @@ __global__ void ...@@ -175,6 +187,7 @@ __global__ void
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>( GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
p_a_grid + a_batch_offset, p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset, p_b_grid + b_batch_offset,
tmp_p_d0_grid,
z_matrix_ptr, z_matrix_ptr,
p_b1_grid + b1_batch_offset, p_b1_grid + b1_batch_offset,
p_lse_grid + lse_batch_offset, p_lse_grid + lse_batch_offset,
...@@ -191,6 +204,7 @@ __global__ void ...@@ -191,6 +204,7 @@ __global__ void
c_element_op, c_element_op,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
d0_grid_desc_m0_n0_m1_m2_n1_m3,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
lse_grid_desc_m, lse_grid_desc_m,
...@@ -209,6 +223,7 @@ __global__ void ...@@ -209,6 +223,7 @@ __global__ void
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>( GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
p_a_grid + a_batch_offset, p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset, p_b_grid + b_batch_offset,
tmp_p_d0_grid,
z_matrix_ptr, z_matrix_ptr,
p_b1_grid + b1_batch_offset, p_b1_grid + b1_batch_offset,
p_lse_grid + lse_batch_offset, p_lse_grid + lse_batch_offset,
...@@ -225,6 +240,7 @@ __global__ void ...@@ -225,6 +240,7 @@ __global__ void
c_element_op, c_element_op,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
d0_grid_desc_m0_n0_m1_m2_n1_m3,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
lse_grid_desc_m, lse_grid_desc_m,
...@@ -240,6 +256,7 @@ __global__ void ...@@ -240,6 +256,7 @@ __global__ void
#else #else
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b_grid; ignore = p_b_grid;
ignore = p_d0_grid;
ignore = p_z_grid; ignore = p_z_grid;
ignore = p_b1_grid; ignore = p_b1_grid;
ignore = p_lse_grid; ignore = p_lse_grid;
...@@ -255,6 +272,7 @@ __global__ void ...@@ -255,6 +272,7 @@ __global__ void
ignore = c_element_op; ignore = c_element_op;
ignore = a_grid_desc_ak0_m_ak1; ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1; ignore = b_grid_desc_bk0_n_bk1;
ignore = d0_grid_desc_m0_n0_m1_m2_n1_m3;
ignore = c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3; ignore = c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3;
ignore = b1_grid_desc_bk0_n_bk1; ignore = b1_grid_desc_bk0_n_bk1;
ignore = lse_grid_desc_m; ignore = lse_grid_desc_m;
...@@ -307,6 +325,7 @@ template <index_t NumDimG, ...@@ -307,6 +325,7 @@ template <index_t NumDimG,
index_t KPerBlock, // Gemm0KPerBlock index_t KPerBlock, // Gemm0KPerBlock
index_t Gemm1NPerBlock, index_t Gemm1NPerBlock,
index_t Gemm1KPerBlock, index_t Gemm1KPerBlock,
index_t Gemm2KPerBlock,
index_t AK1, index_t AK1,
index_t BK1, index_t BK1,
index_t B1K1, index_t B1K1,
...@@ -331,6 +350,7 @@ template <index_t NumDimG, ...@@ -331,6 +350,7 @@ template <index_t NumDimG,
index_t BBlockTransferSrcScalarPerVector, index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1, index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN, bool BBlockLdsExtraN,
index_t D0BlockTransferSrcScalarPerVector,
typename B1BlockTransferThreadClusterLengths_BK0_N_BK1, typename B1BlockTransferThreadClusterLengths_BK0_N_BK1,
typename B1BlockTransferThreadClusterArrangeOrder, typename B1BlockTransferThreadClusterArrangeOrder,
typename B1BlockTransferSrcAccessOrder, typename B1BlockTransferSrcAccessOrder,
...@@ -351,12 +371,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -351,12 +371,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
"Number of dimension must be greater than 0"); "Number of dimension must be greater than 0");
static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size(); using D0DataType = Acc0BiasDataType;
static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size(); using D1DataType = Acc1BiasDataType;
static constexpr index_t DMPerBlock = BlockSize; static constexpr index_t DMPerBlock = BlockSize;
// TODO: implement bias combination // TODO: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented"); static_assert(std::is_void<D1DataType>::value, "Acc1 Bias addition is unimplemented");
using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2; using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2;
...@@ -364,9 +384,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -364,9 +384,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr index_t V_O1 = 8; static constexpr index_t V_O1 = BK1;
static constexpr index_t Y_O1 = 8; static constexpr index_t Y_O1 = AK1;
static constexpr index_t Y_M1 = 2; static constexpr index_t Y_M1 = B1K1;
static constexpr auto padder = GemmGemmPadder<GemmSpec, static constexpr auto padder = GemmGemmPadder<GemmSpec,
Number<MPerBlock>, Number<MPerBlock>,
...@@ -404,31 +424,31 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -404,31 +424,31 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
*/ */
// Q in Gemm A position // Q in Gemm A position
static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_gs_ms_ks_lengths_vec, static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides_vec) const std::vector<index_t>& a_gs_ms_ks_strides)
{ {
return Transform::MakeAGridDescriptor_AK0_M_AK1( return Transform::MakeAGridDescriptor_AK0_M_AK1(
Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec), Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides),
Number<AK1>{}); Number<AK1>{});
} }
// K in Gemm B0 position // K in Gemm B0 position
static auto MakeBGridDescriptor_BK0_N_BK1(const std::vector<index_t>& b_gs_ns_ks_lengths_vec, static auto MakeBGridDescriptor_BK0_N_BK1(const std::vector<index_t>& b_gs_ns_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_strides_vec) const std::vector<index_t>& b_gs_ns_ks_strides)
{ {
return Transform::MakeB0GridDescriptor_BK0_N_BK1( return Transform::MakeB0GridDescriptor_BK0_N_BK1(
Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths_vec, b_gs_ns_ks_strides_vec), Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides),
Number<BK1>{}); Number<BK1>{});
} }
// V in Gemm B1 position // V in Gemm B1 position
static auto static auto
MakeB1GridDescriptor_BK0_N_BK1(const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths_vec, MakeB1GridDescriptor_BK0_N_BK1(const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths,
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides_vec) const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides)
{ {
return Transform::MakeB1GridDescriptor_BK0_N_BK1( return Transform::MakeB1GridDescriptor_BK0_N_BK1(
Transform::MakeB1GridDescriptor_N_K(b1_gs_gemm1ns_gemm1ks_lengths_vec, Transform::MakeB1GridDescriptor_N_K(b1_gs_gemm1ns_gemm1ks_lengths,
b1_gs_gemm1ns_gemm1ks_strides_vec), b1_gs_gemm1ns_gemm1ks_strides),
Number<B1K1>{}); Number<B1K1>{});
} }
...@@ -437,8 +457,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -437,8 +457,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// //
// VGrad in Gemm C position // VGrad in Gemm C position
static auto MakeVGradGridDescriptor_N_O(const std::vector<index_t>& v_gs_os_ns_lengths_vec, static auto MakeVGradGridDescriptor_N_O(const std::vector<index_t>& v_gs_os_ns_lengths,
const std::vector<index_t>& v_gs_os_ns_strides_vec) const std::vector<index_t>& v_gs_os_ns_strides)
{ {
// v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major. // v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major.
// Here directly rearrange lengths/strides before constructing tensor descriptor to reduce // Here directly rearrange lengths/strides before constructing tensor descriptor to reduce
...@@ -464,17 +484,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -464,17 +484,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
ids_old2new.insert(ids_old2new.end(), ns_ids.begin(), ns_ids.end()); ids_old2new.insert(ids_old2new.end(), ns_ids.begin(), ns_ids.end());
ids_old2new.insert(ids_old2new.end(), os_ids.begin(), os_ids.end()); ids_old2new.insert(ids_old2new.end(), os_ids.begin(), os_ids.end());
std::vector<index_t> v_gs_ns_os_lengths_vec(num_dims), v_gs_ns_os_strides_vec(num_dims); std::vector<index_t> v_gs_ns_os_lengths(num_dims), v_gs_ns_os_strides(num_dims);
for(int i = 0; i < num_dims; i++) for(int i = 0; i < num_dims; i++)
{ {
index_t id_new = ids_old2new[i]; index_t id_new = ids_old2new[i];
v_gs_ns_os_lengths_vec[i] = v_gs_os_ns_lengths_vec[id_new]; v_gs_ns_os_lengths[i] = v_gs_os_ns_lengths[id_new];
v_gs_ns_os_strides_vec[i] = v_gs_os_ns_strides_vec[id_new]; v_gs_ns_os_strides[i] = v_gs_os_ns_strides[id_new];
} }
const auto vgrad_desc_nraw_oraw = const auto vgrad_desc_nraw_oraw =
MakeGridDescriptorPair<NumDimG, NumDimN, NumDimO, TensorSpecialization::Default>( MakeGridDescriptorPair<NumDimG, NumDimN, NumDimO, TensorSpecialization::Default>(
v_gs_ns_os_lengths_vec, v_gs_ns_os_strides_vec) v_gs_ns_os_lengths, v_gs_ns_os_strides)
.second; .second;
return PadTensorDescriptor(vgrad_desc_nraw_oraw, return PadTensorDescriptor(vgrad_desc_nraw_oraw,
...@@ -503,17 +523,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -503,17 +523,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// //
// YGrad in Gemm A position // YGrad in Gemm A position
static auto MakeYGradGridDescriptor_O0_M_O1(const std::vector<index_t>& y_gs_ms_os_lengths_vec, static auto MakeYGradGridDescriptor_O0_M_O1(const std::vector<index_t>& y_gs_ms_os_lengths,
const std::vector<index_t>& y_gs_ms_os_strides_vec) const std::vector<index_t>& y_gs_ms_os_strides)
{ {
return Transform::MakeAGridDescriptor_AK0_M_AK1( return Transform::MakeAGridDescriptor_AK0_M_AK1(
Transform::MakeAGridDescriptor_M_K(y_gs_ms_os_lengths_vec, y_gs_ms_os_strides_vec), Transform::MakeAGridDescriptor_M_K(y_gs_ms_os_lengths, y_gs_ms_os_strides),
Number<Y_O1>{}); Number<Y_O1>{});
} }
// V in Gemm B position // V in Gemm B position
static auto MakeVGridDescriptor_O0_N_O1(const std::vector<index_t>& v_gs_os_ns_lengths_vec, static auto MakeVGridDescriptor_O0_N_O1(const std::vector<index_t>& v_gs_os_ns_lengths,
const std::vector<index_t>& v_gs_os_ns_strides_vec) const std::vector<index_t>& v_gs_os_ns_strides)
{ {
// v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major. // v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major.
// Here directly rearrange lengths/strides before constructing tensor descriptor to reduce // Here directly rearrange lengths/strides before constructing tensor descriptor to reduce
...@@ -539,17 +559,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -539,17 +559,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
ids_old2new.insert(ids_old2new.end(), ns_ids.begin(), ns_ids.end()); ids_old2new.insert(ids_old2new.end(), ns_ids.begin(), ns_ids.end());
ids_old2new.insert(ids_old2new.end(), os_ids.begin(), os_ids.end()); ids_old2new.insert(ids_old2new.end(), os_ids.begin(), os_ids.end());
std::vector<index_t> v_gs_ns_os_lengths_vec(num_dims), v_gs_ns_os_strides_vec(num_dims); std::vector<index_t> v_gs_ns_os_lengths(num_dims), v_gs_ns_os_strides(num_dims);
for(int i = 0; i < num_dims; i++) for(int i = 0; i < num_dims; i++)
{ {
index_t id_new = ids_old2new[i]; index_t id_new = ids_old2new[i];
v_gs_ns_os_lengths_vec[i] = v_gs_os_ns_lengths_vec[id_new]; v_gs_ns_os_lengths[i] = v_gs_os_ns_lengths[id_new];
v_gs_ns_os_strides_vec[i] = v_gs_os_ns_strides_vec[id_new]; v_gs_ns_os_strides[i] = v_gs_os_ns_strides[id_new];
} }
const auto v_grid_desc_nraw_oraw = const auto v_grid_desc_nraw_oraw =
MakeGridDescriptorPair<NumDimG, NumDimN, NumDimO, TensorSpecialization::Default>( MakeGridDescriptorPair<NumDimG, NumDimN, NumDimO, TensorSpecialization::Default>(
v_gs_ns_os_lengths_vec, v_gs_ns_os_strides_vec) v_gs_ns_os_lengths, v_gs_ns_os_strides)
.second; .second;
const auto v_grid_desc_n_o = PadTensorDescriptor(v_grid_desc_nraw_oraw, const auto v_grid_desc_n_o = PadTensorDescriptor(v_grid_desc_nraw_oraw,
...@@ -560,11 +580,18 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -560,11 +580,18 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
return Transform::MakeB0GridDescriptor_BK0_N_BK1(v_grid_desc_n_o, Number<V_O1>{}); return Transform::MakeB0GridDescriptor_BK0_N_BK1(v_grid_desc_n_o, Number<V_O1>{});
} }
// D0 in Gemm0 C position
static auto MakeD0GridDescriptor_M_N(const std::vector<index_t>& d_gs_ms_ns_lengths,
const std::vector<index_t>& d_gs_ms_ns_strides)
{
return Transform::MakeCGridDescriptor_M_N(d_gs_ms_ns_lengths, d_gs_ms_ns_strides);
}
// Z in Gemm0 C position // Z in Gemm0 C position
static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths_vec, static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides_vec) const std::vector<index_t>& z_gs_ms_ns_strides)
{ {
return Transform::MakeCGridDescriptor_M_N(z_gs_ms_ns_lengths_vec, z_gs_ms_ns_strides_vec); return Transform::MakeCGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
} }
// //
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i) // dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
...@@ -575,10 +602,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -575,10 +602,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// //
// QGrad in Gemm C position // QGrad in Gemm C position
static auto MakeQGradGridDescriptor_M_K(const std::vector<index_t>& q_gs_ms_ks_lengths_vec, static auto MakeQGradGridDescriptor_M_K(const std::vector<index_t>& q_gs_ms_ks_lengths,
const std::vector<index_t>& q_gs_ms_ks_strides_vec) const std::vector<index_t>& q_gs_ms_ks_strides)
{ {
return Transform::MakeCGridDescriptor_M_N(q_gs_ms_ks_lengths_vec, q_gs_ms_ks_strides_vec); return Transform::MakeCGridDescriptor_M_N(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
} }
// //
...@@ -586,10 +613,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -586,10 +613,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// //
// KGrad in Gemm C position // KGrad in Gemm C position
static auto MakeKGradGridDescriptor_N_K(const std::vector<index_t>& k_gs_ns_ks_lengths_vec, static auto MakeKGradGridDescriptor_N_K(const std::vector<index_t>& k_gs_ns_ks_lengths,
const std::vector<index_t>& k_gs_ns_ks_strides_vec) const std::vector<index_t>& k_gs_ns_ks_strides)
{ {
return Transform::MakeCGridDescriptor_M_N(k_gs_ns_ks_lengths_vec, k_gs_ns_ks_strides_vec); return Transform::MakeCGridDescriptor_M_N(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
} }
static auto MakeLSEGridDescriptor_M(index_t MRaw) static auto MakeLSEGridDescriptor_M(index_t MRaw)
...@@ -644,7 +671,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -644,7 +671,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {})); using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {})); using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {})); using D0GridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using B1GridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using YGridDesc_M_O = decltype(Transform::MakeCGridDescriptor_M_N({}, {})); using YGridDesc_M_O = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1)); using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1));
using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {})); using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {}));
...@@ -655,6 +683,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -655,6 +683,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
using DYGridDesc_M_O = decltype(DTransform::MakeCGridDescriptor_M_N({}, {})); using DYGridDesc_M_O = decltype(DTransform::MakeCGridDescriptor_M_N({}, {}));
using DGridDesc_M = decltype(MakeDGridDescriptor_M(1)); using DGridDesc_M = decltype(MakeDGridDescriptor_M(1));
using D0GridDesc_M_N = decltype(MakeD0GridDescriptor_M_N({}, {}));
using KGridDesc_N_K = decltype(Transform::MakeB0GridDescriptor_N_K({}, {})); using KGridDesc_N_K = decltype(Transform::MakeB0GridDescriptor_N_K({}, {}));
using YGradGridDesc_M0_O_M1 = decltype(MakeYGradGridDescriptor_M0_O_M1(YGridDesc_M_O{})); using YGradGridDesc_M0_O_M1 = decltype(MakeYGradGridDescriptor_M0_O_M1(YGridDesc_M_O{}));
using ZGridDesc_M_N = decltype(MakeZGridDescriptor_M_N({}, {})); using ZGridDesc_M_N = decltype(MakeZGridDescriptor_M_N({}, {}));
...@@ -678,14 +707,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -678,14 +707,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
struct ComputeBasePtrOfStridedBatch struct ComputeBasePtrOfStridedBatch
{ {
ComputeBasePtrOfStridedBatch() {}
ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k, ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k,
const BGridDesc_G_N_K& b_grid_desc_g_n_k, const BGridDesc_G_N_K& b_grid_desc_g_n_k,
const D0GridDesc_G_M_N& d0_grid_desc_g_m_n,
const ZGridDesc_G_M_N& z_grid_desc_g_m_n, const ZGridDesc_G_M_N& z_grid_desc_g_m_n,
const B1GridDesc_G_N_K& b1_grid_desc_g_n_k, const B1GridDesc_G_N_K& b1_grid_desc_g_n_k,
const CGridDesc_G_M_N& c_grid_desc_g_m_n, const CGridDesc_G_M_N& c_grid_desc_g_m_n,
index_t BatchStrideLSE) index_t BatchStrideLSE)
: a_grid_desc_g_m_k_(a_grid_desc_g_m_k), : a_grid_desc_g_m_k_(a_grid_desc_g_m_k),
b_grid_desc_g_n_k_(b_grid_desc_g_n_k), b_grid_desc_g_n_k_(b_grid_desc_g_n_k),
d0_grid_desc_g_m_n_(d0_grid_desc_g_m_n),
z_grid_desc_g_m_n_(z_grid_desc_g_m_n), z_grid_desc_g_m_n_(z_grid_desc_g_m_n),
b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k), b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k),
c_grid_desc_g_m_n_(c_grid_desc_g_m_n), c_grid_desc_g_m_n_(c_grid_desc_g_m_n),
...@@ -703,6 +735,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -703,6 +735,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
return b_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0)); return b_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
} }
__host__ __device__ constexpr long_index_t GetD0BasePtr(index_t g_idx) const
{
return d0_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
__host__ __device__ constexpr long_index_t GetZBasePtr(index_t g_idx) const __host__ __device__ constexpr long_index_t GetZBasePtr(index_t g_idx) const
{ {
return z_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0)); return z_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
...@@ -726,6 +762,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -726,6 +762,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
private: private:
AGridDesc_G_M_K a_grid_desc_g_m_k_; AGridDesc_G_M_K a_grid_desc_g_m_k_;
BGridDesc_G_N_K b_grid_desc_g_n_k_; BGridDesc_G_N_K b_grid_desc_g_n_k_;
D0GridDesc_G_M_N d0_grid_desc_g_m_n_;
ZGridDesc_G_M_N z_grid_desc_g_m_n_; ZGridDesc_G_M_N z_grid_desc_g_m_n_;
B1GridDesc_G_N_K b1_grid_desc_g_n_k_; B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_; CGridDesc_G_M_N c_grid_desc_g_m_n_;
...@@ -736,6 +773,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -736,6 +773,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2< using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2<
InputDataType, // TODO: distinguish A/B datatype InputDataType, // TODO: distinguish A/B datatype
D0DataType,
OutputDataType, OutputDataType,
ZDataType, ZDataType,
GemmDataType, GemmDataType,
...@@ -752,6 +790,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -752,6 +790,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
AGridDesc_AK0_M_AK1, AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1, BGridDesc_BK0_N_BK1,
KGridDesc_N_K, KGridDesc_N_K,
D0GridDesc_M_N,
ZGridDesc_M_N, ZGridDesc_M_N,
B1GridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1,
YGridDesc_M_O, YGridDesc_M_O,
...@@ -763,6 +802,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -763,6 +802,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
KPerBlock, KPerBlock,
Gemm1NPerBlock, Gemm1NPerBlock,
Gemm1KPerBlock, Gemm1KPerBlock,
Gemm2KPerBlock,
AK1, AK1,
BK1, BK1,
B1K1, B1K1,
...@@ -788,6 +828,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -788,6 +828,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
BBlockTransferDstScalarPerVector_BK1, BBlockTransferDstScalarPerVector_BK1,
true, true,
BBlockLdsExtraN, BBlockLdsExtraN,
D0BlockTransferSrcScalarPerVector,
B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterLengths_BK0_N_BK1,
B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferThreadClusterArrangeOrder,
B1BlockTransferSrcAccessOrder, B1BlockTransferSrcAccessOrder,
...@@ -817,46 +858,46 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -817,46 +858,46 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument( Argument(const InputDataType* p_a_grid,
const InputDataType* p_a_grid, const InputDataType* p_b_grid,
const InputDataType* p_b_grid, ZDataType* p_z_grid,
ZDataType* p_z_grid, const InputDataType* p_b1_grid,
const InputDataType* p_b1_grid, const InputDataType* p_c_grid, // for dS
const InputDataType* p_c_grid, // for dS const LSEDataType* p_lse_grid,
const LSEDataType* p_lse_grid, DDataType* p_d_grid,
DDataType* p_d_grid, const InputDataType* p_ygrad_grid,
const InputDataType* p_ygrad_grid, OutputDataType* p_qgrad_grid,
OutputDataType* p_qgrad_grid, OutputDataType* p_kgrad_grid,
OutputDataType* p_kgrad_grid, OutputDataType* p_vgrad_grid,
OutputDataType* p_vgrad_grid, const D0DataType* p_acc0_bias,
const std::array<void*, NumAcc0Bias> p_acc0_biases, const D1DataType* p_acc1_bias,
const std::array<void*, NumAcc1Bias> p_acc1_biases, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& b_gs_ns_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_strides, const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_lengths, const std::vector<index_t>& z_gs_ms_ns_strides,
const std::vector<index_t>& z_gs_ms_ns_strides, const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& lse_gs_ms_lengths,
const std::vector<index_t>& lse_gs_ms_lengths, const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths, const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides, const std::vector<ck::index_t>&
const std::array<std::vector<ck::index_t>, NumAcc1Bias> acc1_bias_gs_ms_gemm1ns_lengths, // acc1_bias_gs_ms_os_lengths
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths const std::vector<ck::index_t>&
const std::array<std::vector<ck::index_t>, NumAcc1Bias> acc1_bias_gs_ms_gemm1ns_strides, // acc1_bias_gs_ms_os_strides
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides AElementwiseOperation a_element_op,
AElementwiseOperation a_element_op, BElementwiseOperation b_element_op,
BElementwiseOperation b_element_op, AccElementwiseOperation acc_element_op,
AccElementwiseOperation acc_element_op, B1ElementwiseOperation b1_element_op,
B1ElementwiseOperation b1_element_op, CElementwiseOperation c_element_op,
CElementwiseOperation c_element_op, float p_drop,
float p_drop, std::tuple<unsigned long long, unsigned long long> seeds)
std::tuple<unsigned long long, unsigned long long> seeds)
: p_a_grid_{p_a_grid}, : p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
p_d0_grid_{p_acc0_bias},
p_z_grid_{p_z_grid}, p_z_grid_{p_z_grid},
p_b1_grid_{p_b1_grid}, p_b1_grid_{p_b1_grid},
p_c_grid_{p_c_grid}, p_c_grid_{p_c_grid},
...@@ -871,7 +912,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -871,7 +912,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
b_grid_desc_bk0_n_bk1_{ b_grid_desc_bk0_n_bk1_{
DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)}, DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)},
z_grid_desc_m_n_{MakeZGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)}, z_grid_desc_m_n_{MakeZGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
b1_grid_desc_bk0_n_bk1_{DeviceOp::MakeB1GridDescriptor_BK0_N_BK1( b1_grid_desc_bk0_n_bk1_{DeviceOp::MakeVGridDescriptor_O0_N_O1(
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)}, b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
y_grid_desc_m_o_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths, y_grid_desc_m_o_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)}, c_gs_ms_gemm1ns_strides)},
...@@ -916,22 +957,35 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -916,22 +957,35 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
c_mz_gemm1nz_strides_{c_gs_ms_gemm1ns_strides[NumDimG + NumDimM - 1], c_mz_gemm1nz_strides_{c_gs_ms_gemm1ns_strides[NumDimG + NumDimM - 1],
c_gs_ms_gemm1ns_strides[NumDimG + NumDimM + NumDimO - 1]}, c_gs_ms_gemm1ns_strides[NumDimG + NumDimM + NumDimO - 1]},
batch_count_{c_grid_desc_g_m_n_.GetLength(I0)}, batch_count_{c_grid_desc_g_m_n_.GetLength(I0)},
compute_base_ptr_of_batch_{
a_grid_desc_g_m_k_,
b_grid_desc_g_n_k_,
z_grid_desc_g_m_n_,
b1_grid_desc_g_n_k_,
c_grid_desc_g_m_n_,
type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize())},
p_drop_{p_drop} p_drop_{p_drop}
{ {
// TODO: implement bias addition // TODO: implement bias addition
ignore = p_acc0_biases; ignore = p_acc1_bias;
ignore = p_acc1_biases; ignore = acc1_bias_gs_ms_gemm1ns_lengths;
ignore = acc0_biases_gs_ms_ns_lengths; ignore = acc1_bias_gs_ms_gemm1ns_strides;
ignore = acc0_biases_gs_ms_ns_strides;
ignore = acc1_biases_gs_ms_gemm1ns_lengths; if constexpr(!is_same<D0DataType, void>::value)
ignore = acc1_biases_gs_ms_gemm1ns_strides; {
const auto d0_grid_desc_m_n = MakeD0GridDescriptor_M_N(acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides);
d0_grid_desc_m0_n0_m1_m2_n1_m3_ =
GridwiseGemm::MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(d0_grid_desc_m_n);
d0_grid_desc_g_m_n_ = Transform::MakeCGridDescriptor_G_M_N(
acc0_bias_gs_ms_ns_lengths, acc0_bias_gs_ms_ns_strides);
d0_n_length_stride_.push_back(acc0_bias_gs_ms_ns_lengths[NumDimG + NumDimM]);
d0_n_length_stride_.push_back(acc0_bias_gs_ms_ns_strides[NumDimG + NumDimM]);
}
compute_base_ptr_of_batch_ = ComputeBasePtrOfStridedBatch(
a_grid_desc_g_m_k_,
b_grid_desc_g_n_k_,
d0_grid_desc_g_m_n_,
z_grid_desc_g_m_n_,
b1_grid_desc_g_n_k_,
c_grid_desc_g_m_n_,
type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize()));
seed_ = std::get<0>(seeds); seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds); offset_ = std::get<1>(seeds);
...@@ -975,6 +1029,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -975,6 +1029,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// pointers // pointers
const InputDataType* p_a_grid_; const InputDataType* p_a_grid_;
const InputDataType* p_b_grid_; const InputDataType* p_b_grid_;
const D0DataType* p_d0_grid_;
ZDataType* p_z_grid_; ZDataType* p_z_grid_;
const InputDataType* p_b1_grid_; const InputDataType* p_b1_grid_;
const InputDataType* p_c_grid_; const InputDataType* p_c_grid_;
...@@ -988,6 +1043,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -988,6 +1043,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// tensor descriptor // tensor descriptor
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::D0GridDescriptor_M0_N0_M1_M2_N1_M3 d0_grid_desc_m0_n0_m1_m2_n1_m3_;
ZGridDesc_M_N z_grid_desc_m_n_; ZGridDesc_M_N z_grid_desc_m_n_;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_; B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
YGridDesc_M_O y_grid_desc_m_o_; YGridDesc_M_O y_grid_desc_m_o_;
...@@ -1000,6 +1056,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1000,6 +1056,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// batch offsets // batch offsets
AGridDesc_G_M_K a_grid_desc_g_m_k_; AGridDesc_G_M_K a_grid_desc_g_m_k_;
BGridDesc_G_N_K b_grid_desc_g_n_k_; BGridDesc_G_N_K b_grid_desc_g_n_k_;
D0GridDesc_G_M_N d0_grid_desc_g_m_n_;
B1GridDesc_G_N_K b1_grid_desc_g_n_k_; B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_; CGridDesc_G_M_N c_grid_desc_g_m_n_;
ZGridDesc_G_M_N z_grid_desc_g_m_n_; ZGridDesc_G_M_N z_grid_desc_g_m_n_;
...@@ -1039,6 +1096,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1039,6 +1096,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
index_t m_raw_padded_; index_t m_raw_padded_;
index_t n_raw_padded_; index_t n_raw_padded_;
// raw data
std::vector<ck::index_t> d0_n_length_stride_;
}; };
// Invoker // Invoker
...@@ -1103,6 +1163,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1103,6 +1163,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
kernel_batched_multihead_attention_backward_qloop_xdl_cshuffle_light_v2< kernel_batched_multihead_attention_backward_qloop_xdl_cshuffle_light_v2<
GridwiseGemm, GridwiseGemm,
InputDataType, InputDataType,
D0DataType,
OutputDataType, OutputDataType,
ZDataType, ZDataType,
LSEDataType, LSEDataType,
...@@ -1114,6 +1175,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1114,6 +1175,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
CElementwiseOperation, CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::D0GridDescriptor_M0_N0_M1_M2_N1_M3,
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3, typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3,
DeviceOp::B1GridDesc_BK0_N_BK1, DeviceOp::B1GridDesc_BK0_N_BK1,
DeviceOp::LSEGridDesc_M, DeviceOp::LSEGridDesc_M,
...@@ -1133,6 +1195,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1133,6 +1195,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
0, 0,
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_d0_grid_,
arg.p_z_grid_, arg.p_z_grid_,
arg.p_b1_grid_, arg.p_b1_grid_,
arg.p_lse_grid_, arg.p_lse_grid_,
...@@ -1148,6 +1211,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1148,6 +1211,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
arg.c_element_op_, arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.d0_grid_desc_m0_n0_m1_m2_n1_m3_,
arg.c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_, arg.c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg.b1_grid_desc_bk0_n_bk1_, arg.b1_grid_desc_bk0_n_bk1_,
arg.lse_grid_desc_m_, arg.lse_grid_desc_m_,
...@@ -1218,17 +1282,30 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1218,17 +1282,30 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// TODO: Check if tensor specialization & strides mismatch // TODO: Check if tensor specialization & strides mismatch
// Check if C permute dimension matches GEMM + GEMM shape // Check if C permute dimension matches GEMM + GEMM shape
const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded
const index_t c_m = arg.y_grid_desc_m_o_.GetLength(I0); const index_t c_m = arg.y_grid_desc_m_o_.GetLength(I0);
const index_t c_gemm1n = arg.y_grid_desc_m_o_.GetLength(I1); const index_t c_gemm1n = arg.y_grid_desc_m_o_.GetLength(I1);
const index_t a_m = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1); const index_t a_m = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t b1_gemm1n = arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1); const index_t b1_gemm1n =
arg.b1_grid_desc_bk0_n_bk1_.GetLength(I0) * arg.b1_grid_desc_bk0_n_bk1_.GetLength(I2);
if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n)) if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n))
{ {
return false; return false;
} }
if constexpr(!is_same<D0DataType, void>::value)
{
if(arg.d0_n_length_stride_[1] == 1 &&
arg.d0_n_length_stride_[0] % D0BlockTransferSrcScalarPerVector != 0)
{
return false;
}
if(arg.d0_n_length_stride_[1] != 1 && D0BlockTransferSrcScalarPerVector != 1)
{
return false;
}
}
// Note: we need raw lengths since threadwise copy can not handle vector load when part of // Note: we need raw lengths since threadwise copy can not handle vector load when part of
// vector is out of bounds // vector is out of bounds
// Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O // Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O
...@@ -1279,44 +1356,44 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1279,44 +1356,44 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg)); return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
} }
static auto MakeArgument( static auto
const InputDataType* p_a, MakeArgument(const InputDataType* p_a,
const InputDataType* p_b, const InputDataType* p_b,
ZDataType* p_z, ZDataType* p_z,
const InputDataType* p_b1, const InputDataType* p_b1,
const InputDataType* p_c, const InputDataType* p_c,
const LSEDataType* p_lse, const LSEDataType* p_lse,
DDataType* p_d_grid, DDataType* p_d_grid,
const InputDataType* p_ygrad_grid, const InputDataType* p_ygrad_grid,
OutputDataType* p_qgrad_grid, OutputDataType* p_qgrad_grid,
OutputDataType* p_kgrad_grid, OutputDataType* p_kgrad_grid,
OutputDataType* p_vgrad_grid, OutputDataType* p_vgrad_grid,
const std::array<void*, NumAcc0Bias> p_acc0_biases, const D0DataType* p_acc0_bias,
const std::array<void*, NumAcc1Bias> p_acc1_biases, const D1DataType* p_acc1_bias,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_strides, const std::vector<index_t>& b_gs_ns_ks_strides,
const std::vector<index_t>& z_gs_ms_ns_lengths, const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides, const std::vector<index_t>& z_gs_ms_ns_strides,
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths, const std::vector<index_t>& lse_gs_ms_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths, const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides, const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias> const std::vector<ck::index_t>&
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths acc1_bias_gs_ms_gemm1ns_lengths, // acc1_bias_gs_ms_os_lengths
const std::array<std::vector<ck::index_t>, NumAcc1Bias> const std::vector<ck::index_t>&
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides acc1_bias_gs_ms_gemm1ns_strides, // acc1_bias_gs_ms_os_strides
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op, AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
float p_drop, float p_drop,
std::tuple<unsigned long long, unsigned long long> seeds) std::tuple<unsigned long long, unsigned long long> seeds)
{ {
return Argument{p_a, return Argument{p_a,
p_b, p_b,
...@@ -1329,8 +1406,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1329,8 +1406,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
p_qgrad_grid, p_qgrad_grid,
p_kgrad_grid, p_kgrad_grid,
p_vgrad_grid, p_vgrad_grid,
p_acc0_biases, p_acc0_bias,
p_acc1_biases, p_acc1_bias,
a_gs_ms_ks_lengths, a_gs_ms_ks_lengths,
a_gs_ms_ks_strides, a_gs_ms_ks_strides,
b_gs_ns_ks_lengths, b_gs_ns_ks_lengths,
...@@ -1342,10 +1419,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1342,10 +1419,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
lse_gs_ms_lengths, lse_gs_ms_lengths,
acc0_biases_gs_ms_ns_lengths, acc0_bias_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides, acc0_bias_gs_ms_ns_strides,
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths acc1_bias_gs_ms_gemm1ns_lengths, // acc1_bias_gs_ms_os_lengths
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides acc1_bias_gs_ms_gemm1ns_strides, // acc1_bias_gs_ms_os_strides
a_element_op, a_element_op,
b_element_op, b_element_op,
acc_element_op, acc_element_op,
...@@ -1371,8 +1448,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1371,8 +1448,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
void* p_qgrad_grid, void* p_qgrad_grid,
void* p_kgrad_grid, void* p_kgrad_grid,
void* p_vgrad_grid, void* p_vgrad_grid,
const std::array<void*, NumAcc0Bias> p_acc0_biases, const void* p_acc0_bias,
const std::array<void*, NumAcc1Bias> p_acc1_biases, const void* p_acc1_bias,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_lengths,
...@@ -1384,12 +1461,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1384,12 +1461,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths, const std::vector<index_t>& lse_gs_ms_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths, const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides, const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias> const std::vector<ck::index_t>&
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths acc1_bias_gs_ms_gemm1ns_lengths, // acc1_bias_gs_ms_os_lengths
const std::array<std::vector<ck::index_t>, NumAcc1Bias> const std::vector<ck::index_t>&
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides acc1_bias_gs_ms_gemm1ns_strides, // acc1_bias_gs_ms_os_strides
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op, AccElementwiseOperation acc_element_op,
...@@ -1398,41 +1475,42 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1398,41 +1475,42 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
float p_drop, float p_drop,
std::tuple<unsigned long long, unsigned long long> seeds) // override std::tuple<unsigned long long, unsigned long long> seeds) // override
{ {
return std::make_unique<Argument>(static_cast<const InputDataType*>(p_a), return std::make_unique<Argument>(
static_cast<const InputDataType*>(p_b), static_cast<const InputDataType*>(p_a),
static_cast<ZDataType*>(p_z), static_cast<const InputDataType*>(p_b),
static_cast<const InputDataType*>(p_b1), static_cast<ZDataType*>(p_z),
static_cast<const InputDataType*>(p_c), static_cast<const InputDataType*>(p_b1),
static_cast<const LSEDataType*>(p_lse), static_cast<const InputDataType*>(p_c),
static_cast<DDataType*>(p_d_grid), static_cast<const LSEDataType*>(p_lse),
static_cast<const InputDataType*>(p_ygrad_grid), static_cast<DDataType*>(p_d_grid),
static_cast<OutputDataType*>(p_qgrad_grid), static_cast<const InputDataType*>(p_ygrad_grid),
static_cast<OutputDataType*>(p_kgrad_grid), static_cast<OutputDataType*>(p_qgrad_grid),
static_cast<OutputDataType*>(p_vgrad_grid), static_cast<OutputDataType*>(p_kgrad_grid),
p_acc0_biases, // cast in struct Argument static_cast<OutputDataType*>(p_vgrad_grid),
p_acc1_biases, // cast in struct Argument static_cast<const D0DataType*>(p_acc0_bias), // cast in struct Argument
a_gs_ms_ks_lengths, static_cast<const D1DataType*>(p_acc1_bias), // cast in struct Argument
a_gs_ms_ks_strides, a_gs_ms_ks_lengths,
b_gs_ns_ks_lengths, a_gs_ms_ks_strides,
b_gs_ns_ks_strides, b_gs_ns_ks_lengths,
z_gs_ms_ns_lengths, b_gs_ns_ks_strides,
z_gs_ms_ns_strides, z_gs_ms_ns_lengths,
b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths z_gs_ms_ns_strides,
b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
lse_gs_ms_lengths, c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
acc0_biases_gs_ms_ns_lengths, lse_gs_ms_lengths,
acc0_biases_gs_ms_ns_strides, acc0_bias_gs_ms_ns_lengths,
acc1_biases_gs_ms_gemm1ns_lengths, acc0_bias_gs_ms_ns_strides,
acc1_biases_gs_ms_gemm1ns_strides, acc1_bias_gs_ms_gemm1ns_lengths,
a_element_op, acc1_bias_gs_ms_gemm1ns_strides,
b_element_op, a_element_op,
acc_element_op, b_element_op,
b1_element_op, acc_element_op,
c_element_op, b1_element_op,
p_drop, c_element_op,
seeds); p_drop,
seeds);
} }
// polymorphic // polymorphic
...@@ -1458,6 +1536,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1458,6 +1536,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
<< MPerBlock << ", " << MPerBlock << ", "
<< Gemm1NPerBlock << ", " << Gemm1NPerBlock << ", "
<< Gemm1KPerBlock << ", " << Gemm1KPerBlock << ", "
<< Gemm2KPerBlock << ", "
<< B1K1 << ", " << B1K1 << ", "
<< getGemmSpecializationString(GemmSpec) << ", " << getGemmSpecializationString(GemmSpec) << ", "
<< "ASpec" << getTensorSpecializationString(ASpec) << ", " << "ASpec" << getTensorSpecializationString(ASpec) << ", "
......
...@@ -566,9 +566,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -566,9 +566,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
return lse_grid_desc_mraw; return lse_grid_desc_mraw;
} }
} }
// D in Gemm0 C position // D0 in Gemm0 C position
static auto MakeDGridDescriptor_M_N(const std::vector<index_t>& d_gs_ms_ns_lengths, static auto MakeD0GridDescriptor_M_N(const std::vector<index_t>& d_gs_ms_ns_lengths,
const std::vector<index_t>& d_gs_ms_ns_strides) const std::vector<index_t>& d_gs_ms_ns_strides)
{ {
return Transform::MakeCGridDescriptor_M_N(d_gs_ms_ns_lengths, d_gs_ms_ns_strides); return Transform::MakeCGridDescriptor_M_N(d_gs_ms_ns_lengths, d_gs_ms_ns_strides);
} }
...@@ -585,7 +585,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -585,7 +585,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using D0GridDesc_M_N = decltype(MakeDGridDescriptor_M_N({}, {})); using D0GridDesc_M_N = decltype(MakeD0GridDescriptor_M_N({}, {}));
using KGridDesc_N_K = decltype(Transform::MakeB0GridDescriptor_N_K({}, {})); using KGridDesc_N_K = decltype(Transform::MakeB0GridDescriptor_N_K({}, {}));
using YGradGridDesc_O0_M_O1 = decltype(MakeYGradGridDescriptor_O0_M_O1({}, {})); using YGradGridDesc_O0_M_O1 = decltype(MakeYGradGridDescriptor_O0_M_O1({}, {}));
using ZGridDesc_M_N = decltype(MakeZGridDescriptor_M_N({}, {})); using ZGridDesc_M_N = decltype(MakeZGridDescriptor_M_N({}, {}));
...@@ -857,8 +857,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -857,8 +857,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
} }
if constexpr(!is_same<D0DataType, void>::value) if constexpr(!is_same<D0DataType, void>::value)
{ {
const auto d0_grid_desc_m_n = const auto d0_grid_desc_m_n = MakeD0GridDescriptor_M_N(acc0_bias_gs_ms_ns_lengths,
MakeDGridDescriptor_M_N(acc0_bias_gs_ms_ns_lengths, acc0_bias_gs_ms_ns_strides); acc0_bias_gs_ms_ns_strides);
d0_grid_desc_m0_n0_m1_m2_n1_m3_ = d0_grid_desc_m0_n0_m1_m2_n1_m3_ =
GridwiseGemm::MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(d0_grid_desc_m_n); GridwiseGemm::MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(d0_grid_desc_m_n);
......
...@@ -518,9 +518,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -518,9 +518,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
return Transform::MakeB0GridDescriptor_BK0_N_BK1(v_grid_desc_n_o, Number<V_O1>{}); return Transform::MakeB0GridDescriptor_BK0_N_BK1(v_grid_desc_n_o, Number<V_O1>{});
} }
// D in Gemm0 C position // D0 in Gemm0 C position
static auto MakeDGridDescriptor_M_N(const std::vector<index_t>& d_gs_ms_ns_lengths, static auto MakeD0GridDescriptor_M_N(const std::vector<index_t>& d_gs_ms_ns_lengths,
const std::vector<index_t>& d_gs_ms_ns_strides) const std::vector<index_t>& d_gs_ms_ns_strides)
{ {
return Transform::MakeCGridDescriptor_M_N(d_gs_ms_ns_lengths, d_gs_ms_ns_strides); return Transform::MakeCGridDescriptor_M_N(d_gs_ms_ns_lengths, d_gs_ms_ns_strides);
} }
...@@ -594,7 +594,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -594,7 +594,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using D0GridDesc_M_N = decltype(MakeDGridDescriptor_M_N({}, {})); using D0GridDesc_M_N = decltype(MakeD0GridDescriptor_M_N({}, {}));
using KGridDesc_N_K = decltype(Transform::MakeB0GridDescriptor_N_K({}, {})); using KGridDesc_N_K = decltype(Transform::MakeB0GridDescriptor_N_K({}, {}));
using YGradGridDesc_M0_O_M1 = decltype(MakeYGradGridDescriptor_M0_O_M1(YGridDesc_M_O{})); using YGradGridDesc_M0_O_M1 = decltype(MakeYGradGridDescriptor_M0_O_M1(YGridDesc_M_O{}));
using ZGridDesc_M_N = decltype(MakeZGridDescriptor_M_N({}, {})); using ZGridDesc_M_N = decltype(MakeZGridDescriptor_M_N({}, {}));
...@@ -870,8 +870,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -870,8 +870,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
if constexpr(!is_same<D0DataType, void>::value) if constexpr(!is_same<D0DataType, void>::value)
{ {
const auto d0_grid_desc_m_n = const auto d0_grid_desc_m_n = MakeD0GridDescriptor_M_N(acc0_bias_gs_ms_ns_lengths,
MakeDGridDescriptor_M_N(acc0_bias_gs_ms_ns_lengths, acc0_bias_gs_ms_ns_strides); acc0_bias_gs_ms_ns_strides);
d0_grid_desc_m0_n0_m1_m2_n1_m3_ = d0_grid_desc_m0_n0_m1_m2_n1_m3_ =
GridwiseGemm::MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(d0_grid_desc_m_n); GridwiseGemm::MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(d0_grid_desc_m_n);
......
...@@ -82,6 +82,7 @@ __global__ void ...@@ -82,6 +82,7 @@ __global__ void
} }
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename D0DataType,
typename GroupKernelArg, typename GroupKernelArg,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
...@@ -156,6 +157,15 @@ __global__ void ...@@ -156,6 +157,15 @@ __global__ void
(arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr (arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
: arg_ptr[group_id].p_z_grid_ + z_batch_offset); : arg_ptr[group_id].p_z_grid_ + z_batch_offset);
const D0DataType* tmp_p_d0_grid = nullptr;
if constexpr(!is_same<D0DataType, void>::value)
{
const long_index_t d0_batch_offset =
__builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetD0BasePtr(g_idx)));
tmp_p_d0_grid = arg_ptr[group_id].p_d0_grid_ + d0_batch_offset;
}
if constexpr(Deterministic) if constexpr(Deterministic)
{ {
for(index_t i = 0; i < num_blocks_per_batch; i++) for(index_t i = 0; i < num_blocks_per_batch; i++)
...@@ -163,6 +173,7 @@ __global__ void ...@@ -163,6 +173,7 @@ __global__ void
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>( GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset, arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset, arg_ptr[group_id].p_b_grid_ + b_batch_offset,
tmp_p_d0_grid,
z_matrix_ptr, z_matrix_ptr,
arg_ptr[group_id].p_b1_grid_ + b1_batch_offset, arg_ptr[group_id].p_b1_grid_ + b1_batch_offset,
arg_ptr[group_id].p_lse_grid_ + lse_batch_offset, arg_ptr[group_id].p_lse_grid_ + lse_batch_offset,
...@@ -179,6 +190,7 @@ __global__ void ...@@ -179,6 +190,7 @@ __global__ void
c_element_op, c_element_op,
arg_ptr[group_id].a_grid_desc_ak0_m_ak1_, arg_ptr[group_id].a_grid_desc_ak0_m_ak1_,
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].d0_grid_desc_m0_n0_m1_m2_n1_m3_,
arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_, arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].lse_grid_desc_m_, arg_ptr[group_id].lse_grid_desc_m_,
...@@ -198,6 +210,7 @@ __global__ void ...@@ -198,6 +210,7 @@ __global__ void
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>( GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset, arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset, arg_ptr[group_id].p_b_grid_ + b_batch_offset,
tmp_p_d0_grid,
z_matrix_ptr, z_matrix_ptr,
arg_ptr[group_id].p_b1_grid_ + b1_batch_offset, arg_ptr[group_id].p_b1_grid_ + b1_batch_offset,
arg_ptr[group_id].p_lse_grid_ + lse_batch_offset, arg_ptr[group_id].p_lse_grid_ + lse_batch_offset,
...@@ -214,6 +227,7 @@ __global__ void ...@@ -214,6 +227,7 @@ __global__ void
c_element_op, c_element_op,
arg_ptr[group_id].a_grid_desc_ak0_m_ak1_, arg_ptr[group_id].a_grid_desc_ak0_m_ak1_,
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].d0_grid_desc_m0_n0_m1_m2_n1_m3_,
arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_, arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].lse_grid_desc_m_, arg_ptr[group_id].lse_grid_desc_m_,
...@@ -276,6 +290,7 @@ template <index_t NumDimG, ...@@ -276,6 +290,7 @@ template <index_t NumDimG,
index_t KPerBlock, // Gemm0KPerBlock index_t KPerBlock, // Gemm0KPerBlock
index_t Gemm1NPerBlock, index_t Gemm1NPerBlock,
index_t Gemm1KPerBlock, index_t Gemm1KPerBlock,
index_t Gemm2KPerBlock,
index_t AK1, index_t AK1,
index_t BK1, index_t BK1,
index_t B1K1, index_t B1K1,
...@@ -300,6 +315,7 @@ template <index_t NumDimG, ...@@ -300,6 +315,7 @@ template <index_t NumDimG,
index_t BBlockTransferSrcScalarPerVector, index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1, index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN, bool BBlockLdsExtraN,
index_t D0BlockTransferSrcScalarPerVector,
index_t CShuffleMXdlPerWavePerShuffle, index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
...@@ -313,12 +329,12 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -313,12 +329,12 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
"Number of dimension must be greater than 0"); "Number of dimension must be greater than 0");
static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size(); using D0DataType = Acc0BiasDataType;
static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size(); using D1DataType = Acc1BiasDataType;
static constexpr index_t DMPerBlock = BlockSize; static constexpr index_t DMPerBlock = BlockSize;
// TODO: implement bias combination // TODO: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented"); static_assert(is_same<D1DataType, void>::value, "Bias1 addition is unimplemented");
using DeviceOp = DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1; using DeviceOp = DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1;
struct ProblemDesc struct ProblemDesc
...@@ -341,19 +357,19 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -341,19 +357,19 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
std::vector<index_t> lse_gs_ms_lengths; std::vector<index_t> lse_gs_ms_lengths;
std::vector<index_t> lse_gs_ms_strides; std::vector<index_t> lse_gs_ms_strides;
std::vector<std::vector<index_t>> acc0_biases_gs_ms_ns_lengths; std::vector<index_t> acc0_bias_gs_ms_ns_lengths;
std::vector<std::vector<index_t>> acc0_biases_gs_ms_ns_strides; std::vector<index_t> acc0_bias_gs_ms_ns_strides;
std::vector<std::vector<index_t>> acc1_biases_gs_ms_os_lengths; std::vector<index_t> acc1_bias_gs_ms_os_lengths;
std::vector<std::vector<index_t>> acc1_biases_gs_ms_os_strides; std::vector<index_t> acc1_bias_gs_ms_os_strides;
}; };
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr index_t V_O1 = 8; static constexpr index_t V_O1 = BK1;
static constexpr index_t Y_O1 = 8; static constexpr index_t Y_O1 = AK1;
static constexpr index_t Y_M1 = 2; static constexpr index_t Y_M1 = B1K1;
static constexpr auto padder = GemmGemmPadder<GemmSpec, static constexpr auto padder = GemmGemmPadder<GemmSpec,
Number<MPerBlock>, Number<MPerBlock>,
...@@ -391,20 +407,20 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -391,20 +407,20 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
*/ */
// Q in Gemm A position // Q in Gemm A position
static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_gs_ms_ks_lengths_vec, static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides_vec) const std::vector<index_t>& a_gs_ms_ks_strides)
{ {
return Transform::MakeAGridDescriptor_AK0_M_AK1( return Transform::MakeAGridDescriptor_AK0_M_AK1(
Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec), Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides),
Number<AK1>{}); Number<AK1>{});
} }
// K in Gemm B0 position // K in Gemm B0 position
static auto MakeBGridDescriptor_BK0_N_BK1(const std::vector<index_t>& b_gs_ns_ks_lengths_vec, static auto MakeBGridDescriptor_BK0_N_BK1(const std::vector<index_t>& b_gs_ns_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_strides_vec) const std::vector<index_t>& b_gs_ns_ks_strides)
{ {
return Transform::MakeB0GridDescriptor_BK0_N_BK1( return Transform::MakeB0GridDescriptor_BK0_N_BK1(
Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths_vec, b_gs_ns_ks_strides_vec), Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides),
Number<BK1>{}); Number<BK1>{});
} }
// //
...@@ -412,8 +428,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -412,8 +428,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
// //
// VGrad in Gemm C position // VGrad in Gemm C position
static auto MakeVGradGridDescriptor_N_O(const std::vector<index_t>& v_gs_os_ns_lengths_vec, static auto MakeVGradGridDescriptor_N_O(const std::vector<index_t>& v_gs_os_ns_lengths,
const std::vector<index_t>& v_gs_os_ns_strides_vec) const std::vector<index_t>& v_gs_os_ns_strides)
{ {
// v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major. // v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major.
// Here directly rearrange lengths/strides before constructing tensor descriptor to reduce // Here directly rearrange lengths/strides before constructing tensor descriptor to reduce
...@@ -439,17 +455,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -439,17 +455,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
ids_old2new.insert(ids_old2new.end(), ns_ids.begin(), ns_ids.end()); ids_old2new.insert(ids_old2new.end(), ns_ids.begin(), ns_ids.end());
ids_old2new.insert(ids_old2new.end(), os_ids.begin(), os_ids.end()); ids_old2new.insert(ids_old2new.end(), os_ids.begin(), os_ids.end());
std::vector<index_t> v_gs_ns_os_lengths_vec(num_dims), v_gs_ns_os_strides_vec(num_dims); std::vector<index_t> v_gs_ns_os_lengths(num_dims), v_gs_ns_os_strides(num_dims);
for(int i = 0; i < num_dims; i++) for(int i = 0; i < num_dims; i++)
{ {
index_t id_new = ids_old2new[i]; index_t id_new = ids_old2new[i];
v_gs_ns_os_lengths_vec[i] = v_gs_os_ns_lengths_vec[id_new]; v_gs_ns_os_lengths[i] = v_gs_os_ns_lengths[id_new];
v_gs_ns_os_strides_vec[i] = v_gs_os_ns_strides_vec[id_new]; v_gs_ns_os_strides[i] = v_gs_os_ns_strides[id_new];
} }
const auto vgrad_desc_nraw_oraw = const auto vgrad_desc_nraw_oraw =
MakeGridDescriptorPair<NumDimG, NumDimN, NumDimO, TensorSpecialization::Default>( MakeGridDescriptorPair<NumDimG, NumDimN, NumDimO, TensorSpecialization::Default>(
v_gs_ns_os_lengths_vec, v_gs_ns_os_strides_vec) v_gs_ns_os_lengths, v_gs_ns_os_strides)
.second; .second;
return PadTensorDescriptor(vgrad_desc_nraw_oraw, return PadTensorDescriptor(vgrad_desc_nraw_oraw,
...@@ -460,17 +476,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -460,17 +476,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
// //
// dQ = alpha * dS * K // dQ = alpha * dS * K
// //
static auto MakeYGradGridDescriptor_O0_M_O1(const std::vector<index_t>& y_gs_ms_os_lengths_vec, static auto MakeYGradGridDescriptor_O0_M_O1(const std::vector<index_t>& y_gs_ms_os_lengths,
const std::vector<index_t>& y_gs_ms_os_strides_vec) const std::vector<index_t>& y_gs_ms_os_strides)
{ {
return Transform::MakeAGridDescriptor_AK0_M_AK1( return Transform::MakeAGridDescriptor_AK0_M_AK1(
Transform::MakeAGridDescriptor_M_K(y_gs_ms_os_lengths_vec, y_gs_ms_os_strides_vec), Transform::MakeAGridDescriptor_M_K(y_gs_ms_os_lengths, y_gs_ms_os_strides),
Number<Y_O1>{}); Number<Y_O1>{});
} }
// V in Gemm B position // V in Gemm B position
static auto MakeVGridDescriptor_O0_N_O1(const std::vector<index_t>& v_gs_os_ns_lengths_vec, static auto MakeVGridDescriptor_O0_N_O1(const std::vector<index_t>& v_gs_os_ns_lengths,
const std::vector<index_t>& v_gs_os_ns_strides_vec) const std::vector<index_t>& v_gs_os_ns_strides)
{ {
// v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major. // v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major.
// Here directly rearrange lengths/strides before constructing tensor descriptor to reduce // Here directly rearrange lengths/strides before constructing tensor descriptor to reduce
...@@ -496,17 +512,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -496,17 +512,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
ids_old2new.insert(ids_old2new.end(), ns_ids.begin(), ns_ids.end()); ids_old2new.insert(ids_old2new.end(), ns_ids.begin(), ns_ids.end());
ids_old2new.insert(ids_old2new.end(), os_ids.begin(), os_ids.end()); ids_old2new.insert(ids_old2new.end(), os_ids.begin(), os_ids.end());
std::vector<index_t> v_gs_ns_os_lengths_vec(num_dims), v_gs_ns_os_strides_vec(num_dims); std::vector<index_t> v_gs_ns_os_lengths(num_dims), v_gs_ns_os_strides(num_dims);
for(int i = 0; i < num_dims; i++) for(int i = 0; i < num_dims; i++)
{ {
index_t id_new = ids_old2new[i]; index_t id_new = ids_old2new[i];
v_gs_ns_os_lengths_vec[i] = v_gs_os_ns_lengths_vec[id_new]; v_gs_ns_os_lengths[i] = v_gs_os_ns_lengths[id_new];
v_gs_ns_os_strides_vec[i] = v_gs_os_ns_strides_vec[id_new]; v_gs_ns_os_strides[i] = v_gs_os_ns_strides[id_new];
} }
const auto v_grid_desc_nraw_oraw = const auto v_grid_desc_nraw_oraw =
MakeGridDescriptorPair<NumDimG, NumDimN, NumDimO, TensorSpecialization::Default>( MakeGridDescriptorPair<NumDimG, NumDimN, NumDimO, TensorSpecialization::Default>(
v_gs_ns_os_lengths_vec, v_gs_ns_os_strides_vec) v_gs_ns_os_lengths, v_gs_ns_os_strides)
.second; .second;
const auto v_grid_desc_n_o = PadTensorDescriptor(v_grid_desc_nraw_oraw, const auto v_grid_desc_n_o = PadTensorDescriptor(v_grid_desc_nraw_oraw,
...@@ -517,10 +533,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -517,10 +533,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
return Transform::MakeB0GridDescriptor_BK0_N_BK1(v_grid_desc_n_o, Number<V_O1>{}); return Transform::MakeB0GridDescriptor_BK0_N_BK1(v_grid_desc_n_o, Number<V_O1>{});
} }
static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths_vec, static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides_vec) const std::vector<index_t>& z_gs_ms_ns_strides)
{ {
return Transform::MakeCGridDescriptor_M_N(z_gs_ms_ns_lengths_vec, z_gs_ms_ns_strides_vec); return Transform::MakeCGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
} }
static auto MakeLSEGridDescriptor_M(index_t MRaw) static auto MakeLSEGridDescriptor_M(index_t MRaw)
...@@ -547,6 +563,23 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -547,6 +563,23 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
return lse_grid_desc_mraw; return lse_grid_desc_mraw;
} }
} }
// D0 in Gemm0 C position
static auto MakeD0GridDescriptor_M_N(const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides)
{
return Transform::MakeCGridDescriptor_M_N(acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides);
}
static auto
MakeD0GridDescriptor_G_M_N(const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides)
{
return Transform::MakeCGridDescriptor_G_M_N(acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides);
}
static auto MakeDGridDescriptor_M(index_t MRaw) static auto MakeDGridDescriptor_M(index_t MRaw)
{ {
...@@ -580,11 +613,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -580,11 +613,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1)); using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1));
using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {})); using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {}));
using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {})); using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {}));
using D0GridDesc_G_M_N = decltype(MakeD0GridDescriptor_G_M_N({}, {}));
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {})); using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using KGridDesc_N_K = decltype(Transform::MakeB0GridDescriptor_N_K({}, {})); using KGridDesc_N_K = decltype(Transform::MakeB0GridDescriptor_N_K({}, {}));
using D0GridDesc_M_N = decltype(MakeD0GridDescriptor_M_N({}, {}));
using YGradGridDesc_O0_M_O1 = decltype(MakeYGradGridDescriptor_O0_M_O1({}, {})); using YGradGridDesc_O0_M_O1 = decltype(MakeYGradGridDescriptor_O0_M_O1({}, {}));
using ZGridDesc_M_N = decltype(MakeZGridDescriptor_M_N({}, {})); using ZGridDesc_M_N = decltype(MakeZGridDescriptor_M_N({}, {}));
...@@ -612,12 +647,14 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -612,12 +647,14 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
{ {
ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k, ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k,
const BGridDesc_G_N_K& b_grid_desc_g_n_k, const BGridDesc_G_N_K& b_grid_desc_g_n_k,
const D0GridDesc_G_M_N& d0_grid_desc_g_m_n,
const ZGridDesc_G_M_N& z_grid_desc_g_m_n, const ZGridDesc_G_M_N& z_grid_desc_g_m_n,
const B1GridDesc_G_N_K& b1_grid_desc_g_n_k, const B1GridDesc_G_N_K& b1_grid_desc_g_n_k,
const CGridDesc_G_M_N& c_grid_desc_g_m_n, const CGridDesc_G_M_N& c_grid_desc_g_m_n,
index_t batch_stride_lse) index_t batch_stride_lse)
: a_grid_desc_g_m_k_(a_grid_desc_g_m_k), : a_grid_desc_g_m_k_(a_grid_desc_g_m_k),
b_grid_desc_g_n_k_(b_grid_desc_g_n_k), b_grid_desc_g_n_k_(b_grid_desc_g_n_k),
d0_grid_desc_g_m_n_(d0_grid_desc_g_m_n),
z_grid_desc_g_m_n_(z_grid_desc_g_m_n), z_grid_desc_g_m_n_(z_grid_desc_g_m_n),
b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k), b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k),
c_grid_desc_g_m_n_(c_grid_desc_g_m_n), c_grid_desc_g_m_n_(c_grid_desc_g_m_n),
...@@ -635,6 +672,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -635,6 +672,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
return b_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0)); return b_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
} }
__host__ __device__ constexpr long_index_t GetD0BasePtr(index_t g_idx) const
{
return d0_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
__host__ __device__ constexpr long_index_t GetZBasePtr(index_t g_idx) const __host__ __device__ constexpr long_index_t GetZBasePtr(index_t g_idx) const
{ {
return z_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0)); return z_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
...@@ -658,6 +700,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -658,6 +700,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
private: private:
AGridDesc_G_M_K a_grid_desc_g_m_k_; AGridDesc_G_M_K a_grid_desc_g_m_k_;
BGridDesc_G_N_K b_grid_desc_g_n_k_; BGridDesc_G_N_K b_grid_desc_g_n_k_;
D0GridDesc_G_M_N d0_grid_desc_g_m_n_;
ZGridDesc_G_M_N z_grid_desc_g_m_n_; ZGridDesc_G_M_N z_grid_desc_g_m_n_;
B1GridDesc_G_N_K b1_grid_desc_g_n_k_; B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_; CGridDesc_G_M_N c_grid_desc_g_m_n_;
...@@ -667,6 +710,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -667,6 +710,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1< using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1<
InputDataType, // TODO: distinguish A/B datatype InputDataType, // TODO: distinguish A/B datatype
D0DataType,
OutputDataType, OutputDataType,
ZDataType, ZDataType,
GemmDataType, GemmDataType,
...@@ -683,6 +727,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -683,6 +727,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
AGridDesc_AK0_M_AK1, AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1, BGridDesc_BK0_N_BK1,
KGridDesc_N_K, KGridDesc_N_K,
D0GridDesc_M_N,
ZGridDesc_M_N, ZGridDesc_M_N,
B1GridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1,
YGridDesc_M_O, YGridDesc_M_O,
...@@ -694,6 +739,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -694,6 +739,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
KPerBlock, KPerBlock,
Gemm1NPerBlock, Gemm1NPerBlock,
Gemm1KPerBlock, Gemm1KPerBlock,
Gemm2KPerBlock,
AK1, AK1,
BK1, BK1,
B1K1, B1K1,
...@@ -719,6 +765,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -719,6 +765,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
BBlockTransferDstScalarPerVector_BK1, BBlockTransferDstScalarPerVector_BK1,
true, true,
BBlockLdsExtraN, BBlockLdsExtraN,
D0BlockTransferSrcScalarPerVector,
CShuffleMXdlPerWavePerShuffle, CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
...@@ -747,6 +794,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -747,6 +794,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
// pointers // pointers
const InputDataType* p_a_grid_; const InputDataType* p_a_grid_;
const InputDataType* p_b_grid_; const InputDataType* p_b_grid_;
const D0DataType* p_d0_grid_;
ZDataType* p_z_grid_; ZDataType* p_z_grid_;
const InputDataType* p_b1_grid_; const InputDataType* p_b1_grid_;
const InputDataType* p_c_grid_; const InputDataType* p_c_grid_;
...@@ -759,6 +807,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -759,6 +807,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
// tensor descriptors for block/thread-wise copy // tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::D0GridDescriptor_M0_N0_M1_M2_N1_M3 d0_grid_desc_m0_n0_m1_m2_n1_m3_;
ZGridDesc_M_N z_grid_desc_m_n_; ZGridDesc_M_N z_grid_desc_m_n_;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_; B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
YGridDesc_M_O y_grid_desc_m_o_; YGridDesc_M_O y_grid_desc_m_o_;
...@@ -805,6 +854,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -805,6 +854,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
CGridDesc_G_M_N c_grid_desc_g_m_n_; CGridDesc_G_M_N c_grid_desc_g_m_n_;
index_t batch_count_; index_t batch_count_;
// raw data
std::vector<ck::index_t> d0_n_length_stride_;
}; };
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
...@@ -820,8 +872,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -820,8 +872,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
std::vector<void*>& p_Qgrads, std::vector<void*>& p_Qgrads,
std::vector<void*>& p_Kgrads, std::vector<void*>& p_Kgrads,
std::vector<void*>& p_Vgrads, std::vector<void*>& p_Vgrads,
const std::array<void*, NumAcc0Bias>& p_acc0_biases, const std::vector<const void*>& p_acc0_bias_vec,
const std::array<void*, NumAcc1Bias>& p_acc1_biases, const std::vector<const void*>& p_acc1_bias_vec,
const std::vector<ProblemDesc>& problem_desc_vec, const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -852,16 +904,14 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -852,16 +904,14 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
group_count_ == ck::type_convert<ck::index_t>(p_Kgrads.size()) && group_count_ == ck::type_convert<ck::index_t>(p_Kgrads.size()) &&
group_count_ == ck::type_convert<ck::index_t>(p_Vgrads.size()) && group_count_ == ck::type_convert<ck::index_t>(p_Vgrads.size()) &&
group_count_ == ck::type_convert<ck::index_t>(p_LSEs.size()) && group_count_ == ck::type_convert<ck::index_t>(p_LSEs.size()) &&
group_count_ == ck::type_convert<ck::index_t>(p_Ds.size()))) group_count_ == ck::type_convert<ck::index_t>(p_Ds.size()) &&
(group_count_ == ck::type_convert<ck::index_t>(p_acc0_bias_vec.size()) ||
ck::type_convert<ck::index_t>(p_acc0_bias_vec.size() == 0)) &&
0 == p_acc1_bias_vec.size()))
{ {
throw std::runtime_error("wrong! group_count_ != p_As/b/b1/c.size"); throw std::runtime_error("wrong! group_count_ != p_As/b/b1/c.size");
} }
if(!(p_acc0_biases.size() == p_acc1_biases.size()))
{
throw std::runtime_error("wrong! acc0_bias_vec.size != acc1_bias_vec.size");
}
grid_size_ = 0; grid_size_ = 0;
index_t z_random_matrix_offset = 0; index_t z_random_matrix_offset = 0;
...@@ -870,8 +920,12 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -870,8 +920,12 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
for(index_t i = 0; i < group_count_; i++) for(index_t i = 0; i < group_count_; i++)
{ {
const auto p_a_grid = static_cast<const InputDataType*>(p_As[i]); const auto p_a_grid = static_cast<const InputDataType*>(p_As[i]);
const auto p_b_grid = static_cast<const InputDataType*>(p_Bs[i]); const auto p_b_grid = static_cast<const InputDataType*>(p_Bs[i]);
const auto p_d0_grid =
(ck::type_convert<ck::index_t>(p_acc0_bias_vec.size()) == group_count_)
? static_cast<const D0DataType*>(p_acc0_bias_vec[i])
: nullptr;
auto p_z_grid = static_cast<ZDataType*>(p_Zs[i]); auto p_z_grid = static_cast<ZDataType*>(p_Zs[i]);
const auto p_b1_grid = static_cast<const InputDataType*>(p_B1s[i]); const auto p_b1_grid = static_cast<const InputDataType*>(p_B1s[i]);
const auto p_c_grid = static_cast<const InputDataType*>(p_Cs[i]); const auto p_c_grid = static_cast<const InputDataType*>(p_Cs[i]);
...@@ -887,6 +941,23 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -887,6 +941,23 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides); problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides);
const auto b_grid_desc_bk0_n_bk1 = DeviceOp::MakeBGridDescriptor_BK0_N_BK1( const auto b_grid_desc_bk0_n_bk1 = DeviceOp::MakeBGridDescriptor_BK0_N_BK1(
problem_desc.b_gs_ns_ks_lengths, problem_desc.b_gs_ns_ks_strides); problem_desc.b_gs_ns_ks_lengths, problem_desc.b_gs_ns_ks_strides);
std::vector<index_t> tmp_d0_gs_ms_ns_lengths;
std::vector<index_t> tmp_d0_gs_ms_ns_strides;
if constexpr(!is_same<D0DataType, void>::value)
{
tmp_d0_gs_ms_ns_lengths = problem_desc.acc0_bias_gs_ms_ns_lengths;
tmp_d0_gs_ms_ns_strides = problem_desc.acc0_bias_gs_ms_ns_strides;
}
else
{
tmp_d0_gs_ms_ns_lengths = {1, 1, 1, 1};
tmp_d0_gs_ms_ns_strides = {0, 0, 0, 0};
}
const D0GridDesc_M_N d0_grid_desc_m_n{DeviceOp::MakeD0GridDescriptor_M_N(
tmp_d0_gs_ms_ns_lengths, tmp_d0_gs_ms_ns_strides)};
const auto d0_grid_desc_m0_n0_m1_m2_n1_m3 =
GridwiseGemm::MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(d0_grid_desc_m_n);
const auto z_grid_desc_m_n = DeviceOp::MakeZGridDescriptor_M_N( const auto z_grid_desc_m_n = DeviceOp::MakeZGridDescriptor_M_N(
problem_desc.z_gs_ms_ns_lengths, problem_desc.z_gs_ms_ns_strides); problem_desc.z_gs_ms_ns_lengths, problem_desc.z_gs_ms_ns_strides);
const auto b1_grid_desc_bk0_n_bk1 = DeviceOp::MakeVGridDescriptor_O0_N_O1( const auto b1_grid_desc_bk0_n_bk1 = DeviceOp::MakeVGridDescriptor_O0_N_O1(
...@@ -906,6 +977,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -906,6 +977,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides); problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides);
const auto b_grid_desc_g_n_k = Transform::MakeB0GridDescriptor_G_N_K( const auto b_grid_desc_g_n_k = Transform::MakeB0GridDescriptor_G_N_K(
problem_desc.b_gs_ns_ks_lengths, problem_desc.b_gs_ns_ks_strides); problem_desc.b_gs_ns_ks_lengths, problem_desc.b_gs_ns_ks_strides);
const auto d0_grid_desc_g_m_n = DeviceOp::MakeD0GridDescriptor_G_M_N(
tmp_d0_gs_ms_ns_lengths, tmp_d0_gs_ms_ns_strides);
const auto z_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N( const auto z_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N(
problem_desc.z_gs_ms_ns_lengths, problem_desc.z_gs_ms_ns_strides); problem_desc.z_gs_ms_ns_lengths, problem_desc.z_gs_ms_ns_strides);
const auto b1_grid_desc_g_n_k = Transform::MakeB1GridDescriptor_G_N_K( const auto b1_grid_desc_g_n_k = Transform::MakeB1GridDescriptor_G_N_K(
...@@ -931,6 +1004,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -931,6 +1004,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
const auto compute_base_ptr_of_batch = ComputeBasePtrOfStridedBatch( const auto compute_base_ptr_of_batch = ComputeBasePtrOfStridedBatch(
a_grid_desc_g_m_k, a_grid_desc_g_m_k,
b_grid_desc_g_n_k, b_grid_desc_g_n_k,
d0_grid_desc_g_m_n,
z_grid_desc_g_m_n, z_grid_desc_g_m_n,
b1_grid_desc_g_n_k, b1_grid_desc_g_n_k,
c_grid_desc_g_m_n, c_grid_desc_g_m_n,
...@@ -942,18 +1016,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -942,18 +1016,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
grid_size_ += grid_size_grp; grid_size_ += grid_size_grp;
// for each group, make sure acc0_biases_gs_ms_ns_lengths.size() == NumAcc0Bias and
// so on
if(!(problem_desc.acc0_biases_gs_ms_ns_lengths.size() == NumAcc0Bias &&
problem_desc.acc0_biases_gs_ms_ns_strides.size() == NumAcc0Bias &&
problem_desc.acc1_biases_gs_ms_os_lengths.size() == NumAcc1Bias &&
problem_desc.acc1_biases_gs_ms_os_strides.size() == NumAcc1Bias))
{
throw std::runtime_error(
"wrong! number of biases in function argument does not "
"match that in template argument");
}
const auto raw_m_padded = GridwiseGemm::GetPaddedSize( const auto raw_m_padded = GridwiseGemm::GetPaddedSize(
problem_desc.a_gs_ms_ks_lengths[NumDimG + NumDimM - 1]); problem_desc.a_gs_ms_ks_lengths[NumDimG + NumDimM - 1]);
const auto raw_n_padded = GridwiseGemm::GetPaddedSize( const auto raw_n_padded = GridwiseGemm::GetPaddedSize(
...@@ -980,6 +1042,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -980,6 +1042,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
group_kernel_args_.push_back({p_a_grid, group_kernel_args_.push_back({p_a_grid,
p_b_grid, p_b_grid,
p_d0_grid,
p_z_grid, p_z_grid,
p_b1_grid, p_b1_grid,
p_c_grid, p_c_grid,
...@@ -990,6 +1053,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -990,6 +1053,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
p_vgrad_grid, p_vgrad_grid,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
d0_grid_desc_m0_n0_m1_m2_n1_m3,
z_grid_desc_m_n, z_grid_desc_m_n,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
y_grid_desc_m_o, y_grid_desc_m_o,
...@@ -1017,6 +1081,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1017,6 +1081,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
z_random_matrix_offset = z_random_matrix_offset =
z_random_matrix_offset + raw_m_padded * raw_n_padded * batch_count; z_random_matrix_offset + raw_m_padded * raw_n_padded * batch_count;
// for check
std::vector<ck::index_t> d0_n_length_stride;
d0_n_length_stride.push_back(tmp_d0_gs_ms_ns_lengths[NumDimG + NumDimM]);
d0_n_length_stride.push_back(tmp_d0_gs_ms_ns_strides[NumDimG + NumDimM]);
group_device_args_.push_back( group_device_args_.push_back(
{{problem_desc.a_gs_ms_ks_lengths[NumDimG + NumDimM - 1], {{problem_desc.a_gs_ms_ks_lengths[NumDimG + NumDimM - 1],
problem_desc.b_gs_ns_ks_lengths[NumDimG + NumDimN - 1], problem_desc.b_gs_ns_ks_lengths[NumDimG + NumDimN - 1],
...@@ -1031,15 +1100,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1031,15 +1100,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
{problem_desc.c_gs_ms_gemm1ns_strides[NumDimG + NumDimM - 1], {problem_desc.c_gs_ms_gemm1ns_strides[NumDimG + NumDimM - 1],
problem_desc.c_gs_ms_gemm1ns_strides[NumDimG + NumDimM + NumDimO - 1]}, problem_desc.c_gs_ms_gemm1ns_strides[NumDimG + NumDimM + NumDimO - 1]},
c_grid_desc_g_m_n, c_grid_desc_g_m_n,
batch_count}); batch_count,
d0_n_length_stride});
} }
// TODO: implement bias addition // TODO: implement bias addition
// ignore = p_acc0_biases; // ignore = p_acc0_bias_vec;
// ignore = p_acc1_biases; // ignore = p_acc1_bias_vec;
// ignore = acc0_biases_gs_ms_ns_lengths; // ignore = acc0_bias_gs_ms_ns_lengths;
// ignore = acc0_biases_gs_ms_ns_strides; // ignore = acc0_bias_gs_ms_ns_strides;
// ignore = acc1_biases_gs_ms_gemm1ns_lengths; // ignore = acc1_bias_gs_ms_gemm1ns_lengths;
// ignore = acc1_biases_gs_ms_gemm1ns_strides; // ignore = acc1_bias_gs_ms_gemm1ns_strides;
} }
// element-wise op // element-wise op
...@@ -1114,6 +1184,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1114,6 +1184,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
const auto kernel = const auto kernel =
kernel_grouped_multihead_attention_backward_qloop_xdl_cshuffle_light_v1< kernel_grouped_multihead_attention_backward_qloop_xdl_cshuffle_light_v1<
GridwiseGemm, GridwiseGemm,
D0DataType,
GroupKernelArg, GroupKernelArg,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
...@@ -1211,6 +1282,19 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1211,6 +1282,19 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
return false; return false;
} }
if constexpr(!is_same<D0DataType, void>::value)
{
if(device_arg.d0_n_length_stride_[1] == 1 &&
device_arg.d0_n_length_stride_[0] % D0BlockTransferSrcScalarPerVector != 0)
{
return false;
}
if(device_arg.d0_n_length_stride_[1] != 1 && D0BlockTransferSrcScalarPerVector != 1)
{
return false;
}
}
// Note: we need raw lengths since threadwise copy can not handle vector load when part // Note: we need raw lengths since threadwise copy can not handle vector load when part
// of vector is out of bounds Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O // of vector is out of bounds Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O
const auto MzRaw = device_arg.raw_lengths_mz_nz_kz_gemm1nz_[0]; const auto MzRaw = device_arg.raw_lengths_mz_nz_kz_gemm1nz_[0];
...@@ -1279,8 +1363,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1279,8 +1363,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
std::vector<void*>& p_Qgrads, std::vector<void*>& p_Qgrads,
std::vector<void*>& p_Kgrads, std::vector<void*>& p_Kgrads,
std::vector<void*>& p_Vgrads, std::vector<void*>& p_Vgrads,
const std::array<void*, NumAcc0Bias>& p_acc0_biases, const std::vector<const void*>& p_acc0_bias_vec,
const std::array<void*, NumAcc1Bias>& p_acc1_biases, const std::vector<const void*>& p_acc1_bias_vec,
const std::vector<ProblemDesc>& problem_desc_vec, const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -1290,16 +1374,26 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1290,16 +1374,26 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
float p_drop, float p_drop,
std::tuple<unsigned long long, unsigned long long> seeds) std::tuple<unsigned long long, unsigned long long> seeds)
{ {
return Argument{p_As, p_Bs, return Argument{p_As,
p_Zs, p_B1s, p_Bs,
p_Cs, p_LSEs, p_Zs,
p_Ds, p_Ygrads, p_B1s,
p_Qgrads, p_Kgrads, p_Cs,
p_Vgrads, p_acc0_biases, p_LSEs,
p_acc1_biases, problem_desc_vec, p_Ds,
a_element_op, b_element_op, p_Ygrads,
acc_element_op, b1_element_op, p_Qgrads,
c_element_op, p_drop, p_Kgrads,
p_Vgrads,
p_acc0_bias_vec,
p_acc1_bias_vec,
problem_desc_vec,
a_element_op,
b_element_op,
acc_element_op,
b1_element_op,
c_element_op,
p_drop,
seeds}; seeds};
} }
...@@ -1319,8 +1413,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1319,8 +1413,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
std::vector<void*>& p_Qgrads, std::vector<void*>& p_Qgrads,
std::vector<void*>& p_Kgrads, std::vector<void*>& p_Kgrads,
std::vector<void*>& p_Vgrads, std::vector<void*>& p_Vgrads,
const std::array<void*, NumAcc0Bias>& p_acc0_biases, const std::vector<const void*>& p_acc0_bias_vec,
const std::array<void*, NumAcc1Bias>& p_acc1_biases, const std::vector<const void*>& p_acc1_bias_vec,
const std::vector<ProblemDesc>& problem_desc_vec, const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -1341,8 +1435,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1341,8 +1435,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
p_Qgrads, p_Qgrads,
p_Kgrads, p_Kgrads,
p_Vgrads, p_Vgrads,
p_acc0_biases, // cast in struct Argument p_acc0_bias_vec, // cast in struct Argument
p_acc1_biases, // cast in struct Argument p_acc1_bias_vec, // cast in struct Argument
problem_desc_vec, problem_desc_vec,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -1376,6 +1470,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1376,6 +1470,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
<< MPerBlock << ", " << MPerBlock << ", "
<< Gemm1NPerBlock << ", " << Gemm1NPerBlock << ", "
<< Gemm1KPerBlock << ", " << Gemm1KPerBlock << ", "
<< Gemm2KPerBlock << ", "
<< B1K1 << ", " << B1K1 << ", "
<< getGemmSpecializationString(GemmSpec) << ", " << getGemmSpecializationString(GemmSpec) << ", "
<< "ASpec" << getTensorSpecializationString(ASpec) << ", " << "ASpec" << getTensorSpecializationString(ASpec) << ", "
......
...@@ -81,6 +81,7 @@ __global__ void ...@@ -81,6 +81,7 @@ __global__ void
} }
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename D0DataType,
typename GroupKernelArg, typename GroupKernelArg,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
...@@ -154,6 +155,15 @@ __global__ void ...@@ -154,6 +155,15 @@ __global__ void
auto z_matrix_ptr = auto z_matrix_ptr =
(arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr (arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
: arg_ptr[group_id].p_z_grid_ + z_batch_offset); : arg_ptr[group_id].p_z_grid_ + z_batch_offset);
const D0DataType* tmp_p_d0_grid = nullptr;
if constexpr(!is_same<D0DataType, void>::value)
{
const long_index_t d0_batch_offset =
__builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetD0BasePtr(g_idx)));
tmp_p_d0_grid = arg_ptr[group_id].p_d0_grid_ + d0_batch_offset;
}
if constexpr(Deterministic) if constexpr(Deterministic)
{ {
...@@ -162,6 +172,7 @@ __global__ void ...@@ -162,6 +172,7 @@ __global__ void
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>( GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset, arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset, arg_ptr[group_id].p_b_grid_ + b_batch_offset,
tmp_p_d0_grid,
z_matrix_ptr, z_matrix_ptr,
arg_ptr[group_id].p_b1_grid_ + b1_batch_offset, arg_ptr[group_id].p_b1_grid_ + b1_batch_offset,
arg_ptr[group_id].p_lse_grid_ + lse_batch_offset, arg_ptr[group_id].p_lse_grid_ + lse_batch_offset,
...@@ -178,6 +189,7 @@ __global__ void ...@@ -178,6 +189,7 @@ __global__ void
c_element_op, c_element_op,
arg_ptr[group_id].a_grid_desc_ak0_m_ak1_, arg_ptr[group_id].a_grid_desc_ak0_m_ak1_,
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].d0_grid_desc_m0_n0_m1_m2_n1_m3_,
arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_, arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].lse_grid_desc_m_, arg_ptr[group_id].lse_grid_desc_m_,
...@@ -197,6 +209,7 @@ __global__ void ...@@ -197,6 +209,7 @@ __global__ void
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>( GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset, arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset, arg_ptr[group_id].p_b_grid_ + b_batch_offset,
tmp_p_d0_grid,
z_matrix_ptr, z_matrix_ptr,
arg_ptr[group_id].p_b1_grid_ + b1_batch_offset, arg_ptr[group_id].p_b1_grid_ + b1_batch_offset,
arg_ptr[group_id].p_lse_grid_ + lse_batch_offset, arg_ptr[group_id].p_lse_grid_ + lse_batch_offset,
...@@ -213,6 +226,7 @@ __global__ void ...@@ -213,6 +226,7 @@ __global__ void
c_element_op, c_element_op,
arg_ptr[group_id].a_grid_desc_ak0_m_ak1_, arg_ptr[group_id].a_grid_desc_ak0_m_ak1_,
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].d0_grid_desc_m0_n0_m1_m2_n1_m3_,
arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_, arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].lse_grid_desc_m_, arg_ptr[group_id].lse_grid_desc_m_,
...@@ -275,6 +289,7 @@ template <index_t NumDimG, ...@@ -275,6 +289,7 @@ template <index_t NumDimG,
index_t KPerBlock, // Gemm0KPerBlock index_t KPerBlock, // Gemm0KPerBlock
index_t Gemm1NPerBlock, index_t Gemm1NPerBlock,
index_t Gemm1KPerBlock, index_t Gemm1KPerBlock,
index_t Gemm2KPerBlock,
index_t AK1, index_t AK1,
index_t BK1, index_t BK1,
index_t B1K1, index_t B1K1,
...@@ -299,6 +314,7 @@ template <index_t NumDimG, ...@@ -299,6 +314,7 @@ template <index_t NumDimG,
index_t BBlockTransferSrcScalarPerVector, index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1, index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN, bool BBlockLdsExtraN,
index_t D0BlockTransferSrcScalarPerVector,
typename B1BlockTransferThreadClusterLengths_BK0_N_BK1, typename B1BlockTransferThreadClusterLengths_BK0_N_BK1,
typename B1BlockTransferThreadClusterArrangeOrder, typename B1BlockTransferThreadClusterArrangeOrder,
typename B1BlockTransferSrcAccessOrder, typename B1BlockTransferSrcAccessOrder,
...@@ -319,12 +335,12 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -319,12 +335,12 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
"Number of dimension must be greater than 0"); "Number of dimension must be greater than 0");
static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size(); using D0DataType = Acc0BiasDataType;
static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size(); using D1DataType = Acc1BiasDataType;
static constexpr index_t DMPerBlock = BlockSize; static constexpr index_t DMPerBlock = BlockSize;
// TODO: implement bias combination // TODO: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented"); static_assert(is_same<D1DataType, void>::value, "Bias1 addition is unimplemented");
using DeviceOp = DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2; using DeviceOp = DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2;
struct ProblemDesc struct ProblemDesc
...@@ -347,19 +363,19 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -347,19 +363,19 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
std::vector<index_t> lse_gs_ms_lengths; std::vector<index_t> lse_gs_ms_lengths;
std::vector<index_t> lse_gs_ms_strides; std::vector<index_t> lse_gs_ms_strides;
std::vector<std::vector<index_t>> acc0_biases_gs_ms_ns_lengths; std::vector<index_t> acc0_bias_gs_ms_ns_lengths;
std::vector<std::vector<index_t>> acc0_biases_gs_ms_ns_strides; std::vector<index_t> acc0_bias_gs_ms_ns_strides;
std::vector<std::vector<index_t>> acc1_biases_gs_ms_os_lengths; std::vector<index_t> acc1_bias_gs_ms_os_lengths;
std::vector<std::vector<index_t>> acc1_biases_gs_ms_os_strides; std::vector<index_t> acc1_bias_gs_ms_os_strides;
}; };
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr index_t V_O1 = 8; static constexpr index_t V_O1 = BK1;
static constexpr index_t Y_O1 = 8; static constexpr index_t Y_O1 = AK1;
static constexpr index_t Y_M1 = 2; static constexpr index_t Y_M1 = B1K1;
static constexpr auto padder = GemmGemmPadder<GemmSpec, static constexpr auto padder = GemmGemmPadder<GemmSpec,
Number<MPerBlock>, Number<MPerBlock>,
...@@ -397,31 +413,31 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -397,31 +413,31 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
*/ */
// Q in Gemm A position // Q in Gemm A position
static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_gs_ms_ks_lengths_vec, static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides_vec) const std::vector<index_t>& a_gs_ms_ks_strides)
{ {
return Transform::MakeAGridDescriptor_AK0_M_AK1( return Transform::MakeAGridDescriptor_AK0_M_AK1(
Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec), Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides),
Number<AK1>{}); Number<AK1>{});
} }
// K in Gemm B0 position // K in Gemm B0 position
static auto MakeBGridDescriptor_BK0_N_BK1(const std::vector<index_t>& b_gs_ns_ks_lengths_vec, static auto MakeBGridDescriptor_BK0_N_BK1(const std::vector<index_t>& b_gs_ns_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_strides_vec) const std::vector<index_t>& b_gs_ns_ks_strides)
{ {
return Transform::MakeB0GridDescriptor_BK0_N_BK1( return Transform::MakeB0GridDescriptor_BK0_N_BK1(
Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths_vec, b_gs_ns_ks_strides_vec), Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides),
Number<BK1>{}); Number<BK1>{});
} }
// V in Gemm B1 position // V in Gemm B1 position
static auto static auto
MakeB1GridDescriptor_BK0_N_BK1(const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths_vec, MakeB1GridDescriptor_BK0_N_BK1(const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths,
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides_vec) const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides)
{ {
return Transform::MakeB1GridDescriptor_BK0_N_BK1( return Transform::MakeB1GridDescriptor_BK0_N_BK1(
Transform::MakeB1GridDescriptor_N_K(b1_gs_gemm1ns_gemm1ks_lengths_vec, Transform::MakeB1GridDescriptor_N_K(b1_gs_gemm1ns_gemm1ks_lengths,
b1_gs_gemm1ns_gemm1ks_strides_vec), b1_gs_gemm1ns_gemm1ks_strides),
Number<B1K1>{}); Number<B1K1>{});
} }
...@@ -430,8 +446,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -430,8 +446,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// //
// VGrad in Gemm C position // VGrad in Gemm C position
static auto MakeVGradGridDescriptor_N_O(const std::vector<index_t>& v_gs_os_ns_lengths_vec, static auto MakeVGradGridDescriptor_N_O(const std::vector<index_t>& v_gs_os_ns_lengths,
const std::vector<index_t>& v_gs_os_ns_strides_vec) const std::vector<index_t>& v_gs_os_ns_strides)
{ {
// v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major. // v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major.
// Here directly rearrange lengths/strides before constructing tensor descriptor to reduce // Here directly rearrange lengths/strides before constructing tensor descriptor to reduce
...@@ -457,17 +473,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -457,17 +473,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
ids_old2new.insert(ids_old2new.end(), ns_ids.begin(), ns_ids.end()); ids_old2new.insert(ids_old2new.end(), ns_ids.begin(), ns_ids.end());
ids_old2new.insert(ids_old2new.end(), os_ids.begin(), os_ids.end()); ids_old2new.insert(ids_old2new.end(), os_ids.begin(), os_ids.end());
std::vector<index_t> v_gs_ns_os_lengths_vec(num_dims), v_gs_ns_os_strides_vec(num_dims); std::vector<index_t> v_gs_ns_os_lengths(num_dims), v_gs_ns_os_strides(num_dims);
for(int i = 0; i < num_dims; i++) for(int i = 0; i < num_dims; i++)
{ {
index_t id_new = ids_old2new[i]; index_t id_new = ids_old2new[i];
v_gs_ns_os_lengths_vec[i] = v_gs_os_ns_lengths_vec[id_new]; v_gs_ns_os_lengths[i] = v_gs_os_ns_lengths[id_new];
v_gs_ns_os_strides_vec[i] = v_gs_os_ns_strides_vec[id_new]; v_gs_ns_os_strides[i] = v_gs_os_ns_strides[id_new];
} }
const auto vgrad_desc_nraw_oraw = const auto vgrad_desc_nraw_oraw =
MakeGridDescriptorPair<NumDimG, NumDimN, NumDimO, TensorSpecialization::Default>( MakeGridDescriptorPair<NumDimG, NumDimN, NumDimO, TensorSpecialization::Default>(
v_gs_ns_os_lengths_vec, v_gs_ns_os_strides_vec) v_gs_ns_os_lengths, v_gs_ns_os_strides)
.second; .second;
return PadTensorDescriptor(vgrad_desc_nraw_oraw, return PadTensorDescriptor(vgrad_desc_nraw_oraw,
...@@ -490,6 +506,69 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -490,6 +506,69 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
//
// dP = dY * V^T
//
// YGrad in Gemm A position
static auto MakeYGradGridDescriptor_O0_M_O1(const std::vector<index_t>& y_gs_ms_os_lengths,
const std::vector<index_t>& y_gs_ms_os_strides)
{
return Transform::MakeAGridDescriptor_AK0_M_AK1(
Transform::MakeAGridDescriptor_M_K(y_gs_ms_os_lengths, y_gs_ms_os_strides),
Number<Y_O1>{});
}
// V in Gemm B position
static auto MakeVGridDescriptor_O0_N_O1(const std::vector<index_t>& v_gs_os_ns_lengths,
const std::vector<index_t>& v_gs_os_ns_strides)
{
// v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major.
// Here directly rearrange lengths/strides before constructing tensor descriptor to reduce
// transformation overhead
// TODO: This will be much easier when inputs are Gs, Ms, Ns, Os. So there's no need to
// extract subsequence and shuffle them.
const index_t num_dims = NumDimG + NumDimN + NumDimO;
// 0, 1, .. NumDimG - 1
std::vector<index_t> gs_ids(NumDimG);
std::iota(gs_ids.begin(), gs_ids.end(), 0);
// NumDimG, NumDimG + 1, ... NumDimG + NumDimO - 1
std::vector<index_t> os_ids(NumDimO);
std::iota(os_ids.begin(), os_ids.end(), NumDimG);
// NumDimG + NumDimO, NumDimG + NumDimO + 1, ... NumDimG + NumDimO + NumDimN - 1
std::vector<index_t> ns_ids(NumDimN);
std::iota(ns_ids.begin(), ns_ids.end(), NumDimG + NumDimO);
std::vector<index_t> ids_old2new;
ids_old2new.insert(ids_old2new.end(), gs_ids.begin(), gs_ids.end());
ids_old2new.insert(ids_old2new.end(), ns_ids.begin(), ns_ids.end());
ids_old2new.insert(ids_old2new.end(), os_ids.begin(), os_ids.end());
std::vector<index_t> v_gs_ns_os_lengths(num_dims), v_gs_ns_os_strides(num_dims);
for(int i = 0; i < num_dims; i++)
{
index_t id_new = ids_old2new[i];
v_gs_ns_os_lengths[i] = v_gs_os_ns_lengths[id_new];
v_gs_ns_os_strides[i] = v_gs_os_ns_strides[id_new];
}
const auto v_grid_desc_nraw_oraw =
MakeGridDescriptorPair<NumDimG, NumDimN, NumDimO, TensorSpecialization::Default>(
v_gs_ns_os_lengths, v_gs_ns_os_strides)
.second;
const auto v_grid_desc_n_o = PadTensorDescriptor(v_grid_desc_nraw_oraw,
make_tuple(NPerBlock, Gemm1NPerBlock),
Sequence<padder.PadN, padder.PadO>{});
// N_O to O0_N_O1; to refactor
return Transform::MakeB0GridDescriptor_BK0_N_BK1(v_grid_desc_n_o, Number<V_O1>{});
}
// //
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i) // dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
// //
...@@ -499,10 +578,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -499,10 +578,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// //
// QGrad in Gemm C position // QGrad in Gemm C position
static auto MakeQGradGridDescriptor_M_K(const std::vector<index_t>& q_gs_ms_ks_lengths_vec, static auto MakeQGradGridDescriptor_M_K(const std::vector<index_t>& q_gs_ms_ks_lengths,
const std::vector<index_t>& q_gs_ms_ks_strides_vec) const std::vector<index_t>& q_gs_ms_ks_strides)
{ {
return Transform::MakeCGridDescriptor_M_N(q_gs_ms_ks_lengths_vec, q_gs_ms_ks_strides_vec); return Transform::MakeCGridDescriptor_M_N(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
} }
// //
...@@ -510,16 +589,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -510,16 +589,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// //
// KGrad in Gemm C position // KGrad in Gemm C position
static auto MakeKGradGridDescriptor_N_K(const std::vector<index_t>& k_gs_ns_ks_lengths_vec, static auto MakeKGradGridDescriptor_N_K(const std::vector<index_t>& k_gs_ns_ks_lengths,
const std::vector<index_t>& k_gs_ns_ks_strides_vec) const std::vector<index_t>& k_gs_ns_ks_strides)
{ {
return Transform::MakeCGridDescriptor_M_N(k_gs_ns_ks_lengths_vec, k_gs_ns_ks_strides_vec); return Transform::MakeCGridDescriptor_M_N(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
} }
static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths_vec, static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides_vec) const std::vector<index_t>& z_gs_ms_ns_strides)
{ {
return Transform::MakeCGridDescriptor_M_N(z_gs_ms_ns_lengths_vec, z_gs_ms_ns_strides_vec); return Transform::MakeCGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
} }
static auto MakeLSEGridDescriptor_M(index_t MRaw) static auto MakeLSEGridDescriptor_M(index_t MRaw)
...@@ -546,6 +625,23 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -546,6 +625,23 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
return lse_grid_desc_mraw; return lse_grid_desc_mraw;
} }
} }
// D0 in Gemm0 C position
static auto MakeD0GridDescriptor_M_N(const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides)
{
return Transform::MakeCGridDescriptor_M_N(acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides);
}
static auto
MakeD0GridDescriptor_G_M_N(const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides)
{
return Transform::MakeCGridDescriptor_G_M_N(acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides);
}
static auto MakeDGridDescriptor_M(index_t MRaw) static auto MakeDGridDescriptor_M(index_t MRaw)
{ {
...@@ -574,16 +670,18 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -574,16 +670,18 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {})); using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {})); using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {})); using B1GridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using YGridDesc_M_O = decltype(Transform::MakeCGridDescriptor_M_N({}, {})); using YGridDesc_M_O = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1)); using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1));
using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {})); using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {}));
using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {})); using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {}));
using D0GridDesc_G_M_N = decltype(MakeD0GridDescriptor_G_M_N({}, {}));
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {})); using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using KGridDesc_N_K = decltype(Transform::MakeB0GridDescriptor_N_K({}, {})); using KGridDesc_N_K = decltype(Transform::MakeB0GridDescriptor_N_K({}, {}));
using D0GridDesc_M_N = decltype(MakeD0GridDescriptor_M_N({}, {}));
using YGradGridDesc_M0_O_M1 = decltype(MakeYGradGridDescriptor_M0_O_M1(YGridDesc_M_O{})); using YGradGridDesc_M0_O_M1 = decltype(MakeYGradGridDescriptor_M0_O_M1(YGridDesc_M_O{}));
using ZGridDesc_M_N = decltype(MakeZGridDescriptor_M_N({}, {})); using ZGridDesc_M_N = decltype(MakeZGridDescriptor_M_N({}, {}));
...@@ -611,12 +709,14 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -611,12 +709,14 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
{ {
ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k, ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k,
const BGridDesc_G_N_K& b_grid_desc_g_n_k, const BGridDesc_G_N_K& b_grid_desc_g_n_k,
const D0GridDesc_G_M_N& d0_grid_desc_g_m_n,
const ZGridDesc_G_M_N& z_grid_desc_g_m_n, const ZGridDesc_G_M_N& z_grid_desc_g_m_n,
const B1GridDesc_G_N_K& b1_grid_desc_g_n_k, const B1GridDesc_G_N_K& b1_grid_desc_g_n_k,
const CGridDesc_G_M_N& c_grid_desc_g_m_n, const CGridDesc_G_M_N& c_grid_desc_g_m_n,
index_t BatchStrideLSE) index_t BatchStrideLSE)
: a_grid_desc_g_m_k_(a_grid_desc_g_m_k), : a_grid_desc_g_m_k_(a_grid_desc_g_m_k),
b_grid_desc_g_n_k_(b_grid_desc_g_n_k), b_grid_desc_g_n_k_(b_grid_desc_g_n_k),
d0_grid_desc_g_m_n_(d0_grid_desc_g_m_n),
z_grid_desc_g_m_n_(z_grid_desc_g_m_n), z_grid_desc_g_m_n_(z_grid_desc_g_m_n),
b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k), b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k),
c_grid_desc_g_m_n_(c_grid_desc_g_m_n), c_grid_desc_g_m_n_(c_grid_desc_g_m_n),
...@@ -634,6 +734,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -634,6 +734,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
return b_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0)); return b_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
} }
__host__ __device__ constexpr long_index_t GetD0BasePtr(index_t g_idx) const
{
return d0_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
__host__ __device__ constexpr long_index_t GetZBasePtr(index_t g_idx) const __host__ __device__ constexpr long_index_t GetZBasePtr(index_t g_idx) const
{ {
return z_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0)); return z_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
...@@ -657,6 +762,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -657,6 +762,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
private: private:
AGridDesc_G_M_K a_grid_desc_g_m_k_; AGridDesc_G_M_K a_grid_desc_g_m_k_;
BGridDesc_G_N_K b_grid_desc_g_n_k_; BGridDesc_G_N_K b_grid_desc_g_n_k_;
D0GridDesc_G_M_N d0_grid_desc_g_m_n_;
ZGridDesc_G_M_N z_grid_desc_g_m_n_; ZGridDesc_G_M_N z_grid_desc_g_m_n_;
B1GridDesc_G_N_K b1_grid_desc_g_n_k_; B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_; CGridDesc_G_M_N c_grid_desc_g_m_n_;
...@@ -666,6 +772,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -666,6 +772,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2< using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2<
InputDataType, // TODO: distinguish A/B datatype InputDataType, // TODO: distinguish A/B datatype
D0DataType,
OutputDataType, OutputDataType,
ZDataType, ZDataType,
GemmDataType, GemmDataType,
...@@ -682,6 +789,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -682,6 +789,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
AGridDesc_AK0_M_AK1, AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1, BGridDesc_BK0_N_BK1,
KGridDesc_N_K, KGridDesc_N_K,
D0GridDesc_M_N,
ZGridDesc_M_N, ZGridDesc_M_N,
B1GridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1,
YGridDesc_M_O, YGridDesc_M_O,
...@@ -693,6 +801,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -693,6 +801,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
KPerBlock, KPerBlock,
Gemm1NPerBlock, Gemm1NPerBlock,
Gemm1KPerBlock, Gemm1KPerBlock,
Gemm2KPerBlock,
AK1, AK1,
BK1, BK1,
B1K1, B1K1,
...@@ -718,6 +827,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -718,6 +827,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
BBlockTransferDstScalarPerVector_BK1, BBlockTransferDstScalarPerVector_BK1,
true, true,
BBlockLdsExtraN, BBlockLdsExtraN,
D0BlockTransferSrcScalarPerVector,
B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterLengths_BK0_N_BK1,
B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferThreadClusterArrangeOrder,
B1BlockTransferSrcAccessOrder, B1BlockTransferSrcAccessOrder,
...@@ -754,6 +864,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -754,6 +864,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// pointers // pointers
const InputDataType* p_a_grid_; const InputDataType* p_a_grid_;
const InputDataType* p_b_grid_; const InputDataType* p_b_grid_;
const D0DataType* p_d0_grid_;
ZDataType* p_z_grid_; ZDataType* p_z_grid_;
const InputDataType* p_b1_grid_; const InputDataType* p_b1_grid_;
const InputDataType* p_c_grid_; const InputDataType* p_c_grid_;
...@@ -766,6 +877,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -766,6 +877,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// tensor descriptors for block/thread-wise copy // tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::D0GridDescriptor_M0_N0_M1_M2_N1_M3 d0_grid_desc_m0_n0_m1_m2_n1_m3_;
ZGridDesc_M_N z_grid_desc_m_n_; ZGridDesc_M_N z_grid_desc_m_n_;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_; B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
YGridDesc_M_O y_grid_desc_m_o_; YGridDesc_M_O y_grid_desc_m_o_;
...@@ -812,6 +924,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -812,6 +924,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
CGridDesc_G_M_N c_grid_desc_g_m_n_; CGridDesc_G_M_N c_grid_desc_g_m_n_;
index_t batch_count_; index_t batch_count_;
// raw data
std::vector<ck::index_t> d0_n_length_stride_;
}; };
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
...@@ -827,8 +942,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -827,8 +942,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
std::vector<void*>& p_Qgrads, std::vector<void*>& p_Qgrads,
std::vector<void*>& p_Kgrads, std::vector<void*>& p_Kgrads,
std::vector<void*>& p_Vgrads, std::vector<void*>& p_Vgrads,
const std::array<void*, NumAcc0Bias>& p_acc0_biases, const std::vector<const void*>& p_acc0_bias_vec,
const std::array<void*, NumAcc1Bias>& p_acc1_biases, const std::vector<const void*>& p_acc1_bias_vec,
const std::vector<ProblemDesc>& problem_desc_vec, const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -859,16 +974,14 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -859,16 +974,14 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
group_count_ == ck::type_convert<ck::index_t>(p_Kgrads.size()) && group_count_ == ck::type_convert<ck::index_t>(p_Kgrads.size()) &&
group_count_ == ck::type_convert<ck::index_t>(p_Vgrads.size()) && group_count_ == ck::type_convert<ck::index_t>(p_Vgrads.size()) &&
group_count_ == ck::type_convert<ck::index_t>(p_LSEs.size()) && group_count_ == ck::type_convert<ck::index_t>(p_LSEs.size()) &&
group_count_ == ck::type_convert<ck::index_t>(p_Ds.size()))) group_count_ == ck::type_convert<ck::index_t>(p_Ds.size()) &&
(group_count_ == ck::type_convert<ck::index_t>(p_acc0_bias_vec.size()) ||
ck::type_convert<ck::index_t>(p_acc0_bias_vec.size() == 0)) &&
0 == p_acc1_bias_vec.size()))
{ {
throw std::runtime_error("wrong! group_count_ != p_As/b/b1/c.size"); throw std::runtime_error("wrong! group_count_ != p_As/b/b1/c.size");
} }
if(!(p_acc0_biases.size() == p_acc1_biases.size()))
{
throw std::runtime_error("wrong! acc0_bias_vec.size != acc1_bias_vec.size");
}
grid_size_ = 0; grid_size_ = 0;
index_t z_random_matrix_offset = 0; index_t z_random_matrix_offset = 0;
...@@ -877,8 +990,12 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -877,8 +990,12 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
for(index_t i = 0; i < group_count_; i++) for(index_t i = 0; i < group_count_; i++)
{ {
const auto p_a_grid = static_cast<const InputDataType*>(p_As[i]); const auto p_a_grid = static_cast<const InputDataType*>(p_As[i]);
const auto p_b_grid = static_cast<const InputDataType*>(p_Bs[i]); const auto p_b_grid = static_cast<const InputDataType*>(p_Bs[i]);
const auto p_d0_grid =
(ck::type_convert<ck::index_t>(p_acc0_bias_vec.size()) == group_count_)
? static_cast<const D0DataType*>(p_acc0_bias_vec[i])
: nullptr;
auto p_z_grid = static_cast<ZDataType*>(p_Zs[i]); auto p_z_grid = static_cast<ZDataType*>(p_Zs[i]);
const auto p_b1_grid = static_cast<const InputDataType*>(p_B1s[i]); const auto p_b1_grid = static_cast<const InputDataType*>(p_B1s[i]);
const auto p_c_grid = static_cast<const InputDataType*>(p_Cs[i]); const auto p_c_grid = static_cast<const InputDataType*>(p_Cs[i]);
...@@ -894,9 +1011,26 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -894,9 +1011,26 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides); problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides);
const auto b_grid_desc_bk0_n_bk1 = DeviceOp::MakeBGridDescriptor_BK0_N_BK1( const auto b_grid_desc_bk0_n_bk1 = DeviceOp::MakeBGridDescriptor_BK0_N_BK1(
problem_desc.b_gs_ns_ks_lengths, problem_desc.b_gs_ns_ks_strides); problem_desc.b_gs_ns_ks_lengths, problem_desc.b_gs_ns_ks_strides);
std::vector<index_t> tmp_d0_gs_ms_ns_lengths;
std::vector<index_t> tmp_d0_gs_ms_ns_strides;
if constexpr(!is_same<D0DataType, void>::value)
{
tmp_d0_gs_ms_ns_lengths = problem_desc.acc0_bias_gs_ms_ns_lengths;
tmp_d0_gs_ms_ns_strides = problem_desc.acc0_bias_gs_ms_ns_strides;
}
else
{
tmp_d0_gs_ms_ns_lengths = {1, 1, 1, 1};
tmp_d0_gs_ms_ns_strides = {0, 0, 0, 0};
}
const D0GridDesc_M_N d0_grid_desc_m_n{DeviceOp::MakeD0GridDescriptor_M_N(
tmp_d0_gs_ms_ns_lengths, tmp_d0_gs_ms_ns_strides)};
const auto d0_grid_desc_m0_n0_m1_m2_n1_m3 =
GridwiseGemm::MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(d0_grid_desc_m_n);
const auto z_grid_desc_m_n = DeviceOp::MakeZGridDescriptor_M_N( const auto z_grid_desc_m_n = DeviceOp::MakeZGridDescriptor_M_N(
problem_desc.z_gs_ms_ns_lengths, problem_desc.z_gs_ms_ns_strides); problem_desc.z_gs_ms_ns_lengths, problem_desc.z_gs_ms_ns_strides);
const auto b1_grid_desc_bk0_n_bk1 = DeviceOp::MakeB1GridDescriptor_BK0_N_BK1( const auto b1_grid_desc_bk0_n_bk1 = DeviceOp::MakeVGridDescriptor_O0_N_O1(
problem_desc.b1_gs_gemm1ns_gemm1ks_lengths, problem_desc.b1_gs_gemm1ns_gemm1ks_lengths,
problem_desc.b1_gs_gemm1ns_gemm1ks_strides); problem_desc.b1_gs_gemm1ns_gemm1ks_strides);
const auto y_grid_desc_m_o = Transform::MakeCGridDescriptor_M_N( const auto y_grid_desc_m_o = Transform::MakeCGridDescriptor_M_N(
...@@ -913,6 +1047,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -913,6 +1047,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides); problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides);
const auto b_grid_desc_g_n_k = Transform::MakeB0GridDescriptor_G_N_K( const auto b_grid_desc_g_n_k = Transform::MakeB0GridDescriptor_G_N_K(
problem_desc.b_gs_ns_ks_lengths, problem_desc.b_gs_ns_ks_strides); problem_desc.b_gs_ns_ks_lengths, problem_desc.b_gs_ns_ks_strides);
const auto d0_grid_desc_g_m_n = DeviceOp::MakeD0GridDescriptor_G_M_N(
tmp_d0_gs_ms_ns_lengths, tmp_d0_gs_ms_ns_strides);
const auto z_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N( const auto z_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N(
problem_desc.z_gs_ms_ns_lengths, problem_desc.z_gs_ms_ns_strides); problem_desc.z_gs_ms_ns_lengths, problem_desc.z_gs_ms_ns_strides);
const auto b1_grid_desc_g_n_k = Transform::MakeB1GridDescriptor_G_N_K( const auto b1_grid_desc_g_n_k = Transform::MakeB1GridDescriptor_G_N_K(
...@@ -938,6 +1074,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -938,6 +1074,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const auto compute_base_ptr_of_batch = ComputeBasePtrOfStridedBatch( const auto compute_base_ptr_of_batch = ComputeBasePtrOfStridedBatch(
a_grid_desc_g_m_k, a_grid_desc_g_m_k,
b_grid_desc_g_n_k, b_grid_desc_g_n_k,
d0_grid_desc_g_m_n,
z_grid_desc_g_m_n, z_grid_desc_g_m_n,
b1_grid_desc_g_n_k, b1_grid_desc_g_n_k,
c_grid_desc_g_m_n, c_grid_desc_g_m_n,
...@@ -949,18 +1086,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -949,18 +1086,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
grid_size_ += grid_size_grp; grid_size_ += grid_size_grp;
// for each group, make sure acc0_biases_gs_ms_ns_lengths.size() == NumAcc0Bias and
// so on
if(!(problem_desc.acc0_biases_gs_ms_ns_lengths.size() == NumAcc0Bias &&
problem_desc.acc0_biases_gs_ms_ns_strides.size() == NumAcc0Bias &&
problem_desc.acc1_biases_gs_ms_os_lengths.size() == NumAcc1Bias &&
problem_desc.acc1_biases_gs_ms_os_strides.size() == NumAcc1Bias))
{
throw std::runtime_error(
"wrong! number of biases in function argument does not "
"match that in template argument");
}
const auto raw_m_padded = GridwiseGemm::GetPaddedSize( const auto raw_m_padded = GridwiseGemm::GetPaddedSize(
problem_desc.a_gs_ms_ks_lengths[NumDimG + NumDimM - 1]); problem_desc.a_gs_ms_ks_lengths[NumDimG + NumDimM - 1]);
const auto raw_n_padded = GridwiseGemm::GetPaddedSize( const auto raw_n_padded = GridwiseGemm::GetPaddedSize(
...@@ -987,6 +1112,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -987,6 +1112,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
group_kernel_args_.push_back({p_a_grid, group_kernel_args_.push_back({p_a_grid,
p_b_grid, p_b_grid,
p_d0_grid,
p_z_grid, p_z_grid,
p_b1_grid, p_b1_grid,
p_c_grid, p_c_grid,
...@@ -997,6 +1123,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -997,6 +1123,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
p_vgrad_grid, p_vgrad_grid,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
d0_grid_desc_m0_n0_m1_m2_n1_m3,
z_grid_desc_m_n, z_grid_desc_m_n,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
y_grid_desc_m_o, y_grid_desc_m_o,
...@@ -1024,6 +1151,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1024,6 +1151,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
z_random_matrix_offset = z_random_matrix_offset =
z_random_matrix_offset + raw_m_padded * raw_n_padded * batch_count; z_random_matrix_offset + raw_m_padded * raw_n_padded * batch_count;
// for check
std::vector<ck::index_t> d0_n_length_stride;
d0_n_length_stride.push_back(tmp_d0_gs_ms_ns_lengths[NumDimG + NumDimM]);
d0_n_length_stride.push_back(tmp_d0_gs_ms_ns_strides[NumDimG + NumDimM]);
group_device_args_.push_back( group_device_args_.push_back(
{{problem_desc.a_gs_ms_ks_lengths[NumDimG + NumDimM - 1], {{problem_desc.a_gs_ms_ks_lengths[NumDimG + NumDimM - 1],
problem_desc.b_gs_ns_ks_lengths[NumDimG + NumDimN - 1], problem_desc.b_gs_ns_ks_lengths[NumDimG + NumDimN - 1],
...@@ -1038,15 +1170,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1038,15 +1170,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
{problem_desc.c_gs_ms_gemm1ns_strides[NumDimG + NumDimM - 1], {problem_desc.c_gs_ms_gemm1ns_strides[NumDimG + NumDimM - 1],
problem_desc.c_gs_ms_gemm1ns_strides[NumDimG + NumDimM + NumDimO - 1]}, problem_desc.c_gs_ms_gemm1ns_strides[NumDimG + NumDimM + NumDimO - 1]},
c_grid_desc_g_m_n, c_grid_desc_g_m_n,
batch_count}); batch_count,
d0_n_length_stride});
} }
// TODO: implement bias addition // TODO: implement bias addition
// ignore = p_acc0_biases; // ignore = p_acc0_bias_vec;
// ignore = p_acc1_biases; // ignore = p_acc1_bias_vec;
// ignore = acc0_biases_gs_ms_ns_lengths; // ignore = acc0_bias_gs_ms_ns_lengths;
// ignore = acc0_biases_gs_ms_ns_strides; // ignore = acc0_bias_gs_ms_ns_strides;
// ignore = acc1_biases_gs_ms_gemm1ns_lengths; // ignore = acc1_bias_gs_ms_gemm1ns_lengths;
// ignore = acc1_biases_gs_ms_gemm1ns_strides; // ignore = acc1_bias_gs_ms_gemm1ns_strides;
} }
// element-wise op // element-wise op
...@@ -1120,6 +1253,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1120,6 +1253,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const auto kernel = const auto kernel =
kernel_grouped_multihead_attention_backward_qloop_xdl_cshuffle_light_v2< kernel_grouped_multihead_attention_backward_qloop_xdl_cshuffle_light_v2<
GridwiseGemm, GridwiseGemm,
D0DataType,
GroupKernelArg, GroupKernelArg,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
...@@ -1209,13 +1343,27 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1209,13 +1343,27 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const index_t c_m = kernel_arg.y_grid_desc_m_o_.GetLength(I0); const index_t c_m = kernel_arg.y_grid_desc_m_o_.GetLength(I0);
const index_t c_gemm1n = kernel_arg.y_grid_desc_m_o_.GetLength(I1); const index_t c_gemm1n = kernel_arg.y_grid_desc_m_o_.GetLength(I1);
const index_t a_m = kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I1); const index_t a_m = kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t b1_gemm1n = kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1); const index_t b1_gemm1n = kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I0) *
kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I2);
if(!(c_g == device_arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n)) if(!(c_g == device_arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n))
{ {
return false; return false;
} }
if constexpr(!is_same<D0DataType, void>::value)
{
if(device_arg.d0_n_length_stride_[1] == 1 &&
device_arg.d0_n_length_stride_[0] % D0BlockTransferSrcScalarPerVector != 0)
{
return false;
}
if(device_arg.d0_n_length_stride_[1] != 1 && D0BlockTransferSrcScalarPerVector != 1)
{
return false;
}
}
// Note: we need raw lengths since threadwise copy can not handle vector load when part // Note: we need raw lengths since threadwise copy can not handle vector load when part
// of vector is out of bounds Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O // of vector is out of bounds Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O
const auto MzRaw = device_arg.raw_lengths_mz_nz_kz_gemm1nz_[0]; const auto MzRaw = device_arg.raw_lengths_mz_nz_kz_gemm1nz_[0];
...@@ -1290,8 +1438,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1290,8 +1438,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
std::vector<void*>& p_Qgrads, std::vector<void*>& p_Qgrads,
std::vector<void*>& p_Kgrads, std::vector<void*>& p_Kgrads,
std::vector<void*>& p_Vgrads, std::vector<void*>& p_Vgrads,
const std::array<void*, NumAcc0Bias>& p_acc0_biases, const std::vector<const void*>& p_acc0_bias_vec,
const std::array<void*, NumAcc1Bias>& p_acc1_biases, const std::vector<const void*>& p_acc1_bias_vec,
const std::vector<ProblemDesc>& problem_desc_vec, const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -1301,16 +1449,26 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1301,16 +1449,26 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
float p_drop, float p_drop,
std::tuple<unsigned long long, unsigned long long> seeds) std::tuple<unsigned long long, unsigned long long> seeds)
{ {
return Argument{p_As, p_Bs, return Argument{p_As,
p_Zs, p_B1s, p_Bs,
p_Cs, p_LSEs, p_Zs,
p_Ds, p_Ygrads, p_B1s,
p_Qgrads, p_Kgrads, p_Cs,
p_Vgrads, p_acc0_biases, p_LSEs,
p_acc1_biases, problem_desc_vec, p_Ds,
a_element_op, b_element_op, p_Ygrads,
acc_element_op, b1_element_op, p_Qgrads,
c_element_op, p_drop, p_Kgrads,
p_Vgrads,
p_acc0_bias_vec,
p_acc1_bias_vec,
problem_desc_vec,
a_element_op,
b_element_op,
acc_element_op,
b1_element_op,
c_element_op,
p_drop,
seeds}; seeds};
} }
...@@ -1330,8 +1488,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1330,8 +1488,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
std::vector<void*>& p_Qgrads, std::vector<void*>& p_Qgrads,
std::vector<void*>& p_Kgrads, std::vector<void*>& p_Kgrads,
std::vector<void*>& p_Vgrads, std::vector<void*>& p_Vgrads,
const std::array<void*, NumAcc0Bias>& p_acc0_biases, const std::vector<const void*>& p_acc0_bias_vec,
const std::array<void*, NumAcc1Bias>& p_acc1_biases, const std::vector<const void*>& p_acc1_bias_vec,
const std::vector<ProblemDesc>& problem_desc_vec, const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -1352,8 +1510,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1352,8 +1510,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
p_Qgrads, p_Qgrads,
p_Kgrads, p_Kgrads,
p_Vgrads, p_Vgrads,
p_acc0_biases, // cast in struct Argument p_acc0_bias_vec, // cast in struct Argument
p_acc1_biases, // cast in struct Argument p_acc1_bias_vec, // cast in struct Argument
problem_desc_vec, problem_desc_vec,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -1387,6 +1545,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1387,6 +1545,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
<< MPerBlock << ", " << MPerBlock << ", "
<< Gemm1NPerBlock << ", " << Gemm1NPerBlock << ", "
<< Gemm1KPerBlock << ", " << Gemm1KPerBlock << ", "
<< Gemm2KPerBlock << ", "
<< B1K1 << ", " << B1K1 << ", "
<< getGemmSpecializationString(GemmSpec) << ", " << getGemmSpecializationString(GemmSpec) << ", "
<< "ASpec" << getTensorSpecializationString(ASpec) << ", " << "ASpec" << getTensorSpecializationString(ASpec) << ", "
......
...@@ -498,7 +498,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -498,7 +498,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
return lse_grid_desc_mraw; return lse_grid_desc_mraw;
} }
} }
// D in Gemm0 C position // D0 in Gemm0 C position
static auto MakeD0GridDescriptor_M_N(const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths, static auto MakeD0GridDescriptor_M_N(const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides) const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides)
{ {
......
...@@ -561,7 +561,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -561,7 +561,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
return lse_grid_desc_mraw; return lse_grid_desc_mraw;
} }
} }
// D in Gemm0 C position // D0 in Gemm0 C position
static auto MakeD0GridDescriptor_M_N(const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths, static auto MakeD0GridDescriptor_M_N(const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides) const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment