Unverified Commit 739e5945 authored by Yongye Zhu's avatar Yongye Zhu Committed by GitHub
Browse files

[Quantization] [Refactor] Create special "GptOssMxfp4MoeMethod" (#39604)


Signed-off-by: default avatarYongye Zhu <zyy1102000@gmail.com>
parent 4d042ed8
...@@ -59,7 +59,7 @@ Modular kernels are supported by the following `FusedMoEMethodBase` classes. ...@@ -59,7 +59,7 @@ Modular kernels are supported by the following `FusedMoEMethodBase` classes.
- [`Fp8MoEMethod`][vllm.model_executor.layers.quantization.fp8.Fp8MoEMethod] - [`Fp8MoEMethod`][vllm.model_executor.layers.quantization.fp8.Fp8MoEMethod]
- [`CompressedTensorsW4A4Nvfp4MoEMethod`][vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.compressed_tensors_moe_w4a4_nvfp4.CompressedTensorsW4A4Nvfp4MoEMethod] - [`CompressedTensorsW4A4Nvfp4MoEMethod`][vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.compressed_tensors_moe_w4a4_nvfp4.CompressedTensorsW4A4Nvfp4MoEMethod]
- [`CompressedTensorsW8A8Fp8MoEMethod`][vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.compressed_tensors_moe_w8a8_fp8.CompressedTensorsW8A8Fp8MoEMethod] - [`CompressedTensorsW8A8Fp8MoEMethod`][vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.compressed_tensors_moe_w8a8_fp8.CompressedTensorsW8A8Fp8MoEMethod]
- [`Mxfp4MoEMethod`][vllm.model_executor.layers.quantization.mxfp4.Mxfp4MoEMethod] - [`GptOssMxfp4MoEMethod`][vllm.model_executor.layers.quantization.mxfp4.GptOssMxfp4MoEMethod]
- [`UnquantizedFusedMoEMethod`][vllm.model_executor.layers.fused_moe.layer.UnquantizedFusedMoEMethod] - [`UnquantizedFusedMoEMethod`][vllm.model_executor.layers.fused_moe.layer.UnquantizedFusedMoEMethod]
## Fused Experts Kernels ## Fused Experts Kernels
......
...@@ -951,6 +951,7 @@ class ModelConfig: ...@@ -951,6 +951,7 @@ class ModelConfig:
# Ensure heavy backends are probed last to avoid unnecessary # Ensure heavy backends are probed last to avoid unnecessary
# imports during override detection (e.g., MXFP4 imports Triton) # imports during override detection (e.g., MXFP4 imports Triton)
"mxfp4", "mxfp4",
"gpt_oss_mxfp4",
"cpu_awq", "cpu_awq",
"gguf", "gguf",
] ]
...@@ -966,7 +967,7 @@ class ModelConfig: ...@@ -966,7 +967,7 @@ class ModelConfig:
for name in quantization_methods: for name in quantization_methods:
method = me_quant.get_quantization_config(name) method = me_quant.get_quantization_config(name)
quantization_override = method.override_quantization_method( quantization_override = method.override_quantization_method(
quant_cfg, self.quantization quant_cfg, self.quantization, hf_config=self.hf_config
) )
if quantization_override is not None: if quantization_override is not None:
# Raise error if the override is not custom (custom would # Raise error if the override is not custom (custom would
......
...@@ -1063,7 +1063,7 @@ class FusedMoE(CustomOp): ...@@ -1063,7 +1063,7 @@ class FusedMoE(CustomOp):
expert_id: int, expert_id: int,
return_success: bool = False, return_success: bool = False,
) -> bool | None: ) -> bool | None:
if self.quant_config and self.quant_config.get_name() == "mxfp4": if self.quant_config and self.quant_config.get_name() == "gpt_oss_mxfp4":
# (FIXME) for gpt-oss all experts are combined # (FIXME) for gpt-oss all experts are combined
if "bias" in weight_name: if "bias" in weight_name:
dim1 = loaded_weight.shape[1] dim1 = loaded_weight.shape[1]
......
...@@ -194,7 +194,7 @@ def _backend_activation_key(backend: Mxfp4MoeBackend) -> QuantKey | None: ...@@ -194,7 +194,7 @@ def _backend_activation_key(backend: Mxfp4MoeBackend) -> QuantKey | None:
return None return None
def select_mxfp4_moe_backend( def select_gpt_oss_mxfp4_moe_backend(
config: FusedMoEConfig, config: FusedMoEConfig,
) -> tuple[Mxfp4MoeBackend, type[mk.FusedMoEExperts] | None]: ) -> tuple[Mxfp4MoeBackend, type[mk.FusedMoEExperts] | None]:
""" """
...@@ -400,7 +400,7 @@ def mxfp4_round_up_hidden_size_and_intermediate_size( ...@@ -400,7 +400,7 @@ def mxfp4_round_up_hidden_size_and_intermediate_size(
return hidden_size, intermediate_size return hidden_size, intermediate_size
def convert_to_mxfp4_moe_kernel_format( def convert_gpt_oss_weight_to_mxfp4_moe_kernel_format(
mxfp4_backend: Mxfp4MoeBackend, mxfp4_backend: Mxfp4MoeBackend,
layer: torch.nn.Module, layer: torch.nn.Module,
w13_weight: torch.Tensor, w13_weight: torch.Tensor,
...@@ -426,7 +426,10 @@ def convert_to_mxfp4_moe_kernel_format( ...@@ -426,7 +426,10 @@ def convert_to_mxfp4_moe_kernel_format(
sf_block_size = 32 # mxfp4 block size sf_block_size = 32 # mxfp4 block size
if mxfp4_backend in (Mxfp4MoeBackend.MARLIN, Mxfp4MoeBackend.BATCHED_MARLIN): if mxfp4_backend in (
Mxfp4MoeBackend.MARLIN,
Mxfp4MoeBackend.BATCHED_MARLIN,
):
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
prepare_moe_mxfp4_layer_for_marlin, prepare_moe_mxfp4_layer_for_marlin,
) )
......
...@@ -30,6 +30,7 @@ QuantizationMethods = Literal[ ...@@ -30,6 +30,7 @@ QuantizationMethods = Literal[
"torchao", "torchao",
"inc", "inc",
"mxfp4", "mxfp4",
"gpt_oss_mxfp4",
"mxfp8", "mxfp8",
"cpu_awq", "cpu_awq",
"online", "online",
...@@ -133,7 +134,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: ...@@ -133,7 +134,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
ModelOptNvFp4Config, ModelOptNvFp4Config,
) )
from .moe_wna16 import MoeWNA16Config from .moe_wna16 import MoeWNA16Config
from .mxfp4 import Mxfp4Config from .mxfp4 import GptOssMxfp4Config, Mxfp4Config
from .mxfp8 import Mxfp8Config from .mxfp8 import Mxfp8Config
from .online.base import OnlineQuantizationConfig from .online.base import OnlineQuantizationConfig
from .torchao import TorchAOConfig from .torchao import TorchAOConfig
...@@ -160,6 +161,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: ...@@ -160,6 +161,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"auto-round": INCConfig, "auto-round": INCConfig,
"inc": INCConfig, "inc": INCConfig,
"mxfp4": Mxfp4Config, "mxfp4": Mxfp4Config,
"gpt_oss_mxfp4": GptOssMxfp4Config,
"mxfp8": Mxfp8Config, "mxfp8": Mxfp8Config,
"cpu_awq": CPUAWQConfig, "cpu_awq": CPUAWQConfig,
"online": OnlineQuantizationConfig, "online": OnlineQuantizationConfig,
......
...@@ -232,7 +232,7 @@ class AWQMarlinConfig(QuantizationConfig): ...@@ -232,7 +232,7 @@ class AWQMarlinConfig(QuantizationConfig):
@classmethod @classmethod
def override_quantization_method( def override_quantization_method(
cls, hf_quant_cfg, user_quant cls, hf_quant_cfg, user_quant, hf_config=None
) -> "QuantizationMethods | None": ) -> "QuantizationMethods | None":
# Skip override to marlin kernels, as they are not # Skip override to marlin kernels, as they are not
# batch invariant # batch invariant
......
...@@ -110,13 +110,22 @@ class QuantizationConfig(ABC): ...@@ -110,13 +110,22 @@ class QuantizationConfig(ABC):
@classmethod @classmethod
def override_quantization_method( def override_quantization_method(
cls, hf_quant_cfg, user_quant cls,
hf_quant_cfg: dict[str, Any],
user_quant: str | None,
hf_config: Any = None,
) -> QuantizationMethods | None: ) -> QuantizationMethods | None:
""" """
Detects if this quantization method can support a given checkpoint Detects if this quantization method can support a given checkpoint
format by overriding the user specified quantization method -- format by overriding the user specified quantization method --
this method should only be overwritten by subclasses in exceptional this method should only be overwritten by subclasses in exceptional
circumstances circumstances.
Args:
hf_quant_cfg: The checkpoint's quantization config dict.
user_quant: The user-specified quantization method string.
hf_config: The HuggingFace model config object (e.g. for
model_type checks). May be None if not available.
""" """
return None return None
......
...@@ -104,7 +104,7 @@ class CPUAWQConfig(QuantizationConfig): ...@@ -104,7 +104,7 @@ class CPUAWQConfig(QuantizationConfig):
@classmethod @classmethod
def override_quantization_method( def override_quantization_method(
cls, hf_quant_cfg, user_quant cls, hf_quant_cfg, user_quant, hf_config=None
) -> "QuantizationMethods | None": ) -> "QuantizationMethods | None":
quant_method = hf_quant_cfg.get("quant_method", "").lower() quant_method = hf_quant_cfg.get("quant_method", "").lower()
if current_platform.is_cpu() and (quant_method == "awq"): if current_platform.is_cpu() and (quant_method == "awq"):
......
...@@ -84,7 +84,7 @@ class GGUFConfig(QuantizationConfig): ...@@ -84,7 +84,7 @@ class GGUFConfig(QuantizationConfig):
@classmethod @classmethod
def override_quantization_method( def override_quantization_method(
cls, hf_quant_cfg: dict[str, Any], user_quant: str | None cls, hf_quant_cfg: dict[str, Any], user_quant: str | None, hf_config=None
) -> "QuantizationMethods | None": ) -> "QuantizationMethods | None":
# When user explicitly specifies --quantization gguf, override # When user explicitly specifies --quantization gguf, override
# whatever quantization method is in the HF model config (e.g. fp8). # whatever quantization method is in the HF model config (e.g. fp8).
......
...@@ -214,7 +214,7 @@ class GPTQMarlinConfig(QuantizationConfig): ...@@ -214,7 +214,7 @@ class GPTQMarlinConfig(QuantizationConfig):
@classmethod @classmethod
def override_quantization_method( def override_quantization_method(
cls, hf_quant_cfg, user_quant cls, hf_quant_cfg, user_quant, hf_config=None
) -> QuantizationMethods | None: ) -> QuantizationMethods | None:
can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg) can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg)
......
...@@ -453,7 +453,7 @@ class INCConfig(QuantizationConfig): ...@@ -453,7 +453,7 @@ class INCConfig(QuantizationConfig):
@classmethod @classmethod
def override_quantization_method( def override_quantization_method(
cls, hf_quant_cfg, user_quant cls, hf_quant_cfg, user_quant, hf_config=None
) -> "QuantizationMethods | None": ) -> "QuantizationMethods | None":
"""Override the `auto-round` method to `inc`.""" """Override the `auto-round` method to `inc`."""
is_auto_round_format = hf_quant_cfg.get("quant_method", None) == "auto-round" is_auto_round_format = hf_quant_cfg.get("quant_method", None) == "auto-round"
......
...@@ -406,7 +406,7 @@ class ModelOptFp8Config(ModelOptQuantConfigBase): ...@@ -406,7 +406,7 @@ class ModelOptFp8Config(ModelOptQuantConfigBase):
@classmethod @classmethod
def override_quantization_method( def override_quantization_method(
cls, hf_quant_cfg, user_quant cls, hf_quant_cfg, user_quant, hf_config=None
) -> QuantizationMethods | None: ) -> QuantizationMethods | None:
algo = cls._extract_modelopt_quant_algo(hf_quant_cfg) algo = cls._extract_modelopt_quant_algo(hf_quant_cfg)
if algo is not None and algo == "FP8": if algo is not None and algo == "FP8":
...@@ -1028,7 +1028,7 @@ class ModelOptNvFp4Config(ModelOptQuantConfigBase): ...@@ -1028,7 +1028,7 @@ class ModelOptNvFp4Config(ModelOptQuantConfigBase):
@classmethod @classmethod
def override_quantization_method( def override_quantization_method(
cls, hf_quant_cfg, user_quant cls, hf_quant_cfg, user_quant, hf_config=None
) -> QuantizationMethods | None: ) -> QuantizationMethods | None:
algo = cls._extract_modelopt_quant_algo(hf_quant_cfg) algo = cls._extract_modelopt_quant_algo(hf_quant_cfg)
if algo is not None and ("NVFP4" in algo or "FP4" in algo): if algo is not None and ("NVFP4" in algo or "FP4" in algo):
...@@ -1525,7 +1525,7 @@ class ModelOptMxFp8Config(ModelOptQuantConfigBase): ...@@ -1525,7 +1525,7 @@ class ModelOptMxFp8Config(ModelOptQuantConfigBase):
@classmethod @classmethod
def override_quantization_method( def override_quantization_method(
cls, hf_quant_cfg, user_quant cls, hf_quant_cfg, user_quant, hf_config=None
) -> QuantizationMethods | None: ) -> QuantizationMethods | None:
algo = cls._extract_modelopt_quant_algo(hf_quant_cfg) algo = cls._extract_modelopt_quant_algo(hf_quant_cfg)
if algo is not None and "MXFP8" in algo: if algo is not None and "MXFP8" in algo:
...@@ -2052,7 +2052,7 @@ class ModelOptMixedPrecisionConfig(ModelOptQuantConfigBase): ...@@ -2052,7 +2052,7 @@ class ModelOptMixedPrecisionConfig(ModelOptQuantConfigBase):
@classmethod @classmethod
def override_quantization_method( def override_quantization_method(
cls, hf_quant_cfg, user_quant cls, hf_quant_cfg, user_quant, hf_config=None
) -> QuantizationMethods | None: ) -> QuantizationMethods | None:
algo = cls._extract_modelopt_quant_algo(hf_quant_cfg) algo = cls._extract_modelopt_quant_algo(hf_quant_cfg)
if algo is not None and algo == "MIXED_PRECISION": if algo is not None and algo == "MIXED_PRECISION":
......
...@@ -130,7 +130,7 @@ class MoeWNA16Config(QuantizationConfig): ...@@ -130,7 +130,7 @@ class MoeWNA16Config(QuantizationConfig):
@classmethod @classmethod
def override_quantization_method( def override_quantization_method(
cls, hf_quant_cfg, user_quant cls, hf_quant_cfg, user_quant, hf_config=None
) -> QuantizationMethods | None: ) -> QuantizationMethods | None:
can_convert = cls.is_moe_wna16_compatible(hf_quant_cfg) can_convert = cls.is_moe_wna16_compatible(hf_quant_cfg)
if can_convert and user_quant == "moe_wna16": if can_convert and user_quant == "moe_wna16":
......
...@@ -19,11 +19,11 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -19,11 +19,11 @@ from vllm.model_executor.layers.fused_moe.config import (
from vllm.model_executor.layers.fused_moe.oracle.mxfp4 import ( from vllm.model_executor.layers.fused_moe.oracle.mxfp4 import (
TRITON_BACKENDS, TRITON_BACKENDS,
Mxfp4MoeBackend, Mxfp4MoeBackend,
convert_to_mxfp4_moe_kernel_format, convert_gpt_oss_weight_to_mxfp4_moe_kernel_format,
make_mxfp4_moe_kernel, make_mxfp4_moe_kernel,
make_mxfp4_moe_quant_config, make_mxfp4_moe_quant_config,
mxfp4_round_up_hidden_size_and_intermediate_size, mxfp4_round_up_hidden_size_and_intermediate_size,
select_mxfp4_moe_backend, select_gpt_oss_mxfp4_moe_backend,
) )
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization import QuantizationMethods
...@@ -38,6 +38,12 @@ logger = init_logger(__name__) ...@@ -38,6 +38,12 @@ logger = init_logger(__name__)
class Mxfp4Config(QuantizationConfig): class Mxfp4Config(QuantizationConfig):
"""Canonical base config for MXFP4 quantization.
Subclasses override get_name() and override_quantization_method() to
register themselves as the handler for a specific checkpoint format.
"""
def __init__(self, ignored_layers: list[str] | None = None): def __init__(self, ignored_layers: list[str] | None = None):
super().__init__() super().__init__()
self.ignored_layers = ignored_layers self.ignored_layers = ignored_layers
...@@ -62,6 +68,8 @@ class Mxfp4Config(QuantizationConfig): ...@@ -62,6 +68,8 @@ class Mxfp4Config(QuantizationConfig):
def get_config_filenames(cls) -> list[str]: def get_config_filenames(cls) -> list[str]:
return [] return []
# TODO (zyongye) This is only temporaty fallback.
# We should have `Mxfp4MoEMethod` after this migration is complete.
def get_quant_method( def get_quant_method(
self, layer: torch.nn.Module, prefix: str self, layer: torch.nn.Module, prefix: str
) -> "QuantizeMethodBase | None": ) -> "QuantizeMethodBase | None":
...@@ -79,7 +87,7 @@ class Mxfp4Config(QuantizationConfig): ...@@ -79,7 +87,7 @@ class Mxfp4Config(QuantizationConfig):
) )
return UnquantizedLinearMethod() return UnquantizedLinearMethod()
elif isinstance(layer, FusedMoE): elif isinstance(layer, FusedMoE):
return Mxfp4MoEMethod(layer.moe_config) return GptOssMxfp4MoEMethod(layer.moe_config)
elif isinstance(layer, Attention): elif isinstance(layer, Attention):
logger.debug_once( logger.debug_once(
"MXFP4 attention layer is not implemented. " "MXFP4 attention layer is not implemented. "
...@@ -93,13 +101,46 @@ class Mxfp4Config(QuantizationConfig): ...@@ -93,13 +101,46 @@ class Mxfp4Config(QuantizationConfig):
return True return True
class Mxfp4MoEMethod(FusedMoEMethodBase): class GptOssMxfp4Config(Mxfp4Config):
"""MXFP4 config for GPT-OSS checkpoints.
Checkpoints carry ``"quant_method": "mxfp4"`` in their JSON config.
override_quantization_method() maps that to the canonical internal name
so that the rest of the loading path uses "gpt_oss_mxfp4" consistently.
"""
@classmethod
def get_name(cls) -> QuantizationMethods:
return "gpt_oss_mxfp4"
@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant, hf_config=None
) -> QuantizationMethods | None:
# Match both "mxfp4" (original checkpoint value) and "gpt_oss_mxfp4"
# (already normalized by verify_and_update_model_config) so that
# explicit --quantization mxfp4 from the user doesn't cause a mismatch.
if not (
isinstance(hf_quant_cfg, dict)
and hf_quant_cfg.get("quant_method") in ("mxfp4", "gpt_oss_mxfp4")
):
return None
# Require explicit confirmation that this is a GPT-OSS model.
# Do NOT fall back to returning the override when hf_config is None,
# as that would silently claim all mxfp4 checkpoints.
model_type = getattr(hf_config, "model_type", None)
if model_type != "gpt_oss":
return None
return "gpt_oss_mxfp4"
class GptOssMxfp4MoEMethod(FusedMoEMethodBase):
"""MXFP4 MoE quantization method.""" """MXFP4 MoE quantization method."""
def __init__(self, moe: FusedMoEConfig): def __init__(self, moe: FusedMoEConfig):
super().__init__(moe) super().__init__(moe)
self.weight_dtype = "mxfp4" self.weight_dtype = "gpt_oss_mxfp4"
self.mxfp4_backend, self.experts_cls = select_mxfp4_moe_backend(moe) self.mxfp4_backend, self.experts_cls = select_gpt_oss_mxfp4_moe_backend(moe)
self.max_capture_size = ( self.max_capture_size = (
get_current_vllm_config().compilation_config.max_cudagraph_capture_size get_current_vllm_config().compilation_config.max_cudagraph_capture_size
...@@ -281,7 +322,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -281,7 +322,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
# Convert weights to kernel format # Convert weights to kernel format
w13, w2, w13_scale, w2_scale, w13_bias, w2_bias = ( w13, w2, w13_scale, w2_scale, w13_bias, w2_bias = (
convert_to_mxfp4_moe_kernel_format( convert_gpt_oss_weight_to_mxfp4_moe_kernel_format(
mxfp4_backend=self.mxfp4_backend, mxfp4_backend=self.mxfp4_backend,
layer=layer, layer=layer,
w13_weight=w13, w13_weight=w13,
......
...@@ -30,11 +30,11 @@ from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_m ...@@ -30,11 +30,11 @@ from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_m
from vllm.model_executor.layers.fused_moe.oracle.mxfp4 import ( from vllm.model_executor.layers.fused_moe.oracle.mxfp4 import (
TRITON_BACKENDS, TRITON_BACKENDS,
Mxfp4MoeBackend, Mxfp4MoeBackend,
convert_to_mxfp4_moe_kernel_format, convert_gpt_oss_weight_to_mxfp4_moe_kernel_format,
make_mxfp4_moe_kernel, make_mxfp4_moe_kernel,
make_mxfp4_moe_quant_config, make_mxfp4_moe_quant_config,
mxfp4_round_up_hidden_size_and_intermediate_size, mxfp4_round_up_hidden_size_and_intermediate_size,
select_mxfp4_moe_backend, select_gpt_oss_mxfp4_moe_backend,
) )
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
prepare_fp8_moe_layer_for_marlin, prepare_fp8_moe_layer_for_marlin,
...@@ -995,7 +995,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): ...@@ -995,7 +995,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
self.w2_precision_config = None self.w2_precision_config = None
if self.ocp_mx_scheme == "w_mxfp4": if self.ocp_mx_scheme == "w_mxfp4":
self.mxfp4_backend, self.experts_cls = select_mxfp4_moe_backend(moe) self.mxfp4_backend, self.experts_cls = select_gpt_oss_mxfp4_moe_backend(moe)
elif self.ocp_mx_scheme.startswith("w_mxfp4"): elif self.ocp_mx_scheme.startswith("w_mxfp4"):
# TODO(bowenbao): refactor and introduce backends for other OCP MX schemes. # TODO(bowenbao): refactor and introduce backends for other OCP MX schemes.
self.mxfp4_backend = Mxfp4MoeBackend.NONE self.mxfp4_backend = Mxfp4MoeBackend.NONE
...@@ -1300,7 +1300,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): ...@@ -1300,7 +1300,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
# Convert weights to kernel format # Convert weights to kernel format
w13, w2, w13_scale, w2_scale, w13_bias, w2_bias = ( w13, w2, w13_scale, w2_scale, w13_bias, w2_bias = (
convert_to_mxfp4_moe_kernel_format( convert_gpt_oss_weight_to_mxfp4_moe_kernel_format(
mxfp4_backend=self.mxfp4_backend, mxfp4_backend=self.mxfp4_backend,
layer=layer, layer=layer,
w13_weight=w13, w13_weight=w13,
......
...@@ -108,6 +108,23 @@ class Gemma4Config(VerifyAndUpdateConfig): ...@@ -108,6 +108,23 @@ class Gemma4Config(VerifyAndUpdateConfig):
class GptOssForCausalLMConfig(VerifyAndUpdateConfig): class GptOssForCausalLMConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
quant_config = getattr(model_config.hf_config, "quantization_config", None)
if quant_config is not None and quant_config.get("quant_method") == "mxfp4":
model_config.hf_config.quantization_config["quant_method"] = "gpt_oss_mxfp4"
hf_text_quant_config = getattr(
model_config.hf_text_config, "quantization_config", None
)
if (
hf_text_quant_config is not None
and hf_text_quant_config.get("quant_method") == "mxfp4"
):
model_config.hf_text_config.quantization_config["quant_method"] = (
"gpt_oss_mxfp4"
)
@staticmethod @staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None: def verify_and_update_config(vllm_config: "VllmConfig") -> None:
structured_outputs_config = vllm_config.structured_outputs_config structured_outputs_config = vllm_config.structured_outputs_config
......
...@@ -560,6 +560,14 @@ class GptOssModel(nn.Module, EagleModelMixin): ...@@ -560,6 +560,14 @@ class GptOssModel(nn.Module, EagleModelMixin):
pcp_rank=get_pcp_group().rank_in_group, pcp_rank=get_pcp_group().rank_in_group,
) )
def _is_mxfp4(weight_dtype: str | None) -> bool:
"""Return True for any MXFP4 weight-dtype variant.
Covers "gpt_oss_mxfp4" (GptOssMxfp4MoEMethod) and "mxfp4"
(QuarkMoEMethod with fp4 weights) and any future variants.
"""
return weight_dtype is not None and "mxfp4" in weight_dtype
def _get_moe_weight_dtype(layer_id: int = 0) -> str | None: def _get_moe_weight_dtype(layer_id: int = 0) -> str | None:
"""Helper function to get MoE quantization weight dtype. """Helper function to get MoE quantization weight dtype.
...@@ -578,7 +586,7 @@ class GptOssModel(nn.Module, EagleModelMixin): ...@@ -578,7 +586,7 @@ class GptOssModel(nn.Module, EagleModelMixin):
moe_weight_dtype = _get_moe_weight_dtype(layer_id=0) moe_weight_dtype = _get_moe_weight_dtype(layer_id=0)
if moe_weight_dtype == "mxfp4": if _is_mxfp4(moe_weight_dtype):
# MXFP4 requires OCP_MX_BLOCK_SIZE alignment # MXFP4 requires OCP_MX_BLOCK_SIZE alignment
intermediate_size_block = intermediate_size // OCP_MX_BLOCK_SIZE intermediate_size_block = intermediate_size // OCP_MX_BLOCK_SIZE
per_rank_intermediate_size_block = cdiv(intermediate_size_block, tp_size) per_rank_intermediate_size_block = cdiv(intermediate_size_block, tp_size)
...@@ -682,7 +690,7 @@ class GptOssModel(nn.Module, EagleModelMixin): ...@@ -682,7 +690,7 @@ class GptOssModel(nn.Module, EagleModelMixin):
continue continue
# Unified handler for mxfp4 weights and scales # Unified handler for mxfp4 weights and scales
elif moe_quant_method == "mxfp4" and any( elif _is_mxfp4(moe_quant_method) and any(
name.endswith(suffix) name.endswith(suffix)
for suffix in [ for suffix in [
".w13_weight_scale", ".w13_weight_scale",
...@@ -1116,8 +1124,22 @@ class GptOssModel(nn.Module, EagleModelMixin): ...@@ -1116,8 +1124,22 @@ class GptOssModel(nn.Module, EagleModelMixin):
if hasattr(self.config, "quantization_config") if hasattr(self.config, "quantization_config")
else None else None
) )
# Normalize the checkpoint's quant_method to the internal name.
# Note: there are three places where "mxfp4" -> "gpt_oss_mxfp4"
# normalization occurs, each serving a different data path:
# 1. GptOssMxfp4Config.override_quantization_method() — sets
# ModelConfig.quantization (used to select the QuantizationConfig
# class at model init time), reading from model_arch_config which
# is a snapshot taken before verify_and_update_model_config runs.
# 2. GptOssForCausalLMConfig.verify_and_update_model_config() —
# patches hf_config.quantization_config in-place (a separate copy
# of the dict from model_arch_config) for later hf_config lookups.
# 3. Here — reads directly from self.config (the raw HF config) which
# may still carry the original "mxfp4" string from the checkpoint.
if quant_method == "mxfp4": if quant_method == "mxfp4":
quant_method = "gpt_oss_mxfp4"
if quant_method == "gpt_oss_mxfp4":
return self._load_weights_mxfp4( return self._load_weights_mxfp4(
ep_rank_end, ep_rank_end,
ep_rank_start, ep_rank_start,
......
...@@ -150,7 +150,7 @@ class Base( ...@@ -150,7 +150,7 @@ class Base(
if self.quant_config: if self.quant_config:
quant_method_name = self.quant_config.get_name() quant_method_name = self.quant_config.get_name()
# Check for unsupported quantization methods. # Check for unsupported quantization methods.
if quant_method_name == "mxfp4": if quant_method_name in ("mxfp4", "gpt_oss_mxfp4"):
raise NotImplementedError( raise NotImplementedError(
"Transformers modeling backend does " "Transformers modeling backend does "
"not support MXFP4 quantization yet." "not support MXFP4 quantization yet."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment