import json import logging import os from dataclasses import dataclass, field, fields from os.path import isdir, join from typing import Optional import huggingface_hub from transformers.utils.hub import PushToHubMixin, cached_file logger = logging.getLogger(__name__) handler = logging.StreamHandler() formatter = logging.Formatter("%(levelname)s - %(message)s") handler.setFormatter(formatter) logger.propagate = False logger.addHandler(handler) logger.setLevel(logging.INFO) CHECKPOINT_FORMAT_FIELD = "checkpoint_format" CHECKPOINT_FORMAT_FIELD_COMPAT_MARLIN = "is_marlin_format" QUANT_METHOD_FIELD = "quant_method" QUANT_CONFIG_FILENAME = "quantize_config.json" # checkpoint formats class CHECKPOINT_FORMAT: GPTQ = "gptq" MARLIN = "marlin" AWQ_GEMM = "gemm" # quant methods class QUANT_METHOD: GPTQ = "gptq" AWQ = "awq" QUANT_METHOD_FORMAT_MAPPING = { QUANT_METHOD.GPTQ: { CHECKPOINT_FORMAT.GPTQ, CHECKPOINT_FORMAT.MARLIN, }, QUANT_METHOD.AWQ: { CHECKPOINT_FORMAT.AWQ_GEMM } } # awq is inference only QUANTIZE_BLACK_LIST = {QUANT_METHOD.AWQ} # compat QUANT_CONFIG_ARG_SYNONYMS = { "w_bit": "bits", "q_group_size": "group_size", } @dataclass class BaseQuantizeConfig(PushToHubMixin): bits: int = field(default=4, metadata={"choices": [2, 3, 4, 8]}) group_size: int = field(default=-1) damp_percent: float = field(default=0.01) desc_act: bool = field(default=True) static_groups: bool = field(default=False) sym: bool = field(default=True) true_sequential: bool = field(default=True) quant_method: str = field(default=QUANT_METHOD.GPTQ) checkpoint_format: str = field(default=CHECKPOINT_FORMAT.GPTQ) model_name_or_path: Optional[str] = field(default=None) model_file_base_name: Optional[str] = field(default=None) def __post_init__(self): fields_info = fields(self) # validate quant method and format is matched valid_checkpoint_formats = QUANT_METHOD_FORMAT_MAPPING.get(self.quant_method, None) if valid_checkpoint_formats is None: raise ValueError(f"Unsupported quantization method: {self.quant_method}") if self.checkpoint_format not in valid_checkpoint_formats: raise ValueError( f"The checkpoint format used is {self.checkpoint_format}, and the quantization method is {self.quant_method}. " f"This is not supported, please open an issue at https://github.com/AutoGPTQ/AutoGPTQ/issues.") if self.bits not in fields_info[0].metadata["choices"]: raise ValueError(f"only support quantize to {fields_info[0].metadata['choices']} bits.") if self.group_size != -1 and self.group_size <= 0: raise ValueError("unless equal to -1, group_size must greater then 0.") if not (0 < self.damp_percent < 1): raise ValueError("damp_percent must between 0 and 1.") def save_pretrained(self, save_dir: str, **kwargs): with open(join(save_dir, QUANT_CONFIG_FILENAME), "w", encoding="utf-8") as f: json.dump(self.to_dict(), f, indent=2) @classmethod # normalize quant config for compat and also performs validation def from_quant_config(cls, quantize_cfg, checkpoint_format: str = None): valid_formats = {CHECKPOINT_FORMAT.GPTQ, CHECKPOINT_FORMAT.MARLIN, CHECKPOINT_FORMAT.AWQ_GEMM} checkpoint_format_auto_inferred = False # compat: checkpoint_format can be passed in via from_quantized() if field missing from json if checkpoint_format: if checkpoint_format not in valid_formats: raise ValueError(f"Unknown quantization checkpoint format: {checkpoint_format}.") if quantize_cfg.get(CHECKPOINT_FORMAT_FIELD): raise ValueError("Conflict: quantization checkpoint_format is passed in and also exists in model config.") # compat: warn if checkpoint_format is missing elif quantize_cfg.get(CHECKPOINT_FORMAT_FIELD) is None: checkpoint_format_auto_inferred = True field_names = [field.name for field in fields(cls)] normalized = {QUANT_METHOD_FIELD: QUANT_METHOD.GPTQ, CHECKPOINT_FORMAT_FIELD: checkpoint_format if checkpoint_format else CHECKPOINT_FORMAT.GPTQ} for key, val in quantize_cfg.items(): key = key.lower() # remap keys according to compat map if key in QUANT_CONFIG_ARG_SYNONYMS and QUANT_CONFIG_ARG_SYNONYMS[key] in field_names: key = QUANT_CONFIG_ARG_SYNONYMS[key] if key == CHECKPOINT_FORMAT_FIELD: val = val.lower() if val in {CHECKPOINT_FORMAT.GPTQ, CHECKPOINT_FORMAT.MARLIN, CHECKPOINT_FORMAT.AWQ_GEMM}: normalized[key] = val else: raise ValueError(f"Unknown quantization format: {val}.") elif key == QUANT_METHOD_FIELD: val = val.lower() # compat: some hf models use quant_method=marlin if val == CHECKPOINT_FORMAT.MARLIN: normalized[CHECKPOINT_FORMAT_FIELD] = CHECKPOINT_FORMAT.MARLIN elif val not in {QUANT_METHOD.GPTQ, QUANT_METHOD.AWQ}: raise ValueError(f"Unknown quantization method: {val}.") else: normalized[QUANT_METHOD_FIELD] = val elif key == CHECKPOINT_FORMAT_FIELD_COMPAT_MARLIN and val: normalized[CHECKPOINT_FORMAT_FIELD] = CHECKPOINT_FORMAT.MARLIN elif key == "version" and val.lower() == CHECKPOINT_FORMAT.AWQ_GEMM: normalized[QUANT_METHOD_FIELD] = QUANT_METHOD.AWQ normalized[CHECKPOINT_FORMAT_FIELD] = CHECKPOINT_FORMAT.AWQ_GEMM elif key in field_names: normalized[key] = val else: logger.info(f"Ignoring unknown parameter in the quantization configuration: {key}.") if checkpoint_format_auto_inferred: logger.info(f"`checkpoint_format` is missing from the quantization configuration and is automatically inferred to {normalized[CHECKPOINT_FORMAT_FIELD]}.") if normalized[CHECKPOINT_FORMAT_FIELD] in {CHECKPOINT_FORMAT.AWQ_GEMM, CHECKPOINT_FORMAT.MARLIN}: # AWQ and Marlin do not reorder the rows. normalized["desc_act"] = False if "sym" not in normalized: logger.warning( "The quantization configuration does not contain an entry `sym` (symmetric quantization). " "This may result in silent errors. Defaulting to `sym=True`." ) return cls(**normalized) @classmethod def from_pretrained(cls, save_dir: str, **kwargs): # Parameters related to loading from Hugging Face Hub cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", False) use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) commit_hash = kwargs.pop("_commit_hash", None) checkpoint_format = kwargs.pop("checkpoint_format", None) transformers_config = False for quantize_config_filename in [ QUANT_CONFIG_FILENAME, "quant_config.json", "config.json", ]: if isdir(save_dir): # Local resolved_config_file = join(save_dir, quantize_config_filename) else: # Remote resolved_config_file = cached_file( save_dir, quantize_config_filename, cache_dir=cache_dir, force_download=force_download, resume_download=resume_download, proxies=proxies, use_auth_token=use_auth_token, revision=revision, local_files_only=local_files_only, subfolder=subfolder, _raise_exceptions_for_missing_entries=False, _raise_exceptions_for_connection_errors=False, _commit_hash=commit_hash, ) if resolved_config_file is not None: if quantize_config_filename == "config.json": transformers_config = True break if resolved_config_file is None: raise ValueError( "No quantize_config.json, quant_config.json or config.json file was found in the model repository." ) with open(resolved_config_file, "r", encoding="utf-8") as f: args_from_json = json.load(f) if transformers_config: args_from_json = args_from_json["quantization_config"] return cls.from_quant_config(args_from_json, checkpoint_format) def get_cache_file_path(self, quant_method: QUANT_METHOD = None, checkpoint_format: CHECKPOINT_FORMAT = None): """ Gets The Cached Weight Path. If remote: $HF_HOME/assets/autogptq/{model_name_or_path}/_{quant-method}_{checkpoint_format}.safetensors If local: {model_name_or_path}/autogptq_model_{quant-method}_{checkpoint_format}.safetensors """ use_quant_method = quant_method if quant_method else self.quant_method use_checkpoint_format = checkpoint_format if checkpoint_format else self.checkpoint_format cache_file_name = f"autogptq_model_{use_quant_method}_{use_checkpoint_format}.safetensors" if os.path.isdir(self.model_name_or_path): cache_file_name = os.path.join(self.model_name_or_path, cache_file_name) else: namespace, subfolder = self.model_name_or_path.split("/") assets_path = huggingface_hub.cached_assets_path( library_name="auto_gptq", namespace=namespace, subfolder=subfolder ) cache_file_name = os.path.join(assets_path, cache_file_name) return cache_file_name, os.path.isfile(cache_file_name) def to_dict(self): return { "bits": self.bits, "group_size": self.group_size, "damp_percent": self.damp_percent, "desc_act": self.desc_act, "static_groups": self.static_groups, "sym": self.sym, "true_sequential": self.true_sequential, "model_name_or_path": self.model_name_or_path, "model_file_base_name": self.model_file_base_name, QUANT_METHOD_FIELD: self.quant_method, CHECKPOINT_FORMAT_FIELD: self.checkpoint_format, }