"...csrc/git@developer.sourcefind.cn:change/sglang.git" did not exist on "fb4ce17de697643ca602248810307e929af847e9"
Commit 8f4dc357 authored by dummycoderfe's avatar dummycoderfe
Browse files

add an loop unroll for moe lds ops

parent 68952cba
...@@ -3,6 +3,17 @@ ...@@ -3,6 +3,17 @@
#include "moe_sorting_api.hpp" #include "moe_sorting_api.hpp"
#define MOE_SORTING_DISPATCH(unroll_num_) \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
using ms_problem = ck_tile::MoeSortingProblem<index_t, ms_weight_type, unroll_num>; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
float ave_time = \
ck_tile::launch_kernel(s, ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs)); \
return ave_time;
float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s) float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s)
{ {
if(t.weight_type == "fp32" && t.index_type == "int32") if(t.weight_type == "fp32" && t.index_type == "int32")
...@@ -14,14 +25,43 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi ...@@ -14,14 +25,43 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
} }
using index_t = ck_tile::index_t; using index_t = ck_tile::index_t;
using ms_weight_type = float; using ms_weight_type = float;
using ms_problem = ck_tile::MoeSortingProblem<index_t, ms_weight_type>; index_t smem_io_unroll_num = ck_tile::integer_divide_ceil(a.tokens * a.topk, 64);
using kernel = ck_tile::MoeSortingKernel<ms_problem>; switch(smem_io_unroll_num)
auto kargs = kernel::MakeKargs(a); {
const dim3 grids = kernel::GridSize(a); case(1): {
const dim3 blocks = kernel::BlockSize(a); MOE_SORTING_DISPATCH(1);
float ave_time = }
ck_tile::launch_kernel(s, ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs)); case(2): {
return ave_time; MOE_SORTING_DISPATCH(2);
}
case(3): {
MOE_SORTING_DISPATCH(3);
}
case(5): {
MOE_SORTING_DISPATCH(5);
}
case(6): {
MOE_SORTING_DISPATCH(6);
}
case(7): {
MOE_SORTING_DISPATCH(7);
}
case(8): {
MOE_SORTING_DISPATCH(8);
}
case(9): {
MOE_SORTING_DISPATCH(9);
}
case(10): {
MOE_SORTING_DISPATCH(10);
}
case(11): {
MOE_SORTING_DISPATCH(11);
}
default: {
MOE_SORTING_DISPATCH(4);
}
}
} }
return -1; return -1;
} }
...@@ -63,7 +63,7 @@ struct MoeSortingKernel ...@@ -63,7 +63,7 @@ struct MoeSortingKernel
CK_TILE_HOST static constexpr auto BlockSize(const Hargs& h) CK_TILE_HOST static constexpr auto BlockSize(const Hargs& h)
{ {
// TODO: need pad to multiply of warp size // TODO: need pad to multiply of warp size
return dim3(ck_tile::max(h.num_experts, ck_tile::get_warp_size())); return dim3(ck_tile::integer_least_multiple(h.num_experts, ck_tile::get_warp_size()));
} }
// in byte // in byte
...@@ -119,12 +119,11 @@ struct MoeSortingKernel ...@@ -119,12 +119,11 @@ struct MoeSortingKernel
index_t* tokens_cnts = shared_mem; // 2d: (blockDim.x + 1, num_experts) index_t* tokens_cnts = shared_mem; // 2d: (blockDim.x + 1, num_experts)
index_t* cumsum = shared_mem + (blockDim.x + 1) * num_experts; // 1: (num_experts + 1) index_t* cumsum = shared_mem + (blockDim.x + 1) * num_experts; // 1: (num_experts + 1)
for(int i = 0; i < num_experts; ++i) for(int i = 0; i < num_experts; ++i)
{ {
tokens_cnts[calc_index(num_experts, tid + 1, i)] = 0; tokens_cnts[calc_index(num_experts, tid + 1, i)] = 0;
} }
#pragma unroll Problem_::InternalLoadUnroll
for(int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) for(int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i)
{ {
++tokens_cnts[calc_index(num_experts, tid + 1, topk_id[i])]; ++tokens_cnts[calc_index(num_experts, tid + 1, topk_id[i])];
...@@ -157,7 +156,6 @@ struct MoeSortingKernel ...@@ -157,7 +156,6 @@ struct MoeSortingKernel
} }
*total_tokens_post_pad = unit_size_mdiv.div(cumsum[num_experts]); *total_tokens_post_pad = unit_size_mdiv.div(cumsum[num_experts]);
} }
__syncthreads(); __syncthreads();
if(tid < num_experts) if(tid < num_experts)
{ {
...@@ -167,6 +165,7 @@ struct MoeSortingKernel ...@@ -167,6 +165,7 @@ struct MoeSortingKernel
} }
} }
#pragma unroll Problem_::InternalLoadUnroll
for(int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) for(int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i)
{ {
index_t expert_id = topk_id[i]; index_t expert_id = topk_id[i];
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
namespace ck_tile { namespace ck_tile {
template <typename IndexType_, typename WeightType_> template <typename IndexType_, typename WeightType_, index_t InternalLoadUnroll_>
struct MoeSortingProblem struct MoeSortingProblem
{ {
// TODO: this kernel only support warp per row // TODO: this kernel only support warp per row
...@@ -18,5 +18,6 @@ struct MoeSortingProblem ...@@ -18,5 +18,6 @@ struct MoeSortingProblem
static constexpr index_t WarpSize = get_warp_size(); static constexpr index_t WarpSize = get_warp_size();
static constexpr index_t WarpsPerBlock = 1; static constexpr index_t WarpsPerBlock = 1;
static constexpr index_t InternalLoadUnroll = InternalLoadUnroll_;
}; };
} // namespace ck_tile } // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment