"...git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "948dd3443bc6b8ffb76cbdddf3f4c5ae0b6637fa"
Unverified Commit 92fbec39 authored by Wei Zhao's avatar Wei Zhao Committed by GitHub
Browse files

[Bug] Fix routing bias dtype for trtllm per-block fp8 moe (#38989)


Signed-off-by: default avatarwzhao18 <wzhao18.sz@gmail.com>
Co-authored-by: default avatarJee Jee Li <pandaleefree@gmail.com>
parent 2f41d6c0
...@@ -358,6 +358,11 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit ...@@ -358,6 +358,11 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit
if self.routing_method_type == RoutingMethodType.DeepSeekV3: if self.routing_method_type == RoutingMethodType.DeepSeekV3:
router_logits = router_logits.to(torch.float32) router_logits = router_logits.to(torch.float32)
# Currently FI requires bfloat16 routing bias.
# https://github.com/flashinfer-ai/flashinfer/issues/2909
if e_score_correction_bias is not None:
e_score_correction_bias = e_score_correction_bias.to(torch.bfloat16)
is_mxfp8 = self.quant_config.block_shape == [1, 32] is_mxfp8 = self.quant_config.block_shape == [1, 32]
if is_mxfp8: if is_mxfp8:
fp8_quant_type = Fp8QuantizationType.MxFp8 fp8_quant_type = Fp8QuantizationType.MxFp8
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment