Unverified Commit 0e5e29c4 authored by valarLip's avatar valarLip Committed by GitHub
Browse files

porting fmoe_sorting from moe_sorting (#1884)

* porting fmoe_sorting from moe_sorting

* pass default example test

* remod
parent 16fa63ea
...@@ -15,6 +15,7 @@ struct fused_moe_args ...@@ -15,6 +15,7 @@ struct fused_moe_args
const void* g_scale_ptr; // [e, 1, n], gate(up) scale const void* g_scale_ptr; // [e, 1, n], gate(up) scale
const void* d_scale_ptr; // [e, 1, k], down scale const void* d_scale_ptr; // [e, 1, k], down scale
const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input
const void* local_expert_mask_ptr; // [e], local_expert_mask_ptr for EP
void* o_ptr; // [m, k], output token (no need to do zeroing) void* o_ptr; // [m, k], output token (no need to do zeroing)
const void* topk_ids_ptr; // [tokens, topk] const void* topk_ids_ptr; // [tokens, topk]
...@@ -48,6 +49,8 @@ struct fused_moe_traits ...@@ -48,6 +49,8 @@ struct fused_moe_traits
int activation; // 0:gelu, 1:silu int activation; // 0:gelu, 1:silu
int gate_only; // 0:g1u0, 1:g1u1 int gate_only; // 0:g1u0, 1:g1u1
int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
bool local_expert_masking; // if mask experts as local expert
}; };
float fused_moe(fused_moe_traits, fused_moe_args, const ck_tile::stream_config&); float fused_moe(fused_moe_traits, fused_moe_args, const ck_tile::stream_config&);
...@@ -11,6 +11,7 @@ struct fused_moesorting_trait ...@@ -11,6 +11,7 @@ struct fused_moesorting_trait
{ {
std::string index_type; std::string index_type;
std::string weight_type; // currently always float std::string weight_type; // currently always float
bool local_expert_masking; // if mask experts as local expert
}; };
struct fused_moesorting_args : public ck_tile::MoeSortingHostArgs struct fused_moesorting_args : public ck_tile::MoeSortingHostArgs
......
...@@ -17,10 +17,11 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf ...@@ -17,10 +17,11 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf
return 1; return 1;
}(); }();
auto t0 = fused_moesorting_trait{"int32", "fp32"}; auto t0 = fused_moesorting_trait{"int32", "fp32", t.local_expert_masking};
auto a0 = fused_moesorting_args{ auto a0 = fused_moesorting_args{
a.topk_ids_ptr, // const void* p_topk_ids; a.topk_ids_ptr, // const void* p_topk_ids;
a.topk_weight_ptr, // const void* p_weights; a.topk_weight_ptr, // const void* p_weights;
a.local_expert_mask_ptr, // const void* p_local_expert_mask;
a.sorted_token_ids_ptr, // void* p_sorted_token_ids; a.sorted_token_ids_ptr, // void* p_sorted_token_ids;
a.sorted_weight_ptr, // void* p_sorted_weights; a.sorted_weight_ptr, // void* p_sorted_weights;
a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids; a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids;
......
...@@ -24,11 +24,16 @@ ...@@ -24,11 +24,16 @@
return ave_time; return ave_time;
#else #else
#define MOE_SORTING_DISPATCH_(sub_token_tile_, sub_token_onshot_) \
#define MOE_SORTING_DISPATCH_(sub_token_tile_, sub_token_onshot_, local_expert_masking_) \
constexpr ck_tile::index_t sub_token_tile = sub_token_tile_; \ constexpr ck_tile::index_t sub_token_tile = sub_token_tile_; \
constexpr bool sub_token_onshot = sub_token_onshot_; \ constexpr bool sub_token_onshot = sub_token_onshot_; \
using ms_problem = \ constexpr bool local_expert_masking = local_expert_masking_; \
ck_tile::MoeSortingProblemEx<index_t, ms_weight_type, sub_token_tile, sub_token_onshot>; \ using ms_problem = ck_tile::MoeSortingProblemEx<index_t, \
ms_weight_type, \
sub_token_tile, \
sub_token_onshot, \
local_expert_masking>; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \ using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \ auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \ const dim3 grids = kernel::GridSize(a); \
...@@ -38,6 +43,44 @@ ...@@ -38,6 +43,44 @@
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \ s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time; return ave_time;
#define MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_) \
if(row_ % 8 == 0) \
{ \
MOE_SORTING_DISPATCH_(8, sub_token_onshot_, local_expert_masking_); \
} \
else if(row_ % 4 == 0) \
{ \
MOE_SORTING_DISPATCH_(4, sub_token_onshot_, local_expert_masking_); \
} \
else if(row_ % 2 == 0) \
{ \
MOE_SORTING_DISPATCH_(2, sub_token_onshot_, local_expert_masking_); \
} \
else \
{ \
MOE_SORTING_DISPATCH_(1, sub_token_onshot_, local_expert_masking_); \
}
#define MOE_SORTING_DISPATCH_SUBTO_(row_, local_expert_masking_) \
if(is_sub_token_onshot) \
{ \
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, true, local_expert_masking_) \
} \
else \
{ \
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, false, local_expert_masking_) \
}
#define MOE_SORTING_DISPATCH_EMASK_(row_) \
if(is_local_expert_masking) \
{ \
MOE_SORTING_DISPATCH_SUBTO_(row_, true) \
} \
else \
{ \
MOE_SORTING_DISPATCH_SUBTO_(row_, false) \
}
#endif #endif
#if !MOE_SORTING_USE_EX_KERNEL #if !MOE_SORTING_USE_EX_KERNEL
...@@ -116,45 +159,10 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til ...@@ -116,45 +159,10 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
auto sub_token_ = r_ - 2; auto sub_token_ = r_ - 2;
r_ = (r_ - 2) / 8; r_ = (r_ - 2) / 8;
bool is_sub_token_onshot = a.tokens <= sub_token_; bool is_sub_token_onshot = a.tokens <= sub_token_;
bool is_local_expert_masking = t.local_expert_masking;
(void)c_; (void)c_;
if(is_sub_token_onshot)
{ MOE_SORTING_DISPATCH_EMASK_(r_);
if(r_ % 8 == 0)
{
MOE_SORTING_DISPATCH_(8, true);
}
else if(r_ % 4 == 0)
{
MOE_SORTING_DISPATCH_(4, true);
}
else if(r_ % 2 == 0)
{
MOE_SORTING_DISPATCH_(2, true);
}
else
{
MOE_SORTING_DISPATCH_(1, true);
}
}
else
{
if(r_ % 8 == 0)
{
MOE_SORTING_DISPATCH_(8, false);
}
else if(r_ % 4 == 0)
{
MOE_SORTING_DISPATCH_(4, false);
}
else if(r_ % 2 == 0)
{
MOE_SORTING_DISPATCH_(2, false);
}
else
{
MOE_SORTING_DISPATCH_(1, false);
}
}
// MOE_SORTING_DISPATCH_ETILE(0, 0); // MOE_SORTING_DISPATCH_ETILE(0, 0);
#endif #endif
} }
......
...@@ -162,6 +162,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -162,6 +162,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
int tp = arg_parser.get_int("tp"); int tp = arg_parser.get_int("tp");
int init = arg_parser.get_int("init"); int init = arg_parser.get_int("init");
uint32_t seed = arg_parser.get_uint32("seed"); uint32_t seed = arg_parser.get_uint32("seed");
bool local_expert_masking = false; // TODO...
// w0 (Gate+Up or Gate only, N size) // w0 (Gate+Up or Gate only, N size)
ck_tile::index_t shared_intermediate_size_0 = intermediate_size * (gate_only ? 1 : 2) / tp; ck_tile::index_t shared_intermediate_size_0 = intermediate_size * (gate_only ? 1 : 2) / tp;
...@@ -230,6 +231,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -230,6 +231,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<YSmoothScaleDataType> sy_host({shared_intermediate_size_1}); // smooth-quant ck_tile::HostTensor<YSmoothScaleDataType> sy_host({shared_intermediate_size_1}); // smooth-quant
ck_tile::HostTensor<IndexDataType> topk_ids_host({tokens, topk}); // to be sort ck_tile::HostTensor<IndexDataType> topk_ids_host({tokens, topk}); // to be sort
ck_tile::HostTensor<TopkWeightDataType> topk_weight_host({tokens, topk}); // to be sort ck_tile::HostTensor<TopkWeightDataType> topk_weight_host({tokens, topk}); // to be sort
ck_tile::HostTensor<IndexDataType> local_expert_mask_host({experts});
int max_num_tokens_padded = topk * tokens + experts * block_m - topk; int max_num_tokens_padded = topk * tokens + experts * block_m - topk;
ck_tile::HostTensor<IndexDataType> sorted_token_ids_host({max_num_tokens_padded}); ck_tile::HostTensor<IndexDataType> sorted_token_ids_host({max_num_tokens_padded});
...@@ -355,6 +357,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -355,6 +357,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::DeviceMem sg_buf(sg_host); ck_tile::DeviceMem sg_buf(sg_host);
ck_tile::DeviceMem sd_buf(sd_host); ck_tile::DeviceMem sd_buf(sd_host);
ck_tile::DeviceMem sy_buf(sy_host); ck_tile::DeviceMem sy_buf(sy_host);
ck_tile::DeviceMem local_expert_mask_buf(local_expert_mask_host);
ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem topk_ids_buf(topk_ids_host); ck_tile::DeviceMem topk_ids_buf(topk_ids_host);
...@@ -378,7 +381,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -378,7 +381,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
block_m, block_m,
activation, activation,
gate_only, gate_only,
fused_quant}; fused_quant,
local_expert_masking};
fused_moe_args args{a_buf.GetDeviceBuffer(), fused_moe_args args{a_buf.GetDeviceBuffer(),
fused_quant != 0 ? sa_buf.GetDeviceBuffer() : nullptr, fused_quant != 0 ? sa_buf.GetDeviceBuffer() : nullptr,
...@@ -387,6 +391,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -387,6 +391,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
fused_quant != 0 ? sg_buf.GetDeviceBuffer() : nullptr, fused_quant != 0 ? sg_buf.GetDeviceBuffer() : nullptr,
fused_quant != 0 ? sd_buf.GetDeviceBuffer() : nullptr, fused_quant != 0 ? sd_buf.GetDeviceBuffer() : nullptr,
fused_quant == 1 ? sy_buf.GetDeviceBuffer() : nullptr, fused_quant == 1 ? sy_buf.GetDeviceBuffer() : nullptr,
local_expert_masking ? local_expert_mask_buf.GetDeviceBuffer()
: nullptr,
o_buf.GetDeviceBuffer(), o_buf.GetDeviceBuffer(),
topk_ids_buf.GetDeviceBuffer(), topk_ids_buf.GetDeviceBuffer(),
topk_weight_buf.GetDeviceBuffer(), topk_weight_buf.GetDeviceBuffer(),
...@@ -442,12 +448,14 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -442,12 +448,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::reference_moe_sorting<TopkWeightDataType, IndexDataType>( ck_tile::reference_moe_sorting<TopkWeightDataType, IndexDataType>(
topk_ids_host, topk_ids_host,
topk_weight_host, topk_weight_host,
local_expert_mask_host,
sorted_token_ids_host, sorted_token_ids_host,
sorted_weight_host, sorted_weight_host,
sorted_expert_ids_host, sorted_expert_ids_host,
num_sorted_tiles_host.mData[0], num_sorted_tiles_host.mData[0],
experts, experts,
block_m); block_m,
local_expert_masking);
if(activation == 0) if(activation == 0)
{ {
CPU_FUSED_MOE(ck_tile::element_wise::Gelu); CPU_FUSED_MOE(ck_tile::element_wise::Gelu);
...@@ -472,12 +480,14 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -472,12 +480,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::reference_moe_sorting<TopkWeightDataType, IndexDataType>( ck_tile::reference_moe_sorting<TopkWeightDataType, IndexDataType>(
topk_ids_host, topk_ids_host,
topk_weight_host, topk_weight_host,
local_expert_mask_host,
sorted_token_ids_host, sorted_token_ids_host,
sorted_weight_host, sorted_weight_host,
sorted_expert_ids_host, sorted_expert_ids_host,
num_sorted_tiles_host.mData[0], num_sorted_tiles_host.mData[0],
experts, experts,
block_m); block_m,
local_expert_masking);
// done, preparing GPU buffer // done, preparing GPU buffer
ck_tile::DeviceMem a_buf(a_host); ck_tile::DeviceMem a_buf(a_host);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment