Unverified Commit ff604064 authored by Kaixi Hou's avatar Kaixi Hou Committed by GitHub
Browse files

[NVIDIA] Change default quant method for model_opt (#11991)

parent 212f5e48
......@@ -535,7 +535,7 @@ class ModelConfig:
quant_cfg = self._parse_modelopt_quant_config(quant_config_dict)
return quant_cfg
def _parse_modelopt_quant_config(self, quant_config_dict: dict) -> dict:
def _parse_modelopt_quant_config(self, quant_config_dict: dict) -> Optional[dict]:
"""Parse ModelOpt quantization config and return the appropriate quant_method."""
json_quant_configs = quant_config_dict["quantization"]
quant_algo = json_quant_configs.get("quant_algo", None)
......@@ -547,8 +547,7 @@ class ModelConfig:
elif quant_algo and "FP8" in quant_algo:
return {"quant_method": "modelopt_fp8"}
else:
# Default to FP8 for backward compatibility
return {"quant_method": "modelopt_fp8"}
return None
def _is_already_quantized(self) -> bool:
"""Check if the model is already quantized based on config files."""
......
......@@ -179,6 +179,13 @@ class QuantizationConfig(ABC):
elif "NVFP4" in quant_algo or "FP4" in quant_algo:
return "modelopt_fp4"
# The hf_quant_config may be a parsed quant config, so we need to check the
# quant_method.
if hf_quant_config.get("quant_method", "") == "modelopt_fp8":
return "modelopt_fp8"
elif hf_quant_config.get("quant_method", "") == "modelopt_fp4":
return "modelopt_fp4"
return None
@staticmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment