Commit 2880f7a5 authored by OscarXu's avatar OscarXu
Browse files

[CK_TILE] Support moe with up gemm

parent 37b35146
...@@ -22,12 +22,24 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile: ...@@ -22,12 +22,24 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile:
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>; using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>;
r = fused_moegemm_<t_>(s, a); r = fused_moegemm_<t_>(s, a);
} }
else if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 0)
{
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 1024, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 0>;
r = fused_moegemm_<t_>(s, a);
}
else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" && else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1) t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1)
{ {
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>; using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>;
r = fused_moegemm_<t_>(s, a); r = fused_moegemm_<t_>(s, a);
} }
else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 0)
{
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 1024, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 0>;
r = fused_moegemm_<t_>(s, a);
}
// clang-format on // clang-format on
return r; return r;
} }
...@@ -16,26 +16,27 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a) ...@@ -16,26 +16,27 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
{ {
using f_traits = ck_tile::FusedMoeGemmTraits<Ts_::GateOnly, Ts_::FusedQuant == 1, 1 /*atomic*/>; using f_traits = ck_tile::FusedMoeGemmTraits<Ts_::GateOnly, Ts_::FusedQuant == 1, 1 /*atomic*/>;
using f_shape = ck_tile::FusedMoeGemmShape<typename Ts_::BlockTile_0, using f_shape = ck_tile::FusedMoeGemmShape<typename Ts_::BlockTile_0,
typename Ts_::WarpPerBlock_0, typename Ts_::WarpPerBlock_0,
typename Ts_::WarpTile_0, typename Ts_::WarpTile_0,
typename Ts_::BlockTile_1, typename Ts_::BlockTile_1,
typename Ts_::WarpPerBlock_0, typename Ts_::WarpPerBlock_0,
typename Ts_::WarpTile_0>; typename Ts_::WarpTile_0>;
using f_problem = using f_problem = ck_tile::FusedMoeGemmPipelineProblem<typename Ts_::ADataType,
ck_tile::FusedMoeGemmPipelineProblem<typename Ts_::ADataType, typename Ts_::GDataType,
typename Ts_::GDataType, typename Ts_::DDataType,
typename Ts_::DDataType, typename Ts_::AccDataType,
typename Ts_::AccDataType, typename Ts_::ODataType,
typename Ts_::ODataType, typename Ts_::AScaleDataType,
typename Ts_::AScaleDataType, typename Ts_::GScaleDataType,
typename Ts_::GScaleDataType, typename Ts_::DScaleDataType,
typename Ts_::DScaleDataType, typename Ts_::YSmoothScaleDataType,
typename Ts_::YSmoothScaleDataType, typename Ts_::TopkWeightDataType,
typename Ts_::TopkWeightDataType, typename Ts_::IndexDataType,
typename Ts_::IndexDataType, // ck_tile::element_wise::FastGeluAsm, //
ck_tile::element_wise::FastGeluAsm, // TODO: hardcoded // TODO: hardcoded
f_shape, ck_tile::element_wise::Silu,
f_traits>; f_shape,
f_traits>;
// using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmEx<f_problem>; // using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmEx<f_problem>;
using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmUk<f_problem>; using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmUk<f_problem>;
......
...@@ -40,7 +40,7 @@ struct fmoe_ // traits, ugly name, only used for internal ...@@ -40,7 +40,7 @@ struct fmoe_ // traits, ugly name, only used for internal
static constexpr ck_tile::index_t BH_ = BlockTIle_::at(ck_tile::number<2>{}); // block hidden static constexpr ck_tile::index_t BH_ = BlockTIle_::at(ck_tile::number<2>{}); // block hidden
static constexpr ck_tile::index_t BD_ = BlockTIle_::at(ck_tile::number<3>{}); // block down static constexpr ck_tile::index_t BD_ = BlockTIle_::at(ck_tile::number<3>{}); // block down
using BlockTile_0 = ck_tile::sequence<BT_, BI_, BH_>; using BlockTile_0 = ck_tile::sequence<BT_, BI_ / (GateOnly_ ? 1 : 2), BH_>;
using WarpPerBlock_0 = ck_tile::remove_cvref_t<WarpPerBlock_>; using WarpPerBlock_0 = ck_tile::remove_cvref_t<WarpPerBlock_>;
using WarpTile_0 = ck_tile::remove_cvref_t<WarpTile_>; using WarpTile_0 = ck_tile::remove_cvref_t<WarpTile_>;
......
...@@ -10,5 +10,8 @@ ...@@ -10,5 +10,8 @@
template float fused_moegemm_< template float fused_moegemm_<
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0> fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a); >(const ck_tile::stream_config& s, fused_moegemm_args a);
template float fused_moegemm_<
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 1024, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
// clang-format on // clang-format on
...@@ -10,5 +10,8 @@ ...@@ -10,5 +10,8 @@
template float fused_moegemm_< template float fused_moegemm_<
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0> fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a); >(const ck_tile::stream_config& s, fused_moegemm_args a);
template float fused_moegemm_<
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 1024, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
// clang-format on // clang-format on
...@@ -228,7 +228,8 @@ struct FusedMoeGemmKernel ...@@ -228,7 +228,8 @@ struct FusedMoeGemmKernel
int max_num_tokens_padded = int max_num_tokens_padded =
hargs.topk * hargs.num_tokens + hargs.num_experts * block_m - hargs.topk; hargs.topk * hargs.num_tokens + hargs.num_experts * block_m - hargs.topk;
// printf("xxx max_num_tokens_padded:%d\n", max_num_tokens_padded); // printf("xxx max_num_tokens_padded:%d\n", max_num_tokens_padded);
return Partitioner::GridSize(max_num_tokens_padded, hargs.intermediate_size); return Partitioner::GridSize(max_num_tokens_padded,
hargs.intermediate_size / (IsGateOnly ? 1 : 2));
} }
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(BlockSize_); } CK_TILE_HOST static constexpr auto BlockSize() { return dim3(BlockSize_); }
...@@ -382,7 +383,7 @@ struct FusedMoeGemmKernel ...@@ -382,7 +383,7 @@ struct FusedMoeGemmKernel
auto o_window = [&]() { auto o_window = [&]() {
ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr); ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr);
auto o_view_ = make_naive_tensor_view<address_space_enum::global, auto o_view_ = make_naive_tensor_view<address_space_enum::global,
memory_operation_enum::atomic_add>( memory_operation_enum::atomic_add>(
o_ptr, o_ptr,
make_tuple(kargs.num_tokens, kargs.hidden_size), make_tuple(kargs.num_tokens, kargs.hidden_size),
make_tuple(kargs.stride_token, 1), make_tuple(kargs.stride_token, 1),
......
...@@ -73,7 +73,7 @@ struct FusedMoeGemmPipeline_FlatmmUk ...@@ -73,7 +73,7 @@ struct FusedMoeGemmPipeline_FlatmmUk
constexpr index_t smem_0 = Policy::template GetUK_0<Problem>().GetSmemSize(); constexpr index_t smem_0 = Policy::template GetUK_0<Problem>().GetSmemSize();
constexpr index_t smem_1 = Policy::template GetUK_1<Problem>().GetSmemSize(); constexpr index_t smem_1 = Policy::template GetUK_1<Problem>().GetSmemSize();
constexpr index_t smem_bridge = constexpr index_t smem_bridge =
BlockShape::Block_M0 * BlockShape::Block_N0 * sizeof(YDataType); BlockShape::Block_M0 * BlockShape::Block_N0 * (IsGateOnly ? 1 : 2);
return max(smem_0, max(smem_1, smem_bridge)); return max(smem_0, max(smem_1, smem_bridge));
} }
...@@ -165,7 +165,7 @@ struct FusedMoeGemmPipeline_FlatmmUk ...@@ -165,7 +165,7 @@ struct FusedMoeGemmPipeline_FlatmmUk
index_t intermediate_tile_id) index_t intermediate_tile_id)
{ {
constexpr index_t hidden_radio_0 = IsGateOnly ? 1 : 2; constexpr index_t hidden_radio_0 = IsGateOnly ? 1 : 2;
ck_tile::index_t shared_intermediate_size_0 = kargs.intermediate_size; ck_tile::index_t shared_intermediate_size_0 = kargs.intermediate_size / hidden_radio_0;
ck_tile::index_t shared_intermediate_size_1 = kargs.intermediate_size / hidden_radio_0; ck_tile::index_t shared_intermediate_size_1 = kargs.intermediate_size / hidden_radio_0;
index_t nr_0 = shared_intermediate_size_0 / BlockShape::Warp_N0; // divide N in W index_t nr_0 = shared_intermediate_size_0 / BlockShape::Warp_N0; // divide N in W
...@@ -175,7 +175,7 @@ struct FusedMoeGemmPipeline_FlatmmUk ...@@ -175,7 +175,7 @@ struct FusedMoeGemmPipeline_FlatmmUk
const IndexDataType expert_id = __builtin_amdgcn_readfirstlane( const IndexDataType expert_id = __builtin_amdgcn_readfirstlane(
reinterpret_cast<const IndexDataType*>(kargs.sorted_expert_ids_ptr)[sorted_tile_id]); reinterpret_cast<const IndexDataType*>(kargs.sorted_expert_ids_ptr)[sorted_tile_id]);
index_t expert_stride_0 = shared_intermediate_size_0 * kargs.hidden_size; index_t expert_stride_0 = shared_intermediate_size_0 * kargs.hidden_size * hidden_radio_0;
index_t expert_stride_1 = shared_intermediate_size_1 * kargs.hidden_size; index_t expert_stride_1 = shared_intermediate_size_1 * kargs.hidden_size;
// nr*kr*w // nr*kr*w
...@@ -200,10 +200,10 @@ struct FusedMoeGemmPipeline_FlatmmUk ...@@ -200,10 +200,10 @@ struct FusedMoeGemmPipeline_FlatmmUk
make_wave_buffer_resource(reinterpret_cast<const ADataType*>(kargs.a_ptr), make_wave_buffer_resource(reinterpret_cast<const ADataType*>(kargs.a_ptr),
kargs.num_tokens * kargs.stride_token * sizeof(ADataType)); kargs.num_tokens * kargs.stride_token * sizeof(ADataType));
auto g_win = [&]() { auto gu_win_gen = [&](auto ptr_offset) {
const GDataType* g_ptr = reinterpret_cast<const GDataType*>(kargs.g_ptr) + const GDataType* g_ptr = reinterpret_cast<const GDataType*>(kargs.g_ptr) +
static_cast<long_index_t>(expert_id) * expert_stride_0 + static_cast<long_index_t>(expert_id) * expert_stride_0 +
interm_idx_nr0 * kr_0 * BlockShape::Block_W0; ptr_offset + interm_idx_nr0 * kr_0 * BlockShape::Block_W0;
auto g_view_ = make_naive_tensor_view<address_space_enum::global>( auto g_view_ = make_naive_tensor_view<address_space_enum::global>(
g_ptr, g_ptr,
make_tuple(nr_0, kr_0, number<BlockShape::Block_W0>{}), make_tuple(nr_0, kr_0, number<BlockShape::Block_W0>{}),
...@@ -220,7 +220,8 @@ struct FusedMoeGemmPipeline_FlatmmUk ...@@ -220,7 +220,8 @@ struct FusedMoeGemmPipeline_FlatmmUk
Policy::template MakeGlobalTileDistribution_G<Problem>(), Policy::template MakeGlobalTileDistribution_G<Problem>(),
sequence<0, 1, 1>{}); sequence<0, 1, 1>{});
return g_window_; return g_window_;
}(); };
auto g_win = gu_win_gen(0);
auto g_res = g_win.get_bottom_tensor_view().get_buffer_view().cached_buf_res_; auto g_res = g_win.get_bottom_tensor_view().get_buffer_view().cached_buf_res_;
auto g_coords = generate_tuple([&](auto i) { return g_win.cached_coords_[i].get_offset(); }, auto g_coords = generate_tuple([&](auto i) { return g_win.cached_coords_[i].get_offset(); },
...@@ -309,32 +310,70 @@ struct FusedMoeGemmPipeline_FlatmmUk ...@@ -309,32 +310,70 @@ struct FusedMoeGemmPipeline_FlatmmUk
auto w_scale = GetWeightScale( auto w_scale = GetWeightScale(
row_coords_o, reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr)); row_coords_o, reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr));
auto uk_0 = Policy::template GetUK_0<Problem>(); auto uk_0_g = Policy::template GetUK_0<Problem>();
auto acc_0 = uk_0(a_res, auto acc_0 = uk_0_g(a_res,
a_coords, a_coords,
g_res, g_res,
g_coords, g_coords,
smem, smem,
kargs.hidden_size, kargs.hidden_size,
BlockShape::Block_K0, // tile offset for B matrix each unroll BlockShape::Block_K0, // tile offset for B matrix each unroll
BlockShape::Block_Kr0 * BlockShape::Block_Kr0 *
BlockShape::Block_W0); // tile offset for B matrix each unroll BlockShape::Block_W0); // tile offset for B matrix each unroll
sweep_tile( // fast GeLu
acc_0, if constexpr(std::is_same_v<typename Problem::GateActivation,
[&](auto idx0, auto idx1) { ck_tile::element_wise::FastGeluAsm>)
fp32x2_t v_{acc_0(idx0), acc_0(idx1)}; {
typename Problem::GateActivation{}(v_, v_); sweep_tile(
acc_0(idx0) = v_.x; acc_0,
acc_0(idx1) = v_.y; [&](auto idx0, auto idx1) {
}, fp32x2_t v_{acc_0(idx0), acc_0(idx1)};
sequence<1, 2>{}); typename Problem::GateActivation{}(v_, v_);
acc_0(idx0) = v_.x;
auto y_pre = cast_tile<YDataType>(acc_0); acc_0(idx1) = v_.y;
},
sequence<1, 2>{});
}
else
{
sweep_tile(
acc_0,
[&](auto idx0) { typename Problem::GateActivation{}(acc_0(idx0), acc_0(idx0)); },
sequence<1, 1>{});
}
auto y_pre = acc_0;
block_sync_lds(); block_sync_lds();
store_tile(bridge_sst_win, y_pre); // up
if(!IsGateOnly)
{
// up ptr. add hafl expoert_stride_0 as offset.
auto u_win = gu_win_gen(shared_intermediate_size_0 * kargs.hidden_size);
auto u_res = u_win.get_bottom_tensor_view().get_buffer_view().cached_buf_res_;
auto u_coords =
generate_tuple([&](auto i) { return u_win.cached_coords_[i].get_offset(); },
number<decltype(u_win)::NumAccess_NonLinear>{});
// reuse UK0
auto uk_0_u = Policy::template GetUK_0<Problem>();
auto acc_0_u = uk_0_u(a_res,
a_coords,
u_res,
u_coords,
smem,
kargs.hidden_size,
BlockShape::Block_K0, // tile offset for B matrix each unroll
BlockShape::Block_Kr0 *
BlockShape::Block_W0); // tile offset for B matrix each unroll
// elementwise mul gate*up.
sweep_tile(
y_pre,
[&](auto idx0) { y_pre(idx0) = y_pre(idx0) * acc_0_u(idx0); },
sequence<1, 1>{});
block_sync_lds();
}
store_tile(bridge_sst_win, cast_tile<YDataType>(y_pre));
block_sync_lds(); block_sync_lds();
auto uk_1 = Policy::template GetUK_1<Problem>(); auto uk_1 = Policy::template GetUK_1<Problem>();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment