Unverified Commit 3bbb2046 authored by Xin Yang's avatar Xin Yang Committed by GitHub
Browse files

[Bugfix] Fix expert_ids padding values in moe_align_block_size kernel (#35161)


Signed-off-by: default avatarXin Yang <xyangx@amazon.com>
parent 576fe503
...@@ -172,7 +172,7 @@ __device__ void _moe_align_block_size( ...@@ -172,7 +172,7 @@ __device__ void _moe_align_block_size(
} }
} }
// Fill remaining expert_ids with 0 // Fill remaining expert_ids with -1
const size_t fill_start_idx = const size_t fill_start_idx =
cumsum[cumsum_offset + num_experts] / block_size + threadIdx.x; cumsum[cumsum_offset + num_experts] / block_size + threadIdx.x;
for (size_t i = fill_start_idx; i < max_num_m_blocks; i += blockDim.x) { for (size_t i = fill_start_idx; i < max_num_m_blocks; i += blockDim.x) {
...@@ -265,7 +265,7 @@ __device__ void _moe_align_block_size_small_batch_expert( ...@@ -265,7 +265,7 @@ __device__ void _moe_align_block_size_small_batch_expert(
} }
} }
// Fill remaining expert_ids with 0 // Fill remaining expert_ids with -1
const size_t fill_start_idx = cumsum[num_experts] / block_size + tid; const size_t fill_start_idx = cumsum[num_experts] / block_size + tid;
for (size_t i = fill_start_idx; i < max_num_m_blocks; i += stride) { for (size_t i = fill_start_idx; i < max_num_m_blocks; i += stride) {
expert_ids[expert_ids_offset + i] = inactive_expert_id; expert_ids[expert_ids_offset + i] = inactive_expert_id;
...@@ -332,7 +332,7 @@ __global__ void moe_align_block_size_kernel( ...@@ -332,7 +332,7 @@ __global__ void moe_align_block_size_kernel(
topk_ids, sorted_token_ids, expert_ids, total_tokens_post_pad, expert_map, topk_ids, sorted_token_ids, expert_ids, total_tokens_post_pad, expert_map,
num_experts, padded_num_experts, experts_per_warp, block_size, numel, num_experts, padded_num_experts, experts_per_warp, block_size, numel,
cumsum, max_num_tokens_padded, CEILDIV(max_num_tokens_padded, block_size), cumsum, max_num_tokens_padded, CEILDIV(max_num_tokens_padded, block_size),
0, 0, topk_num, nullptr, has_expert_map); 0, -1, topk_num, nullptr, has_expert_map);
} }
template <typename scalar_t> template <typename scalar_t>
...@@ -373,7 +373,7 @@ __global__ void moe_align_block_size_small_batch_expert_kernel( ...@@ -373,7 +373,7 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
_moe_align_block_size_small_batch_expert<scalar_t, fill_threads>( _moe_align_block_size_small_batch_expert<scalar_t, fill_threads>(
topk_ids, sorted_token_ids, expert_ids, total_tokens_post_pad, expert_map, topk_ids, sorted_token_ids, expert_ids, total_tokens_post_pad, expert_map,
num_experts, block_size, numel, max_num_tokens_padded, num_experts, block_size, numel, max_num_tokens_padded,
CEILDIV(max_num_tokens_padded, block_size), 0, 0, topk_num, nullptr, CEILDIV(max_num_tokens_padded, block_size), -1, 0, topk_num, nullptr,
has_expert_map); has_expert_map);
} }
......
...@@ -12,7 +12,7 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( ...@@ -12,7 +12,7 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
batched_moe_align_block_size, batched_moe_align_block_size,
moe_align_block_size, moe_align_block_size,
) )
from vllm.utils.math_utils import round_up from vllm.utils.math_utils import cdiv, round_up
from vllm.utils.torch_utils import set_random_seed from vllm.utils.torch_utils import set_random_seed
NUM_TOKENS = [1, 3, 256, 2256, 4096] NUM_TOKENS = [1, 3, 256, 2256, 4096]
...@@ -142,7 +142,9 @@ def torch_moe_align_block_size( ...@@ -142,7 +142,9 @@ def torch_moe_align_block_size(
device=topk_ids.device, device=topk_ids.device,
) )
max_num_blocks = (max_num_tokens_padded + block_size - 1) // block_size max_num_blocks = (max_num_tokens_padded + block_size - 1) // block_size
expert_ids = torch.zeros(max_num_blocks, dtype=torch.int32, device=topk_ids.device) expert_ids = torch.full(
(max_num_blocks,), -1, dtype=torch.int32, device=topk_ids.device
)
current_pos = 0 current_pos = 0
current_block = 0 current_block = 0
...@@ -234,9 +236,10 @@ def test_moe_align_block_size( ...@@ -234,9 +236,10 @@ def test_moe_align_block_size(
assert len(valid_tokens) == total_tokens, ( assert len(valid_tokens) == total_tokens, (
f"Should have exactly {total_tokens} valid tokens, got {len(valid_tokens)}" f"Should have exactly {total_tokens} valid tokens, got {len(valid_tokens)}"
) )
assert (actual_expert_ids >= 0).all() and (actual_expert_ids < num_experts).all(), ( actual_num_blocks = cdiv(int(actual_num_tokens.item()), block_size)
"expert_ids should contain valid expert indices" assert (actual_expert_ids[:actual_num_blocks] >= 0).all() and (
) actual_expert_ids[:actual_num_blocks] < num_experts
).all(), "expert_ids should contain valid expert indices"
@pytest.mark.parametrize("m", [16, 32, 2048]) @pytest.mark.parametrize("m", [16, 32, 2048])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment