Commit bfa06cf2 authored by fsx950223's avatar fsx950223
Browse files

fix bugs

parent 627016c1
...@@ -25,6 +25,7 @@ Kernel outputs: ...@@ -25,6 +25,7 @@ Kernel outputs:
#define PRINT_HOST 0 #define PRINT_HOST 0
#define USING_MASK 0 #define USING_MASK 0
#define VERSION 1
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
...@@ -47,8 +48,6 @@ Kernel outputs: ...@@ -47,8 +48,6 @@ Kernel outputs:
#include "ck/library/reference_tensor_operation/cpu/reference_dropout.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_dropout.hpp"
#define DIM 32
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
...@@ -90,15 +89,15 @@ static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpeciali ...@@ -90,15 +89,15 @@ static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpeciali
static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecialization::Default;
#if DIM == 32 #if VERSION == 1
using DeviceGemmInstance = using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2<
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
NumDimK, NumDimK,
NumDimO, NumDimO,
DataType, DataType,
GemmDataType,
ZDataType, ZDataType,
LSEDataType, LSEDataType,
Acc0BiasDataType, Acc0BiasDataType,
...@@ -119,8 +118,8 @@ using DeviceGemmInstance = ...@@ -119,8 +118,8 @@ using DeviceGemmInstance =
256, 256,
128, // MPerBlock 128, // MPerBlock
128, // NPerBlock 128, // NPerBlock
32, // KPerBlock 64, // KPerBlock
32, // Gemm1NPerBlock 128, // Gemm1NPerBlock
32, // Gemm1KPerBlock 32, // Gemm1KPerBlock
8, // AK1 8, // AK1
8, // BK1 8, // BK1
...@@ -129,8 +128,8 @@ using DeviceGemmInstance = ...@@ -129,8 +128,8 @@ using DeviceGemmInstance =
32, // NPerXDL 32, // NPerXDL
1, // MXdlPerWave 1, // MXdlPerWave
4, // NXdlPerWave 4, // NXdlPerWave
1, // Gemm1NXdlPerWave 4, // Gemm1NXdlPerWave
1, // Gemm2NXdlPerWave 2, // Gemm2NXdlPerWave
S<4, 64, 1>, // ABlockTransfer S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
...@@ -153,14 +152,79 @@ using DeviceGemmInstance = ...@@ -153,14 +152,79 @@ using DeviceGemmInstance =
2, 2,
false, false,
1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle 4, // CShuffleNXdlPerWavePerShuffle
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
#elif DIM == 64 #elif VERSION == 2
// 2nd template using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2<
using DeviceGemmInstance = NumDimG,
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1< NumDimM,
NumDimN,
NumDimK,
NumDimO,
DataType,
GemmDataType,
ZDataType,
LSEDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AccDataType,
ShuffleDataType,
QKVElementOp,
QKVElementOp,
Scale,
QKVElementOp,
YElementOp,
GemmSpec,
TensorSpecQ,
TensorSpecK,
TensorSpecV,
TensorSpecY,
1,
256,
128, // MPerBlock
128, // NPerBlock
64, // KPerBlock
64, // Gemm1NPerBlock
64, // Gemm1KPerBlock
8, // AK1
8, // BK1
2, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
2, // Gemm1NXdlPerWave
2, // Gemm2NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>, // BBlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<8, 32, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
2,
2,
false,
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
#elif VERSION == 3
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -226,7 +290,7 @@ using DeviceGemmInstance = ...@@ -226,7 +290,7 @@ using DeviceGemmInstance =
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
#else #else
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2< using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -254,8 +318,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadA ...@@ -254,8 +318,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadA
128, // MPerBlock 128, // MPerBlock
128, // NPerBlock 128, // NPerBlock
32, // KPerBlock 32, // KPerBlock
128, // Gemm1NPerBlock 32, // Gemm1NPerBlock
64, // Gemm1KPerBlock 32, // Gemm1KPerBlock
8, // AK1 8, // AK1
8, // BK1 8, // BK1
2, // B1K1 2, // B1K1
...@@ -263,7 +327,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadA ...@@ -263,7 +327,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadA
32, // NPerXDL 32, // NPerXDL
1, // MXdlPerWave 1, // MXdlPerWave
4, // NXdlPerWave 4, // NXdlPerWave
4, // Gemm1NXdlPerWave 1, // Gemm1NXdlPerWave
1, // Gemm2NXdlPerWave
S<4, 64, 1>, // ABlockTransfer S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
...@@ -286,8 +351,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadA ...@@ -286,8 +351,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadA
2, 2,
false, false,
1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleMXdlPerWavePerShuffle
4, // CShuffleNXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
#endif #endif
......
...@@ -151,6 +151,7 @@ template <index_t NumDimG, ...@@ -151,6 +151,7 @@ template <index_t NumDimG,
index_t NumDimK, index_t NumDimK,
index_t NumDimO, // NumDimGemm1N index_t NumDimO, // NumDimGemm1N
typename DataType, typename DataType,
typename GemmDataType,
typename ZDataType, typename ZDataType,
typename LSEDataType, typename LSEDataType,
typename Acc0BiasDataType, typename Acc0BiasDataType,
...@@ -182,6 +183,7 @@ template <index_t NumDimG, ...@@ -182,6 +183,7 @@ template <index_t NumDimG,
index_t MXdlPerWave, index_t MXdlPerWave,
index_t NXdlPerWave, index_t NXdlPerWave,
index_t Gemm1NXdlPerWave, index_t Gemm1NXdlPerWave,
index_t Gemm2NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1, typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder, typename ABlockTransferSrcAccessOrder,
...@@ -526,9 +528,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -526,9 +528,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2< using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
DataType, // TODO: distinguish A/B datatype DataType, // TODO: distinguish A/B datatype
LSEDataType, GemmDataType,
GemmAccDataType, GemmAccDataType,
CShuffleDataType, CShuffleDataType,
LSEDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
AccElementwiseOperation, AccElementwiseOperation,
...@@ -556,6 +559,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -556,6 +559,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
MXdlPerWave, MXdlPerWave,
NXdlPerWave, NXdlPerWave,
Gemm1NXdlPerWave, Gemm1NXdlPerWave,
Gemm2NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment