Commit 45131629 authored by carlushuang's avatar carlushuang
Browse files

update pipeline

parent f09dc1f3
...@@ -19,7 +19,7 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile: ...@@ -19,7 +19,7 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile:
if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" && if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1) t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1)
{ {
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<32, 32, 16>, 1, 0>; using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 256, 128, 128>, S<1, 4, 1>, S<32, 32, 16>, 1, 0>;
fused_moegemm_<t_>(s, a); fused_moegemm_<t_>(s, a);
} }
// clang-format on // clang-format on
......
...@@ -11,4 +11,7 @@ template float fused_moegemm_< ...@@ -11,4 +11,7 @@ template float fused_moegemm_<
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<32, 32, 16>, 1, 0> fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<32, 32, 16>, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a); >(const ck_tile::stream_config& s, fused_moegemm_args a);
template float fused_moegemm_<
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 256, 128, 128>, S<1, 4, 1>, S<32, 32, 16>, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
// clang-format on // clang-format on
...@@ -51,6 +51,11 @@ struct FusedMoeGemmPipeline_Flatmm ...@@ -51,6 +51,11 @@ struct FusedMoeGemmPipeline_Flatmm
static constexpr index_t kAlignmentD = Policy::template GetAlignment_D<Problem>(); static constexpr index_t kAlignmentD = Policy::template GetAlignment_D<Problem>();
static constexpr index_t kAlignmentO = Policy::template GetAlignment_O<Problem>(); static constexpr index_t kAlignmentO = Policy::template GetAlignment_O<Problem>();
static constexpr index_t SLD_A = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::SLD_A);
static constexpr index_t GLD_A = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GLD_A);
static constexpr index_t GLD_B = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GLD_B);
static constexpr index_t GST_O = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GST_O);
static constexpr index_t kBlockPerCu = []() { static constexpr index_t kBlockPerCu = []() {
if constexpr(Problem::kBlockPerCu != -1) if constexpr(Problem::kBlockPerCu != -1)
return Problem::kBlockPerCu; return Problem::kBlockPerCu;
...@@ -146,10 +151,14 @@ struct FusedMoeGemmPipeline_Flatmm ...@@ -146,10 +151,14 @@ struct FusedMoeGemmPipeline_Flatmm
auto a_win = make_tile_window_linear( auto a_win = make_tile_window_linear(
a_window_, Policy::template MakeGlobalTileDistribution_A<Problem>()); a_window_, Policy::template MakeGlobalTileDistribution_A<Problem>());
auto g_win = make_tile_window_linear( auto g_win =
g_window_, Policy::template MakeGlobalTileDistribution_G<Problem>()); make_tile_window_linear(g_window_,
auto d_win = make_tile_window_linear( Policy::template MakeGlobalTileDistribution_G<Problem>(),
d_window_, Policy::template MakeGlobalTileDistribution_D<Problem>()); sequence<0, 1, 1>{});
auto d_win =
make_tile_window_linear(d_window_,
Policy::template MakeGlobalTileDistribution_D<Problem>(),
sequence<0, 1, 1>{});
auto o_win = make_tile_window_linear( auto o_win = make_tile_window_linear(
o_window_, Policy::template MakeGlobalTileDistribution_O<Problem>()); o_window_, Policy::template MakeGlobalTileDistribution_O<Problem>());
...@@ -239,8 +248,8 @@ struct FusedMoeGemmPipeline_Flatmm ...@@ -239,8 +248,8 @@ struct FusedMoeGemmPipeline_Flatmm
constexpr auto issues_a = number<a_win.get_num_of_access()>{}; constexpr auto issues_a = number<a_win.get_num_of_access()>{};
constexpr auto issues_g = number<g_win.get_num_of_access()>{}; constexpr auto issues_g = number<g_win.get_num_of_access()>{};
constexpr auto issues_d = number<d_win.get_num_of_access()>{}; // constexpr auto issues_d = number<d_win.get_num_of_access()>{};
constexpr auto issues_o = number<o_win.get_num_of_access()>{}; // constexpr auto issues_o = number<o_win.get_num_of_access()>{};
constexpr auto issues_gemm0 = constexpr auto issues_gemm0 =
number<BlockShape::Repeat_M0 * BlockShape::Repeat_N0 * BlockShape::Repeat_K0 * number<BlockShape::Repeat_M0 * BlockShape::Repeat_N0 * BlockShape::Repeat_K0 *
warp_gemm_0.get_num_of_access()>{}; warp_gemm_0.get_num_of_access()>{};
...@@ -431,12 +440,7 @@ struct FusedMoeGemmPipeline_Flatmm ...@@ -431,12 +440,7 @@ struct FusedMoeGemmPipeline_Flatmm
constexpr index_t total_loops = issues_gemm0; constexpr index_t total_loops = issues_gemm0;
constexpr auto sr = Policy::template GetSequencer_0<Problem>(); constexpr auto sr = Policy::template GetSequencer_0<Problem>();
static_assert(sr.size() == total_loops); static_assert(sr.size() == total_loops);
constexpr index_t SLD_A =
static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::SLD_A);
constexpr index_t GLD_A =
static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GLD_A);
constexpr index_t GLD_B =
static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GLD_B);
constexpr auto c_sld_a_0 = MAKE_SC(); constexpr auto c_sld_a_0 = MAKE_SC();
constexpr auto c_gld_a_0 = MAKE_SC(); constexpr auto c_gld_a_0 = MAKE_SC();
constexpr auto c_gld_b_0 = MAKE_SC(); constexpr auto c_gld_b_0 = MAKE_SC();
...@@ -480,36 +484,33 @@ struct FusedMoeGemmPipeline_Flatmm ...@@ -480,36 +484,33 @@ struct FusedMoeGemmPipeline_Flatmm
}; };
auto pipeline_gemm0_tail = [&]() { auto pipeline_gemm0_tail = [&]() {
constexpr index_t total_loops = issues_gemm0; constexpr index_t total_loops = issues_gemm0;
constexpr index_t mfma_per_gld_g = total_loops / issues_g; // BlockShape::Repeat_M0; constexpr auto sr = Policy::template GetSequencer_0<Problem>();
// constexpr index_t mfma_per_gld_a = total_loops / issues_a; static_assert(sr.size() == total_loops);
// constexpr index_t mfma_per_sld_a = total_loops / issues_sld_a;
constexpr auto c_gld_b_0 = MAKE_SC();
// compute buffer 0 // compute buffer 0
static_for<0, total_loops, 1>{}([&](auto i_issue) { static_for<0, total_loops, 1>{}([&](auto i_issue) {
gemm_0(acc_0, as[I0], gs[I0], i_issue); gemm_0(acc_0, as[I0], gs[I0], i_issue);
if constexpr(i_issue % mfma_per_gld_g == 0) constexpr index_t slot = sr.at(i_issue);
{
gld_g(gs[I1], number<i_issue / mfma_per_gld_g>{});
move_g();
}
// if constexpr (i_issue % mfma_per_gld_a == 0)
// gld_a(a_sst_win0, number<i_issue / mfma_per_gld_a>{});
// if constexpr(i_issue % mfma_per_sld_a == 0) if constexpr(slot & GLD_B)
// { gld_g(gs[I1], number<NEXT_SCI(c_gld_b_0, i_issue)>{});
// block_sync_load_raw(a_sst_win0.get_num_of_access());
// sld_a(as[I1], a_sld_win1, number<i_issue / mfma_per_sld_a>{});
// }
}); });
// if cycle_mfma>gld_a sync here
block_sync_load_raw(issues_g); block_sync_load_raw(issues_g);
sld_a(as[I1], a_sld_win1, NEG1); sld_a(as[I1], a_sld_win1, NEG1);
// compute buffer 1 // compute buffer 1
static_for<0, total_loops, 1>{}([&](auto i_issue) { static_for<0, total_loops, 1>{}([&](auto i_issue) {
gemm_0(acc_0, as[I1], gs[I1], i_issue, TRUE); // last gemm has nop constexpr auto last_nop = [&]() {
if constexpr(i_issue == (total_loops - 1))
return TRUE;
else
return FALSE;
}();
gemm_0(acc_0, as[I1], gs[I1], i_issue, last_nop); // last gemm has nop
}); });
}; };
...@@ -527,73 +528,79 @@ struct FusedMoeGemmPipeline_Flatmm ...@@ -527,73 +528,79 @@ struct FusedMoeGemmPipeline_Flatmm
// note, gemm-1 start from idx-1 to N-2 (0, 1, 2....N-1) // note, gemm-1 start from idx-1 to N-2 (0, 1, 2....N-1)
auto pipeline_gemm1 = [&]() { auto pipeline_gemm1 = [&]() {
constexpr index_t total_loops = issues_gemm1; constexpr index_t total_loops = issues_gemm1;
constexpr index_t mfma_per_gld_d = total_loops / issues_d; // BlockShape::Repeat_M0; constexpr auto sr = Policy::template GetSequencer_1<Problem>();
constexpr index_t mfma_per_atm_o = total_loops / issues_o; static_assert(sr.size() == total_loops);
// compute buffer 1 constexpr auto c_gld_b_0 = MAKE_SC();
constexpr auto c_gst_o_0 = MAKE_SC();
constexpr auto c_gld_b_1 = MAKE_SC();
constexpr auto c_gst_o_1 = MAKE_SC();
// compute buffer 0
static_for<0, total_loops, 1>{}([&](auto i_issue) { static_for<0, total_loops, 1>{}([&](auto i_issue) {
gemm_1(acc_1s[I1], y, ds[I1], i_issue); gemm_1(acc_1s[I1], y, ds[I1], i_issue);
if constexpr(i_issue % mfma_per_gld_d == 0) constexpr index_t slot = sr.at(i_issue);
{ if constexpr(slot & GLD_B)
gld_d(ds[I0], number<i_issue / mfma_per_gld_d>{}); gld_d(ds[I0], number<NEXT_SCI(c_gld_b_0, i_issue)>{});
move_d();
}
if constexpr(i_issue % mfma_per_atm_o == 0) if constexpr(slot & GST_O)
{ {
auto out = cast_tile<ODataType>(acc_1s[I0]); auto out = cast_tile<ODataType>(acc_1s[I0]);
atomic_add_o(out, number<i_issue / mfma_per_atm_o>{}); atomic_add_o(out, number<NEXT_SCI(c_gst_o_0, i_issue)>{});
} }
}); });
move_d();
// move_o();
// compute buffer 0 // compute buffer 1
static_for<0, total_loops, 1>{}([&](auto i_issue) { static_for<0, total_loops, 1>{}([&](auto i_issue) {
gemm_1(acc_1s[I0], y, ds[I0], i_issue); gemm_1(acc_1s[I0], y, ds[I0], i_issue);
if constexpr(i_issue % mfma_per_gld_d == 0) constexpr index_t slot = sr.at(i_issue);
{ if constexpr(slot & GLD_B)
gld_d(ds[I1], number<i_issue / mfma_per_gld_d>{}); gld_d(ds[I1], number<NEXT_SCI(c_gld_b_1, i_issue)>{});
move_d();
}
if constexpr(i_issue % mfma_per_atm_o == 0) if constexpr(slot & GST_O)
{ {
auto out = cast_tile<ODataType>(acc_1s[I1]); auto out = cast_tile<ODataType>(acc_1s[I1]);
atomic_add_o(out, number<i_issue / mfma_per_atm_o>{}); atomic_add_o(out, number<NEXT_SCI(c_gst_o_1, i_issue)>{});
} }
}); });
move_d();
}; };
auto pipeline_gemm1_head = [&]() { auto pipeline_gemm1_head = [&]() {
constexpr index_t total_loops = issues_gemm1; constexpr index_t total_loops = issues_gemm1;
constexpr index_t mfma_per_gld_d = total_loops / issues_d; constexpr auto sr = Policy::template GetSequencer_1<Problem>();
static_assert(sr.size() == total_loops);
constexpr auto c_gld_b_0 = MAKE_SC();
// compute buffer 0 // compute buffer 0
static_for<0, total_loops, 1>{}([&](auto i_issue) { static_for<0, total_loops, 1>{}([&](auto i_issue) {
gemm_1(acc_1s[I0], y, ds[I0], i_issue); gemm_1(acc_1s[I0], y, ds[I0], i_issue);
if constexpr(i_issue % mfma_per_gld_d == 0) constexpr index_t slot = sr.at(i_issue);
{ if constexpr(slot & GLD_B)
gld_d(ds[I1], number<i_issue / mfma_per_gld_d>{}); gld_d(ds[I1], number<NEXT_SCI(c_gld_b_0, i_issue)>{});
move_d();
}
}); });
move_d();
}; };
auto pipeline_gemm1_tail = [&]() { auto pipeline_gemm1_tail = [&]() {
constexpr index_t total_loops = issues_gemm1; constexpr index_t total_loops = issues_gemm1;
constexpr index_t mfma_per_gld_d = total_loops / issues_d; constexpr auto sr = Policy::template GetSequencer_1<Problem>();
constexpr index_t mfma_per_atm_o = total_loops / issues_o; static_assert(sr.size() == total_loops);
constexpr auto c_gst_o_0 = MAKE_SC();
// compute buffer 1 // compute buffer 1
static_for<0, total_loops, 1>{}([&](auto i_issue) { static_for<0, total_loops, 1>{}([&](auto i_issue) {
gemm_1(acc_1s[I1], y, ds[I1], i_issue); gemm_1(acc_1s[I1], y, ds[I1], i_issue);
if constexpr(i_issue % mfma_per_gld_d == 0)
{
gld_d(ds[I0], number<i_issue / mfma_per_gld_d>{});
move_d();
}
if constexpr(i_issue % mfma_per_atm_o == 0) constexpr index_t slot = sr.at(i_issue);
if constexpr(slot & GST_O)
{ {
auto out = cast_tile<ODataType>(acc_1s[I0]); auto out = cast_tile<ODataType>(acc_1s[I0]);
atomic_add_o(out, number<i_issue / mfma_per_atm_o>{}); atomic_add_o(out, number<NEXT_SCI(c_gst_o_0, i_issue)>{});
} }
}); });
{ {
...@@ -620,7 +627,7 @@ struct FusedMoeGemmPipeline_Flatmm ...@@ -620,7 +627,7 @@ struct FusedMoeGemmPipeline_Flatmm
// we manually unroll double buffer inside hot loop // we manually unroll double buffer inside hot loop
const index_t iters_0 = (num_blocks_k0 - 2) / 2; const index_t iters_0 = (num_blocks_k0 - 2) / 2;
index_t i_0 = 0; index_t i_0 = 0; // (void)i_0; (void)iters_0; (void)pipeline_gemm0;
while(i_0++ < iters_0) while(i_0++ < iters_0)
{ {
pipeline_gemm0(); pipeline_gemm0();
...@@ -630,7 +637,7 @@ struct FusedMoeGemmPipeline_Flatmm ...@@ -630,7 +637,7 @@ struct FusedMoeGemmPipeline_Flatmm
pipeline_bridge(); pipeline_bridge();
const index_t iters_1 = (num_blocks_n1 - 2) / 2; const index_t iters_1 = (num_blocks_n1 - 2) / 2;
index_t i_1 = 0; index_t i_1 = 0; // (void) i_1; (void)iters_1; (void)pipeline_gemm1;
pipeline_gemm1_head(); pipeline_gemm1_head();
while(i_1++ < iters_1) while(i_1++ < iters_1)
{ {
......
...@@ -641,6 +641,75 @@ struct FusedMoeGemmPipelineFlatmmPolicy ...@@ -641,6 +641,75 @@ struct FusedMoeGemmPipelineFlatmmPolicy
return seq_all; return seq_all;
// clang-format on // clang-format on
} }
else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 16 &&
S_::Block_M0 == 32 && S_::Block_N0 == 256 && S_::Block_K0 == 128 &&
S_::Block_N1 == 128)
{
// Total 32 instructions, 16 buffer-load-dwordx4 gld_b, 8x buffer-load-dwordx1-async
// gld_a 8x ds_read_b128 sld_a total 64 slot :)
// clang-format off
constexpr auto seq_all =
// 0 1 2 3 4 5 6 7
sequence<GLD_B, GLD_A, GLD_B, GLD_A, GLD_B, GLD_A, GLD_B, GLD_A, // 0
GLD_B, GLD_A, GLD_B, GLD_A, GLD_B, GLD_A, GLD_B, GLD_A, // 1
GLD_B, SLD_A, GLD_B, SLD_A, GLD_B, SLD_A, GLD_B, SLD_A, // 2
GLD_B, SLD_A, GLD_B, SLD_A, GLD_B, SLD_A, GLD_B, SLD_A>{}; // 3
return seq_all;
// clang-format on
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSequencer_1()
{
// this function return seq<...> used to identify gld/sld/valu... inside mfma sequence
// the purpose is to hide thoes instructions under mfma
// every value inside seq<...> is a mask, indicating a specific operation
using S_ = typename Problem::BlockShape;
constexpr index_t GLD_B = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GLD_B);
constexpr index_t GST_O = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GST_O);
if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
S_::Warp_M1 == 32 && S_::Warp_N1 == 32 && S_::Warp_K1 == 16 &&
S_::Block_M0 == 32 && S_::Block_N0 == 512 && S_::Block_K0 == 128 &&
S_::Block_N1 == 128)
{
// Total 64 instructions, 32 buffer-load-dwordx4 gld_b, 8x buffer-load-dwordx1-async
// gld_a 8x ds_read_b128 sld_a total 64 slot :)
// clang-format off
constexpr auto seq_all =
// 0 1 2 3 4 5 6 7
sequence<GLD_B, GST_O, GLD_B, GST_O, GLD_B, GST_O, GLD_B, GST_O, // 0
GLD_B, GST_O, GLD_B, GST_O, GLD_B, GST_O, GLD_B, GST_O, // 1
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 2
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 3
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 4
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 5
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 6
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0>{}; // 7
return seq_all;
// clang-format on
}
else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
S_::Warp_M1 == 32 && S_::Warp_N1 == 32 && S_::Warp_K1 == 16 &&
S_::Block_M0 == 32 && S_::Block_N0 == 256 && S_::Block_K0 == 128 &&
S_::Block_N1 == 128)
{
// Total 64 instructions, 32 buffer-load-dwordx4 gld_b, 8x buffer-load-dwordx1-async
// gld_a 8x ds_read_b128 sld_a total 64 slot :)
// clang-format off
constexpr auto seq_all =
// 0 1 2 3 4 5 6 7
sequence<GLD_B, GST_O, GLD_B, GST_O, GLD_B, GST_O, GLD_B, GST_O, // 0
GLD_B, GST_O, GLD_B, GST_O, GLD_B, GST_O, GLD_B, GST_O, // 1
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 2
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0>{}; // 3
return seq_all;
// clang-format on
}
} }
template <typename Problem> template <typename Problem>
......
...@@ -43,5 +43,6 @@ enum class FusedMoeGemmPipelineSequencerEnum ...@@ -43,5 +43,6 @@ enum class FusedMoeGemmPipelineSequencerEnum
GLD_B = 1 << 3, GLD_B = 1 << 3,
SST_A = 1 << 4, // shared store a SST_A = 1 << 4, // shared store a
SST_B = 1 << 5, SST_B = 1 << 5,
GST_O = 1 << 6, // global store out
}; };
} // namespace ck_tile } // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment