Unverified Commit 561253b3 authored by jiahanc's avatar jiahanc Committed by GitHub
Browse files

[Performance][Fix] update nvfp4 code to support renorm routing (#28569)


Signed-off-by: default avatarjiahanc <173873397+jiahanc@users.noreply.github.com>
Co-authored-by: default avatarMichael Goin <mgoin64@gmail.com>
parent 80b6080d
......@@ -15,6 +15,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
RoutingMethodType,
fp8_w8a8_moe_quant_config,
nvfp4_moe_quant_config,
)
......@@ -1657,16 +1658,19 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
use_llama4_routing = (
custom_routing_function is Llama4MoE.custom_routing_function
)
routing_method_type = flashinfer.RoutingMethodType.DeepSeekV3
routing_method_type = layer.routing_method_type
if use_llama4_routing:
routing_method_type = flashinfer.RoutingMethodType.Llama4
routing_method_type = RoutingMethodType.Llama4
router_logits = (
router_logits.to(torch.float32)
if routing_method_type == RoutingMethodType.DeepSeekV3
else router_logits
)
routing_bias = e_score_correction_bias
if routing_bias is not None:
routing_bias = routing_bias.to(torch.bfloat16)
out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
routing_logits=router_logits
if use_llama4_routing
else router_logits.to(torch.float32),
routing_logits=router_logits,
routing_bias=routing_bias,
hidden_states=hidden_states_fp4,
hidden_states_scale=hidden_states_scale_linear_fp4.view(
......@@ -1690,8 +1694,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
output2_scale_scalar=layer.g2_alphas.data,
num_experts=global_num_experts,
top_k=top_k,
n_group=num_expert_group if num_expert_group is not None else 0,
topk_group=topk_group if topk_group is not None else 0,
n_group=num_expert_group,
topk_group=topk_group,
intermediate_size=layer.intermediate_size_per_partition,
local_expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
......
......@@ -291,5 +291,8 @@ def get_flashinfer_moe_backend() -> FlashinferMoeBackend:
def is_flashinfer_supporting_global_sf(backend: FlashinferMoeBackend | None) -> bool:
# TODO(shuw@nvidia): Update when new backends are added.
backends_supporting_global_sf = (FlashinferMoeBackend.CUTLASS,)
backends_supporting_global_sf = (
FlashinferMoeBackend.CUTLASS,
FlashinferMoeBackend.TENSORRT_LLM,
)
return backend in backends_supporting_global_sf
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment