Commit e8a71150 authored by ltqin's avatar ltqin
Browse files

change to 8 buffer

parent 83e6a4b9
...@@ -112,8 +112,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 ...@@ -112,8 +112,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
static constexpr auto I6 = Number<6>{}; static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{}; static constexpr auto I7 = Number<7>{};
static constexpr auto BaseMultK0 = 4; static constexpr auto BaseMultK0 = 8;
static constexpr auto MultiK0 = BaseMultK0 * 2; static constexpr auto MultiK0 = BaseMultK0 * 1;
// K1 should be Number<...> // K1 should be Number<...>
static constexpr auto K1 = Number<K1Value>{}; static constexpr auto K1 = Number<K1Value>{};
...@@ -449,7 +449,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 ...@@ -449,7 +449,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetElementSpaceSize(), b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetElementSpaceSize(),
true>{}; true>{};
}, },
Number<4>{}); Number<8>{});
const auto wave_id = GetWaveIdx(); const auto wave_id = GetWaveIdx();
const auto wave_k_n_id = GetWaveKNIdx(wave_id[I2]); const auto wave_k_n_id = GetWaveKNIdx(wave_id[I2]);
...@@ -529,31 +529,44 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 ...@@ -529,31 +529,44 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
{ {
// Read // Read
a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf, b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<0>{})); b_thread_buf(Number<0>{}));
// Move
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step); b_thread_slice_copy_step);
// Initialize C
c_thread_buf.Clear();
// a data write to lds
a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf);
// load 2nd a matrix data
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf, b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<1>{})); b_thread_buf(Number<1>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<2>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<3>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step); b_thread_slice_copy_step);
// Initialize C
c_thread_buf.Clear();
// a data write to lds
a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf);
// main body // main body
if constexpr(HasMainK0BlockLoop) if constexpr(HasMainK0BlockLoop)
{ {
...@@ -567,31 +580,49 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 ...@@ -567,31 +580,49 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
block_sync_lds(); block_sync_lds();
static_for<0, MultiK0, BaseMultK0>{}([&](auto) { static_for<0, MultiK0, BaseMultK0>{}([&](auto) {
// 1st
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf, b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<2>{})); b_thread_buf(Number<4>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step); b_thread_slice_copy_step);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf, b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<3>{})); b_thread_buf(Number<5>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step); b_thread_slice_copy_step);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<6>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<7>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
s_nop(); s_nop();
blockwise_gemm.Run(a_block_buf, b_thread_buf(Number<0>{}), c_thread_buf); blockwise_gemm.Run(a_block_buf, b_thread_buf(Number<0>{}), c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow(); blockwise_gemm.MoveABlockSliceWindow();
// 2nd
blockwise_gemm.Run(a_block_buf, b_thread_buf(Number<1>{}), c_thread_buf); blockwise_gemm.Run(a_block_buf, b_thread_buf(Number<1>{}), c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow(); blockwise_gemm.MoveABlockSliceWindow();
// 3rd blockwise_gemm.Run(a_block_buf, b_thread_buf(Number<2>{}), c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
blockwise_gemm.Run(a_block_buf, b_thread_buf(Number<3>{}), c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf, b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
...@@ -606,13 +637,33 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 ...@@ -606,13 +637,33 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
b_thread_buf(Number<1>{})); b_thread_buf(Number<1>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step); b_thread_slice_copy_step);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<2>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<3>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
s_nop(); s_nop();
blockwise_gemm.Run(a_block_buf, b_thread_buf(Number<2>{}), c_thread_buf); blockwise_gemm.Run(a_block_buf, b_thread_buf(Number<4>{}), c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow(); blockwise_gemm.MoveABlockSliceWindow();
// 4th blockwise_gemm.Run(a_block_buf, b_thread_buf(Number<5>{}), c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
blockwise_gemm.Run(a_block_buf, b_thread_buf(Number<3>{}), c_thread_buf); blockwise_gemm.Run(a_block_buf, b_thread_buf(Number<6>{}), c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
blockwise_gemm.Run(a_block_buf, b_thread_buf(Number<7>{}), c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow(); blockwise_gemm.MoveABlockSliceWindow();
}); });
...@@ -632,34 +683,51 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 ...@@ -632,34 +683,51 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
blockwise_gemm.ResetABlockStartWindow(); blockwise_gemm.ResetABlockStartWindow();
static_for<0, MultiK0, BaseMultK0>{}([&](auto i) { static_for<0, MultiK0, BaseMultK0>{}([&](auto i) {
// 1st
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf, b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<2>{})); b_thread_buf(Number<4>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<5>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step); b_thread_slice_copy_step);
blockwise_gemm.Run(a_block_buf, b_thread_buf(Number<0>{}), c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
// 2nd
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf, b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<3>{})); b_thread_buf(Number<6>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<7>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step); b_thread_slice_copy_step);
s_nop();
blockwise_gemm.Run(a_block_buf, b_thread_buf(Number<0>{}), c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
blockwise_gemm.Run(a_block_buf, b_thread_buf(Number<1>{}), c_thread_buf); blockwise_gemm.Run(a_block_buf, b_thread_buf(Number<1>{}), c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow(); blockwise_gemm.MoveABlockSliceWindow();
// 3rd blockwise_gemm.Run(a_block_buf, b_thread_buf(Number<2>{}), c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
blockwise_gemm.Run(a_block_buf, b_thread_buf(Number<3>{}), c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
if constexpr(i < MultiK0 - BaseMultK0) if constexpr(i < MultiK0 - BaseMultK0)
{ {
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf, b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
...@@ -667,14 +735,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 ...@@ -667,14 +735,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
b_thread_buf(Number<0>{})); b_thread_buf(Number<0>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step); b_thread_slice_copy_step);
}
blockwise_gemm.Run(a_block_buf, b_thread_buf(Number<2>{}), c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
// 4th
if constexpr(i < MultiK0 - BaseMultK0)
{
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf, b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
...@@ -682,9 +742,34 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 ...@@ -682,9 +742,34 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
b_thread_buf(Number<1>{})); b_thread_buf(Number<1>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step); b_thread_slice_copy_step);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<2>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<3>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
} }
blockwise_gemm.Run(a_block_buf, b_thread_buf(Number<3>{}), c_thread_buf); s_nop();
blockwise_gemm.Run(a_block_buf, b_thread_buf(Number<4>{}), c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
blockwise_gemm.Run(a_block_buf, b_thread_buf(Number<5>{}), c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
blockwise_gemm.Run(a_block_buf, b_thread_buf(Number<6>{}), c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
blockwise_gemm.Run(a_block_buf, b_thread_buf(Number<7>{}), c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow(); blockwise_gemm.MoveABlockSliceWindow();
}); });
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment