Unverified Commit 4c764007 authored by Enrique Shockwave's avatar Enrique Shockwave Committed by GitHub
Browse files

check marlin format before attempting conversion (#4675)

parent 9f3bd2ad
...@@ -37,6 +37,14 @@ except ImportError: ...@@ -37,6 +37,14 @@ except ImportError:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def check_marlin_format(hf_quant_cfg: Dict[str, Any]) -> bool:
# compat: gptqmodel and autogptq (eol) main use checkpoint_format: str
# compat: autogptq <=0.7.1 is_marlin_format: bool
return hf_quant_cfg.get("checkpoint_format") == "marlin" or hf_quant_cfg.get(
"is_marlin_format", False
)
class GPTQConfig(QuantizationConfig): class GPTQConfig(QuantizationConfig):
"""Config class for GPTQ. """Config class for GPTQ.
...@@ -262,13 +270,15 @@ class GPTQMarlinConfig(QuantizationConfig): ...@@ -262,13 +270,15 @@ class GPTQMarlinConfig(QuantizationConfig):
@classmethod @classmethod
def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]: def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
is_marlin_format = check_marlin_format(hf_quant_cfg)
can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg) can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg)
is_valid_user_quant = ( is_valid_user_quant = (
user_quant is None or user_quant == "marlin" or user_quant == "gptq_marlin" user_quant is None or user_quant == "marlin" or user_quant == "gptq_marlin"
) )
if can_convert and is_valid_user_quant: if not is_marlin_format and can_convert and is_valid_user_quant:
msg = ( msg = (
"The model is convertible to {} during runtime." "The model is convertible to {} during runtime."
" Using {} kernel.".format(cls.get_name(), cls.get_name()) " Using {} kernel.".format(cls.get_name(), cls.get_name())
...@@ -276,7 +286,7 @@ class GPTQMarlinConfig(QuantizationConfig): ...@@ -276,7 +286,7 @@ class GPTQMarlinConfig(QuantizationConfig):
logger.info(msg) logger.info(msg)
return cls.get_name() return cls.get_name()
if can_convert and user_quant == "gptq": if not is_marlin_format and can_convert and user_quant == "gptq":
logger.info( logger.info(
"Detected that the model can run with gptq_marlin" "Detected that the model can run with gptq_marlin"
", however you specified quantization=gptq explicitly," ", however you specified quantization=gptq explicitly,"
...@@ -401,11 +411,7 @@ class MarlinConfig(QuantizationConfig): ...@@ -401,11 +411,7 @@ class MarlinConfig(QuantizationConfig):
@classmethod @classmethod
def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]: def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
# compat: autogptq >=0.8.0 use checkpoint_format: str is_marlin_format = check_marlin_format(hf_quant_cfg)
# compat: autogptq <=0.7.1 is_marlin_format: bool
is_marlin_format = hf_quant_cfg.get(
"checkpoint_format"
) == "marlin" or hf_quant_cfg.get("is_marlin_format", False)
is_valid_user_quant = ( is_valid_user_quant = (
user_quant is None or user_quant == "gptq" or user_quant == "marlin" user_quant is None or user_quant == "gptq" or user_quant == "marlin"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment