Unverified Commit 2256d62d authored by Zhiyu's avatar Zhiyu Committed by GitHub
Browse files

Modelopt quant config adaptation (#8829)

parent 6cdcbcc6
...@@ -111,18 +111,52 @@ class ModelOptFp8Config(QuantizationConfig): ...@@ -111,18 +111,52 @@ class ModelOptFp8Config(QuantizationConfig):
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> ModelOptFp8Config: def from_config(cls, config: Dict[str, Any]) -> ModelOptFp8Config:
quant_method = cls.get_from_keys(config, ["quantization"]).get("quant_algo") # Handle two different config formats:
kv_cache_quant_method = cls.get_from_keys(config, ["quantization"]).get( # 1. hf_quant_config.json format: {"quantization": {"quant_algo": "FP8", ...}}
"kv_cache_quant_algo" # 2. config.json quantization_config format: {"quant_algo": "FP8", ...}
# In future modelopt will deprecate hf_quant_config.json, and only keep config.json.
# For legacy reasons, we keep hf_quant_config.json for now.
# Initialize variables
kv_cache_quant_method = None
exclude_modules = None
# Try flat format first (config.json quantization_config - preferred format)
quant_method = config.get("quant_algo")
if quant_method is not None:
# Flat format (config.json quantization_config)
# For kv_cache, check if kv_cache_scheme exists and extract algo
kv_cache_scheme = config.get("kv_cache_scheme")
if (
kv_cache_scheme
and kv_cache_scheme.get("type") == "float"
and kv_cache_scheme.get("num_bits") == 8
):
kv_cache_quant_method = "FP8"
# Map 'ignore' field to 'exclude_modules'
exclude_modules = config.get("ignore")
else:
# Fall back to nested format (hf_quant_config.json - legacy format)
try:
quantization_section = cls.get_from_keys(config, ["quantization"])
quant_method = quantization_section.get("quant_algo")
kv_cache_quant_method = quantization_section.get("kv_cache_quant_algo")
exclude_modules = quantization_section.get("exclude_modules")
except ValueError:
raise ValueError(
"Cannot find 'quant_algo' in the model's quantization config. "
"Expected either flat format (config.json) or nested format (hf_quant_config.json)."
) )
exclude_modules = cls.get_from_keys(config, ["quantization"]).get( if quant_method is None:
"exclude_modules" raise ValueError(
"Cannot find 'quant_algo' in the model's quantization config. "
) )
if "FP8" not in quant_method: if "FP8" not in quant_method:
raise ValueError( raise ValueError(
"ModelOpt only supports static FP8 quantization in SGLang. " "ModelOptFp8Config only supports static FP8 quantization in SGLang. "
"Check the `hf_quant_config.json` file for your model's configuration." "For FP4 quantization, use ModelOptFp4Config. "
"Check the quantization config for your model's configuration."
) )
return cls( return cls(
...@@ -485,22 +519,63 @@ class ModelOptFp4Config(QuantizationConfig): ...@@ -485,22 +519,63 @@ class ModelOptFp4Config(QuantizationConfig):
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> ModelOptFp4Config: def from_config(cls, config: Dict[str, Any]) -> ModelOptFp4Config:
# Handle two different config formats:
# 1. hf_quant_config.json format: {"quantization": {"quant_algo": "NVFP4", ...}}
# 2. config.json quantization_config format: {"quant_algo": "NVFP4", ...}
# In future modelopt will deprecate hf_quant_config.json, and only keep config.json.
# For legacy reasons, we keep hf_quant_config.json for now.
# Initialize variables
kv_cache_quant_algo = None
group_size = None
exclude_modules = []
# Try flat format first (config.json quantization_config - preferred format)
quant_method = config.get("quant_algo")
if quant_method is not None:
# Flat format (config.json quantization_config)
# Note: FP4 models in config.json format may not have all the detailed fields
# that are present in hf_quant_config.json, so we need to handle defaults
kv_cache_quant_algo = config.get("kv_cache_quant_algo")
if not kv_cache_quant_algo:
# For config.json format, derive from kv_cache_scheme if available
kv_cache_scheme = config.get("kv_cache_scheme")
if (
kv_cache_scheme
and kv_cache_scheme.get("type") == "float"
and kv_cache_scheme.get("num_bits") == 8
):
kv_cache_quant_algo = "FP8"
else:
kv_cache_quant_algo = "auto"
group_size = config.get("group_size")
exclude_modules = config.get("ignore", [])
else:
# Fall back to nested format (hf_quant_config.json - legacy format)
try:
quant_config = cls.get_from_keys(config, ["quantization"]) quant_config = cls.get_from_keys(config, ["quantization"])
quant_method = quant_config["quant_algo"] quant_method = quant_config["quant_algo"]
kv_cache_quant_algo = quant_config.get("kv_cache_quant_algo")
if not kv_cache_quant_algo:
kv_cache_quant_algo = "auto"
group_size = quant_config.get("group_size")
exclude_modules = quant_config.get("exclude_modules", [])
except (ValueError, KeyError):
raise ValueError(
"Cannot find 'quant_algo' in the model's quantization config. "
"Expected either flat format (config.json) or nested format (hf_quant_config.json)."
)
if not quant_method in ["FP8", "NVFP4"]: if not quant_method in ["FP8", "NVFP4"]:
raise ValueError( raise ValueError(
f"ModelOpt currently only supports: FP8, NVFP4" f"ModelOpt currently only supports: FP8, NVFP4"
" quantizations in sglang. Please check the " " quantizations in sglang. Please check the "
"`hf_quant_config.json` file for your model's " "quantization config for your model's configuration."
"quant configuration."
) )
is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
kv_cache_quant_algo = quant_config["kv_cache_quant_algo"]
if not kv_cache_quant_algo: if not (group_size and kv_cache_quant_algo) or exclude_modules is None:
kv_cache_quant_algo = "auto"
group_size = quant_config["group_size"]
exclude_modules = quant_config["exclude_modules"]
if not (group_size and kv_cache_quant_algo and exclude_modules):
logger.warning( logger.warning(
f"group_size: {group_size}," f"group_size: {group_size},"
f"kv_cache_quant_algo: {kv_cache_quant_algo}," f"kv_cache_quant_algo: {kv_cache_quant_algo},"
...@@ -508,8 +583,7 @@ class ModelOptFp4Config(QuantizationConfig): ...@@ -508,8 +583,7 @@ class ModelOptFp4Config(QuantizationConfig):
) )
raise ValueError( raise ValueError(
"NVFP4 quantization requires group size and " "NVFP4 quantization requires group size and "
"kv_cache_quant_algo specified in " "kv_cache_quant_algo specified in the quantization config"
"hf_quant_config.json"
) )
return cls( return cls(
is_checkpoint_nvfp4_serialized, is_checkpoint_nvfp4_serialized,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment