Commit b07e21a4 authored by mtgu0705's avatar mtgu0705
Browse files

Enable pipeline v1 & v2 for int8 quant

parent 80b46cae
...@@ -80,7 +80,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_BScale_X ...@@ -80,7 +80,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_BScale_X
// S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, // S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>,
// ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>; // ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, false, PermuteB>; ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, PermuteB>;
// clang-format on // clang-format on
......
...@@ -81,7 +81,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_BScale_X ...@@ -81,7 +81,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_BScale_X
// S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, // S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>,
// ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>; // ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1>; ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3>;
// clang-format on // clang-format on
template <typename IntType> template <typename IntType>
......
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
#pragma once #pragma once
// #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp"
// #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp"
namespace ck { namespace ck {
...@@ -39,26 +39,79 @@ template <BlockGemmPipelineVersion BlkGemmPipelineVer, ...@@ -39,26 +39,79 @@ template <BlockGemmPipelineVersion BlkGemmPipelineVer,
index_t KPack> index_t KPack>
constexpr auto BlockGemmBScalePipeline_Selector() constexpr auto BlockGemmBScalePipeline_Selector()
{ {
return BlockwiseGemmXdlops_pipeline_v3_b_scale<BlkGemmPipeSche, if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
BlockSize, {
ADataType, return BlockwiseGemmXdlops_pipeline_v1_b_scale<BlkGemmPipeSche,
BDataType, BlockSize,
ComputeDataType, ADataType,
AccDataType, BDataType,
ATileDesc, ComputeDataType,
BTileDesc, AccDataType,
AMmaTileDesc, ATileDesc,
BMmaTileDesc, BTileDesc,
ABlockTransferSrcScalarPerVector, AMmaTileDesc,
BBlockTransferSrcScalarPerVector, BMmaTileDesc,
MPerBlock, ABlockTransferSrcScalarPerVector,
NPerBlock, BBlockTransferSrcScalarPerVector,
KPerBlock, MPerBlock,
MPerXDL, NPerBlock,
NPerXDL, KPerBlock,
MRepeat, MPerXDL,
NRepeat, NPerXDL,
KPack>{}; MRepeat,
NRepeat,
KPack>{};
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
{
return BlockwiseGemmXdlops_pipeline_v2_b_scale<BlkGemmPipeSche,
BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
return BlockwiseGemmXdlops_pipeline_v3_b_scale<BlkGemmPipeSche,
BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
}
else
{
std::cerr << "BlockGemmPipeline configuration is not available" << std::endl;
}
} }
} // namespace ck } // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp"
namespace ck {
// Naive pipeline with lowest resource request per WGP
// GlobalPrefetchStages: 1
// LocalPreFillStages: 1
// LocalPreFetchStages: 0
// LocalSharedMemoryBuffer: 1
template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
index_t BlockSize,
typename ADataType,
typename BDataType,
typename ComputeDataType,
typename AccDataType,
typename ATileDesc,
typename BTileDesc,
typename AMmaTileDesc,
typename BMmaTileDesc,
index_t ABlockTransferSrcScalarPerVector,
index_t BBlockTransferSrcScalarPerVector,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerXDL,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t KPacks>
struct BlockwiseGemmXdlops_pipeline_v1_b_scale
{
};
template <index_t BlockSize,
typename ADataType,
typename BDataType,
typename ComputeDataType,
typename AccDataType,
typename ATileDesc,
typename BTileDesc,
typename AMmaTileDesc,
typename BMmaTileDesc,
index_t ABlockTransferSrcScalarPerVector,
index_t BBlockTransferSrcScalarPerVector,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerXDL,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t KPack
// ,bool TransposeC //disable transposec right now...
>
struct BlockwiseGemmXdlops_pipeline_v1_b_scale<BlockGemmPipelineScheduler::Intrawave,
BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>
: BlockwiseGemmXdlops_pipeline_base<BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>
{
using Base = BlockwiseGemmXdlops_pipeline_base<BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>;
using Base::I0;
using Base::KRepeat;
using Base::xdlops_gemm;
using Base::CalculateCThreadOriginDataIndex;
using Base::CalculateCThreadOriginDataIndex8D;
using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
using Base::GetCThreadBuffer;
using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::a_block_desc_m0_m1_m2_k;
using Base::b_block_desc_n0_n1_n2_k;
using Base::AMmaKStride;
using Base::BMmaKStride;
static constexpr index_t PrefetchStages = 1;
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1;
__host__ static constexpr bool BlockHasHotloop(index_t num_loop)
{
return num_loop > PrefetchStages;
}
__host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
{
ignore = num_loop;
return TailNumber::Full;
}
template <bool HasMainLoop,
TailNumber TailNum,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename CThreadBuffer,
// typename AScaleGridBuffer,
// typename AScaleGridDesc,
// typename AScaleThreadDesc,
// typename AScaleThreadTransfer,
// typename AScaleThreadTransferStep,
typename BScaleGridBuffer,
typename BScaleGridDesc,
typename BScaleThreadDesc,
typename BScaleThreadTransfer,
typename BScaleThreadTransferStep>
__device__ void Run(
// ABlockCopy
const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
// BBlockCopy
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
// CThread
CThreadBuffer& c_thread_buf,
// // AScaleThreadCopy
// const AScaleGridDesc& a_scale_grid_desc,
// const AScaleThreadDesc& a_scale_thread_desc,
// AScaleThreadTransfer& a_scale_thread_copy,
// const AScaleGridBuffer& a_scale_grid_buf,
// const AScaleThreadTransferStep& a_scale_thread_copy_step,
// BScaleThreadCopy
const BScaleGridDesc& b_scale_grid_desc,
const BScaleThreadDesc& b_scale_thread_desc,
BScaleThreadTransfer& b_scale_thread_copy,
const BScaleGridBuffer& b_scale_grid_buf,
const BScaleThreadTransferStep& b_scale_thread_copy_step,
// num_loop
index_t num_loop,
index_t num_loop_per_scale) const
{
// assume kperblock = scaleblockk
ignore = num_loop_per_scale;
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
b_thread_desc_.GetElementSpaceSize());
// auto a_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
// a_scale_thread_desc.GetElementSpaceSize());
// auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
// b_scale_thread_desc.GetElementSpaceSize());
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
b_scale_thread_desc.GetElementSpaceSize());
// Global prefetch 1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// a_scale_thread_copy.Run(a_scale_grid_desc,
// a_scale_grid_buf,
// a_scale_thread_desc,
// make_tuple(I0, I0),
// a_scale_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf,
b_scale_thread_desc,
make_tuple(n0, I0),
b_scale_thread_buf);
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
b_scale_thread_copy_step.At(Number<0>{}));
});
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
b_scale_thread_copy_step.At(Number<1>{}));
// a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, a_scale_thread_copy_step);
// b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
// Local prefill 1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
// Initialize C
c_thread_buf.Clear();
auto c_thread_buf_per_scale = remove_cvref_t<decltype(c_thread_buf)>();
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
// -------------------------------------------------------------------------------------------
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_buf);
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_buf);
});
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
c_thread_buf_per_scale.Clear();
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
});
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
c_thread_buf(Number<c_offset>{}) +=
c_thread_buf_per_scale[Number<t>{}] *
// type_convert<AccDataType>(a_scale_thread_buf[I0]) *
type_convert<AccDataType>(b_scale_thread_buf[n0]);
});
});
});
// a_scale_thread_copy.Run(a_scale_grid_desc,
// a_scale_grid_buf,
// a_scale_thread_desc,
// make_tuple(I0, I0),
// a_scale_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf,
b_scale_thread_desc,
make_tuple(n0, I0),
b_scale_thread_buf);
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc, b_scale_thread_copy_step.At(Number<0>{}));
});
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
b_scale_thread_copy_step.At(Number<1>{}));
// a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, a_scale_thread_copy_step);
// b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
i += 1;
} while(i < (num_loop - 1));
}
// tail
if constexpr(TailNum == TailNumber::Full)
{
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_buf);
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_buf);
});
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
c_thread_buf_per_scale.Clear();
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
});
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
c_thread_buf(Number<c_offset>{}) +=
c_thread_buf_per_scale[Number<t>{}] *
// type_convert<AccDataType>(a_scale_thread_buf[I0]) *
type_convert<AccDataType>(b_scale_thread_buf[n0]);
});
});
});
}
}
protected:
using Base::a_thread_copy_;
using Base::a_thread_desc_;
using Base::b_thread_copy_;
using Base::b_thread_desc_;
using Base::c_thread_desc_;
};
} // namespace ck
...@@ -330,7 +330,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -330,7 +330,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
// a_scale_thread_desc.GetElementSpaceSize()); // a_scale_thread_desc.GetElementSpaceSize());
// auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>( // auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
// b_scale_thread_desc.GetElementSpaceSize()); // b_scale_thread_desc.GetElementSpaceSize());
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ck::half_t>( auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
b_scale_thread_desc.GetElementSpaceSize()); b_scale_thread_desc.GetElementSpaceSize());
// Global prefetch 1 // Global prefetch 1
......
...@@ -148,6 +148,8 @@ struct DeviceGemmMultiD_BScale_Xdl_CShuffle_V3 ...@@ -148,6 +148,8 @@ struct DeviceGemmMultiD_BScale_Xdl_CShuffle_V3
BlkGemmPipelineVer, BlkGemmPipelineVer,
ComputeTypeA, ComputeTypeA,
ComputeTypeB, ComputeTypeB,
PermuteA,
PermuteB,
LDSTypeA, LDSTypeA,
LDSTypeB>; LDSTypeB>;
......
...@@ -111,6 +111,8 @@ template <typename ALayout, ...@@ -111,6 +111,8 @@ template <typename ALayout,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1, BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
typename ComputeTypeA = CDataType, typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA, typename ComputeTypeB = ComputeTypeA,
bool PermuteA = false,
bool PermuteB = false,
typename LDSTypeA = ADataType, typename LDSTypeA = ADataType,
typename LDSTypeB = BDataType> typename LDSTypeB = BDataType>
struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3 struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
...@@ -158,7 +160,7 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3 ...@@ -158,7 +160,7 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
static constexpr index_t APackedSize = []() { static constexpr index_t APackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<ADataType>, pk_i4_t>) if constexpr(is_same_v<remove_cvref_t<ADataType>, pk_i4_t>)
return 2; return 2;
else else
...@@ -340,6 +342,10 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3 ...@@ -340,6 +342,10 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
using GemmSpecialization = tensor_operation::device::GemmSpecialization; using GemmSpecialization = tensor_operation::device::GemmSpecialization;
static_assert(!(is_same_v<remove_cvref_t<ADataType>, pk_i4_t> &&
GemmSpec != GemmSpecialization::Default),
"pk_i4_t does not support padding");
if constexpr(GemmSpec == GemmSpecialization::NKPadding || if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding) GemmSpec == GemmSpecialization::MNKPadding)
{ {
...@@ -394,15 +400,39 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3 ...@@ -394,15 +400,39 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
} }
else else
{ {
// not pad N or K if constexpr(!PermuteB)
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( {
b_grid_desc_nraw_kraw, // not pad N or K
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
make_pass_through_transform(N)), b_grid_desc_nraw_kraw,
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
return b_grid_desc_bk0_n_bk1; make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else
{
// Weight Tile Permute
constexpr index_t BK01 = KPerBlock / BK1Value;
// const index_t BK00 = BK0 / BK01;
const index_t BK0_ = StrideB / BK1Value;
const index_t BK00 = BK0_ / BK01;
const auto b_grid_desc_bk00_n_bk01_bk1_permute =
make_naive_tensor_descriptor_packed(make_tuple(BK00, N, BK01, BK1Value));
const auto b_grid_desc_bk0_n_bk1_permute = transform_tensor_descriptor(
b_grid_desc_bk00_n_bk01_bk1_permute,
make_tuple(make_merge_transform(make_tuple(BK00, BK01)),
make_pass_through_transform(make_tuple(N)),
make_pass_through_transform(BK1Value)),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return b_grid_desc_bk0_n_bk1_permute;
}
} }
} }
...@@ -439,6 +469,13 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3 ...@@ -439,6 +469,13 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
} }
}(); }();
// // pad M and N
// return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
// make_tuple(make_right_pad_transform(M, MPad - M),
// make_right_pad_transform(N, NPad - N)),
// make_tuple(Sequence<0>{}, Sequence<1>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}));
using GemmSpecialization = tensor_operation::device::GemmSpecialization; using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::MNPadding || if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
...@@ -630,20 +667,28 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3 ...@@ -630,20 +667,28 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
{ {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>) if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{ {
a_k_split_offset = blockIdx.z * karg.KRead; a_k_split_offset = blockIdx.z * karg.KRead / APackedSize;
} }
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>) else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{ {
a_k_split_offset = blockIdx.z * karg.KRead * karg.M; a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA;
} }
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>) if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{ {
b_k_split_offset = blockIdx.z * karg.KRead * karg.N; b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB;
} }
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>) else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{ {
b_k_split_offset = blockIdx.z * karg.KRead; if constexpr(!PermuteB)
{
b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize;
}
else
{
const int k0_offset = karg.KRead * karg.N;
b_k_split_offset = blockIdx.z * k0_offset / BPackedSize;
}
} }
if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1)) if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1))
...@@ -673,9 +718,9 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3 ...@@ -673,9 +718,9 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
// in some cases. // in some cases.
else if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value) else if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{ {
constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(LDSTypeA) < 1 constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(LDSTypeA) / APackedSize < 1
? 1 ? 1
: 32 * 4 / KPerBlock / sizeof(LDSTypeA); : 32 * 4 / KPerBlock / sizeof(LDSTypeA) / APackedSize;
constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( constexpr auto a_lds_block_desc = make_naive_tensor_descriptor(
make_tuple( make_tuple(
AK0Number * Number<MLdsLayer>{}, Number<MPerBlock / MLdsLayer>{}, AK1Number), AK0Number * Number<MLdsLayer>{}, Number<MPerBlock / MLdsLayer>{}, AK1Number),
...@@ -809,10 +854,10 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3 ...@@ -809,10 +854,10 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{ {
// NLdsLayer * K0 as logical Bank // NLdsLayer * K0 as logical Bank
constexpr auto NLdsLayer = 32 * 4 / KPerBlock / sizeof(LDSTypeB) < 1 constexpr auto NLdsLayer = 32 * 4 / KPerBlock / sizeof(LDSTypeB) / BPackedSize < 1
? 1 ? 1
: 32 * 4 / KPerBlock / sizeof(LDSTypeB); : 32 * 4 / KPerBlock / sizeof(LDSTypeB) / BPackedSize;
;
constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( constexpr auto b_lds_block_desc = make_naive_tensor_descriptor(
make_tuple( make_tuple(
BK0Number * Number<NLdsLayer>{}, Number<NPerBlock / NLdsLayer>{}, BK1Number), BK0Number * Number<NLdsLayer>{}, Number<NPerBlock / NLdsLayer>{}, BK1Number),
...@@ -994,8 +1039,8 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3 ...@@ -994,8 +1039,8 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
constexpr auto c_block_size = constexpr auto c_block_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
return math::max((a_block_space_size_aligned * sizeof(LDSTypeA) + return math::max((a_block_space_size_aligned * sizeof(LDSTypeA) / APackedSize +
b_block_space_size_aligned * sizeof(LDSTypeB)), b_block_space_size_aligned * sizeof(LDSTypeB) / BPackedSize),
c_block_size * sizeof(CShuffleDataType)); c_block_size * sizeof(CShuffleDataType));
} }
...@@ -1009,7 +1054,8 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3 ...@@ -1009,7 +1054,8 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) &&
!(is_same<tensor_layout::gemm::RowMajor, ALayout>::value))
{ {
if(!(karg.M % MPerBlock == 0)) if(!(karg.M % MPerBlock == 0))
{ {
...@@ -1026,7 +1072,8 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3 ...@@ -1026,7 +1072,8 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) &&
(is_same<tensor_layout::gemm::RowMajor, BLayout>::value))
{ {
if(!(karg.N % NPerBlock == 0)) if(!(karg.N % NPerBlock == 0))
{ {
...@@ -1228,10 +1275,6 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3 ...@@ -1228,10 +1275,6 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>( const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
// const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
// make_tuple(math::integer_divide_ceil(problem.M, ScaleBlockM),
// math::integer_divide_ceil(problem.K, ScaleBlockK)),
// make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor( const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN), make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN),
math::integer_divide_ceil(problem.K, ScaleBlockK)), math::integer_divide_ceil(problem.K, ScaleBlockK)),
...@@ -1247,10 +1290,6 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3 ...@@ -1247,10 +1290,6 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
// p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize()); p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize());
...@@ -1358,7 +1397,7 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3 ...@@ -1358,7 +1397,7 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<LDSTypeB*>(p_shared) + static_cast<LDSTypeB*>(p_shared) +
a_block_space_size_aligned * sizeof(LDSTypeA) / sizeof(LDSTypeB), a_block_space_size_aligned * sizeof(LDSTypeA) / sizeof(LDSTypeB) / APackedSize,
b_block_desc_bk0_n_bk1.GetElementSpaceSize()); b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
...@@ -1377,37 +1416,14 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3 ...@@ -1377,37 +1416,14 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
const index_t ScaleSliceSizeN = NXdlPerWave; const index_t ScaleSliceSizeN = NXdlPerWave;
const index_t ScaleSliceSizeK = 1; const index_t ScaleSliceSizeK = 1;
// constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed(
// make_tuple(Number<ScaleSliceSizeM>{}, Number<ScaleSliceSizeK>{}));
constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<ScaleSliceSizeN>{}, Number<ScaleSliceSizeK>{})); make_tuple(Number<ScaleSliceSizeN>{}, Number<ScaleSliceSizeK>{}));
// auto a_scale_thread_copy =
// ThreadwiseTensorSliceTransfer_v2<AScaleType,
// AScaleType,
// decltype(a_scale_grid_desc_am_ak),
// decltype(a_scale_thread_desc),
// Sequence<ScaleSliceSizeM, ScaleSliceSizeK>,
// Sequence<0, 1>,
// 1,
// 1,
// 1,
// false>(
// a_scale_grid_desc_am_ak, make_multi_index(block_m_id * MPerBlock / ScaleBlockM,
// 0));
constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
auto b_thread_offset = auto b_thread_offset =
get_thread_local_1d_id() % NPerXdl + (get_thread_local_1d_id() / 64) % NWaves * NPerXdl; get_thread_local_1d_id() % NPerXdl + (get_thread_local_1d_id() / 64) % NWaves * NPerXdl;
// __syncthreads();
// if(blockIdx.x==0)
// {
// printf("ThreadIdx: %d, b_thread_offset: %d\n", get_thread_local_1d_id(), b_thread_offset);
// }
auto b_scale_thread_copy = auto b_scale_thread_copy =
ThreadwiseTensorSliceTransfer_v2<BScaleType, ThreadwiseTensorSliceTransfer_v2<BScaleType,
BScaleType, BScaleType,
...@@ -1422,7 +1438,6 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3 ...@@ -1422,7 +1438,6 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
b_scale_grid_desc_bn_ak, b_scale_grid_desc_bn_ak,
make_multi_index(block_n_id * NPerBlock / ScaleBlockN + b_thread_offset, 0)); make_multi_index(block_n_id * NPerBlock / ScaleBlockN + b_thread_offset, 0));
// constexpr auto a_scale_thread_slice_copy_step = make_multi_index(0, 1);
constexpr auto b_scale_thread_slice_copy_step = constexpr auto b_scale_thread_slice_copy_step =
make_tuple(make_multi_index(NWaves * NPerXdl, 0), make_multi_index(-NPerBlock, 1)); make_tuple(make_multi_index(NWaves * NPerXdl, 0), make_multi_index(-NPerBlock, 1));
...@@ -1443,12 +1458,6 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3 ...@@ -1443,12 +1458,6 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
b_block_slice_copy_step, b_block_slice_copy_step,
c_thread_buf, c_thread_buf,
// a_scale_grid_desc_am_ak,
// a_scale_thread_desc,
// a_scale_thread_copy,
// a_scale_grid_buf,
// a_scale_thread_slice_copy_step,
b_scale_grid_desc_bn_ak, b_scale_grid_desc_bn_ak,
b_scale_thread_desc, b_scale_thread_desc,
b_scale_thread_copy, b_scale_thread_copy,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment