Unverified Commit 73503317 authored by jiahanc's avatar jiahanc Committed by GitHub
Browse files

[BugFix] Fix TRT-LLM NVFP4 DP/EP (#32349)


Signed-off-by: default avatarjiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: default avatarRobert Shaw <robshaw@redhat.com>
Co-authored-by: default avatarRobert Shaw <robshaw@redhat.com>
parent 9d1e611f
model_name: "nvidia/Qwen3-30B-A3B-NVFP4"
accuracy_threshold: 0.88
num_questions: 1319
num_fewshot: 5
server_args: "--enforce-eager --max-model-len 8192 --data-parallel-size 2 --enable-expert-parallel"
env:
VLLM_USE_FLASHINFER_MOE_FP4: "1"
VLLM_FLASHINFER_MOE_BACKEND: "latency"
Qwen3-30B-A3B-NvFp4-CT-fi-cutedsl-deepep-ll.yaml Qwen3-30B-A3B-NvFp4-CT-fi-cutedsl-deepep-ll.yaml
Qwen3-30B-A3B-NvFp4-ModelOpt-fi-cutedsl-deepep-ll.yaml Qwen3-30B-A3B-NvFp4-ModelOpt-fi-cutedsl-deepep-ll.yaml
Qwen3-30B-A3B-NvFp4-CT-fi-cutlass.yaml Qwen3-30B-A3B-NvFp4-CT-fi-cutlass.yaml
Qwen3-30B-A3B-NvFp4-ModelOpt-fi-trtllm.yaml
Qwen3-30B-A3B-NvFp4-ModelOpt-fi-cutlass.yaml Qwen3-30B-A3B-NvFp4-ModelOpt-fi-cutlass.yaml
Qwen3-30B-A3B-Fp8-AutoFp8-deepgemm-deepep-ht.yaml Qwen3-30B-A3B-Fp8-AutoFp8-deepgemm-deepep-ht.yaml
Qwen3-30B-A3B-Fp8-AutoFp8-deepgemm-deepep-ll.yaml Qwen3-30B-A3B-Fp8-AutoFp8-deepgemm-deepep-ll.yaml
......
...@@ -53,7 +53,6 @@ from vllm.model_executor.layers.quantization.base_config import ( ...@@ -53,7 +53,6 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer_trtllm_fused_moe
from vllm.utils.math_utils import cdiv, round_up from vllm.utils.math_utils import cdiv, round_up
from vllm.utils.torch_utils import ( from vllm.utils.torch_utils import (
aux_stream, aux_stream,
...@@ -1761,17 +1760,11 @@ class FusedMoE(CustomOp): ...@@ -1761,17 +1760,11 @@ class FusedMoE(CustomOp):
with sp_ctx: with sp_ctx:
extra_tensors = None extra_tensors = None
if do_naive_dispatch_combine: if do_naive_dispatch_combine:
# Avoid circular import
from vllm.model_executor.layers.quantization.modelopt import (
ModelOptNvFp4FusedMoE,
)
post_quant_allgather = ( post_quant_allgather = (
self.quant_method is not None self.quant_method is not None
and self.dp_size > 1 and self.dp_size > 1
and self.use_ep and self.use_ep
and isinstance(self.quant_method, ModelOptNvFp4FusedMoE) and getattr(self.quant_method, "do_post_quant_allgather", False)
and has_flashinfer_trtllm_fused_moe()
) )
if post_quant_allgather: if post_quant_allgather:
hidden_states_to_dispatch, extra_tensors = ( hidden_states_to_dispatch, extra_tensors = (
......
...@@ -1564,6 +1564,10 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1564,6 +1564,10 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
moe_config=self.moe, moe_config=self.moe,
) )
@property
def do_post_quant_allgather(self):
return self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
def prepare_dp_allgather_tensor( def prepare_dp_allgather_tensor(
self, self,
layer: FusedMoE, layer: FusedMoE,
...@@ -1571,13 +1575,17 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1571,13 +1575,17 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> tuple[torch.Tensor, list[torch.Tensor]]: ) -> tuple[torch.Tensor, list[torch.Tensor]]:
"""Optionally prepare extra tensors to carry through DP allgather/EP.""" """Optionally prepare extra tensors to carry through DP allgather/EP."""
if self.nvfp4_backend != NvFp4MoeBackend.FLASHINFER_TRTLLM:
raise RuntimeError(
"prepare_dp_allgather_tensor is only supported for "
"FlashInfer TRTLLM NVFP4 MoE backend."
)
import flashinfer import flashinfer
assert self.moe_quant_config is not None
a1_gscale = self.moe_quant_config.a1_gscale
hidden_states_fp4, hidden_states_sf = flashinfer.fp4_quantize( hidden_states_fp4, hidden_states_sf = flashinfer.fp4_quantize(
hidden_states, hidden_states,
a1_gscale, layer.a1_gscale,
is_sf_swizzled_layout=False, is_sf_swizzled_layout=False,
) )
extra_tensors: list[torch.Tensor] = [hidden_states_sf] extra_tensors: list[torch.Tensor] = [hidden_states_sf]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment