Commit 45131629 authored by carlushuang's avatar carlushuang
Browse files

update pipeline

parent f09dc1f3
......@@ -19,7 +19,7 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile:
if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1)
{
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<32, 32, 16>, 1, 0>;
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 256, 128, 128>, S<1, 4, 1>, S<32, 32, 16>, 1, 0>;
fused_moegemm_<t_>(s, a);
}
// clang-format on
......
......@@ -11,4 +11,7 @@ template float fused_moegemm_<
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<32, 32, 16>, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
template float fused_moegemm_<
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 256, 128, 128>, S<1, 4, 1>, S<32, 32, 16>, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
// clang-format on
......@@ -51,6 +51,11 @@ struct FusedMoeGemmPipeline_Flatmm
static constexpr index_t kAlignmentD = Policy::template GetAlignment_D<Problem>();
static constexpr index_t kAlignmentO = Policy::template GetAlignment_O<Problem>();
static constexpr index_t SLD_A = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::SLD_A);
static constexpr index_t GLD_A = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GLD_A);
static constexpr index_t GLD_B = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GLD_B);
static constexpr index_t GST_O = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GST_O);
static constexpr index_t kBlockPerCu = []() {
if constexpr(Problem::kBlockPerCu != -1)
return Problem::kBlockPerCu;
......@@ -146,10 +151,14 @@ struct FusedMoeGemmPipeline_Flatmm
auto a_win = make_tile_window_linear(
a_window_, Policy::template MakeGlobalTileDistribution_A<Problem>());
auto g_win = make_tile_window_linear(
g_window_, Policy::template MakeGlobalTileDistribution_G<Problem>());
auto d_win = make_tile_window_linear(
d_window_, Policy::template MakeGlobalTileDistribution_D<Problem>());
auto g_win =
make_tile_window_linear(g_window_,
Policy::template MakeGlobalTileDistribution_G<Problem>(),
sequence<0, 1, 1>{});
auto d_win =
make_tile_window_linear(d_window_,
Policy::template MakeGlobalTileDistribution_D<Problem>(),
sequence<0, 1, 1>{});
auto o_win = make_tile_window_linear(
o_window_, Policy::template MakeGlobalTileDistribution_O<Problem>());
......@@ -239,8 +248,8 @@ struct FusedMoeGemmPipeline_Flatmm
constexpr auto issues_a = number<a_win.get_num_of_access()>{};
constexpr auto issues_g = number<g_win.get_num_of_access()>{};
constexpr auto issues_d = number<d_win.get_num_of_access()>{};
constexpr auto issues_o = number<o_win.get_num_of_access()>{};
// constexpr auto issues_d = number<d_win.get_num_of_access()>{};
// constexpr auto issues_o = number<o_win.get_num_of_access()>{};
constexpr auto issues_gemm0 =
number<BlockShape::Repeat_M0 * BlockShape::Repeat_N0 * BlockShape::Repeat_K0 *
warp_gemm_0.get_num_of_access()>{};
......@@ -431,12 +440,7 @@ struct FusedMoeGemmPipeline_Flatmm
constexpr index_t total_loops = issues_gemm0;
constexpr auto sr = Policy::template GetSequencer_0<Problem>();
static_assert(sr.size() == total_loops);
constexpr index_t SLD_A =
static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::SLD_A);
constexpr index_t GLD_A =
static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GLD_A);
constexpr index_t GLD_B =
static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GLD_B);
constexpr auto c_sld_a_0 = MAKE_SC();
constexpr auto c_gld_a_0 = MAKE_SC();
constexpr auto c_gld_b_0 = MAKE_SC();
......@@ -480,36 +484,33 @@ struct FusedMoeGemmPipeline_Flatmm
};
auto pipeline_gemm0_tail = [&]() {
constexpr index_t total_loops = issues_gemm0;
constexpr index_t mfma_per_gld_g = total_loops / issues_g; // BlockShape::Repeat_M0;
// constexpr index_t mfma_per_gld_a = total_loops / issues_a;
// constexpr index_t mfma_per_sld_a = total_loops / issues_sld_a;
constexpr index_t total_loops = issues_gemm0;
constexpr auto sr = Policy::template GetSequencer_0<Problem>();
static_assert(sr.size() == total_loops);
constexpr auto c_gld_b_0 = MAKE_SC();
// compute buffer 0
static_for<0, total_loops, 1>{}([&](auto i_issue) {
gemm_0(acc_0, as[I0], gs[I0], i_issue);
if constexpr(i_issue % mfma_per_gld_g == 0)
{
gld_g(gs[I1], number<i_issue / mfma_per_gld_g>{});
move_g();
}
// if constexpr (i_issue % mfma_per_gld_a == 0)
// gld_a(a_sst_win0, number<i_issue / mfma_per_gld_a>{});
constexpr index_t slot = sr.at(i_issue);
// if constexpr(i_issue % mfma_per_sld_a == 0)
// {
// block_sync_load_raw(a_sst_win0.get_num_of_access());
// sld_a(as[I1], a_sld_win1, number<i_issue / mfma_per_sld_a>{});
// }
if constexpr(slot & GLD_B)
gld_g(gs[I1], number<NEXT_SCI(c_gld_b_0, i_issue)>{});
});
// if cycle_mfma>gld_a sync here
block_sync_load_raw(issues_g);
sld_a(as[I1], a_sld_win1, NEG1);
// compute buffer 1
static_for<0, total_loops, 1>{}([&](auto i_issue) {
gemm_0(acc_0, as[I1], gs[I1], i_issue, TRUE); // last gemm has nop
constexpr auto last_nop = [&]() {
if constexpr(i_issue == (total_loops - 1))
return TRUE;
else
return FALSE;
}();
gemm_0(acc_0, as[I1], gs[I1], i_issue, last_nop); // last gemm has nop
});
};
......@@ -527,73 +528,79 @@ struct FusedMoeGemmPipeline_Flatmm
// note, gemm-1 start from idx-1 to N-2 (0, 1, 2....N-1)
auto pipeline_gemm1 = [&]() {
constexpr index_t total_loops = issues_gemm1;
constexpr index_t mfma_per_gld_d = total_loops / issues_d; // BlockShape::Repeat_M0;
constexpr index_t mfma_per_atm_o = total_loops / issues_o;
constexpr index_t total_loops = issues_gemm1;
constexpr auto sr = Policy::template GetSequencer_1<Problem>();
static_assert(sr.size() == total_loops);
// compute buffer 1
constexpr auto c_gld_b_0 = MAKE_SC();
constexpr auto c_gst_o_0 = MAKE_SC();
constexpr auto c_gld_b_1 = MAKE_SC();
constexpr auto c_gst_o_1 = MAKE_SC();
// compute buffer 0
static_for<0, total_loops, 1>{}([&](auto i_issue) {
gemm_1(acc_1s[I1], y, ds[I1], i_issue);
if constexpr(i_issue % mfma_per_gld_d == 0)
{
gld_d(ds[I0], number<i_issue / mfma_per_gld_d>{});
move_d();
}
constexpr index_t slot = sr.at(i_issue);
if constexpr(slot & GLD_B)
gld_d(ds[I0], number<NEXT_SCI(c_gld_b_0, i_issue)>{});
if constexpr(i_issue % mfma_per_atm_o == 0)
if constexpr(slot & GST_O)
{
auto out = cast_tile<ODataType>(acc_1s[I0]);
atomic_add_o(out, number<i_issue / mfma_per_atm_o>{});
atomic_add_o(out, number<NEXT_SCI(c_gst_o_0, i_issue)>{});
}
});
move_d();
// move_o();
// compute buffer 0
// compute buffer 1
static_for<0, total_loops, 1>{}([&](auto i_issue) {
gemm_1(acc_1s[I0], y, ds[I0], i_issue);
if constexpr(i_issue % mfma_per_gld_d == 0)
{
gld_d(ds[I1], number<i_issue / mfma_per_gld_d>{});
move_d();
}
constexpr index_t slot = sr.at(i_issue);
if constexpr(slot & GLD_B)
gld_d(ds[I1], number<NEXT_SCI(c_gld_b_1, i_issue)>{});
if constexpr(i_issue % mfma_per_atm_o == 0)
if constexpr(slot & GST_O)
{
auto out = cast_tile<ODataType>(acc_1s[I1]);
atomic_add_o(out, number<i_issue / mfma_per_atm_o>{});
atomic_add_o(out, number<NEXT_SCI(c_gst_o_1, i_issue)>{});
}
});
move_d();
};
auto pipeline_gemm1_head = [&]() {
constexpr index_t total_loops = issues_gemm1;
constexpr index_t mfma_per_gld_d = total_loops / issues_d;
constexpr index_t total_loops = issues_gemm1;
constexpr auto sr = Policy::template GetSequencer_1<Problem>();
static_assert(sr.size() == total_loops);
constexpr auto c_gld_b_0 = MAKE_SC();
// compute buffer 0
static_for<0, total_loops, 1>{}([&](auto i_issue) {
gemm_1(acc_1s[I0], y, ds[I0], i_issue);
if constexpr(i_issue % mfma_per_gld_d == 0)
{
gld_d(ds[I1], number<i_issue / mfma_per_gld_d>{});
move_d();
}
constexpr index_t slot = sr.at(i_issue);
if constexpr(slot & GLD_B)
gld_d(ds[I1], number<NEXT_SCI(c_gld_b_0, i_issue)>{});
});
move_d();
};
auto pipeline_gemm1_tail = [&]() {
constexpr index_t total_loops = issues_gemm1;
constexpr index_t mfma_per_gld_d = total_loops / issues_d;
constexpr index_t mfma_per_atm_o = total_loops / issues_o;
constexpr index_t total_loops = issues_gemm1;
constexpr auto sr = Policy::template GetSequencer_1<Problem>();
static_assert(sr.size() == total_loops);
constexpr auto c_gst_o_0 = MAKE_SC();
// compute buffer 1
static_for<0, total_loops, 1>{}([&](auto i_issue) {
gemm_1(acc_1s[I1], y, ds[I1], i_issue);
if constexpr(i_issue % mfma_per_gld_d == 0)
{
gld_d(ds[I0], number<i_issue / mfma_per_gld_d>{});
move_d();
}
if constexpr(i_issue % mfma_per_atm_o == 0)
constexpr index_t slot = sr.at(i_issue);
if constexpr(slot & GST_O)
{
auto out = cast_tile<ODataType>(acc_1s[I0]);
atomic_add_o(out, number<i_issue / mfma_per_atm_o>{});
atomic_add_o(out, number<NEXT_SCI(c_gst_o_0, i_issue)>{});
}
});
{
......@@ -620,7 +627,7 @@ struct FusedMoeGemmPipeline_Flatmm
// we manually unroll double buffer inside hot loop
const index_t iters_0 = (num_blocks_k0 - 2) / 2;
index_t i_0 = 0;
index_t i_0 = 0; // (void)i_0; (void)iters_0; (void)pipeline_gemm0;
while(i_0++ < iters_0)
{
pipeline_gemm0();
......@@ -630,7 +637,7 @@ struct FusedMoeGemmPipeline_Flatmm
pipeline_bridge();
const index_t iters_1 = (num_blocks_n1 - 2) / 2;
index_t i_1 = 0;
index_t i_1 = 0; // (void) i_1; (void)iters_1; (void)pipeline_gemm1;
pipeline_gemm1_head();
while(i_1++ < iters_1)
{
......
......@@ -641,6 +641,75 @@ struct FusedMoeGemmPipelineFlatmmPolicy
return seq_all;
// clang-format on
}
else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 16 &&
S_::Block_M0 == 32 && S_::Block_N0 == 256 && S_::Block_K0 == 128 &&
S_::Block_N1 == 128)
{
// Total 32 instructions, 16 buffer-load-dwordx4 gld_b, 8x buffer-load-dwordx1-async
// gld_a 8x ds_read_b128 sld_a total 64 slot :)
// clang-format off
constexpr auto seq_all =
// 0 1 2 3 4 5 6 7
sequence<GLD_B, GLD_A, GLD_B, GLD_A, GLD_B, GLD_A, GLD_B, GLD_A, // 0
GLD_B, GLD_A, GLD_B, GLD_A, GLD_B, GLD_A, GLD_B, GLD_A, // 1
GLD_B, SLD_A, GLD_B, SLD_A, GLD_B, SLD_A, GLD_B, SLD_A, // 2
GLD_B, SLD_A, GLD_B, SLD_A, GLD_B, SLD_A, GLD_B, SLD_A>{}; // 3
return seq_all;
// clang-format on
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSequencer_1()
{
// this function return seq<...> used to identify gld/sld/valu... inside mfma sequence
// the purpose is to hide thoes instructions under mfma
// every value inside seq<...> is a mask, indicating a specific operation
using S_ = typename Problem::BlockShape;
constexpr index_t GLD_B = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GLD_B);
constexpr index_t GST_O = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GST_O);
if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
S_::Warp_M1 == 32 && S_::Warp_N1 == 32 && S_::Warp_K1 == 16 &&
S_::Block_M0 == 32 && S_::Block_N0 == 512 && S_::Block_K0 == 128 &&
S_::Block_N1 == 128)
{
// Total 64 instructions, 32 buffer-load-dwordx4 gld_b, 8x buffer-load-dwordx1-async
// gld_a 8x ds_read_b128 sld_a total 64 slot :)
// clang-format off
constexpr auto seq_all =
// 0 1 2 3 4 5 6 7
sequence<GLD_B, GST_O, GLD_B, GST_O, GLD_B, GST_O, GLD_B, GST_O, // 0
GLD_B, GST_O, GLD_B, GST_O, GLD_B, GST_O, GLD_B, GST_O, // 1
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 2
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 3
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 4
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 5
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 6
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0>{}; // 7
return seq_all;
// clang-format on
}
else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
S_::Warp_M1 == 32 && S_::Warp_N1 == 32 && S_::Warp_K1 == 16 &&
S_::Block_M0 == 32 && S_::Block_N0 == 256 && S_::Block_K0 == 128 &&
S_::Block_N1 == 128)
{
// Total 64 instructions, 32 buffer-load-dwordx4 gld_b, 8x buffer-load-dwordx1-async
// gld_a 8x ds_read_b128 sld_a total 64 slot :)
// clang-format off
constexpr auto seq_all =
// 0 1 2 3 4 5 6 7
sequence<GLD_B, GST_O, GLD_B, GST_O, GLD_B, GST_O, GLD_B, GST_O, // 0
GLD_B, GST_O, GLD_B, GST_O, GLD_B, GST_O, GLD_B, GST_O, // 1
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 2
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0>{}; // 3
return seq_all;
// clang-format on
}
}
template <typename Problem>
......
......@@ -43,5 +43,6 @@ enum class FusedMoeGemmPipelineSequencerEnum
GLD_B = 1 << 3,
SST_A = 1 << 4, // shared store a
SST_B = 1 << 5,
GST_O = 1 << 6, // global store out
};
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment