Unverified Commit 0e5e29c4 authored by valarLip's avatar valarLip Committed by GitHub
Browse files

porting fmoe_sorting from moe_sorting (#1884)

* porting fmoe_sorting from moe_sorting

* pass default example test

* remod
parent 16fa63ea
...@@ -8,14 +8,15 @@ ...@@ -8,14 +8,15 @@
struct fused_moe_args struct fused_moe_args
{ {
const void* a_ptr; // [m, k], input token const void* a_ptr; // [m, k], input token
const void* a_scale_ptr; // [m, 1], token scale const void* a_scale_ptr; // [m, 1], token scale
const void* g_ptr; // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w]) const void* g_ptr; // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w])
const void* d_ptr; // [e, n, k], pre-shuffle([e, nr, kr, w]) const void* d_ptr; // [e, n, k], pre-shuffle([e, nr, kr, w])
const void* g_scale_ptr; // [e, 1, n], gate(up) scale const void* g_scale_ptr; // [e, 1, n], gate(up) scale
const void* d_scale_ptr; // [e, 1, k], down scale const void* d_scale_ptr; // [e, 1, k], down scale
const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input
void* o_ptr; // [m, k], output token (no need to do zeroing) const void* local_expert_mask_ptr; // [e], local_expert_mask_ptr for EP
void* o_ptr; // [m, k], output token (no need to do zeroing)
const void* topk_ids_ptr; // [tokens, topk] const void* topk_ids_ptr; // [tokens, topk]
const void* topk_weight_ptr; // [tokens, topk] const void* topk_weight_ptr; // [tokens, topk]
...@@ -48,6 +49,8 @@ struct fused_moe_traits ...@@ -48,6 +49,8 @@ struct fused_moe_traits
int activation; // 0:gelu, 1:silu int activation; // 0:gelu, 1:silu
int gate_only; // 0:g1u0, 1:g1u1 int gate_only; // 0:g1u0, 1:g1u1
int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
bool local_expert_masking; // if mask experts as local expert
}; };
float fused_moe(fused_moe_traits, fused_moe_args, const ck_tile::stream_config&); float fused_moe(fused_moe_traits, fused_moe_args, const ck_tile::stream_config&);
...@@ -10,7 +10,8 @@ ...@@ -10,7 +10,8 @@
struct fused_moesorting_trait struct fused_moesorting_trait
{ {
std::string index_type; std::string index_type;
std::string weight_type; // currently always float std::string weight_type; // currently always float
bool local_expert_masking; // if mask experts as local expert
}; };
struct fused_moesorting_args : public ck_tile::MoeSortingHostArgs struct fused_moesorting_args : public ck_tile::MoeSortingHostArgs
......
...@@ -17,10 +17,11 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf ...@@ -17,10 +17,11 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf
return 1; return 1;
}(); }();
auto t0 = fused_moesorting_trait{"int32", "fp32"}; auto t0 = fused_moesorting_trait{"int32", "fp32", t.local_expert_masking};
auto a0 = fused_moesorting_args{ auto a0 = fused_moesorting_args{
a.topk_ids_ptr, // const void* p_topk_ids; a.topk_ids_ptr, // const void* p_topk_ids;
a.topk_weight_ptr, // const void* p_weights; a.topk_weight_ptr, // const void* p_weights;
a.local_expert_mask_ptr, // const void* p_local_expert_mask;
a.sorted_token_ids_ptr, // void* p_sorted_token_ids; a.sorted_token_ids_ptr, // void* p_sorted_token_ids;
a.sorted_weight_ptr, // void* p_sorted_weights; a.sorted_weight_ptr, // void* p_sorted_weights;
a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids; a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids;
......
...@@ -24,20 +24,63 @@ ...@@ -24,20 +24,63 @@
return ave_time; return ave_time;
#else #else
#define MOE_SORTING_DISPATCH_(sub_token_tile_, sub_token_onshot_) \
constexpr ck_tile::index_t sub_token_tile = sub_token_tile_; \ #define MOE_SORTING_DISPATCH_(sub_token_tile_, sub_token_onshot_, local_expert_masking_) \
constexpr bool sub_token_onshot = sub_token_onshot_; \ constexpr ck_tile::index_t sub_token_tile = sub_token_tile_; \
using ms_problem = \ constexpr bool sub_token_onshot = sub_token_onshot_; \
ck_tile::MoeSortingProblemEx<index_t, ms_weight_type, sub_token_tile, sub_token_onshot>; \ constexpr bool local_expert_masking = local_expert_masking_; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \ using ms_problem = ck_tile::MoeSortingProblemEx<index_t, \
auto kargs = kernel::MakeKargs(a); \ ms_weight_type, \
const dim3 grids = kernel::GridSize(a); \ sub_token_tile, \
const dim3 blocks = kernel::BlockSize(a); \ sub_token_onshot, \
const auto lds_bytes = kernel::GetSmemSize(a); \ local_expert_masking>; \
float ave_time = ck_tile::launch_kernel( \ using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \ auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
const auto lds_bytes = kernel::GetSmemSize(a); \
float ave_time = ck_tile::launch_kernel( \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time; return ave_time;
#define MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_) \
if(row_ % 8 == 0) \
{ \
MOE_SORTING_DISPATCH_(8, sub_token_onshot_, local_expert_masking_); \
} \
else if(row_ % 4 == 0) \
{ \
MOE_SORTING_DISPATCH_(4, sub_token_onshot_, local_expert_masking_); \
} \
else if(row_ % 2 == 0) \
{ \
MOE_SORTING_DISPATCH_(2, sub_token_onshot_, local_expert_masking_); \
} \
else \
{ \
MOE_SORTING_DISPATCH_(1, sub_token_onshot_, local_expert_masking_); \
}
#define MOE_SORTING_DISPATCH_SUBTO_(row_, local_expert_masking_) \
if(is_sub_token_onshot) \
{ \
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, true, local_expert_masking_) \
} \
else \
{ \
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, false, local_expert_masking_) \
}
#define MOE_SORTING_DISPATCH_EMASK_(row_) \
if(is_local_expert_masking) \
{ \
MOE_SORTING_DISPATCH_SUBTO_(row_, true) \
} \
else \
{ \
MOE_SORTING_DISPATCH_SUBTO_(row_, false) \
}
#endif #endif
#if !MOE_SORTING_USE_EX_KERNEL #if !MOE_SORTING_USE_EX_KERNEL
...@@ -116,45 +159,10 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til ...@@ -116,45 +159,10 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
auto sub_token_ = r_ - 2; auto sub_token_ = r_ - 2;
r_ = (r_ - 2) / 8; r_ = (r_ - 2) / 8;
bool is_sub_token_onshot = a.tokens <= sub_token_; bool is_sub_token_onshot = a.tokens <= sub_token_;
bool is_local_expert_masking = t.local_expert_masking;
(void)c_; (void)c_;
if(is_sub_token_onshot)
{ MOE_SORTING_DISPATCH_EMASK_(r_);
if(r_ % 8 == 0)
{
MOE_SORTING_DISPATCH_(8, true);
}
else if(r_ % 4 == 0)
{
MOE_SORTING_DISPATCH_(4, true);
}
else if(r_ % 2 == 0)
{
MOE_SORTING_DISPATCH_(2, true);
}
else
{
MOE_SORTING_DISPATCH_(1, true);
}
}
else
{
if(r_ % 8 == 0)
{
MOE_SORTING_DISPATCH_(8, false);
}
else if(r_ % 4 == 0)
{
MOE_SORTING_DISPATCH_(4, false);
}
else if(r_ % 2 == 0)
{
MOE_SORTING_DISPATCH_(2, false);
}
else
{
MOE_SORTING_DISPATCH_(1, false);
}
}
// MOE_SORTING_DISPATCH_ETILE(0, 0); // MOE_SORTING_DISPATCH_ETILE(0, 0);
#endif #endif
} }
......
...@@ -140,28 +140,29 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -140,28 +140,29 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::index_t activation = arg_parser.get_int("act"); ck_tile::index_t activation = arg_parser.get_int("act");
if(stride < 0) if(stride < 0)
stride = hidden_size; stride = hidden_size;
std::string prec_i = arg_parser.get_str("prec_i"); std::string prec_i = arg_parser.get_str("prec_i");
std::string prec_w = arg_parser.get_str("prec_w"); std::string prec_w = arg_parser.get_str("prec_w");
std::string prec_o = arg_parser.get_str("prec_o"); std::string prec_o = arg_parser.get_str("prec_o");
std::string prec_st = arg_parser.get_str("prec_st"); std::string prec_st = arg_parser.get_str("prec_st");
std::string prec_sw = arg_parser.get_str("prec_sw"); std::string prec_sw = arg_parser.get_str("prec_sw");
std::string prec_sq = arg_parser.get_str("prec_sq"); std::string prec_sq = arg_parser.get_str("prec_sq");
std::string prec_kw = arg_parser.get_str("prec_kw"); std::string prec_kw = arg_parser.get_str("prec_kw");
prec_st = (prec_st == "auto") ? "fp32" : prec_st; prec_st = (prec_st == "auto") ? "fp32" : prec_st;
prec_sw = (prec_sw == "auto") ? "fp32" : prec_sw; prec_sw = (prec_sw == "auto") ? "fp32" : prec_sw;
prec_sq = (prec_sq == "auto") ? "fp32" : prec_sq; prec_sq = (prec_sq == "auto") ? "fp32" : prec_sq;
prec_kw = (prec_kw == "auto") ? "fp32" : prec_kw; prec_kw = (prec_kw == "auto") ? "fp32" : prec_kw;
int kname = arg_parser.get_int("kname"); int kname = arg_parser.get_int("kname");
int do_validation = arg_parser.get_int("v"); int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup"); int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat"); int repeat = arg_parser.get_int("repeat");
int fused_quant = arg_parser.get_int("fquant"); int fused_quant = arg_parser.get_int("fquant");
int gate_only = arg_parser.get_int("gate_only"); int gate_only = arg_parser.get_int("gate_only");
int api = arg_parser.get_int("api"); int api = arg_parser.get_int("api");
int balance = arg_parser.get_int("balance"); int balance = arg_parser.get_int("balance");
int tp = arg_parser.get_int("tp"); int tp = arg_parser.get_int("tp");
int init = arg_parser.get_int("init"); int init = arg_parser.get_int("init");
uint32_t seed = arg_parser.get_uint32("seed"); uint32_t seed = arg_parser.get_uint32("seed");
bool local_expert_masking = false; // TODO...
// w0 (Gate+Up or Gate only, N size) // w0 (Gate+Up or Gate only, N size)
ck_tile::index_t shared_intermediate_size_0 = intermediate_size * (gate_only ? 1 : 2) / tp; ck_tile::index_t shared_intermediate_size_0 = intermediate_size * (gate_only ? 1 : 2) / tp;
...@@ -230,6 +231,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -230,6 +231,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<YSmoothScaleDataType> sy_host({shared_intermediate_size_1}); // smooth-quant ck_tile::HostTensor<YSmoothScaleDataType> sy_host({shared_intermediate_size_1}); // smooth-quant
ck_tile::HostTensor<IndexDataType> topk_ids_host({tokens, topk}); // to be sort ck_tile::HostTensor<IndexDataType> topk_ids_host({tokens, topk}); // to be sort
ck_tile::HostTensor<TopkWeightDataType> topk_weight_host({tokens, topk}); // to be sort ck_tile::HostTensor<TopkWeightDataType> topk_weight_host({tokens, topk}); // to be sort
ck_tile::HostTensor<IndexDataType> local_expert_mask_host({experts});
int max_num_tokens_padded = topk * tokens + experts * block_m - topk; int max_num_tokens_padded = topk * tokens + experts * block_m - topk;
ck_tile::HostTensor<IndexDataType> sorted_token_ids_host({max_num_tokens_padded}); ck_tile::HostTensor<IndexDataType> sorted_token_ids_host({max_num_tokens_padded});
...@@ -355,6 +357,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -355,6 +357,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::DeviceMem sg_buf(sg_host); ck_tile::DeviceMem sg_buf(sg_host);
ck_tile::DeviceMem sd_buf(sd_host); ck_tile::DeviceMem sd_buf(sd_host);
ck_tile::DeviceMem sy_buf(sy_host); ck_tile::DeviceMem sy_buf(sy_host);
ck_tile::DeviceMem local_expert_mask_buf(local_expert_mask_host);
ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem topk_ids_buf(topk_ids_host); ck_tile::DeviceMem topk_ids_buf(topk_ids_host);
...@@ -378,7 +381,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -378,7 +381,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
block_m, block_m,
activation, activation,
gate_only, gate_only,
fused_quant}; fused_quant,
local_expert_masking};
fused_moe_args args{a_buf.GetDeviceBuffer(), fused_moe_args args{a_buf.GetDeviceBuffer(),
fused_quant != 0 ? sa_buf.GetDeviceBuffer() : nullptr, fused_quant != 0 ? sa_buf.GetDeviceBuffer() : nullptr,
...@@ -387,6 +391,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -387,6 +391,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
fused_quant != 0 ? sg_buf.GetDeviceBuffer() : nullptr, fused_quant != 0 ? sg_buf.GetDeviceBuffer() : nullptr,
fused_quant != 0 ? sd_buf.GetDeviceBuffer() : nullptr, fused_quant != 0 ? sd_buf.GetDeviceBuffer() : nullptr,
fused_quant == 1 ? sy_buf.GetDeviceBuffer() : nullptr, fused_quant == 1 ? sy_buf.GetDeviceBuffer() : nullptr,
local_expert_masking ? local_expert_mask_buf.GetDeviceBuffer()
: nullptr,
o_buf.GetDeviceBuffer(), o_buf.GetDeviceBuffer(),
topk_ids_buf.GetDeviceBuffer(), topk_ids_buf.GetDeviceBuffer(),
topk_weight_buf.GetDeviceBuffer(), topk_weight_buf.GetDeviceBuffer(),
...@@ -442,12 +448,14 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -442,12 +448,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::reference_moe_sorting<TopkWeightDataType, IndexDataType>( ck_tile::reference_moe_sorting<TopkWeightDataType, IndexDataType>(
topk_ids_host, topk_ids_host,
topk_weight_host, topk_weight_host,
local_expert_mask_host,
sorted_token_ids_host, sorted_token_ids_host,
sorted_weight_host, sorted_weight_host,
sorted_expert_ids_host, sorted_expert_ids_host,
num_sorted_tiles_host.mData[0], num_sorted_tiles_host.mData[0],
experts, experts,
block_m); block_m,
local_expert_masking);
if(activation == 0) if(activation == 0)
{ {
CPU_FUSED_MOE(ck_tile::element_wise::Gelu); CPU_FUSED_MOE(ck_tile::element_wise::Gelu);
...@@ -472,12 +480,14 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -472,12 +480,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::reference_moe_sorting<TopkWeightDataType, IndexDataType>( ck_tile::reference_moe_sorting<TopkWeightDataType, IndexDataType>(
topk_ids_host, topk_ids_host,
topk_weight_host, topk_weight_host,
local_expert_mask_host,
sorted_token_ids_host, sorted_token_ids_host,
sorted_weight_host, sorted_weight_host,
sorted_expert_ids_host, sorted_expert_ids_host,
num_sorted_tiles_host.mData[0], num_sorted_tiles_host.mData[0],
experts, experts,
block_m); block_m,
local_expert_masking);
// done, preparing GPU buffer // done, preparing GPU buffer
ck_tile::DeviceMem a_buf(a_host); ck_tile::DeviceMem a_buf(a_host);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment