Unverified Commit bfe983a1 authored by Adam Osewski's avatar Adam Osewski Committed by GitHub
Browse files

Change block gemm pipeline local prefill loop order. (#1692)

* Fix loop order.

* Fix loop order in pipeline v4
parent b70f367f
...@@ -269,15 +269,14 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave, ...@@ -269,15 +269,14 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
a_thread_desc_, a_thread_desc_,
make_tuple(m0, I0, k, I0), make_tuple(m0, I0, k, I0),
a_thread_buf); a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) { });
b_thread_copy_.Run( static_for<0, NRepeat, 1>{}([&](auto n0) {
b_block_desc_n0_n1_n2_k, b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}), make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf, b_block_buf,
b_thread_desc_, b_thread_desc_,
make_tuple(n0, I0, k, I0), make_tuple(n0, I0, k, I0),
b_thread_buf); b_thread_buf);
});
}); });
}); });
...@@ -341,14 +340,14 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave, ...@@ -341,14 +340,14 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
a_thread_desc_, a_thread_desc_,
make_tuple(m0, I0, k, I0), make_tuple(m0, I0, k, I0),
a_thread_buf); a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) { });
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, static_for<0, NRepeat, 1>{}([&](auto n0) {
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}), b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
b_block_buf, make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_thread_desc_, b_block_buf,
make_tuple(n0, I0, k, I0), b_thread_desc_,
b_thread_buf); make_tuple(n0, I0, k, I0),
}); b_thread_buf);
}); });
}); });
...@@ -396,14 +395,14 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave, ...@@ -396,14 +395,14 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
a_thread_desc_, a_thread_desc_,
make_tuple(m0, I0, k, I0), make_tuple(m0, I0, k, I0),
a_thread_buf); a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) { });
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, static_for<0, NRepeat, 1>{}([&](auto n0) {
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}), b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
b_block_buf, make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_thread_desc_, b_block_buf,
make_tuple(n0, I0, k, I0), b_thread_desc_,
b_thread_buf); make_tuple(n0, I0, k, I0),
}); b_thread_buf);
}); });
}); });
...@@ -447,14 +446,14 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave, ...@@ -447,14 +446,14 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
a_thread_desc_, a_thread_desc_,
make_tuple(m0, I0, k, I0), make_tuple(m0, I0, k, I0),
a_thread_buf); a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) { });
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, static_for<0, NRepeat, 1>{}([&](auto n0) {
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}), b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
b_block_buf, make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_thread_desc_, b_block_buf,
make_tuple(n0, I0, k, I0), b_thread_desc_,
b_thread_buf); make_tuple(n0, I0, k, I0),
}); b_thread_buf);
}); });
}); });
...@@ -760,15 +759,14 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave, ...@@ -760,15 +759,14 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
a_thread_desc_, a_thread_desc_,
make_tuple(m0, I0, k0, I0), make_tuple(m0, I0, k0, I0),
a_thread_buf); a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) { });
b_thread_copy_.Run( static_for<0, NRepeat, 1>{}([&](auto n0) {
b_block_desc_n0_n1_n2_k, b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k0 * KPerInnerLoop>{}), make_tuple(n0, I0, I0, Number<k0 * KPerInnerLoop>{}),
b_block_buf, b_block_buf,
b_thread_desc_, b_thread_desc_,
make_tuple(n0, I0, k0, I0), make_tuple(n0, I0, k0, I0),
b_thread_buf); b_thread_buf);
});
}); });
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
// NOTE: Synchronize threads in a workgroup at the start of each MAC // NOTE: Synchronize threads in a workgroup at the start of each MAC
...@@ -866,14 +864,14 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave, ...@@ -866,14 +864,14 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
a_thread_desc_, a_thread_desc_,
make_tuple(m0, I0, k0, I0), make_tuple(m0, I0, k0, I0),
a_thread_buf); a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) { });
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, static_for<0, NRepeat, 1>{}([&](auto n0) {
make_tuple(n0, I0, I0, Number<k0 * KPerInnerLoop>{}), b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
b_block_buf, make_tuple(n0, I0, I0, Number<k0 * KPerInnerLoop>{}),
b_thread_desc_, b_block_buf,
make_tuple(n0, I0, k0, I0), b_thread_desc_,
b_thread_buf); make_tuple(n0, I0, k0, I0),
}); b_thread_buf);
}); });
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
...@@ -942,14 +940,14 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave, ...@@ -942,14 +940,14 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
a_thread_desc_, a_thread_desc_,
make_tuple(m0, I0, k0, I0), make_tuple(m0, I0, k0, I0),
a_thread_buf); a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) { });
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, static_for<0, NRepeat, 1>{}([&](auto n0) {
make_tuple(n0, I0, I0, Number<k0 * KPerInnerLoop>{}), b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
b_block_buf, make_tuple(n0, I0, I0, Number<k0 * KPerInnerLoop>{}),
b_thread_desc_, b_block_buf,
make_tuple(n0, I0, k0, I0), b_thread_desc_,
b_thread_buf); make_tuple(n0, I0, k0, I0),
}); b_thread_buf);
}); });
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
...@@ -1018,14 +1016,14 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave, ...@@ -1018,14 +1016,14 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
a_thread_desc_, a_thread_desc_,
make_tuple(m0, I0, k0, I0), make_tuple(m0, I0, k0, I0),
a_thread_buf); a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) { });
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, static_for<0, NRepeat, 1>{}([&](auto n0) {
make_tuple(n0, I0, I0, Number<k0 * KPerInnerLoop>{}), b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
b_block_buf, make_tuple(n0, I0, I0, Number<k0 * KPerInnerLoop>{}),
b_thread_desc_, b_block_buf,
make_tuple(n0, I0, k0, I0), b_thread_desc_,
b_thread_buf); make_tuple(n0, I0, k0, I0),
}); b_thread_buf);
}); });
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
......
...@@ -305,14 +305,14 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave, ...@@ -305,14 +305,14 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
a_thread_desc_, a_thread_desc_,
make_tuple(m0, I0, k, I0), make_tuple(m0, I0, k, I0),
a_thread_bufs(I0)); a_thread_bufs(I0));
static_for<0, NRepeat, 1>{}([&](auto n0) { });
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, static_for<0, NRepeat, 1>{}([&](auto n0) {
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}), b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
b_block_buf.At(I0), make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_thread_desc_, b_block_buf.At(I0),
make_tuple(n0, I0, k, I0), b_thread_desc_,
b_thread_bufs(I0)); make_tuple(n0, I0, k, I0),
}); b_thread_bufs(I0));
}); });
}); });
...@@ -356,15 +356,14 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave, ...@@ -356,15 +356,14 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
a_thread_desc_, a_thread_desc_,
make_tuple(m0, I0, k, I0), make_tuple(m0, I0, k, I0),
a_thread_bufs(lds_read_reg_buf)); a_thread_bufs(lds_read_reg_buf));
static_for<0, NRepeat, 1>{}([&](auto n0) { });
b_thread_copy_.Run( static_for<0, NRepeat, 1>{}([&](auto n0) {
b_block_desc_n0_n1_n2_k, b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}), make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf.At(lds_read_buf), b_block_buf.At(lds_read_buf),
b_thread_desc_, b_thread_desc_,
make_tuple(n0, I0, k, I0), make_tuple(n0, I0, k, I0),
b_thread_bufs(lds_read_reg_buf)); b_thread_bufs(lds_read_reg_buf));
});
}); });
}); });
...@@ -437,14 +436,14 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave, ...@@ -437,14 +436,14 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
a_thread_desc_, a_thread_desc_,
make_tuple(m0, I0, k, I0), make_tuple(m0, I0, k, I0),
a_thread_bufs(lds_read_reg_buf)); a_thread_bufs(lds_read_reg_buf));
static_for<0, NRepeat, 1>{}([&](auto n0) { });
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, static_for<0, NRepeat, 1>{}([&](auto n0) {
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}), b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
b_block_buf.At(lds_read_buf), make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_thread_desc_, b_block_buf.At(lds_read_buf),
make_tuple(n0, I0, k, I0), b_thread_desc_,
b_thread_bufs(lds_read_reg_buf)); make_tuple(n0, I0, k, I0),
}); b_thread_bufs(lds_read_reg_buf));
}); });
}); });
...@@ -496,14 +495,14 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave, ...@@ -496,14 +495,14 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
a_thread_desc_, a_thread_desc_,
make_tuple(m0, I0, k, I0), make_tuple(m0, I0, k, I0),
a_thread_bufs(lds_read_reg_buf)); a_thread_bufs(lds_read_reg_buf));
static_for<0, NRepeat, 1>{}([&](auto n0) { });
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, static_for<0, NRepeat, 1>{}([&](auto n0) {
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}), b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
b_block_buf.At(lds_read_buf), make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_thread_desc_, b_block_buf.At(lds_read_buf),
make_tuple(n0, I0, k, I0), b_thread_desc_,
b_thread_bufs(lds_read_reg_buf)); make_tuple(n0, I0, k, I0),
}); b_thread_bufs(lds_read_reg_buf));
}); });
}); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment