Commit 2b840f5a authored by aska-0096's avatar aska-0096
Browse files

reduce prefetch stage in blockwisepipev4

parent 925c0719
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
namespace ck { namespace ck {
// Compute optimimal pipeline with highest resource request // Compute optimimal pipeline with highest resource request
// GlobalPrefetchStages: 4 // GlobalPrefetchStages: 3
// LocalPreFillStages: 2 // LocalPreFillStages: 2
// LocalPreFetchStages: 1 // LocalPreFetchStages: 1
// LocalSharedMemoryBuffer: 2 // LocalSharedMemoryBuffer: 2
...@@ -142,9 +142,9 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave, ...@@ -142,9 +142,9 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
using Base::AMmaKStride; using Base::AMmaKStride;
using Base::BMmaKStride; using Base::BMmaKStride;
static constexpr index_t PrefetchStages = 4; static constexpr index_t PrefetchStages = 3;
static constexpr index_t PrefillStages = 2; static constexpr index_t PrefillStages = 2;
static constexpr index_t GlobalBufferNum = 2; static constexpr index_t GlobalBufferNum = 1;
static constexpr index_t HotloopUnroll = 2; static constexpr index_t HotloopUnroll = 2;
__host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
...@@ -164,8 +164,7 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave, ...@@ -164,8 +164,7 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
} }
} }
template <typename ScheduleGroup> __device__ static constexpr void HotLoopScheduler()
__device__ static constexpr void HotLoopScheduler(ScheduleGroup schedule_group)
{ {
// TODO: Take data type into consideration as pipe ver 3 // TODO: Take data type into consideration as pipe ver 3
// A-B splited schedule // A-B splited schedule
...@@ -195,42 +194,42 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave, ...@@ -195,42 +194,42 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
ignore = i; ignore = i;
static_for<0, num_dsread_per_issue_a, 1>{}([&](auto idsread) { static_for<0, num_dsread_per_issue_a, 1>{}([&](auto idsread) {
ignore = idsread; ignore = idsread;
__builtin_amdgcn_sched_group_barrier(0x100, 1, schedule_group); // DS read __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, schedule_group); // MFMA __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
}); });
static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) { static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) {
ignore = idswrite; ignore = idswrite;
__builtin_amdgcn_sched_group_barrier(0x200, 1, schedule_group); // DS write __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, schedule_group); // MFMA __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
}); });
__builtin_amdgcn_sched_group_barrier(0x020, 1, schedule_group); // VMEM read __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, __builtin_amdgcn_sched_group_barrier(0x008,
num_mfma_per_issue - num_dsread_per_issue_a - num_mfma_per_issue - num_dsread_per_issue_a -
num_dswrite_per_issue_a, num_dswrite_per_issue_a,
schedule_group); // MFMA 0); // MFMA
}); });
static_for<0, num_issue_b, 1>{}([&](auto i) { static_for<0, num_issue_b, 1>{}([&](auto i) {
ignore = i; ignore = i;
static_for<0, num_dsread_per_issue_b, 1>{}([&](auto idsread) { static_for<0, num_dsread_per_issue_b, 1>{}([&](auto idsread) {
ignore = idsread; ignore = idsread;
__builtin_amdgcn_sched_group_barrier(0x100, 1, schedule_group); // DS read __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, schedule_group); // MFMA __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
}); });
static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) { static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
ignore = idswrite; ignore = idswrite;
__builtin_amdgcn_sched_group_barrier(0x200, 1, schedule_group); // DS write __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, schedule_group); // MFMA __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
}); });
__builtin_amdgcn_sched_group_barrier(0x020, 1, schedule_group); // VMEM read __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, __builtin_amdgcn_sched_group_barrier(0x008,
num_mfma_per_issue - num_dsread_per_issue_a - num_mfma_per_issue - num_dsread_per_issue_a -
num_dswrite_per_issue_b, num_dswrite_per_issue_b,
schedule_group); // MFMA 0); // MFMA
}); });
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
} }
...@@ -274,26 +273,15 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave, ...@@ -274,26 +273,15 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs; StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
// Global prefetch 1 // Global prefetch 1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0); b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Global prefetch 2
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Local prefill 1 // Local prefill 1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0), I0); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0));
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I0), I0); b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I0));
// Local prefill 2
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1), I1);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I1), I1);
// Local prefetch 1 // Local prefetch 1
block_sync_lds(); block_sync_lds();
...@@ -316,16 +304,20 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave, ...@@ -316,16 +304,20 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
}); });
}); });
// Global prefetch 3 // Global prefetch 2
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0); b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Global prefetch 4 // Local prefill 2
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1));
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1); b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I1));
// Global prefetch 3
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
...@@ -343,9 +335,7 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave, ...@@ -343,9 +335,7 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
auto LoopFunc = [&](auto lds_read_buf, auto LoopFunc = [&](auto lds_read_buf,
auto lds_read_reg_buf, auto lds_read_reg_buf,
auto lds_write_buf, auto lds_write_buf,
auto vmem_buf, auto mfma_reg_buf) {
auto mfma_reg_buf,
auto schedule_group) {
block_sync_lds(); block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) { static_for<0, KRepeat, 1>{}([&](auto k) {
...@@ -368,13 +358,11 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave, ...@@ -368,13 +358,11 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
}); });
}); });
a_blockwise_copy.RunWrite( a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf));
a_block_desc, a_block_buf.At(lds_write_buf), vmem_buf); b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf));
b_blockwise_copy.RunWrite(
b_block_desc, b_block_buf.At(lds_write_buf), vmem_buf);
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, vmem_buf); a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, vmem_buf); b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
...@@ -411,11 +399,11 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave, ...@@ -411,11 +399,11 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
}); });
}); });
HotLoopScheduler(schedule_group); HotLoopScheduler();
}; };
LoopFunc(I1, I1, I0, I0, I0, I0); LoopFunc(I1, I1, I0, I0);
LoopFunc(I0, I0, I1, I1, I1, I0); LoopFunc(I0, I0, I1, I1);
i += HotloopUnroll; i += HotloopUnroll;
} while(i < (num_loop - PrefetchStages)); } while(i < (num_loop - PrefetchStages));
...@@ -424,9 +412,7 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave, ...@@ -424,9 +412,7 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
auto ReadWriteCompFunc = [&](auto lds_read_buf, auto ReadWriteCompFunc = [&](auto lds_read_buf,
auto lds_read_reg_buf, auto lds_read_reg_buf,
auto lds_write_buf, auto lds_write_buf,
auto vmem_buf, auto mfma_reg_buf) {
auto mfma_reg_buf,
auto schedule_group) {
block_sync_lds(); block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) { static_for<0, KRepeat, 1>{}([&](auto k) {
...@@ -448,8 +434,8 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave, ...@@ -448,8 +434,8 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
}); });
}); });
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf), vmem_buf); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf));
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf), vmem_buf); b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf));
static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
...@@ -479,13 +465,10 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave, ...@@ -479,13 +465,10 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
}); });
}); });
HotLoopScheduler(schedule_group); HotLoopScheduler();
}; };
auto ReadCompFunc = [&](auto lds_read_buf, auto ReadCompFunc = [&](auto lds_read_buf, auto lds_read_reg_buf, auto mfma_reg_buf) {
auto lds_read_reg_buf,
auto mfma_reg_buf,
auto schedule_group) {
block_sync_lds(); block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) { static_for<0, KRepeat, 1>{}([&](auto k) {
...@@ -535,7 +518,7 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave, ...@@ -535,7 +518,7 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
}); });
}); });
HotLoopScheduler(schedule_group); HotLoopScheduler();
}; };
auto CompFunc = [&](auto mfma_reg_buf) { auto CompFunc = [&](auto mfma_reg_buf) {
...@@ -570,15 +553,13 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave, ...@@ -570,15 +553,13 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
// tail // tail
if constexpr(TailNum == TailNumber::Odd) if constexpr(TailNum == TailNumber::Odd)
{ {
ReadWriteCompFunc(I1, I1, I0, I0, I0, I1); ReadWriteCompFunc(I1, I1, I0, I0);
ReadCompFunc(I0, I0, I1, I1); ReadCompFunc(I0, I0, I1);
CompFunc(I0); CompFunc(I0);
} }
else if constexpr(TailNum == TailNumber::Even) else if constexpr(TailNum == TailNumber::Even)
{ {
ReadWriteCompFunc(I1, I1, I0, I0, I0, I1); ReadCompFunc(I1, I1, I0);
ReadWriteCompFunc(I0, I0, I1, I1, I1, I1);
ReadCompFunc(I1, I1, I0, I1);
CompFunc(I1); CompFunc(I1);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment