Commit a16ba4d9 authored by aska-0096's avatar aska-0096
Browse files

Performance regression

parent 00825b10
...@@ -66,7 +66,8 @@ struct BlockwiseGemmXdlops_pipeline_v1 ...@@ -66,7 +66,8 @@ struct BlockwiseGemmXdlops_pipeline_v1
static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2); static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2);
static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, TransposeC>{}; static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, TransposeC>{};
static constexpr auto xdlops_gemm_sp = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack/2, TransposeC>{}; static constexpr auto xdlops_gemm_sp =
XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack / 2, TransposeC>{};
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
static constexpr index_t KRepeat = KPerThread / KPack; static constexpr index_t KRepeat = KPerThread / KPack;
...@@ -339,7 +340,7 @@ struct BlockwiseGemmXdlops_pipeline_v1 ...@@ -339,7 +340,7 @@ struct BlockwiseGemmXdlops_pipeline_v1
a_thread_desc_.GetElementSpaceSize()); a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>( auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
b_thread_desc_.GetElementSpaceSize()); b_thread_desc_.GetElementSpaceSize());
// preload data into LDS // Global prefetch 1th
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
...@@ -347,9 +348,19 @@ struct BlockwiseGemmXdlops_pipeline_v1 ...@@ -347,9 +348,19 @@ struct BlockwiseGemmXdlops_pipeline_v1
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Initialize C // Initialize C
c_thread_buf.Clear(); c_thread_buf.Clear();
// Write to LDS immediately to save VGPR
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
__builtin_amdgcn_sched_barrier(0);
// Global prefetch 2th, Hold in VGPR
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Wait all wave produce this K-loop data // Wait all wave produce this K-loop data
block_sync_lds(); block_sync_lds();
...@@ -386,9 +397,6 @@ struct BlockwiseGemmXdlops_pipeline_v1 ...@@ -386,9 +397,6 @@ struct BlockwiseGemmXdlops_pipeline_v1
do do
{ {
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
// Here only KRepeat-1 times read (1~KRepeat) & compute (0~KRepat-1) of this k-loop // Here only KRepeat-1 times read (1~KRepeat) & compute (0~KRepat-1) of this k-loop
static_for<1, KRepeat, 1>{}([&](auto k) { // k=1,2 instead of kpack*1, ... static_for<1, KRepeat, 1>{}([&](auto k) { // k=1,2 instead of kpack*1, ...
/* Read N+1 */ /* Read N+1 */
...@@ -410,7 +418,6 @@ struct BlockwiseGemmXdlops_pipeline_v1 ...@@ -410,7 +418,6 @@ struct BlockwiseGemmXdlops_pipeline_v1
b_thread_desc_, b_thread_desc_,
make_tuple(n0, Number<k % 2>{}, I0, I0), make_tuple(n0, Number<k % 2>{}, I0, I0),
b_thread_buf); b_thread_buf);
}); });
/* Compute N */ /* Compute N */
...@@ -450,20 +457,26 @@ struct BlockwiseGemmXdlops_pipeline_v1 ...@@ -450,20 +457,26 @@ struct BlockwiseGemmXdlops_pipeline_v1
__builtin_amdgcn_sched_group_barrier(0x100, 3, 0); // DS read __builtin_amdgcn_sched_group_barrier(0x100, 3, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 6, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x008, 6, 0); // MFMA
*/ */
__builtin_amdgcn_sched_group_barrier(0x020, 6, 0); // VMEM read __builtin_amdgcn_sched_group_barrier(0x100, 6, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x100, 6, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 16, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x008, 16, 0); // MFMA
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
// Wait all wave consume this k-loop data // Wait all wave consume this k-loop data
block_sync_lds(); block_sync_lds();
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); // Write 2th prefetch k-loop
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
/* We have entire loop to cover buffer_load latency */
// Next Global prefetch
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
++i; ++i;
// compute(idx=KRepeat) this K-loop can hide ds_write latency // compute(idx=KRepeat) this K-loop can hide ds_write latency
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
...@@ -472,7 +485,7 @@ struct BlockwiseGemmXdlops_pipeline_v1 ...@@ -472,7 +485,7 @@ struct BlockwiseGemmXdlops_pipeline_v1
vector_type<FloatAB, KPack> a_thread_vec; vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec; vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack/2, 1>{}([&](auto ik) { static_for<0, KPack / 2, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) = a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, (KRepeat - 1) % 2, 0, ik))>{}]; make_tuple(m0, (KRepeat - 1) % 2, 0, ik))>{}];
...@@ -493,11 +506,12 @@ struct BlockwiseGemmXdlops_pipeline_v1 ...@@ -493,11 +506,12 @@ struct BlockwiseGemmXdlops_pipeline_v1
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})); c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
}); });
}); });
// Wait all wave produce next k-loop data // __builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 6, 0); // DS write __builtin_amdgcn_sched_group_barrier(0x200, 6, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x020, 6, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 8, 0); // MFMA
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
// Wait all wave produce next k-loop data
block_sync_lds(); block_sync_lds();
// Here 1 time prefetch read(idx=0) of next K-loop // Here 1 time prefetch read(idx=0) of next K-loop
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
...@@ -508,7 +522,6 @@ struct BlockwiseGemmXdlops_pipeline_v1 ...@@ -508,7 +522,6 @@ struct BlockwiseGemmXdlops_pipeline_v1
a_thread_desc_, a_thread_desc_,
make_tuple(m0, I0, I0, I0), make_tuple(m0, I0, I0, I0),
a_thread_buf); a_thread_buf);
}); });
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B // read B
...@@ -525,13 +538,13 @@ struct BlockwiseGemmXdlops_pipeline_v1 ...@@ -525,13 +538,13 @@ struct BlockwiseGemmXdlops_pipeline_v1
vector_type<FloatAB, KPack> a_thread_vec; vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec; vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack/2, 1>{}([&](auto ik) { static_for<0, KPack / 2, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) = a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, (KRepeat - 1) % 2, 0, ik+KPack/2))>{}]; make_tuple(m0, (KRepeat - 1) % 2, 0, ik + KPack / 2))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) = b_thread_vec.template AsType<FloatAB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset( b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, (KRepeat - 1) % 2, 0, ik+KPack/2))>{}]; make_tuple(n0, (KRepeat - 1) % 2, 0, ik + KPack / 2))>{}];
}); });
using mfma_input_type = using mfma_input_type =
...@@ -556,11 +569,12 @@ struct BlockwiseGemmXdlops_pipeline_v1 ...@@ -556,11 +569,12 @@ struct BlockwiseGemmXdlops_pipeline_v1
// __builtin_amdgcn_sched_group_barrier(0x008, 8, 0); // MFMA // __builtin_amdgcn_sched_group_barrier(0x008, 8, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 6, 0); // DS read __builtin_amdgcn_sched_group_barrier(0x100, 6, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 8, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x008, 8, 0); // MFMA
} while(i < (num_loop - 1)); } while(i < (num_loop - 2));
} }
// tail // tail
{ {
// N-2 th k-loop
// Here only KRepeat-1 times read & compute // Here only KRepeat-1 times read & compute
static_for<1, KRepeat, 1>{}([&](auto k) { // k=1,2 instead of kpack*1, ... static_for<1, KRepeat, 1>{}([&](auto k) { // k=1,2 instead of kpack*1, ...
/* Read N+1 */ /* Read N+1 */
...@@ -581,7 +595,6 @@ struct BlockwiseGemmXdlops_pipeline_v1 ...@@ -581,7 +595,6 @@ struct BlockwiseGemmXdlops_pipeline_v1
b_thread_desc_, b_thread_desc_,
make_tuple(n0, Number<k % 2>{}, I0, I0), make_tuple(n0, Number<k % 2>{}, I0, I0),
b_thread_buf); b_thread_buf);
}); });
/* Compute N */ /* Compute N */
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
...@@ -639,6 +652,151 @@ struct BlockwiseGemmXdlops_pipeline_v1 ...@@ -639,6 +652,151 @@ struct BlockwiseGemmXdlops_pipeline_v1
}); });
}); });
__builtin_amdgcn_sched_group_barrier(0x100, 6, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 16, 0); // MFMA
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
// compute(idx=KRepeat) this K-loop can hide ds_write latency
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
/* Compute N */
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack / 2, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, (KRepeat - 1) % 2, 0, ik))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, (KRepeat - 1) % 2, 0, ik))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm_sp.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
__builtin_amdgcn_sched_group_barrier(0x200, 6, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 8, 0); // MFMA
__builtin_amdgcn_sched_barrier(0);
// wait final data producer finished
block_sync_lds();
// N-1 th k-loop
// Here 1 time prefetch read(idx=0) of next K-loop
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, I0),
a_thread_buf);
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(n0, I0, I0, I0),
b_thread_buf);
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
/* Compute N */
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack / 2, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, (KRepeat - 1) % 2, 0, ik + KPack / 2))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, (KRepeat - 1) % 2, 0, ik + KPack / 2))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm_sp.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
__builtin_amdgcn_sched_group_barrier(0x100, 6, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 8, 0); // MFMA
__builtin_amdgcn_sched_barrier(0);
// Here only KRepeat-1 times read & compute
static_for<1, KRepeat, 1>{}([&](auto k) { // k=1,2 instead of kpack*1, ...
/* Read N+1 */
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, Number<k % 2>{}, I0, I0),
a_thread_buf);
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(n0, Number<k % 2>{}, I0, I0),
b_thread_buf);
});
/* Compute N */
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, (k - 1) % 2, 0, ik))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, (k - 1) % 2, 0, ik))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
__builtin_amdgcn_sched_group_barrier(0x100, 6, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 16, 0); // MFMA
__builtin_amdgcn_sched_barrier(0);
/* Final Compute issue */ /* Final Compute issue */
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
...@@ -654,32 +812,6 @@ struct BlockwiseGemmXdlops_pipeline_v1 ...@@ -654,32 +812,6 @@ struct BlockwiseGemmXdlops_pipeline_v1
make_tuple(n0, (KRepeat - 1) % 2, 0, i))>{}]; make_tuple(n0, (KRepeat - 1) % 2, 0, i))>{}];
}); });
#if 0
if(get_thread_local_1d_id() == 0)
{
printf(
"rep of m.n.k (%01d, %01d, %01d)\n", m0.value, n0.value, KRepeat - 1);
}
printf("Tid: %03d, A_compute_buf: %04x %04x %04x %04x %04x %04x %04x %04x\n",
get_thread_local_1d_id(),
*(reinterpret_cast<uint16_t*>(
&(a_thread_vec.template AsType<FloatAB>()(Number<0>{})))),
*(reinterpret_cast<uint16_t*>(
&(a_thread_vec.template AsType<FloatAB>()(Number<1>{})))),
*(reinterpret_cast<uint16_t*>(
&(a_thread_vec.template AsType<FloatAB>()(Number<2>{})))),
*(reinterpret_cast<uint16_t*>(
&(a_thread_vec.template AsType<FloatAB>()(Number<3>{})))),
*(reinterpret_cast<uint16_t*>(
&(a_thread_vec.template AsType<FloatAB>()(Number<4>{})))),
*(reinterpret_cast<uint16_t*>(
&(a_thread_vec.template AsType<FloatAB>()(Number<5>{})))),
*(reinterpret_cast<uint16_t*>(
&(a_thread_vec.template AsType<FloatAB>()(Number<6>{})))),
*(reinterpret_cast<uint16_t*>(
&(a_thread_vec.template AsType<FloatAB>()(Number<7>{})))));
#endif
using mfma_input_type = using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type; typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment