Commit bca3f14c authored by coderfeli's avatar coderfeli
Browse files

fix nswizzle=0

parent e78fbf87
...@@ -139,6 +139,7 @@ static constexpr ck::index_t BLOCKSIZE = 256; ...@@ -139,6 +139,7 @@ static constexpr ck::index_t BLOCKSIZE = 256;
static constexpr ck::index_t NPerBlock = 128; static constexpr ck::index_t NPerBlock = 128;
static constexpr ck::index_t MNPerXDL = 32; static constexpr ck::index_t MNPerXDL = 32;
static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType); static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType);
static constexpr ck::index_t Nswizzle = true;
static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType);
static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType); static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType);
static constexpr ck::index_t EVec = 16 / sizeof(EDataType); static constexpr ck::index_t EVec = 16 / sizeof(EDataType);
...@@ -174,7 +175,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm ...@@ -174,7 +175,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
MXDLPerWave, 1, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec>, MXDLPerWave, 1, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, true, A0DataType>; ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, A0DataType>;
// kernel 2: 128->32x128x128 // kernel 2: 128->32x128x128
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>; // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>;
......
...@@ -169,7 +169,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm ...@@ -169,7 +169,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
CShuffleMXDLPerWave, 1, S<1, CShuffleMLane, 1, CShuffleNLane>, S<EVec, D0Vec, D1Vec, D2Vec>, CShuffleMXDLPerWave, 1, S<1, CShuffleMLane, 1, CShuffleNLane>, S<EVec, D0Vec, D1Vec, D2Vec>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, A0DataType>; ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, A0DataType>;
// kernel 2: 128->32x128x128 // kernel 2: 128->32x128x128
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>; // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>;
......
...@@ -66,6 +66,7 @@ template <typename ALayout, ...@@ -66,6 +66,7 @@ template <typename ALayout,
typename CDEShuffleBlockTransferScalarPerVectors, typename CDEShuffleBlockTransferScalarPerVectors,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1, BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
bool NSwizzle = false,
bool IsInputGemm = true, bool IsInputGemm = true,
typename ComputeTypeA = CDataType, typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA, typename ComputeTypeB = ComputeTypeA,
...@@ -133,6 +134,7 @@ struct DeviceMoeGemm ...@@ -133,6 +134,7 @@ struct DeviceMoeGemm
CDEShuffleBlockTransferScalarPerVectors, CDEShuffleBlockTransferScalarPerVectors,
BlkGemmPipeSched, BlkGemmPipeSched,
BlkGemmPipelineVer, BlkGemmPipelineVer,
NSwizzle,
ComputeTypeA, ComputeTypeA,
ComputeTypeB, ComputeTypeB,
LDSTypeA, LDSTypeA,
......
...@@ -142,6 +142,7 @@ template <typename ALayout, ...@@ -142,6 +142,7 @@ template <typename ALayout,
typename CDEShuffleBlockTransferScalarPerVectors, typename CDEShuffleBlockTransferScalarPerVectors,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1, BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
bool NSwizzle = false,
typename ComputeTypeA = CDataType, typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA, typename ComputeTypeB = ComputeTypeA,
typename LDSTypeA = ADataType, typename LDSTypeA = ADataType,
...@@ -197,9 +198,11 @@ struct GridwiseMoeGemm ...@@ -197,9 +198,11 @@ struct GridwiseMoeGemm
__host__ static auto CalculateGridSize(index_t M, index_t N) __host__ static auto CalculateGridSize(index_t M, index_t N)
{ {
return std::make_tuple(math::integer_divide_ceil(N, NPerBlock) * math::integer_divide_ceil(M, MPerBlock), const index_t nblock = math::integer_divide_ceil(N, NPerBlock);
1, const index_t mblock = math::integer_divide_ceil(M, MPerBlock);
1); const index_t gridx = NSwizzle ? nblock * mblock : nblock;
const index_t gridy = NSwizzle ? 1 : mblock;
return std::make_tuple(gridx, gridy, 1);
} }
__host__ __device__ static auto CalculateMPadded(index_t M) __host__ __device__ static auto CalculateMPadded(index_t M)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment