Commit bfe0120a authored by dummycoderfe's avatar dummycoderfe
Browse files

add extblocksnel to set zeros for moebufs

parent 8f4dc357
...@@ -25,6 +25,7 @@ auto create_args(int argc, char* argv[]) ...@@ -25,6 +25,7 @@ auto create_args(int argc, char* argv[])
.insert("e", "8", "number of experts") .insert("e", "8", "number of experts")
.insert("k", "4", "topk") .insert("k", "4", "topk")
.insert("unit", "32", "unit_size") .insert("unit", "32", "unit_size")
.insert("moe_buf_size", "-1", "moe_buf_size")
.insert("seed", "-1", "seed to be used, -1 means random every time") .insert("seed", "-1", "seed to be used, -1 means random every time")
.insert("kname", "0", "when set to 1 it will print kernel name") .insert("kname", "0", "when set to 1 it will print kernel name")
.insert("warmup", "5", "number of iterations before benchmark the kernel") .insert("warmup", "5", "number of iterations before benchmark the kernel")
...@@ -69,6 +70,7 @@ bool test_moe_sorting(ck_tile::ArgParser args) ...@@ -69,6 +70,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
int topk = args.get_int("k"); int topk = args.get_int("k");
int seed = args.get_int("seed"); int seed = args.get_int("seed");
int unit_size = args.get_int("unit"); int unit_size = args.get_int("unit");
int moe_buf_size = args.get_int("moe_buf_size");
int kname = args.get_int("kname"); int kname = args.get_int("kname");
int warmup = args.get_int("warmup"); int warmup = args.get_int("warmup");
int repeat = args.get_int("repeat"); int repeat = args.get_int("repeat");
...@@ -94,8 +96,10 @@ bool test_moe_sorting(ck_tile::ArgParser args) ...@@ -94,8 +96,10 @@ bool test_moe_sorting(ck_tile::ArgParser args)
ck_tile::HostTensor<WeightType> sorted_weights_host({max_output_ids}, {1}); ck_tile::HostTensor<WeightType> sorted_weights_host({max_output_ids}, {1});
ck_tile::HostTensor<IndexType> expert_ids_host({max_output_ids / unit_size}, {1}); ck_tile::HostTensor<IndexType> expert_ids_host({max_output_ids / unit_size}, {1});
ck_tile::HostTensor<IndexType> sorted_id_cnt_host({1}, {1}); ck_tile::HostTensor<IndexType> sorted_id_cnt_host({1}, {1});
ck_tile::HostTensor<IndexType> moe_buf_host({moe_buf_size}, {1});
ck_tile::FillUniformDistribution<WeightType>{-.5f, .5f}(weights_host); ck_tile::FillUniformDistribution<WeightType>{-.5f, .5f}(weights_host);
ck_tile::FillUniformDistribution<WeightType>{-.5f, .5f}(moe_buf_host);
topid_unique_gen<IndexType>(topk_ids_host.mData, tokens, topk, experts, seed); topid_unique_gen<IndexType>(topk_ids_host.mData, tokens, topk, experts, seed);
ck_tile::DeviceMem topk_ids_dev(topk_ids_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem topk_ids_dev(topk_ids_host.get_element_space_size_in_bytes());
...@@ -104,9 +108,11 @@ bool test_moe_sorting(ck_tile::ArgParser args) ...@@ -104,9 +108,11 @@ bool test_moe_sorting(ck_tile::ArgParser args)
ck_tile::DeviceMem sorted_weights_dev(sorted_weights_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem sorted_weights_dev(sorted_weights_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem expert_ids_dev(expert_ids_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem expert_ids_dev(expert_ids_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem sorted_id_cnt_dev(sorted_id_cnt_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem sorted_id_cnt_dev(sorted_id_cnt_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem moe_buf_dev(moe_buf_host.get_element_space_size_in_bytes());
topk_ids_dev.ToDevice(topk_ids_host.data()); topk_ids_dev.ToDevice(topk_ids_host.data());
weights_dev.ToDevice(weights_host.data()); weights_dev.ToDevice(weights_host.data());
moe_buf_dev.ToDevice(moe_buf_host.data());
moe_sorting_trait trait{index_prec, weight_prec, experts, topk, unit_size, tokens}; moe_sorting_trait trait{index_prec, weight_prec, experts, topk, unit_size, tokens};
...@@ -116,10 +122,12 @@ bool test_moe_sorting(ck_tile::ArgParser args) ...@@ -116,10 +122,12 @@ bool test_moe_sorting(ck_tile::ArgParser args)
sorted_weights_dev.GetDeviceBuffer(), sorted_weights_dev.GetDeviceBuffer(),
expert_ids_dev.GetDeviceBuffer(), expert_ids_dev.GetDeviceBuffer(),
sorted_id_cnt_dev.GetDeviceBuffer(), sorted_id_cnt_dev.GetDeviceBuffer(),
moe_buf_dev.GetDeviceBuffer(),
tokens, tokens,
unit_size, unit_size,
experts, experts,
topk}; topk,
moe_buf_size};
ck_tile::stream_config sc{nullptr, ck_tile::stream_config sc{nullptr,
true, true,
...@@ -146,6 +154,7 @@ bool test_moe_sorting(ck_tile::ArgParser args) ...@@ -146,6 +154,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
sorted_weights_dev.FromDevice(sorted_weights_host.data()); sorted_weights_dev.FromDevice(sorted_weights_host.data());
expert_ids_dev.FromDevice(expert_ids_host.data()); expert_ids_dev.FromDevice(expert_ids_host.data());
sorted_id_cnt_dev.FromDevice(sorted_id_cnt_host.data()); sorted_id_cnt_dev.FromDevice(sorted_id_cnt_host.data());
moe_buf_dev.FromDevice(moe_buf_host.data());
bool rtn = true; bool rtn = true;
if(validate) if(validate)
...@@ -153,6 +162,9 @@ bool test_moe_sorting(ck_tile::ArgParser args) ...@@ -153,6 +162,9 @@ bool test_moe_sorting(ck_tile::ArgParser args)
ck_tile::HostTensor<IndexType> sorted_ids_ref({max_output_ids}, {1}); ck_tile::HostTensor<IndexType> sorted_ids_ref({max_output_ids}, {1});
ck_tile::HostTensor<WeightType> sorted_weights_ref({max_output_ids}, {1}); ck_tile::HostTensor<WeightType> sorted_weights_ref({max_output_ids}, {1});
ck_tile::HostTensor<IndexType> expert_ids_ref({max_output_ids / unit_size}, {1}); ck_tile::HostTensor<IndexType> expert_ids_ref({max_output_ids / unit_size}, {1});
ck_tile::HostTensor<IndexType> moe_buf_ref({moe_buf_size}, {1});
moe_buf_ref.SetZero();
int32_t total_tokens_post_pad = 0; int32_t total_tokens_post_pad = 0;
ck_tile::reference_moe_sorting<WeightType, IndexType>(topk_ids_host, ck_tile::reference_moe_sorting<WeightType, IndexType>(topk_ids_host,
weights_host, weights_host,
...@@ -171,6 +183,8 @@ bool test_moe_sorting(ck_tile::ArgParser args) ...@@ -171,6 +183,8 @@ bool test_moe_sorting(ck_tile::ArgParser args)
1e-6); 1e-6);
rtn &= ck_tile::check_err( rtn &= ck_tile::check_err(
expert_ids_host, expert_ids_ref, std::string("OUT Error: Incorrect eid!"), 1e-6, 1e-6); expert_ids_host, expert_ids_ref, std::string("OUT Error: Incorrect eid!"), 1e-6, 1e-6);
rtn &= ck_tile::check_err(
moe_buf_host, moe_buf_ref, std::string("OUT Error: Incorrect zero buf!"), 0, 0);
rtn &= total_tokens_post_pad == sorted_id_cnt_host.mData[0]; rtn &= total_tokens_post_pad == sorted_id_cnt_host.mData[0];
} }
......
...@@ -5,13 +5,14 @@ ...@@ -5,13 +5,14 @@
#define MOE_SORTING_DISPATCH(unroll_num_) \ #define MOE_SORTING_DISPATCH(unroll_num_) \
constexpr ck_tile::index_t unroll_num = unroll_num_; \ constexpr ck_tile::index_t unroll_num = unroll_num_; \
using ms_problem = ck_tile::MoeSortingProblem<index_t, ms_weight_type, unroll_num>; \ using ms_problem = ck_tile::MoeSortingProblem<index_t, ms_weight_type, unroll_num>; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \ using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \ auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \ const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \ const dim3 blocks = kernel::BlockSize(a); \
float ave_time = \ const auto lds_bytes = kernel::GetSmemSize(a); \
ck_tile::launch_kernel(s, ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs)); \ float ave_time = ck_tile::launch_kernel( \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time; return ave_time;
float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s) float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s)
......
...@@ -20,10 +20,12 @@ struct MoeSortingHostArgs ...@@ -20,10 +20,12 @@ struct MoeSortingHostArgs
void* sorted_weights; void* sorted_weights;
void* expert_ids; void* expert_ids;
void* total_tokens_post_pad; void* total_tokens_post_pad;
void* moe_buf;
index_t tokens; index_t tokens;
index_t unit_size; index_t unit_size;
index_t num_experts; index_t num_experts;
index_t topk; index_t topk;
index_t moe_buf_set_bytes;
}; };
template <typename Problem_> template <typename Problem_>
...@@ -46,33 +48,32 @@ struct MoeSortingKernel ...@@ -46,33 +48,32 @@ struct MoeSortingKernel
void* sorted_weights; void* sorted_weights;
void* expert_ids; void* expert_ids;
void* total_tokens_post_pad; void* total_tokens_post_pad;
void* moe_buf;
index_t tokens; index_t tokens;
index_t num_experts; index_t num_experts;
index_t moe_buf_set_bytes;
index_t tokens_per_thread; index_t tokens_per_thread;
mdiv unit_size_mdiv; mdiv unit_size_mdiv;
mdiv topk_mdiv; mdiv topk_mdiv;
}; };
CK_TILE_HOST static constexpr auto GridSize(const Hargs&) CK_TILE_HOST static constexpr auto GridSize(const Hargs& h)
{ {
// TODO: assume num-experts not too much // TODO: assume num-experts not too much
return dim3(1); return dim3(1 + ck_tile::integer_divide_ceil(h.moe_buf_set_bytes, BlockSize(h).x * 16));
} }
CK_TILE_HOST static constexpr auto BlockSize(const Hargs& h) CK_TILE_HOST static constexpr auto BlockSize(const Hargs& h)
{ {
// TODO: need pad to multiply of warp size
return dim3(ck_tile::integer_least_multiple(h.num_experts, ck_tile::get_warp_size())); return dim3(ck_tile::integer_least_multiple(h.num_experts, ck_tile::get_warp_size()));
} }
// in byte // in byte
CK_TILE_DEVICE static constexpr index_t GetSmemSize() CK_TILE_HOST static constexpr auto GetSmemSize(const Hargs& h)
{ {
// const auto blocks = BlockSize(h); const auto blocks = BlockSize(h);
// return ((blockDim.x + 1) * k.num_experts + (k.num_experts + 1)) * sizeof(index_t); return ((blocks.x + 1) * h.num_experts + (h.num_experts + 1)) * sizeof(index_t);
// TODO: can not use dynamic calculation. need use static to guide compiler
return 65536;
} }
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h) CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
...@@ -83,9 +84,11 @@ struct MoeSortingKernel ...@@ -83,9 +84,11 @@ struct MoeSortingKernel
k.sorted_token_ids = h.sorted_token_ids; k.sorted_token_ids = h.sorted_token_ids;
k.sorted_weights = h.sorted_weights; k.sorted_weights = h.sorted_weights;
k.expert_ids = h.expert_ids; k.expert_ids = h.expert_ids;
k.moe_buf = h.moe_buf;
k.total_tokens_post_pad = h.total_tokens_post_pad; k.total_tokens_post_pad = h.total_tokens_post_pad;
k.tokens = h.tokens; k.tokens = h.tokens;
k.num_experts = h.num_experts; k.num_experts = h.num_experts;
k.moe_buf_set_bytes = h.moe_buf_set_bytes;
const auto blocks = BlockSize(h); const auto blocks = BlockSize(h);
k.tokens_per_thread = integer_divide_ceil(h.tokens * h.topk, blocks.x); k.tokens_per_thread = integer_divide_ceil(h.tokens * h.topk, blocks.x);
...@@ -99,6 +102,15 @@ struct MoeSortingKernel ...@@ -99,6 +102,15 @@ struct MoeSortingKernel
return row * total_col + col; return row * total_col + col;
} }
CK_TILE_DEVICE void moe_buf_set_zero_kernel(uint8x16_t* buf, index_t buf_bytes) const
{
const index_t offset = (blockIdx.x - 1) * blockDim.x + threadIdx.x;
if(offset < buf_bytes / 16)
{
buf[offset] = uint8x16_t(0);
}
}
CK_TILE_DEVICE void moe_align_block_size_kernel(const IndexType* __restrict__ topk_id, CK_TILE_DEVICE void moe_align_block_size_kernel(const IndexType* __restrict__ topk_id,
const WeightType* __restrict__ weights, const WeightType* __restrict__ weights,
index_t* sorted_token_ids, index_t* sorted_token_ids,
...@@ -192,8 +204,13 @@ struct MoeSortingKernel ...@@ -192,8 +204,13 @@ struct MoeSortingKernel
CK_TILE_DEVICE void operator()(Kargs kargs) const CK_TILE_DEVICE void operator()(Kargs kargs) const
{ {
if(blockIdx.x > 0)
{
return moe_buf_set_zero_kernel(reinterpret_cast<uint8x16_t*>(kargs.moe_buf),
kargs.moe_buf_set_bytes);
}
const size_t numel = kargs.tokens * kargs.topk_mdiv.divisor; const size_t numel = kargs.tokens * kargs.topk_mdiv.divisor;
__shared__ char smem[GetSmemSize()]; extern __shared__ char smem[];
return moe_align_block_size_kernel(static_cast<const IndexType*>(kargs.p_topk_ids), return moe_align_block_size_kernel(static_cast<const IndexType*>(kargs.p_topk_ids),
static_cast<const WeightType*>(kargs.p_weights), static_cast<const WeightType*>(kargs.p_weights),
static_cast<IndexType*>(kargs.sorted_token_ids), static_cast<IndexType*>(kargs.sorted_token_ids),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment