Unverified Commit b93ef5e5 authored by lukec's avatar lukec Committed by GitHub
Browse files

Remove the vllm dependency from the moe_align function (#4164)


Co-authored-by: default avatarHongbosherlock <hongbosherlock@gmail.com>
parent d4017a6b
...@@ -47,18 +47,18 @@ __global__ void moe_align_block_size_kernel( ...@@ -47,18 +47,18 @@ __global__ void moe_align_block_size_kernel(
int32_t* __restrict__ expert_ids, int32_t* __restrict__ expert_ids,
int32_t* __restrict__ total_tokens_post_pad, int32_t* __restrict__ total_tokens_post_pad,
int32_t num_experts, int32_t num_experts,
int32_t experts_per_warp,
int32_t block_size, int32_t block_size,
size_t numel, size_t numel,
int32_t* __restrict__ cumsum) { int32_t* __restrict__ cumsum) {
__shared__ int32_t shared_counts[WARP_SIZE][8]; extern __shared__ int32_t shared_counts[];
const int warp_id = threadIdx.x / WARP_SIZE; const int warp_id = threadIdx.x / WARP_SIZE;
const int experts_per_warp = 8;
const int my_expert_start = warp_id * experts_per_warp; const int my_expert_start = warp_id * experts_per_warp;
for (int i = 0; i < experts_per_warp; ++i) { for (int i = 0; i < experts_per_warp; ++i) {
if (my_expert_start + i < num_experts) { if (my_expert_start + i < num_experts) {
shared_counts[warp_id][i] = 0; shared_counts[warp_id * experts_per_warp + i] = 0;
} }
} }
...@@ -71,7 +71,7 @@ __global__ void moe_align_block_size_kernel( ...@@ -71,7 +71,7 @@ __global__ void moe_align_block_size_kernel(
int expert_id = topk_ids[i]; int expert_id = topk_ids[i];
int warp_idx = expert_id / experts_per_warp; int warp_idx = expert_id / experts_per_warp;
int expert_offset = expert_id % experts_per_warp; int expert_offset = expert_id % experts_per_warp;
atomicAdd(&shared_counts[warp_idx][expert_offset], 1); atomicAdd(&shared_counts[warp_idx * experts_per_warp + expert_offset], 1);
} }
__syncthreads(); __syncthreads();
...@@ -82,7 +82,7 @@ __global__ void moe_align_block_size_kernel( ...@@ -82,7 +82,7 @@ __global__ void moe_align_block_size_kernel(
int expert_count = 0; int expert_count = 0;
int warp_idx = (i - 1) / experts_per_warp; int warp_idx = (i - 1) / experts_per_warp;
int expert_offset = (i - 1) % experts_per_warp; int expert_offset = (i - 1) % experts_per_warp;
expert_count = shared_counts[warp_idx][expert_offset]; expert_count = shared_counts[warp_idx * experts_per_warp + expert_offset];
cumsum[i] = cumsum[i - 1] + CEILDIV(expert_count, block_size) * block_size; cumsum[i] = cumsum[i - 1] + CEILDIV(expert_count, block_size) * block_size;
} }
...@@ -108,16 +108,18 @@ void moe_align_block_size( ...@@ -108,16 +108,18 @@ void moe_align_block_size(
torch::Tensor token_cnts_buffer, torch::Tensor token_cnts_buffer,
torch::Tensor cumsum_buffer) { torch::Tensor cumsum_buffer) {
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
TORCH_CHECK(num_experts == 256, "moe_align_block_size kernel only support deepseek v3 now."); TORCH_CHECK(num_experts % WARP_SIZE == 0);
int experts_per_warp = num_experts / WARP_SIZE;
DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
auto align_kernel = moe_align_block_size_kernel<scalar_t>; auto align_kernel = moe_align_block_size_kernel<scalar_t>;
align_kernel<<<1, 1024, 0, stream>>>( size_t shared_mem_size = 32 * experts_per_warp * sizeof(int32_t);
align_kernel<<<1, 1024, shared_mem_size, stream>>>(
topk_ids.data_ptr<scalar_t>(), topk_ids.data_ptr<scalar_t>(),
sorted_token_ids.data_ptr<int32_t>(), sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(), experts_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>(), num_tokens_post_pad.data_ptr<int32_t>(),
num_experts, num_experts,
experts_per_warp,
block_size, block_size,
topk_ids.numel(), topk_ids.numel(),
cumsum_buffer.data_ptr<int32_t>()); cumsum_buffer.data_ptr<int32_t>());
......
...@@ -138,18 +138,20 @@ def moe_align_block_size_triton( ...@@ -138,18 +138,20 @@ def moe_align_block_size_triton(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"block_size,num_tokens,topk", "block_size,num_tokens,topk,num_experts",
list( list(
itertools.product( itertools.product(
[32, 64, 128, 256], # block_size [32, 64, 128, 256], # block_size
[1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096], # num_tokens [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096], # num_tokens
[1, 2, 4, 8, 16, 32, 64], # topk [1, 2, 4, 8, 16, 32, 64], # topk
[64, 160, 256], # num_experts
) )
), ),
) )
def test_moe_align_block_size_compare_implementations(block_size, num_tokens, topk): def test_moe_align_block_size_compare_implementations(
block_size, num_tokens, topk, num_experts
):
# For DeepSeek V3, we have 256 experts # For DeepSeek V3, we have 256 experts
num_experts = 256
topk_ids = torch.stack( topk_ids = torch.stack(
[ [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment