Commit 03c2ba3a authored by aska-0096's avatar aska-0096
Browse files

bug fix + performance opt + clangformat

parent 1a324dfb
...@@ -170,7 +170,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -170,7 +170,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_b_scale<BlockGemmPipelineScheduler::Intra
typename BBlockBuffer, typename BBlockBuffer,
typename BBlockTransferStep, typename BBlockTransferStep,
typename CThreadBuffer, typename CThreadBuffer,
//BScale Thread Copy // BScale Thread Copy
typename BScaleGridBuffer, typename BScaleGridBuffer,
typename BScaleGridDesc, typename BScaleGridDesc,
typename BScaleThreadDesc, typename BScaleThreadDesc,
......
...@@ -719,12 +719,12 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Inter ...@@ -719,12 +719,12 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Inter
const BBlockTransferStep& b_block_copy_step, const BBlockTransferStep& b_block_copy_step,
CThreadBuffer& c_thread_buf, CThreadBuffer& c_thread_buf,
const BScaleGridDesc& b_scale_grid_desc, const BScaleGridDesc& b_scale_grid_desc,
//BScaleThreadCopy // BScaleThreadCopy
const BScaleThreadDesc& b_scale_thread_desc, const BScaleThreadDesc& b_scale_thread_desc,
BScaleThreadTransfer& b_scale_thread_copy, BScaleThreadTransfer& b_scale_thread_copy,
const BScaleGridBuffer& b_scale_grid_buf, const BScaleGridBuffer& b_scale_grid_buf,
const BScaleThreadTransferStep& b_scale_thread_copy_step, const BScaleThreadTransferStep& b_scale_thread_copy_step,
//num loop // num loop
index_t num_loop, index_t num_loop,
index_t num_loop_per_scale) const index_t num_loop_per_scale) const
{ {
...@@ -864,14 +864,14 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Inter ...@@ -864,14 +864,14 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Inter
} }
}); });
// static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { // static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t)
// {
// constexpr index_t c_offset = // constexpr index_t c_offset =
// c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); // c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
// c_thread_buf(Number<c_offset>{}) += // c_thread_buf(Number<c_offset>{}) +=
// c_thread_buf_per_scale[Number<t>{}] * // c_thread_buf_per_scale[Number<t>{}] *
// type_convert<AccDataType>(b_scale_thread_buf[n0]); // type_convert<AccDataType>(b_scale_thread_buf[n0]);
// }); // });
}); });
}); });
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
...@@ -986,7 +986,6 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Inter ...@@ -986,7 +986,6 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Inter
// c_thread_buf_per_scale[Number<t>{}] * // c_thread_buf_per_scale[Number<t>{}] *
// type_convert<AccDataType>(b_scale_thread_buf[n0]); // type_convert<AccDataType>(b_scale_thread_buf[n0]);
// }); // });
}); });
}); });
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
...@@ -1084,7 +1083,6 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Inter ...@@ -1084,7 +1083,6 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Inter
// c_thread_buf_per_scale[Number<t>{}] * // c_thread_buf_per_scale[Number<t>{}] *
// type_convert<AccDataType>(b_scale_thread_buf[n0]); // type_convert<AccDataType>(b_scale_thread_buf[n0]);
// }); // });
}); });
}); });
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
......
...@@ -295,25 +295,24 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -295,25 +295,24 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
BBlockBuffer& b_block_buf, BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step, const BBlockTransferStep& b_block_copy_step,
CThreadBuffer& c_thread_buf, CThreadBuffer& c_thread_buf,
//BScaleThreadCopy // BScaleThreadCopy
const BScaleGridDesc& b_scale_grid_desc, const BScaleGridDesc& b_scale_grid_desc,
const BScaleThreadDesc& b_scale_thread_desc, const BScaleThreadDesc& b_scale_thread_desc,
BScaleThreadTransfer& b_scale_thread_copy, BScaleThreadTransfer& b_scale_thread_copy,
const BScaleGridBuffer& b_scale_grid_buf, const BScaleGridBuffer& b_scale_grid_buf,
const BScaleThreadTransferStep& b_scale_thread_copy_step, const BScaleThreadTransferStep& b_scale_thread_copy_step,
//num loop // num loop
index_t num_loop, index_t num_loop,
index_t num_loop_per_scale) const index_t num_loop_per_scale) const
{ {
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
ignore = num_loop_per_scale;
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>( auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
a_thread_desc_.GetElementSpaceSize()); a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>( auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
b_thread_desc_.GetElementSpaceSize()); b_thread_desc_.GetElementSpaceSize());
//B scale buffer // B scale buffer
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>( auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
b_scale_thread_desc.GetElementSpaceSize()); b_scale_thread_desc.GetElementSpaceSize());
...@@ -328,14 +327,23 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -328,14 +327,23 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
b_scale_thread_copy.Run(b_scale_grid_desc, b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf, b_scale_grid_buf,
b_scale_thread_desc, b_scale_thread_desc,
make_tuple(n0, I0, I0), make_tuple(n0, I0),
b_scale_thread_buf); b_scale_thread_buf);
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
b_scale_thread_copy_step.At(Number<0>{})); b_scale_thread_copy_step.At(Number<0>{}));
}); });
if(num_loop_per_scale == 1)
{
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
b_scale_thread_copy_step.At(Number<2>{}));
}
else
{
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
b_scale_thread_copy_step.At(Number<1>{})); b_scale_thread_copy_step.At(Number<1>{}));
}
// Local prefill 1 // Local prefill 1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
...@@ -350,7 +358,13 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -350,7 +358,13 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
// Initialize C // Initialize C
c_thread_buf.Clear(); c_thread_buf.Clear();
auto c_thread_buf_per_scale = remove_cvref_t<decltype(c_thread_buf)>(); // need actually?
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
AccDataType,
1,
xdlops_gemm.GetRegSizePerXdlops(),
true>
c_thread_buf_per_scale;
// Local prefetch 1 // Local prefetch 1
block_sync_lds(); block_sync_lds();
...@@ -415,8 +429,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -415,8 +429,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
// constexpr index_t c_offset = // constexpr index_t c_offset =
// c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); // c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run( xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(), b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0)); c_thread_buf_per_scale.GetVectorTypeReference(I0));
}); });
...@@ -455,15 +468,23 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -455,15 +468,23 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
b_scale_thread_copy.Run(b_scale_grid_desc, b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf, b_scale_grid_buf,
b_scale_thread_desc, b_scale_thread_desc,
make_tuple(n0, I0, I0), make_tuple(n0, I0),
b_scale_thread_buf); b_scale_thread_buf);
b_scale_thread_copy.MoveSrcSliceWindow( b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc, b_scale_thread_copy_step.At(Number<0>{})); b_scale_grid_desc, b_scale_thread_copy_step.At(Number<0>{}));
}); });
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, if((i + 2) % num_loop_per_scale == 0)
b_scale_thread_copy_step.At(Number<1>{})); {
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc, b_scale_thread_copy_step.At(Number<2>{}));
}
else
{
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc, b_scale_thread_copy_step.At(Number<1>{}));
}
HotLoopScheduler(); HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
......
...@@ -268,13 +268,13 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -268,13 +268,13 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
BBlockBuffer& b_block_buf, BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step, const BBlockTransferStep& b_block_copy_step,
CThreadBuffer& c_thread_buf, CThreadBuffer& c_thread_buf,
//BScaleThreadCopy // BScaleThreadCopy
const BScaleGridDesc& b_scale_grid_desc, const BScaleGridDesc& b_scale_grid_desc,
const BScaleThreadDesc& b_scale_thread_desc, const BScaleThreadDesc& b_scale_thread_desc,
BScaleThreadTransfer& b_scale_thread_copy, BScaleThreadTransfer& b_scale_thread_copy,
const BScaleGridBuffer& b_scale_grid_buf, const BScaleGridBuffer& b_scale_grid_buf,
const BScaleThreadTransferStep& b_scale_thread_copy_step, const BScaleThreadTransferStep& b_scale_thread_copy_step,
//num loop // num loop
index_t num_loop, index_t num_loop,
index_t num_loop_per_scale) const index_t num_loop_per_scale) const
{ {
...@@ -284,7 +284,7 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -284,7 +284,7 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>( auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
b_thread_desc_.GetElementSpaceSize()); b_thread_desc_.GetElementSpaceSize());
//B scale buffer // B scale buffer
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>( auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
b_scale_thread_desc.GetElementSpaceSize()); b_scale_thread_desc.GetElementSpaceSize());
...@@ -409,11 +409,11 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -409,11 +409,11 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
make_tuple(n0, I0), make_tuple(n0, I0),
b_scale_thread_bufs(lds_read_reg_buf)); b_scale_thread_bufs(lds_read_reg_buf));
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_thread_copy_step.At(Number<0>{})); b_scale_grid_desc, b_scale_thread_copy_step.At(Number<0>{}));
}); });
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_thread_copy_step.At(Number<1>{})); b_scale_grid_desc, b_scale_thread_copy_step.At(Number<1>{}));
a_blockwise_copy.RunWrite( a_blockwise_copy.RunWrite(
a_block_desc, a_block_buf.At(lds_write_buf), vmem_buf); a_block_desc, a_block_buf.At(lds_write_buf), vmem_buf);
...@@ -426,7 +426,6 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -426,7 +426,6 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
c_thread_buf_per_scale.Clear(); c_thread_buf_per_scale.Clear();
...@@ -452,8 +451,7 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -452,8 +451,7 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
// constexpr index_t c_offset = // constexpr index_t c_offset =
// c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); // c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run( xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(), b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0)); c_thread_buf_per_scale.GetVectorTypeReference(I0));
}); });
...@@ -462,7 +460,8 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -462,7 +460,8 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
c_thread_buf(Number<c_offset>{}) += c_thread_buf(Number<c_offset>{}) +=
c_thread_buf_per_scale[Number<t>{}] * c_thread_buf_per_scale[Number<t>{}] *
type_convert<AccDataType>(b_scale_thread_bufs(mfma_reg_buf)[n0]); type_convert<AccDataType>(
b_scale_thread_bufs(mfma_reg_buf)[n0]);
}); });
}); });
}); });
...@@ -521,7 +520,6 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -521,7 +520,6 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf), vmem_buf); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf), vmem_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf), vmem_buf); b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf), vmem_buf);
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
c_thread_buf_per_scale.Clear(); c_thread_buf_per_scale.Clear();
...@@ -640,7 +638,6 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -640,7 +638,6 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
}; };
auto CompFunc = [&](auto mfma_reg_buf) { auto CompFunc = [&](auto mfma_reg_buf) {
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
c_thread_buf_per_scale.Clear(); c_thread_buf_per_scale.Clear();
......
...@@ -218,7 +218,12 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale<ALayout, ...@@ -218,7 +218,12 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
}; };
constexpr index_t minimum_occupancy = constexpr index_t minimum_occupancy =
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave
? (BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 &&
MPerBlock * NPerBlock * KPerBlock * sizeof(ADataType) > 128 * 128 * 64 * 2)
? 1
: 2
: 2;
if(has_main_k_block_loop) if(has_main_k_block_loop)
{ {
...@@ -664,7 +669,20 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale<ALayout, ...@@ -664,7 +669,20 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) CElementwiseOperation c_element_op)
{ {
return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, p_b_scale, KBatch, a_element_op, b_element_op, c_element_op}; return Argument{p_a,
p_b,
p_c,
M,
N,
K,
StrideA,
StrideB,
StrideC,
p_b_scale,
KBatch,
a_element_op,
b_element_op,
c_element_op};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
......
...@@ -13,9 +13,9 @@ namespace ck { ...@@ -13,9 +13,9 @@ namespace ck {
__host__ __device__ inline half4_t pki4_to_half4(int q) __host__ __device__ inline half4_t pki4_to_half4(int q)
{ {
const int LO = 0x000f000f; constexpr int LO = 0x000f000f;
const int HI = 0x00f000f0; constexpr int HI = 0x00f000f0;
const int EX = 0x64006400; constexpr int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s. // Guarantee that the `(a & b) | c` operations are LOP3s.
// int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); // int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
// int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); // int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
...@@ -23,9 +23,9 @@ __host__ __device__ inline half4_t pki4_to_half4(int q) ...@@ -23,9 +23,9 @@ __host__ __device__ inline half4_t pki4_to_half4(int q)
int hi = amd_assembly_and_or_b32(q, HI, EX); int hi = amd_assembly_and_or_b32(q, HI, EX);
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`. // directly into `SUB` and `ADD`.
const int SUB = 0xE408E408; //-8 constexpr int SUB = 0xE408E408; //-8
const int MUL = 0x2c002c00; // 1/16 constexpr int MUL = 0x2c002c00; // 1/16
const int ADD = 0xd480d480; //-79 constexpr int ADD = 0xd480d480; //-79
vector_type<half_t, 4> res; vector_type<half_t, 4> res;
...@@ -34,7 +34,15 @@ __host__ __device__ inline half4_t pki4_to_half4(int q) ...@@ -34,7 +34,15 @@ __host__ __device__ inline half4_t pki4_to_half4(int q)
res.template AsType<half2_t>()(Number<1>{}) = amd_assembly_pk_fma_f16( res.template AsType<half2_t>()(Number<1>{}) = amd_assembly_pk_fma_f16(
bit_cast<half2_t>(hi), bit_cast<half2_t>(MUL), bit_cast<half2_t>(ADD)); bit_cast<half2_t>(hi), bit_cast<half2_t>(MUL), bit_cast<half2_t>(ADD));
#if 0
asm volatile("v_and_or_b32 %0, %4, %5, %7 \n \
v_and_or_b32 %1, %4, %6, %7 \n \
v_pk_add_f16 %2, %0, %8 \n \
v_pk_fma_f16 %3, %1, %9, %10 \
"
: "=v"(lo), "=v"(hi), "=v"(res.template AsType<half2_t>()(Number<0>{})), "=v"(res.template AsType<half2_t>()(Number<1>{}))
: "v"(q), "v"(LO), "v"(HI), "s"(EX), "s"(SUB), "v"(MUL), "s"(ADD), "0"(lo), "1"(hi));
#endif
return res.template AsType<half4_t>()[Number<0>{}]; return res.template AsType<half4_t>()[Number<0>{}];
} }
......
...@@ -48,12 +48,7 @@ __global__ void ...@@ -48,12 +48,7 @@ __global__ void
// karg); // karg);
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>( GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid, karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, karg.p_b_scale_grid, p_shared, karg);
karg.p_b_grid,
karg.p_c_grid,
karg.p_b_scale_grid,
p_shared,
karg);
#else #else
ignore = karg; ignore = karg;
#endif // end of if (defined(__gfx9__)) #endif // end of if (defined(__gfx9__))
...@@ -1303,11 +1298,8 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1303,11 +1298,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
// B Scale grid and buffer // B Scale grid and buffer
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor( const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN), make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN),
math::integer_divide_ceil(problem.K, ScaleBlockK), math::integer_divide_ceil(problem.K, ScaleBlockK)),
ScaleBlockK), make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK),
1,
0));
const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize()); p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize());
...@@ -1438,12 +1430,12 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1438,12 +1430,12 @@ struct GridwiseGemm_xdl_cshuffle_v3
KPerBlock); KPerBlock);
// b scale // b scale
static_assert(KPerBlock<=ScaleBlockK); static_assert(KPerBlock <= ScaleBlockK);
const index_t ScaleSliceSizeN = NXdlPerWave; const index_t ScaleSliceSizeN = NXdlPerWave;
const index_t ScaleSliceSizeK = 1; const index_t ScaleSliceSizeK = 1;
constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<ScaleSliceSizeN>{}, Number<ScaleSliceSizeK>{}, Number<1>{})); make_tuple(Number<ScaleSliceSizeN>{}, Number<ScaleSliceSizeK>{}));
constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
...@@ -1455,22 +1447,24 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1455,22 +1447,24 @@ struct GridwiseGemm_xdl_cshuffle_v3
BScaleType, BScaleType,
decltype(b_scale_grid_desc_bn_ak), decltype(b_scale_grid_desc_bn_ak),
decltype(b_scale_thread_desc), decltype(b_scale_thread_desc),
Sequence<1, ScaleSliceSizeK, 1>, Sequence<1, ScaleSliceSizeK>,
Sequence<0, 1, 2>, Sequence<0, 1>,
1, 1,
1, 1,
1, 1,
false>( false>(
b_scale_grid_desc_bn_ak, b_scale_grid_desc_bn_ak,
make_multi_index(block_n_id * NPerBlock / ScaleBlockN + b_thread_offset, 0, 0)); make_multi_index(block_n_id * NPerBlock / ScaleBlockN + b_thread_offset, 0));
constexpr auto b_scale_thread_slice_copy_step = constexpr auto b_scale_thread_slice_copy_step =
make_tuple(make_multi_index(NWaves * NPerXdl, 0, 0), make_tuple(make_multi_index(NWaves * NPerXdl, 0),
make_multi_index(-NPerBlock, KPerBlock/ScaleBlockK, KPerBlock%ScaleBlockK)); make_multi_index(-NPerBlock, 0),
make_multi_index(-NPerBlock, 1));
const index_t num_k_block_per_scale = ScaleBlockK / KPerBlock; const index_t num_k_block_per_scale = ScaleBlockK / KPerBlock;
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1, blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1,
a_blockwise_copy, a_blockwise_copy,
a_grid_buf, a_grid_buf,
...@@ -1867,7 +1861,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1867,7 +1861,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
auto b_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
bit_cast<BDataType*>(static_cast<char*>(p_shared_0) + bit_cast<BDataType*>(static_cast<char*>(p_shared_0) +
a_block_space_size_aligned * sizeof(ADataType)/APackedSize), a_block_space_size_aligned * sizeof(ADataType) / APackedSize),
b_block_desc_bk0_n_bk1.GetElementSpaceSize()); b_block_desc_bk0_n_bk1.GetElementSpaceSize());
auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
...@@ -1875,7 +1869,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1875,7 +1869,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
auto b_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
bit_cast<BDataType*>(bit_cast<char*>(p_shared_1) + bit_cast<BDataType*>(bit_cast<char*>(p_shared_1) +
a_block_space_size_aligned * sizeof(ADataType)/APackedSize), a_block_space_size_aligned * sizeof(ADataType) / APackedSize),
b_block_desc_bk0_n_bk1.GetElementSpaceSize()); b_block_desc_bk0_n_bk1.GetElementSpaceSize());
auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong); auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
...@@ -1924,7 +1918,8 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1924,7 +1918,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
const index_t num_k_block_per_scale = ScaleBlockK / KPerBlock; const index_t num_k_block_per_scale = ScaleBlockK / KPerBlock;
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1, blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1,
a_blockwise_copy, a_blockwise_copy,
a_grid_buf, a_grid_buf,
......
...@@ -40,8 +40,8 @@ template <typename ADataType, ...@@ -40,8 +40,8 @@ template <typename ADataType,
typename BLayout, typename BLayout,
typename CLayout, typename CLayout,
index_t ScaleBlockK> index_t ScaleBlockK>
struct DeviceOperationInstanceFactory< struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmV2BScale<
ck::tensor_operation::device::DeviceGemmV2BScale<ALayout, ALayout,
BLayout, BLayout,
CLayout, CLayout,
ADataType, ADataType,
......
...@@ -8,13 +8,22 @@ namespace tensor_operation { ...@@ -8,13 +8,22 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
void add_device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instances( void add_device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<DeviceGemmV2BScale<Row,
DeviceGemmV2BScale<Row, Col, Row, F16, I4, F16, F16, 1, 128, PassThrough, PassThrough, PassThrough>>>& Col,
instances) Row,
F16,
I4,
F16,
F16,
1,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, instances,
//device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_instances<Interwave, GemmDefault>{}); // device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_instances<Interwave, GemmDefault>{});
device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_instances<Intrawave, GemmDefault>{}); device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_instances<Intrawave, GemmDefault>{});
} }
......
...@@ -72,7 +72,8 @@ bool profile_gemm_b_scale_impl(int do_verification, ...@@ -72,7 +72,8 @@ bool profile_gemm_b_scale_impl(int do_verification,
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<BDataType> b_k_n_permute(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor<BDataType> b_k_n_permute(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<BScaleDataType> b1_k_n(f_host_tensor_descriptor((K + ScaleBlockK - 1) / ScaleBlockK, // K direction group size is ScaleBlockK Tensor<BScaleDataType> b1_k_n(f_host_tensor_descriptor(
(K + ScaleBlockK - 1) / ScaleBlockK, // K direction group size is ScaleBlockK
N, // N direction group size is 1 N, // N direction group size is 1
Scale_Stride_BN, Scale_Stride_BN,
BLayout{})); BLayout{}));
...@@ -167,8 +168,7 @@ bool profile_gemm_b_scale_impl(int do_verification, ...@@ -167,8 +168,7 @@ bool profile_gemm_b_scale_impl(int do_verification,
i4 = i4 - 8; i4 = i4 - 8;
v_b = ck::type_convert<float>(i4); v_b = ck::type_convert<float>(i4);
b_k_n_dequant(k, n) = b_k_n_dequant(k, n) = ck::type_convert<float>(v_b) *
ck::type_convert<float>(v_b) *
ck::type_convert<float>(b1_k_n(k / ScaleBlockK, n)); ck::type_convert<float>(b1_k_n(k / ScaleBlockK, n));
} }
} }
...@@ -291,8 +291,8 @@ bool profile_gemm_b_scale_impl(int do_verification, ...@@ -291,8 +291,8 @@ bool profile_gemm_b_scale_impl(int do_verification,
{ {
auto kbatch_curr = kbatch_list[i]; auto kbatch_curr = kbatch_list[i];
auto argument_ptr = auto argument_ptr = op_ptr->MakeArgumentPointer(
op_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()), static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()), static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()), static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
M, M,
......
...@@ -82,7 +82,14 @@ int profile_gemm_b_scale(int argc, char* argv[]) ...@@ -82,7 +82,14 @@ int profile_gemm_b_scale(int argc, char* argv[])
const int StrideB = std::stoi(argv[13]); const int StrideB = std::stoi(argv[13]);
const int StrideC = std::stoi(argv[14]); const int StrideC = std::stoi(argv[14]);
const int KBatch = std::stoi(argv[15]); const int KBatch = std::stoi(argv[15]);
printf("M:%d, N:%d, K:%d, StrideA:%d, StrideB:%d, StrideC:%d, KBatch:%d\n", M, N, K, StrideA, StrideB, StrideC, KBatch); printf("M:%d, N:%d, K:%d, StrideA:%d, StrideB:%d, StrideC:%d, KBatch:%d\n",
M,
N,
K,
StrideA,
StrideB,
StrideC,
KBatch);
int n_warmup = 1; int n_warmup = 1;
int n_iter = 10; int n_iter = 10;
...@@ -156,14 +163,18 @@ int profile_gemm_b_scale(int argc, char* argv[]) ...@@ -156,14 +163,18 @@ int profile_gemm_b_scale(int argc, char* argv[])
return pass ? 0 : 1; return pass ? 0 : 1;
}; };
// if(data_type == GemmDataType::F16_I4_F16 && layout == GemmMatrixLayout::MK_NK_MN && B_scale_block == BScaleBlockTile::K_64) // if(data_type == GemmDataType::F16_I4_F16 && layout == GemmMatrixLayout::MK_NK_MN &&
// B_scale_block == BScaleBlockTile::K_64)
// { // {
// return profile(F16{}, I4{}, F16{}, F16{}, F32{}, F16{}, ck::Number<64>{}, Row{}, Col{}, Row{}); // return profile(F16{}, I4{}, F16{}, F16{}, F32{}, F16{}, ck::Number<64>{}, Row{}, Col{},
// Row{});
// } // }
if(data_type == GemmDataType::F16_I4_F16 && layout == GemmMatrixLayout::MK_NK_MN && B_scale_block == BScaleBlockTile::K_128) if(data_type == GemmDataType::F16_I4_F16 && layout == GemmMatrixLayout::MK_NK_MN &&
B_scale_block == BScaleBlockTile::K_128)
{ {
printf("F16_I4_F16 MK_NK_MN K_128\n"); printf("F16_I4_F16 MK_NK_MN K_128\n");
return profile(F16{}, I4{}, F16{}, F16{}, F32{}, F16{}, ck::Number<128>{}, Row{}, Col{}, Row{}); return profile(
F16{}, I4{}, F16{}, F16{}, F32{}, F16{}, ck::Number<128>{}, Row{}, Col{}, Row{});
} }
else else
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment