Commit a288c57c authored by valarLip's avatar valarLip
Browse files

update

parent cf646183
...@@ -435,26 +435,44 @@ struct FusedMoeGemmPipeline_Flatmm ...@@ -435,26 +435,44 @@ struct FusedMoeGemmPipeline_Flatmm
static_for<0, total_loops, 1>{}([&](auto i_issue) { static_for<0, total_loops, 1>{}([&](auto i_issue) {
gemm_0(acc_0, as[I0], gs[I0], i_issue); gemm_0(acc_0, as[I0], gs[I0], i_issue);
if constexpr(i_issue % mfma_per_gld_g == 0) if constexpr(i_issue % mfma_per_gld_g == 0)
{
gld_g(gs[I1], number<i_issue / mfma_per_gld_g>{}); gld_g(gs[I1], number<i_issue / mfma_per_gld_g>{});
move_g();
}
if constexpr(i_issue % mfma_per_gld_a == 0) if constexpr(i_issue % mfma_per_gld_a == 0)
{
gld_a(a_sst_win0, number<i_issue / mfma_per_gld_a>{}); gld_a(a_sst_win0, number<i_issue / mfma_per_gld_a>{});
move_a();
}
if constexpr(i_issue % mfma_per_sld_a == 0) if constexpr(i_issue % mfma_per_sld_a == 0)
{
block_sync_lds();
sld_a(as[I1], a_sld_win1, number<i_issue / mfma_per_sld_a>{}); sld_a(as[I1], a_sld_win1, number<i_issue / mfma_per_sld_a>{});
}
}); });
// compute buffer 1 // compute buffer 1
static_for<0, total_loops, 1>{}([&](auto i_issue) { static_for<0, total_loops, 1>{}([&](auto i_issue) {
gemm_0(acc_0, as[I1], gs[I1], i_issue); gemm_0(acc_0, as[I1], gs[I1], i_issue);
if constexpr(i_issue % mfma_per_gld_g == 0) if constexpr(i_issue % mfma_per_gld_g == 0)
{
gld_g(gs[I0], number<i_issue / mfma_per_gld_g>{}); gld_g(gs[I0], number<i_issue / mfma_per_gld_g>{});
move_g();
}
if constexpr(i_issue % mfma_per_gld_a == 0) if constexpr(i_issue % mfma_per_gld_a == 0)
{
gld_a(a_sst_win1, number<i_issue / mfma_per_gld_a>{}); gld_a(a_sst_win1, number<i_issue / mfma_per_gld_a>{});
move_a();
}
if constexpr(i_issue % mfma_per_sld_a == 0) if constexpr(i_issue % mfma_per_sld_a == 0)
{
block_sync_lds();
sld_a(as[I0], a_sld_win0, number<i_issue / mfma_per_sld_a>{}); sld_a(as[I0], a_sld_win0, number<i_issue / mfma_per_sld_a>{});
}
}); });
}; };
...@@ -564,15 +582,19 @@ struct FusedMoeGemmPipeline_Flatmm ...@@ -564,15 +582,19 @@ struct FusedMoeGemmPipeline_Flatmm
// clang-format off // clang-format off
gld_a(a_sst_win0, NEG1, TRUE); gld_a(a_sst_win0, NEG1, TRUE);
gld_g(gs[I0], NEG1, TRUE); gld_g(gs[I0], NEG1, TRUE);
move_a();
move_g();
clear_tile(acc_0);
async_load_fence_raw(g_win.get_num_of_access());
sld_a(as[I0], a_sld_win0, NEG1); sld_a(as[I0], a_sld_win0, NEG1);
gld_a(a_sst_win1, NEG1); gld_a(a_sst_win1, NEG1);
clear_tile(acc_0);
// we manually unroll double buffer inside hot loop // we manually unroll double buffer inside hot loop
const index_t iters_0 = (num_blocks_k0 - 2) / 2; const index_t iters_0 = (num_blocks_k0 - 2) / 2;
index_t i_0 = 0; index_t i_0 = 0;
while(i_0 < iters_0) while(i_0++ < iters_0)
{ {
pipeline_gemm0(); pipeline_gemm0();
} }
...@@ -583,7 +605,7 @@ struct FusedMoeGemmPipeline_Flatmm ...@@ -583,7 +605,7 @@ struct FusedMoeGemmPipeline_Flatmm
const index_t iters_1 = (num_blocks_n1 - 2) / 2; const index_t iters_1 = (num_blocks_n1 - 2) / 2;
index_t i_1 = 0; index_t i_1 = 0;
pipeline_gemm1_head(); pipeline_gemm1_head();
while(i_1 < iters_1) while(i_1++ < iters_1)
{ {
pipeline_gemm1(); pipeline_gemm1();
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment