Unverified Commit 77d1210b authored by HandH1998's avatar HandH1998 Committed by GitHub
Browse files

fix moe_align_block_size (#2615)

parent 70dc2fbe
...@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" ...@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "sgl-kernel" name = "sgl-kernel"
version = "0.0.2.post9" version = "0.0.2.post10"
description = "Kernel Library for SGLang" description = "Kernel Library for SGLang"
readme = "README.md" readme = "README.md"
requires-python = ">=3.8" requires-python = ">=3.8"
......
...@@ -118,31 +118,19 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, int ...@@ -118,31 +118,19 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, int
} }
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size, void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size,
torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad,
torch::Tensor num_tokens_post_pad) { torch::Tensor token_cnts_buffer, torch::Tensor cumsum_buffer) {
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
// calc needed amount of shared mem for `tokens_cnts` and `cumsum` // calc needed amount of shared mem for `tokens_cnts` and `cumsum`
// tensors // tensors
const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE); const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
const int32_t mem_tokens_cnts = ((num_experts + 1) * num_experts) * sizeof(int32_t);
const int32_t mem_cumsum = (num_experts + 1) * sizeof(int32_t);
// allocate global memory
int32_t* tokens_cnts;
int32_t* cumsum;
cudaMalloc(&tokens_cnts, mem_tokens_cnts);
cudaMalloc(&cumsum, mem_cumsum);
// set dynamic shared mem
auto kernel = moe_align_block_size_kernel<scalar_t>; auto kernel = moe_align_block_size_kernel<scalar_t>;
kernel<<<1, num_thread, 0, stream>>>(topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(), kernel<<<1, num_thread, 0, stream>>>(topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(), num_tokens_post_pad.data_ptr<int32_t>(), experts_ids.data_ptr<int32_t>(), num_tokens_post_pad.data_ptr<int32_t>(),
num_experts, block_size, topk_ids.numel(), tokens_cnts, cumsum); num_experts, block_size, topk_ids.numel(),
token_cnts_buffer.data_ptr<int32_t>(), cumsum_buffer.data_ptr<int32_t>());
cudaFree(tokens_cnts);
cudaFree(cumsum);
}); });
} }
......
...@@ -8,6 +8,8 @@ def moe_align_block_size( ...@@ -8,6 +8,8 @@ def moe_align_block_size(
sorted_token_ids, sorted_token_ids,
experts_ids, experts_ids,
num_tokens_post_pad, num_tokens_post_pad,
token_cnts_buffer,
cumsum_buffer,
): ):
_moe_align_block_size( _moe_align_block_size(
topk_ids, topk_ids,
...@@ -16,4 +18,6 @@ def moe_align_block_size( ...@@ -16,4 +18,6 @@ def moe_align_block_size(
sorted_token_ids, sorted_token_ids,
experts_ids, experts_ids,
num_tokens_post_pad, num_tokens_post_pad,
token_cnts_buffer,
cumsum_buffer,
) )
...@@ -18,8 +18,22 @@ def test_moe_align_block_size(): ...@@ -18,8 +18,22 @@ def test_moe_align_block_size():
) )
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
token_cnts_buffer = torch.empty(
(num_experts + 1) * num_experts, dtype=torch.int32, device=topk_ids.device
)
cumsum_buffer = torch.empty(
num_experts + 1, dtype=torch.int32, device=topk_ids.device
)
moe_align_block_size( moe_align_block_size(
topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad topk_ids,
num_experts,
block_size,
sorted_ids,
expert_ids,
num_tokens_post_pad,
token_cnts_buffer,
cumsum_buffer,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment