"docs/benchmarking/cli.md" did not exist on "67187554dd478ba76e79d7a6f8bf02be01290de3"
Unverified Commit 23d825ab authored by Luka Govedič's avatar Luka Govedič Committed by GitHub
Browse files

[torch.compile] Disable ar-rms fusion for ds3-fp4 & DP, fix CI test (#34392)


Signed-off-by: default avatarLuka Govedič <lgovedic@redhat.com>
Signed-off-by: default avatarLuka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
parent f07a1284
...@@ -1687,6 +1687,20 @@ class ModelConfig: ...@@ -1687,6 +1687,20 @@ class ModelConfig:
def is_quantized(self) -> bool: def is_quantized(self) -> bool:
return getattr(self.hf_config, "quantization_config", None) is not None return getattr(self.hf_config, "quantization_config", None) is not None
def is_nvfp4_quantized(self) -> bool:
# ModelOpt NVFP4 checkpoints resolve to modelopt_fp4 quantization method
if self.quantization in ("modelopt_fp4",):
return True
# For Compressed Tensors we look for `"format": "nvfp4-pack-quantized"`
# in the quantization config
quant_config = self.model_arch_config.quantization_config
return (
self.quantization == "compressed-tensors"
and quant_config is not None
and "nvfp4" in quant_config.get("format", "").lower()
)
def get_served_model_name(model: str, served_model_name: str | list[str] | None): def get_served_model_name(model: str, served_model_name: str | list[str] | None):
""" """
......
...@@ -103,15 +103,21 @@ def enable_act_fusion(cfg: "VllmConfig") -> bool: ...@@ -103,15 +103,21 @@ def enable_act_fusion(cfg: "VllmConfig") -> bool:
def enable_allreduce_rms_fusion(cfg: "VllmConfig") -> bool: def enable_allreduce_rms_fusion(cfg: "VllmConfig") -> bool:
"""Enable if TP > 1 and Hopper+ and flashinfer installed.""" """Enable if TP > 1 and Hopper/Blackwell and flashinfer installed."""
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer from vllm.utils.flashinfer import has_flashinfer
return ( return (
cfg.parallel_config.tensor_parallel_size > 1 cfg.parallel_config.tensor_parallel_size > 1
and current_platform.is_cuda() and current_platform.is_cuda()
and current_platform.has_device_capability(90)
and has_flashinfer() and has_flashinfer()
and (
current_platform.is_device_capability(100)
or current_platform.is_device_capability(90)
)
# tp-dp combination broken:
# https://github.com/vllm-project/vllm/issues/34458
and cfg.parallel_config.data_parallel_size == 1
) )
......
...@@ -536,12 +536,34 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig): ...@@ -536,12 +536,34 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
) )
class DeepseekV32ForCausalLM(VerifyAndUpdateConfig): class DeepseekV3ForCausalLM(VerifyAndUpdateConfig):
@classmethod
def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
"""Disable AR-RMS-Quant fusion for DeepSeekV3 in NVFP4"""
# TODO: https://github.com/vllm-project/vllm/issues/34395
# disable AR-rms-fp4 fusion for DSv3+
ar_rms_enabled = vllm_config.compilation_config.pass_config.fuse_allreduce_rms
nvfp4 = vllm_config.model_config.is_nvfp4_quantized()
# Disable by default, warn if manually enabled:
if ar_rms_enabled is None and nvfp4:
vllm_config.compilation_config.pass_config.fuse_allreduce_rms = False
if ar_rms_enabled and nvfp4:
logger.warning(
"Allreduce-rms fusion broken for DeepSeekV3 with NVFP4 quant,"
"see https://github.com/vllm-project/vllm/issues/34395."
)
class DeepseekV32ForCausalLM(DeepseekV3ForCausalLM):
@classmethod @classmethod
def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
""" """
Updated fp8 cache to custom "fp8_ds_mla" format for DeepSeekV32 Updated fp8 cache to custom "fp8_ds_mla" format for DeepSeekV32
""" """
super().verify_and_update_config(vllm_config)
hf_config = vllm_config.model_config.hf_config hf_config = vllm_config.model_config.hf_config
# Mirror the check in vllm/model_executor/models/deepseek_v2.py # Mirror the check in vllm/model_executor/models/deepseek_v2.py
...@@ -632,6 +654,7 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = { ...@@ -632,6 +654,7 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
"MambaForCausalLM": MambaModelConfig, "MambaForCausalLM": MambaModelConfig,
"Mamba2ForCausalLM": MambaModelConfig, "Mamba2ForCausalLM": MambaModelConfig,
"FalconMambaForCausalLM": MambaModelConfig, "FalconMambaForCausalLM": MambaModelConfig,
"DeepseekV3ForCausalLM": DeepseekV3ForCausalLM,
"DeepseekV32ForCausalLM": DeepseekV32ForCausalLM, "DeepseekV32ForCausalLM": DeepseekV32ForCausalLM,
"NemotronHForCausalLM": NemotronHForCausalLMConfig, "NemotronHForCausalLM": NemotronHForCausalLMConfig,
"NemotronHPuzzleForCausalLM": NemotronHForCausalLMConfig, "NemotronHPuzzleForCausalLM": NemotronHForCausalLMConfig,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment