Commit bfe0120a authored by dummycoderfe's avatar dummycoderfe
Browse files

add extblocksnel to set zeros for moebufs

parent 8f4dc357
......@@ -25,6 +25,7 @@ auto create_args(int argc, char* argv[])
.insert("e", "8", "number of experts")
.insert("k", "4", "topk")
.insert("unit", "32", "unit_size")
.insert("moe_buf_size", "-1", "moe_buf_size")
.insert("seed", "-1", "seed to be used, -1 means random every time")
.insert("kname", "0", "when set to 1 it will print kernel name")
.insert("warmup", "5", "number of iterations before benchmark the kernel")
......@@ -69,6 +70,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
int topk = args.get_int("k");
int seed = args.get_int("seed");
int unit_size = args.get_int("unit");
int moe_buf_size = args.get_int("moe_buf_size");
int kname = args.get_int("kname");
int warmup = args.get_int("warmup");
int repeat = args.get_int("repeat");
......@@ -94,8 +96,10 @@ bool test_moe_sorting(ck_tile::ArgParser args)
ck_tile::HostTensor<WeightType> sorted_weights_host({max_output_ids}, {1});
ck_tile::HostTensor<IndexType> expert_ids_host({max_output_ids / unit_size}, {1});
ck_tile::HostTensor<IndexType> sorted_id_cnt_host({1}, {1});
ck_tile::HostTensor<IndexType> moe_buf_host({moe_buf_size}, {1});
ck_tile::FillUniformDistribution<WeightType>{-.5f, .5f}(weights_host);
ck_tile::FillUniformDistribution<WeightType>{-.5f, .5f}(moe_buf_host);
topid_unique_gen<IndexType>(topk_ids_host.mData, tokens, topk, experts, seed);
ck_tile::DeviceMem topk_ids_dev(topk_ids_host.get_element_space_size_in_bytes());
......@@ -104,9 +108,11 @@ bool test_moe_sorting(ck_tile::ArgParser args)
ck_tile::DeviceMem sorted_weights_dev(sorted_weights_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem expert_ids_dev(expert_ids_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem sorted_id_cnt_dev(sorted_id_cnt_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem moe_buf_dev(moe_buf_host.get_element_space_size_in_bytes());
topk_ids_dev.ToDevice(topk_ids_host.data());
weights_dev.ToDevice(weights_host.data());
moe_buf_dev.ToDevice(moe_buf_host.data());
moe_sorting_trait trait{index_prec, weight_prec, experts, topk, unit_size, tokens};
......@@ -116,10 +122,12 @@ bool test_moe_sorting(ck_tile::ArgParser args)
sorted_weights_dev.GetDeviceBuffer(),
expert_ids_dev.GetDeviceBuffer(),
sorted_id_cnt_dev.GetDeviceBuffer(),
moe_buf_dev.GetDeviceBuffer(),
tokens,
unit_size,
experts,
topk};
topk,
moe_buf_size};
ck_tile::stream_config sc{nullptr,
true,
......@@ -146,6 +154,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
sorted_weights_dev.FromDevice(sorted_weights_host.data());
expert_ids_dev.FromDevice(expert_ids_host.data());
sorted_id_cnt_dev.FromDevice(sorted_id_cnt_host.data());
moe_buf_dev.FromDevice(moe_buf_host.data());
bool rtn = true;
if(validate)
......@@ -153,6 +162,9 @@ bool test_moe_sorting(ck_tile::ArgParser args)
ck_tile::HostTensor<IndexType> sorted_ids_ref({max_output_ids}, {1});
ck_tile::HostTensor<WeightType> sorted_weights_ref({max_output_ids}, {1});
ck_tile::HostTensor<IndexType> expert_ids_ref({max_output_ids / unit_size}, {1});
ck_tile::HostTensor<IndexType> moe_buf_ref({moe_buf_size}, {1});
moe_buf_ref.SetZero();
int32_t total_tokens_post_pad = 0;
ck_tile::reference_moe_sorting<WeightType, IndexType>(topk_ids_host,
weights_host,
......@@ -171,6 +183,8 @@ bool test_moe_sorting(ck_tile::ArgParser args)
1e-6);
rtn &= ck_tile::check_err(
expert_ids_host, expert_ids_ref, std::string("OUT Error: Incorrect eid!"), 1e-6, 1e-6);
rtn &= ck_tile::check_err(
moe_buf_host, moe_buf_ref, std::string("OUT Error: Incorrect zero buf!"), 0, 0);
rtn &= total_tokens_post_pad == sorted_id_cnt_host.mData[0];
}
......
......@@ -5,13 +5,14 @@
#define MOE_SORTING_DISPATCH(unroll_num_) \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
using ms_problem = ck_tile::MoeSortingProblem<index_t, ms_weight_type, unroll_num>; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
float ave_time = \
ck_tile::launch_kernel(s, ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs)); \
using ms_problem = ck_tile::MoeSortingProblem<index_t, ms_weight_type, unroll_num>; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
const auto lds_bytes = kernel::GetSmemSize(a); \
float ave_time = ck_tile::launch_kernel( \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time;
float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s)
......
......@@ -20,10 +20,12 @@ struct MoeSortingHostArgs
void* sorted_weights;
void* expert_ids;
void* total_tokens_post_pad;
void* moe_buf;
index_t tokens;
index_t unit_size;
index_t num_experts;
index_t topk;
index_t moe_buf_set_bytes;
};
template <typename Problem_>
......@@ -46,33 +48,32 @@ struct MoeSortingKernel
void* sorted_weights;
void* expert_ids;
void* total_tokens_post_pad;
void* moe_buf;
index_t tokens;
index_t num_experts;
index_t moe_buf_set_bytes;
index_t tokens_per_thread;
mdiv unit_size_mdiv;
mdiv topk_mdiv;
};
CK_TILE_HOST static constexpr auto GridSize(const Hargs&)
CK_TILE_HOST static constexpr auto GridSize(const Hargs& h)
{
// TODO: assume num-experts not too much
return dim3(1);
return dim3(1 + ck_tile::integer_divide_ceil(h.moe_buf_set_bytes, BlockSize(h).x * 16));
}
CK_TILE_HOST static constexpr auto BlockSize(const Hargs& h)
{
// TODO: need pad to multiply of warp size
return dim3(ck_tile::integer_least_multiple(h.num_experts, ck_tile::get_warp_size()));
}
// in byte
CK_TILE_DEVICE static constexpr index_t GetSmemSize()
CK_TILE_HOST static constexpr auto GetSmemSize(const Hargs& h)
{
// const auto blocks = BlockSize(h);
// return ((blockDim.x + 1) * k.num_experts + (k.num_experts + 1)) * sizeof(index_t);
// TODO: can not use dynamic calculation. need use static to guide compiler
return 65536;
const auto blocks = BlockSize(h);
return ((blocks.x + 1) * h.num_experts + (h.num_experts + 1)) * sizeof(index_t);
}
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
......@@ -83,9 +84,11 @@ struct MoeSortingKernel
k.sorted_token_ids = h.sorted_token_ids;
k.sorted_weights = h.sorted_weights;
k.expert_ids = h.expert_ids;
k.moe_buf = h.moe_buf;
k.total_tokens_post_pad = h.total_tokens_post_pad;
k.tokens = h.tokens;
k.num_experts = h.num_experts;
k.moe_buf_set_bytes = h.moe_buf_set_bytes;
const auto blocks = BlockSize(h);
k.tokens_per_thread = integer_divide_ceil(h.tokens * h.topk, blocks.x);
......@@ -99,6 +102,15 @@ struct MoeSortingKernel
return row * total_col + col;
}
CK_TILE_DEVICE void moe_buf_set_zero_kernel(uint8x16_t* buf, index_t buf_bytes) const
{
const index_t offset = (blockIdx.x - 1) * blockDim.x + threadIdx.x;
if(offset < buf_bytes / 16)
{
buf[offset] = uint8x16_t(0);
}
}
CK_TILE_DEVICE void moe_align_block_size_kernel(const IndexType* __restrict__ topk_id,
const WeightType* __restrict__ weights,
index_t* sorted_token_ids,
......@@ -192,8 +204,13 @@ struct MoeSortingKernel
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
if(blockIdx.x > 0)
{
return moe_buf_set_zero_kernel(reinterpret_cast<uint8x16_t*>(kargs.moe_buf),
kargs.moe_buf_set_bytes);
}
const size_t numel = kargs.tokens * kargs.topk_mdiv.divisor;
__shared__ char smem[GetSmemSize()];
extern __shared__ char smem[];
return moe_align_block_size_kernel(static_cast<const IndexType*>(kargs.p_topk_ids),
static_cast<const WeightType*>(kargs.p_weights),
static_cast<IndexType*>(kargs.sorted_token_ids),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment