Commit 8f4dc357 authored by dummycoderfe's avatar dummycoderfe
Browse files

add an loop unroll for moe lds ops

parent 68952cba
......@@ -3,6 +3,17 @@
#include "moe_sorting_api.hpp"
#define MOE_SORTING_DISPATCH(unroll_num_) \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
using ms_problem = ck_tile::MoeSortingProblem<index_t, ms_weight_type, unroll_num>; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
float ave_time = \
ck_tile::launch_kernel(s, ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs)); \
return ave_time;
float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s)
{
if(t.weight_type == "fp32" && t.index_type == "int32")
......@@ -12,16 +23,45 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
printf("lds size exceed, only support experts <127 \n");
return -1;
}
using index_t = ck_tile::index_t;
using ms_weight_type = float;
using ms_problem = ck_tile::MoeSortingProblem<index_t, ms_weight_type>;
using kernel = ck_tile::MoeSortingKernel<ms_problem>;
auto kargs = kernel::MakeKargs(a);
const dim3 grids = kernel::GridSize(a);
const dim3 blocks = kernel::BlockSize(a);
float ave_time =
ck_tile::launch_kernel(s, ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs));
return ave_time;
using index_t = ck_tile::index_t;
using ms_weight_type = float;
index_t smem_io_unroll_num = ck_tile::integer_divide_ceil(a.tokens * a.topk, 64);
switch(smem_io_unroll_num)
{
case(1): {
MOE_SORTING_DISPATCH(1);
}
case(2): {
MOE_SORTING_DISPATCH(2);
}
case(3): {
MOE_SORTING_DISPATCH(3);
}
case(5): {
MOE_SORTING_DISPATCH(5);
}
case(6): {
MOE_SORTING_DISPATCH(6);
}
case(7): {
MOE_SORTING_DISPATCH(7);
}
case(8): {
MOE_SORTING_DISPATCH(8);
}
case(9): {
MOE_SORTING_DISPATCH(9);
}
case(10): {
MOE_SORTING_DISPATCH(10);
}
case(11): {
MOE_SORTING_DISPATCH(11);
}
default: {
MOE_SORTING_DISPATCH(4);
}
}
}
return -1;
}
......@@ -63,7 +63,7 @@ struct MoeSortingKernel
CK_TILE_HOST static constexpr auto BlockSize(const Hargs& h)
{
// TODO: need pad to multiply of warp size
return dim3(ck_tile::max(h.num_experts, ck_tile::get_warp_size()));
return dim3(ck_tile::integer_least_multiple(h.num_experts, ck_tile::get_warp_size()));
}
// in byte
......@@ -119,12 +119,11 @@ struct MoeSortingKernel
index_t* tokens_cnts = shared_mem; // 2d: (blockDim.x + 1, num_experts)
index_t* cumsum = shared_mem + (blockDim.x + 1) * num_experts; // 1: (num_experts + 1)
for(int i = 0; i < num_experts; ++i)
{
tokens_cnts[calc_index(num_experts, tid + 1, i)] = 0;
}
#pragma unroll Problem_::InternalLoadUnroll
for(int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i)
{
++tokens_cnts[calc_index(num_experts, tid + 1, topk_id[i])];
......@@ -157,7 +156,6 @@ struct MoeSortingKernel
}
*total_tokens_post_pad = unit_size_mdiv.div(cumsum[num_experts]);
}
__syncthreads();
if(tid < num_experts)
{
......@@ -167,6 +165,7 @@ struct MoeSortingKernel
}
}
#pragma unroll Problem_::InternalLoadUnroll
for(int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i)
{
index_t expert_id = topk_id[i];
......
......@@ -9,14 +9,15 @@
namespace ck_tile {
template <typename IndexType_, typename WeightType_>
template <typename IndexType_, typename WeightType_, index_t InternalLoadUnroll_>
struct MoeSortingProblem
{
// TODO: this kernel only support warp per row
using WeightType = remove_cvref_t<WeightType_>;
using IndexType = remove_cvref_t<IndexType_>;
static constexpr index_t WarpSize = get_warp_size();
static constexpr index_t WarpsPerBlock = 1;
static constexpr index_t WarpSize = get_warp_size();
static constexpr index_t WarpsPerBlock = 1;
static constexpr index_t InternalLoadUnroll = InternalLoadUnroll_;
};
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment