Unverified Commit 739e5945 authored by Yongye Zhu's avatar Yongye Zhu Committed by GitHub
Browse files

[Quantization] [Refactor] Create special "GptOssMxfp4MoeMethod" (#39604)


Signed-off-by: default avatarYongye Zhu <zyy1102000@gmail.com>
parent 4d042ed8
......@@ -59,7 +59,7 @@ Modular kernels are supported by the following `FusedMoEMethodBase` classes.
- [`Fp8MoEMethod`][vllm.model_executor.layers.quantization.fp8.Fp8MoEMethod]
- [`CompressedTensorsW4A4Nvfp4MoEMethod`][vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.compressed_tensors_moe_w4a4_nvfp4.CompressedTensorsW4A4Nvfp4MoEMethod]
- [`CompressedTensorsW8A8Fp8MoEMethod`][vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.compressed_tensors_moe_w8a8_fp8.CompressedTensorsW8A8Fp8MoEMethod]
- [`Mxfp4MoEMethod`][vllm.model_executor.layers.quantization.mxfp4.Mxfp4MoEMethod]
- [`GptOssMxfp4MoEMethod`][vllm.model_executor.layers.quantization.mxfp4.GptOssMxfp4MoEMethod]
- [`UnquantizedFusedMoEMethod`][vllm.model_executor.layers.fused_moe.layer.UnquantizedFusedMoEMethod]
## Fused Experts Kernels
......
......@@ -951,6 +951,7 @@ class ModelConfig:
# Ensure heavy backends are probed last to avoid unnecessary
# imports during override detection (e.g., MXFP4 imports Triton)
"mxfp4",
"gpt_oss_mxfp4",
"cpu_awq",
"gguf",
]
......@@ -966,7 +967,7 @@ class ModelConfig:
for name in quantization_methods:
method = me_quant.get_quantization_config(name)
quantization_override = method.override_quantization_method(
quant_cfg, self.quantization
quant_cfg, self.quantization, hf_config=self.hf_config
)
if quantization_override is not None:
# Raise error if the override is not custom (custom would
......
......@@ -1063,7 +1063,7 @@ class FusedMoE(CustomOp):
expert_id: int,
return_success: bool = False,
) -> bool | None:
if self.quant_config and self.quant_config.get_name() == "mxfp4":
if self.quant_config and self.quant_config.get_name() == "gpt_oss_mxfp4":
# (FIXME) for gpt-oss all experts are combined
if "bias" in weight_name:
dim1 = loaded_weight.shape[1]
......
......@@ -194,7 +194,7 @@ def _backend_activation_key(backend: Mxfp4MoeBackend) -> QuantKey | None:
return None
def select_mxfp4_moe_backend(
def select_gpt_oss_mxfp4_moe_backend(
config: FusedMoEConfig,
) -> tuple[Mxfp4MoeBackend, type[mk.FusedMoEExperts] | None]:
"""
......@@ -400,7 +400,7 @@ def mxfp4_round_up_hidden_size_and_intermediate_size(
return hidden_size, intermediate_size
def convert_to_mxfp4_moe_kernel_format(
def convert_gpt_oss_weight_to_mxfp4_moe_kernel_format(
mxfp4_backend: Mxfp4MoeBackend,
layer: torch.nn.Module,
w13_weight: torch.Tensor,
......@@ -426,7 +426,10 @@ def convert_to_mxfp4_moe_kernel_format(
sf_block_size = 32 # mxfp4 block size
if mxfp4_backend in (Mxfp4MoeBackend.MARLIN, Mxfp4MoeBackend.BATCHED_MARLIN):
if mxfp4_backend in (
Mxfp4MoeBackend.MARLIN,
Mxfp4MoeBackend.BATCHED_MARLIN,
):
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
prepare_moe_mxfp4_layer_for_marlin,
)
......
......@@ -30,6 +30,7 @@ QuantizationMethods = Literal[
"torchao",
"inc",
"mxfp4",
"gpt_oss_mxfp4",
"mxfp8",
"cpu_awq",
"online",
......@@ -133,7 +134,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
ModelOptNvFp4Config,
)
from .moe_wna16 import MoeWNA16Config
from .mxfp4 import Mxfp4Config
from .mxfp4 import GptOssMxfp4Config, Mxfp4Config
from .mxfp8 import Mxfp8Config
from .online.base import OnlineQuantizationConfig
from .torchao import TorchAOConfig
......@@ -160,6 +161,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"auto-round": INCConfig,
"inc": INCConfig,
"mxfp4": Mxfp4Config,
"gpt_oss_mxfp4": GptOssMxfp4Config,
"mxfp8": Mxfp8Config,
"cpu_awq": CPUAWQConfig,
"online": OnlineQuantizationConfig,
......
......@@ -232,7 +232,7 @@ class AWQMarlinConfig(QuantizationConfig):
@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant
cls, hf_quant_cfg, user_quant, hf_config=None
) -> "QuantizationMethods | None":
# Skip override to marlin kernels, as they are not
# batch invariant
......
......@@ -110,13 +110,22 @@ class QuantizationConfig(ABC):
@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant
cls,
hf_quant_cfg: dict[str, Any],
user_quant: str | None,
hf_config: Any = None,
) -> QuantizationMethods | None:
"""
Detects if this quantization method can support a given checkpoint
format by overriding the user specified quantization method --
this method should only be overwritten by subclasses in exceptional
circumstances
circumstances.
Args:
hf_quant_cfg: The checkpoint's quantization config dict.
user_quant: The user-specified quantization method string.
hf_config: The HuggingFace model config object (e.g. for
model_type checks). May be None if not available.
"""
return None
......
......@@ -104,7 +104,7 @@ class CPUAWQConfig(QuantizationConfig):
@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant
cls, hf_quant_cfg, user_quant, hf_config=None
) -> "QuantizationMethods | None":
quant_method = hf_quant_cfg.get("quant_method", "").lower()
if current_platform.is_cpu() and (quant_method == "awq"):
......
......@@ -84,7 +84,7 @@ class GGUFConfig(QuantizationConfig):
@classmethod
def override_quantization_method(
cls, hf_quant_cfg: dict[str, Any], user_quant: str | None
cls, hf_quant_cfg: dict[str, Any], user_quant: str | None, hf_config=None
) -> "QuantizationMethods | None":
# When user explicitly specifies --quantization gguf, override
# whatever quantization method is in the HF model config (e.g. fp8).
......
......@@ -214,7 +214,7 @@ class GPTQMarlinConfig(QuantizationConfig):
@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant
cls, hf_quant_cfg, user_quant, hf_config=None
) -> QuantizationMethods | None:
can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg)
......
......@@ -453,7 +453,7 @@ class INCConfig(QuantizationConfig):
@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant
cls, hf_quant_cfg, user_quant, hf_config=None
) -> "QuantizationMethods | None":
"""Override the `auto-round` method to `inc`."""
is_auto_round_format = hf_quant_cfg.get("quant_method", None) == "auto-round"
......
......@@ -406,7 +406,7 @@ class ModelOptFp8Config(ModelOptQuantConfigBase):
@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant
cls, hf_quant_cfg, user_quant, hf_config=None
) -> QuantizationMethods | None:
algo = cls._extract_modelopt_quant_algo(hf_quant_cfg)
if algo is not None and algo == "FP8":
......@@ -1028,7 +1028,7 @@ class ModelOptNvFp4Config(ModelOptQuantConfigBase):
@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant
cls, hf_quant_cfg, user_quant, hf_config=None
) -> QuantizationMethods | None:
algo = cls._extract_modelopt_quant_algo(hf_quant_cfg)
if algo is not None and ("NVFP4" in algo or "FP4" in algo):
......@@ -1525,7 +1525,7 @@ class ModelOptMxFp8Config(ModelOptQuantConfigBase):
@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant
cls, hf_quant_cfg, user_quant, hf_config=None
) -> QuantizationMethods | None:
algo = cls._extract_modelopt_quant_algo(hf_quant_cfg)
if algo is not None and "MXFP8" in algo:
......@@ -2052,7 +2052,7 @@ class ModelOptMixedPrecisionConfig(ModelOptQuantConfigBase):
@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant
cls, hf_quant_cfg, user_quant, hf_config=None
) -> QuantizationMethods | None:
algo = cls._extract_modelopt_quant_algo(hf_quant_cfg)
if algo is not None and algo == "MIXED_PRECISION":
......
......@@ -130,7 +130,7 @@ class MoeWNA16Config(QuantizationConfig):
@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant
cls, hf_quant_cfg, user_quant, hf_config=None
) -> QuantizationMethods | None:
can_convert = cls.is_moe_wna16_compatible(hf_quant_cfg)
if can_convert and user_quant == "moe_wna16":
......
......@@ -19,11 +19,11 @@ from vllm.model_executor.layers.fused_moe.config import (
from vllm.model_executor.layers.fused_moe.oracle.mxfp4 import (
TRITON_BACKENDS,
Mxfp4MoeBackend,
convert_to_mxfp4_moe_kernel_format,
convert_gpt_oss_weight_to_mxfp4_moe_kernel_format,
make_mxfp4_moe_kernel,
make_mxfp4_moe_quant_config,
mxfp4_round_up_hidden_size_and_intermediate_size,
select_mxfp4_moe_backend,
select_gpt_oss_mxfp4_moe_backend,
)
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
from vllm.model_executor.layers.quantization import QuantizationMethods
......@@ -38,6 +38,12 @@ logger = init_logger(__name__)
class Mxfp4Config(QuantizationConfig):
"""Canonical base config for MXFP4 quantization.
Subclasses override get_name() and override_quantization_method() to
register themselves as the handler for a specific checkpoint format.
"""
def __init__(self, ignored_layers: list[str] | None = None):
super().__init__()
self.ignored_layers = ignored_layers
......@@ -62,6 +68,8 @@ class Mxfp4Config(QuantizationConfig):
def get_config_filenames(cls) -> list[str]:
return []
# TODO (zyongye) This is only temporaty fallback.
# We should have `Mxfp4MoEMethod` after this migration is complete.
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> "QuantizeMethodBase | None":
......@@ -79,7 +87,7 @@ class Mxfp4Config(QuantizationConfig):
)
return UnquantizedLinearMethod()
elif isinstance(layer, FusedMoE):
return Mxfp4MoEMethod(layer.moe_config)
return GptOssMxfp4MoEMethod(layer.moe_config)
elif isinstance(layer, Attention):
logger.debug_once(
"MXFP4 attention layer is not implemented. "
......@@ -93,13 +101,46 @@ class Mxfp4Config(QuantizationConfig):
return True
class Mxfp4MoEMethod(FusedMoEMethodBase):
class GptOssMxfp4Config(Mxfp4Config):
"""MXFP4 config for GPT-OSS checkpoints.
Checkpoints carry ``"quant_method": "mxfp4"`` in their JSON config.
override_quantization_method() maps that to the canonical internal name
so that the rest of the loading path uses "gpt_oss_mxfp4" consistently.
"""
@classmethod
def get_name(cls) -> QuantizationMethods:
return "gpt_oss_mxfp4"
@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant, hf_config=None
) -> QuantizationMethods | None:
# Match both "mxfp4" (original checkpoint value) and "gpt_oss_mxfp4"
# (already normalized by verify_and_update_model_config) so that
# explicit --quantization mxfp4 from the user doesn't cause a mismatch.
if not (
isinstance(hf_quant_cfg, dict)
and hf_quant_cfg.get("quant_method") in ("mxfp4", "gpt_oss_mxfp4")
):
return None
# Require explicit confirmation that this is a GPT-OSS model.
# Do NOT fall back to returning the override when hf_config is None,
# as that would silently claim all mxfp4 checkpoints.
model_type = getattr(hf_config, "model_type", None)
if model_type != "gpt_oss":
return None
return "gpt_oss_mxfp4"
class GptOssMxfp4MoEMethod(FusedMoEMethodBase):
"""MXFP4 MoE quantization method."""
def __init__(self, moe: FusedMoEConfig):
super().__init__(moe)
self.weight_dtype = "mxfp4"
self.mxfp4_backend, self.experts_cls = select_mxfp4_moe_backend(moe)
self.weight_dtype = "gpt_oss_mxfp4"
self.mxfp4_backend, self.experts_cls = select_gpt_oss_mxfp4_moe_backend(moe)
self.max_capture_size = (
get_current_vllm_config().compilation_config.max_cudagraph_capture_size
......@@ -281,7 +322,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
# Convert weights to kernel format
w13, w2, w13_scale, w2_scale, w13_bias, w2_bias = (
convert_to_mxfp4_moe_kernel_format(
convert_gpt_oss_weight_to_mxfp4_moe_kernel_format(
mxfp4_backend=self.mxfp4_backend,
layer=layer,
w13_weight=w13,
......
......@@ -30,11 +30,11 @@ from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_m
from vllm.model_executor.layers.fused_moe.oracle.mxfp4 import (
TRITON_BACKENDS,
Mxfp4MoeBackend,
convert_to_mxfp4_moe_kernel_format,
convert_gpt_oss_weight_to_mxfp4_moe_kernel_format,
make_mxfp4_moe_kernel,
make_mxfp4_moe_quant_config,
mxfp4_round_up_hidden_size_and_intermediate_size,
select_mxfp4_moe_backend,
select_gpt_oss_mxfp4_moe_backend,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
prepare_fp8_moe_layer_for_marlin,
......@@ -995,7 +995,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
self.w2_precision_config = None
if self.ocp_mx_scheme == "w_mxfp4":
self.mxfp4_backend, self.experts_cls = select_mxfp4_moe_backend(moe)
self.mxfp4_backend, self.experts_cls = select_gpt_oss_mxfp4_moe_backend(moe)
elif self.ocp_mx_scheme.startswith("w_mxfp4"):
# TODO(bowenbao): refactor and introduce backends for other OCP MX schemes.
self.mxfp4_backend = Mxfp4MoeBackend.NONE
......@@ -1300,7 +1300,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
# Convert weights to kernel format
w13, w2, w13_scale, w2_scale, w13_bias, w2_bias = (
convert_to_mxfp4_moe_kernel_format(
convert_gpt_oss_weight_to_mxfp4_moe_kernel_format(
mxfp4_backend=self.mxfp4_backend,
layer=layer,
w13_weight=w13,
......
......@@ -108,6 +108,23 @@ class Gemma4Config(VerifyAndUpdateConfig):
class GptOssForCausalLMConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
quant_config = getattr(model_config.hf_config, "quantization_config", None)
if quant_config is not None and quant_config.get("quant_method") == "mxfp4":
model_config.hf_config.quantization_config["quant_method"] = "gpt_oss_mxfp4"
hf_text_quant_config = getattr(
model_config.hf_text_config, "quantization_config", None
)
if (
hf_text_quant_config is not None
and hf_text_quant_config.get("quant_method") == "mxfp4"
):
model_config.hf_text_config.quantization_config["quant_method"] = (
"gpt_oss_mxfp4"
)
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
structured_outputs_config = vllm_config.structured_outputs_config
......
......@@ -560,6 +560,14 @@ class GptOssModel(nn.Module, EagleModelMixin):
pcp_rank=get_pcp_group().rank_in_group,
)
def _is_mxfp4(weight_dtype: str | None) -> bool:
"""Return True for any MXFP4 weight-dtype variant.
Covers "gpt_oss_mxfp4" (GptOssMxfp4MoEMethod) and "mxfp4"
(QuarkMoEMethod with fp4 weights) and any future variants.
"""
return weight_dtype is not None and "mxfp4" in weight_dtype
def _get_moe_weight_dtype(layer_id: int = 0) -> str | None:
"""Helper function to get MoE quantization weight dtype.
......@@ -578,7 +586,7 @@ class GptOssModel(nn.Module, EagleModelMixin):
moe_weight_dtype = _get_moe_weight_dtype(layer_id=0)
if moe_weight_dtype == "mxfp4":
if _is_mxfp4(moe_weight_dtype):
# MXFP4 requires OCP_MX_BLOCK_SIZE alignment
intermediate_size_block = intermediate_size // OCP_MX_BLOCK_SIZE
per_rank_intermediate_size_block = cdiv(intermediate_size_block, tp_size)
......@@ -682,7 +690,7 @@ class GptOssModel(nn.Module, EagleModelMixin):
continue
# Unified handler for mxfp4 weights and scales
elif moe_quant_method == "mxfp4" and any(
elif _is_mxfp4(moe_quant_method) and any(
name.endswith(suffix)
for suffix in [
".w13_weight_scale",
......@@ -1116,8 +1124,22 @@ class GptOssModel(nn.Module, EagleModelMixin):
if hasattr(self.config, "quantization_config")
else None
)
# Normalize the checkpoint's quant_method to the internal name.
# Note: there are three places where "mxfp4" -> "gpt_oss_mxfp4"
# normalization occurs, each serving a different data path:
# 1. GptOssMxfp4Config.override_quantization_method() — sets
# ModelConfig.quantization (used to select the QuantizationConfig
# class at model init time), reading from model_arch_config which
# is a snapshot taken before verify_and_update_model_config runs.
# 2. GptOssForCausalLMConfig.verify_and_update_model_config() —
# patches hf_config.quantization_config in-place (a separate copy
# of the dict from model_arch_config) for later hf_config lookups.
# 3. Here — reads directly from self.config (the raw HF config) which
# may still carry the original "mxfp4" string from the checkpoint.
if quant_method == "mxfp4":
quant_method = "gpt_oss_mxfp4"
if quant_method == "gpt_oss_mxfp4":
return self._load_weights_mxfp4(
ep_rank_end,
ep_rank_start,
......
......@@ -150,7 +150,7 @@ class Base(
if self.quant_config:
quant_method_name = self.quant_config.get_name()
# Check for unsupported quantization methods.
if quant_method_name == "mxfp4":
if quant_method_name in ("mxfp4", "gpt_oss_mxfp4"):
raise NotImplementedError(
"Transformers modeling backend does "
"not support MXFP4 quantization yet."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment