Unverified Commit 58c468f4 authored by Trevor Morris's avatar Trevor Morris Committed by GitHub
Browse files

Fix FP4 MoE accuracy from missing routed_scaling_factor (#8333)

parent f8ca2368
...@@ -952,7 +952,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -952,7 +952,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
tp_rank: Optional[int] = None, tp_rank: Optional[int] = None,
tp_size: Optional[int] = None, tp_size: Optional[int] = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert activation == "silu", "Only SiLU activation is supported." assert activation == "silu", "Only SiLU activation is supported."
if self.enable_flashinfer_moe: if self.enable_flashinfer_moe:
...@@ -982,13 +981,15 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -982,13 +981,15 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
tp_size=tp_size, tp_size=tp_size,
tp_rank=tp_rank, tp_rank=tp_rank,
tune_max_num_tokens=next_power_of_2(x.shape[0]), tune_max_num_tokens=next_power_of_2(x.shape[0]),
) )[0]
return output[0] if routed_scaling_factor is not None:
output *= routed_scaling_factor
return output
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4 from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
topk_weights, topk_ids, _ = topk_output topk_weights, topk_ids, _ = topk_output
return cutlass_moe_fp4( output = cutlass_moe_fp4(
a=x, a=x,
a1_gscale=layer.w13_input_scale_quant, a1_gscale=layer.w13_input_scale_quant,
w1_fp4=layer.w13_weight, w1_fp4=layer.w13_weight,
...@@ -1003,3 +1004,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -1003,3 +1004,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
params=layer.cutlass_moe_params, params=layer.cutlass_moe_params,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
).to(x.dtype) ).to(x.dtype)
if routed_scaling_factor is not None:
output *= routed_scaling_factor
return output
...@@ -433,10 +433,6 @@ class ServerArgs: ...@@ -433,10 +433,6 @@ class ServerArgs:
self.quantization == "modelopt_fp4" self.quantization == "modelopt_fp4"
), "modelopt_fp4 quantization is required for Flashinfer MOE" ), "modelopt_fp4 quantization is required for Flashinfer MOE"
os.environ["TRTLLM_ENABLE_PDL"] = "1" os.environ["TRTLLM_ENABLE_PDL"] = "1"
self.disable_shared_experts_fusion = True
logger.warning(
f"Flashinfer MoE is enabled. Shared expert fusion is disabled."
)
# DeepEP MoE # DeepEP MoE
if self.enable_deepep_moe: if self.enable_deepep_moe:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment