Commit ec2ad713 authored by letaoqin's avatar letaoqin
Browse files

Merge branch 'mha-train-develop' into mha-train-bias-bwd-type2

parents e3eb4381 e296ee56
...@@ -52,8 +52,8 @@ using CShuffleDataType = F32; ...@@ -52,8 +52,8 @@ using CShuffleDataType = F32;
using CDataType = DataType; using CDataType = DataType;
using ZDataType = U16; // INT32 using ZDataType = U16; // INT32
using LSEDataType = F32; using LSEDataType = F32;
using Acc0BiasDataType = ck::Tuple<>; using Acc0BiasDataType = void;
using Acc1BiasDataType = ck::Tuple<>; using Acc1BiasDataType = void;
static constexpr ck::index_t NumDimG = 2; static constexpr ck::index_t NumDimG = 2;
static constexpr ck::index_t NumDimM = 1; static constexpr ck::index_t NumDimM = 1;
......
...@@ -52,8 +52,8 @@ using CShuffleDataType = F32; ...@@ -52,8 +52,8 @@ using CShuffleDataType = F32;
using CDataType = DataType; using CDataType = DataType;
using ZDataType = U16; // INT32 using ZDataType = U16; // INT32
using LSEDataType = F32; using LSEDataType = F32;
using Acc0BiasDataType = ck::Tuple<>; using Acc0BiasDataType = void;
using Acc1BiasDataType = ck::Tuple<>; using Acc1BiasDataType = void;
static constexpr ck::index_t NumDimG = 2; static constexpr ck::index_t NumDimG = 2;
static constexpr ck::index_t NumDimM = 1; static constexpr ck::index_t NumDimM = 1;
...@@ -121,6 +121,7 @@ using DeviceGemmInstance = ...@@ -121,6 +121,7 @@ using DeviceGemmInstance =
1, // MXdlPerWave 1, // MXdlPerWave
4, // NXdlPerWave 4, // NXdlPerWave
1, // Gemm1NXdlPerWave 1, // Gemm1NXdlPerWave
1, // DropoutStep
S<4, 64, 1>, // ABlockTransfer S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
...@@ -194,6 +195,7 @@ using DeviceGemmInstance = ...@@ -194,6 +195,7 @@ using DeviceGemmInstance =
1, // MXdlPerWave 1, // MXdlPerWave
4, // NXdlPerWave 4, // NXdlPerWave
2, // Gemm1NXdlPerWave 2, // Gemm1NXdlPerWave
1, // DropoutStep
S<4, 64, 1>, // ABlockTransfer S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
...@@ -257,7 +259,7 @@ using DeviceGemmInstance = ...@@ -257,7 +259,7 @@ using DeviceGemmInstance =
128, // MPerBlock 128, // MPerBlock
128, // NPerBlock 128, // NPerBlock
32, // KPerBlock 32, // KPerBlock
128, // Gemm1NPerBlock 64, // Gemm1NPerBlock
32, // Gemm1KPerBlock 32, // Gemm1KPerBlock
8, // AK1 8, // AK1
8, // BK1 8, // BK1
...@@ -266,7 +268,8 @@ using DeviceGemmInstance = ...@@ -266,7 +268,8 @@ using DeviceGemmInstance =
32, // NPerXDL 32, // NPerXDL
1, // MXdlPerWave 1, // MXdlPerWave
4, // NXdlPerWave 4, // NXdlPerWave
4, // Gemm1NXdlPerWave 2, // Gemm1NXdlPerWave
1, // DropoutStep
S<4, 64, 1>, // ABlockTransfer S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
...@@ -282,7 +285,7 @@ using DeviceGemmInstance = ...@@ -282,7 +285,7 @@ using DeviceGemmInstance =
8, 8,
true, true,
4, 4,
S<8, 32, 1>, // B1BlockTransfer S<16, 16, 1>, // B1BlockTransfer
S<0, 2, 1>, S<0, 2, 1>,
S<0, 2, 1>, S<0, 2, 1>,
1, 1,
......
...@@ -125,8 +125,8 @@ using DeviceGemmInstanceFWD = ...@@ -125,8 +125,8 @@ using DeviceGemmInstanceFWD =
GemmDataType, GemmDataType,
ZDataType, ZDataType,
LSEDataType, LSEDataType,
Acc0BiasDataType, void,
Acc1BiasDataType, void,
AccDataType, AccDataType,
ShuffleDataType, ShuffleDataType,
QKVElementOp, QKVElementOp,
...@@ -259,8 +259,8 @@ using DeviceGemmInstanceFWD = ...@@ -259,8 +259,8 @@ using DeviceGemmInstanceFWD =
GemmDataType, GemmDataType,
ZDataType, ZDataType,
LSEDataType, LSEDataType,
Acc0BiasDataType, void,
Acc1BiasDataType, void,
AccDataType, AccDataType,
ShuffleDataType, ShuffleDataType,
QKVElementOp, QKVElementOp,
...@@ -463,8 +463,8 @@ using DeviceGemmInstanceFWD = ...@@ -463,8 +463,8 @@ using DeviceGemmInstanceFWD =
GemmDataType, GemmDataType,
ZDataType, ZDataType,
LSEDataType, LSEDataType,
Acc0BiasDataType, void,
Acc1BiasDataType, void,
AccDataType, AccDataType,
ShuffleDataType, ShuffleDataType,
QKVElementOp, QKVElementOp,
......
...@@ -113,11 +113,11 @@ static constexpr bool Deterministic = false; ...@@ -113,11 +113,11 @@ static constexpr bool Deterministic = false;
#if(DIM <= 32) #if(DIM <= 32)
// clang-format off // clang-format off
using DeviceGemmInstanceFWD = using DeviceGemmInstanceFWD =
// #################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| ADataType| BDataType| B1DataType| CDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector| D1BlockTransfer| MaskingSpec| Deterministic| // #################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| ADataType| BDataType| B1DataType| CDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1|Dropout| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector| D1BlockTransfer| MaskingSpec| Deterministic|
// #################################################################################| | | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| _NPerBlock| SrcScalar| | | // #################################################################################| | | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| Step| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| _NPerBlock| SrcScalar| | |
// #################################################################################| | | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | PerVector| | | // #################################################################################| | | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | PerVector| | |
// #################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // #################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, InputDataType, InputDataType, InputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 32, 32, 8, 8, 2, 32, 32, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 4, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, false, 1, 1, S<1, 64, 1, 4>, 8, 4, MaskingSpec, Deterministic>; ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, InputDataType, InputDataType, InputDataType, GemmDataType, ZDataType, LSEDataType, void, void, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 32, 32, 8, 8, 2, 32, 32, 1, 4, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 4, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, false, 1, 1, S<1, 64, 1, 4>, 8, 4, MaskingSpec, Deterministic>;
using DeviceGemmInstanceBWD = using DeviceGemmInstanceBWD =
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic| // ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
...@@ -129,11 +129,11 @@ using DeviceGemmInstanceBWD = ...@@ -129,11 +129,11 @@ using DeviceGemmInstanceBWD =
#elif(DIM <= 64) #elif(DIM <= 64)
// clang-format off // clang-format off
using DeviceGemmInstanceFWD = using DeviceGemmInstanceFWD =
// #################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| ADataType| BDataType| B1DataType| CDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector| D1BlockTransfer| MaskingSpec| Deterministic| // #################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| ADataType| BDataType| B1DataType| CDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1|Dropout| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector| D1BlockTransfer| MaskingSpec| Deterministic|
// #################################################################################| | | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| _NPerBlock| SrcScalar| | | // #################################################################################| | | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| Step| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| _NPerBlock| SrcScalar| | |
// #################################################################################| | | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | PerVector| | | // #################################################################################| | | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | PerVector| | |
// #################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // #################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, InputDataType, InputDataType, InputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 4, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, 4, MaskingSpec, Deterministic>; ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, InputDataType, InputDataType, InputDataType, GemmDataType, ZDataType, LSEDataType, void, void, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 4, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, 4, MaskingSpec, Deterministic>;
using DeviceGemmInstanceBWD = using DeviceGemmInstanceBWD =
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic| // ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
...@@ -152,11 +152,11 @@ using DeviceGemmInstanceBWD = ...@@ -152,11 +152,11 @@ using DeviceGemmInstanceBWD =
#elif(DIM <= 128) #elif(DIM <= 128)
// clang-format off // clang-format off
using DeviceGemmInstanceFWD = using DeviceGemmInstanceFWD =
// #################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| ADataType| BDataType| B1DataType| CDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector| D1BlockTransfer| MaskingSpec| Deterministic| // #################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| ADataType| BDataType| B1DataType| CDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1|Dropout| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector| D1BlockTransfer| MaskingSpec| Deterministic|
// #################################################################################| | | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| _NPerBlock| SrcScalar| | | // #################################################################################| | | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| Step| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| _NPerBlock| SrcScalar| | |
// #################################################################################| | | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | PerVector| | | // #################################################################################| | | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | PerVector| | |
// #################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // #################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, InputDataType, InputDataType, InputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, 4, MaskingSpec, Deterministic>; ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, InputDataType, InputDataType, InputDataType, GemmDataType, ZDataType, LSEDataType, void, void, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 4, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, 4, MaskingSpec, Deterministic>;
using DeviceGemmInstanceBWD = using DeviceGemmInstanceBWD =
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic| // ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
...@@ -534,8 +534,8 @@ int run(int argc, char* argv[]) ...@@ -534,8 +534,8 @@ int run(int argc, char* argv[])
static_cast<InputDataType*>(y_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(y_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(nullptr), static_cast<ZDataType*>(nullptr),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()), static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases; nullptr, // p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases; nullptr, // p_acc1_biases;
q_gs_ms_ks_lengths, q_gs_ms_ks_lengths,
q_gs_ms_ks_strides, q_gs_ms_ks_strides,
k_gs_ns_ks_lengths, k_gs_ns_ks_lengths,
...@@ -594,8 +594,8 @@ int run(int argc, char* argv[]) ...@@ -594,8 +594,8 @@ int run(int argc, char* argv[])
static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases; {}, // p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases; {}, // p_acc1_biases;
q_gs_ms_ks_lengths, q_gs_ms_ks_lengths,
q_gs_ms_ks_strides, q_gs_ms_ks_strides,
k_gs_ns_ks_lengths, k_gs_ns_ks_lengths,
......
...@@ -52,8 +52,8 @@ using CShuffleDataType = F32; ...@@ -52,8 +52,8 @@ using CShuffleDataType = F32;
using CDataType = DataType; using CDataType = DataType;
using ZDataType = U16; // INT32 using ZDataType = U16; // INT32
using LSEDataType = F32; using LSEDataType = F32;
using Acc0BiasDataType = ck::Tuple<>; using Acc0BiasDataType = void;
using Acc1BiasDataType = ck::Tuple<>; using Acc1BiasDataType = void;
static constexpr ck::index_t NumDimG = 2; static constexpr ck::index_t NumDimG = 2;
static constexpr ck::index_t NumDimM = 1; static constexpr ck::index_t NumDimM = 1;
......
...@@ -52,8 +52,8 @@ using CShuffleDataType = F32; ...@@ -52,8 +52,8 @@ using CShuffleDataType = F32;
using CDataType = DataType; using CDataType = DataType;
using ZDataType = U16; // INT32 using ZDataType = U16; // INT32
using LSEDataType = F32; using LSEDataType = F32;
using Acc0BiasDataType = ck::Tuple<>; using Acc0BiasDataType = void;
using Acc1BiasDataType = ck::Tuple<>; using Acc1BiasDataType = void;
static constexpr ck::index_t NumDimG = 2; static constexpr ck::index_t NumDimG = 2;
static constexpr ck::index_t NumDimM = 1; static constexpr ck::index_t NumDimM = 1;
...@@ -121,6 +121,7 @@ using DeviceGemmInstance = ...@@ -121,6 +121,7 @@ using DeviceGemmInstance =
1, // MXdlPerWave 1, // MXdlPerWave
4, // NXdlPerWave 4, // NXdlPerWave
1, // Gemm1NXdlPerWave 1, // Gemm1NXdlPerWave
1, // DropoutStep
S<4, 64, 1>, // ABlockTransfer S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
...@@ -194,6 +195,7 @@ using DeviceGemmInstance = ...@@ -194,6 +195,7 @@ using DeviceGemmInstance =
1, // MXdlPerWave 1, // MXdlPerWave
4, // NXdlPerWave 4, // NXdlPerWave
2, // Gemm1NXdlPerWave 2, // Gemm1NXdlPerWave
1, // DropoutStep
S<4, 64, 1>, // ABlockTransfer S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
...@@ -257,7 +259,7 @@ using DeviceGemmInstance = ...@@ -257,7 +259,7 @@ using DeviceGemmInstance =
128, // MPerBlock 128, // MPerBlock
128, // NPerBlock 128, // NPerBlock
32, // KPerBlock 32, // KPerBlock
128, // Gemm1NPerBlock 64, // Gemm1NPerBlock
32, // Gemm1KPerBlock 32, // Gemm1KPerBlock
8, // AK1 8, // AK1
8, // BK1 8, // BK1
...@@ -266,7 +268,8 @@ using DeviceGemmInstance = ...@@ -266,7 +268,8 @@ using DeviceGemmInstance =
32, // NPerXDL 32, // NPerXDL
1, // MXdlPerWave 1, // MXdlPerWave
4, // NXdlPerWave 4, // NXdlPerWave
4, // Gemm1NXdlPerWave 2, // Gemm1NXdlPerWave
1, // DropoutStep
S<4, 64, 1>, // ABlockTransfer S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
...@@ -282,7 +285,7 @@ using DeviceGemmInstance = ...@@ -282,7 +285,7 @@ using DeviceGemmInstance =
8, 8,
true, true,
1, 1,
S<8, 32, 1>, // B1BlockTransfer S<16, 16, 1>, // B1BlockTransfer
S<0, 2, 1>, S<0, 2, 1>,
S<0, 2, 1>, S<0, 2, 1>,
1, 1,
......
...@@ -124,8 +124,8 @@ using DeviceGemmInstanceFWD = ...@@ -124,8 +124,8 @@ using DeviceGemmInstanceFWD =
GemmDataType, GemmDataType,
ZDataType, ZDataType,
LSEDataType, LSEDataType,
Acc0BiasDataType, void,
Acc1BiasDataType, void,
AccDataType, AccDataType,
ShuffleDataType, ShuffleDataType,
QKVElementOp, QKVElementOp,
...@@ -258,8 +258,8 @@ using DeviceGemmInstanceFWD = ...@@ -258,8 +258,8 @@ using DeviceGemmInstanceFWD =
GemmDataType, GemmDataType,
ZDataType, ZDataType,
LSEDataType, LSEDataType,
Acc0BiasDataType, void,
Acc1BiasDataType, void,
AccDataType, AccDataType,
ShuffleDataType, ShuffleDataType,
QKVElementOp, QKVElementOp,
...@@ -462,8 +462,8 @@ using DeviceGemmInstanceFWD = ...@@ -462,8 +462,8 @@ using DeviceGemmInstanceFWD =
GemmDataType, GemmDataType,
ZDataType, ZDataType,
LSEDataType, LSEDataType,
Acc0BiasDataType, void,
Acc1BiasDataType, void,
AccDataType, AccDataType,
ShuffleDataType, ShuffleDataType,
QKVElementOp, QKVElementOp,
......
...@@ -112,11 +112,11 @@ static constexpr bool Deterministic = false; ...@@ -112,11 +112,11 @@ static constexpr bool Deterministic = false;
#if(DIM <= 32) #if(DIM <= 32)
// clang-format off // clang-format off
using DeviceGemmInstanceFWD = using DeviceGemmInstanceFWD =
// #################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| ADataType| BDataType| B1DataType| CDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector| D1BlockTransfer| MaskingSpec| Deterministic| // #################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| ADataType| BDataType| B1DataType| CDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1|Dropout| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector| D1BlockTransfer| MaskingSpec| Deterministic|
// #################################################################################| | | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| _NPerBlock| SrcScalar| | | // #################################################################################| | | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| Step| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| _NPerBlock| SrcScalar| | |
// #################################################################################| | | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | PerVector| | | // #################################################################################| | | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | PerVector| | |
// #################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // #################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, InputDataType, InputDataType, InputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 32, 32, 8, 8, 2, 32, 32, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, false, 1, 1, S<1, 64, 1, 4>, 8, 1, MaskingSpec, Deterministic>; ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, InputDataType, InputDataType, InputDataType, GemmDataType, ZDataType, LSEDataType, void, void, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 32, 32, 8, 8, 2, 32, 32, 1, 4, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, false, 1, 1, S<1, 64, 1, 4>, 8, 1, MaskingSpec, Deterministic>;
using DeviceGemmInstanceBWD = using DeviceGemmInstanceBWD =
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic| // ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
...@@ -128,11 +128,11 @@ using DeviceGemmInstanceBWD = ...@@ -128,11 +128,11 @@ using DeviceGemmInstanceBWD =
#elif(DIM <= 64) #elif(DIM <= 64)
// clang-format off // clang-format off
using DeviceGemmInstanceFWD = using DeviceGemmInstanceFWD =
// #################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| ADataType| BDataType| B1DataType| CDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector| D1BlockTransfer| MaskingSpec| Deterministic| // #################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| ADataType| BDataType| B1DataType| CDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1|Dropout| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector| D1BlockTransfer| MaskingSpec| Deterministic|
// #################################################################################| | | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| _NPerBlock| SrcScalar| | | // #################################################################################| | | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| Step| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| _NPerBlock| SrcScalar| | |
// #################################################################################| | | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | PerVector| | | // #################################################################################| | | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | PerVector| | |
// #################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // #################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, InputDataType, InputDataType, InputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, 1, MaskingSpec, Deterministic>; ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, InputDataType, InputDataType, InputDataType, GemmDataType, ZDataType, LSEDataType, void, void, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, 1, MaskingSpec, Deterministic>;
using DeviceGemmInstanceBWD = using DeviceGemmInstanceBWD =
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic| // ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
...@@ -151,11 +151,11 @@ using DeviceGemmInstanceBWD = ...@@ -151,11 +151,11 @@ using DeviceGemmInstanceBWD =
#elif(DIM <= 128) #elif(DIM <= 128)
// clang-format off // clang-format off
using DeviceGemmInstanceFWD = using DeviceGemmInstanceFWD =
// #################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| ADataType| BDataType| B1DataType| CDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector| D1BlockTransfer| MaskingSpec| Deterministic| // #################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| ADataType| BDataType| B1DataType| CDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1|Dropout| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector| D1BlockTransfer| MaskingSpec| Deterministic|
// #################################################################################| | | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| _NPerBlock| SrcScalar| | | // #################################################################################| | | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| Step| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| _NPerBlock| SrcScalar| | |
// #################################################################################| | | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | PerVector| | | // #################################################################################| | | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | PerVector| | |
// #################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // #################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, InputDataType, InputDataType, InputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, 1, MaskingSpec, Deterministic>; ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, InputDataType, InputDataType, InputDataType, GemmDataType, ZDataType, LSEDataType, void, void, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, 1, MaskingSpec, Deterministic>;
using DeviceGemmInstanceBWD = using DeviceGemmInstanceBWD =
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic| // ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
......
...@@ -177,8 +177,8 @@ int run(int argc, char* argv[]) ...@@ -177,8 +177,8 @@ int run(int argc, char* argv[])
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()), static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(nullptr), static_cast<ZDataType*>(nullptr),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()), static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases; nullptr, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases; nullptr, // std::array<void*, 1> p_acc1_biases;
a_gs_ms_ks_lengths, a_gs_ms_ks_lengths,
a_gs_ms_ks_strides, a_gs_ms_ks_strides,
b0_gs_ns_ks_lengths, b0_gs_ns_ks_lengths,
......
...@@ -50,11 +50,10 @@ using B1DataType = DataType; ...@@ -50,11 +50,10 @@ using B1DataType = DataType;
using AccDataType = F32; using AccDataType = F32;
using CShuffleDataType = F32; using CShuffleDataType = F32;
using CDataType = DataType; using CDataType = DataType;
using DDataType = F16;
using ZDataType = U16; // INT32 using ZDataType = U16; // INT32
using LSEDataType = F32; using LSEDataType = F32;
using Acc0BiasDataType = ck::Tuple<DDataType>; using Acc0BiasDataType = F16;
using Acc1BiasDataType = ck::Tuple<>; using Acc1BiasDataType = void;
static constexpr ck::index_t NumDimG = 2; static constexpr ck::index_t NumDimG = 2;
static constexpr ck::index_t NumDimM = 1; static constexpr ck::index_t NumDimM = 1;
...@@ -122,6 +121,7 @@ using DeviceGemmInstance = ...@@ -122,6 +121,7 @@ using DeviceGemmInstance =
1, // MXdlPerWave 1, // MXdlPerWave
4, // NXdlPerWave 4, // NXdlPerWave
1, // Gemm1NXdlPerWave 1, // Gemm1NXdlPerWave
1, // DropoutStep
S<4, 64, 1>, // ABlockTransfer S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
...@@ -195,6 +195,7 @@ using DeviceGemmInstance = ...@@ -195,6 +195,7 @@ using DeviceGemmInstance =
1, // MXdlPerWave 1, // MXdlPerWave
4, // NXdlPerWave 4, // NXdlPerWave
2, // Gemm1NXdlPerWave 2, // Gemm1NXdlPerWave
1, // DropoutStep
S<4, 64, 1>, // ABlockTransfer S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
...@@ -268,6 +269,7 @@ using DeviceGemmInstance = ...@@ -268,6 +269,7 @@ using DeviceGemmInstance =
1, // MXdlPerWave 1, // MXdlPerWave
4, // NXdlPerWave 4, // NXdlPerWave
4, // Gemm1NXdlPerWave 4, // Gemm1NXdlPerWave
1, // DropoutStep
S<4, 64, 1>, // ABlockTransfer S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
......
...@@ -48,13 +48,12 @@ using ADataType = DataType; ...@@ -48,13 +48,12 @@ using ADataType = DataType;
using B0DataType = DataType; using B0DataType = DataType;
using B1DataType = DataType; using B1DataType = DataType;
using AccDataType = F32; using AccDataType = F32;
using DDataType = F16;
using CShuffleDataType = F32; using CShuffleDataType = F32;
using CDataType = DataType; using CDataType = DataType;
using ZDataType = U16; // INT32 using ZDataType = U16; // INT32
using LSEDataType = F32; using LSEDataType = F32;
using Acc0BiasDataType = ck::Tuple<DDataType>; using Acc0BiasDataType = F16;
using Acc1BiasDataType = ck::Tuple<>; using Acc1BiasDataType = void;
static constexpr ck::index_t NumDimG = 2; static constexpr ck::index_t NumDimG = 2;
static constexpr ck::index_t NumDimM = 1; static constexpr ck::index_t NumDimM = 1;
...@@ -122,6 +121,7 @@ using DeviceGemmInstance = ...@@ -122,6 +121,7 @@ using DeviceGemmInstance =
1, // MXdlPerWave 1, // MXdlPerWave
4, // NXdlPerWave 4, // NXdlPerWave
1, // Gemm1NXdlPerWave 1, // Gemm1NXdlPerWave
1, // DropoutStep
S<4, 64, 1>, // ABlockTransfer S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
...@@ -195,6 +195,7 @@ using DeviceGemmInstance = ...@@ -195,6 +195,7 @@ using DeviceGemmInstance =
1, // MXdlPerWave 1, // MXdlPerWave
4, // NXdlPerWave 4, // NXdlPerWave
2, // Gemm1NXdlPerWave 2, // Gemm1NXdlPerWave
1, // DropoutStep
S<4, 64, 1>, // ABlockTransfer S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
...@@ -268,6 +269,7 @@ using DeviceGemmInstance = ...@@ -268,6 +269,7 @@ using DeviceGemmInstance =
1, // MXdlPerWave 1, // MXdlPerWave
4, // NXdlPerWave 4, // NXdlPerWave
4, // Gemm1NXdlPerWave 4, // Gemm1NXdlPerWave
1, // DropoutStep
S<4, 64, 1>, // ABlockTransfer S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
......
...@@ -116,7 +116,7 @@ int run(int argc, char* argv[]) ...@@ -116,7 +116,7 @@ int run(int argc, char* argv[])
Tensor<B1DataType> b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); Tensor<B1DataType> b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides);
Tensor<CDataType> c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); Tensor<CDataType> c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
Tensor<CDataType> c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); Tensor<CDataType> c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
Tensor<DDataType> d_gs_ms_ns(d_gs_ms_ns_lengths, z_gs_ms_ns_strides); Tensor<Acc0BiasDataType> d_gs_ms_ns(d_gs_ms_ns_lengths, z_gs_ms_ns_strides);
Tensor<ZDataType> z_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides); Tensor<ZDataType> z_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
Tensor<LSEDataType> lse_gs_ms_host_result(lse_gs_ms_lengths, lse_gs_ms_strides); Tensor<LSEDataType> lse_gs_ms_host_result(lse_gs_ms_lengths, lse_gs_ms_strides);
Tensor<LSEDataType> lse_gs_ms_device_result(lse_gs_ms_lengths, lse_gs_ms_strides); Tensor<LSEDataType> lse_gs_ms_device_result(lse_gs_ms_lengths, lse_gs_ms_strides);
...@@ -137,25 +137,25 @@ int run(int argc, char* argv[]) ...@@ -137,25 +137,25 @@ int run(int argc, char* argv[])
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2}); a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<DDataType>{-2, 2}); d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<Acc0BiasDataType>{-1, 1});
break; break;
case 2: case 2:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}); a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_3<DDataType>{-0.5, 0.5}); d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_3<Acc0BiasDataType>{-0.5, 0.5});
break; break;
case 3: case 3:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2}); a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<DDataType>{1}); d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
break; break;
default: default:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<DDataType>{1}); d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
} }
DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize()); DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize());
...@@ -163,7 +163,7 @@ int run(int argc, char* argv[]) ...@@ -163,7 +163,7 @@ int run(int argc, char* argv[])
DeviceMem b1_device_buf(sizeof(B1DataType) * b1_gs_os_ns.mDesc.GetElementSpaceSize()); DeviceMem b1_device_buf(sizeof(B1DataType) * b1_gs_os_ns.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) * DeviceMem c_device_buf(sizeof(CDataType) *
c_gs_ms_os_device_result.mDesc.GetElementSpaceSize()); c_gs_ms_os_device_result.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(DDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize()); DeviceMem d_device_buf(sizeof(Acc0BiasDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize());
DeviceMem z_device_buf(sizeof(ZDataType) * z_gs_ms_ns.mDesc.GetElementSpaceSize()); DeviceMem z_device_buf(sizeof(ZDataType) * z_gs_ms_ns.mDesc.GetElementSpaceSize());
DeviceMem lse_device_buf(sizeof(LSEDataType) * DeviceMem lse_device_buf(sizeof(LSEDataType) *
lse_gs_ms_device_result.mDesc.GetElementSpaceSize()); lse_gs_ms_device_result.mDesc.GetElementSpaceSize());
...@@ -181,40 +181,40 @@ int run(int argc, char* argv[]) ...@@ -181,40 +181,40 @@ int run(int argc, char* argv[])
// do GEMM // do GEMM
// TODO ANT: replace array with vector? // TODO ANT: replace array with vector?
auto gemm = DeviceGemmInstance{}; auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument( auto argument =
static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()), gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<B0DataType*>(b0_device_buf.GetDeviceBuffer()), static_cast<B0DataType*>(b0_device_buf.GetDeviceBuffer()),
static_cast<B1DataType*>(b1_device_buf.GetDeviceBuffer()), static_cast<B1DataType*>(b1_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()), static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(nullptr), static_cast<ZDataType*>(nullptr),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()), static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
std::array<void*, 1>{d_device_buf.GetDeviceBuffer()}, // std::array<void*, 1> p_acc0_biases; static_cast<Acc0BiasDataType*>(d_device_buf.GetDeviceBuffer()), //
{}, // std::array<void*, 1> p_acc1_biases; nullptr,
a_gs_ms_ks_lengths, a_gs_ms_ks_lengths,
a_gs_ms_ks_strides, a_gs_ms_ks_strides,
b0_gs_ns_ks_lengths, b0_gs_ns_ks_lengths,
b0_gs_ns_ks_strides, b0_gs_ns_ks_strides,
b1_gs_os_ns_lengths, b1_gs_os_ns_lengths,
b1_gs_os_ns_strides, b1_gs_os_ns_strides,
c_gs_ms_os_lengths, c_gs_ms_os_lengths,
c_gs_ms_os_strides, c_gs_ms_os_strides,
z_gs_ms_ns_lengths, z_gs_ms_ns_lengths,
z_gs_ms_ns_strides, z_gs_ms_ns_strides,
lse_gs_ms_lengths, lse_gs_ms_lengths,
std::array<std::vector<ck::index_t>, 1>{d_gs_ms_ns_lengths}, // acc0_biases_gs_ms_ns_lengths d_gs_ms_ns_lengths, // acc0_biases_gs_ms_ns_lengths
std::array<std::vector<ck::index_t>, 1>{d_gs_ms_ns_strides}, // acc0_biases_gs_ms_ns_strides d_gs_ms_ns_strides, // acc0_biases_gs_ms_ns_strides
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths}, {}, // std::vector<ck::index_t>
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides}, {}, // std::vector<ck::index_t>
a_element_op, a_element_op,
b0_element_op, b0_element_op,
acc0_element_op, acc0_element_op,
b1_element_op, b1_element_op,
c_element_op, c_element_op,
p_drop, // dropout ratio p_drop, // dropout ratio
{seed, offset}); // dropout random seed and offset, offset should be at least the number of {seed, offset}); // dropout random seed and offset, offset should be at
// elements on a thread // least the number of elements on a thread
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
{ {
...@@ -227,15 +227,16 @@ int run(int argc, char* argv[]) ...@@ -227,15 +227,16 @@ int run(int argc, char* argv[])
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount; std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount;
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + std::size_t num_bytes =
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O + (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + sizeof(B1DataType) * N * O +
sizeof(DDataType) * M * N * Acc0BiasDataType::Size()) * sizeof(CDataType) * M * O +
BatchCount; sizeof(Acc0BiasDataType) * M * N * (std::is_void<Acc0BiasDataType>::value ? 0 : 1)) *
BatchCount;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time; float gb_per_sec = num_bytes / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl; << gemm.GetTypeString() << std::endl;
...@@ -243,41 +244,38 @@ int run(int argc, char* argv[]) ...@@ -243,41 +244,38 @@ int run(int argc, char* argv[])
if(do_verification) if(do_verification)
{ {
// run for storing z tensor // run for storing z tensor
argument = gemm.MakeArgument( argument =
static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()), gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<B0DataType*>(b0_device_buf.GetDeviceBuffer()), static_cast<B0DataType*>(b0_device_buf.GetDeviceBuffer()),
static_cast<B1DataType*>(b1_device_buf.GetDeviceBuffer()), static_cast<B1DataType*>(b1_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()), static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(z_device_buf.GetDeviceBuffer()), static_cast<ZDataType*>(z_device_buf.GetDeviceBuffer()),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()), static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
std::array<void*, 1>{ static_cast<Acc0BiasDataType*>(d_device_buf.GetDeviceBuffer()),
d_device_buf.GetDeviceBuffer()}, // std::array<void*, 1> p_acc0_biases; nullptr,
{}, // std::array<void*, 1> p_acc1_biases; a_gs_ms_ks_lengths,
a_gs_ms_ks_lengths, a_gs_ms_ks_strides,
a_gs_ms_ks_strides, b0_gs_ns_ks_lengths,
b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides,
b0_gs_ns_ks_strides, b1_gs_os_ns_lengths,
b1_gs_os_ns_lengths, b1_gs_os_ns_strides,
b1_gs_os_ns_strides, c_gs_ms_os_lengths,
c_gs_ms_os_lengths, c_gs_ms_os_strides,
c_gs_ms_os_strides, z_gs_ms_ns_lengths,
z_gs_ms_ns_lengths, z_gs_ms_ns_strides,
z_gs_ms_ns_strides, lse_gs_ms_lengths,
lse_gs_ms_lengths, d_gs_ms_ns_lengths,
std::array<std::vector<ck::index_t>, 1>{ d_gs_ms_ns_strides,
d_gs_ms_ns_lengths}, // acc0_biases_gs_ms_ns_lengths {},
std::array<std::vector<ck::index_t>, 1>{ {},
d_gs_ms_ns_strides}, // acc0_biases_gs_ms_ns_strides a_element_op,
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths}, b0_element_op,
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides}, acc0_element_op,
a_element_op, b1_element_op,
b0_element_op, c_element_op,
acc0_element_op, p_drop, // dropout ratio
b1_element_op, {seed, offset}); // dropout random seed and offset, offset should be
c_element_op, // at least the number of elements on a thread
p_drop, // dropout ratio
{seed, offset}); // dropout random seed and offset, offset should be at least the number
// of elements on a thread
c_device_buf.SetZero(); c_device_buf.SetZero();
lse_device_buf.SetZero(); lse_device_buf.SetZero();
invoker.Run(argument, StreamConfig{nullptr, false}); invoker.Run(argument, StreamConfig{nullptr, false});
...@@ -294,7 +292,7 @@ int run(int argc, char* argv[]) ...@@ -294,7 +292,7 @@ int run(int argc, char* argv[])
Tensor<ADataType> a1_g_m_n_drop({G0 * G1, M, N}); Tensor<ADataType> a1_g_m_n_drop({G0 * G1, M, N});
Tensor<LSEDataType> lse_g_m_host_result( Tensor<LSEDataType> lse_g_m_host_result(
{BatchCount, M}); // scratch object after max + ln(sum) {BatchCount, M}); // scratch object after max + ln(sum)
Tensor<DDataType> d_g_m_n({G0 * G1, M, N}); Tensor<Acc0BiasDataType> d_g_m_n({G0 * G1, M, N});
Tensor<ZDataType> z_g_m_n({G0 * G1, M, N}); Tensor<ZDataType> z_g_m_n({G0 * G1, M, N});
Tensor<CDataType> c_g_m_o_host_result({BatchCount, M, O}); // scratch object after gemm1 Tensor<CDataType> c_g_m_o_host_result({BatchCount, M, O}); // scratch object after gemm1
...@@ -324,12 +322,12 @@ int run(int argc, char* argv[]) ...@@ -324,12 +322,12 @@ int run(int argc, char* argv[])
ref_gemm0_invoker.Run(ref_gemm0_argument); ref_gemm0_invoker.Run(ref_gemm0_argument);
// bias // bias
acc0_g_m_n.ForEach([&](auto& self, auto idx) { self(idx) += d_g_m_n(idx); }); acc0_g_m_n.ForEach([&](auto& self, auto idx) { self(idx) += ck::type_convert<AccDataType>(d_g_m_n(idx)); });
// masking // masking
const auto mask = DeviceGemmInstance::C0MatrixMask(M, N); const auto mask = DeviceGemmInstance::C0MatrixMask(M, N);
acc0_g_m_n.ForEach([&](auto& self, auto idx) { acc0_g_m_n.ForEach([&](auto& self, auto idx) {
if(mask.IsMaskedElement(idx[1], idx[2])) if(mask.IsMaskedElement(idx[1], idx[2]))
self(idx) = -ck::NumericLimits<float>::Infinity(); self(idx) = -ck::NumericLimits<AccDataType>::Infinity();
}); });
// softmax // softmax
......
...@@ -57,7 +57,7 @@ int run(int argc, char* argv[]) ...@@ -57,7 +57,7 @@ int run(int argc, char* argv[])
std::vector<const void*> p_b0; std::vector<const void*> p_b0;
std::vector<const void*> p_b1; std::vector<const void*> p_b1;
std::vector<void*> p_c; std::vector<void*> p_c;
std::vector<std::vector<const void*>> p_d; std::vector<const void*> p_d;
std::vector<void*> p_z; // for result verification std::vector<void*> p_z; // for result verification
std::vector<void*> p_z_nullptr; // for time test std::vector<void*> p_z_nullptr; // for time test
std::vector<void*> p_lse; std::vector<void*> p_lse;
...@@ -67,7 +67,7 @@ int run(int argc, char* argv[]) ...@@ -67,7 +67,7 @@ int run(int argc, char* argv[])
std::vector<Tensor<B0DataType>> b0_tensors; std::vector<Tensor<B0DataType>> b0_tensors;
std::vector<Tensor<B1DataType>> b1_tensors; std::vector<Tensor<B1DataType>> b1_tensors;
std::vector<Tensor<CDataType>> c_tensors; std::vector<Tensor<CDataType>> c_tensors;
std::vector<Tensor<DDataType>> d_tensors; std::vector<Tensor<Acc0BiasDataType>> d_tensors;
std::vector<Tensor<ZDataType>> z_tensors; std::vector<Tensor<ZDataType>> z_tensors;
std::vector<Tensor<LSEDataType>> lse_tensors; std::vector<Tensor<LSEDataType>> lse_tensors;
...@@ -147,10 +147,8 @@ int run(int argc, char* argv[]) ...@@ -147,10 +147,8 @@ int run(int argc, char* argv[])
z_gs_ms_ns_strides, z_gs_ms_ns_strides,
lse_gs_ms_lengths, lse_gs_ms_lengths,
lse_gs_ms_strides, lse_gs_ms_strides,
std::vector<std::vector<ck::index_t>>{ d_gs_ms_ns_lengths, // acc0_biases_gs_ms_ns_lengths
d_gs_ms_ns_lengths}, // acc0_biases_gs_ms_ns_lengths d_gs_ms_ns_strides, // acc0_biases_gs_ms_ns_strides
std::vector<std::vector<ck::index_t>>{
d_gs_ms_ns_strides}, // acc0_biases_gs_ms_ns_strides
{}, // acc1_biases_gs_ms_os_lengths {}, // acc1_biases_gs_ms_os_lengths
{}}); // acc1_biases_gs_ms_os_strides {}}); // acc1_biases_gs_ms_os_strides
...@@ -159,7 +157,7 @@ int run(int argc, char* argv[]) ...@@ -159,7 +157,7 @@ int run(int argc, char* argv[])
Tensor<B0DataType> b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); Tensor<B0DataType> b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides);
Tensor<B1DataType> b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); Tensor<B1DataType> b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides);
Tensor<CDataType> c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); Tensor<CDataType> c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
Tensor<DDataType> d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides); Tensor<Acc0BiasDataType> d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides);
Tensor<ZDataType> z_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides); Tensor<ZDataType> z_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
Tensor<LSEDataType> lse_gs_ms_device_result(lse_gs_ms_lengths, lse_gs_ms_strides); Tensor<LSEDataType> lse_gs_ms_device_result(lse_gs_ms_lengths, lse_gs_ms_strides);
...@@ -167,7 +165,7 @@ int run(int argc, char* argv[]) ...@@ -167,7 +165,7 @@ int run(int argc, char* argv[])
flop += (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * Batch; flop += (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * Batch;
num_byte += (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + num_byte += (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O + sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O +
sizeof(DDataType) * M * N * (Acc0BiasDataType::Size() ? 0 : 1)) * sizeof(Acc0BiasDataType) * M * N * (std::is_void<Acc0BiasDataType>::value ? 0 : 1)) *
Batch; Batch;
if(i < 4) if(i < 4)
...@@ -191,25 +189,25 @@ int run(int argc, char* argv[]) ...@@ -191,25 +189,25 @@ int run(int argc, char* argv[])
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2}); a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<DDataType>{-1, 1}); d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<Acc0BiasDataType>{-1, 1});
break; break;
case 2: case 2:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}); a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_3<DDataType>{-0.5, 0.5}); d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_3<Acc0BiasDataType>{-0.5, 0.5});
break; break;
case 3: case 3:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2}); a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<DDataType>{1}); d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
break; break;
default: default:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{1}); a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<DDataType>{1}); d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
} }
a_tensors.push_back(a_gs_ms_ks); a_tensors.push_back(a_gs_ms_ks);
...@@ -229,7 +227,7 @@ int run(int argc, char* argv[]) ...@@ -229,7 +227,7 @@ int run(int argc, char* argv[])
c_tensors_device.emplace_back(std::make_unique<DeviceMem>( c_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(CDataType) * c_gs_ms_os_device_result.mDesc.GetElementSpaceSize())); sizeof(CDataType) * c_gs_ms_os_device_result.mDesc.GetElementSpaceSize()));
d_tensors_device.emplace_back(std::make_unique<DeviceMem>( d_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(DDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize())); sizeof(Acc0BiasDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize()));
z_tensors_device.emplace_back(std::make_unique<DeviceMem>( z_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(ZDataType) * z_gs_ms_ns.mDesc.GetElementSpaceSize())); sizeof(ZDataType) * z_gs_ms_ns.mDesc.GetElementSpaceSize()));
lse_tensors_device.emplace_back(std::make_unique<DeviceMem>( lse_tensors_device.emplace_back(std::make_unique<DeviceMem>(
...@@ -244,9 +242,7 @@ int run(int argc, char* argv[]) ...@@ -244,9 +242,7 @@ int run(int argc, char* argv[])
p_b0.push_back(b0_tensors_device[i]->GetDeviceBuffer()); p_b0.push_back(b0_tensors_device[i]->GetDeviceBuffer());
p_b1.push_back(b1_tensors_device[i]->GetDeviceBuffer()); p_b1.push_back(b1_tensors_device[i]->GetDeviceBuffer());
p_c.push_back(c_tensors_device[i]->GetDeviceBuffer()); p_c.push_back(c_tensors_device[i]->GetDeviceBuffer());
p_d.push_back({d_tensors_device[i]->GetDeviceBuffer()}); p_d.push_back(d_tensors_device[i]->GetDeviceBuffer());
// std::cout << "from host group id: " << i << " d address: " <<
// d_tensors_device[i]->GetDeviceBuffer() << std::endl;
p_z.push_back(z_tensors_device[i]->GetDeviceBuffer()); p_z.push_back(z_tensors_device[i]->GetDeviceBuffer());
p_z_nullptr.push_back(nullptr); p_z_nullptr.push_back(nullptr);
p_lse.push_back(lse_tensors_device[i]->GetDeviceBuffer()); p_lse.push_back(lse_tensors_device[i]->GetDeviceBuffer());
...@@ -363,7 +359,7 @@ int run(int argc, char* argv[]) ...@@ -363,7 +359,7 @@ int run(int argc, char* argv[])
Tensor<B0DataType> b0_g_k_n({G0 * G1, K, N}); Tensor<B0DataType> b0_g_k_n({G0 * G1, K, N});
Tensor<B1DataType> b1_g_n_o({G0 * G1, N, O}); Tensor<B1DataType> b1_g_n_o({G0 * G1, N, O});
Tensor<AccDataType> acc0_g_m_n({G0 * G1, M, N}); // scratch object after gemm0 Tensor<AccDataType> acc0_g_m_n({G0 * G1, M, N}); // scratch object after gemm0
Tensor<AccDataType> d_g_m_n({G0 * G1, M, N}); Tensor<Acc0BiasDataType> d_g_m_n({G0 * G1, M, N});
Tensor<ADataType> a1_g_m_n({G0 * G1, M, N}); // scratch object after softmax Tensor<ADataType> a1_g_m_n({G0 * G1, M, N}); // scratch object after softmax
Tensor<ADataType> a1_g_m_n_drop({G0 * G1, M, N}); // scratch object after softmax Tensor<ADataType> a1_g_m_n_drop({G0 * G1, M, N}); // scratch object after softmax
Tensor<CDataType> c_g_m_o_host_result({G0 * G1, M, O}); // scratch object after gemm1 Tensor<CDataType> c_g_m_o_host_result({G0 * G1, M, O}); // scratch object after gemm1
...@@ -400,12 +396,12 @@ int run(int argc, char* argv[]) ...@@ -400,12 +396,12 @@ int run(int argc, char* argv[])
ref_gemm0_invoker.Run(ref_gemm0_argument); ref_gemm0_invoker.Run(ref_gemm0_argument);
// bias // bias
acc0_g_m_n.ForEach([&](auto& self, auto idx) { self(idx) += d_g_m_n(idx); }); acc0_g_m_n.ForEach([&](auto& self, auto idx) { self(idx) += ck::type_convert<AccDataType>(d_g_m_n(idx)); });
// masking // masking
const auto mask = DeviceGemmInstance::C0MatrixMask(M, N); const auto mask = DeviceGemmInstance::C0MatrixMask(M, N);
acc0_g_m_n.ForEach([&](auto& self, auto idx) { acc0_g_m_n.ForEach([&](auto& self, auto idx) {
if(mask.IsMaskedElement(idx[1], idx[2])) if(mask.IsMaskedElement(idx[1], idx[2]))
self(idx) = -ck::NumericLimits<float>::Infinity(); self(idx) = -ck::NumericLimits<AccDataType>::Infinity();
}); });
// softmax // softmax
......
...@@ -138,12 +138,12 @@ struct BlockwiseDropout ...@@ -138,12 +138,12 @@ struct BlockwiseDropout
constexpr int tmp_size = MRepeat * KRepeat; constexpr int tmp_size = MRepeat * KRepeat;
int philox_calls = tmp_size / 4; int philox_calls = tmp_size / 8;
ushort tmp[tmp_size]; ushort tmp[tmp_size];
for(int i = 0; i < philox_calls; i++) for(int i = 0; i < philox_calls; i++)
{ {
ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * Offset{} * MRaw); ph.get_random_8x16((tmp + i * 8), element_global_1d_id + i * Offset{} * MRaw);
} }
block_sync_lds(); block_sync_lds();
...@@ -179,12 +179,12 @@ struct BlockwiseDropout ...@@ -179,12 +179,12 @@ struct BlockwiseDropout
constexpr int tmp_size = MRepeat * KRepeat; constexpr int tmp_size = MRepeat * KRepeat;
int philox_calls = tmp_size / 4; int philox_calls = tmp_size / 8;
ushort tmp[tmp_size]; ushort tmp[tmp_size];
for(int i = 0; i < philox_calls; i++) for(int i = 0; i < philox_calls; i++)
{ {
ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * Offset{} * MRaw); ph.get_random_8x16((tmp + i * 8), element_global_1d_id + i * Offset{} * MRaw);
} }
block_sync_lds(); block_sync_lds();
...@@ -218,21 +218,19 @@ struct BlockwiseDropout ...@@ -218,21 +218,19 @@ struct BlockwiseDropout
} }
// get raw z matrix with random number for shuffle // get raw z matrix with random number for shuffle
template <typename ZThreadBuffer, template <typename ZThreadBuffer, typename Step, typename Offset>
typename Step,
typename Offset> // N3*N4=8
__host__ __device__ void GenerateZMatrixAttnFwd(ck::philox& ph, __host__ __device__ void GenerateZMatrixAttnFwd(ck::philox& ph,
index_t element_global_1d_id, index_t element_global_1d_id,
ZThreadBuffer& z_thread_buf) ZThreadBuffer& z_thread_buf)
{ {
constexpr int tmp_size = MRepeat * KRepeat / Step{}.value; constexpr int tmp_size = MRepeat * KRepeat / Step{}.value;
int philox_calls = tmp_size / 4; int philox_calls = tmp_size / 8;
ushort tmp[tmp_size]; ushort tmp[tmp_size];
for(int i = 0; i < philox_calls; i++) for(int i = 0; i < philox_calls; i++)
{ {
ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * Offset{}); ph.get_random_8x16((tmp + i * 8), element_global_1d_id + i * Offset{});
} }
static_for<0, tmp_size, 1>{}([&](auto i) { z_thread_buf(i) = tmp[i.value]; }); static_for<0, tmp_size, 1>{}([&](auto i) { z_thread_buf(i) = tmp[i.value]; });
......
...@@ -87,9 +87,6 @@ template <index_t NumDimG, ...@@ -87,9 +87,6 @@ template <index_t NumDimG,
MaskingSpecialization MaskingSpec> MaskingSpecialization MaskingSpec>
struct DeviceBatchedMultiheadAttentionForward : public BaseOperator struct DeviceBatchedMultiheadAttentionForward : public BaseOperator
{ {
static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer( virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
const void* p_a, const void* p_a,
const void* p_b0, const void* p_b0,
...@@ -97,8 +94,8 @@ struct DeviceBatchedMultiheadAttentionForward : public BaseOperator ...@@ -97,8 +94,8 @@ struct DeviceBatchedMultiheadAttentionForward : public BaseOperator
void* p_c, void* p_c,
void* p_z, void* p_z,
void* p_lse, void* p_lse,
const std::array<void*, NumAcc0Bias> p_acc0_biases, const void* p_acc0_bias,
const std::array<void*, NumAcc1Bias> p_acc1_biases, const void* p_acc1_bias,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_lengths,
...@@ -110,12 +107,10 @@ struct DeviceBatchedMultiheadAttentionForward : public BaseOperator ...@@ -110,12 +107,10 @@ struct DeviceBatchedMultiheadAttentionForward : public BaseOperator
const std::vector<index_t>& z_gs_ms_ns_lengths, // z_gs_ms_os_lengths const std::vector<index_t>& z_gs_ms_ns_lengths, // z_gs_ms_os_lengths
const std::vector<index_t>& z_gs_ms_ns_strides, // z_gs_ms_os_strides const std::vector<index_t>& z_gs_ms_ns_strides, // z_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths, // lse_gs_ms_lengths const std::vector<index_t>& lse_gs_ms_lengths, // lse_gs_ms_lengths
const std::array<std::vector<index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths, const std::vector<index_t>& acc0_bias_gs_ms_ns_lengths,
const std::array<std::vector<index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides, const std::vector<index_t>& acc0_bias_gs_ms_ns_strides,
const std::array<std::vector<index_t>, NumAcc1Bias> const std::vector<index_t>& acc1_bias_gs_ms_gemm1ns_lengths,
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths const std::vector<index_t>& acc1_bias_gs_ms_gemm1ns_strides,
const std::array<std::vector<index_t>, NumAcc1Bias>
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
B0ElementwiseOperation b0_element_op, B0ElementwiseOperation b0_element_op,
Acc0ElementwiseOperation acc0_element_op, Acc0ElementwiseOperation acc0_element_op,
......
...@@ -111,11 +111,11 @@ struct DeviceGroupedMultiheadAttentionForward : public BaseOperator ...@@ -111,11 +111,11 @@ struct DeviceGroupedMultiheadAttentionForward : public BaseOperator
std::vector<index_t> lse_gs_ms_lengths; std::vector<index_t> lse_gs_ms_lengths;
std::vector<index_t> lse_gs_ms_strides; std::vector<index_t> lse_gs_ms_strides;
std::vector<std::vector<index_t>> acc0_biases_gs_ms_ns_lengths; std::vector<index_t> acc0_biases_gs_ms_ns_lengths;
std::vector<std::vector<index_t>> acc0_biases_gs_ms_ns_strides; std::vector<index_t> acc0_biases_gs_ms_ns_strides;
std::vector<std::vector<index_t>> acc1_biases_gs_ms_os_lengths; std::vector<index_t> acc1_biases_gs_ms_os_lengths;
std::vector<std::vector<index_t>> acc1_biases_gs_ms_os_strides; std::vector<index_t> acc1_biases_gs_ms_os_strides;
}; };
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
...@@ -125,9 +125,9 @@ struct DeviceGroupedMultiheadAttentionForward : public BaseOperator ...@@ -125,9 +125,9 @@ struct DeviceGroupedMultiheadAttentionForward : public BaseOperator
std::vector<void*> p_c_vec, std::vector<void*> p_c_vec,
std::vector<void*> p_z_vec, std::vector<void*> p_z_vec,
std::vector<void*> p_lse_vec, std::vector<void*> p_lse_vec,
std::vector<std::vector<const void*>> p_acc0_biases_vec, std::vector<const void*> p_acc0_bias_vec,
std::vector<std::vector<const void*>> p_acc1_biases_vec, std::vector<const void*> p_acc1_bias_vec,
std::vector<ProblemDesc> problem_desc_vec, std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
B0ElementwiseOperation b0_element_op, B0ElementwiseOperation b0_element_op,
Acc0ElementwiseOperation acc0_element_op, Acc0ElementwiseOperation acc0_element_op,
......
...@@ -289,12 +289,6 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1 ...@@ -289,12 +289,6 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
"Number of dimension must be greater than 0"); "Number of dimension must be greater than 0");
static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
// TODO ANT: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
#if 0 #if 0
// TODO ANT: use alias // TODO ANT: use alias
static constexpr index_t NumDimGemm0M = NumDimM; static constexpr index_t NumDimGemm0M = NumDimM;
...@@ -535,39 +529,36 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1 ...@@ -535,39 +529,36 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
// FIXME: constness // FIXME: constness
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument( Argument(const ADataType* p_a_grid,
const ADataType* p_a_grid, const BDataType* p_b_grid,
const BDataType* p_b_grid, const B1DataType* p_b1_grid,
const B1DataType* p_b1_grid, CDataType* p_c_grid,
CDataType* p_c_grid, ZDataType* p_z_grid,
ZDataType* p_z_grid, LSEDataType* p_lse_grid,
LSEDataType* p_lse_grid, const void* p_acc0_bias,
const std::array<void*, NumAcc0Bias> p_acc0_biases, const void* p_acc1_bias,
const std::array<void*, NumAcc1Bias> p_acc1_biases, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& b_gs_ns_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_strides, const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_lengths, const std::vector<index_t>& z_gs_ms_ns_strides,
const std::vector<index_t>& z_gs_ms_ns_strides, const std::vector<index_t>& lse_gs_ms_lengths,
const std::vector<index_t>& lse_gs_ms_lengths, const std::vector<ck::index_t> acc0_bias_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths, const std::vector<ck::index_t> acc0_bias_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides, const std::vector<ck::index_t> acc1_bias_gs_ms_gemm1ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc1Bias> const std::vector<ck::index_t> acc1_bias_gs_ms_gemm1ns_strides,
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths AElementwiseOperation a_element_op,
const std::array<std::vector<ck::index_t>, NumAcc1Bias> BElementwiseOperation b_element_op,
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides AccElementwiseOperation acc_element_op,
AElementwiseOperation a_element_op, B1ElementwiseOperation b1_element_op,
BElementwiseOperation b_element_op, CElementwiseOperation c_element_op,
AccElementwiseOperation acc_element_op, float p_dropout,
B1ElementwiseOperation b1_element_op, std::tuple<unsigned long long, unsigned long long> seeds)
CElementwiseOperation c_element_op,
float p_dropout,
std::tuple<unsigned long long, unsigned long long> seeds)
: p_a_grid_{p_a_grid}, : p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
p_b1_grid_{p_b1_grid}, p_b1_grid_{p_b1_grid},
...@@ -624,12 +615,12 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1 ...@@ -624,12 +615,12 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize())} type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize())}
{ {
// TODO ANT: implement bias addition // TODO ANT: implement bias addition
ignore = p_acc0_biases; ignore = p_acc0_bias;
ignore = p_acc1_biases; ignore = p_acc1_bias;
ignore = acc0_biases_gs_ms_ns_lengths; ignore = acc0_bias_gs_ms_ns_lengths;
ignore = acc0_biases_gs_ms_ns_strides; ignore = acc0_bias_gs_ms_ns_strides;
ignore = acc1_biases_gs_ms_gemm1ns_lengths; ignore = acc1_bias_gs_ms_gemm1ns_lengths;
ignore = acc1_biases_gs_ms_gemm1ns_strides; ignore = acc1_bias_gs_ms_gemm1ns_strides;
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_, b_grid_desc_bk0_n_bk1_,
...@@ -984,39 +975,37 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1 ...@@ -984,39 +975,37 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg)); return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
} }
static auto MakeArgument( static auto
const ADataType* p_a, MakeArgument(const ADataType* p_a,
const BDataType* p_b, const BDataType* p_b,
const B1DataType* p_b1, const B1DataType* p_b1,
CDataType* p_c, CDataType* p_c,
ZDataType* p_z, ZDataType* p_z,
LSEDataType* p_lse, LSEDataType* p_lse,
const std::array<void*, NumAcc0Bias> p_acc0_biases, const void* p_acc0_bias,
const std::array<void*, NumAcc1Bias> p_acc1_biases, const void* p_acc1_bias,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_strides, const std::vector<index_t>& b_gs_ns_ks_strides,
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& z_gs_ms_ns_lengths, const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides, const std::vector<index_t>& z_gs_ms_ns_strides,
const std::vector<index_t>& lse_gs_ms_lengths, const std::vector<index_t>& lse_gs_ms_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths, const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides, const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias> const std::vector<ck::index_t>& acc1_bias_gs_ms_gemm1ns_lengths,
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths const std::vector<ck::index_t>& acc1_bias_gs_ms_gemm1ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias> AElementwiseOperation a_element_op,
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides BElementwiseOperation b_element_op,
AElementwiseOperation a_element_op, AccElementwiseOperation acc_element_op,
BElementwiseOperation b_element_op, B1ElementwiseOperation b1_element_op,
AccElementwiseOperation acc_element_op, CElementwiseOperation c_element_op,
B1ElementwiseOperation b1_element_op, float p_dropout,
CElementwiseOperation c_element_op, std::tuple<unsigned long long, unsigned long long> seeds)
float p_dropout,
std::tuple<unsigned long long, unsigned long long> seeds)
{ {
return Argument{p_a, return Argument{p_a,
p_b, p_b,
...@@ -1024,8 +1013,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1 ...@@ -1024,8 +1013,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
p_c, p_c,
p_z, p_z,
p_lse, p_lse,
p_acc0_biases, p_acc0_bias,
p_acc1_biases, p_acc1_bias,
a_gs_ms_ks_lengths, a_gs_ms_ks_lengths,
a_gs_ms_ks_strides, a_gs_ms_ks_strides,
b_gs_ns_ks_lengths, b_gs_ns_ks_lengths,
...@@ -1037,10 +1026,10 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1 ...@@ -1037,10 +1026,10 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
z_gs_ms_ns_lengths, z_gs_ms_ns_lengths,
z_gs_ms_ns_strides, z_gs_ms_ns_strides,
lse_gs_ms_lengths, lse_gs_ms_lengths,
acc0_biases_gs_ms_ns_lengths, acc0_bias_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides, acc0_bias_gs_ms_ns_strides,
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths acc1_bias_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides acc1_bias_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
a_element_op, a_element_op,
b_element_op, b_element_op,
acc_element_op, acc_element_op,
...@@ -1061,8 +1050,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1 ...@@ -1061,8 +1050,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
void* p_c, void* p_c,
void* p_z, void* p_z,
void* p_lse, void* p_lse,
const std::array<void*, NumAcc0Bias> p_acc0_biases, const void* p_acc0_bias,
const std::array<void*, NumAcc1Bias> p_acc1_biases, const void* p_acc1_bias,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_lengths,
...@@ -1074,12 +1063,10 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1 ...@@ -1074,12 +1063,10 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
const std::vector<index_t>& z_gs_ms_ns_lengths, const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides, const std::vector<index_t>& z_gs_ms_ns_strides,
const std::vector<index_t>& lse_gs_ms_lengths, const std::vector<index_t>& lse_gs_ms_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths, const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides, const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias> const std::vector<ck::index_t>& acc1_bias_gs_ms_gemm1ns_lengths,
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths const std::vector<ck::index_t>& acc1_bias_gs_ms_gemm1ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op, AccElementwiseOperation acc_element_op,
...@@ -1094,8 +1081,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1 ...@@ -1094,8 +1081,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
static_cast<CDataType*>(p_c), static_cast<CDataType*>(p_c),
static_cast<ZDataType*>(p_z), static_cast<ZDataType*>(p_z),
static_cast<LSEDataType*>(p_lse), static_cast<LSEDataType*>(p_lse),
p_acc0_biases, // cast in struct Argument p_acc0_bias, // cast in struct Argument
p_acc1_biases, // cast in struct Argument p_acc1_bias, // cast in struct Argument
a_gs_ms_ks_lengths, a_gs_ms_ks_lengths,
a_gs_ms_ks_strides, a_gs_ms_ks_strides,
b_gs_ns_ks_lengths, b_gs_ns_ks_lengths,
...@@ -1107,10 +1094,10 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1 ...@@ -1107,10 +1094,10 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
z_gs_ms_ns_lengths, z_gs_ms_ns_lengths,
z_gs_ms_ns_strides, z_gs_ms_ns_strides,
lse_gs_ms_lengths, lse_gs_ms_lengths,
acc0_biases_gs_ms_ns_lengths, acc0_bias_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides, acc0_bias_gs_ms_ns_strides,
acc1_biases_gs_ms_gemm1ns_lengths, acc1_bias_gs_ms_gemm1ns_lengths,
acc1_biases_gs_ms_gemm1ns_strides, acc1_bias_gs_ms_gemm1ns_strides,
a_element_op, a_element_op,
b_element_op, b_element_op,
acc_element_op, acc_element_op,
......
...@@ -25,7 +25,7 @@ namespace device { ...@@ -25,7 +25,7 @@ namespace device {
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatAB,
typename D0sPointer, typename D0DataType,
typename FloatC, typename FloatC,
typename ZDataType, typename ZDataType,
typename FloatLSE, typename FloatLSE,
...@@ -37,10 +37,10 @@ template <typename GridwiseGemm, ...@@ -37,10 +37,10 @@ template <typename GridwiseGemm,
typename CElementwiseOperation, typename CElementwiseOperation,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5, typename D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
typename B1GridDesc_BK0_N_BK1, typename B1GridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5, typename ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6,
typename LSEGridDescriptor_M, typename LSEGridDescriptor_M,
typename Block2CTileMap, typename Block2CTileMap,
typename ComputeBasePtrOfStridedBatch, typename ComputeBasePtrOfStridedBatch,
...@@ -56,7 +56,7 @@ __global__ void ...@@ -56,7 +56,7 @@ __global__ void
kernel_batched_multiheadattention_forward_xdl_cshuffle_v2( kernel_batched_multiheadattention_forward_xdl_cshuffle_v2(
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
D0sPointer p_d0s_grid, const D0DataType* __restrict__ p_d0_grid,
const FloatAB* __restrict__ p_b1_grid, const FloatAB* __restrict__ p_b1_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
ZDataType* __restrict__ p_z_grid, ZDataType* __restrict__ p_z_grid,
...@@ -68,13 +68,13 @@ __global__ void ...@@ -68,13 +68,13 @@ __global__ void
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 const D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6,
const LSEGridDescriptor_M lse_grid_desc_m, const LSEGridDescriptor_M lse_grid_desc_m,
const Block2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
const index_t batch_count, const index_t batch_count,
...@@ -107,11 +107,15 @@ __global__ void ...@@ -107,11 +107,15 @@ __global__ void
static_cast<long_index_t>(compute_base_ptr_of_batch.GetZBasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetZBasePtr(g_idx)));
const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetLSEBasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetLSEBasePtr(g_idx)));
static_for<0, p_d0s_grid.Size(), 1>{}([&](auto In) { const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane(
const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane( static_cast<long_index_t>(compute_base_ptr_of_batch.GetD0BasePtr(g_idx)));
static_cast<long_index_t>(compute_base_ptr_of_batch.GetD0BasePtr(g_idx, In)));
p_d0s_grid(In) = p_d0s_grid(In) + d0_batch_offset; const D0DataType* tmp_p_d0_grid = nullptr;
});
if constexpr(!is_same<D0DataType, void>::value)
{
tmp_p_d0_grid = p_d0_grid + d0_batch_offset;
}
// const index_t global_thread_id = get_thread_global_1d_id(); // const index_t global_thread_id = get_thread_global_1d_id();
ck::philox ph(seed, 0, offset); ck::philox ph(seed, 0, offset);
...@@ -125,7 +129,7 @@ __global__ void ...@@ -125,7 +129,7 @@ __global__ void
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>( GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>(
p_a_grid + a_batch_offset, p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset, p_b_grid + b_batch_offset,
p_d0s_grid, tmp_p_d0_grid,
p_b1_grid + b1_batch_offset, p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset, p_c_grid + c_batch_offset,
p_z_grid == nullptr ? nullptr : p_z_grid + z_batch_offset, p_z_grid == nullptr ? nullptr : p_z_grid + z_batch_offset,
...@@ -138,10 +142,10 @@ __global__ void ...@@ -138,10 +142,10 @@ __global__ void
c_element_op, c_element_op,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6,
lse_grid_desc_m, lse_grid_desc_m,
block_2_ctile_map, block_2_ctile_map,
c0_matrix_mask, c0_matrix_mask,
...@@ -158,7 +162,7 @@ __global__ void ...@@ -158,7 +162,7 @@ __global__ void
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>( GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>(
p_a_grid + a_batch_offset, p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset, p_b_grid + b_batch_offset,
p_d0s_grid, tmp_p_d0_grid,
p_b1_grid + b1_batch_offset, p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset, p_c_grid + c_batch_offset,
p_z_grid == nullptr ? nullptr : p_z_grid + z_batch_offset, p_z_grid == nullptr ? nullptr : p_z_grid + z_batch_offset,
...@@ -171,10 +175,10 @@ __global__ void ...@@ -171,10 +175,10 @@ __global__ void
c_element_op, c_element_op,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6,
lse_grid_desc_m, lse_grid_desc_m,
block_2_ctile_map, block_2_ctile_map,
c0_matrix_mask, c0_matrix_mask,
...@@ -188,7 +192,7 @@ __global__ void ...@@ -188,7 +192,7 @@ __global__ void
#else #else
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b_grid; ignore = p_b_grid;
ignore = p_d0s_grid; ignore = p_d0_grid;
ignore = p_b1_grid; ignore = p_b1_grid;
ignore = p_c_grid; ignore = p_c_grid;
ignore = p_z_grid; ignore = p_z_grid;
...@@ -200,10 +204,10 @@ __global__ void ...@@ -200,10 +204,10 @@ __global__ void
ignore = c_element_op; ignore = c_element_op;
ignore = a_grid_desc_ak0_m_ak1; ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1; ignore = b_grid_desc_bk0_n_bk1;
ignore = d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5; ignore = d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5;
ignore = b1_grid_desc_bk0_n_bk1; ignore = b1_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5; ignore = z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6;
ignore = lse_grid_desc_m; ignore = lse_grid_desc_m;
ignore = block_2_ctile_map; ignore = block_2_ctile_map;
ignore = batch_count; ignore = batch_count;
...@@ -263,6 +267,7 @@ template <index_t NumDimG, ...@@ -263,6 +267,7 @@ template <index_t NumDimG,
index_t MXdlPerWave, index_t MXdlPerWave,
index_t NXdlPerWave, index_t NXdlPerWave,
index_t Gemm1NXdlPerWave, index_t Gemm1NXdlPerWave,
index_t DropoutStep,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1, typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder, typename ABlockTransferSrcAccessOrder,
...@@ -317,11 +322,10 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -317,11 +322,10 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
"Number of dimension must be greater than 0"); "Number of dimension must be greater than 0");
static constexpr index_t NumD0Tensor = Acc0BiasDataType::Size();
static constexpr index_t NumD1Tensor = Acc1BiasDataType::Size();
// TODO ANT: implement bias combination // TODO ANT: implement bias combination
static_assert(NumD1Tensor == 0, "Acc1 Bias addition is unimplemented"); static_assert(std::is_void<Acc1BiasDataType>::value, "Acc1 Bias addition is unimplemented");
using D0DataType = Acc0BiasDataType;
using D1DataType = Acc1BiasDataType;
#if 0 #if 0
// TODO ANT: use alias // TODO ANT: use alias
...@@ -405,40 +409,16 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -405,40 +409,16 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
} }
} }
static auto MakeD0sGridDescriptor_M_N(
const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_strides)
{
return generate_tuple(
[&](auto i) {
return Transform::MakeCGridDescriptor_M_N(acc0_biases_gs_ms_ns_lengths[i],
acc0_biases_gs_ms_ns_strides[i]);
},
Number<NumD0Tensor>{});
}
static auto MakeD0sGridDescriptor_G_M_N(
const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_strides)
{
return generate_tuple(
[&](auto i) {
return Transform::MakeCGridDescriptor_G_M_N(acc0_biases_gs_ms_ns_lengths[i],
acc0_biases_gs_ms_ns_strides[i]);
},
Number<NumD0Tensor>{});
}
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {})); using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {})); using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using D0sGridDesc_M_N = decltype(MakeD0sGridDescriptor_M_N({}, {})); using D0GridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {})); using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {}));
using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {})); using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using ZGridDesc_M_N = decltype(MakeZGridDescriptor_M_N({}, {})); using ZGridDesc_M_N = decltype(MakeZGridDescriptor_M_N({}, {}));
using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1)); using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1));
using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {})); using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {}));
using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {})); using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {}));
using D0sGridDesc_G_M_N = decltype(MakeD0sGridDescriptor_G_M_N({}, {})); using D0GridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {})); using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
...@@ -462,16 +442,17 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -462,16 +442,17 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
struct ComputeBasePtrOfStridedBatch struct ComputeBasePtrOfStridedBatch
{ {
ComputeBasePtrOfStridedBatch() {}
ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k, ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k,
const BGridDesc_G_N_K& b_grid_desc_g_n_k, const BGridDesc_G_N_K& b_grid_desc_g_n_k,
const D0sGridDesc_G_M_N& d0s_grid_desc_g_m_n, const D0GridDesc_G_M_N& d0_grid_desc_g_m_n,
const B1GridDesc_G_N_K& b1_grid_desc_g_n_k, const B1GridDesc_G_N_K& b1_grid_desc_g_n_k,
const CGridDesc_G_M_N& c_grid_desc_g_m_n, const CGridDesc_G_M_N& c_grid_desc_g_m_n,
const ZGridDesc_G_M_N& z_grid_desc_g_m_n, const ZGridDesc_G_M_N& z_grid_desc_g_m_n,
index_t BatchStrideLSE) index_t BatchStrideLSE)
: a_grid_desc_g_m_k_(a_grid_desc_g_m_k), : a_grid_desc_g_m_k_(a_grid_desc_g_m_k),
b_grid_desc_g_n_k_(b_grid_desc_g_n_k), b_grid_desc_g_n_k_(b_grid_desc_g_n_k),
d0s_grid_desc_g_m_n_(d0s_grid_desc_g_m_n), d0_grid_desc_g_m_n_(d0_grid_desc_g_m_n),
b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k), b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k),
c_grid_desc_g_m_n_(c_grid_desc_g_m_n), c_grid_desc_g_m_n_(c_grid_desc_g_m_n),
z_grid_desc_g_m_n_(z_grid_desc_g_m_n), z_grid_desc_g_m_n_(z_grid_desc_g_m_n),
...@@ -489,11 +470,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -489,11 +470,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
return b_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0)); return b_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
} }
template <index_t I> __host__ __device__ constexpr long_index_t GetD0BasePtr(index_t g_idx) const
__host__ __device__ constexpr long_index_t GetD0BasePtr(index_t g_idx,
Number<I> d0_idx) const
{ {
return d0s_grid_desc_g_m_n_[d0_idx].CalculateOffset(make_multi_index(g_idx, 0, 0)); return d0_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
} }
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const __host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
...@@ -519,7 +498,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -519,7 +498,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
private: private:
AGridDesc_G_M_K a_grid_desc_g_m_k_; AGridDesc_G_M_K a_grid_desc_g_m_k_;
BGridDesc_G_N_K b_grid_desc_g_n_k_; BGridDesc_G_N_K b_grid_desc_g_n_k_;
D0sGridDesc_G_M_N d0s_grid_desc_g_m_n_; D0GridDesc_G_M_N d0_grid_desc_g_m_n_;
B1GridDesc_G_N_K b1_grid_desc_g_n_k_; B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_; CGridDesc_G_M_N c_grid_desc_g_m_n_;
ZGridDesc_G_M_N z_grid_desc_g_m_n_; ZGridDesc_G_M_N z_grid_desc_g_m_n_;
...@@ -544,7 +523,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -544,7 +523,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1, AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1, BGridDesc_BK0_N_BK1,
D0sGridDesc_M_N, D0GridDesc_M_N,
B1GridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1,
CGridDesc_M_N, CGridDesc_M_N,
ZGridDesc_M_N, ZGridDesc_M_N,
...@@ -564,6 +543,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -564,6 +543,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
MXdlPerWave, MXdlPerWave,
NXdlPerWave, NXdlPerWave,
Gemm1NXdlPerWave, Gemm1NXdlPerWave,
DropoutStep,
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
...@@ -603,41 +583,41 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -603,41 +583,41 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
// FIXME: constness // FIXME: constness
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument( Argument(const ADataType* p_a_grid,
const ADataType* p_a_grid, const BDataType* p_b_grid,
const BDataType* p_b_grid, const B1DataType* p_b1_grid,
const B1DataType* p_b1_grid, CDataType* p_c_grid,
CDataType* p_c_grid, ZDataType* p_z_grid,
ZDataType* p_z_grid, LSEDataType* p_lse_grid,
LSEDataType* p_lse_grid, const D0DataType* p_acc0_biases,
const std::array<void*, NumD0Tensor> p_acc0_biases, const D1DataType* p_acc1_biases,
const std::array<void*, NumD1Tensor> p_acc1_biases, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& b_gs_ns_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_strides, const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_lengths, const std::vector<index_t>& z_gs_ms_ns_strides,
const std::vector<index_t>& z_gs_ms_ns_strides, const std::vector<index_t>& lse_gs_ms_lengths,
const std::vector<index_t>& lse_gs_ms_lengths, const std::vector<index_t>& acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumD0Tensor> acc0_biases_gs_ms_ns_lengths, const std::vector<index_t>& acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumD0Tensor> acc0_biases_gs_ms_ns_strides, const std::vector<index_t>&
const std::array<std::vector<ck::index_t>, NumD1Tensor> acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths const std::vector<index_t>&
const std::array<std::vector<ck::index_t>, NumD1Tensor> acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides AElementwiseOperation a_element_op,
AElementwiseOperation a_element_op, BElementwiseOperation b_element_op,
BElementwiseOperation b_element_op, AccElementwiseOperation acc_element_op,
AccElementwiseOperation acc_element_op, B1ElementwiseOperation b1_element_op,
B1ElementwiseOperation b1_element_op, CElementwiseOperation c_element_op,
CElementwiseOperation c_element_op, float p_dropout,
float p_dropout, std::tuple<unsigned long long, unsigned long long> seeds)
std::tuple<unsigned long long, unsigned long long> seeds)
: p_a_grid_{p_a_grid}, : p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
p_d0_grid_{p_acc0_biases},
p_b1_grid_{p_b1_grid}, p_b1_grid_{p_b1_grid},
p_c_grid_{p_c_grid}, p_c_grid_{p_c_grid},
p_z_grid_{p_z_grid}, p_z_grid_{p_z_grid},
...@@ -656,8 +636,6 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -656,8 +636,6 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
b_grid_desc_g_n_k_{ b_grid_desc_g_n_k_{
Transform::MakeB0GridDescriptor_G_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)}, Transform::MakeB0GridDescriptor_G_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)},
d0s_grid_desc_g_m_n_{DeviceOp::MakeD0sGridDescriptor_G_M_N(
acc0_biases_gs_ms_ns_lengths, acc0_biases_gs_ms_ns_strides)},
b1_grid_desc_g_n_k_{Transform::MakeB1GridDescriptor_G_N_K( b1_grid_desc_g_n_k_{Transform::MakeB1GridDescriptor_G_N_K(
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)}, b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
c_grid_desc_g_m_n_{Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths, c_grid_desc_g_m_n_{Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths,
...@@ -684,15 +662,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -684,15 +662,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
b1_gs_gemm1ns_gemm1ks_strides[NumDimG + NumDimO + NumDimN - 1]}, b1_gs_gemm1ns_gemm1ks_strides[NumDimG + NumDimO + NumDimN - 1]},
c_mz_gemm1nz_strides_{c_gs_ms_gemm1ns_strides[NumDimG + NumDimM - 1], c_mz_gemm1nz_strides_{c_gs_ms_gemm1ns_strides[NumDimG + NumDimM - 1],
c_gs_ms_gemm1ns_strides[NumDimG + NumDimM + NumDimO - 1]}, c_gs_ms_gemm1ns_strides[NumDimG + NumDimM + NumDimO - 1]},
batch_count_{c_grid_desc_g_m_n_.GetLength(I0)}, batch_count_{c_grid_desc_g_m_n_.GetLength(I0)}
compute_base_ptr_of_batch_{
a_grid_desc_g_m_k_,
b_grid_desc_g_n_k_,
d0s_grid_desc_g_m_n_,
b1_grid_desc_g_n_k_,
c_grid_desc_g_m_n_,
z_grid_desc_g_m_n_,
type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize())}
{ {
// TODO ANT: implement bias addition // TODO ANT: implement bias addition
ignore = p_acc1_biases; ignore = p_acc1_biases;
...@@ -709,23 +679,22 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -709,23 +679,22 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n_); c_grid_desc_m_n_);
D0sGridDesc_M_N d0s_grid_desc_m_n{DeviceOp::MakeD0sGridDescriptor_M_N( if constexpr(!is_same<D0DataType, void>::value)
acc0_biases_gs_ms_ns_lengths, acc0_biases_gs_ms_ns_strides)}; {
d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_ = d0_grid_desc_m_n_ = Transform::MakeCGridDescriptor_M_N(
GridwiseGemm::MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5( acc0_biases_gs_ms_ns_lengths, acc0_biases_gs_ms_ns_strides);
d0s_grid_desc_m_n); d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_ =
GridwiseGemm::MakeD0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(
d0_grid_desc_m_n_);
d0_grid_desc_g_m_n_ = Transform::MakeCGridDescriptor_G_M_N(
acc0_biases_gs_ms_ns_lengths, acc0_biases_gs_ms_ns_strides);
d0_n_length_stride_.push_back(acc0_biases_gs_ms_ns_lengths[NumDimG + NumDimM]);
d0_n_length_stride_.push_back(acc0_biases_gs_ms_ns_strides[NumDimG + NumDimM]);
}
} }
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
using D0DataType = remove_cvref_t<tuple_element_t<i.value, Acc0BiasDataType>>;
// D0 pointer
p_d0s_grid_(i) = static_cast<const D0DataType*>(p_acc0_biases[i]);
// for check
d0s_n_length_stride_[i].push_back(
acc0_biases_gs_ms_ns_lengths[i][NumDimG + NumDimM]);
d0s_n_length_stride_[i].push_back(
acc0_biases_gs_ms_ns_strides[i][NumDimG + NumDimM]);
});
is_dropout_ = p_dropout > 0.0; // is_dropout_ = p_dropout > 0.0; //
p_dropout_ = 1.f - p_dropout; p_dropout_ = 1.f - p_dropout;
p_dropout_in_16bits_ = uint16_t(std::floor(p_dropout_ * 65535.0)); p_dropout_in_16bits_ = uint16_t(std::floor(p_dropout_ * 65535.0));
...@@ -735,8 +704,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -735,8 +704,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
seed_ = std::get<0>(seeds); seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds); offset_ = std::get<1>(seeds);
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_ = z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(z_grid_desc_m_n_); GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6(
z_grid_desc_m_n_);
m_raw_padded_ = GridwiseGemm::GetPaddedSize(raw_lengths_mz_nz_kz_gemm1nz_[0]); m_raw_padded_ = GridwiseGemm::GetPaddedSize(raw_lengths_mz_nz_kz_gemm1nz_[0]);
n_raw_padded_ = GridwiseGemm::GetPaddedSize(raw_lengths_mz_nz_kz_gemm1nz_[1]); n_raw_padded_ = GridwiseGemm::GetPaddedSize(raw_lengths_mz_nz_kz_gemm1nz_[1]);
...@@ -745,6 +715,15 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -745,6 +715,15 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
{ {
is_lse_storing_ = false; is_lse_storing_ = false;
} }
compute_base_ptr_of_batch_ = ComputeBasePtrOfStridedBatch(
a_grid_desc_g_m_k_,
b_grid_desc_g_n_k_,
d0_grid_desc_g_m_n_,
b1_grid_desc_g_n_k_,
c_grid_desc_g_m_n_,
z_grid_desc_g_m_n_,
type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize()));
} }
void Print() const void Print() const
...@@ -755,6 +734,13 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -755,6 +734,13 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
std::cout << "b_grid_desc_g_n_k_: " << b_grid_desc_g_n_k_.GetLength(I0) << ", " std::cout << "b_grid_desc_g_n_k_: " << b_grid_desc_g_n_k_.GetLength(I0) << ", "
<< b_grid_desc_g_n_k_.GetLength(I1) << ", " << b_grid_desc_g_n_k_.GetLength(I1) << ", "
<< b_grid_desc_g_n_k_.GetLength(I2) << '\n'; << b_grid_desc_g_n_k_.GetLength(I2) << '\n';
std::cout << "d0_grid_desc_g_m_n_: " << d0_grid_desc_g_m_n_.GetLength(I0) << ", "
<< d0_grid_desc_g_m_n_.GetLength(I1) << ", "
<< d0_grid_desc_g_m_n_.GetLength(I2) << '\n';
std::cout << "d0_grid_desc_m_n_: " << d0_grid_desc_m_n_.GetLength(I0) << ", "
<< d0_grid_desc_m_n_.GetLength(I1) << '\n';
std::cout << "b1_grid_desc_g_n_k_: " << b1_grid_desc_g_n_k_.GetLength(I0) << ", " std::cout << "b1_grid_desc_g_n_k_: " << b1_grid_desc_g_n_k_.GetLength(I0) << ", "
<< b1_grid_desc_g_n_k_.GetLength(I1) << ", " << b1_grid_desc_g_n_k_.GetLength(I1) << ", "
<< b1_grid_desc_g_n_k_.GetLength(I2) << '\n'; << b1_grid_desc_g_n_k_.GetLength(I2) << '\n';
...@@ -766,7 +752,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -766,7 +752,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
// pointers // pointers
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
typename GridwiseGemm::D0sGridPointer p_d0s_grid_; const D0DataType* p_d0_grid_;
const B1DataType* p_b1_grid_; const B1DataType* p_b1_grid_;
CDataType* p_c_grid_; CDataType* p_c_grid_;
ZDataType* p_z_grid_; ZDataType* p_z_grid_;
...@@ -775,6 +761,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -775,6 +761,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
// tensor descriptor // tensor descriptor
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
D0GridDesc_M_N d0_grid_desc_m_n_;
typename GridwiseGemm::D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_; B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
ZGridDesc_M_N z_grid_desc_m_n_; ZGridDesc_M_N z_grid_desc_m_n_;
...@@ -782,17 +771,16 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -782,17 +771,16 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
AGridDesc_G_M_K a_grid_desc_g_m_k_; AGridDesc_G_M_K a_grid_desc_g_m_k_;
BGridDesc_G_N_K b_grid_desc_g_n_k_; BGridDesc_G_N_K b_grid_desc_g_n_k_;
D0sGridDesc_G_M_N d0s_grid_desc_g_m_n_; D0GridDesc_G_M_N d0_grid_desc_g_m_n_;
typename GridwiseGemm::D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_;
B1GridDesc_G_N_K b1_grid_desc_g_n_k_; B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_; CGridDesc_G_M_N c_grid_desc_g_m_n_;
ZGridDesc_G_M_N z_grid_desc_g_m_n_; ZGridDesc_G_M_N z_grid_desc_g_m_n_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_; c_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_; z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6_;
// block-to-c-tile map // block-to-c-tile map
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
...@@ -830,7 +818,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -830,7 +818,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
index_t n_raw_padded_; index_t n_raw_padded_;
// raw data // raw data
std::array<std::vector<ck::index_t>, NumD0Tensor> d0s_n_length_stride_; std::vector<ck::index_t> d0_n_length_stride_;
}; };
// Invoker // Invoker
...@@ -861,7 +849,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -861,7 +849,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
const auto kernel = kernel_batched_multiheadattention_forward_xdl_cshuffle_v2< const auto kernel = kernel_batched_multiheadattention_forward_xdl_cshuffle_v2<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
typename GridwiseGemm::D0sGridPointer, D0DataType,
CDataType, CDataType,
ZDataType, ZDataType,
LSEDataType, LSEDataType,
...@@ -873,10 +861,10 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -873,10 +861,10 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
CElementwiseOperation, CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5, typename GridwiseGemm::D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
DeviceOp::B1GridDesc_BK0_N_BK1, DeviceOp::B1GridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5, typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6,
DeviceOp::LSEGridDesc_M, DeviceOp::LSEGridDesc_M,
typename GridwiseGemm::DefaultBlock2CTileMap, typename GridwiseGemm::DefaultBlock2CTileMap,
ComputeBasePtrOfStridedBatch, ComputeBasePtrOfStridedBatch,
...@@ -894,7 +882,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -894,7 +882,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
0, 0,
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_d0s_grid_, arg.p_d0_grid_,
arg.p_b1_grid_, arg.p_b1_grid_,
arg.p_c_grid_, arg.p_c_grid_,
arg.p_z_grid_, arg.p_z_grid_,
...@@ -906,10 +894,10 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -906,10 +894,10 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
arg.c_element_op_, arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_, arg.d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg.b1_grid_desc_bk0_n_bk1_, arg.b1_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_, arg.z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6_,
arg.lse_grid_desc_m_, arg.lse_grid_desc_m_,
arg.block_2_ctile_map_, arg.block_2_ctile_map_,
arg.batch_count_, arg.batch_count_,
...@@ -1037,18 +1025,19 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -1037,18 +1025,19 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
return false; return false;
} }
for(int i = 0; i < NumD0Tensor; i++) if constexpr(!is_same<D0DataType, void>::value)
{ {
if(arg.d0s_n_length_stride_[i][1] == 1 && if(arg.d0_n_length_stride_[1] == 1 &&
arg.d0s_n_length_stride_[i][0] % Acc0BiasTransferSrcScalarPerVector != 0) arg.d0_n_length_stride_[0] % Acc0BiasTransferSrcScalarPerVector != 0)
{ {
return false; return false;
} }
if(arg.d0s_n_length_stride_[i][1] != 1 && Acc0BiasTransferSrcScalarPerVector != 1) if(arg.d0_n_length_stride_[1] != 1 && Acc0BiasTransferSrcScalarPerVector != 1)
{ {
return false; return false;
} }
} }
// Note: we need raw lengths since threadwise copy can not handle vector load when part of // Note: we need raw lengths since threadwise copy can not handle vector load when part of
// vector is out of bounds // vector is out of bounds
// Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O // Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O
...@@ -1100,39 +1089,39 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -1100,39 +1089,39 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg)); return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
} }
static auto MakeArgument( static auto
const ADataType* p_a, MakeArgument(const ADataType* p_a,
const BDataType* p_b, const BDataType* p_b,
const B1DataType* p_b1, const B1DataType* p_b1,
CDataType* p_c, CDataType* p_c,
ZDataType* p_z, ZDataType* p_z,
LSEDataType* p_lse, LSEDataType* p_lse,
const std::array<void*, NumD0Tensor> p_acc0_biases, const D0DataType* p_acc0_biases,
const std::array<void*, NumD1Tensor> p_acc1_biases, const D1DataType* p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_strides, const std::vector<index_t>& b_gs_ns_ks_strides,
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& z_gs_ms_ns_lengths, const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides, const std::vector<index_t>& z_gs_ms_ns_strides,
const std::vector<index_t>& lse_gs_ms_lengths, const std::vector<index_t>& lse_gs_ms_lengths,
const std::array<std::vector<ck::index_t>, NumD0Tensor> acc0_biases_gs_ms_ns_lengths, const std::vector<index_t>& acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumD0Tensor> acc0_biases_gs_ms_ns_strides, const std::vector<index_t>& acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumD1Tensor> const std::vector<index_t>&
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
const std::array<std::vector<ck::index_t>, NumD1Tensor> const std::vector<index_t>&
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op, AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
float p_dropout, float p_dropout,
std::tuple<unsigned long long, unsigned long long> seeds) std::tuple<unsigned long long, unsigned long long> seeds)
{ {
return Argument{p_a, return Argument{p_a,
p_b, p_b,
...@@ -1177,8 +1166,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -1177,8 +1166,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
void* p_c, void* p_c,
void* p_z, void* p_z,
void* p_lse, void* p_lse,
const std::array<void*, NumD0Tensor> p_acc0_biases, const void* p_acc0_biases,
const std::array<void*, NumD1Tensor> p_acc1_biases, const void* p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_lengths,
...@@ -1190,11 +1179,11 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -1190,11 +1179,11 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
const std::vector<index_t>& z_gs_ms_ns_lengths, const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides, const std::vector<index_t>& z_gs_ms_ns_strides,
const std::vector<index_t>& lse_gs_ms_lengths, const std::vector<index_t>& lse_gs_ms_lengths,
const std::array<std::vector<ck::index_t>, NumD0Tensor> acc0_biases_gs_ms_ns_lengths, const std::vector<index_t>& acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumD0Tensor> acc0_biases_gs_ms_ns_strides, const std::vector<index_t>& acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumD1Tensor> const std::vector<index_t>&
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
const std::array<std::vector<ck::index_t>, NumD1Tensor> const std::vector<index_t>&
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -1204,36 +1193,37 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -1204,36 +1193,37 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
float p_dropout, float p_dropout,
std::tuple<unsigned long long, unsigned long long> seeds) override std::tuple<unsigned long long, unsigned long long> seeds) override
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(
static_cast<const BDataType*>(p_b), static_cast<const ADataType*>(p_a),
static_cast<const B1DataType*>(p_b1), static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c), static_cast<const B1DataType*>(p_b1),
static_cast<ZDataType*>(p_z), static_cast<CDataType*>(p_c),
static_cast<LSEDataType*>(p_lse), static_cast<ZDataType*>(p_z),
p_acc0_biases, // cast in struct Argument static_cast<LSEDataType*>(p_lse),
p_acc1_biases, // cast in struct Argument static_cast<const D0DataType*>(p_acc0_biases), // cast in struct Argument
a_gs_ms_ks_lengths, static_cast<const D1DataType*>(p_acc1_biases), // cast in struct Argument
a_gs_ms_ks_strides, a_gs_ms_ks_lengths,
b_gs_ns_ks_lengths, a_gs_ms_ks_strides,
b_gs_ns_ks_strides, b_gs_ns_ks_lengths,
b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths b_gs_ns_ks_strides,
b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
z_gs_ms_ns_lengths, c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
z_gs_ms_ns_strides, z_gs_ms_ns_lengths,
lse_gs_ms_lengths, z_gs_ms_ns_strides,
acc0_biases_gs_ms_ns_lengths, lse_gs_ms_lengths,
acc0_biases_gs_ms_ns_strides, acc0_biases_gs_ms_ns_lengths,
acc1_biases_gs_ms_gemm1ns_lengths, acc0_biases_gs_ms_ns_strides,
acc1_biases_gs_ms_gemm1ns_strides, acc1_biases_gs_ms_gemm1ns_lengths,
a_element_op, acc1_biases_gs_ms_gemm1ns_strides,
b_element_op, a_element_op,
acc_element_op, b_element_op,
b1_element_op, acc_element_op,
c_element_op, b1_element_op,
p_dropout, c_element_op,
seeds); p_dropout,
seeds);
} }
// polymorphic // polymorphic
......
...@@ -279,12 +279,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1 ...@@ -279,12 +279,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
"Number of dimension must be greater than 0"); "Number of dimension must be greater than 0");
static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
// TODO ANT: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
#if 0 #if 0
// TODO ANT: use alias // TODO ANT: use alias
static constexpr index_t NumDimGemm0M = NumDimM; static constexpr index_t NumDimGemm0M = NumDimM;
...@@ -603,8 +597,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1 ...@@ -603,8 +597,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1
std::vector<void*> p_c_vec, std::vector<void*> p_c_vec,
std::vector<void*> p_z_vec, std::vector<void*> p_z_vec,
std::vector<void*> p_lse_vec, std::vector<void*> p_lse_vec,
std::vector<std::vector<const void*>> p_acc0_biases_vec, std::vector<const void*> p_acc0_bias_vec,
std::vector<std::vector<const void*>> p_acc1_biases_vec, std::vector<const void*> p_acc1_bias_vec,
std::vector<ProblemDesc> problem_desc_vec, std::vector<ProblemDesc> problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -619,6 +613,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1 ...@@ -619,6 +613,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1
b1_element_op_{b1_element_op}, b1_element_op_{b1_element_op},
c_element_op_{c_element_op} c_element_op_{c_element_op}
{ {
ignore = p_acc0_bias_vec;
ignore = p_acc1_bias_vec;
// TODO ANT: implement bias addition // TODO ANT: implement bias addition
group_count_ = problem_desc_vec.size(); group_count_ = problem_desc_vec.size();
...@@ -628,11 +625,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1 ...@@ -628,11 +625,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1
throw std::runtime_error("wrong! group_count_ != a/b/b1/c_vec.size"); throw std::runtime_error("wrong! group_count_ != a/b/b1/c_vec.size");
} }
if(!(p_acc0_biases_vec.size() == p_acc1_biases_vec.size()))
{
throw std::runtime_error("wrong! acc0_bias_vec.size != acc1_bias_vec.size");
}
grid_size_ = 0; grid_size_ = 0;
for(std::size_t i = 0; i < group_count_; i++) for(std::size_t i = 0; i < group_count_; i++)
...@@ -710,18 +702,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1 ...@@ -710,18 +702,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1
grid_size_ += grid_size_grp; grid_size_ += grid_size_grp;
// for each group, make sure acc0_biases_gs_ms_ns_lengths.size() == NumAcc0Bias and
// so on
if(!(problem_desc.acc0_biases_gs_ms_ns_lengths.size() == NumAcc0Bias &&
problem_desc.acc0_biases_gs_ms_ns_strides.size() == NumAcc0Bias &&
problem_desc.acc1_biases_gs_ms_os_lengths.size() == NumAcc1Bias &&
problem_desc.acc1_biases_gs_ms_os_strides.size() == NumAcc1Bias))
{
throw std::runtime_error(
"wrong! number of biases in function argument does not "
"match that in template argument");
}
group_kernel_args_.push_back({p_a_grid, group_kernel_args_.push_back({p_a_grid,
p_b_grid, p_b_grid,
p_b1_grid, p_b1_grid,
...@@ -1055,8 +1035,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1 ...@@ -1055,8 +1035,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1
std::vector<void*> p_c_vec, std::vector<void*> p_c_vec,
std::vector<void*> p_z_vec, std::vector<void*> p_z_vec,
std::vector<void*> p_lse_vec, std::vector<void*> p_lse_vec,
std::vector<std::vector<const void*>> p_acc0_biases_vec, std::vector<const void*> p_acc0_bias_vec,
std::vector<std::vector<const void*>> p_acc1_biases_vec, std::vector<const void*> p_acc1_bias_vec,
std::vector<ProblemDesc> problem_desc_vec, std::vector<ProblemDesc> problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -1072,8 +1052,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1 ...@@ -1072,8 +1052,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1
p_c_vec, p_c_vec,
p_z_vec, p_z_vec,
p_lse_vec, p_lse_vec,
p_acc0_biases_vec, p_acc0_bias_vec,
p_acc1_biases_vec, p_acc1_bias_vec,
problem_desc_vec, problem_desc_vec,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -1094,9 +1074,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1 ...@@ -1094,9 +1074,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1
std::vector<void*> p_c_vec, std::vector<void*> p_c_vec,
std::vector<void*> p_z_vec, std::vector<void*> p_z_vec,
std::vector<void*> p_lse_vec, std::vector<void*> p_lse_vec,
std::vector<std::vector<const void*>> p_acc0_biases_vec, std::vector<const void*> p_acc0_bias_vec,
std::vector<std::vector<const void*>> p_acc1_biases_vec, std::vector<const void*> p_acc1_bias_vec,
std::vector<ProblemDesc> problem_desc_vec, std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op, AccElementwiseOperation acc_element_op,
...@@ -1111,8 +1091,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1 ...@@ -1111,8 +1091,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1
p_c_vec, p_c_vec,
p_z_vec, p_z_vec,
p_lse_vec, p_lse_vec,
p_acc0_biases_vec, p_acc0_bias_vec,
p_acc1_biases_vec, p_acc1_bias_vec,
problem_desc_vec, problem_desc_vec,
a_element_op, a_element_op,
b_element_op, b_element_op,
......
...@@ -24,6 +24,7 @@ namespace tensor_operation { ...@@ -24,6 +24,7 @@ namespace tensor_operation {
namespace device { namespace device {
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename D0DataType,
typename GemmAccDataType, typename GemmAccDataType,
typename GroupKernelArg, typename GroupKernelArg,
typename AElementwiseOperation, typename AElementwiseOperation,
...@@ -100,13 +101,17 @@ __global__ void ...@@ -100,13 +101,17 @@ __global__ void
const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>( const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetLSEBasePtr(g_idx))); arg_ptr[group_id].compute_base_ptr_of_batch_.GetLSEBasePtr(g_idx)));
typename GridwiseGemm::D0sGridPointer p_d0s_grid = arg_ptr[group_id].p_d0s_grid_; const D0DataType* tmp_p_d0_grid = nullptr;
static_for<0, p_d0s_grid.Size(), 1>{}([&](auto In) {
if constexpr(!is_same<D0DataType, void>::value)
{
const long_index_t d0_batch_offset = const long_index_t d0_batch_offset =
__builtin_amdgcn_readfirstlane(static_cast<long_index_t>( __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetD0BasePtr(g_idx, In))); arg_ptr[group_id].compute_base_ptr_of_batch_.GetD0BasePtr(g_idx)));
p_d0s_grid(In) = p_d0s_grid(In) + d0_batch_offset;
}); tmp_p_d0_grid = arg_ptr[group_id].p_d0_grid_ + d0_batch_offset;
}
if constexpr(Deterministic) if constexpr(Deterministic)
{ {
for(index_t i = 0; i < num_blocks_per_batch; i++) for(index_t i = 0; i < num_blocks_per_batch; i++)
...@@ -114,7 +119,7 @@ __global__ void ...@@ -114,7 +119,7 @@ __global__ void
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>( GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset, arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset, arg_ptr[group_id].p_b_grid_ + b_batch_offset,
p_d0s_grid, tmp_p_d0_grid,
arg_ptr[group_id].p_b1_grid_ + b1_batch_offset, arg_ptr[group_id].p_b1_grid_ + b1_batch_offset,
arg_ptr[group_id].p_c_grid_ + c_batch_offset, arg_ptr[group_id].p_c_grid_ + c_batch_offset,
arg_ptr[group_id].p_z_grid_ == nullptr arg_ptr[group_id].p_z_grid_ == nullptr
...@@ -132,10 +137,10 @@ __global__ void ...@@ -132,10 +137,10 @@ __global__ void
c_element_op, c_element_op,
arg_ptr[group_id].a_grid_desc_ak0_m_ak1_, arg_ptr[group_id].a_grid_desc_ak0_m_ak1_,
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_, arg_ptr[group_id].d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_, arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg_ptr[group_id].z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_, arg_ptr[group_id].z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6_,
arg_ptr[group_id].lse_grid_desc_m_, arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].block_2_ctile_map_, arg_ptr[group_id].block_2_ctile_map_,
arg_ptr[group_id].c0_matrix_mask_, arg_ptr[group_id].c0_matrix_mask_,
...@@ -153,7 +158,7 @@ __global__ void ...@@ -153,7 +158,7 @@ __global__ void
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>( GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset, arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset, arg_ptr[group_id].p_b_grid_ + b_batch_offset,
p_d0s_grid, tmp_p_d0_grid,
arg_ptr[group_id].p_b1_grid_ + b1_batch_offset, arg_ptr[group_id].p_b1_grid_ + b1_batch_offset,
arg_ptr[group_id].p_c_grid_ + c_batch_offset, arg_ptr[group_id].p_c_grid_ + c_batch_offset,
arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
...@@ -170,10 +175,10 @@ __global__ void ...@@ -170,10 +175,10 @@ __global__ void
c_element_op, c_element_op,
arg_ptr[group_id].a_grid_desc_ak0_m_ak1_, arg_ptr[group_id].a_grid_desc_ak0_m_ak1_,
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_, arg_ptr[group_id].d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_, arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg_ptr[group_id].z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_, arg_ptr[group_id].z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6_,
arg_ptr[group_id].lse_grid_desc_m_, arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].block_2_ctile_map_, arg_ptr[group_id].block_2_ctile_map_,
arg_ptr[group_id].c0_matrix_mask_, arg_ptr[group_id].c0_matrix_mask_,
...@@ -244,6 +249,7 @@ template <index_t NumDimG, ...@@ -244,6 +249,7 @@ template <index_t NumDimG,
index_t MXdlPerWave, index_t MXdlPerWave,
index_t NXdlPerWave, index_t NXdlPerWave,
index_t Gemm1NXdlPerWave, index_t Gemm1NXdlPerWave,
index_t DropoutStep,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1, typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder, typename ABlockTransferSrcAccessOrder,
...@@ -298,11 +304,10 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -298,11 +304,10 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
"Number of dimension must be greater than 0"); "Number of dimension must be greater than 0");
static constexpr index_t NumD0Tensor = Acc0BiasDataType::Size(); using D0DataType = Acc0BiasDataType;
static constexpr index_t NumD1Tensor = Acc1BiasDataType::Size(); using D1DataType = Acc1BiasDataType;
// TODO ANT: implement bias combination // TODO ANT: implement bias combination
static_assert(NumD1Tensor == 0, "Acc1 Bias addition is unimplemented"); static_assert(std::is_void<Acc1BiasDataType>::value, "Acc1 Bias addition is unimplemented");
#if 0 #if 0
// TODO ANT: use alias // TODO ANT: use alias
...@@ -405,44 +410,38 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -405,44 +410,38 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
} }
} }
static auto MakeD0sGridDescriptor_M_N( static auto
const std::vector<std::vector<ck::index_t>>& acc0_biases_gs_ms_ns_lengths, MakeD0GridDescriptor_M_N(const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_lengths,
const std::vector<std::vector<ck::index_t>>& acc0_biases_gs_ms_ns_strides) const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_strides)
{ {
return generate_tuple(
[&](auto i) { return Transform::MakeCGridDescriptor_M_N(acc0_biases_gs_ms_ns_lengths,
return Transform::MakeCGridDescriptor_M_N(acc0_biases_gs_ms_ns_lengths[i], acc0_biases_gs_ms_ns_strides);
acc0_biases_gs_ms_ns_strides[i]);
},
Number<NumD0Tensor>{});
} }
static auto MakeD0sGridDescriptor_G_M_N( static auto
const std::vector<std::vector<ck::index_t>>& acc0_biases_gs_ms_ns_lengths, MakeD0GridDescriptor_G_M_N(const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_lengths,
const std::vector<std::vector<ck::index_t>>& acc0_biases_gs_ms_ns_strides) const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_strides)
{ {
return generate_tuple(
[&](auto i) { return Transform::MakeCGridDescriptor_G_M_N(acc0_biases_gs_ms_ns_lengths,
return Transform::MakeCGridDescriptor_G_M_N(acc0_biases_gs_ms_ns_lengths[i], acc0_biases_gs_ms_ns_strides);
acc0_biases_gs_ms_ns_strides[i]);
},
Number<NumD0Tensor>{});
} }
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {})); using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {})); using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using D0sGridDesc_M_N = decltype(MakeD0sGridDescriptor_M_N({}, {})); using D0GridDesc_M_N = decltype(MakeD0GridDescriptor_M_N({}, {}));
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {})); using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {}));
using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {})); using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1)); using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1));
using ZGridDesc_M_N = decltype(MakeZGridDescriptor_M_N({}, {})); using ZGridDesc_M_N = decltype(MakeZGridDescriptor_M_N({}, {}));
using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {})); using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {}));
using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {})); using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {}));
using D0sGridDesc_G_M_N = decltype(MakeD0sGridDescriptor_G_M_N({}, {})); using D0GridDesc_G_M_N = decltype(MakeD0GridDescriptor_G_M_N({}, {}));
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {})); using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
constexpr static auto make_MaskOutPredicate() constexpr static auto make_MaskOutPredicate()
{ {
...@@ -465,14 +464,14 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -465,14 +464,14 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
{ {
ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k, ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k,
const BGridDesc_G_N_K& b_grid_desc_g_n_k, const BGridDesc_G_N_K& b_grid_desc_g_n_k,
const D0sGridDesc_G_M_N& d0s_grid_desc_g_m_n, const D0GridDesc_G_M_N& d0_grid_desc_g_m_n,
const B1GridDesc_G_N_K& b1_grid_desc_g_n_k, const B1GridDesc_G_N_K& b1_grid_desc_g_n_k,
const CGridDesc_G_M_N& c_grid_desc_g_m_n, const CGridDesc_G_M_N& c_grid_desc_g_m_n,
const ZGridDesc_G_M_N& z_grid_desc_g_m_n, const ZGridDesc_G_M_N& z_grid_desc_g_m_n,
index_t BatchStrideLSE) index_t BatchStrideLSE)
: a_grid_desc_g_m_k_(a_grid_desc_g_m_k), : a_grid_desc_g_m_k_(a_grid_desc_g_m_k),
b_grid_desc_g_n_k_(b_grid_desc_g_n_k), b_grid_desc_g_n_k_(b_grid_desc_g_n_k),
d0s_grid_desc_g_m_n_(d0s_grid_desc_g_m_n), d0_grid_desc_g_m_n_(d0_grid_desc_g_m_n),
b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k), b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k),
c_grid_desc_g_m_n_(c_grid_desc_g_m_n), c_grid_desc_g_m_n_(c_grid_desc_g_m_n),
z_grid_desc_g_m_n_(z_grid_desc_g_m_n), z_grid_desc_g_m_n_(z_grid_desc_g_m_n),
...@@ -490,11 +489,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -490,11 +489,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
return b_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0)); return b_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
} }
template <index_t I> __host__ __device__ constexpr long_index_t GetD0BasePtr(index_t g_idx) const
__host__ __device__ constexpr long_index_t GetD0BasePtr(index_t g_idx,
Number<I> d0_idx) const
{ {
return d0s_grid_desc_g_m_n_[d0_idx].CalculateOffset(make_multi_index(g_idx, 0, 0)); return d0_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
} }
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const __host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
...@@ -520,7 +517,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -520,7 +517,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
private: private:
AGridDesc_G_M_K a_grid_desc_g_m_k_; AGridDesc_G_M_K a_grid_desc_g_m_k_;
BGridDesc_G_N_K b_grid_desc_g_n_k_; BGridDesc_G_N_K b_grid_desc_g_n_k_;
D0sGridDesc_G_M_N d0s_grid_desc_g_m_n_; D0GridDesc_G_M_N d0_grid_desc_g_m_n_;
B1GridDesc_G_N_K b1_grid_desc_g_n_k_; B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_; CGridDesc_G_M_N c_grid_desc_g_m_n_;
ZGridDesc_G_M_N z_grid_desc_g_m_n_; ZGridDesc_G_M_N z_grid_desc_g_m_n_;
...@@ -546,7 +543,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -546,7 +543,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1, AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1, BGridDesc_BK0_N_BK1,
D0sGridDesc_M_N, D0GridDesc_M_N,
B1GridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1,
CGridDesc_M_N, CGridDesc_M_N,
ZGridDesc_M_N, ZGridDesc_M_N,
...@@ -566,6 +563,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -566,6 +563,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
MXdlPerWave, MXdlPerWave,
NXdlPerWave, NXdlPerWave,
Gemm1NXdlPerWave, Gemm1NXdlPerWave,
DropoutStep,
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
...@@ -608,7 +606,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -608,7 +606,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
// pointers // pointers
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
typename GridwiseGemm::D0sGridPointer p_d0s_grid_; const D0DataType* p_d0_grid_;
const B1DataType* p_b1_grid_; const B1DataType* p_b1_grid_;
CDataType* p_c_grid_; CDataType* p_c_grid_;
ZDataType* p_z_grid_; ZDataType* p_z_grid_;
...@@ -617,13 +615,13 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -617,13 +615,13 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
// tensor descriptors for block/thread-wise copy // tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 typename GridwiseGemm::D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_; d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_; B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_; c_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_; z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6_;
ZGridDesc_M_N z_grid_desc_m_n_; ZGridDesc_M_N z_grid_desc_m_n_;
LSEGridDesc_M lse_grid_desc_m_; LSEGridDesc_M lse_grid_desc_m_;
...@@ -658,7 +656,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -658,7 +656,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
// raw data // raw data
std::array<std::vector<ck::index_t>, NumD0Tensor> d0s_n_length_stride_; std::vector<ck::index_t> d0_n_length_stride_;
}; };
// Argument // Argument
...@@ -671,9 +669,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -671,9 +669,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
std::vector<void*> p_c_vec, std::vector<void*> p_c_vec,
std::vector<void*> p_z_vec, std::vector<void*> p_z_vec,
std::vector<void*> p_lse_vec, std::vector<void*> p_lse_vec,
std::vector<std::vector<const void*>> p_acc0_biases_vec, std::vector<const void*> p_acc0_biases_vec,
std::vector<std::vector<const void*>> p_acc1_biases_vec, std::vector<const void*> p_acc1_biases_vec,
std::vector<ProblemDesc> problem_desc_vec, std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op, AccElementwiseOperation acc_element_op,
...@@ -693,7 +691,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -693,7 +691,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
if(!(group_count_ == p_a_vec.size() && group_count_ == p_b_vec.size() && if(!(group_count_ == p_a_vec.size() && group_count_ == p_b_vec.size() &&
group_count_ == p_b1_vec.size() && group_count_ == p_c_vec.size() && group_count_ == p_b1_vec.size() && group_count_ == p_c_vec.size() &&
(group_count_ == p_acc0_biases_vec.size() || p_acc0_biases_vec.size() == 0))) (group_count_ == p_acc0_biases_vec.size() || p_acc0_biases_vec.size() == 0) &&
(group_count_ == p_z_vec.size() || p_z_vec.size() == 0) &&
(group_count_ == p_lse_vec.size() || p_lse_vec.size() == 0)))
{ {
throw std::runtime_error("wrong! group_count_ != a/b/b1/c_vec.size"); throw std::runtime_error("wrong! group_count_ != a/b/b1/c_vec.size");
} }
...@@ -704,43 +704,50 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -704,43 +704,50 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
for(std::size_t i = 0; i < group_count_; i++) for(std::size_t i = 0; i < group_count_; i++)
{ {
const auto p_a_grid = static_cast<const ADataType*>(p_a_vec[i]); const auto p_a_grid = static_cast<const ADataType*>(p_a_vec[i]);
const auto p_b_grid = static_cast<const BDataType*>(p_b_vec[i]); const auto p_b_grid = static_cast<const BDataType*>(p_b_vec[i]);
const auto p_d0_grid = (p_acc0_biases_vec.size() == group_count_)
const auto& problem_desc = problem_desc_vec[i]; ? static_cast<const D0DataType*>(p_acc0_biases_vec[i])
std::array<std::vector<ck::index_t>, NumD0Tensor> d0s_n_length_stride; : nullptr;
typename GridwiseGemm::D0sGridPointer p_d0s_grid; const auto p_b1_grid = static_cast<const B1DataType*>(p_b1_vec[i]);
static_for<0, NumD0Tensor, 1>{}([&](auto j) { const auto p_c_grid = static_cast<CDataType*>(p_c_vec[i]);
using D0DataType = remove_cvref_t<tuple_element_t<j.value, Acc0BiasDataType>>; const auto p_z_grid = (p_z_vec.size() == group_count_)
// D0 pointer ? static_cast<ZDataType*>(p_z_vec[i])
p_d0s_grid(j) = static_cast<const D0DataType*>(p_acc0_biases_vec[i][j]); : nullptr;
// for check const auto p_lse_grid = (p_lse_vec.size() == group_count_)
d0s_n_length_stride[j].push_back( ? static_cast<LSEDataType*>(p_lse_vec[i])
problem_desc.acc0_biases_gs_ms_ns_lengths[j][NumDimG + NumDimM]); : nullptr;
d0s_n_length_stride[j].push_back(
problem_desc.acc0_biases_gs_ms_ns_strides[j][NumDimG + NumDimM]);
});
const auto p_b1_grid = static_cast<const B1DataType*>(p_b1_vec[i]);
const auto p_c_grid = static_cast<CDataType*>(p_c_vec[i]);
const auto p_z_grid = static_cast<ZDataType*>(p_z_vec[i]);
const auto p_lse_grid = static_cast<LSEDataType*>(p_lse_vec[i]);
if(p_lse_grid == nullptr) if(p_lse_grid == nullptr)
{ {
is_lse_storing_ = false; is_lse_storing_ = false;
} }
const auto& problem_desc = problem_desc_vec[i];
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides); problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides);
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
problem_desc.b0_gs_ns_ks_lengths, problem_desc.b0_gs_ns_ks_strides); problem_desc.b0_gs_ns_ks_lengths, problem_desc.b0_gs_ns_ks_strides);
const D0sGridDesc_M_N d0s_grid_desc_m_n{
DeviceOp::MakeD0sGridDescriptor_M_N(problem_desc.acc0_biases_gs_ms_ns_lengths, std::vector<index_t> tmp_d0_gs_ms_ns_lengths;
problem_desc.acc0_biases_gs_ms_ns_strides)}; std::vector<index_t> tmp_d0_gs_ms_ns_strides;
const auto d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
GridwiseGemm::MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5( if constexpr(!is_same<D0DataType, void>::value)
d0s_grid_desc_m_n); {
tmp_d0_gs_ms_ns_lengths = problem_desc.acc0_biases_gs_ms_ns_lengths;
tmp_d0_gs_ms_ns_strides = problem_desc.acc0_biases_gs_ms_ns_strides;
}
else
{
tmp_d0_gs_ms_ns_lengths = {1, 1, 1, 1};
tmp_d0_gs_ms_ns_strides = {0, 0, 0, 0};
}
const D0GridDesc_M_N d0_grid_desc_m_n{DeviceOp::MakeD0GridDescriptor_M_N(
tmp_d0_gs_ms_ns_lengths, tmp_d0_gs_ms_ns_strides)};
const auto d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
GridwiseGemm::MakeD0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(
d0_grid_desc_m_n);
const auto b1_grid_desc_bk0_n_bk1 = MakeB1GridDescriptor_BK0_N_BK1( const auto b1_grid_desc_bk0_n_bk1 = MakeB1GridDescriptor_BK0_N_BK1(
problem_desc.b1_gs_os_ns_lengths, problem_desc.b1_gs_os_ns_strides); problem_desc.b1_gs_os_ns_lengths, problem_desc.b1_gs_os_ns_strides);
const auto c_grid_desc_m_n = Transform::MakeCGridDescriptor_M_N( const auto c_grid_desc_m_n = Transform::MakeCGridDescriptor_M_N(
...@@ -754,9 +761,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -754,9 +761,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides); problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides);
const auto b_grid_desc_g_n_k = Transform::MakeB0GridDescriptor_G_N_K( const auto b_grid_desc_g_n_k = Transform::MakeB0GridDescriptor_G_N_K(
problem_desc.b0_gs_ns_ks_lengths, problem_desc.b0_gs_ns_ks_strides); problem_desc.b0_gs_ns_ks_lengths, problem_desc.b0_gs_ns_ks_strides);
const auto d0s_grid_desc_g_m_n = DeviceOp::MakeD0sGridDescriptor_G_M_N( const auto d0_grid_desc_g_m_n = DeviceOp::MakeD0GridDescriptor_G_M_N(
problem_desc.acc0_biases_gs_ms_ns_lengths, tmp_d0_gs_ms_ns_lengths, tmp_d0_gs_ms_ns_strides);
problem_desc.acc0_biases_gs_ms_ns_strides);
const auto b1_grid_desc_g_n_k = Transform::MakeB1GridDescriptor_G_N_K( const auto b1_grid_desc_g_n_k = Transform::MakeB1GridDescriptor_G_N_K(
problem_desc.b1_gs_os_ns_lengths, problem_desc.b1_gs_os_ns_strides); problem_desc.b1_gs_os_ns_lengths, problem_desc.b1_gs_os_ns_strides);
const auto c_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N( const auto c_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N(
...@@ -768,12 +774,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -768,12 +774,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n); c_grid_desc_m_n);
// typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 const auto z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6 =
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5; GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6(
const auto z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(
z_grid_desc_m_n); z_grid_desc_m_n);
const index_t BlockStart = grid_size_; const index_t BlockStart = grid_size_;
...@@ -788,7 +790,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -788,7 +790,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
const auto compute_base_ptr_of_batch = ComputeBasePtrOfStridedBatch( const auto compute_base_ptr_of_batch = ComputeBasePtrOfStridedBatch(
a_grid_desc_g_m_k, a_grid_desc_g_m_k,
b_grid_desc_g_n_k, b_grid_desc_g_n_k,
d0s_grid_desc_g_m_n, d0_grid_desc_g_m_n,
b1_grid_desc_g_n_k, b1_grid_desc_g_n_k,
c_grid_desc_g_m_n, c_grid_desc_g_m_n,
z_grid_desc_g_m_n, z_grid_desc_g_m_n,
...@@ -800,18 +802,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -800,18 +802,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
grid_size_ += grid_size_grp; grid_size_ += grid_size_grp;
// for each group, make sure acc0_biases_gs_ms_ns_lengths.size() == NumD0Tensor and
// so on
if(!(problem_desc.acc0_biases_gs_ms_ns_lengths.size() == NumD0Tensor &&
problem_desc.acc0_biases_gs_ms_ns_strides.size() == NumD0Tensor &&
problem_desc.acc1_biases_gs_ms_os_lengths.size() == NumD1Tensor &&
problem_desc.acc1_biases_gs_ms_os_strides.size() == NumD1Tensor))
{
throw std::runtime_error(
"wrong! number of biases in function argument does not "
"match that in template argument");
}
const auto raw_m_padded = GridwiseGemm::GetPaddedSize( const auto raw_m_padded = GridwiseGemm::GetPaddedSize(
problem_desc.a_gs_ms_ks_lengths[NumDimG + NumDimM - 1]); problem_desc.a_gs_ms_ks_lengths[NumDimG + NumDimM - 1]);
const auto raw_n_padded = GridwiseGemm::GetPaddedSize( const auto raw_n_padded = GridwiseGemm::GetPaddedSize(
...@@ -819,17 +809,17 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -819,17 +809,17 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
group_kernel_args_.push_back({p_a_grid, group_kernel_args_.push_back({p_a_grid,
p_b_grid, p_b_grid,
p_d0s_grid, p_d0_grid,
p_b1_grid, p_b1_grid,
p_c_grid, p_c_grid,
p_z_grid, p_z_grid,
p_lse_grid, p_lse_grid,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6,
z_grid_desc_m_n, z_grid_desc_m_n,
lse_grid_desc_m, lse_grid_desc_m,
block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n), block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n),
...@@ -845,6 +835,11 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -845,6 +835,11 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
z_random_matrix_offset = z_random_matrix_offset =
z_random_matrix_offset + raw_m_padded * raw_n_padded * batch_count; z_random_matrix_offset + raw_m_padded * raw_n_padded * batch_count;
// for check
std::vector<ck::index_t> d0_n_length_stride;
d0_n_length_stride.push_back(tmp_d0_gs_ms_ns_lengths[NumDimG + NumDimM]);
d0_n_length_stride.push_back(tmp_d0_gs_ms_ns_strides[NumDimG + NumDimM]);
group_device_args_.push_back( group_device_args_.push_back(
{{problem_desc.a_gs_ms_ks_lengths[NumDimG + NumDimM - 1], {{problem_desc.a_gs_ms_ks_lengths[NumDimG + NumDimM - 1],
problem_desc.b0_gs_ns_ks_lengths[NumDimG + NumDimN - 1], problem_desc.b0_gs_ns_ks_lengths[NumDimG + NumDimN - 1],
...@@ -859,10 +854,10 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -859,10 +854,10 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
{problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM - 1], {problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM - 1],
problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM + NumDimO - 1]}, problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM + NumDimO - 1]},
c_grid_desc_m_n, c_grid_desc_m_n,
d0s_n_length_stride}); d0_n_length_stride});
} }
is_dropout_ = p_dropout > 0.0; // use_dropout_ = p_dropout > 0.0; //
p_dropout_ = 1.f - p_dropout; p_dropout_ = 1.f - p_dropout;
p_dropout_in_16bits_ = uint16_t(std::floor(p_dropout_ * 65535.0)); p_dropout_in_16bits_ = uint16_t(std::floor(p_dropout_ * 65535.0));
p_dropout_ = 1.f / p_dropout_; p_dropout_ = 1.f / p_dropout_;
...@@ -889,7 +884,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -889,7 +884,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
unsigned long long seed_; unsigned long long seed_;
unsigned long long offset_; unsigned long long offset_;
GemmAccDataType p_dropout_rescale_; GemmAccDataType p_dropout_rescale_;
bool is_dropout_; bool use_dropout_;
bool is_lse_storing_ = true; bool is_lse_storing_ = true;
}; };
...@@ -925,9 +920,10 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -925,9 +920,10 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
float ave_time = 0; float ave_time = 0;
auto launch_kernel = auto launch_kernel =
[&](auto has_main_k_block_loop_, auto is_dropout_, auto is_lse_storing_) { [&](auto has_main_k_block_loop_, auto use_dropout_, auto is_lse_storing_) {
const auto kernel = const auto kernel =
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v2<GridwiseGemm, kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v2<GridwiseGemm,
D0DataType,
GemmAccDataType, GemmAccDataType,
GroupKernelArg, GroupKernelArg,
AElementwiseOperation, AElementwiseOperation,
...@@ -936,7 +932,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -936,7 +932,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
B1ElementwiseOperation, B1ElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
has_main_k_block_loop_, has_main_k_block_loop_,
is_dropout_, use_dropout_,
is_lse_storing_, is_lse_storing_,
Deterministic>; Deterministic>;
...@@ -963,7 +959,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -963,7 +959,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
// to concern Gemm0's loop // to concern Gemm0's loop
if(all_has_main_k_block_loop) if(all_has_main_k_block_loop)
{ {
if(arg.is_dropout_) if(arg.use_dropout_)
{ {
if(arg.is_lse_storing_) if(arg.is_lse_storing_)
{ {
...@@ -996,7 +992,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -996,7 +992,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
} }
else if(!some_has_main_k_block_loop) else if(!some_has_main_k_block_loop)
{ {
if(arg.is_dropout_) if(arg.use_dropout_)
{ {
if(arg.is_lse_storing_) if(arg.is_lse_storing_)
{ {
...@@ -1079,19 +1075,20 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -1079,19 +1075,20 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
return false; return false;
} }
for(int In = 0; In < NumD0Tensor; In++) if constexpr(!is_same<D0DataType, void>::value)
{ {
if(device_arg.d0s_n_length_stride_[In][1] == 1 && if(device_arg.d0_n_length_stride_[1] == 1 &&
device_arg.d0s_n_length_stride_[In][0] % Acc0BiasTransferSrcScalarPerVector != 0) device_arg.d0_n_length_stride_[0] % Acc0BiasTransferSrcScalarPerVector != 0)
{ {
return false; return false;
} }
if(device_arg.d0s_n_length_stride_[In][1] != 1 && if(device_arg.d0_n_length_stride_[1] != 1 &&
Acc0BiasTransferSrcScalarPerVector != 1) Acc0BiasTransferSrcScalarPerVector != 1)
{ {
return false; return false;
} }
} }
// Check if having main loop // Check if having main loop
const auto K = kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * const auto K = kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) *
kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
...@@ -1172,9 +1169,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -1172,9 +1169,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
std::vector<void*> p_c_vec, std::vector<void*> p_c_vec,
std::vector<void*> p_z_vec, std::vector<void*> p_z_vec,
std::vector<void*> p_lse_vec, std::vector<void*> p_lse_vec,
std::vector<std::vector<const void*>> p_acc0_biases_vec, std::vector<const void*> p_acc0_biases_vec,
std::vector<std::vector<const void*>> p_acc1_biases_vec, std::vector<const void*> p_acc1_biases_vec,
std::vector<ProblemDesc> problem_desc_vec, std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op, AccElementwiseOperation acc_element_op,
...@@ -1211,9 +1208,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -1211,9 +1208,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
std::vector<void*> p_c_vec, std::vector<void*> p_c_vec,
std::vector<void*> p_z_vec, std::vector<void*> p_z_vec,
std::vector<void*> p_lse_vec, std::vector<void*> p_lse_vec,
std::vector<std::vector<const void*>> p_acc0_biases_vec, std::vector<const void*> p_acc0_biases_vec,
std::vector<std::vector<const void*>> p_acc1_biases_vec, std::vector<const void*> p_acc1_biases_vec,
std::vector<ProblemDesc> problem_desc_vec, std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op, AccElementwiseOperation acc_element_op,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment