Unverified Commit 081b5594 authored by Shu Wang's avatar Shu Wang Committed by GitHub
Browse files

Fix routing_bias dtype (#25711)


Signed-off-by: default avatarShu Wang. <shuw@nvidia.com>
parent 57329a8c
...@@ -1454,10 +1454,13 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1454,10 +1454,13 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
routing_method_type = flashinfer.RoutingMethodType.DeepSeekV3 routing_method_type = flashinfer.RoutingMethodType.DeepSeekV3
if use_llama4_routing: if use_llama4_routing:
routing_method_type = flashinfer.RoutingMethodType.Llama4 routing_method_type = flashinfer.RoutingMethodType.Llama4
routing_bias = e_score_correction_bias
if routing_bias is not None:
routing_bias = routing_bias.to(torch.bfloat16)
out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe( out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
routing_logits=router_logits routing_logits=router_logits
if use_llama4_routing else router_logits.to(torch.float32), if use_llama4_routing else router_logits.to(torch.float32),
routing_bias=e_score_correction_bias, routing_bias=routing_bias,
hidden_states=hidden_states_fp4, hidden_states=hidden_states_fp4,
hidden_states_scale=hidden_states_scale_linear_fp4.view( hidden_states_scale=hidden_states_scale_linear_fp4.view(
torch.float8_e4m3fn).flatten(), torch.float8_e4m3fn).flatten(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment