Commit e35f2c1a authored by aska-0096's avatar aska-0096
Browse files

optimize perf; enable v4; i4_bufferload_not_solved

parent 27f9ed07
...@@ -359,13 +359,6 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -359,13 +359,6 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
// Initialize C // Initialize C
c_thread_buf.Clear(); c_thread_buf.Clear();
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
AccDataType,
1,
xdlops_gemm.GetRegSizePerXdlops(),
true>
c_thread_buf_per_scale;
// Local prefetch 1 // Local prefetch 1
block_sync_lds(); block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, KRepeat, 1>{}([&](auto k0) {
...@@ -381,6 +374,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -381,6 +374,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k0 * BMmaKStride>{}), make_tuple(n0, I0, I0, Number<k0 * BMmaKStride>{}),
b_block_buf, b_block_buf,
b_scale_thread_buf[n0],
b_thread_desc_, b_thread_desc_,
make_tuple(n0, I0, k0, I0), make_tuple(n0, I0, k0, I0),
b_thread_buf); b_thread_buf);
...@@ -406,10 +400,31 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -406,10 +400,31 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, NRepeat, 1>{}([&](auto n0) { b_scale_thread_copy.Run(b_scale_grid_desc,
c_thread_buf_per_scale.Clear(); b_scale_grid_buf,
static_for<0, KRepeat, 1>{}([&](auto k0) { b_scale_thread_desc,
make_tuple(n0, I0),
b_scale_thread_buf);
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc, b_scale_thread_copy_step.At(Number<0>{}));
});
if((i + 2) % num_loop_per_scale == 0)
{
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc, b_scale_thread_copy_step.At(Number<2>{}));
}
else
{
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc, b_scale_thread_copy_step.At(Number<1>{}));
}
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec; vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec; vector_type<ComputeDataType, KPack> b_thread_vec;
...@@ -426,20 +441,13 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -426,20 +441,13 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
typename vector_type<ComputeDataType, typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type; xdlops_gemm.K1PerXdlops>::type;
// constexpr index_t c_offset =
// c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
});
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset = constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
c_thread_buf(Number<c_offset>{}) +=
c_thread_buf_per_scale[Number<t>{}] * xdlops_gemm.Run(
// type_convert<AccDataType>(a_scale_thread_buf[I0]) * a_thread_vec.template AsType<mfma_input_type>(),
type_convert<AccDataType>(b_scale_thread_buf[n0]); b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
}); });
}); });
}); });
...@@ -459,32 +467,12 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -459,32 +467,12 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k0 * BMmaKStride>{}), make_tuple(n0, I0, I0, Number<k0 * BMmaKStride>{}),
b_block_buf, b_block_buf,
b_scale_thread_buf[n0],
b_thread_desc_, b_thread_desc_,
make_tuple(n0, I0, k0, I0), make_tuple(n0, I0, k0, I0),
b_thread_buf); b_thread_buf);
}); });
}); });
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf,
b_scale_thread_desc,
make_tuple(n0, I0),
b_scale_thread_buf);
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc, b_scale_thread_copy_step.At(Number<0>{}));
});
if((i + 2) % num_loop_per_scale == 0)
{
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc, b_scale_thread_copy_step.At(Number<2>{}));
}
else
{
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc, b_scale_thread_copy_step.At(Number<1>{}));
}
HotLoopScheduler(); HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
...@@ -495,10 +483,9 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -495,10 +483,9 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
// tail // tail
if constexpr(TailNum == TailNumber::Full) if constexpr(TailNum == TailNumber::Full)
{ {
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
c_thread_buf_per_scale.Clear(); static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec; vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec; vector_type<ComputeDataType, KPack> b_thread_vec;
...@@ -514,17 +501,12 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -514,17 +501,12 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
using mfma_input_type = using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type; typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
});
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset = constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
c_thread_buf(Number<c_offset>{}) +=
c_thread_buf_per_scale[Number<t>{}] * xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
// type_convert<AccDataType>(a_scale_thread_buf[I0]) * b_thread_vec.template AsType<mfma_input_type>(),
type_convert<AccDataType>(b_scale_thread_buf[n0]); c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
}); });
}); });
}); });
......
...@@ -142,9 +142,9 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -142,9 +142,9 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
using Base::AMmaKStride; using Base::AMmaKStride;
using Base::BMmaKStride; using Base::BMmaKStride;
static constexpr index_t PrefetchStages = 4; static constexpr index_t PrefetchStages = 3;
static constexpr index_t PrefillStages = 2; static constexpr index_t PrefillStages = 2;
static constexpr index_t GlobalBufferNum = 2; static constexpr index_t GlobalBufferNum = 1;
static constexpr index_t HotloopUnroll = 2; static constexpr index_t HotloopUnroll = 2;
__host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
...@@ -164,8 +164,7 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -164,8 +164,7 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
} }
} }
template <typename ScheduleGroup> __device__ static constexpr void HotLoopScheduler()
__device__ static constexpr void HotLoopScheduler(ScheduleGroup schedule_group)
{ {
// TODO: Take data type into consideration as pipe ver 3 // TODO: Take data type into consideration as pipe ver 3
// A-B splited schedule // A-B splited schedule
...@@ -195,42 +194,42 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -195,42 +194,42 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
ignore = i; ignore = i;
static_for<0, num_dsread_per_issue_a, 1>{}([&](auto idsread) { static_for<0, num_dsread_per_issue_a, 1>{}([&](auto idsread) {
ignore = idsread; ignore = idsread;
__builtin_amdgcn_sched_group_barrier(0x100, 1, schedule_group); // DS read __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, schedule_group); // MFMA __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
}); });
static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) { static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) {
ignore = idswrite; ignore = idswrite;
__builtin_amdgcn_sched_group_barrier(0x200, 1, schedule_group); // DS write __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, schedule_group); // MFMA __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
}); });
__builtin_amdgcn_sched_group_barrier(0x020, 1, schedule_group); // VMEM read __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, __builtin_amdgcn_sched_group_barrier(0x008,
num_mfma_per_issue - num_dsread_per_issue_a - num_mfma_per_issue - num_dsread_per_issue_a -
num_dswrite_per_issue_a, num_dswrite_per_issue_a,
schedule_group); // MFMA 0); // MFMA
}); });
static_for<0, num_issue_b, 1>{}([&](auto i) { static_for<0, num_issue_b, 1>{}([&](auto i) {
ignore = i; ignore = i;
static_for<0, num_dsread_per_issue_b, 1>{}([&](auto idsread) { static_for<0, num_dsread_per_issue_b, 1>{}([&](auto idsread) {
ignore = idsread; ignore = idsread;
__builtin_amdgcn_sched_group_barrier(0x100, 1, schedule_group); // DS read __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, schedule_group); // MFMA __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
}); });
static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) { static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
ignore = idswrite; ignore = idswrite;
__builtin_amdgcn_sched_group_barrier(0x200, 1, schedule_group); // DS write __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, schedule_group); // MFMA __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
}); });
__builtin_amdgcn_sched_group_barrier(0x020, 1, schedule_group); // VMEM read __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, __builtin_amdgcn_sched_group_barrier(0x008,
num_mfma_per_issue - num_dsread_per_issue_a - num_mfma_per_issue - num_dsread_per_issue_a -
num_dswrite_per_issue_b, num_dswrite_per_issue_b,
schedule_group); // MFMA 0); // MFMA
}); });
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
} }
...@@ -278,7 +277,6 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -278,7 +277,6 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
index_t num_loop, index_t num_loop,
index_t num_loop_per_scale) const index_t num_loop_per_scale) const
{ {
ignore = num_loop_per_scale;
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>( auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
a_thread_desc_.GetElementSpaceSize()); a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>( auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
...@@ -293,8 +291,8 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -293,8 +291,8 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
StaticallyIndexedArray<decltype(b_scale_thread_buf), Number<2>{}> b_scale_thread_bufs; StaticallyIndexedArray<decltype(b_scale_thread_buf), Number<2>{}> b_scale_thread_bufs;
// Global prefetch 1 // Global prefetch 1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0); b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
...@@ -309,23 +307,50 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -309,23 +307,50 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
b_scale_thread_copy_step.At(Number<0>{})); b_scale_thread_copy_step.At(Number<0>{}));
}); });
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
b_scale_thread_copy_step.At(Number<1>{})); if(num_loop_per_scale == 1)
{
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
b_scale_thread_copy_step.At(Number<2>{}));
}
else
{
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
b_scale_thread_copy_step.At(Number<1>{}));
}
// Local prefill 1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0));
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I0));
// Global prefetch 2 // Global prefetch 2
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1); a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1); b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Local prefill 1 static_for<0, NRepeat, 1>{}([&](auto n0) {
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0), I0); b_scale_thread_copy.Run(b_scale_grid_desc,
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I0), I0); b_scale_grid_buf,
b_scale_thread_desc,
make_tuple(n0, I0),
b_scale_thread_bufs(I1));
// Local prefill 2 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1), I1); b_scale_thread_copy_step.At(Number<0>{}));
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I1), I1); });
if(2 % num_loop_per_scale == 0)
{
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
b_scale_thread_copy_step.At(Number<2>{}));
}
else
{
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
b_scale_thread_copy_step.At(Number<1>{}));
}
// Local prefetch 1 // Local prefetch 1
block_sync_lds(); block_sync_lds();
...@@ -341,6 +366,7 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -341,6 +366,7 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}), make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf.At(I0), b_block_buf.At(I0),
b_scale_thread_bufs(I0)[n0],
b_thread_desc_, b_thread_desc_,
make_tuple(n0, I0, k, I0), make_tuple(n0, I0, k, I0),
b_thread_bufs(I0)); b_thread_bufs(I0));
...@@ -348,23 +374,41 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -348,23 +374,41 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
}); });
}); });
// Local prefill 2
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1));
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I1));
// Global prefetch 3 // Global prefetch 3
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0); b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Global prefetch 4 static_for<0, NRepeat, 1>{}([&](auto n0) {
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1); b_scale_thread_copy.Run(b_scale_grid_desc,
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1); b_scale_grid_buf,
b_scale_thread_desc,
make_tuple(n0, I0),
b_scale_thread_bufs(I0));
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_scale_thread_copy_step.At(Number<0>{}));
});
if(3 % num_loop_per_scale == 0)
{
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
b_scale_thread_copy_step.At(Number<2>{}));
}
else
{
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
b_scale_thread_copy_step.At(Number<1>{}));
}
// Initialize C // Initialize C
c_thread_buf.Clear(); c_thread_buf.Clear();
auto c_thread_buf_per_scale = remove_cvref_t<decltype(c_thread_buf)>(); // need actually?
// main body // main body
if constexpr(HasMainLoop) if constexpr(HasMainLoop)
...@@ -376,9 +420,7 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -376,9 +420,7 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
auto LoopFunc = [&](auto lds_read_buf, auto LoopFunc = [&](auto lds_read_buf,
auto lds_read_reg_buf, auto lds_read_reg_buf,
auto lds_write_buf, auto lds_write_buf,
auto vmem_buf, auto mfma_reg_buf) {
auto mfma_reg_buf,
auto schedule_group) {
block_sync_lds(); block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) { static_for<0, KRepeat, 1>{}([&](auto k) {
...@@ -389,15 +431,15 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -389,15 +431,15 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
a_thread_desc_, a_thread_desc_,
make_tuple(m0, I0, k, I0), make_tuple(m0, I0, k, I0),
a_thread_bufs(lds_read_reg_buf)); a_thread_bufs(lds_read_reg_buf));
static_for<0, NRepeat, 1>{}([&](auto n0) { });
b_thread_copy_.Run( static_for<0, NRepeat, 1>{}([&](auto n0) {
b_block_desc_n0_n1_n2_k, b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}), make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf.At(lds_read_buf), b_block_buf.At(lds_read_buf),
b_thread_desc_, b_scale_thread_bufs(lds_read_buf)[n0],
make_tuple(n0, I0, k, I0), b_thread_desc_,
b_thread_bufs(lds_read_reg_buf)); make_tuple(n0, I0, k, I0),
}); b_thread_bufs(lds_read_reg_buf));
}); });
}); });
...@@ -412,24 +454,30 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -412,24 +454,30 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
b_scale_thread_copy.MoveSrcSliceWindow( b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc, b_scale_thread_copy_step.At(Number<0>{})); b_scale_grid_desc, b_scale_thread_copy_step.At(Number<0>{}));
}); });
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc, b_scale_thread_copy_step.At(Number<1>{}));
a_blockwise_copy.RunWrite( if((i + 4 + mfma_reg_buf.value) % num_loop_per_scale == 0)
a_block_desc, a_block_buf.At(lds_write_buf), vmem_buf); {
b_blockwise_copy.RunWrite( b_scale_thread_copy.MoveSrcSliceWindow(
b_block_desc, b_block_buf.At(lds_write_buf), vmem_buf); b_scale_grid_desc, b_scale_thread_copy_step.At(Number<2>{}));
}
else
{
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc, b_scale_thread_copy_step.At(Number<1>{}));
}
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf));
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf));
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, vmem_buf); a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, vmem_buf); b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
c_thread_buf_per_scale.Clear(); static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec; vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec; vector_type<ComputeDataType, KPack> b_thread_vec;
...@@ -448,29 +496,22 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -448,29 +496,22 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
typename vector_type<ComputeDataType, typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type; xdlops_gemm.K1PerXdlops>::type;
// constexpr index_t c_offset =
// c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
});
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset = constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
c_thread_buf(Number<c_offset>{}) +=
c_thread_buf_per_scale[Number<t>{}] * xdlops_gemm.Run(
type_convert<AccDataType>( a_thread_vec.template AsType<mfma_input_type>(),
b_scale_thread_bufs(mfma_reg_buf)[n0]); b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
}); });
}); });
}); });
HotLoopScheduler(schedule_group); HotLoopScheduler();
}; };
LoopFunc(I1, I1, I0, I0, I0, I0); LoopFunc(I1, I1, I0, I0);
LoopFunc(I0, I0, I1, I1, I1, I0); LoopFunc(I0, I0, I1, I1);
i += HotloopUnroll; i += HotloopUnroll;
} while(i < (num_loop - PrefetchStages)); } while(i < (num_loop - PrefetchStages));
...@@ -479,9 +520,7 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -479,9 +520,7 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
auto ReadWriteCompFunc = [&](auto lds_read_buf, auto ReadWriteCompFunc = [&](auto lds_read_buf,
auto lds_read_reg_buf, auto lds_read_reg_buf,
auto lds_write_buf, auto lds_write_buf,
auto vmem_buf, auto mfma_reg_buf) {
auto mfma_reg_buf,
auto schedule_group) {
block_sync_lds(); block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) { static_for<0, KRepeat, 1>{}([&](auto k) {
...@@ -492,38 +531,24 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -492,38 +531,24 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
a_thread_desc_, a_thread_desc_,
make_tuple(m0, I0, k, I0), make_tuple(m0, I0, k, I0),
a_thread_bufs(lds_read_reg_buf)); a_thread_bufs(lds_read_reg_buf));
static_for<0, NRepeat, 1>{}([&](auto n0) { });
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, static_for<0, NRepeat, 1>{}([&](auto n0) {
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}), b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
b_block_buf.At(lds_read_buf), make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_thread_desc_, b_block_buf.At(lds_read_buf),
make_tuple(n0, I0, k, I0), b_scale_thread_bufs(lds_read_buf)[n0],
b_thread_bufs(lds_read_reg_buf)); b_thread_desc_,
}); make_tuple(n0, I0, k, I0),
b_thread_bufs(lds_read_reg_buf));
}); });
}); });
// B scale copy a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf));
static_for<0, NRepeat, 1>{}([&](auto n0) { b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf));
b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf,
b_scale_thread_desc,
make_tuple(n0, I0),
b_scale_thread_bufs(lds_read_reg_buf));
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
b_scale_thread_copy_step.At(Number<0>{}));
});
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
b_scale_thread_copy_step.At(Number<1>{}));
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf), vmem_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf), vmem_buf);
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
c_thread_buf_per_scale.Clear(); static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec; vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec; vector_type<ComputeDataType, KPack> b_thread_vec;
...@@ -539,30 +564,20 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -539,30 +564,20 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
using mfma_input_type = using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type; typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
// constexpr index_t c_offset = constexpr index_t c_offset =
// c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(), xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(), b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0)); c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
c_thread_buf(Number<c_offset>{}) +=
c_thread_buf_per_scale[Number<t>{}] *
type_convert<AccDataType>(b_scale_thread_bufs(mfma_reg_buf)[n0]);
}); });
}); });
}); });
HotLoopScheduler(schedule_group); HotLoopScheduler();
}; };
auto ReadCompFunc = [&](auto lds_read_buf, auto ReadCompFunc = [&](auto lds_read_buf, auto lds_read_reg_buf, auto mfma_reg_buf) {
auto lds_read_reg_buf,
auto mfma_reg_buf,
auto schedule_group) {
block_sync_lds(); block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) { static_for<0, KRepeat, 1>{}([&](auto k) {
...@@ -573,35 +588,21 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -573,35 +588,21 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
a_thread_desc_, a_thread_desc_,
make_tuple(m0, I0, k, I0), make_tuple(m0, I0, k, I0),
a_thread_bufs(lds_read_reg_buf)); a_thread_bufs(lds_read_reg_buf));
static_for<0, NRepeat, 1>{}([&](auto n0) { });
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, static_for<0, NRepeat, 1>{}([&](auto n0) {
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}), b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
b_block_buf.At(lds_read_buf), make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_thread_desc_, b_block_buf.At(lds_read_buf),
make_tuple(n0, I0, k, I0), b_scale_thread_bufs(lds_read_buf)[n0],
b_thread_bufs(lds_read_reg_buf)); b_thread_desc_,
}); make_tuple(n0, I0, k, I0),
b_thread_bufs(lds_read_reg_buf));
}); });
}); });
// B scale copy static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
b_scale_thread_copy.Run(b_scale_grid_desc, static_for<0, NRepeat, 1>{}([&](auto n0) {
b_scale_grid_buf,
b_scale_thread_desc,
make_tuple(n0, I0),
b_scale_thread_bufs(lds_read_reg_buf));
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
b_scale_thread_copy_step.At(Number<0>{}));
});
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
b_scale_thread_copy_step.At(Number<1>{}));
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
c_thread_buf_per_scale.Clear();
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec; vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec; vector_type<ComputeDataType, KPack> b_thread_vec;
...@@ -617,31 +618,23 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -617,31 +618,23 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
using mfma_input_type = using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type; typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
// constexpr index_t c_offset = constexpr index_t c_offset =
// c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(), xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(), b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0)); c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
c_thread_buf(Number<c_offset>{}) +=
c_thread_buf_per_scale[Number<t>{}] *
type_convert<AccDataType>(b_scale_thread_bufs(mfma_reg_buf)[n0]);
}); });
}); });
}); });
HotLoopScheduler(schedule_group); HotLoopScheduler();
}; };
auto CompFunc = [&](auto mfma_reg_buf) { auto CompFunc = [&](auto mfma_reg_buf) {
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
c_thread_buf_per_scale.Clear(); static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec; vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec; vector_type<ComputeDataType, KPack> b_thread_vec;
...@@ -657,35 +650,27 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -657,35 +650,27 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
using mfma_input_type = using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type; typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
// constexpr index_t c_offset = constexpr index_t c_offset =
// c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(), xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(), b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0)); c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
c_thread_buf(Number<c_offset>{}) +=
c_thread_buf_per_scale[Number<t>{}] *
type_convert<AccDataType>(b_scale_thread_bufs(mfma_reg_buf)[n0]);
}); });
}); });
}); });
}; };
// tail // tail
if constexpr(TailNum == TailNumber::Odd) if constexpr(TailNum == TailNumber::Odd)
{ {
ReadWriteCompFunc(I1, I1, I0, I0, I0, I1); ReadWriteCompFunc(I1, I1, I0, I0);
ReadCompFunc(I0, I0, I1, I1); ReadCompFunc(I0, I0, I1);
CompFunc(I0); CompFunc(I0);
} }
else if constexpr(TailNum == TailNumber::Even) else if constexpr(TailNum == TailNumber::Even)
{ {
ReadWriteCompFunc(I1, I1, I0, I0, I0, I1); ReadCompFunc(I1, I1, I0);
ReadWriteCompFunc(I0, I0, I1, I1, I1, I1);
ReadCompFunc(I1, I1, I0, I1);
CompFunc(I1); CompFunc(I1);
} }
} }
......
...@@ -220,9 +220,9 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale<ALayout, ...@@ -220,9 +220,9 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
constexpr index_t minimum_occupancy = constexpr index_t minimum_occupancy =
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave
? (BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 && ? (BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 &&
MPerBlock * NPerBlock * KPerBlock * sizeof(ADataType) > 128 * 128 * 64 * 2) MPerBlock * NPerBlock * KPerBlock * sizeof(ADataType) <= 128 * 128 * 64 * 2)
? 1 ? 2
: 2 : 1
: 2; : 2;
if(has_main_k_block_loop) if(has_main_k_block_loop)
......
...@@ -11,6 +11,98 @@ ...@@ -11,6 +11,98 @@
namespace ck { namespace ck {
__host__ __device__ inline half4_t pki4_to_half4_scale(int q, const ck::half2_t& scale)
{
constexpr int LO = 0x000f000f;
constexpr int HI = 0x00f000f0;
constexpr int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
// int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
// int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
int lo = amd_assembly_and_or_b32(q, LO, EX);
int hi = amd_assembly_and_or_b32(q, HI, EX);
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
constexpr int SUB = 0xE408E408; //-8
constexpr int MUL = 0x2c002c00; // 1/16
constexpr int ADD = 0xd480d480; //-79
vector_type<half_t, 4> res;
res.template AsType<half2_t>()(Number<0>{}) =
amd_assembly_pk_add_f16(bit_cast<half2_t>(lo), bit_cast<half2_t>(SUB));
res.template AsType<half2_t>()(Number<1>{}) = amd_assembly_pk_fma_f16(
bit_cast<half2_t>(hi), bit_cast<half2_t>(MUL), bit_cast<half2_t>(ADD));
asm volatile("v_pk_mul_f16 %0, %1, %2"
: "=v"(res.template AsType<half2_t>()(Number<0>{}))
: "v"(res.template AsType<half2_t>()(Number<0>{})), "v"(scale));
asm volatile("v_pk_mul_f16 %0, %1, %2"
: "=v"(res.template AsType<half2_t>()(Number<1>{}))
: "v"(res.template AsType<half2_t>()(Number<1>{})), "v"(scale));
return res.template AsType<half4_t>()[Number<0>{}];
}
// Further fuse the scale into inline assembly, sanity failed
#if 0
__host__ __device__ inline half4_t pki4_to_half4_scale(int q, const ck::half_t& scale)
{
constexpr int LO = 0x000f000f;
constexpr int HI = 0x00f000f0;
constexpr int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
// int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
// int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
int lo = amd_assembly_and_or_b32(q, LO, EX);
int hi = amd_assembly_and_or_b32(q, HI, EX);
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
// constexpr int SUB = 0xE408E408; //-8
// constexpr int MUL = 0x2c002c00; // 1/16
// constexpr int ADD = 0xd480d480; //-79
constexpr half_t SUB = bit_cast<half_t>(static_cast<uint16_t>(0xE408));
constexpr half_t MUL = bit_cast<half_t>(static_cast<uint16_t>(0x2c00));
constexpr half_t ADD = bit_cast<half_t>(static_cast<uint16_t>(0xd480));
vector_type<half_t, 2> scale_2;
scale_2.template AsType<half_t>()(Number<0>{}) = scale;
scale_2.template AsType<half_t>()(Number<1>{}) = scale;
vector_type<half_t, 2> sub_2;
sub_2.template AsType<half_t>()(Number<0>{}) = SUB * scale;
sub_2.template AsType<half_t>()(Number<1>{}) = SUB * scale;
vector_type<half_t, 2> mul_2;
mul_2.template AsType<half_t>()(Number<0>{}) = MUL * scale;
mul_2.template AsType<half_t>()(Number<1>{}) = MUL * scale;
vector_type<half_t, 2> add_2;
add_2.template AsType<half_t>()(Number<0>{}) = ADD * scale;
add_2.template AsType<half_t>()(Number<1>{}) = ADD * scale;
vector_type<half_t, 4> res;
res.template AsType<half2_t>()(Number<0>{}) =
amd_assembly_pk_fma_f16(bit_cast<half2_t>(lo),
scale_2.template AsType<half2_t>()(Number<0>{}),
sub_2.template AsType<half2_t>()(Number<0>{}));
res.template AsType<half2_t>()(Number<1>{}) =
amd_assembly_pk_fma_f16(bit_cast<half2_t>(hi),
mul_2.template AsType<half2_t>()(Number<0>{}),
add_2.template AsType<half2_t>()(Number<0>{}));
// asm volatile("v_pk_mul_f16 %0, %1, %2"
// : "=v"(res.template AsType<half2_t>()(Number<0>{}))
// : "v"(res.template AsType<half2_t>()(Number<0>{})), "v"(scale));
// asm volatile("v_pk_mul_f16 %0, %1, %2"
// : "=v"(res.template AsType<half2_t>()(Number<1>{}))
// : "v"(res.template AsType<half2_t>()(Number<1>{})), "v"(scale));
return res.template AsType<half4_t>()[Number<0>{}];
}
#endif
__host__ __device__ inline half4_t pki4_to_half4(int q) __host__ __device__ inline half4_t pki4_to_half4(int q)
{ {
constexpr int LO = 0x000f000f; constexpr int LO = 0x000f000f;
...@@ -119,6 +211,69 @@ struct PassThroughPack8 ...@@ -119,6 +211,69 @@ struct PassThroughPack8
result.template AsType<half4_t>()(Number<0>{}) = pki4_to_half4(bit_cast<int>(x)); result.template AsType<half4_t>()(Number<0>{}) = pki4_to_half4(bit_cast<int>(x));
result.template AsType<half4_t>()(Number<1>{}) = pki4_to_half4(bit_cast<int>(x) >> 8); result.template AsType<half4_t>()(Number<1>{}) = pki4_to_half4(bit_cast<int>(x) >> 8);
y = result.template AsType<half8_t>()[Number<0>{}];
#else
vector_type<half_t, 8> dst;
vector_type<pk_i4_t, 4> src{x};
dst.template AsType<half2_t>()(Number<0>{}) =
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<0>{}]);
dst.template AsType<half2_t>()(Number<1>{}) =
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<1>{}]);
dst.template AsType<half2_t>()(Number<2>{}) =
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<2>{}]);
dst.template AsType<half2_t>()(Number<3>{}) =
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<3>{}]);
y = dst.template AsType<half8_t>()[Number<0>{}];
#endif
}
constexpr const static bool is_pack8_invocable = true;
};
struct DequantPack8
{
template <typename Y, typename X, typename Z>
__host__ __device__ void operator()(Y& y, const X& x, const Z& z) const;
__host__ __device__ constexpr void
operator()(ck::half8_t& y, const ck::pk_i4x4_t& x, const ck::half2_t& z) const
{
#if 1
int x_permute = 0;
int bits4_0 = (bit_cast<int>(x) >> 0) & 0xF;
int bits4_1 = (bit_cast<int>(x) >> 4) & 0xF;
int bits4_2 = (bit_cast<int>(x) >> 8) & 0xF;
int bits4_3 = (bit_cast<int>(x) >> 12) & 0xF;
int bits4_4 = (bit_cast<int>(x) >> 16) & 0xF;
int bits4_5 = (bit_cast<int>(x) >> 20) & 0xF;
int bits4_6 = (bit_cast<int>(x) >> 24) & 0xF;
int bits4_7 = (bit_cast<int>(x) >> 28) & 0xF;
x_permute |= (bits4_1 << 0);
x_permute |= (bits4_3 << 4);
x_permute |= (bits4_5 << 8);
x_permute |= (bits4_7 << 12);
x_permute |= (bits4_0 << 16);
x_permute |= (bits4_2 << 20);
x_permute |= (bits4_4 << 24);
x_permute |= (bits4_6 << 28);
vector_type<half_t, 8> result;
result.template AsType<half4_t>()(Number<0>{}) = pki4_to_half4_scale(x_permute, z);
result.template AsType<half4_t>()(Number<1>{}) = pki4_to_half4_scale(x_permute >> 8, z);
y = result.template AsType<half8_t>()[Number<0>{}];
#elif 1
vector_type<half_t, 8> result;
result.template AsType<half4_t>()(Number<0>{}) = pki4_to_half4_scale(bit_cast<int>(x), z);
result.template AsType<half4_t>()(Number<1>{}) =
pki4_to_half4_scale(bit_cast<int>(x) >> 8, z);
y = result.template AsType<half8_t>()[Number<0>{}]; y = result.template AsType<half8_t>()[Number<0>{}];
#else #else
vector_type<half_t, 8> dst; vector_type<half_t, 8> dst;
......
...@@ -1914,7 +1914,9 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1914,7 +1914,9 @@ struct GridwiseGemm_xdl_cshuffle_v3
make_multi_index(block_n_id * NPerBlock / ScaleBlockN + b_thread_offset, 0)); make_multi_index(block_n_id * NPerBlock / ScaleBlockN + b_thread_offset, 0));
constexpr auto b_scale_thread_slice_copy_step = constexpr auto b_scale_thread_slice_copy_step =
make_tuple(make_multi_index(NWaves * NPerXdl, 0), make_multi_index(-NPerBlock, 1)); make_tuple(make_multi_index(NWaves * NPerXdl, 0),
make_multi_index(-NPerBlock, 0),
make_multi_index(-NPerBlock, 1));
const index_t num_k_block_per_scale = ScaleBlockK / KPerBlock; const index_t num_k_block_per_scale = ScaleBlockK / KPerBlock;
......
...@@ -1252,6 +1252,237 @@ struct ThreadwiseTensorSliceTransfer_v4 ...@@ -1252,6 +1252,237 @@ struct ThreadwiseTensorSliceTransfer_v4
}); });
} }
// Fuse scale
template <typename SrcRefToOriginDisplacement,
typename DstOriginIdx,
typename SrcBuffer,
typename DstBuffer>
__device__ void Run(const SrcDesc&,
const SrcRefToOriginDisplacement&,
const SrcBuffer& src_buf,
const DstData& scale,
const DstDesc&,
const DstOriginIdx&,
DstBuffer& dst_buf) const
{
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc and DstDesc need to known at compile-time");
static_assert(
is_same<remove_cvref_t<typename SrcBuffer::type>, remove_cvref_t<SrcData>>::value &&
is_same<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>::value,
"wrong! SrcBuffer or DstBuffer data type is wrong");
static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer");
static_assert(is_known_at_compile_time<remove_cvref_t<SrcRefToOriginDisplacement>>::value &&
is_known_at_compile_time<remove_cvref_t<DstOriginIdx>>::value,
"wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known "
"at compile-time");
// SrcDesc and DstDesc are known at compile-time
constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
constexpr auto dst_desc = remove_cvref_t<DstDesc>{};
// SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time
constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{});
constexpr auto dst_origin_idx = to_multi_index(DstOriginIdx{});
// scalar per access of each dim
constexpr auto src_scalar_per_access = generate_sequence_v2(
[&](auto i) constexpr {
if constexpr(i == SrcVectorDim)
{
return Number<SrcScalarPerVector>{};
}
else
{
return Number<1>{};
}
},
Number<nDim>{});
// scalar step (if steping on SrcVectorDim) of each dim
constexpr auto src_scalar_step_in_vector = generate_sequence_v2(
[&](auto i) constexpr {
if constexpr(i == SrcVectorDim)
{
return Number<1>{};
}
else
{
return Number<0>{};
}
},
Number<nDim>{});
constexpr auto access_lengths = SliceLengths{} / src_scalar_per_access;
constexpr auto dim_access_order = DimAccessOrder{};
constexpr auto ordered_access_lengths =
container_reorder_given_new2old(access_lengths, dim_access_order);
static_ford<decltype(ordered_access_lengths)>{}([&](auto ordered_access_idx) {
#if 0
// TODO: unable to compile
// position in slice window
constexpr auto data_to_origin_disp_idx =
container_reorder_given_old2new(ordered_access_idx, dim_access_order) *
src_scalar_per_access;
#else
// position in slice window
constexpr auto data_to_origin_disp_idx =
ordered_access_idx.ReorderGivenOld2New(dim_access_order) * src_scalar_per_access;
#endif
// src coordinate
constexpr auto src_ref_to_data_disp_idx =
src_ref_to_origin_disp_idx + data_to_origin_disp_idx;
constexpr auto src_ref_to_data_disp_coord_step =
make_tensor_coordinate_step(src_desc, src_ref_to_data_disp_idx);
auto src_data_coord = src_ref_coord_;
move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_coord_step);
vector_type_maker_t<SrcData, SrcScalarPerVector / PackedSize> src_tmp_vector;
using src_vector_t = typename decltype(src_tmp_vector)::type;
const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
src_desc, src_data_coord);
// copy data from src_buf into src_tmp_vector
if constexpr(SrcBuffer::IsDynamicBuffer())
{
src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
src_buf.template Get<src_vector_t>(src_data_coord.GetOffset() / PackedSize,
is_src_valid);
}
else if constexpr(SrcBuffer::IsStaticBuffer())
{
static_assert(false, "");
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
constexpr index_t src_offset = src_desc.CalculateOffset(
src_ref_to_origin_disp_idx + data_to_origin_disp_idx +
i * src_scalar_step_in_vector);
src_tmp_vector.template AsType<SrcData>()(i) = src_buf[Number<src_offset>{}];
});
}
if constexpr(is_same<remove_cvref_t<SrcData>, pk_i4_t>::value &&
is_same<remove_cvref_t<DstData>, half_t>::value)
{
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
vector_type_maker_t<DstData, SrcScalarPerVector> dst_tmp_vector;
vector_type<DstData, 2> scale_vector;
scale_vector.template AsType<DstData>()(Number<0>{}) = scale;
scale_vector.template AsType<DstData>()(Number<1>{}) = scale;
constexpr index_t pack_size = 8;
static_assert(SrcScalarPerVector % pack_size == 0, "");
using src_v_t = typename vector_type_maker_t<SrcData, pack_size / PackedSize>::type;
using dst_v_t = typename vector_type_maker_t<DstData, pack_size>::type;
using scale_v_t = typename vector_type_maker_t<DstData, 2>::type;
static_for<0, SrcScalarPerVector / pack_size, 1>{}([&](auto i) {
ck::tensor_operation::element_wise::DequantPack8{}(
dst_tmp_vector.template AsType<dst_v_t>()(i),
src_tmp_vector.template AsType<src_v_t>()[i],
scale_vector.template AsType<scale_v_t>()[Number<0>{}]);
});
// copy data from dst_tmp_vector into dst_buf
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
constexpr index_t dst_offset = dst_desc.CalculateOffset(
dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);
dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
});
}
else if constexpr(is_same<remove_cvref_t<SrcData>, pk_i4_t>::value &&
is_same<remove_cvref_t<DstData>, f8_t>::value)
{
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
vector_type_maker_t<DstData, SrcScalarPerVector> dst_tmp_vector;
constexpr index_t pack_size = 8;
static_assert(SrcScalarPerVector % pack_size == 0, "");
using src_v_t = typename vector_type_maker_t<SrcData, pack_size / PackedSize>::type;
using dst_v_t = typename vector_type_maker_t<DstData, pack_size>::type;
static_for<0, SrcScalarPerVector / pack_size, 1>{}([&](auto i) {
ck::tensor_operation::element_wise::PassThroughPack8{}(
dst_tmp_vector.template AsType<dst_v_t>()(i),
src_tmp_vector.template AsType<src_v_t>()[i]);
});
// copy data from dst_tmp_vector into dst_buf
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
constexpr index_t dst_offset = dst_desc.CalculateOffset(
dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);
dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
});
}
else if constexpr(is_same<remove_cvref_t<SrcData>, f8_t>::value &&
is_same<remove_cvref_t<DstData>, half_t>::value &&
SrcScalarPerVector % 2 == 0)
{
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
vector_type_maker_t<DstData, SrcScalarPerVector> dst_tmp_vector;
constexpr index_t pack_size = 2;
using dst_v_t = typename vector_type_maker_t<DstData, pack_size>::type;
using src_v_t = typename vector_type_maker_t<SrcData, pack_size>::type;
static_for<0, SrcScalarPerVector / pack_size, 1>{}([&](auto i) {
ck::tensor_operation::element_wise::PassThroughPack2{}(
dst_tmp_vector.template AsType<dst_v_t>()(i),
src_tmp_vector.template AsType<src_v_t>()[i]);
});
// copy data from dst_tmp_vector into dst_buf
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
constexpr index_t dst_offset = dst_desc.CalculateOffset(
dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);
dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
});
}
else
{
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
vector_type_maker_t<DstData, SrcScalarPerVector> dst_tmp_vector;
// TODO: if SrcData and DstData are vetor type, then static_cast may not compile
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
dst_tmp_vector.template AsType<DstData>()(i) =
type_convert<DstData>(src_tmp_vector.template AsType<SrcData>()[i]);
});
// copy data from dst_tmp_vector into dst_buf
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
constexpr index_t dst_offset = dst_desc.CalculateOffset(
dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);
dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
});
}
});
}
template <typename SrcSliceMoveStepIdx> template <typename SrcSliceMoveStepIdx>
__device__ void MoveSrcSliceWindow(const SrcDesc&, __device__ void MoveSrcSliceWindow(const SrcDesc&,
const SrcSliceMoveStepIdx& src_slice_move_step_idx) const SrcSliceMoveStepIdx& src_slice_move_step_idx)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment