Unverified Commit 18da2c96 authored by Kaixi Hou's avatar Kaixi Hou Committed by GitHub
Browse files

[NVIDIA] Fix trtllm fp4 moe backend when used in MTP (#9384)

parent 9b5f0f64
...@@ -783,13 +783,17 @@ class DeepEPMoE(EPMoE): ...@@ -783,13 +783,17 @@ class DeepEPMoE(EPMoE):
return hidden_states return hidden_states
def get_moe_impl_class(): def get_moe_impl_class(quant_config: Optional[QuantizationConfig] = None):
if get_moe_a2a_backend().is_deepep(): if get_moe_a2a_backend().is_deepep():
return DeepEPMoE return DeepEPMoE
# NEW: Direct FP4 detection (bypasses EP requirements) # NEW: Direct FP4 detection (bypasses EP requirements)
# Check for FP4 quantization with TRTLLM flag, regardless of EP # Check for FP4 quantization with TRTLLM flag, regardless of EP
if get_moe_runner_backend().is_flashinfer_trtllm(): if get_moe_runner_backend().is_flashinfer_trtllm():
# FlashInferFP4MoE must be paired with ModelOptNvFp4FusedMoEMethod.
# If UnquantizedFusedMoEMethod is detected, fall back to FusedMoE instead.
if quant_config is None:
return FusedMoE
try: try:
# Check the quantization argument directly # Check the quantization argument directly
quantization = global_server_args_dict.get("quantization") quantization = global_server_args_dict.get("quantization")
......
...@@ -1008,6 +1008,8 @@ class FlashInferFP4MoE(FusedMoE): ...@@ -1008,6 +1008,8 @@ class FlashInferFP4MoE(FusedMoE):
hidden_states: Input tensor hidden_states: Input tensor
topk_output: TopKOutput object with Bypassed format topk_output: TopKOutput object with Bypassed format
""" """
assert isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod)
assert TopKOutputChecker.format_is_bypassed(topk_output) assert TopKOutputChecker.format_is_bypassed(topk_output)
router_logits = topk_output.router_logits router_logits = topk_output.router_logits
......
...@@ -198,6 +198,7 @@ class TopK(CustomOp): ...@@ -198,6 +198,7 @@ class TopK(CustomOp):
correction_bias: Optional[torch.Tensor] = None, correction_bias: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
apply_routed_scaling_factor_on_output: Optional[bool] = False, apply_routed_scaling_factor_on_output: Optional[bool] = False,
force_topk: bool = False,
): ):
# NOTE: scoring_func is not used for now, but we keep it for future use # NOTE: scoring_func is not used for now, but we keep it for future use
# see https://github.com/sgl-project/sglang/pull/4505 for more details # see https://github.com/sgl-project/sglang/pull/4505 for more details
...@@ -220,6 +221,7 @@ class TopK(CustomOp): ...@@ -220,6 +221,7 @@ class TopK(CustomOp):
) )
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel() self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
self.force_topk = force_topk
def forward_native( def forward_native(
self, self,
...@@ -254,7 +256,7 @@ class TopK(CustomOp): ...@@ -254,7 +256,7 @@ class TopK(CustomOp):
sm_first=not self.topk_config.renormalize, sm_first=not self.topk_config.renormalize,
) )
return TritonKernelTopKOutput(routing_data, gather_idx, scatter_idx) return TritonKernelTopKOutput(routing_data, gather_idx, scatter_idx)
elif ( elif not self.force_topk and (
should_use_flashinfer_trtllm_moe() should_use_flashinfer_trtllm_moe()
or get_moe_runner_backend().is_flashinfer_mxfp4() or get_moe_runner_backend().is_flashinfer_mxfp4()
): ):
......
...@@ -319,7 +319,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -319,7 +319,7 @@ class DeepseekV2MoE(nn.Module):
config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
) )
self.experts = get_moe_impl_class()( self.experts = get_moe_impl_class(quant_config)(
num_experts=config.n_routed_experts num_experts=config.n_routed_experts
+ self.num_fused_shared_experts + self.num_fused_shared_experts
+ global_server_args_dict["ep_num_redundant_experts"], + global_server_args_dict["ep_num_redundant_experts"],
...@@ -343,6 +343,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -343,6 +343,7 @@ class DeepseekV2MoE(nn.Module):
correction_bias=self.gate.e_score_correction_bias, correction_bias=self.gate.e_score_correction_bias,
routed_scaling_factor=self.routed_scaling_factor, routed_scaling_factor=self.routed_scaling_factor,
apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk(), apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk(),
force_topk=quant_config is None,
) )
self.shared_experts_is_int8 = False self.shared_experts_is_int8 = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment