"megatron/git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "2387ce01c91e5ddcb91d86a335d993e0a664b9dd"
Commit b2634103 authored by ltqin's avatar ltqin
Browse files

simple change buffer number

parent b1dd76f3
...@@ -521,7 +521,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 ...@@ -521,7 +521,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
ignore = i; ignore = i;
return blockwise_gemm.AlloCAThreadBuff(); return blockwise_gemm.AlloCAThreadBuff();
}, },
Number<BaseMultK0/2>{}); Number<BaseMultK0 / 2>{});
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
...@@ -539,7 +539,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 ...@@ -539,7 +539,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step);
auto read_b_first_half_data = [&]() { auto read_b_first_half_data = [&]() {
static_for<0, MultiK0 / 2, 1>{}([&](auto ii) { static_for<0, BaseMultK0 / 2, 1>{}([&](auto ii) {
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf, b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
...@@ -550,7 +550,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 ...@@ -550,7 +550,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
}); });
}; };
auto read_b_last_half_data = [&]() { auto read_b_last_half_data = [&]() {
static_for<MultiK0 / 2, MultiK0, 1>{}([&](auto ii) { static_for<BaseMultK0 / 2, BaseMultK0, 1>{}([&](auto ii) {
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf, b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
...@@ -561,7 +561,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 ...@@ -561,7 +561,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
}); });
}; };
auto read_a_lds_data = [&]() { auto read_a_lds_data = [&]() {
static_for<0, MultiK0 / 2, 1>{}([&](auto ii) { static_for<0, BaseMultK0 / 2, 1>{}([&](auto ii) {
blockwise_gemm.ReadAThreadData(a_block_buf, a_thread_buf(Number<ii>{})); blockwise_gemm.ReadAThreadData(a_block_buf, a_thread_buf(Number<ii>{}));
blockwise_gemm.MoveABlockSliceWindow(); blockwise_gemm.MoveABlockSliceWindow();
}); });
...@@ -589,7 +589,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 ...@@ -589,7 +589,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
s_barrier(); s_barrier();
static_for<0, MultiK0 / 2, 1>{}([&](auto ii) { static_for<0, BaseMultK0 / 2, 1>{}([&](auto ii) {
blockwise_gemm.Run(a_thread_buf(Number<ii>{}), blockwise_gemm.Run(a_thread_buf(Number<ii>{}),
b_thread_buf(Number<ii>{}), b_thread_buf(Number<ii>{}),
c_thread_buf); c_thread_buf);
...@@ -600,8 +600,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 ...@@ -600,8 +600,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
s_barrier(); s_barrier();
static_for<MultiK0 / 2, MultiK0, 1>{}([&](auto ii) { static_for<BaseMultK0 / 2, BaseMultK0, 1>{}([&](auto ii) {
blockwise_gemm.Run(a_thread_buf(Number<ii - 4>{}), blockwise_gemm.Run(a_thread_buf(Number<ii - BaseMultK0 / 2>{}),
b_thread_buf(Number<ii>{}), b_thread_buf(Number<ii>{}),
c_thread_buf); c_thread_buf);
}); });
...@@ -628,7 +628,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 ...@@ -628,7 +628,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
s_barrier(); s_barrier();
static_for<0, MultiK0 / 2, 1>{}([&](auto ii) { static_for<0, BaseMultK0 / 2, 1>{}([&](auto ii) {
blockwise_gemm.Run( blockwise_gemm.Run(
a_thread_buf(Number<ii>{}), b_thread_buf(Number<ii>{}), c_thread_buf); a_thread_buf(Number<ii>{}), b_thread_buf(Number<ii>{}), c_thread_buf);
}); });
...@@ -642,8 +642,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 ...@@ -642,8 +642,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
s_barrier(); s_barrier();
static_for<MultiK0 / 2, MultiK0, 1>{}([&](auto ii) { static_for<BaseMultK0 / 2, BaseMultK0, 1>{}([&](auto ii) {
blockwise_gemm.Run(a_thread_buf(Number<ii - 4>{}), blockwise_gemm.Run(a_thread_buf(Number<ii - BaseMultK0 / 2>{}),
b_thread_buf(Number<ii>{}), b_thread_buf(Number<ii>{}),
c_thread_buf); c_thread_buf);
}); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment