Commit 1377186f authored by wangshaojie6's avatar wangshaojie6
Browse files

4 stage prefetch

parent 64fbf5a2
...@@ -567,16 +567,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_all_lds_v1 ...@@ -567,16 +567,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_all_lds_v1
// preload data to regiester and LDS // preload data to regiester and LDS
{ {
// Read // Read
a_threadwise_copy.Run(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_grid_buf,
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
a_thread_buf_0);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf, b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf_0); b_thread_buf_0);
a_threadwise_copy.Run(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_grid_buf,
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
a_thread_buf_0);
// Move // Move
a_threadwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3, a_threadwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
...@@ -585,16 +586,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_all_lds_v1 ...@@ -585,16 +586,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_all_lds_v1
b_thread_slice_copy_step); b_thread_slice_copy_step);
// Read // Read
a_threadwise_copy.Run(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_grid_buf,
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
a_thread_buf_1);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf, b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf_1); b_thread_buf_1);
a_threadwise_copy.Run(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_grid_buf,
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
a_thread_buf_1);
// Move // Move
a_threadwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3, a_threadwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
...@@ -603,16 +604,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_all_lds_v1 ...@@ -603,16 +604,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_all_lds_v1
b_thread_slice_copy_step); b_thread_slice_copy_step);
// Read // Read
a_threadwise_copy.Run(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_grid_buf,
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
a_thread_buf_2);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf, b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf_2); b_thread_buf_2);
a_threadwise_copy.Run(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_grid_buf,
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
a_thread_buf_2);
// Move // Move
a_threadwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3, a_threadwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
...@@ -620,6 +622,20 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_all_lds_v1 ...@@ -620,6 +622,20 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_all_lds_v1
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step); b_thread_slice_copy_step);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf_3);
a_threadwise_copy.Run(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_grid_buf,
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
a_thread_buf_3);
// Initialize C // Initialize C
c_thread_buf.Clear(); c_thread_buf.Clear();
// a data write to lds // a data write to lds
...@@ -632,17 +648,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_all_lds_v1 ...@@ -632,17 +648,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_all_lds_v1
index_t i = 0; index_t i = 0;
do do
{ {
static_for<0, MultiK0, 4>{}([&](auto) { {
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf_3);
a_threadwise_copy.Run(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_grid_buf,
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
a_thread_buf_3);
blockwise_gemm.Run(a_thread_buf_0, b_thread_buf_0, c_thread_buf); blockwise_gemm.Run(a_thread_buf_0, b_thread_buf_0, c_thread_buf);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
...@@ -703,7 +709,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_all_lds_v1 ...@@ -703,7 +709,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_all_lds_v1
b_thread_slice_copy_step); b_thread_slice_copy_step);
a_threadwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3, a_threadwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_thread_slice_copy_step); a_thread_slice_copy_step);
});
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf_3);
a_threadwise_copy.Run(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_grid_buf,
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
a_thread_buf_3);
}
i += 1; i += 1;
} while(i < (K0BlockMainLoop - 1)); } while(i < (K0BlockMainLoop - 1));
...@@ -713,16 +730,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_all_lds_v1 ...@@ -713,16 +730,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_all_lds_v1
{ {
static_for<0, MultiK0, 4>{}([&](auto i) { static_for<0, MultiK0, 4>{}([&](auto i) {
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf_3);
a_threadwise_copy.Run(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_grid_buf,
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
a_thread_buf_3);
blockwise_gemm.Run(a_thread_buf_0, b_thread_buf_0, c_thread_buf); blockwise_gemm.Run(a_thread_buf_0, b_thread_buf_0, c_thread_buf);
...@@ -786,6 +794,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_all_lds_v1 ...@@ -786,6 +794,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_all_lds_v1
b_thread_slice_copy_step); b_thread_slice_copy_step);
a_threadwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3, a_threadwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_thread_slice_copy_step); a_thread_slice_copy_step);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf_3);
a_threadwise_copy.Run(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_grid_buf,
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
a_thread_buf_3);
} }
}); });
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment