Commit 03c2ba3a authored by aska-0096's avatar aska-0096
Browse files

bug fix + performance opt + clangformat

parent 1a324dfb
...@@ -92,29 +92,6 @@ constexpr auto BlockGemmPipeline_Selector() ...@@ -92,29 +92,6 @@ constexpr auto BlockGemmPipeline_Selector()
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{ {
return BlockwiseGemmXdlops_pipeline_v3_b_scale<BlkGemmPipeSche, return BlockwiseGemmXdlops_pipeline_v3_b_scale<BlkGemmPipeSche,
BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{
return BlockwiseGemmXdlops_pipeline_v4_b_scale<BlkGemmPipeSche,
BlockSize, BlockSize,
ADataType, ADataType,
BDataType, BDataType,
...@@ -135,9 +112,9 @@ constexpr auto BlockGemmPipeline_Selector() ...@@ -135,9 +112,9 @@ constexpr auto BlockGemmPipeline_Selector()
NRepeat, NRepeat,
KPack>{}; KPack>{};
} }
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v5) else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{ {
return BlockwiseGemmXdlops_pipeline_v5<BlkGemmPipeSche, return BlockwiseGemmXdlops_pipeline_v4_b_scale<BlkGemmPipeSche,
BlockSize, BlockSize,
ADataType, ADataType,
BDataType, BDataType,
...@@ -158,6 +135,29 @@ constexpr auto BlockGemmPipeline_Selector() ...@@ -158,6 +135,29 @@ constexpr auto BlockGemmPipeline_Selector()
NRepeat, NRepeat,
KPack>{}; KPack>{};
} }
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v5)
{
return BlockwiseGemmXdlops_pipeline_v5<BlkGemmPipeSche,
BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
}
else else
{ {
std::cerr << "BlockGemmPipeline configuration is not available" << std::endl; std::cerr << "BlockGemmPipeline configuration is not available" << std::endl;
......
...@@ -59,25 +59,25 @@ template <index_t BlockSize, ...@@ -59,25 +59,25 @@ template <index_t BlockSize,
// ,bool TransposeC //disable transposec right now... // ,bool TransposeC //disable transposec right now...
> >
struct BlockwiseGemmXdlops_pipeline_v1_b_scale<BlockGemmPipelineScheduler::Intrawave, struct BlockwiseGemmXdlops_pipeline_v1_b_scale<BlockGemmPipelineScheduler::Intrawave,
BlockSize, BlockSize,
ADataType, ADataType,
BDataType, BDataType,
ComputeDataType, ComputeDataType,
AccDataType, AccDataType,
ATileDesc, ATileDesc,
BTileDesc, BTileDesc,
AMmaTileDesc, AMmaTileDesc,
BMmaTileDesc, BMmaTileDesc,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
KPerBlock, KPerBlock,
MPerXDL, MPerXDL,
NPerXDL, NPerXDL,
MRepeat, MRepeat,
NRepeat, NRepeat,
KPack> KPack>
: BlockwiseGemmXdlops_pipeline_base<BlockSize, : BlockwiseGemmXdlops_pipeline_base<BlockSize,
ADataType, ADataType,
BDataType, BDataType,
...@@ -170,7 +170,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -170,7 +170,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_b_scale<BlockGemmPipelineScheduler::Intra
typename BBlockBuffer, typename BBlockBuffer,
typename BBlockTransferStep, typename BBlockTransferStep,
typename CThreadBuffer, typename CThreadBuffer,
//BScale Thread Copy // BScale Thread Copy
typename BScaleGridBuffer, typename BScaleGridBuffer,
typename BScaleGridDesc, typename BScaleGridDesc,
typename BScaleThreadDesc, typename BScaleThreadDesc,
...@@ -209,7 +209,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -209,7 +209,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_b_scale<BlockGemmPipelineScheduler::Intra
a_thread_desc_.GetElementSpaceSize()); a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>( auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
b_thread_desc_.GetElementSpaceSize()); b_thread_desc_.GetElementSpaceSize());
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>( auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
b_scale_thread_desc.GetElementSpaceSize()); b_scale_thread_desc.GetElementSpaceSize());
......
...@@ -59,25 +59,25 @@ template <index_t BlockSize, ...@@ -59,25 +59,25 @@ template <index_t BlockSize,
// ,bool TransposeC //disable transposec right now... // ,bool TransposeC //disable transposec right now...
> >
struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Intrawave, struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Intrawave,
BlockSize, BlockSize,
ADataType, ADataType,
BDataType, BDataType,
ComputeDataType, ComputeDataType,
AccDataType, AccDataType,
ATileDesc, ATileDesc,
BTileDesc, BTileDesc,
AMmaTileDesc, AMmaTileDesc,
BMmaTileDesc, BMmaTileDesc,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
KPerBlock, KPerBlock,
MPerXDL, MPerXDL,
NPerXDL, NPerXDL,
MRepeat, MRepeat,
NRepeat, NRepeat,
KPack> KPack>
: BlockwiseGemmXdlops_pipeline_base<BlockSize, : BlockwiseGemmXdlops_pipeline_base<BlockSize,
ADataType, ADataType,
BDataType, BDataType,
...@@ -546,25 +546,25 @@ template <index_t BlockSize, ...@@ -546,25 +546,25 @@ template <index_t BlockSize,
// ,bool TransposeC //disable transposec right now... // ,bool TransposeC //disable transposec right now...
> >
struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Interwave, struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Interwave,
BlockSize, BlockSize,
ADataType, ADataType,
BDataType, BDataType,
ComputeDataType, ComputeDataType,
AccDataType, AccDataType,
ATileDesc, ATileDesc,
BTileDesc, BTileDesc,
AMmaTileDesc, AMmaTileDesc,
BMmaTileDesc, BMmaTileDesc,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
KPerBlock, KPerBlock,
MPerXDL, MPerXDL,
NPerXDL, NPerXDL,
MRepeat, MRepeat,
NRepeat, NRepeat,
KPack> KPack>
: BlockwiseGemmXdlops_pipeline_base<BlockSize, : BlockwiseGemmXdlops_pipeline_base<BlockSize,
ADataType, ADataType,
BDataType, BDataType,
...@@ -719,16 +719,16 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Inter ...@@ -719,16 +719,16 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Inter
const BBlockTransferStep& b_block_copy_step, const BBlockTransferStep& b_block_copy_step,
CThreadBuffer& c_thread_buf, CThreadBuffer& c_thread_buf,
const BScaleGridDesc& b_scale_grid_desc, const BScaleGridDesc& b_scale_grid_desc,
//BScaleThreadCopy // BScaleThreadCopy
const BScaleThreadDesc& b_scale_thread_desc, const BScaleThreadDesc& b_scale_thread_desc,
BScaleThreadTransfer& b_scale_thread_copy, BScaleThreadTransfer& b_scale_thread_copy,
const BScaleGridBuffer& b_scale_grid_buf, const BScaleGridBuffer& b_scale_grid_buf,
const BScaleThreadTransferStep& b_scale_thread_copy_step, const BScaleThreadTransferStep& b_scale_thread_copy_step,
//num loop // num loop
index_t num_loop, index_t num_loop,
index_t num_loop_per_scale) const index_t num_loop_per_scale) const
{ {
ignore = num_loop_per_scale; ignore = num_loop_per_scale;
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>( auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
a_thread_desc_.GetElementSpaceSize()); a_thread_desc_.GetElementSpaceSize());
...@@ -751,7 +751,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Inter ...@@ -751,7 +751,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Inter
b_scale_thread_desc, b_scale_thread_desc,
make_tuple(n0, I0), make_tuple(n0, I0),
b_scale_thread_buf); b_scale_thread_buf);
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
b_scale_thread_copy_step.At(Number<0>{})); b_scale_thread_copy_step.At(Number<0>{}));
}); });
...@@ -864,16 +864,16 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Inter ...@@ -864,16 +864,16 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Inter
} }
}); });
// static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { // static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t)
// {
// constexpr index_t c_offset = // constexpr index_t c_offset =
// c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); // c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
// c_thread_buf(Number<c_offset>{}) += // c_thread_buf(Number<c_offset>{}) +=
// c_thread_buf_per_scale[Number<t>{}] * // c_thread_buf_per_scale[Number<t>{}] *
// type_convert<AccDataType>(b_scale_thread_buf[n0]); // type_convert<AccDataType>(b_scale_thread_buf[n0]);
// }); // });
}); });
}); });
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(0); __builtin_amdgcn_s_setprio(0);
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
...@@ -983,10 +983,9 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Inter ...@@ -983,10 +983,9 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Inter
// constexpr index_t c_offset = // constexpr index_t c_offset =
// c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); // c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
// c_thread_buf(Number<c_offset>{}) += // c_thread_buf(Number<c_offset>{}) +=
// c_thread_buf_per_scale[Number<t>{}] * // c_thread_buf_per_scale[Number<t>{}] *
// type_convert<AccDataType>(b_scale_thread_buf[n0]); // type_convert<AccDataType>(b_scale_thread_buf[n0]);
// }); // });
}); });
}); });
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
...@@ -1084,7 +1083,6 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Inter ...@@ -1084,7 +1083,6 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Inter
// c_thread_buf_per_scale[Number<t>{}] * // c_thread_buf_per_scale[Number<t>{}] *
// type_convert<AccDataType>(b_scale_thread_buf[n0]); // type_convert<AccDataType>(b_scale_thread_buf[n0]);
// }); // });
}); });
}); });
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
......
...@@ -59,25 +59,25 @@ template <index_t BlockSize, ...@@ -59,25 +59,25 @@ template <index_t BlockSize,
// ,bool TransposeC //disable transposec right now... // ,bool TransposeC //disable transposec right now...
> >
struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intrawave, struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intrawave,
BlockSize, BlockSize,
ADataType, ADataType,
BDataType, BDataType,
ComputeDataType, ComputeDataType,
AccDataType, AccDataType,
ATileDesc, ATileDesc,
BTileDesc, BTileDesc,
AMmaTileDesc, AMmaTileDesc,
BMmaTileDesc, BMmaTileDesc,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
KPerBlock, KPerBlock,
MPerXDL, MPerXDL,
NPerXDL, NPerXDL,
MRepeat, MRepeat,
NRepeat, NRepeat,
KPack> KPack>
: BlockwiseGemmXdlops_pipeline_base<BlockSize, : BlockwiseGemmXdlops_pipeline_base<BlockSize,
ADataType, ADataType,
BDataType, BDataType,
...@@ -295,25 +295,24 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -295,25 +295,24 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
BBlockBuffer& b_block_buf, BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step, const BBlockTransferStep& b_block_copy_step,
CThreadBuffer& c_thread_buf, CThreadBuffer& c_thread_buf,
//BScaleThreadCopy // BScaleThreadCopy
const BScaleGridDesc& b_scale_grid_desc, const BScaleGridDesc& b_scale_grid_desc,
const BScaleThreadDesc& b_scale_thread_desc, const BScaleThreadDesc& b_scale_thread_desc,
BScaleThreadTransfer& b_scale_thread_copy, BScaleThreadTransfer& b_scale_thread_copy,
const BScaleGridBuffer& b_scale_grid_buf, const BScaleGridBuffer& b_scale_grid_buf,
const BScaleThreadTransferStep& b_scale_thread_copy_step, const BScaleThreadTransferStep& b_scale_thread_copy_step,
//num loop // num loop
index_t num_loop, index_t num_loop,
index_t num_loop_per_scale) const index_t num_loop_per_scale) const
{ {
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
ignore = num_loop_per_scale;
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>( auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
a_thread_desc_.GetElementSpaceSize()); a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>( auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
b_thread_desc_.GetElementSpaceSize()); b_thread_desc_.GetElementSpaceSize());
//B scale buffer // B scale buffer
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>( auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
b_scale_thread_desc.GetElementSpaceSize()); b_scale_thread_desc.GetElementSpaceSize());
...@@ -328,14 +327,23 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -328,14 +327,23 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
b_scale_thread_copy.Run(b_scale_grid_desc, b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf, b_scale_grid_buf,
b_scale_thread_desc, b_scale_thread_desc,
make_tuple(n0, I0, I0), make_tuple(n0, I0),
b_scale_thread_buf); b_scale_thread_buf);
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
b_scale_thread_copy_step.At(Number<0>{})); b_scale_thread_copy_step.At(Number<0>{}));
}); });
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
b_scale_thread_copy_step.At(Number<1>{})); if(num_loop_per_scale == 1)
{
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
b_scale_thread_copy_step.At(Number<2>{}));
}
else
{
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
b_scale_thread_copy_step.At(Number<1>{}));
}
// Local prefill 1 // Local prefill 1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
...@@ -350,7 +358,13 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -350,7 +358,13 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
// Initialize C // Initialize C
c_thread_buf.Clear(); c_thread_buf.Clear();
auto c_thread_buf_per_scale = remove_cvref_t<decltype(c_thread_buf)>(); // need actually?
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
AccDataType,
1,
xdlops_gemm.GetRegSizePerXdlops(),
true>
c_thread_buf_per_scale;
// Local prefetch 1 // Local prefetch 1
block_sync_lds(); block_sync_lds();
...@@ -415,10 +429,9 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -415,10 +429,9 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
// constexpr index_t c_offset = // constexpr index_t c_offset =
// c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); // c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run( xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
a_thread_vec.template AsType<mfma_input_type>(), b_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(), c_thread_buf_per_scale.GetVectorTypeReference(I0));
c_thread_buf_per_scale.GetVectorTypeReference(I0));
}); });
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset = constexpr index_t c_offset =
...@@ -455,15 +468,23 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -455,15 +468,23 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
b_scale_thread_copy.Run(b_scale_grid_desc, b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf, b_scale_grid_buf,
b_scale_thread_desc, b_scale_thread_desc,
make_tuple(n0, I0, I0), make_tuple(n0, I0),
b_scale_thread_buf); b_scale_thread_buf);
b_scale_thread_copy.MoveSrcSliceWindow( b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc, b_scale_thread_copy_step.At(Number<0>{})); b_scale_grid_desc, b_scale_thread_copy_step.At(Number<0>{}));
}); });
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, if((i + 2) % num_loop_per_scale == 0)
b_scale_thread_copy_step.At(Number<1>{})); {
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc, b_scale_thread_copy_step.At(Number<2>{}));
}
else
{
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc, b_scale_thread_copy_step.At(Number<1>{}));
}
HotLoopScheduler(); HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
......
...@@ -268,13 +268,13 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -268,13 +268,13 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
BBlockBuffer& b_block_buf, BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step, const BBlockTransferStep& b_block_copy_step,
CThreadBuffer& c_thread_buf, CThreadBuffer& c_thread_buf,
//BScaleThreadCopy // BScaleThreadCopy
const BScaleGridDesc& b_scale_grid_desc, const BScaleGridDesc& b_scale_grid_desc,
const BScaleThreadDesc& b_scale_thread_desc, const BScaleThreadDesc& b_scale_thread_desc,
BScaleThreadTransfer& b_scale_thread_copy, BScaleThreadTransfer& b_scale_thread_copy,
const BScaleGridBuffer& b_scale_grid_buf, const BScaleGridBuffer& b_scale_grid_buf,
const BScaleThreadTransferStep& b_scale_thread_copy_step, const BScaleThreadTransferStep& b_scale_thread_copy_step,
//num loop // num loop
index_t num_loop, index_t num_loop,
index_t num_loop_per_scale) const index_t num_loop_per_scale) const
{ {
...@@ -284,7 +284,7 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -284,7 +284,7 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>( auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
b_thread_desc_.GetElementSpaceSize()); b_thread_desc_.GetElementSpaceSize());
//B scale buffer // B scale buffer
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>( auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
b_scale_thread_desc.GetElementSpaceSize()); b_scale_thread_desc.GetElementSpaceSize());
...@@ -409,11 +409,11 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -409,11 +409,11 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
make_tuple(n0, I0), make_tuple(n0, I0),
b_scale_thread_bufs(lds_read_reg_buf)); b_scale_thread_bufs(lds_read_reg_buf));
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_thread_copy_step.At(Number<0>{})); b_scale_grid_desc, b_scale_thread_copy_step.At(Number<0>{}));
}); });
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_thread_copy_step.At(Number<1>{})); b_scale_grid_desc, b_scale_thread_copy_step.At(Number<1>{}));
a_blockwise_copy.RunWrite( a_blockwise_copy.RunWrite(
a_block_desc, a_block_buf.At(lds_write_buf), vmem_buf); a_block_desc, a_block_buf.At(lds_write_buf), vmem_buf);
...@@ -426,7 +426,6 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -426,7 +426,6 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
c_thread_buf_per_scale.Clear(); c_thread_buf_per_scale.Clear();
...@@ -437,32 +436,32 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -437,32 +436,32 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
static_for<0, KPack, 1>{}([&](auto ik) { static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) = a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_bufs[mfma_reg_buf] a_thread_bufs[mfma_reg_buf]
[Number<a_thread_desc_.CalculateOffset( [Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}]; make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) = b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[mfma_reg_buf] b_thread_bufs[mfma_reg_buf]
[Number<b_thread_desc_.CalculateOffset( [Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}]; make_tuple(n0, I0, k0, ik))>{}];
}); });
using mfma_input_type = using mfma_input_type =
typename vector_type<ComputeDataType, typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type; xdlops_gemm.K1PerXdlops>::type;
// constexpr index_t c_offset = // constexpr index_t c_offset =
// c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); // c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run( xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
a_thread_vec.template AsType<mfma_input_type>(), b_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(), c_thread_buf_per_scale.GetVectorTypeReference(I0));
c_thread_buf_per_scale.GetVectorTypeReference(I0));
}); });
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset = constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
c_thread_buf(Number<c_offset>{}) += c_thread_buf(Number<c_offset>{}) +=
c_thread_buf_per_scale[Number<t>{}] * c_thread_buf_per_scale[Number<t>{}] *
type_convert<AccDataType>(b_scale_thread_bufs(mfma_reg_buf)[n0]); type_convert<AccDataType>(
b_scale_thread_bufs(mfma_reg_buf)[n0]);
}); });
}); });
}); });
...@@ -513,15 +512,14 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -513,15 +512,14 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
b_scale_thread_bufs(lds_read_reg_buf)); b_scale_thread_bufs(lds_read_reg_buf));
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
b_scale_thread_copy_step.At(Number<0>{})); b_scale_thread_copy_step.At(Number<0>{}));
}); });
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
b_scale_thread_copy_step.At(Number<1>{})); b_scale_thread_copy_step.At(Number<1>{}));
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf), vmem_buf); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf), vmem_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf), vmem_buf); b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf), vmem_buf);
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
c_thread_buf_per_scale.Clear(); c_thread_buf_per_scale.Clear();
...@@ -595,10 +593,10 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -595,10 +593,10 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
b_scale_thread_bufs(lds_read_reg_buf)); b_scale_thread_bufs(lds_read_reg_buf));
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
b_scale_thread_copy_step.At(Number<0>{})); b_scale_thread_copy_step.At(Number<0>{}));
}); });
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
b_scale_thread_copy_step.At(Number<1>{})); b_scale_thread_copy_step.At(Number<1>{}));
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
...@@ -640,7 +638,6 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -640,7 +638,6 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
}; };
auto CompFunc = [&](auto mfma_reg_buf) { auto CompFunc = [&](auto mfma_reg_buf) {
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
c_thread_buf_per_scale.Clear(); c_thread_buf_per_scale.Clear();
......
...@@ -99,7 +99,7 @@ struct DeviceGemmV2BScale : public BaseOperator ...@@ -99,7 +99,7 @@ struct DeviceGemmV2BScale : public BaseOperator
ck::index_t K, ck::index_t K,
ck::index_t StrideA, ck::index_t StrideA,
ck::index_t StrideB, ck::index_t StrideB,
ck::index_t StrideC, ck::index_t StrideC,
const void* p_b_scale, const void* p_b_scale,
ck::index_t KSplit, ck::index_t KSplit,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
......
...@@ -35,8 +35,8 @@ template <typename ALayout, ...@@ -35,8 +35,8 @@ template <typename ALayout,
typename CElementwiseOperation, typename CElementwiseOperation,
GemmSpecialization GemmSpec, GemmSpecialization GemmSpec,
index_t BlockSize, index_t BlockSize,
index_t ScaleBlockN, // scale block for N index_t ScaleBlockN, // scale block for N
index_t ScaleBlockK, // scale block for K index_t ScaleBlockK, // scale block for K
index_t MPerBlock, index_t MPerBlock,
index_t NPerBlock, index_t NPerBlock,
index_t KPerBlock, index_t KPerBlock,
...@@ -218,7 +218,12 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale<ALayout, ...@@ -218,7 +218,12 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
}; };
constexpr index_t minimum_occupancy = constexpr index_t minimum_occupancy =
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave
? (BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 &&
MPerBlock * NPerBlock * KPerBlock * sizeof(ADataType) > 128 * 128 * 64 * 2)
? 1
: 2
: 2;
if(has_main_k_block_loop) if(has_main_k_block_loop)
{ {
...@@ -659,12 +664,25 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale<ALayout, ...@@ -659,12 +664,25 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
index_t StrideB, index_t StrideB,
index_t StrideC, index_t StrideC,
const BScaleDataType* p_b_scale, const BScaleDataType* p_b_scale,
index_t KBatch, index_t KBatch,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) CElementwiseOperation c_element_op)
{ {
return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, p_b_scale, KBatch, a_element_op, b_element_op, c_element_op}; return Argument{p_a,
p_b,
p_c,
M,
N,
K,
StrideA,
StrideB,
StrideC,
p_b_scale,
KBatch,
a_element_op,
b_element_op,
c_element_op};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
...@@ -680,7 +698,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale<ALayout, ...@@ -680,7 +698,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
index_t StrideB, index_t StrideB,
index_t StrideC, index_t StrideC,
const void* p_b_scale, const void* p_b_scale,
index_t KBatch, index_t KBatch,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) override CElementwiseOperation c_element_op) override
......
...@@ -13,9 +13,9 @@ namespace ck { ...@@ -13,9 +13,9 @@ namespace ck {
__host__ __device__ inline half4_t pki4_to_half4(int q) __host__ __device__ inline half4_t pki4_to_half4(int q)
{ {
const int LO = 0x000f000f; constexpr int LO = 0x000f000f;
const int HI = 0x00f000f0; constexpr int HI = 0x00f000f0;
const int EX = 0x64006400; constexpr int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s. // Guarantee that the `(a & b) | c` operations are LOP3s.
// int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); // int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
// int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); // int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
...@@ -23,9 +23,9 @@ __host__ __device__ inline half4_t pki4_to_half4(int q) ...@@ -23,9 +23,9 @@ __host__ __device__ inline half4_t pki4_to_half4(int q)
int hi = amd_assembly_and_or_b32(q, HI, EX); int hi = amd_assembly_and_or_b32(q, HI, EX);
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`. // directly into `SUB` and `ADD`.
const int SUB = 0xE408E408; //-8 constexpr int SUB = 0xE408E408; //-8
const int MUL = 0x2c002c00; // 1/16 constexpr int MUL = 0x2c002c00; // 1/16
const int ADD = 0xd480d480; //-79 constexpr int ADD = 0xd480d480; //-79
vector_type<half_t, 4> res; vector_type<half_t, 4> res;
...@@ -34,7 +34,15 @@ __host__ __device__ inline half4_t pki4_to_half4(int q) ...@@ -34,7 +34,15 @@ __host__ __device__ inline half4_t pki4_to_half4(int q)
res.template AsType<half2_t>()(Number<1>{}) = amd_assembly_pk_fma_f16( res.template AsType<half2_t>()(Number<1>{}) = amd_assembly_pk_fma_f16(
bit_cast<half2_t>(hi), bit_cast<half2_t>(MUL), bit_cast<half2_t>(ADD)); bit_cast<half2_t>(hi), bit_cast<half2_t>(MUL), bit_cast<half2_t>(ADD));
#if 0
asm volatile("v_and_or_b32 %0, %4, %5, %7 \n \
v_and_or_b32 %1, %4, %6, %7 \n \
v_pk_add_f16 %2, %0, %8 \n \
v_pk_fma_f16 %3, %1, %9, %10 \
"
: "=v"(lo), "=v"(hi), "=v"(res.template AsType<half2_t>()(Number<0>{})), "=v"(res.template AsType<half2_t>()(Number<1>{}))
: "v"(q), "v"(LO), "v"(HI), "s"(EX), "s"(SUB), "v"(MUL), "s"(ADD), "0"(lo), "1"(hi));
#endif
return res.template AsType<half4_t>()[Number<0>{}]; return res.template AsType<half4_t>()[Number<0>{}];
} }
...@@ -80,14 +88,14 @@ struct PassThroughPack8 ...@@ -80,14 +88,14 @@ struct PassThroughPack8
{ {
#if 1 #if 1
int x_permute = 0; int x_permute = 0;
int bits4_0 = (bit_cast<int>(x) >> 0) & 0xF; int bits4_0 = (bit_cast<int>(x) >> 0) & 0xF;
int bits4_1 = (bit_cast<int>(x) >> 4) & 0xF; int bits4_1 = (bit_cast<int>(x) >> 4) & 0xF;
int bits4_2 = (bit_cast<int>(x) >> 8) & 0xF; int bits4_2 = (bit_cast<int>(x) >> 8) & 0xF;
int bits4_3 = (bit_cast<int>(x) >> 12) & 0xF; int bits4_3 = (bit_cast<int>(x) >> 12) & 0xF;
int bits4_4 = (bit_cast<int>(x) >> 16) & 0xF; int bits4_4 = (bit_cast<int>(x) >> 16) & 0xF;
int bits4_5 = (bit_cast<int>(x) >> 20) & 0xF; int bits4_5 = (bit_cast<int>(x) >> 20) & 0xF;
int bits4_6 = (bit_cast<int>(x) >> 24) & 0xF; int bits4_6 = (bit_cast<int>(x) >> 24) & 0xF;
int bits4_7 = (bit_cast<int>(x) >> 28) & 0xF; int bits4_7 = (bit_cast<int>(x) >> 28) & 0xF;
x_permute |= (bits4_1 << 0); x_permute |= (bits4_1 << 0);
x_permute |= (bits4_3 << 4); x_permute |= (bits4_3 << 4);
...@@ -111,7 +119,7 @@ struct PassThroughPack8 ...@@ -111,7 +119,7 @@ struct PassThroughPack8
result.template AsType<half4_t>()(Number<0>{}) = pki4_to_half4(bit_cast<int>(x)); result.template AsType<half4_t>()(Number<0>{}) = pki4_to_half4(bit_cast<int>(x));
result.template AsType<half4_t>()(Number<1>{}) = pki4_to_half4(bit_cast<int>(x) >> 8); result.template AsType<half4_t>()(Number<1>{}) = pki4_to_half4(bit_cast<int>(x) >> 8);
y = result.template AsType<half8_t>()[Number<0>{}]; y = result.template AsType<half8_t>()[Number<0>{}];
#else #else
vector_type<half_t, 8> dst; vector_type<half_t, 8> dst;
vector_type<pk_i4_t, 4> src{x}; vector_type<pk_i4_t, 4> src{x};
...@@ -125,7 +133,7 @@ struct PassThroughPack8 ...@@ -125,7 +133,7 @@ struct PassThroughPack8
dst.template AsType<half2_t>()(Number<3>{}) = dst.template AsType<half2_t>()(Number<3>{}) =
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<3>{}]); pki4_to_half2(src.template AsType<pk_i4_t>()[Number<3>{}]);
y = dst.template AsType<half8_t>()[Number<0>{}]; y = dst.template AsType<half8_t>()[Number<0>{}];
#endif #endif
} }
......
...@@ -48,12 +48,7 @@ __global__ void ...@@ -48,12 +48,7 @@ __global__ void
// karg); // karg);
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>( GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid, karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, karg.p_b_scale_grid, p_shared, karg);
karg.p_b_grid,
karg.p_c_grid,
karg.p_b_scale_grid,
p_shared,
karg);
#else #else
ignore = karg; ignore = karg;
#endif // end of if (defined(__gfx9__)) #endif // end of if (defined(__gfx9__))
...@@ -113,8 +108,8 @@ template <typename ALayout, ...@@ -113,8 +108,8 @@ template <typename ALayout,
typename CElementwiseOperation, typename CElementwiseOperation,
tensor_operation::device::GemmSpecialization GemmSpec, tensor_operation::device::GemmSpecialization GemmSpec,
index_t BlockSize, index_t BlockSize,
index_t ScaleBlockN, // scale N index_t ScaleBlockN, // scale N
index_t ScaleBlockK, // scale K index_t ScaleBlockK, // scale K
index_t MPerBlock, index_t MPerBlock,
index_t NPerBlock, index_t NPerBlock,
index_t KPerBlock, index_t KPerBlock,
...@@ -605,7 +600,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -605,7 +600,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
index_t StrideA_, index_t StrideA_,
index_t StrideB_, index_t StrideB_,
index_t StrideC_, index_t StrideC_,
const BScaleType* p_b_scale_grid_, const BScaleType* p_b_scale_grid_,
index_t k_batch_, index_t k_batch_,
AElementwiseOperation a_element_op_, AElementwiseOperation a_element_op_,
BElementwiseOperation b_element_op_, BElementwiseOperation b_element_op_,
...@@ -636,7 +631,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -636,7 +631,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
const ADataType* p_a_grid; const ADataType* p_a_grid;
const BDataType* p_b_grid; const BDataType* p_b_grid;
CDataType* p_c_grid; CDataType* p_c_grid;
const BScaleType* p_b_scale_grid; const BScaleType* p_b_scale_grid;
const AElementwiseOperation a_element_op; const AElementwiseOperation a_element_op;
const BElementwiseOperation b_element_op; const BElementwiseOperation b_element_op;
...@@ -1303,13 +1298,10 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1303,13 +1298,10 @@ struct GridwiseGemm_xdl_cshuffle_v3
// B Scale grid and buffer // B Scale grid and buffer
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor( const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN), make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN),
math::integer_divide_ceil(problem.K, ScaleBlockK), math::integer_divide_ceil(problem.K, ScaleBlockK)),
ScaleBlockK), make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK),
1,
0));
const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize()); p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize());
const AElementwiseOperation a_element_op{}; const AElementwiseOperation a_element_op{};
...@@ -1438,12 +1430,12 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1438,12 +1430,12 @@ struct GridwiseGemm_xdl_cshuffle_v3
KPerBlock); KPerBlock);
// b scale // b scale
static_assert(KPerBlock<=ScaleBlockK); static_assert(KPerBlock <= ScaleBlockK);
const index_t ScaleSliceSizeN = NXdlPerWave; const index_t ScaleSliceSizeN = NXdlPerWave;
const index_t ScaleSliceSizeK = 1; const index_t ScaleSliceSizeK = 1;
constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<ScaleSliceSizeN>{}, Number<ScaleSliceSizeK>{}, Number<1>{})); make_tuple(Number<ScaleSliceSizeN>{}, Number<ScaleSliceSizeK>{}));
constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
...@@ -1455,41 +1447,43 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1455,41 +1447,43 @@ struct GridwiseGemm_xdl_cshuffle_v3
BScaleType, BScaleType,
decltype(b_scale_grid_desc_bn_ak), decltype(b_scale_grid_desc_bn_ak),
decltype(b_scale_thread_desc), decltype(b_scale_thread_desc),
Sequence<1, ScaleSliceSizeK, 1>, Sequence<1, ScaleSliceSizeK>,
Sequence<0, 1, 2>, Sequence<0, 1>,
1, 1,
1, 1,
1, 1,
false>( false>(
b_scale_grid_desc_bn_ak, b_scale_grid_desc_bn_ak,
make_multi_index(block_n_id * NPerBlock / ScaleBlockN + b_thread_offset, 0, 0)); make_multi_index(block_n_id * NPerBlock / ScaleBlockN + b_thread_offset, 0));
constexpr auto b_scale_thread_slice_copy_step = constexpr auto b_scale_thread_slice_copy_step =
make_tuple(make_multi_index(NWaves * NPerXdl, 0, 0), make_tuple(make_multi_index(NWaves * NPerXdl, 0),
make_multi_index(-NPerBlock, KPerBlock/ScaleBlockK, KPerBlock%ScaleBlockK)); make_multi_index(-NPerBlock, 0),
make_multi_index(-NPerBlock, 1));
const index_t num_k_block_per_scale = ScaleBlockK / KPerBlock; const index_t num_k_block_per_scale = ScaleBlockK / KPerBlock;
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1, blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
a_block_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
a_blockwise_copy, a_block_desc_ak0_m_ak1,
a_grid_buf, a_blockwise_copy,
a_block_buf, a_grid_buf,
a_block_slice_copy_step, a_block_buf,
b_grid_desc_bk0_n_bk1, a_block_slice_copy_step,
b_block_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
b_blockwise_copy, b_block_desc_bk0_n_bk1,
b_grid_buf, b_blockwise_copy,
b_block_buf, b_grid_buf,
b_block_slice_copy_step, b_block_buf,
c_thread_buf, b_block_slice_copy_step,
b_scale_grid_desc_bn_ak, c_thread_buf,
b_scale_thread_desc, b_scale_grid_desc_bn_ak,
b_scale_thread_copy, b_scale_thread_desc,
b_scale_grid_buf, b_scale_thread_copy,
b_scale_thread_slice_copy_step, b_scale_grid_buf,
num_k_block_main_loop, b_scale_thread_slice_copy_step,
num_k_block_per_scale); num_k_block_main_loop,
num_k_block_per_scale);
// shuffle C and write out // shuffle C and write out
{ {
...@@ -1756,7 +1750,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1756,7 +1750,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
math::integer_divide_ceil(problem.K, ScaleBlockK)), math::integer_divide_ceil(problem.K, ScaleBlockK)),
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1)); make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize()); p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize());
const AElementwiseOperation a_element_op{}; const AElementwiseOperation a_element_op{};
...@@ -1867,7 +1861,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1867,7 +1861,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
auto b_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
bit_cast<BDataType*>(static_cast<char*>(p_shared_0) + bit_cast<BDataType*>(static_cast<char*>(p_shared_0) +
a_block_space_size_aligned * sizeof(ADataType)/APackedSize), a_block_space_size_aligned * sizeof(ADataType) / APackedSize),
b_block_desc_bk0_n_bk1.GetElementSpaceSize()); b_block_desc_bk0_n_bk1.GetElementSpaceSize());
auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
...@@ -1875,7 +1869,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1875,7 +1869,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
auto b_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
bit_cast<BDataType*>(bit_cast<char*>(p_shared_1) + bit_cast<BDataType*>(bit_cast<char*>(p_shared_1) +
a_block_space_size_aligned * sizeof(ADataType)/APackedSize), a_block_space_size_aligned * sizeof(ADataType) / APackedSize),
b_block_desc_bk0_n_bk1.GetElementSpaceSize()); b_block_desc_bk0_n_bk1.GetElementSpaceSize());
auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong); auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
...@@ -1924,28 +1918,29 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1924,28 +1918,29 @@ struct GridwiseGemm_xdl_cshuffle_v3
const index_t num_k_block_per_scale = ScaleBlockK / KPerBlock; const index_t num_k_block_per_scale = ScaleBlockK / KPerBlock;
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1, blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
a_block_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
a_blockwise_copy, a_block_desc_ak0_m_ak1,
a_grid_buf, a_blockwise_copy,
a_block_bufs, a_grid_buf,
a_block_slice_copy_step, a_block_bufs,
b_grid_desc_bk0_n_bk1, a_block_slice_copy_step,
b_block_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
b_blockwise_copy, b_block_desc_bk0_n_bk1,
b_grid_buf, b_blockwise_copy,
b_block_bufs, b_grid_buf,
b_block_slice_copy_step, b_block_bufs,
c_thread_buf, b_block_slice_copy_step,
c_thread_buf,
b_scale_grid_desc_bn_ak,
b_scale_thread_desc, b_scale_grid_desc_bn_ak,
b_scale_thread_copy, b_scale_thread_desc,
b_scale_grid_buf, b_scale_thread_copy,
b_scale_thread_slice_copy_step, b_scale_grid_buf,
b_scale_thread_slice_copy_step,
num_k_block_main_loop,
num_k_block_per_scale); num_k_block_main_loop,
num_k_block_per_scale);
// shuffle C and write out // shuffle C and write out
{ {
......
...@@ -1176,7 +1176,7 @@ struct ThreadwiseTensorSliceTransfer_v4 ...@@ -1176,7 +1176,7 @@ struct ThreadwiseTensorSliceTransfer_v4
}); });
} }
else if constexpr(is_same<remove_cvref_t<SrcData>, pk_i4_t>::value && else if constexpr(is_same<remove_cvref_t<SrcData>, pk_i4_t>::value &&
is_same<remove_cvref_t<DstData>, f8_t>::value) is_same<remove_cvref_t<DstData>, f8_t>::value)
{ {
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData) // DstData)
......
...@@ -39,20 +39,20 @@ template <typename ADataType, ...@@ -39,20 +39,20 @@ template <typename ADataType,
typename ALayout, typename ALayout,
typename BLayout, typename BLayout,
typename CLayout, typename CLayout,
index_t ScaleBlockK> index_t ScaleBlockK>
struct DeviceOperationInstanceFactory< struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmV2BScale<
ck::tensor_operation::device::DeviceGemmV2BScale<ALayout, ALayout,
BLayout, BLayout,
CLayout, CLayout,
ADataType, ADataType,
BDataType, BDataType,
BScaleDataType, BScaleDataType,
CDataType, CDataType,
1, 1,
ScaleBlockK, ScaleBlockK,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>> ck::tensor_operation::element_wise::PassThrough>>
{ {
using DeviceOp = DeviceGemmV2BScale<ALayout, using DeviceOp = DeviceGemmV2BScale<ALayout,
BLayout, BLayout,
...@@ -70,7 +70,7 @@ struct DeviceOperationInstanceFactory< ...@@ -70,7 +70,7 @@ struct DeviceOperationInstanceFactory<
static auto GetInstances() static auto GetInstances()
{ {
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, pk_i4_t> && if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, pk_i4_t> &&
is_same_v<CDataType, half_t>) is_same_v<CDataType, half_t>)
{ {
......
...@@ -8,13 +8,22 @@ namespace tensor_operation { ...@@ -8,13 +8,22 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
void add_device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instances( void add_device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<DeviceGemmV2BScale<Row,
DeviceGemmV2BScale<Row, Col, Row, F16, I4, F16, F16, 1, 128, PassThrough, PassThrough, PassThrough>>>& Col,
instances) Row,
F16,
I4,
F16,
F16,
1,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, instances,
//device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_instances<Interwave, GemmDefault>{}); // device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_instances<Interwave, GemmDefault>{});
device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_instances<Intrawave, GemmDefault>{}); device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_instances<Intrawave, GemmDefault>{});
} }
......
...@@ -30,24 +30,24 @@ template <typename ADataType, ...@@ -30,24 +30,24 @@ template <typename ADataType,
typename ComputeDataType, typename ComputeDataType,
typename AccDataType, typename AccDataType,
typename CDataType, typename CDataType,
index_t ScaleBlockK, index_t ScaleBlockK,
typename ALayout, typename ALayout,
typename BLayout, typename BLayout,
typename CLayout> typename CLayout>
bool profile_gemm_b_scale_impl(int do_verification, bool profile_gemm_b_scale_impl(int do_verification,
int init_method, int init_method,
bool do_log, bool do_log,
bool time_kernel, bool time_kernel,
int M, int M,
int N, int N,
int K, int K,
int StrideA, int StrideA,
int StrideB, int StrideB,
int StrideC, int StrideC,
int KBatch, int KBatch,
int n_warmup, int n_warmup,
int n_iter, int n_iter,
uint64_t rotating = 0) uint64_t rotating = 0)
{ {
bool pass = true; bool pass = true;
...@@ -66,24 +66,25 @@ bool profile_gemm_b_scale_impl(int do_verification, ...@@ -66,24 +66,25 @@ bool profile_gemm_b_scale_impl(int do_verification,
}; };
ck::index_t Scale_Stride_BN = ck::is_same_v<BLayout, ck::tensor_layout::gemm::ColumnMajor> ck::index_t Scale_Stride_BN = ck::is_same_v<BLayout, ck::tensor_layout::gemm::ColumnMajor>
? ((K + ScaleBlockK - 1) / ScaleBlockK) ? ((K + ScaleBlockK - 1) / ScaleBlockK)
: N; : N;
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<BDataType> b_k_n_permute(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor<BDataType> b_k_n_permute(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<BScaleDataType> b1_k_n(f_host_tensor_descriptor((K + ScaleBlockK - 1) / ScaleBlockK, // K direction group size is ScaleBlockK Tensor<BScaleDataType> b1_k_n(f_host_tensor_descriptor(
N, // N direction group size is 1 (K + ScaleBlockK - 1) / ScaleBlockK, // K direction group size is ScaleBlockK
Scale_Stride_BN, N, // N direction group size is 1
BLayout{})); Scale_Stride_BN,
BLayout{}));
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
int total_gemm_needed = a_m_k.GetElementSpaceSizeInBytes() + int total_gemm_needed = a_m_k.GetElementSpaceSizeInBytes() +
b_k_n.GetElementSpaceSizeInBytes() + b_k_n.GetElementSpaceSizeInBytes() +
b1_k_n.GetElementSpaceSizeInBytes(); b1_k_n.GetElementSpaceSizeInBytes();
int rotating_count = std::max( int rotating_count = std::max(
1, 1,
std::min(n_iter, std::min(n_iter,
static_cast<int>(std::ceil(static_cast<double>(rotating) / total_gemm_needed)))); static_cast<int>(std::ceil(static_cast<double>(rotating) / total_gemm_needed))));
...@@ -167,9 +168,8 @@ bool profile_gemm_b_scale_impl(int do_verification, ...@@ -167,9 +168,8 @@ bool profile_gemm_b_scale_impl(int do_verification,
i4 = i4 - 8; i4 = i4 - 8;
v_b = ck::type_convert<float>(i4); v_b = ck::type_convert<float>(i4);
b_k_n_dequant(k, n) = b_k_n_dequant(k, n) = ck::type_convert<float>(v_b) *
ck::type_convert<float>(v_b) * ck::type_convert<float>(b1_k_n(k / ScaleBlockK, n));
ck::type_convert<float>(b1_k_n(k / ScaleBlockK, n));
} }
} }
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType, using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
...@@ -291,21 +291,21 @@ bool profile_gemm_b_scale_impl(int do_verification, ...@@ -291,21 +291,21 @@ bool profile_gemm_b_scale_impl(int do_verification,
{ {
auto kbatch_curr = kbatch_list[i]; auto kbatch_curr = kbatch_list[i];
auto argument_ptr = auto argument_ptr = op_ptr->MakeArgumentPointer(
op_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()), static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()), static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()), static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
M, M,
N, N,
K, K,
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
static_cast<BScaleDataType*>(b1_device_buf.GetDeviceBuffer()), static_cast<BScaleDataType*>(b1_device_buf.GetDeviceBuffer()),
kbatch_curr, kbatch_curr,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op); c_element_op);
auto invoker_ptr = op_ptr->MakeInvokerPointer(); auto invoker_ptr = op_ptr->MakeInvokerPointer();
......
...@@ -32,8 +32,8 @@ enum struct GemmDataType ...@@ -32,8 +32,8 @@ enum struct GemmDataType
enum struct BScaleBlockTile enum struct BScaleBlockTile
{ {
K_64, // 0 K_64, // 0
K_128, // 1 K_128, // 1
}; };
#define OP_NAME "gemm_b_scale" #define OP_NAME "gemm_b_scale"
...@@ -82,7 +82,14 @@ int profile_gemm_b_scale(int argc, char* argv[]) ...@@ -82,7 +82,14 @@ int profile_gemm_b_scale(int argc, char* argv[])
const int StrideB = std::stoi(argv[13]); const int StrideB = std::stoi(argv[13]);
const int StrideC = std::stoi(argv[14]); const int StrideC = std::stoi(argv[14]);
const int KBatch = std::stoi(argv[15]); const int KBatch = std::stoi(argv[15]);
printf("M:%d, N:%d, K:%d, StrideA:%d, StrideB:%d, StrideC:%d, KBatch:%d\n", M, N, K, StrideA, StrideB, StrideC, KBatch); printf("M:%d, N:%d, K:%d, StrideA:%d, StrideB:%d, StrideC:%d, KBatch:%d\n",
M,
N,
K,
StrideA,
StrideB,
StrideC,
KBatch);
int n_warmup = 1; int n_warmup = 1;
int n_iter = 10; int n_iter = 10;
...@@ -156,14 +163,18 @@ int profile_gemm_b_scale(int argc, char* argv[]) ...@@ -156,14 +163,18 @@ int profile_gemm_b_scale(int argc, char* argv[])
return pass ? 0 : 1; return pass ? 0 : 1;
}; };
// if(data_type == GemmDataType::F16_I4_F16 && layout == GemmMatrixLayout::MK_NK_MN && B_scale_block == BScaleBlockTile::K_64) // if(data_type == GemmDataType::F16_I4_F16 && layout == GemmMatrixLayout::MK_NK_MN &&
// B_scale_block == BScaleBlockTile::K_64)
// { // {
// return profile(F16{}, I4{}, F16{}, F16{}, F32{}, F16{}, ck::Number<64>{}, Row{}, Col{}, Row{}); // return profile(F16{}, I4{}, F16{}, F16{}, F32{}, F16{}, ck::Number<64>{}, Row{}, Col{},
// Row{});
// } // }
if(data_type == GemmDataType::F16_I4_F16 && layout == GemmMatrixLayout::MK_NK_MN && B_scale_block == BScaleBlockTile::K_128) if(data_type == GemmDataType::F16_I4_F16 && layout == GemmMatrixLayout::MK_NK_MN &&
B_scale_block == BScaleBlockTile::K_128)
{ {
printf("F16_I4_F16 MK_NK_MN K_128\n"); printf("F16_I4_F16 MK_NK_MN K_128\n");
return profile(F16{}, I4{}, F16{}, F16{}, F32{}, F16{}, ck::Number<128>{}, Row{}, Col{}, Row{}); return profile(
F16{}, I4{}, F16{}, F16{}, F32{}, F16{}, ck::Number<128>{}, Row{}, Col{}, Row{});
} }
else else
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment