Commit 3bb718ad authored by valarLip's avatar valarLip
Browse files

update pipeline_gemm0

parent c6c3c142
......@@ -640,6 +640,11 @@ CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0)
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
}
CK_TILE_DEVICE void lds_load_fence(index_t cnt = 0)
{
asm volatile("s_waitcnt lgkmcnt(%0)" : : "n"(cnt) : "memory");
}
template <typename scalar_type, index_t N, bool pre_nop = false>
struct buffer_atomic_add_if;
......
......@@ -73,6 +73,24 @@ CK_TILE_DEVICE void block_sync_lds()
#endif
}
CK_TILE_DEVICE void block_sync_load_raw(index_t cnt = 0)
{
#ifdef __gfx12__
asm volatile("s_wait_loadcnt %0 \n"
"s_barrier_signal -1 \n"
"s_barrier_wait -1"
:
: "n"(cnt)
: "memory");
#else
asm volatile("s_waitcnt vmcnt(%0) \n"
"s_barrier"
:
: "n"(cnt)
: "memory");
#endif
}
CK_TILE_DEVICE void block_sync_lds_direct_load()
{
asm volatile("\
......
......@@ -260,9 +260,9 @@ struct FusedMoeGemmPipeline_Flatmm
{
async_load_tile_raw(a_store_, a_win, i_access, PreNop{});
};
// auto move_a = [&]() {
// move_tile_window(a_win, {number<0>{}, number<BlockShape::Block_K0>{}});
// };
auto move_a = [&]() {
move_tile_window(a_win, {number<0>{}, number<BlockShape::Block_K0>{}});
};
auto sld_a = [&](auto& a_, auto& win_, auto i_access) {
load_tile_raw(a_, win_, i_access);
};
......@@ -284,11 +284,11 @@ struct FusedMoeGemmPipeline_Flatmm
}
load_tile_raw(g_, g_win, i_access, FALSE, PreNop{});
};
// auto move_g =
// [&]() {
// move_tile_window(g_win,
// {number<0>{}, number<BlockShape::Block_Kr0>{}, number<0>{}});
// };
auto move_g =
[&]() {
move_tile_window(g_win,
{number<0>{}, number<BlockShape::Block_Kr0>{}, number<0>{}});
};
statically_indexed_array<d_thread_type, 2> ds;
auto gld_d = [&]<typename PreNop = bool_constant<false>>(
......@@ -296,10 +296,10 @@ struct FusedMoeGemmPipeline_Flatmm
{
load_tile_raw(d_, d_win, i_access, FALSE, PreNop{});
};
// auto move_d = [&]() {
// // d move along gemm-n
// move_tile_window(d_win, {number<BlockShape::Block_N1>{}, number<0>{}});
// };
auto move_d = [&]() {
// d move along gemm-n
move_tile_window(d_win, {number<BlockShape::Block_N1>{}, number<0>{}});
};
auto atomic_add_o = [&]<typename PreNop = bool_constant<false>>(
auto& o_, auto i_access, PreNop = {})
......@@ -427,53 +427,66 @@ struct FusedMoeGemmPipeline_Flatmm
// mfma(that can reuse the B matrix) only affected by M repeat.
auto pipeline_gemm0 = [&]() {
constexpr index_t total_loops = issues_gemm0;
constexpr index_t mfma_per_gld_g = total_loops / issues_g; // BlockShape::Repeat_M0;
constexpr index_t mfma_per_gld_a = total_loops / issues_a;
constexpr index_t mfma_per_sld_a = total_loops / issues_sld_a;
constexpr index_t mfma_per_ld = total_loops / (issues_g + issues_a + issues_sld_a);
// compute buffer 0
static_for<0, total_loops, 1>{}([&](auto i_issue) {
gemm_0(acc_0, as[I0], gs[I0], i_issue);
if constexpr(i_issue % mfma_per_gld_g == 0)
if constexpr(i_issue % mfma_per_ld == 0)
{
gld_g(gs[I1], number<i_issue / mfma_per_gld_g>{});
move_g();
}
constexpr index_t ld_id = 0;
if constexpr(i_issue % mfma_per_gld_a == 0)
if constexpr(ld_id < issues_g)
{
gld_a(a_sst_win0, number<i_issue / mfma_per_gld_a>{});
move_a();
gld_g(gs[I0], number<ld_id>{});
}
if constexpr(i_issue % mfma_per_sld_a == 0)
if constexpr(ld_id - issues_g < +issues_a)
{
gld_a(a_sst_win0, number<ld_id - issues_g>{});
}
if constexpr(ld_id - issues_g - issues_a < issues_sld_a)
{
block_sync_lds();
sld_a(as[I1], a_sld_win1, number<i_issue / mfma_per_sld_a>{});
sld_a(as[I1], a_sld_win1, number<ld_id - issues_g - issues_a>{});
}
ld_id++;
}
});
move_g();
move_a();
block_sync_load_raw(issues_a + issues_g);
lds_load_fence();
// compute buffer 1
static_for<0, total_loops, 1>{}([&](auto i_issue) {
gemm_0(acc_0, as[I1], gs[I1], i_issue);
if constexpr(i_issue % mfma_per_gld_g == 0)
if constexpr(i_issue % mfma_per_ld == 0)
{
gld_g(gs[I0], number<i_issue / mfma_per_gld_g>{});
move_g();
}
constexpr index_t ld_id = 0;
if constexpr(i_issue % mfma_per_gld_a == 0)
if constexpr(ld_id < issues_g)
{
gld_a(a_sst_win1, number<i_issue / mfma_per_gld_a>{});
move_a();
gld_g(gs[I1], number<ld_id>{});
}
if constexpr(i_issue % mfma_per_sld_a == 0)
if constexpr(ld_id - issues_g < +issues_a)
{
gld_a(a_sst_win1, number<ld_id - issues_g>{});
}
if constexpr(ld_id - issues_g - issues_a < issues_sld_a)
{
block_sync_lds();
sld_a(as[I0], a_sld_win0, number<i_issue / mfma_per_sld_a>{});
sld_a(as[I0], a_sld_win0, number<ld_id - issues_g - issues_a>{});
}
ld_id++;
}
});
move_g();
move_a();
block_sync_load_raw(issues_a + issues_g);
lds_load_fence();
};
auto pipeline_gemm0_tail = [&]() {
......@@ -486,14 +499,23 @@ struct FusedMoeGemmPipeline_Flatmm
static_for<0, total_loops, 1>{}([&](auto i_issue) {
gemm_0(acc_0, as[I0], gs[I0], i_issue);
if constexpr(i_issue % mfma_per_gld_g == 0)
{
gld_g(gs[I1], number<i_issue / mfma_per_gld_g>{});
move_g();
}
// if constexpr (i_issue % mfma_per_gld_a == 0)
// gld_a(a_sst_win0, number<i_issue / mfma_per_gld_a>{});
if constexpr(i_issue % mfma_per_sld_a == 0)
sld_a(as[I1], a_sld_win1, number<i_issue / mfma_per_sld_a>{});
// if constexpr(i_issue % mfma_per_sld_a == 0)
// {
// block_sync_load_raw(a_sst_win0.get_num_of_access());
// sld_a(as[I1], a_sld_win1, number<i_issue / mfma_per_sld_a>{});
// }
});
// if cycle_mfma>gld_a sync here
block_sync_load_raw(issues_g);
sld_a(as[I1], a_sld_win1, NEG1{});
// compute buffer 1
static_for<0, total_loops, 1>{}([&](auto i_issue) {
......@@ -523,7 +545,10 @@ struct FusedMoeGemmPipeline_Flatmm
static_for<0, total_loops, 1>{}([&](auto i_issue) {
gemm_1(acc_1s[I1], y, ds[I1], i_issue);
if constexpr(i_issue % mfma_per_gld_d == 0)
{
gld_d(ds[I0], number<i_issue / mfma_per_gld_d>{});
move_d();
}
if constexpr(i_issue % mfma_per_atm_o == 0)
{
......@@ -536,7 +561,10 @@ struct FusedMoeGemmPipeline_Flatmm
static_for<0, total_loops, 1>{}([&](auto i_issue) {
gemm_1(acc_1s[I0], y, ds[I0], i_issue);
if constexpr(i_issue % mfma_per_gld_d == 0)
{
gld_d(ds[I1], number<i_issue / mfma_per_gld_d>{});
move_d();
}
if constexpr(i_issue % mfma_per_atm_o == 0)
{
......@@ -553,7 +581,10 @@ struct FusedMoeGemmPipeline_Flatmm
static_for<0, total_loops, 1>{}([&](auto i_issue) {
gemm_1(acc_1s[I0], y, ds[I0], i_issue);
if constexpr(i_issue % mfma_per_gld_d == 0)
{
gld_d(ds[I1], number<i_issue / mfma_per_gld_d>{});
move_d();
}
});
};
auto pipeline_gemm1_tail = [&]() {
......@@ -564,7 +595,10 @@ struct FusedMoeGemmPipeline_Flatmm
static_for<0, total_loops, 1>{}([&](auto i_issue) {
gemm_1(acc_1s[I1], y, ds[I1], i_issue);
if constexpr(i_issue % mfma_per_gld_d == 0)
{
gld_d(ds[I0], number<i_issue / mfma_per_gld_d>{});
move_d();
}
if constexpr(i_issue % mfma_per_atm_o == 0)
{
......@@ -586,10 +620,13 @@ struct FusedMoeGemmPipeline_Flatmm
move_g();
clear_tile(acc_0);
async_load_fence_raw(g_win.get_num_of_access());
sld_a(as[I0], a_sld_win0, NEG1);
// preload for next round
gld_a(a_sst_win1, NEG1);
gld_g(gs[I1], NEG1);
// make sure a,g loaded
block_sync_load_raw(issues_a + issues_g);
lds_load_fence();
// we manually unroll double buffer inside hot loop
const index_t iters_0 = (num_blocks_k0 - 2) / 2;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment