"...composable_kernel.git" did not exist on "6ef034f6cad7ec70b3a06518bec7fef8def11d51"
Commit f98de64a authored by ltqin's avatar ltqin
Browse files

regular code

parent 6e3b47c3
...@@ -442,7 +442,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 ...@@ -442,7 +442,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3.GetElementSpaceSize(), a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3.GetElementSpaceSize(),
true>{}; true>{};
}, },
Number<4>{}); Number<BaseMultK0>{});
const auto wave_k_m_id = GetWaveKMIdx(wave_id[I2]); const auto wave_k_m_id = GetWaveKMIdx(wave_id[I2]);
auto a_threadwise_copy = auto a_threadwise_copy =
...@@ -488,7 +488,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 ...@@ -488,7 +488,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetElementSpaceSize(), b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetElementSpaceSize(),
true>{}; true>{};
}, },
Number<4>{}); Number<BaseMultK0>{});
const auto wave_k_n_id = GetWaveKNIdx(wave_id[I2]); const auto wave_k_n_id = GetWaveKNIdx(wave_id[I2]);
...@@ -563,71 +563,54 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 ...@@ -563,71 +563,54 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
{ {
// Read // Read
auto read_first_half_data = [&]() { auto read_first_half_data = [&]() {
a_threadwise_copy.Run(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3, static_for<0, BaseMultK0 / 2, 1>{}([&](auto i) {
a_grid_buf, a_threadwise_copy.Run(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3, a_grid_buf,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_thread_buf(Number<0>{})); make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
a_threadwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3, a_thread_buf(Number<i>{}));
a_thread_slice_copy_step); a_threadwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_thread_slice_copy_step);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf, b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_grid_buf,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_buf(Number<0>{})); make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_thread_buf(Number<i>{}));
b_thread_slice_copy_step); b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
a_threadwise_copy.Run(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3, });
a_grid_buf,
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
a_thread_buf(Number<1>{}));
a_threadwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_thread_slice_copy_step);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<1>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
}; };
auto read_last_half_data = [&]() { auto read_last_half_data = [&]() {
a_threadwise_copy.Run(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3, static_for<BaseMultK0 / 2, BaseMultK0, 1>{}([&](auto i) {
a_grid_buf, a_threadwise_copy.Run(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3, a_grid_buf,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_thread_buf(Number<2>{})); make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
a_threadwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3, a_thread_buf(Number<i>{}));
a_thread_slice_copy_step); a_threadwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_thread_slice_copy_step);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf, b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_grid_buf,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_buf(Number<2>{})); make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_thread_buf(Number<i>{}));
b_thread_slice_copy_step); b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
a_threadwise_copy.Run(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3, });
a_grid_buf, };
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3, auto run_first_half_gemm = [&]() {
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), static_for<0, BaseMultK0 / 2, 1>{}([&](auto i) {
a_thread_buf(Number<3>{})); blockwise_gemm.Run(
a_threadwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3, a_thread_buf(Number<i>{}), b_thread_buf(Number<i>{}), c_thread_buf);
a_thread_slice_copy_step); });
};
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, auto run_last_half_gemm = [&]() {
b_grid_buf, static_for<BaseMultK0 / 2, BaseMultK0, 1>{}([&](auto i) {
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3, blockwise_gemm.Run(
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), a_thread_buf(Number<i>{}), b_thread_buf(Number<i>{}), c_thread_buf);
b_thread_buf(Number<3>{})); });
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
}; };
read_first_half_data(); read_first_half_data();
// Initialize C // Initialize C
...@@ -641,24 +624,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 ...@@ -641,24 +624,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
index_t i = 0; index_t i = 0;
do do
{ {
// 1st
read_last_half_data(); read_last_half_data();
s_nop(); s_nop();
run_first_half_gemm();
blockwise_gemm.Run(
a_thread_buf(Number<0>{}), b_thread_buf(Number<0>{}), c_thread_buf);
blockwise_gemm.Run(
a_thread_buf(Number<1>{}), b_thread_buf(Number<1>{}), c_thread_buf);
read_first_half_data(); read_first_half_data();
s_nop(); s_nop();
run_last_half_gemm();
blockwise_gemm.Run(
a_thread_buf(Number<2>{}), b_thread_buf(Number<2>{}), c_thread_buf);
blockwise_gemm.Run(
a_thread_buf(Number<3>{}), b_thread_buf(Number<3>{}), c_thread_buf);
i += 1; i += 1;
} while(i < (K0BlockMainLoop - 1)); } while(i < (K0BlockMainLoop - 1));
...@@ -666,18 +637,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 ...@@ -666,18 +637,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
// tail // tail
{ {
// 1st
read_last_half_data(); read_last_half_data();
s_nop(); s_nop();
blockwise_gemm.Run( run_first_half_gemm();
a_thread_buf(Number<0>{}), b_thread_buf(Number<0>{}), c_thread_buf); run_last_half_gemm();
blockwise_gemm.Run(
a_thread_buf(Number<1>{}), b_thread_buf(Number<1>{}), c_thread_buf);
blockwise_gemm.Run(
a_thread_buf(Number<2>{}), b_thread_buf(Number<2>{}), c_thread_buf);
blockwise_gemm.Run(
a_thread_buf(Number<3>{}), b_thread_buf(Number<3>{}), c_thread_buf);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment