Commit 428ae72a authored by ltqin's avatar ltqin
Browse files

using MultiK0 control b load data loop

parent 993ec45c
...@@ -112,7 +112,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 ...@@ -112,7 +112,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
static constexpr auto I6 = Number<6>{}; static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{}; static constexpr auto I7 = Number<7>{};
static constexpr auto MultiK0 = 2; static constexpr auto MultiK0 = 2 * 1;
// K1 should be Number<...> // K1 should be Number<...>
static constexpr auto K1 = Number<K1Value>{}; static constexpr auto K1 = Number<K1Value>{};
...@@ -543,40 +543,46 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 ...@@ -543,40 +543,46 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
// main body // main body
if constexpr(HasMainK0BlockLoop) if constexpr(HasMainK0BlockLoop)
{ {
index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / (2 * K0PerBlock)); index_t K0BlockMainLoop =
index_t i = 0; __builtin_amdgcn_readfirstlane(K0 / (MultiK0 * K0PerBlock));
index_t i = 0;
do do
{ {
a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_odd_buf);
blockwise_gemm.ResetABlockStartWindow(); blockwise_gemm.ResetABlockStartWindow();
block_sync_lds(); block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
// only move b windows static_for<0, MultiK0, 2>{}([&](auto) {
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step); b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_odd_buf);
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, // only move b windows
b_grid_buf, b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_thread_slice_copy_step);
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_even_buf);
blockwise_gemm.MoveABlockSliceWindow(); b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf); b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_even_buf);
blockwise_gemm.MoveABlockSliceWindow();
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
blockwise_gemm.MoveABlockSliceWindow();
});
block_sync_lds(); block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf);
// move a and b window // move a and b window
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1,
a_block_slice_copy_step); a_block_slice_copy_step);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
i += 1; i += 1;
} while(i < (K0BlockMainLoop - 1)); } while(i < (K0BlockMainLoop - 1));
...@@ -586,18 +592,37 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 ...@@ -586,18 +592,37 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
{ {
block_sync_lds(); block_sync_lds();
// block_sync_lds();
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_odd_buf);
blockwise_gemm.ResetABlockStartWindow(); blockwise_gemm.ResetABlockStartWindow();
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); static_for<0, MultiK0, 2>{}([&](auto i) {
// block_sync_lds();
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_odd_buf);
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
// block_sync_lds(); if constexpr(i < MultiK0 - 2)
blockwise_gemm.MoveABlockSliceWindow(); { // only move b windows
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf); b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_even_buf);
}
// block_sync_lds();
blockwise_gemm.MoveABlockSliceWindow();
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf);
if constexpr(i < MultiK0 - 2)
{
blockwise_gemm.MoveABlockSliceWindow();
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
}
});
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment