Unverified Commit 2256d62d authored by Zhiyu's avatar Zhiyu Committed by GitHub
Browse files

Modelopt quant config adaptation (#8829)

parent 6cdcbcc6
......@@ -111,18 +111,52 @@ class ModelOptFp8Config(QuantizationConfig):
@classmethod
def from_config(cls, config: Dict[str, Any]) -> ModelOptFp8Config:
quant_method = cls.get_from_keys(config, ["quantization"]).get("quant_algo")
kv_cache_quant_method = cls.get_from_keys(config, ["quantization"]).get(
"kv_cache_quant_algo"
)
exclude_modules = cls.get_from_keys(config, ["quantization"]).get(
"exclude_modules"
)
# Handle two different config formats:
# 1. hf_quant_config.json format: {"quantization": {"quant_algo": "FP8", ...}}
# 2. config.json quantization_config format: {"quant_algo": "FP8", ...}
# In future modelopt will deprecate hf_quant_config.json, and only keep config.json.
# For legacy reasons, we keep hf_quant_config.json for now.
# Initialize variables
kv_cache_quant_method = None
exclude_modules = None
# Try flat format first (config.json quantization_config - preferred format)
quant_method = config.get("quant_algo")
if quant_method is not None:
# Flat format (config.json quantization_config)
# For kv_cache, check if kv_cache_scheme exists and extract algo
kv_cache_scheme = config.get("kv_cache_scheme")
if (
kv_cache_scheme
and kv_cache_scheme.get("type") == "float"
and kv_cache_scheme.get("num_bits") == 8
):
kv_cache_quant_method = "FP8"
# Map 'ignore' field to 'exclude_modules'
exclude_modules = config.get("ignore")
else:
# Fall back to nested format (hf_quant_config.json - legacy format)
try:
quantization_section = cls.get_from_keys(config, ["quantization"])
quant_method = quantization_section.get("quant_algo")
kv_cache_quant_method = quantization_section.get("kv_cache_quant_algo")
exclude_modules = quantization_section.get("exclude_modules")
except ValueError:
raise ValueError(
"Cannot find 'quant_algo' in the model's quantization config. "
"Expected either flat format (config.json) or nested format (hf_quant_config.json)."
)
if quant_method is None:
raise ValueError(
"Cannot find 'quant_algo' in the model's quantization config. "
)
if "FP8" not in quant_method:
raise ValueError(
"ModelOpt only supports static FP8 quantization in SGLang. "
"Check the `hf_quant_config.json` file for your model's configuration."
"ModelOptFp8Config only supports static FP8 quantization in SGLang. "
"For FP4 quantization, use ModelOptFp4Config. "
"Check the quantization config for your model's configuration."
)
return cls(
......@@ -485,22 +519,63 @@ class ModelOptFp4Config(QuantizationConfig):
@classmethod
def from_config(cls, config: Dict[str, Any]) -> ModelOptFp4Config:
quant_config = cls.get_from_keys(config, ["quantization"])
quant_method = quant_config["quant_algo"]
# Handle two different config formats:
# 1. hf_quant_config.json format: {"quantization": {"quant_algo": "NVFP4", ...}}
# 2. config.json quantization_config format: {"quant_algo": "NVFP4", ...}
# In future modelopt will deprecate hf_quant_config.json, and only keep config.json.
# For legacy reasons, we keep hf_quant_config.json for now.
# Initialize variables
kv_cache_quant_algo = None
group_size = None
exclude_modules = []
# Try flat format first (config.json quantization_config - preferred format)
quant_method = config.get("quant_algo")
if quant_method is not None:
# Flat format (config.json quantization_config)
# Note: FP4 models in config.json format may not have all the detailed fields
# that are present in hf_quant_config.json, so we need to handle defaults
kv_cache_quant_algo = config.get("kv_cache_quant_algo")
if not kv_cache_quant_algo:
# For config.json format, derive from kv_cache_scheme if available
kv_cache_scheme = config.get("kv_cache_scheme")
if (
kv_cache_scheme
and kv_cache_scheme.get("type") == "float"
and kv_cache_scheme.get("num_bits") == 8
):
kv_cache_quant_algo = "FP8"
else:
kv_cache_quant_algo = "auto"
group_size = config.get("group_size")
exclude_modules = config.get("ignore", [])
else:
# Fall back to nested format (hf_quant_config.json - legacy format)
try:
quant_config = cls.get_from_keys(config, ["quantization"])
quant_method = quant_config["quant_algo"]
kv_cache_quant_algo = quant_config.get("kv_cache_quant_algo")
if not kv_cache_quant_algo:
kv_cache_quant_algo = "auto"
group_size = quant_config.get("group_size")
exclude_modules = quant_config.get("exclude_modules", [])
except (ValueError, KeyError):
raise ValueError(
"Cannot find 'quant_algo' in the model's quantization config. "
"Expected either flat format (config.json) or nested format (hf_quant_config.json)."
)
if not quant_method in ["FP8", "NVFP4"]:
raise ValueError(
f"ModelOpt currently only supports: FP8, NVFP4"
" quantizations in sglang. Please check the "
"`hf_quant_config.json` file for your model's "
"quant configuration."
"quantization config for your model's configuration."
)
is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
kv_cache_quant_algo = quant_config["kv_cache_quant_algo"]
if not kv_cache_quant_algo:
kv_cache_quant_algo = "auto"
group_size = quant_config["group_size"]
exclude_modules = quant_config["exclude_modules"]
if not (group_size and kv_cache_quant_algo and exclude_modules):
if not (group_size and kv_cache_quant_algo) or exclude_modules is None:
logger.warning(
f"group_size: {group_size},"
f"kv_cache_quant_algo: {kv_cache_quant_algo},"
......@@ -508,8 +583,7 @@ class ModelOptFp4Config(QuantizationConfig):
)
raise ValueError(
"NVFP4 quantization requires group size and "
"kv_cache_quant_algo specified in "
"hf_quant_config.json"
"kv_cache_quant_algo specified in the quantization config"
)
return cls(
is_checkpoint_nvfp4_serialized,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment