Unverified Commit 9b17c574 authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[ModelBash][DSR1 NVFp4] Removed Bf16 Bias Cast (#34298)


Signed-off-by: default avatarRobert Shaw <robshaw@redhat.com>
Co-authored-by: default avatarRobert Shaw <robshaw@redhat.com>
parent 1b3540e6
...@@ -25,6 +25,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -25,6 +25,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.platforms import current_platform from vllm.platforms import current_platform
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import ( from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
NvFp4MoeBackend, NvFp4MoeBackend,
) )
...@@ -316,11 +317,7 @@ def flashinfer_trtllm_fp4_moe( ...@@ -316,11 +317,7 @@ def flashinfer_trtllm_fp4_moe(
if use_llama4_routing: if use_llama4_routing:
routing_method_type = flashinfer.RoutingMethodType.Llama4 routing_method_type = flashinfer.RoutingMethodType.Llama4
# Prepare routing bias # Cast to Fp32 (required by kernel).
routing_bias = e_score_correction_bias
if routing_bias is not None:
routing_bias = routing_bias.to(torch.bfloat16)
router_logits = ( router_logits = (
router_logits.to(torch.float32) router_logits.to(torch.float32)
if routing_method_type == RoutingMethodType.DeepSeekV3 if routing_method_type == RoutingMethodType.DeepSeekV3
...@@ -330,7 +327,7 @@ def flashinfer_trtllm_fp4_moe( ...@@ -330,7 +327,7 @@ def flashinfer_trtllm_fp4_moe(
# Call TRT-LLM FP4 block-scale MoE kernel # Call TRT-LLM FP4 block-scale MoE kernel
out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe( out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
routing_logits=router_logits, routing_logits=router_logits,
routing_bias=routing_bias, routing_bias=e_score_correction_bias,
hidden_states=hidden_states_fp4, hidden_states=hidden_states_fp4,
hidden_states_scale=hidden_states_scale_linear_fp4.view( hidden_states_scale=hidden_states_scale_linear_fp4.view(
torch.float8_e4m3fn torch.float8_e4m3fn
...@@ -447,7 +444,7 @@ def flashinfer_trtllm_fp4_routed_moe( ...@@ -447,7 +444,7 @@ def flashinfer_trtllm_fp4_routed_moe(
def prepare_nvfp4_moe_layer_for_fi_or_cutlass( def prepare_nvfp4_moe_layer_for_fi_or_cutlass(
backend: "NvFp4MoeBackend", backend: "NvFp4MoeBackend",
layer: torch.nn.Module, layer: "FusedMoE",
w13: torch.Tensor, w13: torch.Tensor,
w13_scale: torch.Tensor, w13_scale: torch.Tensor,
w13_scale_2: torch.Tensor, w13_scale_2: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment