"vllm/vscode:/vscode.git/clone" did not exist on "6682c231fa97f33d3b3f4d788da4e14959989a67"
Unverified Commit 5bf3c42d authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[Bug][MoE] Fix TRTLLM NVFP4 Routing Kernel Precision (#36725)


Signed-off-by: default avatarRobert Shaw <robshaw@redhat.com>
Co-authored-by: default avatarRobert Shaw <robshaw@redhat.com>
parent 38364a7e
...@@ -298,10 +298,7 @@ class TrtLlmNvFp4ExpertsMonolithic( ...@@ -298,10 +298,7 @@ class TrtLlmNvFp4ExpertsMonolithic(
and self.routing_method_type != RoutingMethodType.Llama4 and self.routing_method_type != RoutingMethodType.Llama4
) )
# Prepare routing bias into kernel format. # Prepare router logits for kernel format.
routing_bias = e_score_correction_bias
if routing_bias is not None:
routing_bias = routing_bias.to(torch.bfloat16)
router_logits = ( router_logits = (
router_logits.to(torch.float32) router_logits.to(torch.float32)
if self.routing_method_type == RoutingMethodType.DeepSeekV3 if self.routing_method_type == RoutingMethodType.DeepSeekV3
...@@ -311,7 +308,7 @@ class TrtLlmNvFp4ExpertsMonolithic( ...@@ -311,7 +308,7 @@ class TrtLlmNvFp4ExpertsMonolithic(
# Invoke kernel. # Invoke kernel.
return flashinfer.fused_moe.trtllm_fp4_block_scale_moe( return flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
routing_logits=router_logits, routing_logits=router_logits,
routing_bias=routing_bias, routing_bias=e_score_correction_bias,
hidden_states=hidden_states, hidden_states=hidden_states,
hidden_states_scale=a1q_scale.view(torch.float8_e4m3fn).reshape( hidden_states_scale=a1q_scale.view(torch.float8_e4m3fn).reshape(
*hidden_states.shape[:-1], -1 *hidden_states.shape[:-1], -1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment