"driver/src/conv_driver.cpp" did not exist on "e72eece8fcf79d1d3a958089fca1f02bfb71b777"
Commit bfa06cf2 authored by fsx950223's avatar fsx950223
Browse files

fix bugs

parent 627016c1
......@@ -25,6 +25,7 @@ Kernel outputs:
#define PRINT_HOST 0
#define USING_MASK 0
#define VERSION 1
#include <iostream>
#include <numeric>
......@@ -47,8 +48,6 @@ Kernel outputs:
#include "ck/library/reference_tensor_operation/cpu/reference_dropout.hpp"
#define DIM 32
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
......@@ -90,15 +89,15 @@ static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpeciali
static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecialization::Default;
#if DIM == 32
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1<
#if VERSION == 1
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2<
NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
DataType,
GemmDataType,
ZDataType,
LSEDataType,
Acc0BiasDataType,
......@@ -119,8 +118,8 @@ using DeviceGemmInstance =
256,
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
32, // Gemm1NPerBlock
64, // KPerBlock
128, // Gemm1NPerBlock
32, // Gemm1KPerBlock
8, // AK1
8, // BK1
......@@ -129,8 +128,8 @@ using DeviceGemmInstance =
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
1, // Gemm1NXdlPerWave
1, // Gemm2NXdlPerWave
4, // Gemm1NXdlPerWave
2, // Gemm2NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
......@@ -153,14 +152,79 @@ using DeviceGemmInstance =
2,
false,
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
4, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
#elif DIM == 64
// 2nd template
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1<
#elif VERSION == 2
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2<
NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
DataType,
GemmDataType,
ZDataType,
LSEDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AccDataType,
ShuffleDataType,
QKVElementOp,
QKVElementOp,
Scale,
QKVElementOp,
YElementOp,
GemmSpec,
TensorSpecQ,
TensorSpecK,
TensorSpecV,
TensorSpecY,
1,
256,
128, // MPerBlock
128, // NPerBlock
64, // KPerBlock
64, // Gemm1NPerBlock
64, // Gemm1KPerBlock
8, // AK1
8, // BK1
2, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
2, // Gemm1NXdlPerWave
2, // Gemm2NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>, // BBlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<8, 32, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
2,
2,
false,
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
#elif VERSION == 3
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1<
NumDimG,
NumDimM,
NumDimN,
......@@ -226,7 +290,7 @@ using DeviceGemmInstance =
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
#else
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2<
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1<
NumDimG,
NumDimM,
NumDimN,
......@@ -254,8 +318,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadA
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
128, // Gemm1NPerBlock
64, // Gemm1KPerBlock
32, // Gemm1NPerBlock
32, // Gemm1KPerBlock
8, // AK1
8, // BK1
2, // B1K1
......@@ -263,7 +327,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadA
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
4, // Gemm1NXdlPerWave
1, // Gemm1NXdlPerWave
1, // Gemm2NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
......@@ -286,8 +351,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadA
2,
false,
1, // CShuffleMXdlPerWavePerShuffle
4, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
1, // CShuffleNXdlPerWavePerShuffle
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
#endif
......
......@@ -151,6 +151,7 @@ template <index_t NumDimG,
index_t NumDimK,
index_t NumDimO, // NumDimGemm1N
typename DataType,
typename GemmDataType,
typename ZDataType,
typename LSEDataType,
typename Acc0BiasDataType,
......@@ -182,6 +183,7 @@ template <index_t NumDimG,
index_t MXdlPerWave,
index_t NXdlPerWave,
index_t Gemm1NXdlPerWave,
index_t Gemm2NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
......@@ -526,9 +528,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
// GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
DataType, // TODO: distinguish A/B datatype
LSEDataType,
GemmDataType,
GemmAccDataType,
CShuffleDataType,
LSEDataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
......@@ -556,6 +559,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
MXdlPerWave,
NXdlPerWave,
Gemm1NXdlPerWave,
Gemm2NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment