Commit a4140bb9 authored by carlushuang's avatar carlushuang
Browse files

use magiv div to accelerate compute

parent a61ccfe8
...@@ -110,16 +110,16 @@ bool test_moe_sorting(ck_tile::ArgParser args) ...@@ -110,16 +110,16 @@ bool test_moe_sorting(ck_tile::ArgParser args)
moe_sorting_trait trait{index_prec, weight_prec, experts, topk, unit_size, tokens}; moe_sorting_trait trait{index_prec, weight_prec, experts, topk, unit_size, tokens};
moe_sorting_kargs karg{topk_ids_dev.GetDeviceBuffer(), moe_sorting_args karg{topk_ids_dev.GetDeviceBuffer(),
weights_dev.GetDeviceBuffer(), weights_dev.GetDeviceBuffer(),
sorted_ids_dev.GetDeviceBuffer(), sorted_ids_dev.GetDeviceBuffer(),
sorted_weights_dev.GetDeviceBuffer(), sorted_weights_dev.GetDeviceBuffer(),
expert_ids_dev.GetDeviceBuffer(), expert_ids_dev.GetDeviceBuffer(),
sorted_id_cnt_dev.GetDeviceBuffer(), sorted_id_cnt_dev.GetDeviceBuffer(),
tokens, tokens,
unit_size, unit_size,
experts, experts,
topk}; topk};
ck_tile::stream_config sc{nullptr, ck_tile::stream_config sc{nullptr,
true, true,
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include "moe_sorting_api.hpp" #include "moe_sorting_api.hpp"
float moe_sorting(moe_sorting_trait t, moe_sorting_kargs a, ck_tile::stream_config s) float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s)
{ {
if(t.weight_type == "fp32" && t.index_type == "int32") if(t.weight_type == "fp32" && t.index_type == "int32")
{ {
...@@ -15,14 +15,12 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_kargs a, ck_tile::stream_conf ...@@ -15,14 +15,12 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_kargs a, ck_tile::stream_conf
using index_t = ck_tile::index_t; using index_t = ck_tile::index_t;
using ms_weight_type = float; using ms_weight_type = float;
using ms_problem = ck_tile::MoeSortingProblem<index_t, ms_weight_type>; using ms_problem = ck_tile::MoeSortingProblem<index_t, ms_weight_type>;
// using ms_pipeline = ck_tile::MoeSortingPipeline<ms_problem>; using kernel = ck_tile::MoeSortingKernel<ms_problem>;
using kernel = ck_tile::MoeSortingKernel<ms_problem>; auto kargs = kernel::MakeKargs(a);
auto kargs = kernel::MakeKargs(a); const dim3 grids = kernel::GridSize(a);
const dim3 grids = 1; const dim3 blocks = kernel::BlockSize(a);
const dim3 blocks = ck_tile::max(t.experts, ck_tile::get_warp_size()); float ave_time =
const size_t lds_size = ((blocks.x + 1) * t.experts + (t.experts + 1)) * sizeof(index_t); ck_tile::launch_kernel(s, ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs));
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_size, kargs));
return ave_time; return ave_time;
} }
return -1; return -1;
......
...@@ -17,8 +17,8 @@ struct moe_sorting_trait ...@@ -17,8 +17,8 @@ struct moe_sorting_trait
int tokens; int tokens;
}; };
struct moe_sorting_kargs : public ck_tile::MoeSortingHostArgs struct moe_sorting_args : public ck_tile::MoeSortingHostArgs
{ {
}; };
float moe_sorting(moe_sorting_trait t, moe_sorting_kargs a, ck_tile::stream_config s); float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s);
...@@ -7,4 +7,5 @@ ...@@ -7,4 +7,5 @@
#include "ck_tile/ops/moe_sorting/pipeline/moe_sorting_pipeline.hpp" #include "ck_tile/ops/moe_sorting/pipeline/moe_sorting_pipeline.hpp"
#include "ck_tile/ops/moe_sorting/pipeline/moe_sorting_policy.hpp" #include "ck_tile/ops/moe_sorting/pipeline/moe_sorting_policy.hpp"
#include "ck_tile/ops/moe_sorting/pipeline/moe_sorting_problem.hpp" #include "ck_tile/ops/moe_sorting/pipeline/moe_sorting_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
...@@ -29,7 +29,6 @@ struct MoeSortingHostArgs ...@@ -29,7 +29,6 @@ struct MoeSortingHostArgs
template <typename Problem_> template <typename Problem_>
struct MoeSortingKernel struct MoeSortingKernel
{ {
// using Pipeline = remove_cvref_t<Pipeline_>;
using Problem = remove_cvref_t<Problem_>; using Problem = remove_cvref_t<Problem_>;
using IndexType = typename Problem::IndexType; using IndexType = typename Problem::IndexType;
...@@ -37,10 +36,63 @@ struct MoeSortingKernel ...@@ -37,10 +36,63 @@ struct MoeSortingKernel
typedef MoeSortingHostArgs MoeSortingKargs; typedef MoeSortingHostArgs MoeSortingKargs;
using Kargs = MoeSortingKargs;
using Hargs = MoeSortingHostArgs; using Hargs = MoeSortingHostArgs;
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h) { return h; } struct Kargs
{
const void* p_topk_ids;
const void* p_weights;
void* sorted_token_ids;
void* sorted_weights;
void* expert_ids;
void* total_tokens_post_pad;
index_t tokens;
index_t num_experts;
index_t tokens_per_thread;
mdiv unit_size_mdiv;
mdiv topk_mdiv;
};
CK_TILE_HOST static constexpr auto GridSize(const Hargs&)
{
// TODO: assume num-experts not too much
return dim3(1);
}
CK_TILE_HOST static constexpr auto BlockSize(const Hargs& h)
{
// TODO: need pad to multiply of warp size
return dim3(ck_tile::max(h.num_experts, ck_tile::get_warp_size()));
}
// in byte
CK_TILE_DEVICE static constexpr index_t GetSmemSize()
{
// const auto blocks = BlockSize(h);
// return ((blockDim.x + 1) * k.num_experts + (k.num_experts + 1)) * sizeof(index_t);
// TODO: can not use dynamic calculation. need use static to guide compiler
return 65536;
}
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
{
Kargs k;
k.p_topk_ids = h.p_topk_ids;
k.p_weights = h.p_weights;
k.sorted_token_ids = h.sorted_token_ids;
k.sorted_weights = h.sorted_weights;
k.expert_ids = h.expert_ids;
k.total_tokens_post_pad = h.total_tokens_post_pad;
k.tokens = h.tokens;
k.num_experts = h.num_experts;
const auto blocks = BlockSize(h);
k.tokens_per_thread = integer_divide_ceil(h.tokens * h.topk, blocks.x);
k.unit_size_mdiv = mdiv{static_cast<uint32_t>(h.unit_size)};
k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)};
return k;
}
CK_TILE_DEVICE index_t calc_index(index_t total_col, index_t row, index_t col) const CK_TILE_DEVICE index_t calc_index(index_t total_col, index_t row, index_t col) const
{ {
...@@ -54,15 +106,16 @@ struct MoeSortingKernel ...@@ -54,15 +106,16 @@ struct MoeSortingKernel
index_t* expert_ids, index_t* expert_ids,
index_t* total_tokens_post_pad, index_t* total_tokens_post_pad,
const index_t num_experts, const index_t num_experts,
const index_t unit_size, const index_t tokens_per_thread,
const index_t numel, const index_t numel,
const index_t topk) const const mdiv unit_size_mdiv,
const mdiv topk_mdiv,
void* smem) const
{ {
const index_t tokens_per_thread = integer_divide_ceil(numel, blockDim.x); const index_t tid = static_cast<index_t>(threadIdx.x);
const index_t tid = static_cast<index_t>(threadIdx.x); const index_t start_idx = tid * tokens_per_thread;
const index_t start_idx = tid * tokens_per_thread;
extern __shared__ index_t shared_mem[]; index_t* shared_mem = reinterpret_cast<index_t*>(smem);
index_t* tokens_cnts = shared_mem; // 2d: (blockDim.x + 1, num_experts) index_t* tokens_cnts = shared_mem; // 2d: (blockDim.x + 1, num_experts)
index_t* cumsum = shared_mem + (blockDim.x + 1) * num_experts; // 1: (num_experts + 1) index_t* cumsum = shared_mem + (blockDim.x + 1) * num_experts; // 1: (num_experts + 1)
...@@ -88,28 +141,29 @@ struct MoeSortingKernel ...@@ -88,28 +141,29 @@ struct MoeSortingKernel
} }
} }
__syncthreads(); // __syncthreads();
if(tid == 0) if(tid == 0)
{ {
cumsum[0] = 0; cumsum[0] = 0;
for(int i = 1; i <= num_experts; ++i) for(int i = 1; i <= num_experts; ++i)
{ {
cumsum[i] = auto current_units = [&]() {
cumsum[i - 1] + index_t x_ = tokens_cnts[calc_index(num_experts, blockDim.x, i - 1)] +
max(integer_divide_ceil(tokens_cnts[calc_index(num_experts, blockDim.x, i - 1)], unit_size_mdiv.divisor - 1;
unit_size), index_t y_ = unit_size_mdiv.div(x_);
1) * return max(y_, 1) * unit_size_mdiv.divisor;
unit_size; }();
cumsum[i] = cumsum[i - 1] + current_units;
} }
*total_tokens_post_pad = cumsum[num_experts] / unit_size; *total_tokens_post_pad = unit_size_mdiv.div(cumsum[num_experts]);
} }
__syncthreads(); __syncthreads();
if(tid < num_experts) if(tid < num_experts)
{ {
for(int i = cumsum[tid]; i < cumsum[tid + 1]; i += unit_size) for(int i = cumsum[tid]; i < cumsum[tid + 1]; i += unit_size_mdiv.divisor)
{ {
expert_ids[i / unit_size] = tid; expert_ids[unit_size_mdiv.div(i)] = tid;
} }
} }
...@@ -118,11 +172,12 @@ struct MoeSortingKernel ...@@ -118,11 +172,12 @@ struct MoeSortingKernel
index_t expert_id = topk_id[i]; index_t expert_id = topk_id[i];
index_t rank_post_pad = index_t rank_post_pad =
tokens_cnts[calc_index(num_experts, tid, expert_id)] + cumsum[expert_id]; tokens_cnts[calc_index(num_experts, tid, expert_id)] + cumsum[expert_id];
sorted_token_ids[rank_post_pad] = i / topk; sorted_token_ids[rank_post_pad] = topk_mdiv.div(i);
sorted_weights[rank_post_pad] = weights[i]; sorted_weights[rank_post_pad] = weights[i];
++tokens_cnts[calc_index(num_experts, tid, expert_id)]; ++tokens_cnts[calc_index(num_experts, tid, expert_id)];
} }
const index_t prefill_token = numel / topk;
const index_t prefill_token = topk_mdiv.div(numel);
if(tid < num_experts) if(tid < num_experts)
{ {
index_t expert_offset = index_t expert_offset =
...@@ -138,7 +193,8 @@ struct MoeSortingKernel ...@@ -138,7 +193,8 @@ struct MoeSortingKernel
CK_TILE_DEVICE void operator()(Kargs kargs) const CK_TILE_DEVICE void operator()(Kargs kargs) const
{ {
const size_t numel = kargs.tokens * kargs.topk; const size_t numel = kargs.tokens * kargs.topk_mdiv.divisor;
__shared__ char smem[GetSmemSize()];
return moe_align_block_size_kernel(static_cast<const IndexType*>(kargs.p_topk_ids), return moe_align_block_size_kernel(static_cast<const IndexType*>(kargs.p_topk_ids),
static_cast<const WeightType*>(kargs.p_weights), static_cast<const WeightType*>(kargs.p_weights),
static_cast<IndexType*>(kargs.sorted_token_ids), static_cast<IndexType*>(kargs.sorted_token_ids),
...@@ -146,9 +202,11 @@ struct MoeSortingKernel ...@@ -146,9 +202,11 @@ struct MoeSortingKernel
static_cast<IndexType*>(kargs.expert_ids), static_cast<IndexType*>(kargs.expert_ids),
static_cast<IndexType*>(kargs.total_tokens_post_pad), static_cast<IndexType*>(kargs.total_tokens_post_pad),
kargs.num_experts, kargs.num_experts,
kargs.unit_size, kargs.tokens_per_thread,
numel, numel,
kargs.topk); kargs.unit_size_mdiv,
kargs.topk_mdiv,
smem);
} }
}; };
} // namespace ck_tile } // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment