Unverified Commit 18da2c96 authored by Kaixi Hou's avatar Kaixi Hou Committed by GitHub
Browse files

[NVIDIA] Fix trtllm fp4 moe backend when used in MTP (#9384)

parent 9b5f0f64
......@@ -783,13 +783,17 @@ class DeepEPMoE(EPMoE):
return hidden_states
def get_moe_impl_class():
def get_moe_impl_class(quant_config: Optional[QuantizationConfig] = None):
if get_moe_a2a_backend().is_deepep():
return DeepEPMoE
# NEW: Direct FP4 detection (bypasses EP requirements)
# Check for FP4 quantization with TRTLLM flag, regardless of EP
if get_moe_runner_backend().is_flashinfer_trtllm():
# FlashInferFP4MoE must be paired with ModelOptNvFp4FusedMoEMethod.
# If UnquantizedFusedMoEMethod is detected, fall back to FusedMoE instead.
if quant_config is None:
return FusedMoE
try:
# Check the quantization argument directly
quantization = global_server_args_dict.get("quantization")
......
......@@ -1008,6 +1008,8 @@ class FlashInferFP4MoE(FusedMoE):
hidden_states: Input tensor
topk_output: TopKOutput object with Bypassed format
"""
assert isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod)
assert TopKOutputChecker.format_is_bypassed(topk_output)
router_logits = topk_output.router_logits
......
......@@ -198,6 +198,7 @@ class TopK(CustomOp):
correction_bias: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None,
apply_routed_scaling_factor_on_output: Optional[bool] = False,
force_topk: bool = False,
):
# NOTE: scoring_func is not used for now, but we keep it for future use
# see https://github.com/sgl-project/sglang/pull/4505 for more details
......@@ -220,6 +221,7 @@ class TopK(CustomOp):
)
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
self.force_topk = force_topk
def forward_native(
self,
......@@ -254,7 +256,7 @@ class TopK(CustomOp):
sm_first=not self.topk_config.renormalize,
)
return TritonKernelTopKOutput(routing_data, gather_idx, scatter_idx)
elif (
elif not self.force_topk and (
should_use_flashinfer_trtllm_moe()
or get_moe_runner_backend().is_flashinfer_mxfp4()
):
......
......@@ -319,7 +319,7 @@ class DeepseekV2MoE(nn.Module):
config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
)
self.experts = get_moe_impl_class()(
self.experts = get_moe_impl_class(quant_config)(
num_experts=config.n_routed_experts
+ self.num_fused_shared_experts
+ global_server_args_dict["ep_num_redundant_experts"],
......@@ -343,6 +343,7 @@ class DeepseekV2MoE(nn.Module):
correction_bias=self.gate.e_score_correction_bias,
routed_scaling_factor=self.routed_scaling_factor,
apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk(),
force_topk=quant_config is None,
)
self.shared_experts_is_int8 = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment