Unverified Commit 847a57cd authored by Wenlong Wang's avatar Wenlong Wang Committed by GitHub
Browse files

[Bugfix][MoE Kernel] Fix incorrect routing selection for models without expert...


[Bugfix][MoE Kernel] Fix incorrect routing selection for models without expert groups (e.g., MiniMax-M2.1) (#34673)
Signed-off-by: default avatarwwl2755 <wangwenlong2755@gmail.com>
Signed-off-by: default avatarRobert Shaw <robshaw@redhat.com>
Co-authored-by: default avatarRobert Shaw <robshaw@redhat.com>
Co-authored-by: default avatarRobert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
parent fcd6ac97
...@@ -398,80 +398,3 @@ def test_convert_moe_weights_to_flashinfer_trtllm_block_layout( ...@@ -398,80 +398,3 @@ def test_convert_moe_weights_to_flashinfer_trtllm_block_layout(
assert w13_converted.shape[0] == num_experts assert w13_converted.shape[0] == num_experts
assert w2_converted.shape[0] == num_experts assert w2_converted.shape[0] == num_experts
def test_flashinfer_blockscale_fp8_none_expert_group(monkeypatch):
"""Test that flashinfer_fused_moe_blockscale_fp8 handles num_expert_group=None.
Regression test for https://github.com/vllm-project/vllm/issues/34477
MiniMax-M2.1 uses sigmoid scoring with e_score_correction_bias but no
grouped top-k, resulting in num_expert_group=None. This triggered a crash
in the flashinfer kernel when DeepSeekV3 routing was selected.
"""
if not current_platform.has_device_capability(100):
pytest.skip("Test requires SM >= 100 (Blackwell)")
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
from tests.kernels.quant_utils import native_per_token_group_quant_fp8
set_random_seed(7)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
e = 16 # num_experts (must be divisible by 4)
topk = 6 # top_k > 1 triggers DeepSeekV3 routing with sigmoid
m, n, k = 10, 4096, 5120
block_shape = [128, 128]
block_k = block_shape[1]
with set_current_vllm_config(vllm_config):
# Create BF16 hidden states
x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
# Create FP8 block-scale quantized weights
w13_bf16 = torch.randn((e, 2 * n, k), device="cuda", dtype=torch.bfloat16) / 10
w2_bf16 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16) / 10
# Quantize weights per-block to FP8
w13_fp8_list, w13_scale_list = [], []
w2_fp8_list, w2_scale_list = [], []
for i in range(e):
wq, ws = native_per_token_group_quant_fp8(w13_bf16[i], block_k)
w13_fp8_list.append(wq)
w13_scale_list.append(ws)
wq, ws = native_per_token_group_quant_fp8(w2_bf16[i], block_k)
w2_fp8_list.append(wq)
w2_scale_list.append(ws)
w13_fp8 = torch.stack(w13_fp8_list)
w13_scale = torch.stack(w13_scale_list)
w2_fp8 = torch.stack(w2_fp8_list)
w2_scale = torch.stack(w2_scale_list)
# DeepSeekV3 routing uses float32 logits + optional bias
routing_logits = torch.randn((m, e), device="cuda", dtype=torch.float32)
routing_bias = torch.randn(e, device="cuda", dtype=torch.float32)
# This should NOT crash with num_expert_group=None
output = torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
routing_logits=routing_logits,
routing_bias=routing_bias,
x=x,
w13_weight=w13_fp8,
w13_weight_scale_inv=w13_scale,
w2_weight=w2_fp8,
w2_weight_scale_inv=w2_scale,
global_num_experts=e,
top_k=topk,
num_expert_group=None,
topk_group=None,
intermediate_size=n,
expert_offset=0,
local_num_experts=e,
block_shape=block_shape,
routing_method_type=RoutingMethodType.DeepSeekV3,
routed_scaling=1.0,
)
assert output is not None
assert output.shape == (m, k)
...@@ -8,11 +8,7 @@ import torch ...@@ -8,11 +8,7 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.config import ParallelConfig from vllm.config import ParallelConfig
from vllm.distributed import ( from vllm.distributed import get_dp_group, get_pcp_group, get_tensor_model_parallel_rank
get_dp_group,
get_pcp_group,
get_tensor_model_parallel_rank,
)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import ( from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
...@@ -126,20 +122,31 @@ class RoutingMethodType(IntEnum): ...@@ -126,20 +122,31 @@ class RoutingMethodType(IntEnum):
def get_routing_method_type( def get_routing_method_type(
scoring_func: str, top_k: int, renormalize: bool scoring_func: str,
top_k: int,
renormalize: bool,
num_expert_group: int | None,
has_e_score_bias: bool,
) -> RoutingMethodType: ) -> RoutingMethodType:
if has_e_score_bias:
if (num_expert_group or 0) > 0 and scoring_func == "sigmoid":
return RoutingMethodType.DeepSeekV3
else:
return RoutingMethodType.Unspecified
if scoring_func == "sigmoid": if scoring_func == "sigmoid":
if top_k == 1: if top_k == 1:
return RoutingMethodType.Llama4 return RoutingMethodType.Llama4
else: else:
return RoutingMethodType.DeepSeekV3 return RoutingMethodType.Unspecified
elif scoring_func == "softmax":
if scoring_func == "softmax":
if renormalize: if renormalize:
return RoutingMethodType.Renormalize return RoutingMethodType.Renormalize
else: else:
return RoutingMethodType.Default return RoutingMethodType.Default
else:
return RoutingMethodType.Unspecified return RoutingMethodType.Unspecified
@dataclass @dataclass
......
...@@ -165,6 +165,8 @@ class FusedTopKBiasRouter(BaseRouter): ...@@ -165,6 +165,8 @@ class FusedTopKBiasRouter(BaseRouter):
scoring_func=self.scoring_func, scoring_func=self.scoring_func,
top_k=self.top_k, top_k=self.top_k,
renormalize=self.renormalize, renormalize=self.renormalize,
num_expert_group=None,
has_e_score_bias=True,
) )
def _compute_routing( def _compute_routing(
......
...@@ -142,6 +142,8 @@ class FusedTopKRouter(BaseRouter): ...@@ -142,6 +142,8 @@ class FusedTopKRouter(BaseRouter):
scoring_func=self.scoring_func, scoring_func=self.scoring_func,
top_k=self.top_k, top_k=self.top_k,
renormalize=self.renormalize, renormalize=self.renormalize,
num_expert_group=None,
has_e_score_bias=False,
) )
def _compute_routing( def _compute_routing(
......
...@@ -13,7 +13,10 @@ from vllm.model_executor.custom_op import CustomOp ...@@ -13,7 +13,10 @@ from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.batch_invariant import ( from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant, vllm_is_batch_invariant,
) )
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType from vllm.model_executor.layers.fused_moe.config import (
RoutingMethodType,
get_routing_method_type,
)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_grouped_topk, rocm_aiter_grouped_topk,
) )
...@@ -277,16 +280,15 @@ class GroupedTopKRouter(BaseRouter): ...@@ -277,16 +280,15 @@ class GroupedTopKRouter(BaseRouter):
self.e_score_correction_bias = e_score_correction_bias self.e_score_correction_bias = e_score_correction_bias
self.num_fused_shared_experts = num_fused_shared_experts self.num_fused_shared_experts = num_fused_shared_experts
if scoring_func == "sigmoid":
self._routing_method_type = RoutingMethodType.DeepSeekV3
else:
# NOTE: this prohibits the FLASHINFER_TRTLLM kernels from
# being selected, since they only support DeepSeek-style.
self._routing_method_type = RoutingMethodType.Unspecified
@property @property
def routing_method_type(self) -> RoutingMethodType: def routing_method_type(self) -> RoutingMethodType:
return self._routing_method_type return get_routing_method_type(
scoring_func=self.scoring_func,
top_k=self.top_k,
renormalize=self.renormalize,
num_expert_group=self.num_expert_group,
has_e_score_bias=self.e_score_correction_bias is not None,
)
def _compute_routing( def _compute_routing(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment