Unverified Commit 4c764007 authored by Enrique Shockwave's avatar Enrique Shockwave Committed by GitHub
Browse files

check marlin format before attempting conversion (#4675)

parent 9f3bd2ad
......@@ -37,6 +37,14 @@ except ImportError:
logger = logging.getLogger(__name__)
def check_marlin_format(hf_quant_cfg: Dict[str, Any]) -> bool:
# compat: gptqmodel and autogptq (eol) main use checkpoint_format: str
# compat: autogptq <=0.7.1 is_marlin_format: bool
return hf_quant_cfg.get("checkpoint_format") == "marlin" or hf_quant_cfg.get(
"is_marlin_format", False
)
class GPTQConfig(QuantizationConfig):
"""Config class for GPTQ.
......@@ -262,13 +270,15 @@ class GPTQMarlinConfig(QuantizationConfig):
@classmethod
def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
is_marlin_format = check_marlin_format(hf_quant_cfg)
can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg)
is_valid_user_quant = (
user_quant is None or user_quant == "marlin" or user_quant == "gptq_marlin"
)
if can_convert and is_valid_user_quant:
if not is_marlin_format and can_convert and is_valid_user_quant:
msg = (
"The model is convertible to {} during runtime."
" Using {} kernel.".format(cls.get_name(), cls.get_name())
......@@ -276,7 +286,7 @@ class GPTQMarlinConfig(QuantizationConfig):
logger.info(msg)
return cls.get_name()
if can_convert and user_quant == "gptq":
if not is_marlin_format and can_convert and user_quant == "gptq":
logger.info(
"Detected that the model can run with gptq_marlin"
", however you specified quantization=gptq explicitly,"
......@@ -401,11 +411,7 @@ class MarlinConfig(QuantizationConfig):
@classmethod
def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
# compat: autogptq >=0.8.0 use checkpoint_format: str
# compat: autogptq <=0.7.1 is_marlin_format: bool
is_marlin_format = hf_quant_cfg.get(
"checkpoint_format"
) == "marlin" or hf_quant_cfg.get("is_marlin_format", False)
is_marlin_format = check_marlin_format(hf_quant_cfg)
is_valid_user_quant = (
user_quant is None or user_quant == "gptq" or user_quant == "marlin"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment