Unverified Commit eddaa2b5 authored by Qubitium's avatar Qubitium Committed by GitHub
Browse files

Add support for new autogptq quant_config.checkpoint_format (#332)

parent 2af565b3
...@@ -19,7 +19,11 @@ from vllm.model_executor.layers.quantization.marlin import MarlinConfig ...@@ -19,7 +19,11 @@ from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.model_loader import _set_default_torch_dtype from vllm.model_executor.model_loader import _set_default_torch_dtype
from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel
QUANTIONCONFIG_MAPPING = {"awq": AWQConfig, "gptq": GPTQConfig, "marlin": MarlinConfig} QUANTIZATION_CONFIG_MAPPING = {
"awq": AWQConfig,
"gptq": GPTQConfig,
"marlin": MarlinConfig,
}
logger = logging.getLogger("model_runner") logger = logging.getLogger("model_runner")
...@@ -300,30 +304,31 @@ class ModelRunner: ...@@ -300,30 +304,31 @@ class ModelRunner:
# Load weights # Load weights
linear_method = None linear_method = None
quant_cfg = getattr(self.model_config.hf_config, "quantization_config", None)
if quant_cfg is not None:
quant_method = quant_cfg.get("quant_method", "").lower()
# compat: autogptq >=0.8.0 use checkpoint_format: str
# compat: autogptq <=0.7.1 is_marlin_format: bool
is_format_marlin = quant_cfg.get(
"checkpoint_format"
) == "marlin" or quant_cfg.get("is_marlin_format", False)
# Use marlin if the GPTQ model is serialized in marlin format.
if quant_method == "gptq" and is_format_marlin:
quant_method = "marlin"
quant_config_class = QUANTIZATION_CONFIG_MAPPING.get(quant_method)
if quant_config_class is None:
raise ValueError(f"Unsupported quantization method: {quant_method}")
quant_config = quant_config_class.from_config(quant_cfg)
logger.info(f"quant_config: {quant_config}")
linear_method = quant_config.get_linear_method()
with _set_default_torch_dtype(torch.float16): with _set_default_torch_dtype(torch.float16):
with torch.device("cuda"): with torch.device("cuda"):
hf_quant_config = getattr(
self.model_config.hf_config, "quantization_config", None
)
if hf_quant_config is not None:
hf_quant_method = hf_quant_config["quant_method"]
# compat: autogptq uses is_marlin_format within quant config
if (
hf_quant_method == "gptq"
and "is_marlin_format" in hf_quant_config
and hf_quant_config["is_marlin_format"]
):
hf_quant_method = "marlin"
quant_config_class = QUANTIONCONFIG_MAPPING.get(hf_quant_method)
if quant_config_class is None:
raise ValueError(
f"Unsupported quantization method: {hf_quant_config['quant_method']}"
)
quant_config = quant_config_class.from_config(hf_quant_config)
logger.info(f"quant_config: {quant_config}")
linear_method = quant_config.get_linear_method()
model = model_class( model = model_class(
config=self.model_config.hf_config, linear_method=linear_method config=self.model_config.hf_config, linear_method=linear_method
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment