"vscode:/vscode.git/clone" did not exist on "094b7d9496ccbdcb15bafbdab0083e54734da2d6"
Unverified Commit d0105b84 authored by sychen52's avatar sychen52 Committed by GitHub
Browse files

add mixed precision support for modelopt (#35047)


Signed-off-by: default avatarShiyang Chen <shiychen@nvidia.com>
parent 832a780f
...@@ -883,6 +883,7 @@ class ModelConfig: ...@@ -883,6 +883,7 @@ class ModelConfig:
"modelopt", "modelopt",
"modelopt_fp4", "modelopt_fp4",
"modelopt_mxfp8", "modelopt_mxfp8",
"modelopt_mixed",
"petit_nvfp4", "petit_nvfp4",
# Ensure heavy backends are probed last to avoid unnecessary # Ensure heavy backends are probed last to avoid unnecessary
# imports during override detection (e.g., MXFP4 imports Triton) # imports during override detection (e.g., MXFP4 imports Triton)
......
...@@ -18,6 +18,7 @@ QuantizationMethods = Literal[ ...@@ -18,6 +18,7 @@ QuantizationMethods = Literal[
"modelopt", "modelopt",
"modelopt_fp4", "modelopt_fp4",
"modelopt_mxfp8", "modelopt_mxfp8",
"modelopt_mixed",
"gguf", "gguf",
"gptq_marlin", "gptq_marlin",
"awq_marlin", "awq_marlin",
...@@ -120,7 +121,12 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: ...@@ -120,7 +121,12 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
from .gptq import GPTQConfig from .gptq import GPTQConfig
from .gptq_marlin import GPTQMarlinConfig from .gptq_marlin import GPTQMarlinConfig
from .inc import INCConfig from .inc import INCConfig
from .modelopt import ModelOptFp8Config, ModelOptMxFp8Config, ModelOptNvFp4Config from .modelopt import (
ModelOptFp8Config,
ModelOptMixedPrecisionConfig,
ModelOptMxFp8Config,
ModelOptNvFp4Config,
)
from .moe_wna16 import MoeWNA16Config from .moe_wna16 import MoeWNA16Config
from .mxfp4 import Mxfp4Config from .mxfp4 import Mxfp4Config
from .petit import PetitNvFp4Config from .petit import PetitNvFp4Config
...@@ -135,6 +141,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: ...@@ -135,6 +141,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"modelopt": ModelOptFp8Config, "modelopt": ModelOptFp8Config,
"modelopt_fp4": ModelOptNvFp4Config, "modelopt_fp4": ModelOptNvFp4Config,
"modelopt_mxfp8": ModelOptMxFp8Config, "modelopt_mxfp8": ModelOptMxFp8Config,
"modelopt_mixed": ModelOptMixedPrecisionConfig,
"gguf": GGUFConfig, "gguf": GGUFConfig,
"gptq_marlin": GPTQMarlinConfig, "gptq_marlin": GPTQMarlinConfig,
"awq_marlin": AWQMarlinConfig, "awq_marlin": AWQMarlinConfig,
......
...@@ -114,6 +114,8 @@ QUANT_ALGOS = [ ...@@ -114,6 +114,8 @@ QUANT_ALGOS = [
"NVFP4", "NVFP4",
# MXFP8 # MXFP8
"MXFP8", "MXFP8",
# MIXED_PRECISION,
"MIXED_PRECISION",
] ]
KV_CACHE_QUANT_ALGOS = ["FP8"] KV_CACHE_QUANT_ALGOS = ["FP8"]
...@@ -235,6 +237,26 @@ class ModelOptQuantConfigBase(QuantizationConfig): ...@@ -235,6 +237,26 @@ class ModelOptQuantConfigBase(QuantizationConfig):
self.exclude_modules = hf_to_vllm_mapper.apply_list(new_exclude_modules) self.exclude_modules = hf_to_vllm_mapper.apply_list(new_exclude_modules)
@staticmethod
def _extract_modelopt_quant_algo(
hf_quant_cfg: dict[str, Any] | None,
) -> str | None:
"""Extract upper-cased quant_algo from a modelopt config.
Returns the quant_algo string (upper-cased), or None if the config
is not a modelopt config.
"""
if hf_quant_cfg is None:
return None
if hf_quant_cfg.get("quant_method", "").lower() != "modelopt":
return None
if "quantization" in hf_quant_cfg:
quant_config = hf_quant_cfg["quantization"]
if isinstance(quant_config, dict):
return str(quant_config.get("quant_algo", "")).upper()
return None
return str(hf_quant_cfg.get("quant_algo", "")).upper()
@staticmethod @staticmethod
def get_config_filenames() -> list[str]: def get_config_filenames() -> list[str]:
return ["hf_quant_config.json"] return ["hf_quant_config.json"]
...@@ -272,10 +294,20 @@ class ModelOptQuantConfigBase(QuantizationConfig): ...@@ -272,10 +294,20 @@ class ModelOptQuantConfigBase(QuantizationConfig):
# "exclude_modules" is the key in the legacy hf_quant_config.json # "exclude_modules" is the key in the legacy hf_quant_config.json
exclude_modules = quant_config.get("exclude_modules", []) exclude_modules = quant_config.get("exclude_modules", [])
else: else:
# Compressed-tensors style format: # Compressed-tensors style format (config.json quantization_config):
# {"quant_algo": "...", "quant_method": "modelopt"} # {"quant_algo": "...", "quant_method": "modelopt"}
quant_method = config.get("quant_algo") quant_method = config.get("quant_algo")
kv_cache_quant_method = config.get("kv_cache_quant_algo")
# "kv_cache_scheme" (a dict) instead of "kv_cache_quant_algo" (a string).
kv_cache_scheme = config.get("kv_cache_scheme")
if isinstance(kv_cache_scheme, dict) and (
kv_cache_scheme.get("type") == "float"
and kv_cache_scheme.get("num_bits") == 8
):
kv_cache_quant_method = "FP8"
else:
kv_cache_quant_method = None
# "ignore" is the key in config.json # "ignore" is the key in config.json
exclude_modules = config.get("ignore", []) exclude_modules = config.get("ignore", [])
group_size_raw = config.get("group_size") group_size_raw = config.get("group_size")
...@@ -379,32 +411,9 @@ class ModelOptFp8Config(ModelOptQuantConfigBase): ...@@ -379,32 +411,9 @@ class ModelOptFp8Config(ModelOptQuantConfigBase):
def override_quantization_method( def override_quantization_method(
cls, hf_quant_cfg, user_quant cls, hf_quant_cfg, user_quant
) -> QuantizationMethods | None: ) -> QuantizationMethods | None:
"""Detect if this ModelOpt config should be used based on algo = cls._extract_modelopt_quant_algo(hf_quant_cfg)
quantization config.""" if algo is not None and algo == "FP8":
return "modelopt"
if hf_quant_cfg is None:
return None
# Use the community standard 'quant_method'
quant_method = hf_quant_cfg.get("quant_method", "").lower()
# Only proceed if the method is explicitly "modelopt"
if quant_method != "modelopt":
return None
# Look for ModelOpt-specific config structure
if "quantization" in hf_quant_cfg:
quant_config = hf_quant_cfg["quantization"]
if isinstance(quant_config, dict):
quant_algo = str(quant_config.get("quant_algo", ""))
if quant_algo.upper() == "FP8":
return "modelopt"
else:
# Check for compressed-tensors style config with specific quant_algo
quant_algo = str(hf_quant_cfg.get("quant_algo", ""))
if quant_algo.upper() == "FP8":
return "modelopt"
return None return None
@classmethod @classmethod
...@@ -1031,32 +1040,9 @@ class ModelOptNvFp4Config(ModelOptQuantConfigBase): ...@@ -1031,32 +1040,9 @@ class ModelOptNvFp4Config(ModelOptQuantConfigBase):
def override_quantization_method( def override_quantization_method(
cls, hf_quant_cfg, user_quant cls, hf_quant_cfg, user_quant
) -> QuantizationMethods | None: ) -> QuantizationMethods | None:
"""Detect if this ModelOpt FP4 config should be used based on algo = cls._extract_modelopt_quant_algo(hf_quant_cfg)
quantization config.""" if algo is not None and ("NVFP4" in algo or "FP4" in algo):
if hf_quant_cfg is None: return "modelopt_fp4"
return None
# Use the community standard 'quant_method'
quant_method = hf_quant_cfg.get("quant_method", "").lower()
# Only proceed if the method is explicitly "modelopt"
if quant_method != "modelopt":
return None
# Look for ModelOpt-specific config structure
if "quantization" in hf_quant_cfg:
quant_config = hf_quant_cfg["quantization"]
if isinstance(quant_config, dict):
quant_algo = quant_config.get("quant_algo", "")
if "NVFP4" in quant_algo:
return "modelopt_fp4"
else:
# Check for compressed-tensors style config with specific
# quant_algo field
quant_algo = hf_quant_cfg.get("quant_algo", "")
if isinstance(quant_algo, str) and "FP4" in quant_algo.upper():
return "modelopt_fp4"
return None return None
@classmethod @classmethod
...@@ -1619,31 +1605,9 @@ class ModelOptMxFp8Config(ModelOptQuantConfigBase): ...@@ -1619,31 +1605,9 @@ class ModelOptMxFp8Config(ModelOptQuantConfigBase):
def override_quantization_method( def override_quantization_method(
cls, hf_quant_cfg, user_quant cls, hf_quant_cfg, user_quant
) -> QuantizationMethods | None: ) -> QuantizationMethods | None:
"""Detect if this ModelOpt MXFP8 config should be used based on algo = cls._extract_modelopt_quant_algo(hf_quant_cfg)
quantization config.""" if algo is not None and "MXFP8" in algo:
if hf_quant_cfg is None: return "modelopt_mxfp8"
return None
# Use the community standard 'quant_method'
quant_method = hf_quant_cfg.get("quant_method", "").lower()
# Only proceed if the method is explicitly "modelopt"
if quant_method != "modelopt":
return None
# Look for ModelOpt-specific config structure
if "quantization" in hf_quant_cfg:
quant_config = hf_quant_cfg["quantization"]
if isinstance(quant_config, dict):
quant_algo = str(quant_config.get("quant_algo", "")).upper()
if "MXFP8" in quant_algo:
return "modelopt_mxfp8"
else:
# Check for compressed-tensors style config with specific quant_algo
quant_algo = str(hf_quant_cfg.get("quant_algo", "")).upper()
if "MXFP8" in quant_algo:
return "modelopt_mxfp8"
return None return None
@classmethod @classmethod
...@@ -1841,3 +1805,188 @@ class ModelOptMxFp8LinearMethod(LinearMethodBase): ...@@ -1841,3 +1805,188 @@ class ModelOptMxFp8LinearMethod(LinearMethodBase):
# Register the method classes for ModelOptMxFp8Config # Register the method classes for ModelOptMxFp8Config
ModelOptMxFp8Config.LinearMethodCls = ModelOptMxFp8LinearMethod ModelOptMxFp8Config.LinearMethodCls = ModelOptMxFp8LinearMethod
ModelOptMxFp8Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod ModelOptMxFp8Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod
class ModelOptMixedPrecisionConfig(ModelOptQuantConfigBase):
"""Config class for ModelOpt MIXED_PRECISION.
Supports checkpoints where different layers use different quantization
algorithms (e.g., FP8 for dense layers and NVFP4 for MoE experts).
The per-layer algorithm is specified in the ``quantized_layers`` dict
inside ``config.json``'s ``quantization_config`` (preferred) or the
legacy ``hf_quant_config.json``.
"""
def __init__(
self,
kv_cache_quant_method: str | None,
exclude_modules: list[str],
quantized_layers: dict[str, dict[str, Any]],
fp8_config: ModelOptFp8Config,
nvfp4_config: ModelOptNvFp4Config,
) -> None:
super().__init__(exclude_modules)
self.kv_cache_quant_method = kv_cache_quant_method
self.quantized_layers = quantized_layers
self.fp8_config = fp8_config
self.nvfp4_config = nvfp4_config
def get_name(self) -> QuantizationMethods:
return "modelopt_mixed"
def get_supported_act_dtypes(self) -> list[torch.dtype]:
return [torch.bfloat16, torch.half]
@classmethod
def get_min_capability(cls) -> int:
return 89
@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant
) -> QuantizationMethods | None:
algo = cls._extract_modelopt_quant_algo(hf_quant_cfg)
if algo is not None and algo == "MIXED_PRECISION":
return "modelopt_mixed"
return None
@classmethod
def _from_config(
cls,
*,
quant_method: str,
kv_cache_quant_method: str | None,
exclude_modules: list[str],
original_config: dict[str, Any],
group_size: int | None,
**kwargs: Any,
) -> "ModelOptMixedPrecisionConfig":
if "quantization" in original_config:
quantized_layers = original_config["quantization"].get(
"quantized_layers", {}
)
else:
quantized_layers = original_config.get("quantized_layers", {})
if not quantized_layers:
raise ValueError(
"MIXED_PRECISION quant_algo requires a non-empty "
"'quantized_layers' mapping in the quantization config."
)
# Determine group_size from the first NVFP4 entry if not provided.
if group_size is None:
for layer_info in quantized_layers.values():
if layer_info.get("quant_algo", "").upper() == "NVFP4":
group_size = layer_info.get("group_size", 16)
break
if group_size is None:
group_size = 16
fp8_config = ModelOptFp8Config(
quant_method="FP8",
is_checkpoint_fp8_serialized=True,
kv_cache_quant_method=kv_cache_quant_method,
exclude_modules=[],
)
nvfp4_config = ModelOptNvFp4Config(
is_checkpoint_nvfp4_serialized=True,
kv_cache_quant_algo=kv_cache_quant_method,
exclude_modules=[],
group_size=group_size,
)
return cls(
kv_cache_quant_method=kv_cache_quant_method,
exclude_modules=exclude_modules,
quantized_layers=quantized_layers,
fp8_config=fp8_config,
nvfp4_config=nvfp4_config,
)
def _resolve_quant_algo(self, prefix: str) -> str | None:
"""Look up the quant_algo for a vLLM-side layer prefix.
Tries three strategies in order:
1. Direct lookup in ``quantized_layers``.
2. Packed/fused-layer lookup (unfuse via ``packed_modules_mapping``).
3. Prefix-based lookup for FusedMoE (any child key starts with
``prefix + "."``).
Returns the upper-cased quant_algo string, or *None* if the prefix
is not found.
"""
# 1. Direct lookup
if prefix in self.quantized_layers:
return self.quantized_layers[prefix]["quant_algo"].upper()
# 2. Packed / fused layer lookup
proj_name = prefix.rsplit(".", 1)[-1]
if self.packed_modules_mapping and proj_name in self.packed_modules_mapping:
algos: set[str] = set()
base = prefix.rsplit(".", 1)[0]
for shard_name in self.packed_modules_mapping[proj_name]:
shard_prefix = f"{base}.{shard_name}"
if shard_prefix in self.quantized_layers:
algos.add(self.quantized_layers[shard_prefix]["quant_algo"].upper())
if len(algos) == 1:
return algos.pop()
if len(algos) > 1:
raise ValueError(
f"Mixed quant_algo within fused layer {prefix}: "
f"{algos}. All shards must use the same quantization."
)
# 3. Prefix-based lookup (for FusedMoE / parent modules)
prefix_dot = prefix + "."
for key, info in self.quantized_layers.items():
if key.startswith(prefix_dot):
return info["quant_algo"].upper()
return None
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> "QuantizeMethodBase | None":
"""Return quantize-method based on layer."""
# KV-cache quantization
if isinstance(layer, Attention):
if self.kv_cache_quant_method:
return ModelOptFp8KVCacheMethod(self)
return None
# Excluded layers
if self.is_layer_excluded(prefix):
if isinstance(layer, LinearBase):
return UnquantizedLinearMethod()
return None
quant_algo = self._resolve_quant_algo(prefix)
if isinstance(layer, LinearBase):
if quant_algo == "FP8":
return ModelOptFp8LinearMethod(self.fp8_config)
if quant_algo == "NVFP4":
return ModelOptNvFp4LinearMethod(self.nvfp4_config)
# Layer not in quantized_layers — leave unquantized
return UnquantizedLinearMethod()
if isinstance(layer, FusedMoE):
if quant_algo == "FP8":
return ModelOptFp8MoEMethod(
quant_config=self.fp8_config,
moe_config=layer.moe_config,
)
if quant_algo == "NVFP4":
return ModelOptNvFp4FusedMoE(
quant_config=self.nvfp4_config,
moe_config=layer.moe_config,
)
return None
return None
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
super().apply_vllm_mapper(hf_to_vllm_mapper)
if self.quantized_layers:
self.quantized_layers = hf_to_vllm_mapper.apply_dict(self.quantized_layers)
...@@ -287,7 +287,17 @@ def get_quant_config( ...@@ -287,7 +287,17 @@ def get_quant_config(
) )
if hf_quant_config is not None: if hf_quant_config is not None:
return quant_cls.from_config(hf_quant_config) # For modelopt_mixed, config.json's quantization_config may or may
# not contain the per-layer quantized_layers map. Newer checkpoints
# embed it directly; older ones keep it only in hf_quant_config.json.
# If it is missing, fall through to the file-based loading path.
if (
model_config.quantization == "modelopt_mixed"
and "quantized_layers" not in hf_quant_config
):
pass # fall through to file-based loading below
else:
return quant_cls.from_config(hf_quant_config)
# if hf_quant_config is None, we will try to get config from # if hf_quant_config is None, we will try to get config from
# hf_overrides # hf_overrides
...@@ -365,8 +375,8 @@ def get_quant_config( ...@@ -365,8 +375,8 @@ def get_quant_config(
if model_config.quantization == "bitsandbytes": if model_config.quantization == "bitsandbytes":
config["adapter_name_or_path"] = model_config.model config["adapter_name_or_path"] = model_config.model
elif model_config.quantization == "modelopt": elif model_config.quantization in ("modelopt", "modelopt_mixed"):
if config["producer"]["name"] == "modelopt": if config.get("producer", {}).get("name") == "modelopt":
return quant_cls.from_config(config) return quant_cls.from_config(config)
else: else:
raise ValueError( raise ValueError(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment