Commit 08e44540 authored by coderfeli's avatar coderfeli
Browse files

debugs

parent a75f162b
...@@ -25,7 +25,8 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile: ...@@ -25,7 +25,8 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile:
else if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" && else if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 0) t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 0)
{ {
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 1024, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 0>; using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 0>;
// using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 1024, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 0>;
r = fused_moegemm_<t_>(s, a); r = fused_moegemm_<t_>(s, a);
} }
else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" && else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" &&
...@@ -37,7 +38,8 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile: ...@@ -37,7 +38,8 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile:
else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" && else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 0) t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 0)
{ {
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 1024, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 0>; // using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 1024, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 0>;
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 0>;
r = fused_moegemm_<t_>(s, a); r = fused_moegemm_<t_>(s, a);
} }
// clang-format on // clang-format on
......
...@@ -33,7 +33,7 @@ struct fmoe_ // traits, ugly name, only used for internal ...@@ -33,7 +33,7 @@ struct fmoe_ // traits, ugly name, only used for internal
using YSmoothScaleDataType = ck_tile::remove_cvref_t<typename TypeConfig::YSmoothScaleDataType>; using YSmoothScaleDataType = ck_tile::remove_cvref_t<typename TypeConfig::YSmoothScaleDataType>;
using TopkWeightDataType = ck_tile::remove_cvref_t<typename TypeConfig::TopkWeightDataType>; using TopkWeightDataType = ck_tile::remove_cvref_t<typename TypeConfig::TopkWeightDataType>;
using IndexDataType = ck_tile::remove_cvref_t<typename TypeConfig::IndexDataType>; using IndexDataType = ck_tile::remove_cvref_t<typename TypeConfig::IndexDataType>;
// S<32, 1024, 128, 128>, S<1, 4, 1>, S<16, 16, 32> // S<32, 1024|512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>
static constexpr ck_tile::index_t BT_ = BlockTIle_::at(ck_tile::number<0>{}); // block token static constexpr ck_tile::index_t BT_ = BlockTIle_::at(ck_tile::number<0>{}); // block token
static constexpr ck_tile::index_t BI_ = static constexpr ck_tile::index_t BI_ =
BlockTIle_::at(ck_tile::number<1>{}); // block intermediate BlockTIle_::at(ck_tile::number<1>{}); // block intermediate
...@@ -44,7 +44,7 @@ struct fmoe_ // traits, ugly name, only used for internal ...@@ -44,7 +44,7 @@ struct fmoe_ // traits, ugly name, only used for internal
using WarpPerBlock_0 = ck_tile::remove_cvref_t<WarpPerBlock_>; // S<1, 4, 1> using WarpPerBlock_0 = ck_tile::remove_cvref_t<WarpPerBlock_>; // S<1, 4, 1>
using WarpTile_0 = ck_tile::remove_cvref_t<WarpTile_>; // S<16, 16, 32> using WarpTile_0 = ck_tile::remove_cvref_t<WarpTile_>; // S<16, 16, 32>
using BlockTile_1 = ck_tile::sequence<BT_, BD_, BI_ / (GateOnly_ ? 1 : 2)>; // 32, 128, 512 using BlockTile_1 = ck_tile::sequence<BT_, BD_, BI_ / (GateOnly_ ? 1 : 2)>; // 32, 128, 512|256
using WarpPerBlock_1 = ck_tile::remove_cvref_t<WarpPerBlock_>; /// S<1, 4, 1> using WarpPerBlock_1 = ck_tile::remove_cvref_t<WarpPerBlock_>; /// S<1, 4, 1>
using WarpTile_1 = ck_tile::remove_cvref_t<WarpTile_>; // S<16, 16, 32> using WarpTile_1 = ck_tile::remove_cvref_t<WarpTile_>; // S<16, 16, 32>
......
...@@ -10,8 +10,13 @@ ...@@ -10,8 +10,13 @@
template float fused_moegemm_< template float fused_moegemm_<
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0> fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a); >(const ck_tile::stream_config& s, fused_moegemm_args a);
template float fused_moegemm_< template float fused_moegemm_<
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 1024, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 0> fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 1024, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a); >(const ck_tile::stream_config& s, fused_moegemm_args a);
template float fused_moegemm_<
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
// clang-format on // clang-format on
...@@ -13,5 +13,8 @@ template float fused_moegemm_< ...@@ -13,5 +13,8 @@ template float fused_moegemm_<
template float fused_moegemm_< template float fused_moegemm_<
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 1024, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 0> fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 1024, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a); >(const ck_tile::stream_config& s, fused_moegemm_args a);
template float fused_moegemm_<
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
// clang-format on // clang-format on
...@@ -59,6 +59,42 @@ auto shuffle_moe_weight(const ck_tile::HostTensor<T>& t, std::string mfma_dtype, ...@@ -59,6 +59,42 @@ auto shuffle_moe_weight(const ck_tile::HostTensor<T>& t, std::string mfma_dtype,
} }
return t; return t;
} }
template <typename T>
auto shuffle_moe_weight_gateup(const ck_tile::HostTensor<T>& t, std::string mfma_dtype, int mfma_type = 0)
{
assert(t.get_lengths().size() == 3);
int b_ = t.get_lengths()[0];
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[2];
if((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 0)
{
ck_tile::HostTensor<T> t_view({b_, n_ / 32, 32, k_ / 16, 2, 8});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5});
}
else if((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 1)
{
ck_tile::HostTensor<T> t_view({b_, 2 , n_ / 512, 16 , 16, k_ / 32, 4, 8});
// ck_tile::HostTensor<T> t_view({b_, n_ / 16, 16, k_ / 32, 4, 8});
std::copy(t.begin(), t.end(), t_view.begin());
// return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5});
return ck_tile::reference_permute(t_view, {0, 2, 1, 3, 5, 6, 4, 7});
}
else if((mfma_dtype == "int8" || mfma_dtype == "fp8") && mfma_type == 0)
{
ck_tile::HostTensor<T> t_view({b_, n_ / 32, 32, k_ / 32, 2, 16});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5});
}
else if((mfma_dtype == "int8" || mfma_dtype == "fp8") && mfma_type == 1)
{
ck_tile::HostTensor<T> t_view({b_, n_ / 16, 16, k_ / 64, 4, 16});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5});
}
return t;
}
template <typename IndexType> template <typename IndexType>
void topid_unique_gen( void topid_unique_gen(
......
...@@ -427,10 +427,10 @@ struct FlatmmSn_32x128x512_1x4x1_16x16x32_FP16_itl : public FlatmmSn_32x128x512_ ...@@ -427,10 +427,10 @@ struct FlatmmSn_32x128x512_1x4x1_16x16x32_FP16_itl : public FlatmmSn_32x128x512_
[v_os_b1]"v"(static_cast<index_t>(cached_coords_b[number<1>{}] * sizeof(BDataType))), [v_os_b1]"v"(static_cast<index_t>(cached_coords_b[number<1>{}] * sizeof(BDataType))),
[v_os_b2]"v"(static_cast<index_t>(cached_coords_b[number<2>{}] * sizeof(BDataType))), [v_os_b2]"v"(static_cast<index_t>(cached_coords_b[number<2>{}] * sizeof(BDataType))),
[v_os_b3]"v"(static_cast<index_t>(cached_coords_b[number<3>{}] * sizeof(BDataType))), [v_os_b3]"v"(static_cast<index_t>(cached_coords_b[number<3>{}] * sizeof(BDataType))),
[v_os_b4]"v"(static_cast<index_t>(cached_coords_b[number<4>{}] * sizeof(BDataType))), // [v_os_b4]"v"(static_cast<index_t>(cached_coords_b[number<4>{}] * sizeof(BDataType))),
[v_os_b5]"v"(static_cast<index_t>(cached_coords_b[number<5>{}] * sizeof(BDataType))), // [v_os_b5]"v"(static_cast<index_t>(cached_coords_b[number<5>{}] * sizeof(BDataType))),
[v_os_b6]"v"(static_cast<index_t>(cached_coords_b[number<6>{}] * sizeof(BDataType))), // [v_os_b6]"v"(static_cast<index_t>(cached_coords_b[number<6>{}] * sizeof(BDataType))),
[v_os_b7]"v"(static_cast<index_t>(cached_coords_b[number<7>{}] * sizeof(BDataType))), // [v_os_b7]"v"(static_cast<index_t>(cached_coords_b[number<7>{}] * sizeof(BDataType))),
[s_tile_os_o]"s"(tile_stride_o_bytes), [s_tile_os_o]"s"(tile_stride_o_bytes),
[s_tile_os_b]"s"(tile_stride_b_bytes), [s_tile_os_b]"s"(tile_stride_b_bytes),
......
...@@ -63,7 +63,7 @@ struct FusedMoeGemmShape ...@@ -63,7 +63,7 @@ struct FusedMoeGemmShape
// S<32, 512, 128>, S<1, 4, 1>, S<16, 16, 32> // S<32, 512, 128>, S<1, 4, 1>, S<16, 16, 32>
static constexpr index_t Block_M0 = BlockTile_0::at(number<0>{}); //32 static constexpr index_t Block_M0 = BlockTile_0::at(number<0>{}); //32
static constexpr index_t Block_N0 = BlockTile_0::at(number<1>{}); //512 static constexpr index_t Block_N0 = BlockTile_0::at(number<1>{}); //256
static constexpr index_t Block_K0 = BlockTile_0::at(number<2>{}); // 128 static constexpr index_t Block_K0 = BlockTile_0::at(number<2>{}); // 128
static constexpr index_t WarpPerBlock_M0 = WarpPerBlock_0::at(number<0>{}); // 1 static constexpr index_t WarpPerBlock_M0 = WarpPerBlock_0::at(number<0>{}); // 1
static constexpr index_t WarpPerBlock_N0 = WarpPerBlock_0::at(number<1>{}); // 4 static constexpr index_t WarpPerBlock_N0 = WarpPerBlock_0::at(number<1>{}); // 4
...@@ -73,18 +73,18 @@ struct FusedMoeGemmShape ...@@ -73,18 +73,18 @@ struct FusedMoeGemmShape
static constexpr index_t Warp_K0 = WarpTile_0::at(number<2>{}); // 32 static constexpr index_t Warp_K0 = WarpTile_0::at(number<2>{}); // 32
static constexpr index_t ThreadPerBlock_M0 = Warp_M0 * WarpPerBlock_M0; static constexpr index_t ThreadPerBlock_M0 = Warp_M0 * WarpPerBlock_M0;
static constexpr index_t ThreadPerBlock_N0 = Warp_N0 * WarpPerBlock_N0; static constexpr index_t ThreadPerBlock_N0 = Warp_N0 * WarpPerBlock_N0; // 64
static constexpr index_t ThreadPerBlock_K0 = Warp_K0 * WarpPerBlock_K0; static constexpr index_t ThreadPerBlock_K0 = Warp_K0 * WarpPerBlock_K0;
static_assert(Block_M0 % ThreadPerBlock_M0 == 0); static_assert(Block_M0 % ThreadPerBlock_M0 == 0);
static_assert(Block_N0 % ThreadPerBlock_N0 == 0); static_assert(Block_N0 % ThreadPerBlock_N0 == 0);
static_assert(Block_K0 % ThreadPerBlock_K0 == 0); static_assert(Block_K0 % ThreadPerBlock_K0 == 0);
static constexpr index_t Repeat_M0 = Block_M0 / ThreadPerBlock_M0; // 2 static constexpr index_t Repeat_M0 = Block_M0 / ThreadPerBlock_M0; // 2
static constexpr index_t Repeat_N0 = Block_N0 / ThreadPerBlock_N0; // 8 static constexpr index_t Repeat_N0 = Block_N0 / ThreadPerBlock_N0; // 4
static constexpr index_t Repeat_K0 = Block_K0 / ThreadPerBlock_K0; // 4 static constexpr index_t Repeat_K0 = Block_K0 / ThreadPerBlock_K0; // 4
static constexpr index_t Block_M1 = BlockTile_1::at(number<0>{}); //32 static constexpr index_t Block_M1 = BlockTile_1::at(number<0>{}); //32
static constexpr index_t Block_N1 = BlockTile_1::at(number<1>{}); //128 static constexpr index_t Block_N1 = BlockTile_1::at(number<1>{}); //128
static constexpr index_t Block_K1 = BlockTile_1::at(number<2>{}); //512 static constexpr index_t Block_K1 = BlockTile_1::at(number<2>{}); //256
static constexpr index_t WarpPerBlock_M1 = WarpPerBlock_1::at(number<0>{}); // 1 static constexpr index_t WarpPerBlock_M1 = WarpPerBlock_1::at(number<0>{}); // 1
static constexpr index_t WarpPerBlock_N1 = WarpPerBlock_1::at(number<1>{}); // 4 static constexpr index_t WarpPerBlock_N1 = WarpPerBlock_1::at(number<1>{}); // 4
static constexpr index_t WarpPerBlock_K1 = WarpPerBlock_1::at(number<2>{}); // 1 static constexpr index_t WarpPerBlock_K1 = WarpPerBlock_1::at(number<2>{}); // 1
...@@ -100,7 +100,7 @@ struct FusedMoeGemmShape ...@@ -100,7 +100,7 @@ struct FusedMoeGemmShape
static_assert(Block_K1 % ThreadPerBlock_K1 == 0); static_assert(Block_K1 % ThreadPerBlock_K1 == 0);
static constexpr index_t Repeat_M1 = Block_M1 / ThreadPerBlock_M1; // 2 static constexpr index_t Repeat_M1 = Block_M1 / ThreadPerBlock_M1; // 2
static constexpr index_t Repeat_N1 = Block_N1 / ThreadPerBlock_N1; // 2 static constexpr index_t Repeat_N1 = Block_N1 / ThreadPerBlock_N1; // 2
static constexpr index_t Repeat_K1 = Block_K1 / ThreadPerBlock_K1; // 16 static constexpr index_t Repeat_K1 = Block_K1 / ThreadPerBlock_K1; // 8
static constexpr index_t BlockSize = warpSize * NumWarps; static constexpr index_t BlockSize = warpSize * NumWarps;
...@@ -118,7 +118,7 @@ struct FusedMoeGemmShape ...@@ -118,7 +118,7 @@ struct FusedMoeGemmShape
static constexpr index_t Block_Kr0 = Block_K0 / Warp_K0; static constexpr index_t Block_Kr0 = Block_K0 / Warp_K0;
static constexpr index_t Block_W1 = Warp_N1 * Warp_K1; // 512 static constexpr index_t Block_W1 = Warp_N1 * Warp_K1; // 512
static constexpr index_t Block_Nr1 = Block_N1 / Warp_N1; // 8 static constexpr index_t Block_Nr1 = Block_N1 / Warp_N1; // 8
static constexpr index_t Block_Kr1 = Block_K1 / Warp_K1; // 16 static constexpr index_t Block_Kr1 = Block_K1 / Warp_K1; // 8
static_assert(Block_W0 == Block_W1); static_assert(Block_W0 == Block_W1);
// static_assert(Block_Nr0 == Block_Kr1); // static_assert(Block_Nr0 == Block_Kr1);
......
...@@ -243,7 +243,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy ...@@ -243,7 +243,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_G() CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_G()
{ {
constexpr auto PermuteEnum = Problem::Traits::PermuteEnum; constexpr auto PermuteEnum = Problem::Traits::PermuteEnum;
// constexpr index_t hidden_radio_0 = Problem::Traits::IsGateOnly ? 1 : 2; constexpr index_t hidden_radio_0 = Problem::Traits::IsGateOnly ? 1 : 2;
using S_ = typename Problem::BlockShape; using S_ = typename Problem::BlockShape;
if constexpr(PermuteEnum == FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten) if constexpr(PermuteEnum == FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten)
{ {
...@@ -251,7 +251,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy ...@@ -251,7 +251,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy
// number<S_::Repeat_N0>{}.eee(); // number<S_::Repeat_N0>{}.eee();
return MakeGlobalTileDistribution_Nr_Kr_W<S_::WarpPerBlock_N0, return MakeGlobalTileDistribution_Nr_Kr_W<S_::WarpPerBlock_N0,
S_::WarpPerBlock_K0, S_::WarpPerBlock_K0,
S_::Repeat_N0, /// hidden_radio_0, S_::Repeat_N0 * hidden_radio_0,
S_::Repeat_K0, S_::Repeat_K0,
get_warp_size(), get_warp_size(),
GetAlignment_G<Problem>()>(); GetAlignment_G<Problem>()>();
...@@ -803,7 +803,21 @@ struct FusedMoeGemmPipelineFlatmmPolicy ...@@ -803,7 +803,21 @@ struct FusedMoeGemmPipelineFlatmmPolicy
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32) S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32)
{ {
return Flatmm_32x512x128_1x4x1_16x16x32_FP16{}; return Flatmm_32x512x128_1x4x1_16x16x32_FP16{};
}
else if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::GDataType, ck_tile::bf16_t> &&
S_::Block_M0 == 32 && S_::Block_N0 == 256 && S_::Block_K0 == 128 &&
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32)
{
return Flatmm_32x256x128_1x4x1_16x16x32_BF16{};
} }
else if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::fp16_t> &&
std::is_same_v<typename Problem::GDataType, ck_tile::fp16_t> &&
S_::Block_M0 == 32 && S_::Block_N0 == 256 && S_::Block_K0 == 128 &&
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32)
{
return Flatmm_32x256x128_1x4x1_16x16x32_FP16{};
}
} }
template <typename Problem> template <typename Problem>
...@@ -851,6 +865,26 @@ struct FusedMoeGemmPipelineFlatmmPolicy ...@@ -851,6 +865,26 @@ struct FusedMoeGemmPipelineFlatmmPolicy
// return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16{}; // return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16{};
return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16_itl{}; return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16_itl{};
} }
else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::TopkWeightDataType, float> &&
S_::Block_M1 == 32 && S_::Block_N1 == 128 && S_::Block_K1 == 256 &&
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32 &&
T_::PipeInterleave == true)
{
// return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16{};
return FlatmmSn_32x128x256_1x4x1_16x16x32_BF16_itl{};
}
else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::fp16_t> &&
std::is_same_v<typename Problem::DDataType, ck_tile::fp16_t> &&
std::is_same_v<typename Problem::TopkWeightDataType, float> &&
S_::Block_M1 == 32 && S_::Block_N1 == 128 && S_::Block_K1 == 256 &&
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32 &&
T_::PipeInterleave == true)
{
// return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16{};
return FlatmmSn_32x128x256_1x4x1_16x16x32_FP16_itl{};
}
} }
}; };
} // namespace ck_tile } // namespace ck_tile
...@@ -73,7 +73,7 @@ struct FusedMoeGemmPipeline_FlatmmUk ...@@ -73,7 +73,7 @@ struct FusedMoeGemmPipeline_FlatmmUk
constexpr index_t smem_0 = Policy::template GetUK_0<Problem>().GetSmemSize(); constexpr index_t smem_0 = Policy::template GetUK_0<Problem>().GetSmemSize();
constexpr index_t smem_1 = Policy::template GetUK_1<Problem>().GetSmemSize(); constexpr index_t smem_1 = Policy::template GetUK_1<Problem>().GetSmemSize();
constexpr index_t smem_bridge = constexpr index_t smem_bridge =
BlockShape::Block_M0 * BlockShape::Block_N0 * (IsGateOnly ? 1 : 2); BlockShape::Block_M0 * BlockShape::Block_N0;
return max(smem_0, max(smem_1, smem_bridge)); return max(smem_0, max(smem_1, smem_bridge));
} }
...@@ -168,7 +168,7 @@ struct FusedMoeGemmPipeline_FlatmmUk ...@@ -168,7 +168,7 @@ struct FusedMoeGemmPipeline_FlatmmUk
index_t intermediate_tile_id) index_t intermediate_tile_id)
{ {
constexpr index_t hidden_radio_0 = IsGateOnly ? 1 : 2; constexpr index_t hidden_radio_0 = IsGateOnly ? 1 : 2;
ck_tile::index_t shared_intermediate_size_0 = kargs.intermediate_size / hidden_radio_0; ck_tile::index_t shared_intermediate_size_0 = kargs.intermediate_size;
ck_tile::index_t shared_intermediate_size_1 = kargs.intermediate_size / hidden_radio_0; ck_tile::index_t shared_intermediate_size_1 = kargs.intermediate_size / hidden_radio_0;
index_t nr_0 = shared_intermediate_size_0 / BlockShape::Warp_N0; // divide N in W index_t nr_0 = shared_intermediate_size_0 / BlockShape::Warp_N0; // divide N in W
...@@ -178,13 +178,13 @@ struct FusedMoeGemmPipeline_FlatmmUk ...@@ -178,13 +178,13 @@ struct FusedMoeGemmPipeline_FlatmmUk
const IndexDataType expert_id = __builtin_amdgcn_readfirstlane( const IndexDataType expert_id = __builtin_amdgcn_readfirstlane(
reinterpret_cast<const IndexDataType*>(kargs.sorted_expert_ids_ptr)[sorted_tile_id]); reinterpret_cast<const IndexDataType*>(kargs.sorted_expert_ids_ptr)[sorted_tile_id]);
index_t expert_stride_0 = shared_intermediate_size_0 * kargs.hidden_size * hidden_radio_0; index_t expert_stride_0 = shared_intermediate_size_0 * kargs.hidden_size;
index_t expert_stride_1 = shared_intermediate_size_1 * kargs.hidden_size; index_t expert_stride_1 = shared_intermediate_size_1 * kargs.hidden_size;
// nr*kr*w // nr*kr*w
index_t interm_idx_nr0 = __builtin_amdgcn_readfirstlane( index_t interm_idx_nr0 = __builtin_amdgcn_readfirstlane(
intermediate_tile_id * intermediate_tile_id *
BlockShape::Block_Nr0); // intermediate_tile_id * Block_N / (N in W) BlockShape::Block_Nr0 * hidden_radio_0); // intermediate_tile_id * Block_N / (N in W)
index_t interm_idx_kr1 = __builtin_amdgcn_readfirstlane( index_t interm_idx_kr1 = __builtin_amdgcn_readfirstlane(
intermediate_tile_id * intermediate_tile_id *
...@@ -218,7 +218,7 @@ struct FusedMoeGemmPipeline_FlatmmUk ...@@ -218,7 +218,7 @@ struct FusedMoeGemmPipeline_FlatmmUk
auto g_window_ = make_tile_window_linear_raw( auto g_window_ = make_tile_window_linear_raw(
g_view_, g_view_,
make_tuple(number<BlockShape::Block_Nr0>{}, make_tuple(number<BlockShape::Block_Nr0 * hidden_radio_0>{},
number<BlockShape::Block_Kr0>{}, number<BlockShape::Block_Kr0>{},
number<BlockShape::Block_W0>{}), number<BlockShape::Block_W0>{}),
{0, 0, 0}, {0, 0, 0},
...@@ -324,7 +324,7 @@ struct FusedMoeGemmPipeline_FlatmmUk ...@@ -324,7 +324,7 @@ struct FusedMoeGemmPipeline_FlatmmUk
kargs.hidden_size, kargs.hidden_size,
BlockShape::Block_K0, // tile offset for B matrix each unroll BlockShape::Block_K0, // tile offset for B matrix each unroll
BlockShape::Block_Kr0 * BlockShape::Block_Kr0 *
BlockShape::Block_W0); // tile offset for B matrix each unroll BlockShape::Block_W0); // tile offset for B matrix each unroll
// fast GeLu // fast GeLu
if constexpr(std::is_same_v<typename Problem::GateActivation, if constexpr(std::is_same_v<typename Problem::GateActivation,
...@@ -351,32 +351,32 @@ struct FusedMoeGemmPipeline_FlatmmUk ...@@ -351,32 +351,32 @@ struct FusedMoeGemmPipeline_FlatmmUk
block_sync_lds(); block_sync_lds();
// up // up
if(!IsGateOnly) // if(!IsGateOnly)
{ // {
// up ptr. add hafl expoert_stride_0 as offset. // // up ptr. add hafl expoert_stride_0 as offset.
auto u_win = gu_win_gen(shared_intermediate_size_0 * kargs.hidden_size); // auto u_win = gu_win_gen(shared_intermediate_size_0 * kargs.hidden_size);
auto u_res = u_win.get_bottom_tensor_view().get_buffer_view().cached_buf_res_; // auto u_res = u_win.get_bottom_tensor_view().get_buffer_view().cached_buf_res_;
auto u_coords = // auto u_coords =
generate_tuple([&](auto i) { return u_win.cached_coords_[i].get_offset(); }, // generate_tuple([&](auto i) { return u_win.cached_coords_[i].get_offset(); },
number<decltype(u_win)::NumAccess_NonLinear>{}); // number<decltype(u_win)::NumAccess_NonLinear>{});
// reuse UK0 // // reuse UK0
auto uk_0_u = Policy::template GetUK_0<Problem>(); // auto uk_0_u = Policy::template GetUK_0<Problem>();
auto acc_0_u = uk_0_u(a_res, // auto acc_0_u = uk_0_u(a_res,
a_coords, // a_coords,
u_res, // u_res,
u_coords, // u_coords,
smem, // smem,
kargs.hidden_size, // kargs.hidden_size,
BlockShape::Block_K0, // tile offset for B matrix each unroll // BlockShape::Block_K0, // tile offset for B matrix each unroll
BlockShape::Block_Kr0 * // BlockShape::Block_Kr0 *
BlockShape::Block_W0); // tile offset for B matrix each unroll // BlockShape::Block_W0); // tile offset for B matrix each unroll
// elementwise mul gate*up. // // elementwise mul gate*up.
sweep_tile( // sweep_tile(
y_pre, // y_pre,
[&](auto idx0) { y_pre(idx0) = y_pre(idx0) * acc_0_u(idx0); }, // [&](auto idx0) { y_pre(idx0) = y_pre(idx0) * acc_0_u(idx0); },
sequence<1, 1>{}); // sequence<1, 1>{});
block_sync_lds(); // block_sync_lds();
} // }
store_tile(bridge_sst_win, cast_tile<YDataType>(y_pre)); store_tile(bridge_sst_win, cast_tile<YDataType>(y_pre));
block_sync_lds(); block_sync_lds();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment