Unverified Commit c0adab48 authored by carlushuang's avatar carlushuang Committed by GitHub
Browse files

[CK_TILE] moe sorting ex kernel to support expert > 128 (#1840)

* moe sorting ex

* fix bug for race condition

* fix bug and optimze large expert

* fix

* optimize with sub_token_oneshot

* support skip empty tokens for expert sorting

* update moe_sorting

* tidy code
parent 2312f4aa
...@@ -26,6 +26,10 @@ auto create_args(int argc, char* argv[]) ...@@ -26,6 +26,10 @@ auto create_args(int argc, char* argv[])
.insert("k", "4", "topk") .insert("k", "4", "topk")
.insert("unit", "32", "unit_size") .insert("unit", "32", "unit_size")
.insert("moe_buf_size", "0", "moe_buf_size") .insert("moe_buf_size", "0", "moe_buf_size")
.insert("local_eid",
"-1",
"a list of experts enabled as local expert. e.g. \"0,1,4,5\"\n"
"please make sure eid is in ascending order!")
.insert("seed", "-1", "seed to be used, -1 means random every time") .insert("seed", "-1", "seed to be used, -1 means random every time")
.insert("kname", "0", "when set to 1 it will print kernel name") .insert("kname", "0", "when set to 1 it will print kernel name")
.insert("warmup", "5", "number of iterations before benchmark the kernel") .insert("warmup", "5", "number of iterations before benchmark the kernel")
...@@ -74,6 +78,7 @@ bool test_moe_sorting(ck_tile::ArgParser args) ...@@ -74,6 +78,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
int kname = args.get_int("kname"); int kname = args.get_int("kname");
int warmup = args.get_int("warmup"); int warmup = args.get_int("warmup");
int repeat = args.get_int("repeat"); int repeat = args.get_int("repeat");
int max_output_ids = int max_output_ids =
ck_tile::integer_least_multiple(topk * tokens + num_experts * unit_size - topk, unit_size); ck_tile::integer_least_multiple(topk * tokens + num_experts * unit_size - topk, unit_size);
...@@ -90,6 +95,30 @@ bool test_moe_sorting(ck_tile::ArgParser args) ...@@ -90,6 +95,30 @@ bool test_moe_sorting(ck_tile::ArgParser args)
return false; return false;
} }
bool local_expert_masking = args.get_str("local_eid") != "-1";
auto local_expert_masking_host = [&]() {
if(local_expert_masking)
{
auto local_eid = args.get_int_vec("local_eid");
// std::vector<int> v_ {num_experts, 0};
ck_tile::HostTensor<IndexType> v_{{num_experts}};
v_.SetZero();
for(auto eid : local_eid)
{
if(eid >= num_experts)
{
throw std::runtime_error(
"local_eid larger than number of expert, please check");
}
v_.mData[eid] = 1;
}
return v_;
}
else
// return std::vector<int>{};
return ck_tile::HostTensor<IndexType>{{1}};
}();
// tokens already considered batch size // tokens already considered batch size
ck_tile::HostTensor<IndexType> topk_ids_host({tokens, topk}, {topk, 1}); ck_tile::HostTensor<IndexType> topk_ids_host({tokens, topk}, {topk, 1});
ck_tile::HostTensor<WeightType> weights_host({tokens, topk}, {topk, 1}); ck_tile::HostTensor<WeightType> weights_host({tokens, topk}, {topk, 1});
...@@ -111,6 +140,8 @@ bool test_moe_sorting(ck_tile::ArgParser args) ...@@ -111,6 +140,8 @@ bool test_moe_sorting(ck_tile::ArgParser args)
sorted_expert_ids_host.get_element_space_size_in_bytes()); sorted_expert_ids_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem sorted_id_cnt_dev(sorted_id_cnt_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem sorted_id_cnt_dev(sorted_id_cnt_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem moe_buf_dev(moe_buf_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem moe_buf_dev(moe_buf_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem local_expert_masking_dev(
local_expert_masking_host.get_element_space_size_in_bytes());
topk_ids_dev.ToDevice(topk_ids_host.data()); topk_ids_dev.ToDevice(topk_ids_host.data());
weights_dev.ToDevice(weights_host.data()); weights_dev.ToDevice(weights_host.data());
...@@ -118,11 +149,15 @@ bool test_moe_sorting(ck_tile::ArgParser args) ...@@ -118,11 +149,15 @@ bool test_moe_sorting(ck_tile::ArgParser args)
{ {
moe_buf_dev.ToDevice(moe_buf_host.data()); moe_buf_dev.ToDevice(moe_buf_host.data());
} }
if(local_expert_masking)
local_expert_masking_dev.ToDevice(local_expert_masking_host.data());
moe_sorting_trait trait{index_prec, weight_prec}; moe_sorting_trait trait{index_prec, weight_prec, local_expert_masking};
moe_sorting_args karg{topk_ids_dev.GetDeviceBuffer(), moe_sorting_args karg{topk_ids_dev.GetDeviceBuffer(),
weights_dev.GetDeviceBuffer(), weights_dev.GetDeviceBuffer(),
local_expert_masking ? local_expert_masking_dev.GetDeviceBuffer()
: nullptr,
sorted_ids_dev.GetDeviceBuffer(), sorted_ids_dev.GetDeviceBuffer(),
sorted_weights_dev.GetDeviceBuffer(), sorted_weights_dev.GetDeviceBuffer(),
sorted_expert_ids_dev.GetDeviceBuffer(), sorted_expert_ids_dev.GetDeviceBuffer(),
...@@ -140,15 +175,22 @@ bool test_moe_sorting(ck_tile::ArgParser args) ...@@ -140,15 +175,22 @@ bool test_moe_sorting(ck_tile::ArgParser args)
warmup, warmup,
repeat}; repeat};
auto ms = moe_sorting(trait, karg, sc); auto ms = moe_sorting(trait, karg, sc);
printf("[%s|%s]tokens:%d, num_experts:%d, topk:%d, ms:%f , ", printf("[%s|%s]tokens:%d, num_experts:%d, topk:%d, ",
index_prec.c_str(), index_prec.c_str(),
weight_prec.c_str(), weight_prec.c_str(),
tokens, tokens,
num_experts, num_experts,
topk, topk);
ms);
if(local_expert_masking)
{
printf("local_eid:%s, ", args.get_str("local_eid").c_str());
}
if(ms < 0) if(ms < 0)
printf("not supported\n"); printf("not supported\n");
else
printf("ms:%f, ", ms);
fflush(stdout); fflush(stdout);
if(ms < 0) if(ms < 0)
{ {
...@@ -174,12 +216,14 @@ bool test_moe_sorting(ck_tile::ArgParser args) ...@@ -174,12 +216,14 @@ bool test_moe_sorting(ck_tile::ArgParser args)
int32_t ref_total_tokens_post_pad = 0; int32_t ref_total_tokens_post_pad = 0;
ck_tile::reference_moe_sorting<WeightType, IndexType>(topk_ids_host, ck_tile::reference_moe_sorting<WeightType, IndexType>(topk_ids_host,
weights_host, weights_host,
local_expert_masking_host,
sorted_ids_ref, sorted_ids_ref,
sorted_weights_ref, sorted_weights_ref,
sorted_expert_ids_ref, sorted_expert_ids_ref,
ref_total_tokens_post_pad, ref_total_tokens_post_pad,
num_experts, num_experts,
unit_size); unit_size,
local_expert_masking);
rtn &= ck_tile::check_err( rtn &= ck_tile::check_err(
sorted_ids_host, sorted_ids_ref, std::string("OUT Error: Incorrect ids!"), 1e-6, 1e-6); sorted_ids_host, sorted_ids_ref, std::string("OUT Error: Incorrect ids!"), 1e-6, 1e-6);
rtn &= ck_tile::check_err(sorted_weights_host, rtn &= ck_tile::check_err(sorted_weights_host,
...@@ -199,9 +243,16 @@ bool test_moe_sorting(ck_tile::ArgParser args) ...@@ -199,9 +243,16 @@ bool test_moe_sorting(ck_tile::ArgParser args)
moe_buf_host, moe_buf_ref, std::string("OUT Error: Incorrect zero buf!"), 0, 0); moe_buf_host, moe_buf_ref, std::string("OUT Error: Incorrect zero buf!"), 0, 0);
} }
rtn &= ref_total_tokens_post_pad == sorted_id_cnt_host.mData[0]; rtn &= ref_total_tokens_post_pad == sorted_id_cnt_host.mData[0];
printf("total_tokens_post_pad:%d(%d), ",
ref_total_tokens_post_pad,
sorted_id_cnt_host.mData[0]);
} }
printf("valid:%s\n", rtn ? "y" : "n"); printf("valid:%s", rtn ? "y" : "n");
fflush(stdout);
if(!rtn)
printf(", (%d)", seed);
printf("\n");
fflush(stdout); fflush(stdout);
return rtn; return rtn;
} }
......
...@@ -3,6 +3,12 @@ ...@@ -3,6 +3,12 @@
#include "moe_sorting_api.hpp" #include "moe_sorting_api.hpp"
#ifndef MOE_SORTING_USE_EX_KERNEL
#define MOE_SORTING_USE_EX_KERNEL 1
#endif
#if !MOE_SORTING_USE_EX_KERNEL
#define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \ #define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \
constexpr ck_tile::index_t unroll_num = unroll_num_; \ constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr ck_tile::index_t expert_tile = expert_tile_; \ constexpr ck_tile::index_t expert_tile = expert_tile_; \
...@@ -17,6 +23,67 @@ ...@@ -17,6 +23,67 @@
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \ s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time; return ave_time;
#else
#define MOE_SORTING_DISPATCH_(sub_token_tile_, sub_token_onshot_, local_expert_masking_) \
constexpr ck_tile::index_t sub_token_tile = sub_token_tile_; \
constexpr bool sub_token_onshot = sub_token_onshot_; \
constexpr bool local_expert_masking = local_expert_masking_; \
using ms_problem = ck_tile::MoeSortingProblemEx<index_t, \
ms_weight_type, \
sub_token_tile, \
sub_token_onshot, \
local_expert_masking>; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
const auto lds_bytes = kernel::GetSmemSize(a); \
float ave_time = ck_tile::launch_kernel( \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time;
#define MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_) \
if(row_ % 8 == 0) \
{ \
MOE_SORTING_DISPATCH_(8, sub_token_onshot_, local_expert_masking_); \
} \
else if(row_ % 4 == 0) \
{ \
MOE_SORTING_DISPATCH_(4, sub_token_onshot_, local_expert_masking_); \
} \
else if(row_ % 2 == 0) \
{ \
MOE_SORTING_DISPATCH_(2, sub_token_onshot_, local_expert_masking_); \
} \
else \
{ \
MOE_SORTING_DISPATCH_(1, sub_token_onshot_, local_expert_masking_); \
}
#define MOE_SORTING_DISPATCH_SUBTO_(row_, local_expert_masking_) \
if(is_sub_token_onshot) \
{ \
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, true, local_expert_masking_) \
} \
else \
{ \
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, false, local_expert_masking_) \
}
#define MOE_SORTING_DISPATCH_EMASK_(row_) \
if(is_local_expert_masking) \
{ \
MOE_SORTING_DISPATCH_SUBTO_(row_, true) \
} \
else \
{ \
MOE_SORTING_DISPATCH_SUBTO_(row_, false) \
}
#endif
#if !MOE_SORTING_USE_EX_KERNEL
#define MOE_SORTING_DISPATCH(unroll_num_) \ #define MOE_SORTING_DISPATCH(unroll_num_) \
if(a.num_experts <= 8) \ if(a.num_experts <= 8) \
{ \ { \
...@@ -38,11 +105,13 @@ ...@@ -38,11 +105,13 @@
{ \ { \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \ MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \
} }
#endif
float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s) float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s)
{ {
if(t.weight_type == "fp32" && t.index_type == "int32") if(t.weight_type == "fp32" && t.index_type == "int32")
{ {
#if !MOE_SORTING_USE_EX_KERNEL
if(a.num_experts > 127) if(a.num_experts > 127)
{ {
printf("lds size exceed, only support experts <127 \n"); printf("lds size exceed, only support experts <127 \n");
...@@ -83,6 +152,19 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi ...@@ -83,6 +152,19 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
MOE_SORTING_DISPATCH(4); MOE_SORTING_DISPATCH(4);
} }
} }
#else
using index_t = ck_tile::index_t;
using ms_weight_type = float;
auto [r_, c_] = ck_tile::moe_sorting_get_smem_row_col(a.tokens, a.num_experts);
auto sub_token_ = r_ - 2;
r_ = (r_ - 2) / 8;
bool is_sub_token_onshot = a.tokens <= sub_token_;
bool is_local_expert_masking = t.local_expert_masking;
(void)c_;
MOE_SORTING_DISPATCH_EMASK_(r_);
// MOE_SORTING_DISPATCH_ETILE(0, 0);
#endif
} }
return -1; return -1;
} }
...@@ -11,6 +11,7 @@ struct moe_sorting_trait ...@@ -11,6 +11,7 @@ struct moe_sorting_trait
{ {
std::string index_type; std::string index_type;
std::string weight_type; // currently always float std::string weight_type; // currently always float
bool local_expert_masking; // if mask experts as local expert
}; };
struct moe_sorting_args : public ck_tile::MoeSortingHostArgs struct moe_sorting_args : public ck_tile::MoeSortingHostArgs
......
...@@ -17,4 +17,12 @@ $EXE -t=71 -e=11 -k=11 ...@@ -17,4 +17,12 @@ $EXE -t=71 -e=11 -k=11
$EXE -t=1 -e=1 -k=1 $EXE -t=1 -e=1 -k=1
$EXE -t=99 -e=2 -k=1 $EXE -t=99 -e=2 -k=1
$EXE -t=333 -e=99 -k=13 $EXE -t=333 -e=99 -k=13
$EXE -t=11 -e=256 -k=5
$EXE -t=64 -e=455 -k=8
$EXE -t=777 -e=802 -k=99
$EXE -t=4097 -e=906 -k=51
$EXE -t=128 -e=32 -k=5 -moe_buf_size=262144 $EXE -t=128 -e=32 -k=5 -moe_buf_size=262144
$EXE -t=13 -e=64 -k=3 -local_eid=4,5,6,7,8,9,10,11
$EXE -t=99 -e=33 -k=9 -local_eid=6,10,11,15,19
$EXE -t=80 -e=99 -k=10 -local_eid=0,8,12,33
$EXE -t=11 -e=256 -k=5 -local_eid=99,110,129
...@@ -42,7 +42,7 @@ summary of the key design of this fused-moe operator: ...@@ -42,7 +42,7 @@ summary of the key design of this fused-moe operator:
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5 // (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]] // weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
// //
// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1) // max_num_tokens_padded : topk * input_tokens + num_experts * M_a - topk (updated)
// * this could be larger than actual, since actual tokens are on GPU // * this could be larger than actual, since actual tokens are on GPU
// //
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5] // sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
......
...@@ -3,6 +3,12 @@ ...@@ -3,6 +3,12 @@
#include "fused_moesorting.hpp" #include "fused_moesorting.hpp"
#ifndef MOE_SORTING_USE_EX_KERNEL
#define MOE_SORTING_USE_EX_KERNEL 1
#endif
#if !MOE_SORTING_USE_EX_KERNEL
#define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \ #define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \
constexpr ck_tile::index_t unroll_num = unroll_num_; \ constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr ck_tile::index_t expert_tile = expert_tile_; \ constexpr ck_tile::index_t expert_tile = expert_tile_; \
...@@ -17,6 +23,24 @@ ...@@ -17,6 +23,24 @@
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \ s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time; return ave_time;
#else
#define MOE_SORTING_DISPATCH_(sub_token_tile_, sub_token_onshot_) \
constexpr ck_tile::index_t sub_token_tile = sub_token_tile_; \
constexpr bool sub_token_onshot = sub_token_onshot_; \
using ms_problem = \
ck_tile::MoeSortingProblemEx<index_t, ms_weight_type, sub_token_tile, sub_token_onshot>; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
const auto lds_bytes = kernel::GetSmemSize(a); \
float ave_time = ck_tile::launch_kernel( \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time;
#endif
#if !MOE_SORTING_USE_EX_KERNEL
#define MOE_SORTING_DISPATCH(unroll_num_) \ #define MOE_SORTING_DISPATCH(unroll_num_) \
if(a.num_experts <= 8) \ if(a.num_experts <= 8) \
{ \ { \
...@@ -38,11 +62,13 @@ ...@@ -38,11 +62,13 @@
{ \ { \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \ MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \
} }
#endif
float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_tile::stream_config s) float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_tile::stream_config s)
{ {
if(t.weight_type == "fp32" && t.index_type == "int32") if(t.weight_type == "fp32" && t.index_type == "int32")
{ {
#if !MOE_SORTING_USE_EX_KERNEL
if(a.num_experts > 127) if(a.num_experts > 127)
{ {
printf("lds size exceed, only support experts <127 \n"); printf("lds size exceed, only support experts <127 \n");
...@@ -83,6 +109,54 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til ...@@ -83,6 +109,54 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
MOE_SORTING_DISPATCH(4); MOE_SORTING_DISPATCH(4);
} }
} }
#else
using index_t = ck_tile::index_t;
using ms_weight_type = float;
auto [r_, c_] = ck_tile::moe_sorting_get_smem_row_col(a.tokens, a.num_experts);
auto sub_token_ = r_ - 2;
r_ = (r_ - 2) / 8;
bool is_sub_token_onshot = a.tokens <= sub_token_;
(void)c_;
if(is_sub_token_onshot)
{
if(r_ % 8 == 0)
{
MOE_SORTING_DISPATCH_(8, true);
}
else if(r_ % 4 == 0)
{
MOE_SORTING_DISPATCH_(4, true);
}
else if(r_ % 2 == 0)
{
MOE_SORTING_DISPATCH_(2, true);
}
else
{
MOE_SORTING_DISPATCH_(1, true);
}
}
else
{
if(r_ % 8 == 0)
{
MOE_SORTING_DISPATCH_(8, false);
}
else if(r_ % 4 == 0)
{
MOE_SORTING_DISPATCH_(4, false);
}
else if(r_ % 2 == 0)
{
MOE_SORTING_DISPATCH_(2, false);
}
else
{
MOE_SORTING_DISPATCH_(1, false);
}
}
// MOE_SORTING_DISPATCH_ETILE(0, 0);
#endif
} }
return -1; return -1;
} }
...@@ -14,12 +14,15 @@ namespace ck_tile { ...@@ -14,12 +14,15 @@ namespace ck_tile {
template <typename WeightType, typename IndexType = index_t> template <typename WeightType, typename IndexType = index_t>
CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids, CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
const HostTensor<WeightType>& weights, const HostTensor<WeightType>& weights,
const HostTensor<IndexType>& local_expert_mask,
HostTensor<IndexType>& p_sorted_token_ids, HostTensor<IndexType>& p_sorted_token_ids,
HostTensor<WeightType>& sorted_weight, HostTensor<WeightType>& sorted_weight,
HostTensor<IndexType>& sorted_expert_ids, HostTensor<IndexType>& sorted_expert_ids,
index_t& unit_cnt, index_t& unit_cnt,
const index_t experts, const index_t experts,
const index_t unit_size) const index_t unit_size,
bool local_expert_masking,
bool skip_experts_with_zero_token = true)
{ {
const index_t num_token = topk_ids.mDesc.get_lengths()[0]; const index_t num_token = topk_ids.mDesc.get_lengths()[0];
const index_t topk = topk_ids.mDesc.get_lengths()[1]; const index_t topk = topk_ids.mDesc.get_lengths()[1];
...@@ -33,8 +36,11 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids, ...@@ -33,8 +36,11 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
#endif #endif
std::vector<std::vector<WeightType>> expert_token_weights( std::vector<std::vector<WeightType>> expert_token_weights(
experts, std::vector<WeightType>(unit_size, 0)); experts, std::vector<WeightType>(unit_size, 0));
// count number of unit-size slices in this expert
std::vector<IndexType> expert_slices(experts, 1); std::vector<IndexType> expert_slices(experts, 1);
// count the tokens used in this expert
std::vector<IndexType> expert_slice_idxs(experts, 0); std::vector<IndexType> expert_slice_idxs(experts, 0);
// TODO: above 2 buffer seems duplicated
for(index_t t = 0; t < num_token; t++) for(index_t t = 0; t < num_token; t++)
{ {
...@@ -72,8 +78,23 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids, ...@@ -72,8 +78,23 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
IndexType* out_tokens = p_sorted_token_ids.data(); IndexType* out_tokens = p_sorted_token_ids.data();
WeightType* out_weights = sorted_weight.data(); WeightType* out_weights = sorted_weight.data();
IndexType* out_expert_id = sorted_expert_ids.data(); IndexType* out_expert_id = sorted_expert_ids.data();
int curr_expert_id = 0;
for(index_t e = 0; e < experts; e++) for(index_t e = 0; e < experts; e++)
{ {
if(local_expert_masking)
{
if(local_expert_mask(e) == 0)
continue;
}
if(skip_experts_with_zero_token)
{
if(expert_slice_idxs[e] == 0)
{
curr_expert_id++;
continue;
}
}
memcpy(out_tokens, expert_tokens[e].data(), sizeof(index_t) * expert_slices[e] * unit_size); memcpy(out_tokens, expert_tokens[e].data(), sizeof(index_t) * expert_slices[e] * unit_size);
out_tokens += expert_slices[e] * unit_size; out_tokens += expert_slices[e] * unit_size;
memcpy(out_weights, memcpy(out_weights,
...@@ -83,10 +104,11 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids, ...@@ -83,10 +104,11 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
for(index_t s = 0; s < expert_slices[e]; s++) for(index_t s = 0; s < expert_slices[e]; s++)
{ {
out_expert_id[s] = e; out_expert_id[s] = curr_expert_id;
unit_cnt++; unit_cnt++;
} }
out_expert_id += expert_slices[e]; out_expert_id += expert_slices[e];
curr_expert_id++;
} }
unit_cnt *= unit_size; unit_cnt *= unit_size;
return; return;
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp" #include "ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp"
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp" #include "ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp"
#include "ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp" #include "ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp"
#include "ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp" #include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp" #include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp" #include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp"
...@@ -14,7 +15,6 @@ ...@@ -14,7 +15,6 @@
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp" #include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp" #include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp" #include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp" #include "ck_tile/ops/common/utils.hpp"
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5 // (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]] // weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
// //
// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1) // max_num_tokens_padded : topk * input_tokens + num_experts * M_a - topk (updated)
// * this could be larger than actual, since actual tokens are on GPU // * this could be larger than actual, since actual tokens are on GPU
// //
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5] // sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
......
...@@ -15,6 +15,10 @@ namespace ck_tile { ...@@ -15,6 +15,10 @@ namespace ck_tile {
#define MOE_SORTING_MOCK_ID(token_id_, topk_id_) \ #define MOE_SORTING_MOCK_ID(token_id_, topk_id_) \
static_cast<uint32_t>(((token_id_)&0x00ffffff) | (((topk_id_)&0xff) << 24)) static_cast<uint32_t>(((token_id_)&0x00ffffff) | (((topk_id_)&0xff) << 24))
#ifndef MOE_SORTING_USE_EX_KERNEL
#define MOE_SORTING_USE_EX_KERNEL 1
#endif
// clang-format off // clang-format off
// [indexing implementation-1] // [indexing implementation-1]
// using M_a as constexpr block_size to partition all tokens into different slices // using M_a as constexpr block_size to partition all tokens into different slices
...@@ -28,7 +32,7 @@ namespace ck_tile { ...@@ -28,7 +32,7 @@ namespace ck_tile {
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5 // (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]] // weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
// //
// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1) // max_num_tokens_padded : topk * input_tokens + num_experts * M_a - topk (updated)
// * this could be larger than actual, since actual tokens are on GPU // * this could be larger than actual, since actual tokens are on GPU
// //
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5] // sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
...@@ -55,6 +59,34 @@ namespace ck_tile { ...@@ -55,6 +59,34 @@ namespace ck_tile {
// num_tokens_post_padded_ptr : [28] // num_tokens_post_padded_ptr : [28]
// num_sorted_tiles_ptr : [7] // num_sorted_tiles_ptr : [7]
// //
// skip_experts_with_zero_tokens(SkipExpertsWithZeroTokens)
// if enabled, the expert with no tokens will be skipped, in stead of padding to at least 1 unit_size(M_a)
//
// (pack below tensor, skip element marked with `-`)
// Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y - - - - Y Y Y Y
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -|
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o]
//
//
// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 5]
// num_tokens_post_padded_ptr : [24]
//
// * local_expert_mask : indicate local expert mask used on current GPU (used for EP case)
// and modify the output expert-ID, because we will only have enbaled expert on specific GPU.
// we call expert input to this kernel as "global expert id", output as "local expert id"
//
// * local_expert_mask : [1, 0, 1, 1, 0, 1] (mask out expert-id=1, 4)
//
// (pack below tensor, skip element marked with `-`)
// Y Y Y Y - - - - Y Y Y Y Y Y Y Y Y Y Y Y - - - - Y Y Y Y
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -|
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o]
//
// sorted_expert_ids_ptr : [0, 1, 2, 2, 3] (note original it was exper-id= 0, 2, 3, 5, but we produce "local expert id")
// num_tokens_post_padded_ptr : [20]
//
// * different from vLLM // * different from vLLM
// 1) token_id stored in sorted_token_ids_ptr is actual token_id, not token_id*top_K expanded id // 1) token_id stored in sorted_token_ids_ptr is actual token_id, not token_id*top_K expanded id
// 2)need sorted_weight_ptr // 2)need sorted_weight_ptr
...@@ -67,10 +99,80 @@ namespace ck_tile { ...@@ -67,10 +99,80 @@ namespace ck_tile {
// 4)num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one) // 4)num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one)
// //
// max_num_tokens_padded: opk_ids.numel() + num_experts * (block_size - 1) // max_num_tokens_padded: opk_ids.numel() + num_experts * (block_size - 1)
CK_TILE_HOST constexpr auto moe_sorting_get_smem_row_col(int num_tokens_, int num_experts_)
{
/* num_experts + 1
* +--------------------------------------+
* | |
* | |
* | | * -> sub-tokens
* | |
* | |
* +--------------------------------------+
* | | 2 -> cumsum buffer
* +--------------------------------------+
*
*/
int smem_cols = num_experts_ + 1; // usually experts is power of 2. padding here
int smem_rows = [&](){
index_t target_occupancy_ = 2;
constexpr index_t total_ = 65536 / sizeof(int);
constexpr index_t sub_unroll = 8;
constexpr index_t cumsum_bufs = 2; // 1 for cumsum, 1 for cnt
// at lease 2 lines, one for sub_token unroll, one for cumsum
// should be enough
if ((total_ / target_occupancy_) < ((cumsum_bufs+sub_unroll) * smem_cols)) {
if ((total_ / 1) < ((cumsum_bufs+sub_unroll) * smem_cols))
throw std::runtime_error("too many num_experts, can't allocate smem");
target_occupancy_ = 1;
}
int r = total_ / target_occupancy_ / smem_cols;
// round to sub_unroll multipl
int r_for_sub_token = r - cumsum_bufs;
r_for_sub_token = min(r_for_sub_token, num_tokens_);
r_for_sub_token = (r_for_sub_token + sub_unroll - 1) / sub_unroll * sub_unroll;
r_for_sub_token = max(r_for_sub_token, 1);
if(r_for_sub_token > 1)
{
int r_unroll_ = r_for_sub_token / sub_unroll;
// round to 1x/2x/4x/8x number of sub_unroll
int clz_ = __builtin_clz(r_unroll_); // 0b1:31 0b2:30, 0b3:30, 0b4:29
int mask_ = (1 << (31 - clz_)) - 1;
mask_ = mask_ > 0b111 ? 0b111 : mask_; //clamp to 8x at most
mask_ = ~mask_;
//printf("r_unroll_:%d, clz:%d, mask:%x\n", r_unroll_, clz_, mask_); fflush(stdout);
r_for_sub_token = (r_unroll_ & mask_) * sub_unroll;
}
// final check
if( (r_for_sub_token + cumsum_bufs * smem_cols * target_occupancy_ ) >= total_ ) {
throw std::runtime_error("can't run this kernel, request LDS over size");
}
return r_for_sub_token + cumsum_bufs;
}();
// printf("r:%d, c:%d\n", smem_rows, smem_cols);
return ck_tile::make_tuple(smem_rows, smem_cols);
}
struct MoeSortingHostArgs struct MoeSortingHostArgs
{ {
const void* p_topk_ids; // [token, topk] const void* p_topk_ids; // [token, topk]
const void* p_weights; // [token, topk] const void* p_weights; // [token, topk]
const void* p_local_expert_mask;
void* p_sorted_token_ids; void* p_sorted_token_ids;
void* p_sorted_weights; void* p_sorted_weights;
void* p_sorted_expert_ids; void* p_sorted_expert_ids;
...@@ -101,6 +203,7 @@ struct MoeSortingKernel ...@@ -101,6 +203,7 @@ struct MoeSortingKernel
{ {
const void* p_topk_ids; const void* p_topk_ids;
const void* p_weights; const void* p_weights;
const void* p_local_expert_mask;
void* p_sorted_token_ids; void* p_sorted_token_ids;
void* p_sorted_weights; void* p_sorted_weights;
void* p_sorted_expert_ids; void* p_sorted_expert_ids;
...@@ -111,8 +214,11 @@ struct MoeSortingKernel ...@@ -111,8 +214,11 @@ struct MoeSortingKernel
index_t moe_buf_bytes; index_t moe_buf_bytes;
index_t tokens_per_thread; index_t tokens_per_thread;
index_t smem_rows;
mdiv unit_size_mdiv; mdiv unit_size_mdiv;
mdiv topk_mdiv; mdiv topk_mdiv;
mdiv expert_mdiv;
// mdiv sub_tokens_mdiv;
}; };
CK_TILE_HOST static constexpr auto GridSize(const Hargs& h) CK_TILE_HOST static constexpr auto GridSize(const Hargs& h)
...@@ -123,15 +229,25 @@ struct MoeSortingKernel ...@@ -123,15 +229,25 @@ struct MoeSortingKernel
CK_TILE_HOST static constexpr auto BlockSize(const Hargs& h) CK_TILE_HOST static constexpr auto BlockSize(const Hargs& h)
{ {
#if MOE_SORTING_USE_EX_KERNEL
(void)h;
return dim3(256);
#else
return dim3(ck_tile::integer_least_multiple(h.num_experts, ck_tile::get_warp_size())); return dim3(ck_tile::integer_least_multiple(h.num_experts, ck_tile::get_warp_size()));
#endif
} }
// in byte // in byte
CK_TILE_HOST static constexpr auto GetSmemSize(const Hargs& h) CK_TILE_HOST static constexpr auto GetSmemSize(const Hargs& h)
{ {
#if MOE_SORTING_USE_EX_KERNEL
auto [smem_rows, smem_cols] = moe_sorting_get_smem_row_col(h.tokens, h.num_experts);
return smem_rows * smem_cols * sizeof(int);
#else
const auto blocks = BlockSize(h); const auto blocks = BlockSize(h);
// usually num_experts is power of 2, we pad 1 dword here for the row-size // usually num_experts is power of 2, we pad 1 dword here for the row-size
return ((blocks.x + 1) * (h.num_experts + 1) + (h.num_experts + 1)) * sizeof(index_t); return ((blocks.x + 1) * (h.num_experts + 1) + (h.num_experts + 1)) * sizeof(index_t);
#endif
} }
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h) CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
...@@ -139,6 +255,7 @@ struct MoeSortingKernel ...@@ -139,6 +255,7 @@ struct MoeSortingKernel
Kargs k; Kargs k;
k.p_topk_ids = h.p_topk_ids; k.p_topk_ids = h.p_topk_ids;
k.p_weights = h.p_weights; k.p_weights = h.p_weights;
k.p_local_expert_mask = h.p_local_expert_mask;
k.p_sorted_token_ids = h.p_sorted_token_ids; k.p_sorted_token_ids = h.p_sorted_token_ids;
k.p_sorted_weights = h.p_sorted_weights; k.p_sorted_weights = h.p_sorted_weights;
k.p_sorted_expert_ids = h.p_sorted_expert_ids; k.p_sorted_expert_ids = h.p_sorted_expert_ids;
...@@ -152,10 +269,18 @@ struct MoeSortingKernel ...@@ -152,10 +269,18 @@ struct MoeSortingKernel
k.tokens_per_thread = integer_divide_ceil(h.tokens * h.topk, blocks.x); k.tokens_per_thread = integer_divide_ceil(h.tokens * h.topk, blocks.x);
k.unit_size_mdiv = mdiv{static_cast<uint32_t>(h.unit_size)}; k.unit_size_mdiv = mdiv{static_cast<uint32_t>(h.unit_size)};
k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)}; k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)};
k.smem_rows = [&](){
auto [r_, c_] = moe_sorting_get_smem_row_col(h.tokens, h.num_experts);
(void) c_;
return r_;
}();
k.expert_mdiv = mdiv{static_cast<uint32_t>(h.num_experts)};
// k.sub_tokens_mdiv = mdiv{static_cast<uint32_t>(k.smem_rows - 1)};
return k; return k;
} }
// [a, b, c, d....] -> [a, a+b, a+b+c, a+b+c+d, ....] // [a, b, c, d....] -> [a, a+b, a+b+c, a+b+c+d, ....]
// NOTE: wave_size need at least be 16!! dpp 16 is one row
template <typename data_t, int wave_size> template <typename data_t, int wave_size>
__device__ inline void wave_cumsum(data_t& thread_data) const __device__ inline void wave_cumsum(data_t& thread_data) const
{ {
...@@ -196,6 +321,40 @@ struct MoeSortingKernel ...@@ -196,6 +321,40 @@ struct MoeSortingKernel
bank_mask, bank_mask,
bound_ctrl))); // row_shr:4 bound_ctrl))); // row_shr:4
} }
if constexpr(wave_size == 8) {
// wave-size=8 need one extra shift
thread_data =
reduce_op(thread_data,
__builtin_bit_cast(data_t, __builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data),
0x118,
row_mask,
bank_mask,
bound_ctrl))); // row_shr:8
#if 0
constexpr int bank_mask_0_7 = 0b1100;
auto reduce_op_r = [&](auto x_, auto y_) { return x_ - y_; };
thread_data = reduce_op_r(thread_data, __builtin_bit_cast(data_t,
__builtin_amdgcn_update_dpp(0, /* old value */
__builtin_bit_cast(int, thread_data),
0x157,
row_mask,
bank_mask_0_7,
bound_ctrl))// row_newbcast:7
);
#else
data_t xxx =__builtin_bit_cast(data_t,
__builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data),
0x157,
row_mask,
bank_mask,
bound_ctrl)); // row_newbcast:7
data_t yyy = (__lane_id() / 8) % 2 == 0 ? 0 : xxx;
thread_data = thread_data - yyy;
#endif
}
if constexpr(wave_size > 8) if constexpr(wave_size > 8)
{ {
thread_data = thread_data =
...@@ -224,6 +383,36 @@ struct MoeSortingKernel ...@@ -224,6 +383,36 @@ struct MoeSortingKernel
} }
} }
// reduce single pixel within a wave
template <typename T, typename F, index_t wave_size_ = warpSize>
__device__ static constexpr T wave_reduce(T local, F reduce_f, number<wave_size_> = {})
{
// constexpr int wave_size = 64;
// constexpr int reduce_stage = 6; // 1<<6=64
// clang-format off
constexpr int reduce_stage = [](){
if constexpr(wave_size_ == 2) return 1;
else if constexpr(wave_size_ == 4) return 2;
else if constexpr(wave_size_ == 8) return 3;
else if constexpr(wave_size_ == 16) return 4;
else if constexpr(wave_size_ == 32) return 5;
else if constexpr(wave_size_ == 64) return 6;
else return 0;
}();
// clang-format on
T v_local = local;
#pragma unroll reduce_stage
for(int i_stage = 0; i_stage < reduce_stage; i_stage++)
{
int src_lane = __lane_id() ^ (1 << i_stage);
int32_t v_remote_tmp =
__builtin_amdgcn_ds_bpermute(src_lane << 2, bit_cast<int32_t>(v_local));
T v_remote = bit_cast<T>(v_remote_tmp);
v_local = reduce_f(v_local, v_remote);
}
return v_local;
}
CK_TILE_DEVICE index_t calc_index(index_t total_col, index_t row, index_t col) const CK_TILE_DEVICE index_t calc_index(index_t total_col, index_t row, index_t col) const
{ {
return row * total_col + col; return row * total_col + col;
...@@ -257,37 +446,37 @@ struct MoeSortingKernel ...@@ -257,37 +446,37 @@ struct MoeSortingKernel
index_t* shared_mem = reinterpret_cast<index_t*>(smem); index_t* shared_mem = reinterpret_cast<index_t*>(smem);
index_t* tokens_cnts = shared_mem; // 2d: (blockDim.x + 1, num_experts) index_t* tokens_cnts = shared_mem; // 2d: (blockDim.x + 1, num_experts)
index_t* cumsum = shared_mem + (blockDim.x + 1) * (num_experts+1); // 1: (num_experts + 1) index_t* cumsum = shared_mem + (blockDim.x + 1) * (num_experts + 1); // 1: (num_experts + 1)
for(int i = 0; i < num_experts; ++i) for(int i = 0; i < num_experts; ++i)
{ {
tokens_cnts[calc_index(num_experts+1, tid + 1, i)] = 0; tokens_cnts[calc_index(num_experts + 1, tid + 1, i)] = 0;
} }
#pragma unroll Problem_::InternalLoadUnroll #pragma unroll Problem_::InternalLoadUnroll
for(int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) for(int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i)
{ {
++tokens_cnts[calc_index(num_experts+1, tid + 1, topk_id[i])]; ++tokens_cnts[calc_index(num_experts + 1, tid + 1, topk_id[i])];
} }
__syncthreads(); __syncthreads();
#if 1 #if 1
if(tid < num_experts) if(tid < num_experts)
{ {
tokens_cnts[calc_index(num_experts+1, 0, tid)] = 0; tokens_cnts[calc_index(num_experts + 1, 0, tid)] = 0;
index_t local_c[8]; index_t local_c[8];
index_t prev_c = 0; index_t prev_c = 0;
// TODO: manually unroll. pragma unroll does not work well when we have dependency // TODO: manually unroll. pragma unroll does not work well when we have dependency
for(int i = 1; i <= static_cast<index_t>(blockDim.x); i+= 8) for(int i = 1; i <= static_cast<index_t>(blockDim.x); i += 8)
{ {
local_c[0] = tokens_cnts[calc_index(num_experts+1, i + 0, tid)]; local_c[0] = tokens_cnts[calc_index(num_experts + 1, i + 0, tid)];
local_c[1] = tokens_cnts[calc_index(num_experts+1, i + 1, tid)]; local_c[1] = tokens_cnts[calc_index(num_experts + 1, i + 1, tid)];
local_c[2] = tokens_cnts[calc_index(num_experts+1, i + 2, tid)]; local_c[2] = tokens_cnts[calc_index(num_experts + 1, i + 2, tid)];
local_c[3] = tokens_cnts[calc_index(num_experts+1, i + 3, tid)]; local_c[3] = tokens_cnts[calc_index(num_experts + 1, i + 3, tid)];
local_c[4] = tokens_cnts[calc_index(num_experts+1, i + 4, tid)]; local_c[4] = tokens_cnts[calc_index(num_experts + 1, i + 4, tid)];
local_c[5] = tokens_cnts[calc_index(num_experts+1, i + 5, tid)]; local_c[5] = tokens_cnts[calc_index(num_experts + 1, i + 5, tid)];
local_c[6] = tokens_cnts[calc_index(num_experts+1, i + 6, tid)]; local_c[6] = tokens_cnts[calc_index(num_experts + 1, i + 6, tid)];
local_c[7] = tokens_cnts[calc_index(num_experts+1, i + 7, tid)]; local_c[7] = tokens_cnts[calc_index(num_experts + 1, i + 7, tid)];
local_c[0] += prev_c; local_c[0] += prev_c;
local_c[1] += local_c[0]; local_c[1] += local_c[0];
...@@ -299,50 +488,56 @@ struct MoeSortingKernel ...@@ -299,50 +488,56 @@ struct MoeSortingKernel
local_c[7] += local_c[6]; local_c[7] += local_c[6];
prev_c = local_c[7]; prev_c = local_c[7];
tokens_cnts[calc_index(num_experts+1, i + 0, tid)] = local_c[0]; tokens_cnts[calc_index(num_experts + 1, i + 0, tid)] = local_c[0];
tokens_cnts[calc_index(num_experts+1, i + 1, tid)] = local_c[1]; tokens_cnts[calc_index(num_experts + 1, i + 1, tid)] = local_c[1];
tokens_cnts[calc_index(num_experts+1, i + 2, tid)] = local_c[2]; tokens_cnts[calc_index(num_experts + 1, i + 2, tid)] = local_c[2];
tokens_cnts[calc_index(num_experts+1, i + 3, tid)] = local_c[3]; tokens_cnts[calc_index(num_experts + 1, i + 3, tid)] = local_c[3];
tokens_cnts[calc_index(num_experts+1, i + 4, tid)] = local_c[4]; tokens_cnts[calc_index(num_experts + 1, i + 4, tid)] = local_c[4];
tokens_cnts[calc_index(num_experts+1, i + 5, tid)] = local_c[5]; tokens_cnts[calc_index(num_experts + 1, i + 5, tid)] = local_c[5];
tokens_cnts[calc_index(num_experts+1, i + 6, tid)] = local_c[6]; tokens_cnts[calc_index(num_experts + 1, i + 6, tid)] = local_c[6];
tokens_cnts[calc_index(num_experts+1, i + 7, tid)] = local_c[7]; tokens_cnts[calc_index(num_experts + 1, i + 7, tid)] = local_c[7];
} }
} }
#else #else
// TODO: below code still working, but slow in expert=32/topk=5 case. Put here for future heuristic // TODO: below code still working, but slow in expert=32/topk=5 case. Put here for future
// heuristic
{ {
if(tid < num_experts) if(tid < num_experts)
tokens_cnts[calc_index(num_experts+1, 0, tid)] = 0; tokens_cnts[calc_index(num_experts + 1, 0, tid)] = 0;
for(int i = 0; i < num_experts; i+=8) { for(int i = 0; i < num_experts; i += 8)
{
index_t local_c[8]; index_t local_c[8];
#pragma unroll #pragma unroll
for(int j = 0; j < 8; j++) { for(int j = 0; j < 8; j++)
local_c[j] = tokens_cnts[calc_index(num_experts+1, tid+1, i+j)]; {
local_c[j] = tokens_cnts[calc_index(num_experts + 1, tid + 1, i + j)];
} }
#pragma unroll #pragma unroll
for(int j = 0; j < 8; j++) { for(int j = 0; j < 8; j++)
{
wave_cumsum<int, 64>(local_c[j]); wave_cumsum<int, 64>(local_c[j]);
} }
#pragma unroll #pragma unroll
for(int j = 0; j < 8; j++) { for(int j = 0; j < 8; j++)
tokens_cnts[calc_index(num_experts+1, tid+1, i+j)] = local_c[j]; {
tokens_cnts[calc_index(num_experts + 1, tid + 1, i + j)] = local_c[j];
} }
} }
} }
#endif #endif
__syncthreads(); __syncthreads();
if constexpr (Problem::ExpertTile == 0) { if constexpr(Problem::ExpertTile == 0)
{
if(tid == 0) if(tid == 0)
{ {
cumsum[0] = 0; cumsum[0] = 0;
for(int i = 1; i <= num_experts; ++i) for(int i = 1; i <= num_experts; ++i)
{ {
auto current_units = [&]() { auto current_units = [&]() {
index_t x_ = tokens_cnts[calc_index(num_experts+1, blockDim.x, i - 1)] + index_t x_ = tokens_cnts[calc_index(num_experts + 1, blockDim.x, i - 1)] +
unit_size_mdiv.divisor - 1; unit_size_mdiv.divisor - 1;
index_t y_ = unit_size_mdiv.div(x_); index_t y_ = unit_size_mdiv.div(x_);
return max(y_, 1) * unit_size_mdiv.divisor; return max(y_, 1) * unit_size_mdiv.divisor;
...@@ -351,20 +546,24 @@ struct MoeSortingKernel ...@@ -351,20 +546,24 @@ struct MoeSortingKernel
} }
*p_total_tokens_post_pad = cumsum[num_experts]; *p_total_tokens_post_pad = cumsum[num_experts];
} }
} else { }
// TODO: we have out-of-bound read here. But result is still OK (will ignore tid >= expert) else
// for simplicity, not check experts here. {
int local_cnt = tokens_cnts[calc_index(num_experts+1, blockDim.x, tid)]; // TODO: we have out-of-bound read here. But result is still OK (will ignore tid >=
// expert) for simplicity, not check experts here.
int local_cnt = tokens_cnts[calc_index(num_experts + 1, blockDim.x, tid)];
int blocks_pers_expert = unit_size_mdiv.div(local_cnt + unit_size_mdiv.divisor - 1); int blocks_pers_expert = unit_size_mdiv.div(local_cnt + unit_size_mdiv.divisor - 1);
int padded_tokens_per_expert = max(blocks_pers_expert, 1) * unit_size_mdiv.divisor; int padded_tokens_per_expert = max(blocks_pers_expert, 1) * unit_size_mdiv.divisor;
int local_cumsum = padded_tokens_per_expert; int local_cumsum = padded_tokens_per_expert;
wave_cumsum<int, 64>(local_cumsum); wave_cumsum<int, 64>(local_cumsum);
if(tid == (num_experts - 1)) { if(tid == (num_experts - 1))
{
cumsum[0] = 0; cumsum[0] = 0;
*p_total_tokens_post_pad = local_cumsum; *p_total_tokens_post_pad = local_cumsum;
} }
if(tid < num_experts) { if(tid < num_experts)
{
cumsum[tid + 1] = local_cumsum; cumsum[tid + 1] = local_cumsum;
} }
} }
...@@ -384,7 +583,7 @@ struct MoeSortingKernel ...@@ -384,7 +583,7 @@ struct MoeSortingKernel
for(int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) for(int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i)
{ {
index_t expert_id = topk_id[i]; index_t expert_id = topk_id[i];
index_t local_cnt = tokens_cnts[calc_index(num_experts+1, tid, expert_id)]; index_t local_cnt = tokens_cnts[calc_index(num_experts + 1, tid, expert_id)];
index_t rank_post_pad = local_cnt + cumsum[expert_id]; index_t rank_post_pad = local_cnt + cumsum[expert_id];
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID #if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
uint32_t curr_token_id, curr_topk_id; uint32_t curr_token_id, curr_topk_id;
...@@ -394,15 +593,16 @@ struct MoeSortingKernel ...@@ -394,15 +593,16 @@ struct MoeSortingKernel
p_sorted_token_ids[rank_post_pad] = topk_mdiv.div(i); p_sorted_token_ids[rank_post_pad] = topk_mdiv.div(i);
#endif #endif
p_sorted_weights[rank_post_pad] = weights[i]; p_sorted_weights[rank_post_pad] = weights[i];
tokens_cnts[calc_index(num_experts+1, tid, expert_id)] = local_cnt+1; tokens_cnts[calc_index(num_experts + 1, tid, expert_id)] = local_cnt + 1;
} }
if constexpr (Problem::ExpertTile == 0) { if constexpr(Problem::ExpertTile == 0)
{
const index_t prefill_token = topk_mdiv.div(numel); const index_t prefill_token = topk_mdiv.div(numel);
if(tid < num_experts) if(tid < num_experts)
{ {
index_t expert_offset = index_t expert_offset =
cumsum[tid] + tokens_cnts[calc_index(num_experts+1, blockDim.x, tid)]; cumsum[tid] + tokens_cnts[calc_index(num_experts + 1, blockDim.x, tid)];
index_t expert_end = cumsum[tid + 1]; index_t expert_end = cumsum[tid + 1];
while(expert_offset < expert_end) while(expert_offset < expert_end)
{ {
...@@ -417,16 +617,19 @@ struct MoeSortingKernel ...@@ -417,16 +617,19 @@ struct MoeSortingKernel
} }
} }
} }
else { else
{
const index_t prefill_token = topk_mdiv.div(numel); const index_t prefill_token = topk_mdiv.div(numel);
// TODO: only support expert-tile like 8, 16, 32 // TODO: only support expert-tile like 8, 16, 32
static constexpr index_t experts_per_wave = warpSize / Problem::ExpertTile; static constexpr index_t experts_per_wave = warpSize / Problem::ExpertTile;
{ {
index_t eid = tid / experts_per_wave; index_t eid = tid / experts_per_wave;
index_t expert_offset = index_t expert_offset = cumsum[eid] +
cumsum[eid] + tokens_cnts[calc_index(num_experts+1, blockDim.x, eid)] + tid % experts_per_wave; tokens_cnts[calc_index(num_experts + 1, blockDim.x, eid)] +
tid % experts_per_wave;
index_t expert_end = cumsum[eid + 1]; index_t expert_end = cumsum[eid + 1];
if(eid < num_experts) { if(eid < num_experts)
{
while(expert_offset < expert_end) while(expert_offset < expert_end)
{ {
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID #if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
...@@ -436,9 +639,362 @@ struct MoeSortingKernel ...@@ -436,9 +639,362 @@ struct MoeSortingKernel
p_sorted_token_ids[expert_offset] = prefill_token; p_sorted_token_ids[expert_offset] = prefill_token;
#endif #endif
p_sorted_weights[expert_offset] = static_cast<WeightType>(0.0); p_sorted_weights[expert_offset] = static_cast<WeightType>(0.0);
expert_offset+=experts_per_wave; expert_offset += experts_per_wave;
}
}
}
}
}
// only support index_t, and single pixel access
struct simple_smem_indexer
{
index_t* smem;
index_t row_stride;
// this is 2D
CK_TILE_DEVICE simple_smem_indexer(index_t* smem_, index_t row_stride_)
: smem(smem_), row_stride(row_stride_)
{
}
CK_TILE_DEVICE const index_t& operator()(index_t i_row, index_t i_col) const
{
return smem[i_row * row_stride + i_col];
}
CK_TILE_DEVICE index_t& operator()(index_t i_row, index_t i_col)
{
return smem[i_row * row_stride + i_col];
}
// this is 1D or linear
CK_TILE_DEVICE simple_smem_indexer(index_t* smem_) : smem(smem_), row_stride(0) {}
CK_TILE_DEVICE const index_t& operator()(index_t idx) const { return smem[idx]; }
CK_TILE_DEVICE index_t& operator()(index_t idx) { return smem[idx]; }
};
CK_TILE_DEVICE void
moe_align_block_size_kernel_ex(const IndexType* __restrict__ topk_id,
const WeightType* __restrict__ weights,
const IndexType* __restrict__ local_expert_mask,
index_t* p_sorted_token_ids,
WeightType* p_sorted_weights,
index_t* p_sorted_expert_ids,
index_t* p_total_tokens_post_pad,
const index_t num_experts,
const index_t tokens,
const mdiv unit_size_mdiv,
const mdiv topk_mdiv,
const mdiv expert_mdiv,
const index_t smem_rows,
void* smem) const
{
const index_t tid = static_cast<index_t>(threadIdx.x);
const index_t wid = __builtin_amdgcn_readfirstlane(tid / warpSize);
const index_t lid = __lane_id();
constexpr index_t block_size = 256; // blockDim.x;
const index_t sub_tokens = smem_rows - 2; // sub_tokens_mdiv.divisor;
const index_t topk = topk_mdiv.divisor;
auto f_sum = [](auto x_, auto y_) { return x_ + y_; };
const index_t smem_cols = num_experts + 1;
simple_smem_indexer smem_cumsum{reinterpret_cast<index_t*>(smem) + 0};
simple_smem_indexer smem_cumdup{reinterpret_cast<index_t*>(smem) + smem_cols};
simple_smem_indexer smem_tokens{reinterpret_cast<index_t*>(smem) + 2 * smem_cols,
smem_cols};
// #pragma unroll 8
for(int i = tid; i < (sub_tokens * num_experts); i += block_size)
{
uint32_t curr_token_id, curr_expert_id;
expert_mdiv.divmod(i, curr_token_id, curr_expert_id);
smem_tokens(curr_token_id, curr_expert_id) = 0;
}
__syncthreads();
for(int i_token = 0; i_token < tokens; i_token += sub_tokens)
{
// NOTE: below for loop can't have barrier inside!!
for(int i = tid; i < (sub_tokens * topk); i += block_size)
{
uint32_t curr_token_id, curr_topk_id;
topk_mdiv.divmod(i, curr_token_id, curr_topk_id);
int i_t = i_token + curr_token_id;
if(i_t < tokens)
{
int eid = topk_id[i_t * topk + curr_topk_id];
if constexpr(Problem::SubTokenOneShot)
smem_tokens(curr_token_id, eid) = curr_topk_id + 1;
else
smem_tokens(curr_token_id, eid)++;
}
__builtin_amdgcn_s_waitcnt(0xc07f);
} }
__syncthreads(); // make sure different i_token iteration not overlap by different wave
} }
// counting
if(tid == 0)
{
smem_cumsum(0) = 0;
// smem_cumdup(0) = 0;
}
{
constexpr int lane_group_sz = 8;
int lane_group_id = tid / lane_group_sz;
int lane_group_os = tid % lane_group_sz;
constexpr int lane_group_nm = block_size / lane_group_sz;
for(int i_e = lane_group_id; i_e < num_experts; i_e += lane_group_nm)
{
index_t local_c[Problem::SubTokenTile];
index_t cnt = 0;
for(int i = 0; i < sub_tokens; i += 8 * Problem::SubTokenTile)
{
#pragma unroll Problem::SubTokenTile
for(int j = 0; j < Problem::SubTokenTile; j++)
{
local_c[j] = smem_tokens(i + j * 8 + lane_group_os, i_e);
if constexpr(Problem::SubTokenOneShot)
{
local_c[j] = local_c[j] != 0 ? 1 : 0;
}
}
#pragma unroll Problem::SubTokenTile
for(int j = 0; j < Problem::SubTokenTile; j++)
{
cnt += wave_reduce(local_c[j], f_sum, number<8>{});
}
}
if(lane_group_os == 0)
smem_cumsum(i_e + 1) = cnt;
}
}
if constexpr(Problem::LocalExpertMasking)
{
smem_cumdup(0) = 0;
for(int i_e = tid; i_e < num_experts; i_e += block_size)
{
// reuse this buffer
smem_cumdup(i_e + 1) = local_expert_mask[i_e];
}
}
__syncthreads();
{
if(wid == 0)
{
// NOTE: under this block can never use __syncthreads!
int i_e_ = 0;
int local_cumsum_ = 0;
for(; i_e_ < num_experts; i_e_ += warpSize)
{
int pre_cumsum_ = smem_cumsum(lid == 0 ? i_e_ : 0);
int local_cnt = smem_cumsum(i_e_ + lid + 1);
int blocks_pers_expert =
unit_size_mdiv.div(local_cnt + unit_size_mdiv.divisor - 1);
int pre_cumsum_masking = [&]() {
if constexpr(Problem::LocalExpertMasking)
return smem_cumdup(lid == 0 ? i_e_ : 0);
else
return 0; // not used
}();
int local_masking = [&]() {
if constexpr(Problem::LocalExpertMasking)
return smem_cumdup(i_e_ + lid + 1);
else
return 0; // not used
}();
int padded_tokens_per_expert = [&]() {
int x_ = [&]() {
if constexpr(Problem::SkipExpertsWithZeroTokens)
{
// if local_cnt is zero, blocks_pers_expert will be zero
// this is what we want to achieve
return blocks_pers_expert * unit_size_mdiv.divisor;
}
else
{
return max(blocks_pers_expert, 1) * unit_size_mdiv.divisor;
}
}();
if constexpr(Problem::LocalExpertMasking)
{
return local_masking ? x_ : 0;
}
else
return x_;
}();
local_cumsum_ = padded_tokens_per_expert;
local_cumsum_ += pre_cumsum_; // note pre_cumsum must be added after local
// cumsum padded in case local cumsum is zero, but
// pre_sumsum has value, which will result int
// zero local cumsum(but we want at least padded)
wave_cumsum<int, warpSize>(local_cumsum_);
if((i_e_ + lid) < num_experts)
smem_cumsum(i_e_ + lid + 1) = local_cumsum_;
if constexpr(Problem::LocalExpertMasking)
{
local_masking += pre_cumsum_masking;
wave_cumsum<int, warpSize>(local_masking);
if((i_e_ + lid) < num_experts)
smem_cumdup(i_e_ + lid + 1) = local_masking;
}
// NOTE: this waitcnt is a must, compiler will not generate waitcnt lgkmcnt()
// for above write however __syncthreads will cause barrier with waves other
// than 0(which is not we want)
__builtin_amdgcn_s_waitcnt(0xc07f);
}
if((lid + i_e_ - warpSize) == (num_experts - 1))
{
*p_total_tokens_post_pad = local_cumsum_;
}
}
__syncthreads();
}
for(int i_e = tid; i_e < num_experts; i_e += block_size)
{
int e_start = smem_cumsum(i_e);
int e_end = smem_cumsum(i_e + 1);
int expert_id = [&]() {
if constexpr(Problem::LocalExpertMasking)
{
// local expert id from cumsum
return smem_cumdup(i_e);
}
else
return i_e;
}();
smem_cumdup(i_e) = e_start; // duplicate cumsum for later use
if constexpr(Problem::SkipExpertsWithZeroTokens)
{
if(e_start == e_end) // skip zero token expert
continue;
}
if constexpr(Problem::LocalExpertMasking)
{
if(local_expert_mask[i_e] == 0)
continue;
}
for(int i = e_start; i < e_end; i += unit_size_mdiv.divisor)
{
p_sorted_expert_ids[unit_size_mdiv.div(i)] = expert_id;
}
}
smem_cumdup(num_experts) = smem_cumsum(num_experts);
// fill the p_sorted_token_ids/p_sorted_weights
for(int i_token = 0; i_token < tokens; i_token += sub_tokens)
{
if constexpr(!Problem::SubTokenOneShot)
{
// clear every time
for(int i = tid; i < (sub_tokens * num_experts); i += block_size)
{
uint32_t curr_token_id, curr_expert_id;
expert_mdiv.divmod(i, curr_token_id, curr_expert_id);
smem_tokens(curr_token_id, curr_expert_id) = 0;
}
__syncthreads();
// load again
for(int i = tid; i < (sub_tokens * topk); i += block_size)
{
uint32_t curr_token_id_, curr_topk_id_;
topk_mdiv.divmod(i, curr_token_id_, curr_topk_id_);
int curr_token_id = static_cast<int>(curr_token_id_);
int curr_topk_id = static_cast<int>(curr_topk_id_);
int i_t = i_token + curr_token_id;
if(i_t < tokens)
{
int eid = topk_id[i_t * topk + curr_topk_id];
smem_tokens(curr_token_id, eid) = curr_topk_id + 1; // at least 1
}
}
__syncthreads();
}
{
constexpr int lane_group_sz = 8;
int lane_group_id = tid / lane_group_sz;
int lane_group_os = tid % lane_group_sz;
constexpr int lane_group_nm = block_size / lane_group_sz;
for(int eid = lane_group_id; eid < num_experts; eid += lane_group_nm)
{
if constexpr(Problem::LocalExpertMasking)
{
if(local_expert_mask[eid] == 0)
continue;
}
int position = smem_cumsum(eid);
for(int i_sub_token = lane_group_os; i_sub_token < sub_tokens;
i_sub_token += lane_group_sz)
{
auto x = smem_tokens(i_sub_token, eid);
int local_cnt_cache = x != 0 ? 1 : 0;
int local_cnt = local_cnt_cache;
wave_cumsum<int, lane_group_sz>(local_cnt);
if(x != 0)
{
// now x is topk value
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
p_sorted_token_ids[position + local_cnt - 1] =
MOE_SORTING_MOCK_ID(i_token + i_sub_token, x - 1);
#else
p_sorted_token_ids[position + local_cnt - 1] = i_token + i_sub_token;
#endif
p_sorted_weights[position + local_cnt - 1] =
weights[(i_token + i_sub_token) * topk + x - 1];
}
int remote_cnt = __builtin_amdgcn_ds_bpermute(
(lane_group_sz * (lane_group_id + 1) - 1) << 2, local_cnt);
position += remote_cnt;
}
smem_cumsum(eid) = position;
}
}
__syncthreads();
}
// add the skip number
for(int eid = tid; eid < num_experts; eid += block_size)
{
int e_start = smem_cumsum(eid);
int e_end = smem_cumdup(eid + 1);
if constexpr(Problem::SkipExpertsWithZeroTokens)
{
if(e_start == e_end) // skip zero token expert
continue;
}
while(e_start < e_end)
{
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
p_sorted_token_ids[e_start] = MOE_SORTING_MOCK_ID(tokens, topk);
#else
p_sorted_token_ids[e_start] = tokens;
#endif
p_sorted_weights[e_start] = static_cast<WeightType>(0.0);
e_start++;
} }
} }
} }
...@@ -456,6 +1012,24 @@ struct MoeSortingKernel ...@@ -456,6 +1012,24 @@ struct MoeSortingKernel
} }
const size_t numel = kargs.tokens * kargs.topk_mdiv.divisor; const size_t numel = kargs.tokens * kargs.topk_mdiv.divisor;
extern __shared__ char smem[]; extern __shared__ char smem[];
#if MOE_SORTING_USE_EX_KERNEL
(void)numel;
return moe_align_block_size_kernel_ex(
static_cast<const IndexType*>(kargs.p_topk_ids),
static_cast<const WeightType*>(kargs.p_weights),
static_cast<const IndexType*>(kargs.p_local_expert_mask),
static_cast<IndexType*>(kargs.p_sorted_token_ids),
static_cast<WeightType*>(kargs.p_sorted_weights),
static_cast<IndexType*>(kargs.p_sorted_expert_ids),
static_cast<IndexType*>(kargs.p_total_tokens_post_pad),
kargs.num_experts,
kargs.tokens,
kargs.unit_size_mdiv,
kargs.topk_mdiv,
kargs.expert_mdiv,
kargs.smem_rows,
smem);
#else
return moe_align_block_size_kernel(static_cast<const IndexType*>(kargs.p_topk_ids), return moe_align_block_size_kernel(static_cast<const IndexType*>(kargs.p_topk_ids),
static_cast<const WeightType*>(kargs.p_weights), static_cast<const WeightType*>(kargs.p_weights),
static_cast<IndexType*>(kargs.p_sorted_token_ids), static_cast<IndexType*>(kargs.p_sorted_token_ids),
...@@ -468,6 +1042,7 @@ struct MoeSortingKernel ...@@ -468,6 +1042,7 @@ struct MoeSortingKernel
kargs.unit_size_mdiv, kargs.unit_size_mdiv,
kargs.topk_mdiv, kargs.topk_mdiv,
smem); smem);
#endif
} }
}; };
......
...@@ -25,4 +25,28 @@ struct MoeSortingProblem ...@@ -25,4 +25,28 @@ struct MoeSortingProblem
InternalLoadUnroll_; // TODO: need better design(like tile size) InternalLoadUnroll_; // TODO: need better design(like tile size)
static constexpr index_t ExpertTile = ExpertTile_; // TODO: only used in store out static constexpr index_t ExpertTile = ExpertTile_; // TODO: only used in store out
}; };
template <typename IndexType_,
typename WeightType_,
index_t SubTokenTile_, // 1,2,4,8, or 0 in the future
bool SubTokenOneShot_, // if we only loop over once or not
bool LocalExpertMasking_, // used in EP case
bool SkipExpertsWithZeroTokens_ = true,
index_t ExpertTile_ = 0>
struct MoeSortingProblemEx
{
// TODO: this kernel only support warp per row
using WeightType = remove_cvref_t<WeightType_>;
using IndexType = remove_cvref_t<IndexType_>;
static constexpr index_t WarpSize = get_warp_size();
static constexpr index_t WarpsPerBlock = 1;
static constexpr index_t SubTokenTile = SubTokenTile_;
static constexpr bool SubTokenOneShot = SubTokenOneShot_;
static constexpr bool LocalExpertMasking = LocalExpertMasking_;
static constexpr bool SkipExpertsWithZeroTokens = SkipExpertsWithZeroTokens_;
static_assert(SubTokenTile == 1 || SubTokenTile == 2 || SubTokenTile == 4 || SubTokenTile == 8);
static constexpr index_t ExpertTile = ExpertTile_; // TODO: only used in store out
};
} // namespace ck_tile } // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment