Unverified Commit dbdd9ae0 authored by Hongxia Yang's avatar Hongxia Yang Committed by GitHub
Browse files

[ROCm][Bugfix] fix exception related to trust_remote_code for MiniMax-M2.1-MXFP4 (#37698)


Signed-off-by: default avatarHongxia Yang <hongxiay.yang@amd.com>
Co-authored-by: default avatarHongxia Yang <hongxiay.yang@amd.com>
parent e8b055a5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for QuarkConfig.maybe_update_config.
Fetches real HF configs (metadata only, no model weights) to verify
that dynamic_mxfp4_quant is only enabled for DeepSeek-V3-family models.
Run: pytest tests/quantization/test_quark_maybe_update_config.py -v
"""
import pytest
from transformers import AutoConfig
from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig
def _make_quark_config() -> QuarkConfig:
"""Create a minimal QuarkConfig for testing."""
return QuarkConfig(quant_config={}, kv_cache_group=[], pack_method="reorder")
# ---------------------------------------------------------------------------
# Non-deepseek models must not flip dynamic_mxfp4_quant
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"model_name",
["amd/MiniMax-M2.1-MXFP4"],
)
def test_non_deepseek_model_stays_false(model_name: str):
"""Non-deepseek_v3 models must not enable dynamic_mxfp4_quant."""
hf_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
qcfg = _make_quark_config()
qcfg.maybe_update_config(model_name, hf_config=hf_config)
assert qcfg.dynamic_mxfp4_quant is False
# ---------------------------------------------------------------------------
# DeepSeek-V3 family + fp4 must enable dynamic_mxfp4_quant
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"model_name",
["amd/DeepSeek-R1-MXFP4-ASQ"],
)
def test_deepseek_family_fp4_enables_flag(model_name: str):
hf_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
qcfg = _make_quark_config()
qcfg.maybe_update_config(model_name, hf_config=hf_config)
assert qcfg.dynamic_mxfp4_quant is True
# ---------------------------------------------------------------------------
# Missing hf_config → warn and stay False
# ---------------------------------------------------------------------------
def test_missing_hf_config_stays_false():
qcfg = _make_quark_config()
qcfg.maybe_update_config("some/model")
assert qcfg.dynamic_mxfp4_quant is False
...@@ -526,7 +526,10 @@ class VllmConfig: ...@@ -526,7 +526,10 @@ class VllmConfig:
f"method {model_config.quantization}. Supported dtypes: " f"method {model_config.quantization}. Supported dtypes: "
f"{supported_dtypes}" f"{supported_dtypes}"
) )
quant_config.maybe_update_config(model_config.model) quant_config.maybe_update_config(
model_config.model,
hf_config=model_config.hf_config,
)
return quant_config return quant_config
return None return None
......
...@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Union ...@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Union
import torch import torch
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
from transformers import PretrainedConfig
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -146,7 +147,12 @@ class AWQConfig(QuantizationConfig): ...@@ -146,7 +147,12 @@ class AWQConfig(QuantizationConfig):
self.modules_to_not_convert self.modules_to_not_convert
) )
def maybe_update_config(self, model_name: str, revision: str | None = None): def maybe_update_config(
self,
model_name: str,
hf_config: PretrainedConfig | None = None,
revision: str | None = None,
):
if self.modules_to_not_convert: if self.modules_to_not_convert:
return return
......
...@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Any ...@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Any
import torch import torch
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
from torch.nn import Parameter from torch.nn import Parameter
from transformers import PretrainedConfig
import vllm.model_executor.layers.fused_moe # noqa import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
...@@ -332,7 +333,12 @@ class AWQMarlinConfig(QuantizationConfig): ...@@ -332,7 +333,12 @@ class AWQMarlinConfig(QuantizationConfig):
self.modules_to_not_convert self.modules_to_not_convert
) )
def maybe_update_config(self, model_name: str, revision: str | None = None): def maybe_update_config(
self,
model_name: str,
hf_config: PretrainedConfig | None = None,
revision: str | None = None,
):
if self.modules_to_not_convert: if self.modules_to_not_convert:
return return
......
...@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any ...@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization import QuantizationMethods
...@@ -168,10 +169,23 @@ class QuantizationConfig(ABC): ...@@ -168,10 +169,23 @@ class QuantizationConfig(ABC):
# TODO (@kylesayrs): add implementations for all subclasses # TODO (@kylesayrs): add implementations for all subclasses
pass pass
def maybe_update_config(self, model_name: str): # noqa: B027 def maybe_update_config( # noqa: B027
self,
model_name: str,
hf_config: PretrainedConfig | None = None,
revision: str | None = None,
):
""" """
Interface to update values after config initialization. Interface to update values after config initialization.
Args:
model_name: The name of the model
hf_config: The Hugging Face config of the model
revision: The revision of the model
Returns:
""" """
# TODO: revision is never passed currently in vllm.py,
# but is used in subclasses, should we remove this parameter?
pass pass
def is_mxfp4_quant(self, prefix: str, layer: torch.nn.Module) -> bool: def is_mxfp4_quant(self, prefix: str, layer: torch.nn.Module) -> bool:
......
...@@ -5,6 +5,7 @@ from typing import Any ...@@ -5,6 +5,7 @@ from typing import Any
import torch import torch
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
from transformers import PretrainedConfig
from vllm._custom_ops import ( from vllm._custom_ops import (
cpu_gemm_wna16, cpu_gemm_wna16,
...@@ -133,7 +134,12 @@ class CPUAWQConfig(QuantizationConfig): ...@@ -133,7 +134,12 @@ class CPUAWQConfig(QuantizationConfig):
self.modules_to_not_convert self.modules_to_not_convert
) )
def maybe_update_config(self, model_name: str, revision: str | None = None): def maybe_update_config(
self,
model_name: str,
hf_config: PretrainedConfig | None = None,
revision: str | None = None,
):
if self.modules_to_not_convert: if self.modules_to_not_convert:
return return
......
...@@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any, Union ...@@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any, Union
import torch import torch
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from transformers import PretrainedConfig
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -193,7 +194,12 @@ class GPTQConfig(QuantizationConfig): ...@@ -193,7 +194,12 @@ class GPTQConfig(QuantizationConfig):
self.modules_in_block_to_quantize self.modules_in_block_to_quantize
) )
def maybe_update_config(self, model_name: str, revision: str | None = None): def maybe_update_config(
self,
model_name: str,
hf_config: PretrainedConfig | None = None,
revision: str | None = None,
):
if self.modules_in_block_to_quantize: if self.modules_in_block_to_quantize:
if is_list_of(self.modules_in_block_to_quantize, list): if is_list_of(self.modules_in_block_to_quantize, list):
# original modules_in_block_to_quantize: list[list[str]] # original modules_in_block_to_quantize: list[list[str]]
......
...@@ -6,6 +6,7 @@ from typing import Any ...@@ -6,6 +6,7 @@ from typing import Any
import torch import torch
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
from transformers import PretrainedConfig
import vllm.model_executor.layers.fused_moe # noqa import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
...@@ -299,7 +300,12 @@ class GPTQMarlinConfig(QuantizationConfig): ...@@ -299,7 +300,12 @@ class GPTQMarlinConfig(QuantizationConfig):
self.modules_in_block_to_quantize self.modules_in_block_to_quantize
) )
def maybe_update_config(self, model_name: str, revision: str | None = None): def maybe_update_config(
self,
model_name: str,
hf_config: PretrainedConfig | None = None,
revision: str | None = None,
):
if self.modules_in_block_to_quantize: if self.modules_in_block_to_quantize:
if is_list_of(self.modules_in_block_to_quantize, list): if is_list_of(self.modules_in_block_to_quantize, list):
# original modules_in_block_to_quantize: list[list[str]] # original modules_in_block_to_quantize: list[list[str]]
......
...@@ -5,6 +5,7 @@ import fnmatch ...@@ -5,6 +5,7 @@ import fnmatch
from typing import TYPE_CHECKING, Any, cast from typing import TYPE_CHECKING, Any, cast
import torch import torch
from transformers import PretrainedConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.attention import Attention
...@@ -36,7 +37,6 @@ from vllm.model_executor.layers.quantization.quark.utils import ( ...@@ -36,7 +37,6 @@ from vllm.model_executor.layers.quantization.quark.utils import (
) )
from vllm.model_executor.models.utils import WeightsMapper from vllm.model_executor.models.utils import WeightsMapper
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.transformers_utils.config import get_config
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.model_executor.models.utils import WeightsMapper from vllm.model_executor.models.utils import WeightsMapper
...@@ -45,6 +45,10 @@ __all__ = ["QuarkLinearMethod"] ...@@ -45,6 +45,10 @@ __all__ = ["QuarkLinearMethod"]
logger = init_logger(__name__) logger = init_logger(__name__)
# model_type values that use dynamic MXFP4 re-quantization for
# OCP MX fp4 Quark checkpoints
_DEEPSEEK_V3_FAMILY_MODEL_TYPES = frozenset({"deepseek_v3"})
class QuarkConfig(QuantizationConfig): class QuarkConfig(QuantizationConfig):
def __init__( def __init__(
...@@ -63,19 +67,28 @@ class QuarkConfig(QuantizationConfig): ...@@ -63,19 +67,28 @@ class QuarkConfig(QuantizationConfig):
self.pack_method = pack_method self.pack_method = pack_method
self.dynamic_mxfp4_quant = False self.dynamic_mxfp4_quant = False
def maybe_update_config(self, model_name: str, revision: str | None = None): def maybe_update_config(
self.hf_config = get_config( self,
model=model_name, model_name: str,
trust_remote_code=False, # or get from model_config if available hf_config: PretrainedConfig | None = None,
revision=revision, revision: str | None = None,
config_format="auto", ):
) """Enable dynamic MXFP4 only for DeepSeek-V3-family + fp4 Quark checkpoints."""
quant_config = getattr(self.hf_config, "quantization_config", None) if (
getattr(hf_config, "model_type", None)
not in _DEEPSEEK_V3_FAMILY_MODEL_TYPES
):
return
quant_config = getattr(hf_config, "quantization_config", None)
if quant_config is not None: if quant_config is not None:
quant_dtype = quant_config["global_quant_config"]["weight"]["dtype"] quant_dtype = (
model_type = self.hf_config.model_type quant_config.get("global_quant_config", {})
if quant_dtype == "fp4" and model_type == "deepseek_v3": .get("weight", {})
.get("dtype")
)
if quant_dtype == "fp4":
self.dynamic_mxfp4_quant = True self.dynamic_mxfp4_quant = True
def get_linear_method(self) -> "QuarkLinearMethod": def get_linear_method(self) -> "QuarkLinearMethod":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment