"configs/vscode:/vscode.git/clone" did not exist on "3cfe73de3ff75cdb657f0db77cd6228c28f41da1"
Commit 883c060a authored by guangzlu's avatar guangzlu
Browse files

bwd qloop v1 passed

parent 72498367
...@@ -18,15 +18,36 @@ namespace device { ...@@ -18,15 +18,36 @@ namespace device {
namespace instance { namespace instance {
void add_device_batched_mha_bwd_qloop_casual_f16_f16_gmk_gnk_gno_gmo_instances( void add_device_batched_mha_bwd_qloop_casual_f16_f16_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr<DeviceBatchedMultiheadAttentionBackward<
2,
1,
1,
1,
1,
F16,
F16,
unsigned short,
F32,
ck::Tuple<>,
ck::Tuple<>,
PassThrough,
PassThrough,
Scale,
PassThrough,
PassThrough,
MaskingSpecialization::MaskUpperTriangleFromTopLeft>>>& instances);
void add_device_batched_mha_bwd_qloop_noncasual_f16_f16_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceBatchedMultiheadAttentionBackward<2, DeviceBatchedMultiheadAttentionBackward<2,
1, 1,
1, 1,
1, 1,
1,
F16, F16,
F16, F16,
unsigned short, unsigned short,
F32, F32,
ck::Tuple<>, ck::Tuple<>,
ck::Tuple<>, ck::Tuple<>,
PassThrough, PassThrough,
...@@ -34,31 +55,29 @@ void add_device_batched_mha_bwd_qloop_casual_f16_f16_gmk_gnk_gno_gmo_instances( ...@@ -34,31 +55,29 @@ void add_device_batched_mha_bwd_qloop_casual_f16_f16_gmk_gnk_gno_gmo_instances(
Scale, Scale,
PassThrough, PassThrough,
PassThrough, PassThrough,
MaskingSpecialization::MaskUpperTriangleFromTopLeft>>>& MaskingSpecialization::MaskDisabled>>>& instances);
instances);
void add_device_batched_mha_bwd_qloop_noncasual_f16_f16_gmk_gnk_gno_gmo_instances(
std::vector<
std::unique_ptr<DeviceBatchedMultiheadAttentionBackward<2,
1,
1,
1,
1,
F16,
F16,
unsigned short,
F32,
ck::Tuple<>,
ck::Tuple<>,
PassThrough,
PassThrough,
Scale,
PassThrough,
PassThrough,
MaskingSpecialization::MaskDisabled>>>&
instances);
void add_device_batched_mha_bwd_qloop_casual_bf16_bf16_gmk_gnk_gno_gmo_instances( void add_device_batched_mha_bwd_qloop_casual_bf16_bf16_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr<DeviceBatchedMultiheadAttentionBackward<
2,
1,
1,
1,
1,
BF16,
BF16,
unsigned short,
F32,
ck::Tuple<>,
ck::Tuple<>,
PassThrough,
PassThrough,
Scale,
PassThrough,
PassThrough,
MaskingSpecialization::MaskUpperTriangleFromTopLeft>>>& instances);
void add_device_batched_mha_bwd_qloop_noncasual_bf16_bf16_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceBatchedMultiheadAttentionBackward<2, DeviceBatchedMultiheadAttentionBackward<2,
1, 1,
...@@ -68,7 +87,7 @@ void add_device_batched_mha_bwd_qloop_casual_bf16_bf16_gmk_gnk_gno_gmo_instances ...@@ -68,7 +87,7 @@ void add_device_batched_mha_bwd_qloop_casual_bf16_bf16_gmk_gnk_gno_gmo_instances
BF16, BF16,
BF16, BF16,
unsigned short, unsigned short,
F32, F32,
ck::Tuple<>, ck::Tuple<>,
ck::Tuple<>, ck::Tuple<>,
PassThrough, PassThrough,
...@@ -76,29 +95,7 @@ void add_device_batched_mha_bwd_qloop_casual_bf16_bf16_gmk_gnk_gno_gmo_instances ...@@ -76,29 +95,7 @@ void add_device_batched_mha_bwd_qloop_casual_bf16_bf16_gmk_gnk_gno_gmo_instances
Scale, Scale,
PassThrough, PassThrough,
PassThrough, PassThrough,
MaskingSpecialization::MaskUpperTriangleFromTopLeft>>>& MaskingSpecialization::MaskDisabled>>>& instances);
instances);
void add_device_batched_mha_bwd_qloop_noncasual_bf16_bf16_gmk_gnk_gno_gmo_instances(
std::vector<
std::unique_ptr<DeviceBatchedMultiheadAttentionBackward<2,
1,
1,
1,
1,
BF16,
BF16,
unsigned short,
F32,
ck::Tuple<>,
ck::Tuple<>,
PassThrough,
PassThrough,
Scale,
PassThrough,
PassThrough,
MaskingSpecialization::MaskDisabled>>>&
instances);
template <typename InputDataType, template <typename InputDataType,
typename OutputDataType, typename OutputDataType,
...@@ -151,8 +148,7 @@ struct DeviceOperationInstanceFactory< ...@@ -151,8 +148,7 @@ struct DeviceOperationInstanceFactory<
{ {
if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromTopLeft) if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromTopLeft)
{ {
add_device_batched_mha_bwd_qloop_casual_f16_f16_gmk_gnk_gno_gmo_instances( add_device_batched_mha_bwd_qloop_casual_f16_f16_gmk_gnk_gno_gmo_instances(op_ptrs);
op_ptrs);
} }
else if(MaskingSpec == MaskingSpecialization::MaskDisabled) else if(MaskingSpec == MaskingSpecialization::MaskDisabled)
{ {
......
...@@ -32,10 +32,11 @@ using YElementOp = PassThrough; ...@@ -32,10 +32,11 @@ using YElementOp = PassThrough;
using Acc0BiasDataType = ck::Tuple<>; using Acc0BiasDataType = ck::Tuple<>;
using Acc1BiasDataType = ck::Tuple<>; using Acc1BiasDataType = ck::Tuple<>;
//static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; // static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
//static constexpr auto TensorDefault = ck::tensor_operation::device::TensorSpecialization::Default; // static constexpr auto TensorDefault =
// ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default;
...@@ -50,9 +51,8 @@ template <index_t NumDimG, ...@@ -50,9 +51,8 @@ template <index_t NumDimG,
index_t NumDimK, index_t NumDimK,
index_t NumDimO, index_t NumDimO,
MaskingSpecialization MaskingSpec> MaskingSpecialization MaskingSpec>
using device_batched_mha_bwd_qloop_bf16_bf16_gmk_gnk_gno_gmo_instances = using device_batched_mha_bwd_qloop_bf16_bf16_gmk_gnk_gno_gmo_instances = std::tuple<
std::tuple< // clang-format off
// clang-format off
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic| // ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | | // ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | | // ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
...@@ -61,71 +61,67 @@ using device_batched_mha_bwd_qloop_bf16_bf16_gmk_gnk_gno_gmo_instances = ...@@ -61,71 +61,67 @@ using device_batched_mha_bwd_qloop_bf16_bf16_gmk_gnk_gno_gmo_instances =
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, unsigned short, F32, Acc0BiasDataType, Acc1BiasDataType, F32, F32, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>, // ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, unsigned short, F32, Acc0BiasDataType, Acc1BiasDataType, F32, F32, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>,
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, unsigned short, F32, Acc0BiasDataType, Acc1BiasDataType, F32, F32, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>, // ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, unsigned short, F32, Acc0BiasDataType, Acc1BiasDataType, F32, F32, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>,
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, unsigned short, F32, Acc0BiasDataType, Acc1BiasDataType, F32, F32, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic> // ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, unsigned short, F32, Acc0BiasDataType, Acc1BiasDataType, F32, F32, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>
// clang-format on // clang-format on
>; >;
void add_device_batched_mha_bwd_qloop_casual_bf16_bf16_gmk_gnk_gno_gmo_instances( void add_device_batched_mha_bwd_qloop_casual_bf16_bf16_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<DeviceBatchedMultiheadAttentionBackward<
DeviceBatchedMultiheadAttentionBackward<2, 2,
1, 1,
1, 1,
1, 1,
1, 1,
BF16, BF16,
BF16, BF16,
unsigned short, unsigned short,
F32, F32,
ck::Tuple<>, ck::Tuple<>,
ck::Tuple<>, ck::Tuple<>,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale, Scale,
PassThrough, PassThrough,
PassThrough, PassThrough,
MaskingSpecialization::MaskUpperTriangleFromTopLeft>>>& MaskingSpecialization::MaskUpperTriangleFromTopLeft>>>& instances)
instances)
{ {
add_device_operation_instances( add_device_operation_instances(instances,
instances, device_batched_mha_bwd_qloop_bf16_bf16_gmk_gnk_gno_gmo_instances<
device_batched_mha_bwd_qloop_bf16_bf16_gmk_gnk_gno_gmo_instances< 2,
2, 1,
1, 1,
1, 1,
1, 1,
1, MaskingSpecialization::MaskUpperTriangleFromTopLeft>{});
MaskingSpecialization::MaskUpperTriangleFromTopLeft>{});
} }
void add_device_batched_mha_bwd_qloop_noncasual_bf16_bf16_gmk_gnk_gno_gmo_instances( void add_device_batched_mha_bwd_qloop_noncasual_bf16_bf16_gmk_gnk_gno_gmo_instances(
std::vector< std::vector<std::unique_ptr<
std::unique_ptr<DeviceBatchedMultiheadAttentionBackward<2, DeviceBatchedMultiheadAttentionBackward<2,
1, 1,
1, 1,
1, 1,
1, 1,
BF16, BF16,
BF16, BF16,
unsigned short, unsigned short,
F32, F32,
ck::Tuple<>, ck::Tuple<>,
ck::Tuple<>, ck::Tuple<>,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale, Scale,
PassThrough, PassThrough,
PassThrough, PassThrough,
MaskingSpecialization::MaskDisabled>>>& MaskingSpecialization::MaskDisabled>>>& instances)
instances)
{ {
add_device_operation_instances( add_device_operation_instances(instances,
instances, device_batched_mha_bwd_qloop_bf16_bf16_gmk_gnk_gno_gmo_instances<
device_batched_mha_bwd_qloop_bf16_bf16_gmk_gnk_gno_gmo_instances< 2,
2, 1,
1, 1,
1, 1,
1, 1,
1, MaskingSpecialization::MaskDisabled>{});
MaskingSpecialization::MaskDisabled>{});
} }
} // namespace instance } // namespace instance
......
...@@ -32,10 +32,11 @@ using YElementOp = PassThrough; ...@@ -32,10 +32,11 @@ using YElementOp = PassThrough;
using Acc0BiasDataType = ck::Tuple<>; using Acc0BiasDataType = ck::Tuple<>;
using Acc1BiasDataType = ck::Tuple<>; using Acc1BiasDataType = ck::Tuple<>;
//static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; // static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
//static constexpr auto TensorDefault = ck::tensor_operation::device::TensorSpecialization::Default; // static constexpr auto TensorDefault =
// ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default;
...@@ -50,9 +51,8 @@ template <index_t NumDimG, ...@@ -50,9 +51,8 @@ template <index_t NumDimG,
index_t NumDimK, index_t NumDimK,
index_t NumDimO, index_t NumDimO,
MaskingSpecialization MaskingSpec> MaskingSpecialization MaskingSpec>
using device_batched_mha_bwd_qloop_f16_f16_gmk_gnk_gno_gmo_instances = using device_batched_mha_bwd_qloop_f16_f16_gmk_gnk_gno_gmo_instances = std::tuple<
std::tuple< // clang-format off
// clang-format off
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic| // ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | | // ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | | // ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
...@@ -61,71 +61,67 @@ using device_batched_mha_bwd_qloop_f16_f16_gmk_gnk_gno_gmo_instances = ...@@ -61,71 +61,67 @@ using device_batched_mha_bwd_qloop_f16_f16_gmk_gnk_gno_gmo_instances =
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, unsigned short, F32, Acc0BiasDataType, Acc1BiasDataType, F32, F32, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>, // ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, unsigned short, F32, Acc0BiasDataType, Acc1BiasDataType, F32, F32, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>,
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, unsigned short, F32, Acc0BiasDataType, Acc1BiasDataType, F32, F32, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>, // ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, unsigned short, F32, Acc0BiasDataType, Acc1BiasDataType, F32, F32, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>,
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, unsigned short, F32, Acc0BiasDataType, Acc1BiasDataType, F32, F32, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic> // ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, unsigned short, F32, Acc0BiasDataType, Acc1BiasDataType, F32, F32, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>
// clang-format on // clang-format on
>; >;
void add_device_batched_mha_bwd_qloop_casual_f16_f16_gmk_gnk_gno_gmo_instances( void add_device_batched_mha_bwd_qloop_casual_f16_f16_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<DeviceBatchedMultiheadAttentionBackward<
DeviceBatchedMultiheadAttentionBackward<2, 2,
1, 1,
1, 1,
1, 1,
1, 1,
F16, F16,
F16, F16,
unsigned short, unsigned short,
F32, F32,
ck::Tuple<>, ck::Tuple<>,
ck::Tuple<>, ck::Tuple<>,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale, Scale,
PassThrough, PassThrough,
PassThrough, PassThrough,
MaskingSpecialization::MaskUpperTriangleFromTopLeft>>>& MaskingSpecialization::MaskUpperTriangleFromTopLeft>>>& instances)
instances)
{ {
add_device_operation_instances( add_device_operation_instances(instances,
instances, device_batched_mha_bwd_qloop_f16_f16_gmk_gnk_gno_gmo_instances<
device_batched_mha_bwd_qloop_f16_f16_gmk_gnk_gno_gmo_instances< 2,
2, 1,
1, 1,
1, 1,
1, 1,
1, MaskingSpecialization::MaskUpperTriangleFromTopLeft>{});
MaskingSpecialization::MaskUpperTriangleFromTopLeft>{});
} }
void add_device_batched_mha_bwd_qloop_noncasual_f16_f16_gmk_gnk_gno_gmo_instances( void add_device_batched_mha_bwd_qloop_noncasual_f16_f16_gmk_gnk_gno_gmo_instances(
std::vector< std::vector<std::unique_ptr<
std::unique_ptr<DeviceBatchedMultiheadAttentionBackward<2, DeviceBatchedMultiheadAttentionBackward<2,
1, 1,
1, 1,
1, 1,
1, 1,
F16, F16,
F16, F16,
unsigned short, unsigned short,
F32, F32,
ck::Tuple<>, ck::Tuple<>,
ck::Tuple<>, ck::Tuple<>,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale, Scale,
PassThrough, PassThrough,
PassThrough, PassThrough,
MaskingSpecialization::MaskDisabled>>>& MaskingSpecialization::MaskDisabled>>>& instances)
instances)
{ {
add_device_operation_instances( add_device_operation_instances(instances,
instances, device_batched_mha_bwd_qloop_f16_f16_gmk_gnk_gno_gmo_instances<
device_batched_mha_bwd_qloop_f16_f16_gmk_gnk_gno_gmo_instances< 2,
2, 1,
1, 1,
1, 1,
1, 1,
1, MaskingSpecialization::MaskDisabled>{});
MaskingSpecialization::MaskDisabled>{});
} }
} // namespace instance } // namespace instance
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment