Commit 3bb718ad authored by valarLip's avatar valarLip
Browse files

update pipeline_gemm0

parent c6c3c142
...@@ -640,6 +640,11 @@ CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0) ...@@ -640,6 +640,11 @@ CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0)
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
} }
CK_TILE_DEVICE void lds_load_fence(index_t cnt = 0)
{
asm volatile("s_waitcnt lgkmcnt(%0)" : : "n"(cnt) : "memory");
}
template <typename scalar_type, index_t N, bool pre_nop = false> template <typename scalar_type, index_t N, bool pre_nop = false>
struct buffer_atomic_add_if; struct buffer_atomic_add_if;
......
...@@ -73,6 +73,24 @@ CK_TILE_DEVICE void block_sync_lds() ...@@ -73,6 +73,24 @@ CK_TILE_DEVICE void block_sync_lds()
#endif #endif
} }
CK_TILE_DEVICE void block_sync_load_raw(index_t cnt = 0)
{
#ifdef __gfx12__
asm volatile("s_wait_loadcnt %0 \n"
"s_barrier_signal -1 \n"
"s_barrier_wait -1"
:
: "n"(cnt)
: "memory");
#else
asm volatile("s_waitcnt vmcnt(%0) \n"
"s_barrier"
:
: "n"(cnt)
: "memory");
#endif
}
CK_TILE_DEVICE void block_sync_lds_direct_load() CK_TILE_DEVICE void block_sync_lds_direct_load()
{ {
asm volatile("\ asm volatile("\
......
...@@ -260,9 +260,9 @@ struct FusedMoeGemmPipeline_Flatmm ...@@ -260,9 +260,9 @@ struct FusedMoeGemmPipeline_Flatmm
{ {
async_load_tile_raw(a_store_, a_win, i_access, PreNop{}); async_load_tile_raw(a_store_, a_win, i_access, PreNop{});
}; };
// auto move_a = [&]() { auto move_a = [&]() {
// move_tile_window(a_win, {number<0>{}, number<BlockShape::Block_K0>{}}); move_tile_window(a_win, {number<0>{}, number<BlockShape::Block_K0>{}});
// }; };
auto sld_a = [&](auto& a_, auto& win_, auto i_access) { auto sld_a = [&](auto& a_, auto& win_, auto i_access) {
load_tile_raw(a_, win_, i_access); load_tile_raw(a_, win_, i_access);
}; };
...@@ -284,11 +284,11 @@ struct FusedMoeGemmPipeline_Flatmm ...@@ -284,11 +284,11 @@ struct FusedMoeGemmPipeline_Flatmm
} }
load_tile_raw(g_, g_win, i_access, FALSE, PreNop{}); load_tile_raw(g_, g_win, i_access, FALSE, PreNop{});
}; };
// auto move_g = auto move_g =
// [&]() { [&]() {
// move_tile_window(g_win, move_tile_window(g_win,
// {number<0>{}, number<BlockShape::Block_Kr0>{}, number<0>{}}); {number<0>{}, number<BlockShape::Block_Kr0>{}, number<0>{}});
// }; };
statically_indexed_array<d_thread_type, 2> ds; statically_indexed_array<d_thread_type, 2> ds;
auto gld_d = [&]<typename PreNop = bool_constant<false>>( auto gld_d = [&]<typename PreNop = bool_constant<false>>(
...@@ -296,10 +296,10 @@ struct FusedMoeGemmPipeline_Flatmm ...@@ -296,10 +296,10 @@ struct FusedMoeGemmPipeline_Flatmm
{ {
load_tile_raw(d_, d_win, i_access, FALSE, PreNop{}); load_tile_raw(d_, d_win, i_access, FALSE, PreNop{});
}; };
// auto move_d = [&]() { auto move_d = [&]() {
// // d move along gemm-n // d move along gemm-n
// move_tile_window(d_win, {number<BlockShape::Block_N1>{}, number<0>{}}); move_tile_window(d_win, {number<BlockShape::Block_N1>{}, number<0>{}});
// }; };
auto atomic_add_o = [&]<typename PreNop = bool_constant<false>>( auto atomic_add_o = [&]<typename PreNop = bool_constant<false>>(
auto& o_, auto i_access, PreNop = {}) auto& o_, auto i_access, PreNop = {})
...@@ -427,53 +427,66 @@ struct FusedMoeGemmPipeline_Flatmm ...@@ -427,53 +427,66 @@ struct FusedMoeGemmPipeline_Flatmm
// mfma(that can reuse the B matrix) only affected by M repeat. // mfma(that can reuse the B matrix) only affected by M repeat.
auto pipeline_gemm0 = [&]() { auto pipeline_gemm0 = [&]() {
constexpr index_t total_loops = issues_gemm0; constexpr index_t total_loops = issues_gemm0;
constexpr index_t mfma_per_gld_g = total_loops / issues_g; // BlockShape::Repeat_M0; constexpr index_t mfma_per_ld = total_loops / (issues_g + issues_a + issues_sld_a);
constexpr index_t mfma_per_gld_a = total_loops / issues_a;
constexpr index_t mfma_per_sld_a = total_loops / issues_sld_a;
// compute buffer 0 // compute buffer 0
static_for<0, total_loops, 1>{}([&](auto i_issue) { static_for<0, total_loops, 1>{}([&](auto i_issue) {
gemm_0(acc_0, as[I0], gs[I0], i_issue); gemm_0(acc_0, as[I0], gs[I0], i_issue);
if constexpr(i_issue % mfma_per_gld_g == 0)
{
gld_g(gs[I1], number<i_issue / mfma_per_gld_g>{});
move_g();
}
if constexpr(i_issue % mfma_per_gld_a == 0) if constexpr(i_issue % mfma_per_ld == 0)
{ {
gld_a(a_sst_win0, number<i_issue / mfma_per_gld_a>{}); constexpr index_t ld_id = 0;
move_a();
if constexpr(ld_id < issues_g)
{
gld_g(gs[I0], number<ld_id>{});
}
if constexpr(ld_id - issues_g < +issues_a)
{
gld_a(a_sst_win0, number<ld_id - issues_g>{});
}
if constexpr(ld_id - issues_g - issues_a < issues_sld_a)
{
sld_a(as[I1], a_sld_win1, number<ld_id - issues_g - issues_a>{});
}
ld_id++;
} }
if constexpr(i_issue % mfma_per_sld_a == 0)
{
block_sync_lds();
sld_a(as[I1], a_sld_win1, number<i_issue / mfma_per_sld_a>{});
}
}); });
move_g();
move_a();
block_sync_load_raw(issues_a + issues_g);
lds_load_fence();
// compute buffer 1 // compute buffer 1
static_for<0, total_loops, 1>{}([&](auto i_issue) { static_for<0, total_loops, 1>{}([&](auto i_issue) {
gemm_0(acc_0, as[I1], gs[I1], i_issue); gemm_0(acc_0, as[I1], gs[I1], i_issue);
if constexpr(i_issue % mfma_per_gld_g == 0)
{
gld_g(gs[I0], number<i_issue / mfma_per_gld_g>{});
move_g();
}
if constexpr(i_issue % mfma_per_gld_a == 0)
{
gld_a(a_sst_win1, number<i_issue / mfma_per_gld_a>{});
move_a();
}
if constexpr(i_issue % mfma_per_sld_a == 0) if constexpr(i_issue % mfma_per_ld == 0)
{ {
block_sync_lds(); constexpr index_t ld_id = 0;
sld_a(as[I0], a_sld_win0, number<i_issue / mfma_per_sld_a>{});
if constexpr(ld_id < issues_g)
{
gld_g(gs[I1], number<ld_id>{});
}
if constexpr(ld_id - issues_g < +issues_a)
{
gld_a(a_sst_win1, number<ld_id - issues_g>{});
}
if constexpr(ld_id - issues_g - issues_a < issues_sld_a)
{
sld_a(as[I0], a_sld_win0, number<ld_id - issues_g - issues_a>{});
}
ld_id++;
} }
}); });
move_g();
move_a();
block_sync_load_raw(issues_a + issues_g);
lds_load_fence();
}; };
auto pipeline_gemm0_tail = [&]() { auto pipeline_gemm0_tail = [&]() {
...@@ -486,14 +499,23 @@ struct FusedMoeGemmPipeline_Flatmm ...@@ -486,14 +499,23 @@ struct FusedMoeGemmPipeline_Flatmm
static_for<0, total_loops, 1>{}([&](auto i_issue) { static_for<0, total_loops, 1>{}([&](auto i_issue) {
gemm_0(acc_0, as[I0], gs[I0], i_issue); gemm_0(acc_0, as[I0], gs[I0], i_issue);
if constexpr(i_issue % mfma_per_gld_g == 0) if constexpr(i_issue % mfma_per_gld_g == 0)
{
gld_g(gs[I1], number<i_issue / mfma_per_gld_g>{}); gld_g(gs[I1], number<i_issue / mfma_per_gld_g>{});
move_g();
}
// if constexpr (i_issue % mfma_per_gld_a == 0) // if constexpr (i_issue % mfma_per_gld_a == 0)
// gld_a(a_sst_win0, number<i_issue / mfma_per_gld_a>{}); // gld_a(a_sst_win0, number<i_issue / mfma_per_gld_a>{});
if constexpr(i_issue % mfma_per_sld_a == 0) // if constexpr(i_issue % mfma_per_sld_a == 0)
sld_a(as[I1], a_sld_win1, number<i_issue / mfma_per_sld_a>{}); // {
// block_sync_load_raw(a_sst_win0.get_num_of_access());
// sld_a(as[I1], a_sld_win1, number<i_issue / mfma_per_sld_a>{});
// }
}); });
// if cycle_mfma>gld_a sync here
block_sync_load_raw(issues_g);
sld_a(as[I1], a_sld_win1, NEG1{});
// compute buffer 1 // compute buffer 1
static_for<0, total_loops, 1>{}([&](auto i_issue) { static_for<0, total_loops, 1>{}([&](auto i_issue) {
...@@ -523,7 +545,10 @@ struct FusedMoeGemmPipeline_Flatmm ...@@ -523,7 +545,10 @@ struct FusedMoeGemmPipeline_Flatmm
static_for<0, total_loops, 1>{}([&](auto i_issue) { static_for<0, total_loops, 1>{}([&](auto i_issue) {
gemm_1(acc_1s[I1], y, ds[I1], i_issue); gemm_1(acc_1s[I1], y, ds[I1], i_issue);
if constexpr(i_issue % mfma_per_gld_d == 0) if constexpr(i_issue % mfma_per_gld_d == 0)
{
gld_d(ds[I0], number<i_issue / mfma_per_gld_d>{}); gld_d(ds[I0], number<i_issue / mfma_per_gld_d>{});
move_d();
}
if constexpr(i_issue % mfma_per_atm_o == 0) if constexpr(i_issue % mfma_per_atm_o == 0)
{ {
...@@ -536,7 +561,10 @@ struct FusedMoeGemmPipeline_Flatmm ...@@ -536,7 +561,10 @@ struct FusedMoeGemmPipeline_Flatmm
static_for<0, total_loops, 1>{}([&](auto i_issue) { static_for<0, total_loops, 1>{}([&](auto i_issue) {
gemm_1(acc_1s[I0], y, ds[I0], i_issue); gemm_1(acc_1s[I0], y, ds[I0], i_issue);
if constexpr(i_issue % mfma_per_gld_d == 0) if constexpr(i_issue % mfma_per_gld_d == 0)
{
gld_d(ds[I1], number<i_issue / mfma_per_gld_d>{}); gld_d(ds[I1], number<i_issue / mfma_per_gld_d>{});
move_d();
}
if constexpr(i_issue % mfma_per_atm_o == 0) if constexpr(i_issue % mfma_per_atm_o == 0)
{ {
...@@ -553,7 +581,10 @@ struct FusedMoeGemmPipeline_Flatmm ...@@ -553,7 +581,10 @@ struct FusedMoeGemmPipeline_Flatmm
static_for<0, total_loops, 1>{}([&](auto i_issue) { static_for<0, total_loops, 1>{}([&](auto i_issue) {
gemm_1(acc_1s[I0], y, ds[I0], i_issue); gemm_1(acc_1s[I0], y, ds[I0], i_issue);
if constexpr(i_issue % mfma_per_gld_d == 0) if constexpr(i_issue % mfma_per_gld_d == 0)
{
gld_d(ds[I1], number<i_issue / mfma_per_gld_d>{}); gld_d(ds[I1], number<i_issue / mfma_per_gld_d>{});
move_d();
}
}); });
}; };
auto pipeline_gemm1_tail = [&]() { auto pipeline_gemm1_tail = [&]() {
...@@ -564,7 +595,10 @@ struct FusedMoeGemmPipeline_Flatmm ...@@ -564,7 +595,10 @@ struct FusedMoeGemmPipeline_Flatmm
static_for<0, total_loops, 1>{}([&](auto i_issue) { static_for<0, total_loops, 1>{}([&](auto i_issue) {
gemm_1(acc_1s[I1], y, ds[I1], i_issue); gemm_1(acc_1s[I1], y, ds[I1], i_issue);
if constexpr(i_issue % mfma_per_gld_d == 0) if constexpr(i_issue % mfma_per_gld_d == 0)
{
gld_d(ds[I0], number<i_issue / mfma_per_gld_d>{}); gld_d(ds[I0], number<i_issue / mfma_per_gld_d>{});
move_d();
}
if constexpr(i_issue % mfma_per_atm_o == 0) if constexpr(i_issue % mfma_per_atm_o == 0)
{ {
...@@ -586,10 +620,13 @@ struct FusedMoeGemmPipeline_Flatmm ...@@ -586,10 +620,13 @@ struct FusedMoeGemmPipeline_Flatmm
move_g(); move_g();
clear_tile(acc_0); clear_tile(acc_0);
async_load_fence_raw(g_win.get_num_of_access()); // preload for next round
sld_a(as[I0], a_sld_win0, NEG1); gld_a(a_sst_win1, NEG1);
gld_a(a_sst_win1, NEG1); gld_g(gs[I1], NEG1);
// make sure a,g loaded
block_sync_load_raw(issues_a + issues_g);
lds_load_fence();
// we manually unroll double buffer inside hot loop // we manually unroll double buffer inside hot loop
const index_t iters_0 = (num_blocks_k0 - 2) / 2; const index_t iters_0 = (num_blocks_k0 - 2) / 2;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment