Unverified Commit bfe983a1 authored by Adam Osewski's avatar Adam Osewski Committed by GitHub
Browse files

Change block gemm pipeline local prefill loop order. (#1692)

* Fix loop order.

* Fix loop order in pipeline v4
parent b70f367f
...@@ -269,9 +269,9 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave, ...@@ -269,9 +269,9 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
a_thread_desc_, a_thread_desc_,
make_tuple(m0, I0, k, I0), make_tuple(m0, I0, k, I0),
a_thread_buf); a_thread_buf);
});
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run( b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}), make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf, b_block_buf,
b_thread_desc_, b_thread_desc_,
...@@ -279,7 +279,6 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave, ...@@ -279,7 +279,6 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
b_thread_buf); b_thread_buf);
}); });
}); });
});
static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
...@@ -341,6 +340,7 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave, ...@@ -341,6 +340,7 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
a_thread_desc_, a_thread_desc_,
make_tuple(m0, I0, k, I0), make_tuple(m0, I0, k, I0),
a_thread_buf); a_thread_buf);
});
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}), make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
...@@ -350,7 +350,6 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave, ...@@ -350,7 +350,6 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
b_thread_buf); b_thread_buf);
}); });
}); });
});
static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
...@@ -396,6 +395,7 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave, ...@@ -396,6 +395,7 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
a_thread_desc_, a_thread_desc_,
make_tuple(m0, I0, k, I0), make_tuple(m0, I0, k, I0),
a_thread_buf); a_thread_buf);
});
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}), make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
...@@ -405,7 +405,6 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave, ...@@ -405,7 +405,6 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
b_thread_buf); b_thread_buf);
}); });
}); });
});
static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
...@@ -447,6 +446,7 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave, ...@@ -447,6 +446,7 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
a_thread_desc_, a_thread_desc_,
make_tuple(m0, I0, k, I0), make_tuple(m0, I0, k, I0),
a_thread_buf); a_thread_buf);
});
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}), make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
...@@ -456,7 +456,6 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave, ...@@ -456,7 +456,6 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
b_thread_buf); b_thread_buf);
}); });
}); });
});
static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
...@@ -760,16 +759,15 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave, ...@@ -760,16 +759,15 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
a_thread_desc_, a_thread_desc_,
make_tuple(m0, I0, k0, I0), make_tuple(m0, I0, k0, I0),
a_thread_buf); a_thread_buf);
});
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run( b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k0 * KPerInnerLoop>{}), make_tuple(n0, I0, I0, Number<k0 * KPerInnerLoop>{}),
b_block_buf, b_block_buf,
b_thread_desc_, b_thread_desc_,
make_tuple(n0, I0, k0, I0), make_tuple(n0, I0, k0, I0),
b_thread_buf); b_thread_buf);
}); });
});
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
// NOTE: Synchronize threads in a workgroup at the start of each MAC // NOTE: Synchronize threads in a workgroup at the start of each MAC
// cluster, but except the first, as we can shorten non-MAC cluster a bit // cluster, but except the first, as we can shorten non-MAC cluster a bit
...@@ -866,6 +864,7 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave, ...@@ -866,6 +864,7 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
a_thread_desc_, a_thread_desc_,
make_tuple(m0, I0, k0, I0), make_tuple(m0, I0, k0, I0),
a_thread_buf); a_thread_buf);
});
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k0 * KPerInnerLoop>{}), make_tuple(n0, I0, I0, Number<k0 * KPerInnerLoop>{}),
...@@ -874,7 +873,6 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave, ...@@ -874,7 +873,6 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
make_tuple(n0, I0, k0, I0), make_tuple(n0, I0, k0, I0),
b_thread_buf); b_thread_buf);
}); });
});
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
if constexpr(k0.value != 0 || KRepeat == 1) if constexpr(k0.value != 0 || KRepeat == 1)
...@@ -942,6 +940,7 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave, ...@@ -942,6 +940,7 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
a_thread_desc_, a_thread_desc_,
make_tuple(m0, I0, k0, I0), make_tuple(m0, I0, k0, I0),
a_thread_buf); a_thread_buf);
});
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k0 * KPerInnerLoop>{}), make_tuple(n0, I0, I0, Number<k0 * KPerInnerLoop>{}),
...@@ -950,7 +949,6 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave, ...@@ -950,7 +949,6 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
make_tuple(n0, I0, k0, I0), make_tuple(n0, I0, k0, I0),
b_thread_buf); b_thread_buf);
}); });
});
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
if constexpr(k0.value != 0 || KRepeat == 1) if constexpr(k0.value != 0 || KRepeat == 1)
...@@ -1018,6 +1016,7 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave, ...@@ -1018,6 +1016,7 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
a_thread_desc_, a_thread_desc_,
make_tuple(m0, I0, k0, I0), make_tuple(m0, I0, k0, I0),
a_thread_buf); a_thread_buf);
});
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k0 * KPerInnerLoop>{}), make_tuple(n0, I0, I0, Number<k0 * KPerInnerLoop>{}),
...@@ -1026,7 +1025,6 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave, ...@@ -1026,7 +1025,6 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
make_tuple(n0, I0, k0, I0), make_tuple(n0, I0, k0, I0),
b_thread_buf); b_thread_buf);
}); });
});
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
if constexpr(k0.value != 0 || KRepeat == 1) if constexpr(k0.value != 0 || KRepeat == 1)
......
...@@ -305,6 +305,7 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave, ...@@ -305,6 +305,7 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
a_thread_desc_, a_thread_desc_,
make_tuple(m0, I0, k, I0), make_tuple(m0, I0, k, I0),
a_thread_bufs(I0)); a_thread_bufs(I0));
});
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}), make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
...@@ -314,7 +315,6 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave, ...@@ -314,7 +315,6 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
b_thread_bufs(I0)); b_thread_bufs(I0));
}); });
}); });
});
// Global prefetch 3 // Global prefetch 3
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
...@@ -356,9 +356,9 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave, ...@@ -356,9 +356,9 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
a_thread_desc_, a_thread_desc_,
make_tuple(m0, I0, k, I0), make_tuple(m0, I0, k, I0),
a_thread_bufs(lds_read_reg_buf)); a_thread_bufs(lds_read_reg_buf));
});
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run( b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}), make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf.At(lds_read_buf), b_block_buf.At(lds_read_buf),
b_thread_desc_, b_thread_desc_,
...@@ -366,7 +366,6 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave, ...@@ -366,7 +366,6 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
b_thread_bufs(lds_read_reg_buf)); b_thread_bufs(lds_read_reg_buf));
}); });
}); });
});
a_blockwise_copy.RunWrite( a_blockwise_copy.RunWrite(
a_block_desc, a_block_buf.At(lds_write_buf), vmem_buf); a_block_desc, a_block_buf.At(lds_write_buf), vmem_buf);
...@@ -437,6 +436,7 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave, ...@@ -437,6 +436,7 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
a_thread_desc_, a_thread_desc_,
make_tuple(m0, I0, k, I0), make_tuple(m0, I0, k, I0),
a_thread_bufs(lds_read_reg_buf)); a_thread_bufs(lds_read_reg_buf));
});
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}), make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
...@@ -446,7 +446,6 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave, ...@@ -446,7 +446,6 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
b_thread_bufs(lds_read_reg_buf)); b_thread_bufs(lds_read_reg_buf));
}); });
}); });
});
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf), vmem_buf); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf), vmem_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf), vmem_buf); b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf), vmem_buf);
...@@ -496,6 +495,7 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave, ...@@ -496,6 +495,7 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
a_thread_desc_, a_thread_desc_,
make_tuple(m0, I0, k, I0), make_tuple(m0, I0, k, I0),
a_thread_bufs(lds_read_reg_buf)); a_thread_bufs(lds_read_reg_buf));
});
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}), make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
...@@ -505,7 +505,6 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave, ...@@ -505,7 +505,6 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
b_thread_bufs(lds_read_reg_buf)); b_thread_bufs(lds_read_reg_buf));
}); });
}); });
});
static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment