Commit 2880f7a5 authored by OscarXu's avatar OscarXu
Browse files

[CK_TILE] Support moe with up gemm

parent 37b35146
...@@ -22,12 +22,24 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile: ...@@ -22,12 +22,24 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile:
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>; using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>;
r = fused_moegemm_<t_>(s, a); r = fused_moegemm_<t_>(s, a);
} }
else if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 0)
{
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 1024, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 0>;
r = fused_moegemm_<t_>(s, a);
}
else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" && else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1) t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1)
{ {
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>; using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>;
r = fused_moegemm_<t_>(s, a); r = fused_moegemm_<t_>(s, a);
} }
else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 0)
{
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 1024, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 0>;
r = fused_moegemm_<t_>(s, a);
}
// clang-format on // clang-format on
return r; return r;
} }
...@@ -21,8 +21,7 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a) ...@@ -21,8 +21,7 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
typename Ts_::BlockTile_1, typename Ts_::BlockTile_1,
typename Ts_::WarpPerBlock_0, typename Ts_::WarpPerBlock_0,
typename Ts_::WarpTile_0>; typename Ts_::WarpTile_0>;
using f_problem = using f_problem = ck_tile::FusedMoeGemmPipelineProblem<typename Ts_::ADataType,
ck_tile::FusedMoeGemmPipelineProblem<typename Ts_::ADataType,
typename Ts_::GDataType, typename Ts_::GDataType,
typename Ts_::DDataType, typename Ts_::DDataType,
typename Ts_::AccDataType, typename Ts_::AccDataType,
...@@ -33,7 +32,9 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a) ...@@ -33,7 +32,9 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
typename Ts_::YSmoothScaleDataType, typename Ts_::YSmoothScaleDataType,
typename Ts_::TopkWeightDataType, typename Ts_::TopkWeightDataType,
typename Ts_::IndexDataType, typename Ts_::IndexDataType,
ck_tile::element_wise::FastGeluAsm, // TODO: hardcoded // ck_tile::element_wise::FastGeluAsm, //
// TODO: hardcoded
ck_tile::element_wise::Silu,
f_shape, f_shape,
f_traits>; f_traits>;
......
...@@ -40,7 +40,7 @@ struct fmoe_ // traits, ugly name, only used for internal ...@@ -40,7 +40,7 @@ struct fmoe_ // traits, ugly name, only used for internal
static constexpr ck_tile::index_t BH_ = BlockTIle_::at(ck_tile::number<2>{}); // block hidden static constexpr ck_tile::index_t BH_ = BlockTIle_::at(ck_tile::number<2>{}); // block hidden
static constexpr ck_tile::index_t BD_ = BlockTIle_::at(ck_tile::number<3>{}); // block down static constexpr ck_tile::index_t BD_ = BlockTIle_::at(ck_tile::number<3>{}); // block down
using BlockTile_0 = ck_tile::sequence<BT_, BI_, BH_>; using BlockTile_0 = ck_tile::sequence<BT_, BI_ / (GateOnly_ ? 1 : 2), BH_>;
using WarpPerBlock_0 = ck_tile::remove_cvref_t<WarpPerBlock_>; using WarpPerBlock_0 = ck_tile::remove_cvref_t<WarpPerBlock_>;
using WarpTile_0 = ck_tile::remove_cvref_t<WarpTile_>; using WarpTile_0 = ck_tile::remove_cvref_t<WarpTile_>;
......
...@@ -10,5 +10,8 @@ ...@@ -10,5 +10,8 @@
template float fused_moegemm_< template float fused_moegemm_<
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0> fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a); >(const ck_tile::stream_config& s, fused_moegemm_args a);
template float fused_moegemm_<
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 1024, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
// clang-format on // clang-format on
...@@ -10,5 +10,8 @@ ...@@ -10,5 +10,8 @@
template float fused_moegemm_< template float fused_moegemm_<
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0> fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a); >(const ck_tile::stream_config& s, fused_moegemm_args a);
template float fused_moegemm_<
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 1024, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
// clang-format on // clang-format on
...@@ -228,7 +228,8 @@ struct FusedMoeGemmKernel ...@@ -228,7 +228,8 @@ struct FusedMoeGemmKernel
int max_num_tokens_padded = int max_num_tokens_padded =
hargs.topk * hargs.num_tokens + hargs.num_experts * block_m - hargs.topk; hargs.topk * hargs.num_tokens + hargs.num_experts * block_m - hargs.topk;
// printf("xxx max_num_tokens_padded:%d\n", max_num_tokens_padded); // printf("xxx max_num_tokens_padded:%d\n", max_num_tokens_padded);
return Partitioner::GridSize(max_num_tokens_padded, hargs.intermediate_size); return Partitioner::GridSize(max_num_tokens_padded,
hargs.intermediate_size / (IsGateOnly ? 1 : 2));
} }
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(BlockSize_); } CK_TILE_HOST static constexpr auto BlockSize() { return dim3(BlockSize_); }
......
...@@ -73,7 +73,7 @@ struct FusedMoeGemmPipeline_FlatmmUk ...@@ -73,7 +73,7 @@ struct FusedMoeGemmPipeline_FlatmmUk
constexpr index_t smem_0 = Policy::template GetUK_0<Problem>().GetSmemSize(); constexpr index_t smem_0 = Policy::template GetUK_0<Problem>().GetSmemSize();
constexpr index_t smem_1 = Policy::template GetUK_1<Problem>().GetSmemSize(); constexpr index_t smem_1 = Policy::template GetUK_1<Problem>().GetSmemSize();
constexpr index_t smem_bridge = constexpr index_t smem_bridge =
BlockShape::Block_M0 * BlockShape::Block_N0 * sizeof(YDataType); BlockShape::Block_M0 * BlockShape::Block_N0 * (IsGateOnly ? 1 : 2);
return max(smem_0, max(smem_1, smem_bridge)); return max(smem_0, max(smem_1, smem_bridge));
} }
...@@ -165,7 +165,7 @@ struct FusedMoeGemmPipeline_FlatmmUk ...@@ -165,7 +165,7 @@ struct FusedMoeGemmPipeline_FlatmmUk
index_t intermediate_tile_id) index_t intermediate_tile_id)
{ {
constexpr index_t hidden_radio_0 = IsGateOnly ? 1 : 2; constexpr index_t hidden_radio_0 = IsGateOnly ? 1 : 2;
ck_tile::index_t shared_intermediate_size_0 = kargs.intermediate_size; ck_tile::index_t shared_intermediate_size_0 = kargs.intermediate_size / hidden_radio_0;
ck_tile::index_t shared_intermediate_size_1 = kargs.intermediate_size / hidden_radio_0; ck_tile::index_t shared_intermediate_size_1 = kargs.intermediate_size / hidden_radio_0;
index_t nr_0 = shared_intermediate_size_0 / BlockShape::Warp_N0; // divide N in W index_t nr_0 = shared_intermediate_size_0 / BlockShape::Warp_N0; // divide N in W
...@@ -175,7 +175,7 @@ struct FusedMoeGemmPipeline_FlatmmUk ...@@ -175,7 +175,7 @@ struct FusedMoeGemmPipeline_FlatmmUk
const IndexDataType expert_id = __builtin_amdgcn_readfirstlane( const IndexDataType expert_id = __builtin_amdgcn_readfirstlane(
reinterpret_cast<const IndexDataType*>(kargs.sorted_expert_ids_ptr)[sorted_tile_id]); reinterpret_cast<const IndexDataType*>(kargs.sorted_expert_ids_ptr)[sorted_tile_id]);
index_t expert_stride_0 = shared_intermediate_size_0 * kargs.hidden_size; index_t expert_stride_0 = shared_intermediate_size_0 * kargs.hidden_size * hidden_radio_0;
index_t expert_stride_1 = shared_intermediate_size_1 * kargs.hidden_size; index_t expert_stride_1 = shared_intermediate_size_1 * kargs.hidden_size;
// nr*kr*w // nr*kr*w
...@@ -200,10 +200,10 @@ struct FusedMoeGemmPipeline_FlatmmUk ...@@ -200,10 +200,10 @@ struct FusedMoeGemmPipeline_FlatmmUk
make_wave_buffer_resource(reinterpret_cast<const ADataType*>(kargs.a_ptr), make_wave_buffer_resource(reinterpret_cast<const ADataType*>(kargs.a_ptr),
kargs.num_tokens * kargs.stride_token * sizeof(ADataType)); kargs.num_tokens * kargs.stride_token * sizeof(ADataType));
auto g_win = [&]() { auto gu_win_gen = [&](auto ptr_offset) {
const GDataType* g_ptr = reinterpret_cast<const GDataType*>(kargs.g_ptr) + const GDataType* g_ptr = reinterpret_cast<const GDataType*>(kargs.g_ptr) +
static_cast<long_index_t>(expert_id) * expert_stride_0 + static_cast<long_index_t>(expert_id) * expert_stride_0 +
interm_idx_nr0 * kr_0 * BlockShape::Block_W0; ptr_offset + interm_idx_nr0 * kr_0 * BlockShape::Block_W0;
auto g_view_ = make_naive_tensor_view<address_space_enum::global>( auto g_view_ = make_naive_tensor_view<address_space_enum::global>(
g_ptr, g_ptr,
make_tuple(nr_0, kr_0, number<BlockShape::Block_W0>{}), make_tuple(nr_0, kr_0, number<BlockShape::Block_W0>{}),
...@@ -220,7 +220,8 @@ struct FusedMoeGemmPipeline_FlatmmUk ...@@ -220,7 +220,8 @@ struct FusedMoeGemmPipeline_FlatmmUk
Policy::template MakeGlobalTileDistribution_G<Problem>(), Policy::template MakeGlobalTileDistribution_G<Problem>(),
sequence<0, 1, 1>{}); sequence<0, 1, 1>{});
return g_window_; return g_window_;
}(); };
auto g_win = gu_win_gen(0);
auto g_res = g_win.get_bottom_tensor_view().get_buffer_view().cached_buf_res_; auto g_res = g_win.get_bottom_tensor_view().get_buffer_view().cached_buf_res_;
auto g_coords = generate_tuple([&](auto i) { return g_win.cached_coords_[i].get_offset(); }, auto g_coords = generate_tuple([&](auto i) { return g_win.cached_coords_[i].get_offset(); },
...@@ -309,8 +310,8 @@ struct FusedMoeGemmPipeline_FlatmmUk ...@@ -309,8 +310,8 @@ struct FusedMoeGemmPipeline_FlatmmUk
auto w_scale = GetWeightScale( auto w_scale = GetWeightScale(
row_coords_o, reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr)); row_coords_o, reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr));
auto uk_0 = Policy::template GetUK_0<Problem>(); auto uk_0_g = Policy::template GetUK_0<Problem>();
auto acc_0 = uk_0(a_res, auto acc_0 = uk_0_g(a_res,
a_coords, a_coords,
g_res, g_res,
g_coords, g_coords,
...@@ -320,6 +321,10 @@ struct FusedMoeGemmPipeline_FlatmmUk ...@@ -320,6 +321,10 @@ struct FusedMoeGemmPipeline_FlatmmUk
BlockShape::Block_Kr0 * BlockShape::Block_Kr0 *
BlockShape::Block_W0); // tile offset for B matrix each unroll BlockShape::Block_W0); // tile offset for B matrix each unroll
// fast GeLu
if constexpr(std::is_same_v<typename Problem::GateActivation,
ck_tile::element_wise::FastGeluAsm>)
{
sweep_tile( sweep_tile(
acc_0, acc_0,
[&](auto idx0, auto idx1) { [&](auto idx0, auto idx1) {
...@@ -329,12 +334,46 @@ struct FusedMoeGemmPipeline_FlatmmUk ...@@ -329,12 +334,46 @@ struct FusedMoeGemmPipeline_FlatmmUk
acc_0(idx1) = v_.y; acc_0(idx1) = v_.y;
}, },
sequence<1, 2>{}); sequence<1, 2>{});
}
else
{
sweep_tile(
acc_0,
[&](auto idx0) { typename Problem::GateActivation{}(acc_0(idx0), acc_0(idx0)); },
sequence<1, 1>{});
}
auto y_pre = acc_0;
block_sync_lds();
auto y_pre = cast_tile<YDataType>(acc_0); // up
if(!IsGateOnly)
{
// up ptr. add hafl expoert_stride_0 as offset.
auto u_win = gu_win_gen(shared_intermediate_size_0 * kargs.hidden_size);
auto u_res = u_win.get_bottom_tensor_view().get_buffer_view().cached_buf_res_;
auto u_coords =
generate_tuple([&](auto i) { return u_win.cached_coords_[i].get_offset(); },
number<decltype(u_win)::NumAccess_NonLinear>{});
// reuse UK0
auto uk_0_u = Policy::template GetUK_0<Problem>();
auto acc_0_u = uk_0_u(a_res,
a_coords,
u_res,
u_coords,
smem,
kargs.hidden_size,
BlockShape::Block_K0, // tile offset for B matrix each unroll
BlockShape::Block_Kr0 *
BlockShape::Block_W0); // tile offset for B matrix each unroll
// elementwise mul gate*up.
sweep_tile(
y_pre,
[&](auto idx0) { y_pre(idx0) = y_pre(idx0) * acc_0_u(idx0); },
sequence<1, 1>{});
block_sync_lds(); block_sync_lds();
}
store_tile(bridge_sst_win, y_pre); store_tile(bridge_sst_win, cast_tile<YDataType>(y_pre));
block_sync_lds(); block_sync_lds();
auto uk_1 = Policy::template GetUK_1<Problem>(); auto uk_1 = Policy::template GetUK_1<Problem>();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment