Commit da900c3b authored by yangql's avatar yangql
Browse files

Initial commit

parents
import json
import logging
import os
from dataclasses import dataclass, field, fields
from os.path import isdir, join
from typing import Optional
import huggingface_hub
from transformers.utils.hub import PushToHubMixin, cached_file
logger = logging.getLogger(__name__)
handler = logging.StreamHandler()
formatter = logging.Formatter("%(levelname)s - %(message)s")
handler.setFormatter(formatter)
logger.propagate = False
logger.addHandler(handler)
logger.setLevel(logging.INFO)
CHECKPOINT_FORMAT_FIELD = "checkpoint_format"
CHECKPOINT_FORMAT_FIELD_COMPAT_MARLIN = "is_marlin_format"
QUANT_METHOD_FIELD = "quant_method"
QUANT_CONFIG_FILENAME = "quantize_config.json"
# checkpoint formats
class CHECKPOINT_FORMAT:
GPTQ = "gptq"
MARLIN = "marlin"
AWQ_GEMM = "gemm"
# quant methods
class QUANT_METHOD:
GPTQ = "gptq"
AWQ = "awq"
QUANT_METHOD_FORMAT_MAPPING = {
QUANT_METHOD.GPTQ: {
CHECKPOINT_FORMAT.GPTQ,
CHECKPOINT_FORMAT.MARLIN,
},
QUANT_METHOD.AWQ: {
CHECKPOINT_FORMAT.AWQ_GEMM
}
}
# awq is inference only
QUANTIZE_BLACK_LIST = {QUANT_METHOD.AWQ}
# compat
QUANT_CONFIG_ARG_SYNONYMS = {
"w_bit": "bits",
"q_group_size": "group_size",
}
@dataclass
class BaseQuantizeConfig(PushToHubMixin):
bits: int = field(default=4, metadata={"choices": [2, 3, 4, 8]})
group_size: int = field(default=-1)
damp_percent: float = field(default=0.01)
desc_act: bool = field(default=True)
static_groups: bool = field(default=False)
sym: bool = field(default=True)
true_sequential: bool = field(default=True)
quant_method: str = field(default=QUANT_METHOD.GPTQ)
checkpoint_format: str = field(default=CHECKPOINT_FORMAT.GPTQ)
model_name_or_path: Optional[str] = field(default=None)
model_file_base_name: Optional[str] = field(default=None)
def __post_init__(self):
fields_info = fields(self)
# validate quant method and format is matched
valid_checkpoint_formats = QUANT_METHOD_FORMAT_MAPPING.get(self.quant_method, None)
if valid_checkpoint_formats is None:
raise ValueError(f"Unsupported quantization method: {self.quant_method}")
if self.checkpoint_format not in valid_checkpoint_formats:
raise ValueError(
f"The checkpoint format used is {self.checkpoint_format}, and the quantization method is {self.quant_method}. "
f"This is not supported, please open an issue at https://github.com/AutoGPTQ/AutoGPTQ/issues.")
if self.bits not in fields_info[0].metadata["choices"]:
raise ValueError(f"only support quantize to {fields_info[0].metadata['choices']} bits.")
if self.group_size != -1 and self.group_size <= 0:
raise ValueError("unless equal to -1, group_size must greater then 0.")
if not (0 < self.damp_percent < 1):
raise ValueError("damp_percent must between 0 and 1.")
def save_pretrained(self, save_dir: str, **kwargs):
with open(join(save_dir, QUANT_CONFIG_FILENAME), "w", encoding="utf-8") as f:
json.dump(self.to_dict(), f, indent=2)
@classmethod
# normalize quant config for compat and also performs validation
def from_quant_config(cls, quantize_cfg, checkpoint_format: str = None):
valid_formats = {CHECKPOINT_FORMAT.GPTQ, CHECKPOINT_FORMAT.MARLIN, CHECKPOINT_FORMAT.AWQ_GEMM}
checkpoint_format_auto_inferred = False
# compat: checkpoint_format can be passed in via from_quantized() if field missing from json
if checkpoint_format:
if checkpoint_format not in valid_formats:
raise ValueError(f"Unknown quantization checkpoint format: {checkpoint_format}.")
if quantize_cfg.get(CHECKPOINT_FORMAT_FIELD):
raise ValueError("Conflict: quantization checkpoint_format is passed in and also exists in model config.")
# compat: warn if checkpoint_format is missing
elif quantize_cfg.get(CHECKPOINT_FORMAT_FIELD) is None:
checkpoint_format_auto_inferred = True
field_names = [field.name for field in fields(cls)]
normalized = {QUANT_METHOD_FIELD: QUANT_METHOD.GPTQ, CHECKPOINT_FORMAT_FIELD: checkpoint_format if checkpoint_format else CHECKPOINT_FORMAT.GPTQ}
for key, val in quantize_cfg.items():
key = key.lower()
# remap keys according to compat map
if key in QUANT_CONFIG_ARG_SYNONYMS and QUANT_CONFIG_ARG_SYNONYMS[key] in field_names:
key = QUANT_CONFIG_ARG_SYNONYMS[key]
if key == CHECKPOINT_FORMAT_FIELD:
val = val.lower()
if val in {CHECKPOINT_FORMAT.GPTQ, CHECKPOINT_FORMAT.MARLIN, CHECKPOINT_FORMAT.AWQ_GEMM}:
normalized[key] = val
else:
raise ValueError(f"Unknown quantization format: {val}.")
elif key == QUANT_METHOD_FIELD:
val = val.lower()
# compat: some hf models use quant_method=marlin
if val == CHECKPOINT_FORMAT.MARLIN:
normalized[CHECKPOINT_FORMAT_FIELD] = CHECKPOINT_FORMAT.MARLIN
elif val not in {QUANT_METHOD.GPTQ, QUANT_METHOD.AWQ}:
raise ValueError(f"Unknown quantization method: {val}.")
else:
normalized[QUANT_METHOD_FIELD] = val
elif key == CHECKPOINT_FORMAT_FIELD_COMPAT_MARLIN and val:
normalized[CHECKPOINT_FORMAT_FIELD] = CHECKPOINT_FORMAT.MARLIN
elif key == "version" and val.lower() == CHECKPOINT_FORMAT.AWQ_GEMM:
normalized[QUANT_METHOD_FIELD] = QUANT_METHOD.AWQ
normalized[CHECKPOINT_FORMAT_FIELD] = CHECKPOINT_FORMAT.AWQ_GEMM
elif key in field_names:
normalized[key] = val
else:
logger.info(f"Ignoring unknown parameter in the quantization configuration: {key}.")
if checkpoint_format_auto_inferred:
logger.info(f"`checkpoint_format` is missing from the quantization configuration and is automatically inferred to {normalized[CHECKPOINT_FORMAT_FIELD]}.")
if normalized[CHECKPOINT_FORMAT_FIELD] in {CHECKPOINT_FORMAT.AWQ_GEMM, CHECKPOINT_FORMAT.MARLIN}:
# AWQ and Marlin do not reorder the rows.
normalized["desc_act"] = False
if "sym" not in normalized:
logger.warning(
"The quantization configuration does not contain an entry `sym` (symmetric quantization). "
"This may result in silent errors. Defaulting to `sym=True`."
)
return cls(**normalized)
@classmethod
def from_pretrained(cls, save_dir: str, **kwargs):
# Parameters related to loading from Hugging Face Hub
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
commit_hash = kwargs.pop("_commit_hash", None)
checkpoint_format = kwargs.pop("checkpoint_format", None)
transformers_config = False
for quantize_config_filename in [
QUANT_CONFIG_FILENAME,
"quant_config.json",
"config.json",
]:
if isdir(save_dir): # Local
resolved_config_file = join(save_dir, quantize_config_filename)
else: # Remote
resolved_config_file = cached_file(
save_dir,
quantize_config_filename,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
use_auth_token=use_auth_token,
revision=revision,
local_files_only=local_files_only,
subfolder=subfolder,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
_commit_hash=commit_hash,
)
if resolved_config_file is not None:
if quantize_config_filename == "config.json":
transformers_config = True
break
if resolved_config_file is None:
raise ValueError(
"No quantize_config.json, quant_config.json or config.json file was found in the model repository."
)
with open(resolved_config_file, "r", encoding="utf-8") as f:
args_from_json = json.load(f)
if transformers_config:
args_from_json = args_from_json["quantization_config"]
return cls.from_quant_config(args_from_json, checkpoint_format)
def get_cache_file_path(self, quant_method: QUANT_METHOD = None, checkpoint_format: CHECKPOINT_FORMAT = None):
"""
Gets The Cached Weight Path.
If remote: $HF_HOME/assets/autogptq/{model_name_or_path}/_{quant-method}_{checkpoint_format}.safetensors
If local: {model_name_or_path}/autogptq_model_{quant-method}_{checkpoint_format}.safetensors
"""
use_quant_method = quant_method if quant_method else self.quant_method
use_checkpoint_format = checkpoint_format if checkpoint_format else self.checkpoint_format
cache_file_name = f"autogptq_model_{use_quant_method}_{use_checkpoint_format}.safetensors"
if os.path.isdir(self.model_name_or_path):
cache_file_name = os.path.join(self.model_name_or_path, cache_file_name)
else:
namespace, subfolder = self.model_name_or_path.split("/")
assets_path = huggingface_hub.cached_assets_path(
library_name="auto_gptq", namespace=namespace, subfolder=subfolder
)
cache_file_name = os.path.join(assets_path, cache_file_name)
return cache_file_name, os.path.isfile(cache_file_name)
def to_dict(self):
return {
"bits": self.bits,
"group_size": self.group_size,
"damp_percent": self.damp_percent,
"desc_act": self.desc_act,
"static_groups": self.static_groups,
"sym": self.sym,
"true_sequential": self.true_sequential,
"model_name_or_path": self.model_name_or_path,
"model_file_base_name": self.model_file_base_name,
QUANT_METHOD_FIELD: self.quant_method,
CHECKPOINT_FORMAT_FIELD: self.checkpoint_format,
}
import math
import os
import time
from logging import getLogger
import torch
import torch.nn as nn
import transformers
from .quantizer import Quantizer
logger = getLogger(__name__)
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
class GPTQ:
def __init__(self, layer):
self.layer = layer
self.dev = self.layer.weight.device
W = layer.weight.data.clone()
if isinstance(self.layer, nn.Conv2d):
W = W.flatten(1)
if isinstance(self.layer, transformers.pytorch_utils.Conv1D):
W = W.t()
self.rows = W.shape[0]
self.columns = W.shape[1]
self.H = torch.zeros((self.columns, self.columns), device=self.dev)
self.nsamples = 0
self.quantizer = Quantizer()
def add_batch(self, inp, out):
if os.environ.get("DEBUG"):
self.inp1 = inp
self.out1 = out
if len(inp.shape) == 2:
inp = inp.unsqueeze(0)
tmp = inp.shape[0]
if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D):
if len(inp.shape) == 3:
inp = inp.reshape((-1, inp.shape[-1]))
inp = inp.t()
if isinstance(self.layer, nn.Conv2d):
unfold = nn.Unfold(
self.layer.kernel_size,
dilation=self.layer.dilation,
padding=self.layer.padding,
stride=self.layer.stride,
)
inp = unfold(inp)
inp = inp.permute([1, 0, 2])
inp = inp.flatten(1)
self.H *= self.nsamples / (self.nsamples + tmp)
self.nsamples += tmp
# inp = inp.float()
inp = math.sqrt(2 / self.nsamples) * inp.float()
# self.H += 2 / self.nsamples * inp.matmul(inp.t())
self.H += inp.matmul(inp.t())
def fasterquant(
self,
blocksize=128,
percdamp=0.01,
group_size=-1,
actorder=False,
static_groups=False,
):
W = self.layer.weight.data.clone()
if isinstance(self.layer, nn.Conv2d):
W = W.flatten(1)
if isinstance(self.layer, transformers.Conv1D):
W = W.t()
W = W.float()
tick = time.time()
if not self.quantizer.ready():
self.quantizer.find_params(W, weight=True)
H = self.H
del self.H
dead = torch.diag(H) == 0
H[dead, dead] = 1
W[:, dead] = 0
g_idx = []
scale = []
zero = []
now_idx = 1
if static_groups:
import copy
groups = []
for i in range(0, self.columns, group_size):
quantizer = copy.deepcopy(self.quantizer)
quantizer.find_params(W[:, i : (i + group_size)], weight=True)
scale.append(quantizer.scale)
zero.append(quantizer.zero)
groups.append(quantizer)
if actorder:
perm = torch.argsort(torch.diag(H), descending=True)
W = W[:, perm]
H = H[perm][:, perm]
invperm = torch.argsort(perm)
Losses = torch.zeros_like(W)
Q = torch.zeros_like(W)
damp = percdamp * torch.mean(torch.diag(H))
diag = torch.arange(self.columns, device=self.dev)
H[diag, diag] += damp
try:
H = torch.linalg.cholesky(H)
H = torch.cholesky_inverse(H)
H = torch.linalg.cholesky(H, upper=True)
except torch._C._LinAlgError as e:
import pdb
pdb.set_trace()
print("Cholesky decomposition failed: ", e)
epsilon = 1e-3
# 添加一个小的对角线偏移以增强正定性
H += epsilon * torch.eye(H.size(0), device=H.get_device())
H = torch.linalg.cholesky(H)
H = torch.cholesky_inverse(H)
H = torch.linalg.cholesky(H, upper=True)
Hinv = H
for i1 in range(0, self.columns, blocksize):
i2 = min(i1 + blocksize, self.columns)
count = i2 - i1
W1 = W[:, i1:i2].clone()
Q1 = torch.zeros_like(W1)
Err1 = torch.zeros_like(W1)
Losses1 = torch.zeros_like(W1)
Hinv1 = Hinv[i1:i2, i1:i2]
for i in range(count):
w = W1[:, i]
d = Hinv1[i, i]
if group_size != -1:
if not static_groups:
if (i1 + i) % group_size == 0:
self.quantizer.find_params(W[:, (i1 + i) : (i1 + i + group_size)], weight=True)
if ((i1 + i) // group_size) - now_idx == -1:
scale.append(self.quantizer.scale)
zero.append(self.quantizer.zero)
now_idx += 1
else:
idx = i1 + i
if actorder:
idx = perm[idx]
self.quantizer = groups[idx // group_size]
q = self.quantizer.quantize(w.unsqueeze(1)).flatten()
Q1[:, i] = q
Losses1[:, i] = (w - q) ** 2 / d**2
err1 = (w - q) / d
W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
Err1[:, i] = err1
Q[:, i1:i2] = Q1
Losses[:, i1:i2] = Losses1 / 2
W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])
if os.environ.get("DEBUG"):
self.layer.weight.data[:, :i2] = Q[:, :i2]
self.layer.weight.data[:, i2:] = W[:, i2:]
logger.debug(torch.sum((self.layer(self.inp1) - self.out1) ** 2))
logger.debug(torch.sum(Losses))
torch.cuda.synchronize()
logger.info(f"duration: {(time.time() - tick)}")
logger.info(f"avg loss: {torch.sum(Losses).item() / self.nsamples}")
group_size = group_size if group_size != -1 else self.columns
if static_groups and actorder:
g_idx = [perm[i] // group_size for i in range(self.columns)]
else:
g_idx = [i // group_size for i in range(self.columns)]
g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device)
if actorder:
Q = Q[:, invperm]
g_idx = g_idx[invperm]
if isinstance(self.layer, transformers.Conv1D):
Q = Q.t()
self.layer.weight.data = Q.reshape(self.layer.weight.shape).type_as(self.layer.weight.data)
if os.environ.get("DEBUG"):
logger.debug(torch.sum((self.layer(self.inp1) - self.out1) ** 2))
if scale == []:
scale.append(self.quantizer.scale)
zero.append(self.quantizer.zero)
scale = torch.cat(scale, dim=1)
zero = torch.cat(zero, dim=1)
return scale, zero, g_idx
def free(self):
if os.environ.get("DEBUG"):
self.inp1 = None
self.out1 = None
self.H = None
self.Losses = None
self.Trace = None
torch.cuda.empty_cache()
__all__ = ["GPTQ"]
from logging import getLogger
import torch
import torch.nn as nn
logger = getLogger(__name__)
def quantize(x, scale, zero, maxq):
if maxq < 0:
return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero
q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
return scale * (q - zero)
class Quantizer(nn.Module):
def __init__(self, shape=1):
super(Quantizer, self).__init__()
self.register_buffer("maxq", torch.tensor(0))
self.register_buffer("scale", torch.zeros(shape))
self.register_buffer("zero", torch.zeros(shape))
def configure(
self,
bits,
perchannel=False,
sym=True,
mse=False,
norm=2.4,
grid=100,
maxshrink=0.8,
trits=False,
):
self.maxq = torch.tensor(2**bits - 1)
self.perchannel = perchannel
self.sym = sym
self.mse = mse
self.norm = norm
self.grid = grid
self.maxshrink = maxshrink
if trits:
self.maxq = torch.tensor(-1)
def find_params(self, x, weight=False):
dev = x.device
self.maxq = self.maxq.to(dev)
shape = x.shape
if self.perchannel:
if weight:
x = x.flatten(1)
else:
if len(shape) == 4:
x = x.permute([1, 0, 2, 3])
x = x.flatten(1)
if len(shape) == 3:
x = x.reshape((-1, shape[-1])).t()
if len(shape) == 2:
x = x.t()
else:
x = x.flatten().unsqueeze(0)
tmp = torch.zeros(x.shape[0], device=dev)
xmin = torch.minimum(x.min(1)[0], tmp)
xmax = torch.maximum(x.max(1)[0], tmp)
if self.sym:
xmax = torch.maximum(torch.abs(xmin), xmax)
tmp = xmin < 0
if torch.any(tmp):
xmin[tmp] = -xmax[tmp]
tmp = (xmin == 0) & (xmax == 0)
xmin[tmp] = -1
xmax[tmp] = +1
if self.maxq < 0:
self.scale = xmax
self.zero = xmin
else:
self.scale = (xmax - xmin) / self.maxq
if self.sym:
self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2)
else:
self.zero = torch.round(-xmin / self.scale)
if self.mse:
best = torch.full([x.shape[0]], float("inf"), device=dev)
for i in range(int(self.maxshrink * self.grid)):
p = 1 - i / self.grid
xmin1 = p * xmin
xmax1 = p * xmax
scale1 = (xmax1 - xmin1) / self.maxq
zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero
q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq)
q -= x
q.abs_()
q.pow_(self.norm)
err = torch.sum(q, 1)
tmp = err < best
if torch.any(tmp):
best[tmp] = err[tmp]
self.scale[tmp] = scale1[tmp]
self.zero[tmp] = zero1[tmp]
if not self.perchannel:
if weight:
tmp = shape[0]
else:
tmp = shape[1] if len(shape) != 3 else shape[2]
self.scale = self.scale.repeat(tmp)
self.zero = self.zero.repeat(tmp)
if weight:
shape = [-1] + [1] * (len(shape) - 1)
self.scale = self.scale.reshape(shape)
self.zero = self.zero.reshape(shape)
return
if len(shape) == 4:
self.scale = self.scale.reshape((1, -1, 1, 1))
self.zero = self.zero.reshape((1, -1, 1, 1))
if len(shape) == 3:
self.scale = self.scale.reshape((1, 1, -1))
self.zero = self.zero.reshape((1, 1, -1))
if len(shape) == 2:
self.scale = self.scale.unsqueeze(0)
self.zero = self.zero.unsqueeze(0)
def quantize(self, x):
if self.ready():
return quantize(x, self.scale, self.zero, self.maxq)
return x
def enabled(self):
return self.maxq > 0
def ready(self):
return torch.all(self.scale != 0)
__all__ = ["Quantizer"]
from .perplexity_utils import Perplexity
import gc
import json
import logging
import os
import shutil
import tempfile
from typing import Dict, List, Optional, Union
import torch
import torch.nn as nn
from accelerate.utils.constants import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
from accelerate.utils.modeling import (
check_tied_parameters_in_config,
check_tied_parameters_on_same_device,
find_tied_parameters,
load_offloaded_weights,
load_state_dict,
retie_parameters,
set_module_tensor_to_device,
)
from accelerate.utils.offload import offload_weight, save_offload_index
logger = logging.getLogger(__name__)
# TODO: Remove and use instead accelerate.utils.modeling.load_checkpoint_in_model once https://github.com/huggingface/accelerate/pull/2588 is merged & accelerate 0.29 is released.
def load_checkpoint_in_model(
model: nn.Module,
checkpoint: Union[str, os.PathLike],
device_map: Optional[Dict[str, Union[int, str, torch.device]]] = None,
offload_folder: Optional[Union[str, os.PathLike]] = None,
dtype: Optional[Union[str, torch.dtype]] = None,
offload_state_dict: bool = False,
offload_buffers: bool = False,
keep_in_fp32_modules: List[str] = None,
offload_8bit_bnb: bool = False,
strict: bool = False,
):
"""
Loads a (potentially sharded) checkpoint inside a model, potentially sending weights to a given device as they are
loaded.
<Tip warning={true}>
Once loaded across devices, you still need to call [`dispatch_model`] on your model to make it able to run. To
group the checkpoint loading and dispatch in one single call, use [`load_checkpoint_and_dispatch`].
</Tip>
Args:
model (`torch.nn.Module`):
The model in which we want to load a checkpoint.
checkpoint (`str` or `os.PathLike`):
The folder checkpoint to load. It can be:
- a path to a file containing a whole model state dict
- a path to a `.json` file containing the index to a sharded checkpoint
- a path to a folder containing a unique `.index.json` file and the shards of a checkpoint.
- a path to a folder containing a unique pytorch_model.bin or a model.safetensors file.
device_map (`Dict[str, Union[int, str, torch.device]]`, *optional*):
A map that specifies where each submodule should go. It doesn't need to be refined to each parameter/buffer
name, once a given module name is inside, every submodule of it will be sent to the same device.
offload_folder (`str` or `os.PathLike`, *optional*):
If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
dtype (`str` or `torch.dtype`, *optional*):
If provided, the weights will be converted to that type when loaded.
offload_state_dict (`bool`, *optional*, defaults to `False`):
If `True`, will temporarily offload the CPU state dict on the hard drive to avoid getting out of CPU RAM if
the weight of the CPU state dict + the biggest shard does not fit.
offload_buffers (`bool`, *optional*, defaults to `False`):
Whether or not to include the buffers in the weights offloaded to disk.
keep_in_fp32_modules(`List[str]`, *optional*):
A list of the modules that we keep in `torch.float32` dtype.
offload_8bit_bnb (`bool`, *optional*):
Whether or not to enable offload of 8-bit modules on cpu/disk.
strict (`bool`, *optional*, defaults to `False`):
Whether to strictly enforce that the keys in the checkpoint state_dict match the keys of the model's state_dict.
"""
if offload_8bit_bnb:
from accelerate.utils.bnb import quantize_and_offload_8bit
tied_params = find_tied_parameters(model)
if check_tied_parameters_in_config(model) and len(tied_params) == 0:
logger.warn(
"The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function."
)
if device_map is not None:
check_tied_parameters_on_same_device(tied_params, device_map)
if offload_folder is None and device_map is not None and "disk" in device_map.values():
raise ValueError(
"At least one of the model submodule will be offloaded to disk, please pass along an `offload_folder`."
)
elif offload_folder is not None and device_map is not None and "disk" in device_map.values():
os.makedirs(offload_folder, exist_ok=True)
if isinstance(dtype, str):
# We accept "torch.float16" or just "float16"
dtype = dtype.replace("torch.", "")
dtype = getattr(torch, dtype)
checkpoint_files = None
index_filename = None
if os.path.isfile(checkpoint):
if str(checkpoint).endswith(".json"):
index_filename = checkpoint
else:
checkpoint_files = [checkpoint]
elif os.path.isdir(checkpoint):
# check if the whole state dict is present
potential_state_bin = [f for f in os.listdir(checkpoint) if f == WEIGHTS_NAME]
potential_state_safetensor = [f for f in os.listdir(checkpoint) if f == SAFE_WEIGHTS_NAME]
if len(potential_state_bin) == 1:
checkpoint_files = [os.path.join(checkpoint, potential_state_bin[0])]
elif len(potential_state_safetensor) == 1:
checkpoint_files = [os.path.join(checkpoint, potential_state_safetensor[0])]
else:
# otherwise check for sharded checkpoints
potential_index = [f for f in os.listdir(checkpoint) if f.endswith(".index.json")]
if len(potential_index) == 0:
raise ValueError(
f"{checkpoint} is not a folder containing a `.index.json` file or a {WEIGHTS_NAME} or a {SAFE_WEIGHTS_NAME} file"
)
elif len(potential_index) == 1:
index_filename = os.path.join(checkpoint, potential_index[0])
else:
raise ValueError(
f"{checkpoint} containing more than one `.index.json` file, delete the irrelevant ones."
)
else:
raise ValueError(
"`checkpoint` should be the path to a file containing a whole state dict, or the index of a sharded "
f"checkpoint, or a folder containing a sharded checkpoint or the whole state dict, but got {checkpoint}."
)
if index_filename is not None:
checkpoint_folder = os.path.split(index_filename)[0]
with open(index_filename) as f:
index = json.loads(f.read())
if "weight_map" in index:
index = index["weight_map"]
checkpoint_files = sorted(list(set(index.values()))) # noqa: C414
checkpoint_files = [os.path.join(checkpoint_folder, f) for f in checkpoint_files]
# Logic for missing/unexepected keys goes here.
offload_index = {}
if offload_state_dict:
state_dict_folder = tempfile.mkdtemp()
state_dict_index = {}
unexpected_keys = set()
model_keys = set(model.state_dict().keys())
buffer_names = [name for name, _ in model.named_buffers()]
for checkpoint_file in checkpoint_files:
loaded_checkpoint = load_state_dict(checkpoint_file, device_map=device_map)
if device_map is None:
model.load_state_dict(loaded_checkpoint, strict=strict)
unexpected_keys.update(set(loaded_checkpoint.keys()) - model_keys)
else:
for param_name, param in loaded_checkpoint.items():
# skip SCB parameter (for 8-bit serialization)
if "SCB" in param_name:
continue
if param_name not in model_keys:
unexpected_keys.add(param_name)
if not strict:
continue # Skip loading this parameter.
module_name = param_name
while len(module_name) > 0 and module_name not in device_map:
module_name = ".".join(module_name.split(".")[:-1])
if module_name == "" and "" not in device_map:
# TODO: group all errors and raise at the end.
raise ValueError(f"{param_name} doesn't have any device set.")
param_device = device_map[module_name]
new_dtype = dtype
if dtype is not None and torch.is_floating_point(param):
if keep_in_fp32_modules is not None and dtype == torch.float16:
proceed = False
for key in keep_in_fp32_modules:
if ((key in param_name) and (key + "." in param_name)) or key == param_name:
proceed = True
break
if proceed:
new_dtype = torch.float32
if "weight" in param_name and param_name.replace("weight", "SCB") in loaded_checkpoint.keys():
if param.dtype == torch.int8:
fp16_statistics = loaded_checkpoint[param_name.replace("weight", "SCB")]
else:
fp16_statistics = None
if param_device == "disk":
if offload_buffers or param_name not in buffer_names:
if new_dtype is None:
new_dtype = param.dtype
if offload_8bit_bnb:
quantize_and_offload_8bit(
model, param, param_name, new_dtype, offload_folder, offload_index, fp16_statistics
)
continue
else:
set_module_tensor_to_device(model, param_name, "meta", dtype=new_dtype)
offload_weight(param, param_name, offload_folder, index=offload_index)
elif param_device == "cpu" and offload_state_dict:
if new_dtype is None:
new_dtype = param.dtype
if offload_8bit_bnb:
quantize_and_offload_8bit(
model, param, param_name, new_dtype, state_dict_folder, state_dict_index, fp16_statistics
)
else:
set_module_tensor_to_device(model, param_name, "meta", dtype=new_dtype)
offload_weight(param, param_name, state_dict_folder, index=state_dict_index)
else:
set_module_tensor_to_device(
model,
param_name,
param_device,
value=param,
dtype=new_dtype,
fp16_statistics=fp16_statistics,
)
# Force Python to clean up.
del loaded_checkpoint
gc.collect()
if not strict and len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the model checkpoint at {checkpoint} were not used when"
f" initializing {model.__class__.__name__}: {unexpected_keys}. This may or may not be an issue - make sure that the checkpoint does not have unnecessary parameters, or that the model definition correctly corresponds to the checkpoint."
)
save_offload_index(offload_index, offload_folder)
# Load back offloaded state dict on CPU
if offload_state_dict:
load_offloaded_weights(model, state_dict_index, state_dict_folder)
shutil.rmtree(state_dict_folder)
retie_parameters(model, tied_params)
import copy
import random
from functools import partial
from typing import Callable, Dict, List, Optional
import torch
from datasets import DatasetDict, IterableDatasetDict, load_dataset
from torch import LongTensor
from torch.utils.data import DataLoader
from transformers import PreTrainedTokenizer
def make_data_block(
samples: Dict[str, List[str]],
prompt_col_name: str,
label_col_name: str,
tokenizer: PreTrainedTokenizer,
preprocess_fn: Optional[Callable] = None,
sample_max_len: int = 1024,
block_max_len: int = 2048,
add_eos_token: bool = False,
truncate_prompt: bool = True,
merge_prompt_label: bool = False,
) -> Dict[str, List[LongTensor]]:
"""A simple implementation of text generation oriented smart batching to maximize VRAM usage when evaluation
:param samples: Dict[str, List[str]], samples that used to make data blocks
:param prompt_col_name: str, name of the key in samples whose value stores prompt
:param label_col_name: str, name of the key in samples whose value stores label
:param tokenizer: transformers.PretrainedTokenizer, tokenizer that used to tokenize samples
:param preprocess_fn: Optional[Callable], optional function that used to preprocess samples such as
refactor the data structure of samples, note the output of this function must be a dict whose keys
at least contains `prompt_col_name` and `label_col_name`
:param sample_max_len: int, defaults to 1024, max tokens number of each sample (before padding)
:param block_max_len: int, defaults to 2048, max tokens number of each data block (after padding)
:param add_eos_token: bool, defaults to False, whether add eos_token or not to the label
:param truncate_prompt: bool, defaults to True, whether to truncate prompt if the sample's total tokens
number exceeds `sample_max_len`, if not, will truncate label and drop this sample when all tokens
in label are truncated
:param merge_prompt_label: bool, defaults to False, will merge label into prompt if set to True, usually
this only required when doing language modeling task
:return: Dict[str, List[torch.LongTensor]], a dict whose keys are `input_ids`, `attention_mask` and
`label` and values are a list of torch.LongTensor
"""
if preprocess_fn:
samples = preprocess_fn(samples)
prompts = samples[prompt_col_name]
labels = samples[label_col_name]
# tokenize samples
tokenized_prompts = tokenizer(prompts, truncation=False)["input_ids"]
tokenized_labels = tokenizer(labels, truncation=False)["input_ids"]
# filter tokenized samples by length
dropped_indices = []
for idx, (tokenized_prompt, tokenized_label) in enumerate(zip(tokenized_prompts, tokenized_labels)):
if add_eos_token:
tokenized_label += [tokenizer.eos_token_id]
len_prompt = len(tokenized_prompt)
len_label = len(tokenized_label)
exceed_len = len_prompt + len_label - sample_max_len
if exceed_len > 0:
if truncate_prompt:
tokenized_prompt = tokenized_prompt[exceed_len:]
else:
tokenized_label = tokenized_label[:-exceed_len]
tokenized_prompts[idx] = tokenized_prompt
tokenized_labels[idx] = tokenized_label
if not tokenized_label:
dropped_indices.append(idx)
# make data blocks of samples
tokenized_samples = sorted(
[(p, l) for idx, (p, l) in enumerate(zip(tokenized_prompts, tokenized_labels)) if idx not in dropped_indices],
key=lambda x: (len(x[0]) + len(x[1])) if merge_prompt_label else len(x[0]),
)
sample_blocks = []
sample_block = []
blk_max_len = 0
blk_total_len = 0
for tokenized_sample in tokenized_samples:
prompt_ids, label_ids = tokenized_sample
ori_sample_len = len(prompt_ids)
if merge_prompt_label:
ori_sample_len += len(label_ids)
if ori_sample_len <= blk_max_len:
additional_len = blk_max_len
sample_len = blk_max_len
else:
additional_len = len(sample_block) * (ori_sample_len - blk_max_len) + ori_sample_len
sample_len = ori_sample_len
if blk_total_len + additional_len > block_max_len:
sample_blocks.append((copy.copy(sample_block), blk_max_len))
sample_block = []
blk_max_len = 0
blk_total_len = 0
sample_len = ori_sample_len
additional_len = ori_sample_len
sample_block.append(tokenized_sample)
blk_max_len = max(blk_max_len, sample_len)
blk_total_len += additional_len
if sample_block:
sample_blocks.append((copy.copy(sample_block), blk_max_len))
del sample_block
del blk_max_len
del blk_total_len
new_samples = {"input_ids": [], "attention_mask": [], "labels": []}
# padding each data block internally
for block, blk_max_len in sample_blocks:
input_ids = []
attention_mask = []
label_ids = []
label_max_len = max([len(sample[1]) for sample in block])
for sample in block:
tokenized_prompt, tokenized_label = sample
sample_len = len(tokenized_prompt)
if merge_prompt_label:
sample_len += len(tokenized_label)
pad_num = blk_max_len - sample_len
if merge_prompt_label:
input_ids.append([tokenizer.pad_token_id] * pad_num + tokenized_prompt + tokenized_label)
label_ids.append([-100] * (pad_num + len(tokenized_prompt)) + tokenized_label)
else:
input_ids.append([tokenizer.pad_token_id] * pad_num + tokenized_prompt)
label_ids.append([-100] * (label_max_len - len(tokenized_label)) + tokenized_label)
attention_mask.append([0] * pad_num + [1] * sample_len)
new_samples["input_ids"].append(input_ids)
new_samples["attention_mask"].append(attention_mask)
new_samples["labels"].append(label_ids)
return new_samples
def collate_data(blocks: List[Dict[str, List[List[int]]]], pad_token_id: int) -> Dict[str, LongTensor]:
def pad_block(block, pads):
return torch.cat((pads.to(block.device), block), dim=-1)
input_ids_blocks = [LongTensor(block["input_ids"]) for block in blocks]
attention_mask_blocks = [LongTensor(block["attention_mask"]) for block in blocks]
label_blocks = [LongTensor(block["labels"]) for block in blocks]
bsz = len(blocks)
inp_max_len = max([block.size(-1) for block in input_ids_blocks])
label_max_len = max([block.size(-1) for block in label_blocks])
for i in range(bsz):
block_bsz, block_inp_len = input_ids_blocks[i].shape
block_label_len = label_blocks[i].shape[-1]
pad_num = inp_max_len - block_inp_len
if pad_num > 0:
input_ids_blocks[i] = pad_block(input_ids_blocks[i], torch.ones((block_bsz, pad_num)) * pad_token_id)
attention_mask_blocks[i] = pad_block(attention_mask_blocks[i], torch.zeros((block_bsz, pad_num)))
label_pad_num = label_max_len - block_label_len
if label_pad_num > 0:
label_blocks[i] = pad_block(label_blocks[i], torch.ones((block_bsz, label_pad_num)) * -100)
return {
"input_ids": torch.cat(input_ids_blocks, dim=0).long(),
"attention_mask": torch.cat(attention_mask_blocks, dim=0).long(),
"labels": torch.cat(label_blocks, dim=0).long(),
}
def get_dataloader(
data_path_or_name: str,
prompt_col_name: str,
label_col_name: str,
tokenizer: PreTrainedTokenizer,
load_fn: Optional[Callable] = None,
preprocess_fn: Optional[Callable] = None,
num_samples: int = 128,
sample_max_len: int = 1024,
block_max_len: int = 2048,
add_eos_token: bool = False,
truncate_prompt: bool = True,
merge_prompt_label: bool = False,
load_fn_kwargs: Optional[dict] = None,
preprocess_fn_kwargs: Optional[dict] = None,
**kwargs,
) -> DataLoader:
"""load dataset and build dataloader
:param data_path_or_name: str, dataset name in hf-hub or local file path
:param prompt_col_name: str, see `make_data_block`
:param label_col_name: str, see `make_data_block`
:param tokenizer: str, see `make_data_block`
:param load_fn: Optional[Callable], defaults to None, function used to load dataset, if not specified,
use `datasets.load_dataset`
:param preprocess_fn: Optional[Callable], see `make_data_block`
:param num_samples: int, defaults to 128, total samples used to evaluation
:param sample_max_len: int, see `make_data_block`
:param block_max_len: int, see `make_data_block`
:param add_eos_token: bool, see `make_data_block`
:param truncate_prompt: bool, see `make_data_block`
:param merge_prompt_label: bool, see `make_data_block`
:param load_fn_kwargs: Optional[dict], defaults to None, keyword arguments used
for `load_fn` or `datasets.load_dataset`
:param preprocess_fn_kwargs: Optional[dict], defaults to None, keyword arguments used
for `preprocess_fn`
:param kwargs: additional keyword arguments will be passed to torch's `DataLoader` initialization,
note values of `batch_size`, `shuffle` and `collate_fn` will always be overridden to fixed value
:return: torch.utils.data.DataLoader
"""
if not load_fn_kwargs:
load_fn_kwargs = {}
if not preprocess_fn_kwargs:
preprocess_fn_kwargs = {}
if load_fn:
ds = load_fn(data_path_or_name, **load_fn_kwargs)
else:
ds = load_dataset(data_path_or_name, **load_fn_kwargs)
if isinstance(ds, (DatasetDict, IterableDatasetDict)):
if "evaluation" in ds:
ds = ds["evaluation"]
elif "test" in ds:
ds = ds["test"]
else:
ds = ds["train"]
ds = ds.select(
indices=random.sample(range(len(ds)), min(len(ds), num_samples)),
keep_in_memory=True,
)
ds = ds.map(
make_data_block,
batched=True,
batch_size=len(ds),
num_proc=1,
remove_columns=ds.column_names,
keep_in_memory=True,
load_from_cache_file=False,
fn_kwargs={
"prompt_col_name": prompt_col_name,
"label_col_name": label_col_name,
"tokenizer": tokenizer,
"preprocess_fn": partial(preprocess_fn, **preprocess_fn_kwargs),
"sample_max_len": sample_max_len,
"block_max_len": block_max_len,
"add_eos_token": add_eos_token,
"truncate_prompt": truncate_prompt,
"merge_prompt_label": merge_prompt_label,
},
)
# override some arguments' values in kwargs despite user specified
kwargs["batch_size"] = 1
kwargs["shuffle"] = False
kwargs["collate_fn"] = partial(collate_data, pad_token_id=tokenizer.pad_token_id)
dl = DataLoader(ds, **kwargs)
return dl
__all__ = ["make_data_block", "collate_data", "get_dataloader"]
import gc
import torch
from ..nn_modules.qlinear.qlinear_exllama import QuantLinear as ExllamaQuantLinear
def exllama_set_max_input_length(model, max_input_length: int):
"""
This method does not necessarily require `model` to inherit from BaseGPTQForCausalLM.
When using the exllama backend with act-order, it is necessary to initialize a buffer that depends on the maximum expected input length. In case the
default used (EXLLAMA_DEFAULT_MAX_INPUT_LENGTH) is too short, this method can be called to extend the buffer size without reloading the whole model.
"""
# The import is set here to avoid a global import. Arguably this is quite ugly, it would be better to have lazy loading.
from exllama_kernels import cleanup_buffers_cuda, prepare_buffers
if not model.quantize_config.desc_act:
raise ValueError(
"The method exllama_set_max_input_length should be called only when using the exllama backend **with act-order**."
)
uses_exllama = False
for name, submodule in model.named_modules():
if isinstance(submodule, ExllamaQuantLinear):
uses_exllama = True
if not uses_exllama:
raise ValueError(
f"The function exllama_set_max_input_length was called, but the model (instance of {model.__class__.__name__}) does not use the exllama backend for GPTQ. An other implementation is used (exllamav2, cuda, cuda-old, triton) and that the call to exllama_set_max_input_length is unnecessary. Please remove the call to exllama_set_max_input_length or use the exllama v1 backend."
)
device_to_buffers_size = {}
for device, buffers in model.device_to_buffers.items():
device_to_buffers_size[device] = {
"max_dq_buffer_size": buffers["max_dq_buffer_size"],
"max_inner_outer_dim": buffers["max_inner_outer_dim"],
}
# For an unknown reason calling just `del model.device_to_buffers` raises an AttributeError.
for key in list(model.device_to_buffers.keys()):
del model.device_to_buffers[key]
model.device_to_buffers = None
del model.device_to_buffers
gc.collect()
torch.cuda.empty_cache()
cleanup_buffers_cuda()
device_to_buffers = {}
for device, buffers_size in device_to_buffers_size.items():
# The temp_state buffer is required to reorder X in the act-order case.
# The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
device_to_buffers[device] = {
"temp_state": torch.zeros(
(max_input_length, buffers_size["max_inner_outer_dim"]),
dtype=torch.float16,
device=device,
),
"temp_dq": torch.zeros(
(1, buffers_size["max_dq_buffer_size"]),
dtype=torch.float16,
device=device,
),
"max_dq_buffer_size": buffers_size["max_dq_buffer_size"],
"max_inner_outer_dim": buffers_size["max_inner_outer_dim"],
}
prepare_buffers(
device,
device_to_buffers[device]["temp_state"],
device_to_buffers[device]["temp_dq"],
)
# Buffers need to be persistent to avoid any bug.
model.device_to_buffers = device_to_buffers
return model
from logging import getLogger
from typing import Optional
import torch
from packaging.version import parse as parse_version
try:
import triton # noqa: F401
TRITON_AVAILABLE = True
except ImportError:
TRITON_AVAILABLE = False
try:
import autogptq_cuda_64 # noqa: F401
AUTOGPTQ_CUDA_AVAILABLE = True
except Exception:
AUTOGPTQ_CUDA_AVAILABLE = False
try:
import exllama_kernels # noqa: F401
EXLLAMA_KERNELS_AVAILABLE = True
except Exception:
EXLLAMA_KERNELS_AVAILABLE = False
try:
import exllamav2_kernels # noqa: F401
EXLLAMAV2_KERNELS_AVAILABLE = True
except Exception:
EXLLAMAV2_KERNELS_AVAILABLE = False
try:
import cQIGen # noqa: F401
QIGEN_AVAILABLE = True
QIGEN_EXCEPTION = None
except Exception as e:
QIGEN_AVAILABLE = False
QIGEN_EXCEPTION = e
try:
import autogptq_marlin_cuda # noqa: F401
MARLIN_AVAILABLE = True
MARLIN_EXCEPTION = None
except Exception as e:
MARLIN_AVAILABLE = False
MARLIN_EXCEPTION = e
logger = getLogger(__name__)
def dynamically_import_QuantLinear(
use_triton: bool,
desc_act: bool,
group_size: int,
bits: int,
disable_exllama: Optional[bool] = None,
disable_exllamav2: bool = False,
use_qigen: bool = False,
use_marlin: bool = False,
use_tritonv2: bool = False,
):
try:
import habana_frameworks.torch.hpu # noqa: F401
except Exception as e:
pass
else:
from ..nn_modules.qlinear.qlinear_hpu import QuantLinear
return QuantLinear
if use_qigen:
if not QIGEN_AVAILABLE:
raise ValueError(
f"QIGen appears to be not available with the error: {QIGEN_EXCEPTION}. Please check your installation or use `use_qigen=False`."
)
from ..nn_modules.qlinear.qlinear_qigen import QuantLinear
else:
if use_triton or use_tritonv2:
if torch.version.hip:
logger.warning(
"Running GPTQ triton version on AMD GPUs is untested and may result in errors or wrong predictions. Please use use_triton=False."
)
if use_tritonv2:
logger.debug("Using tritonv2 for GPTQ")
from ..nn_modules.qlinear.qlinear_tritonv2 import QuantLinear
else:
from ..nn_modules.qlinear.qlinear_triton import QuantLinear
else:
# If disable_exllamav2 is True, we want to fall back on the exllama kernel and not the cuda/cuda_old ones.
if disable_exllama is None:
if disable_exllamav2:
disable_exllama = False
else:
disable_exllama = True
if bits == 4 and use_marlin:
from ..nn_modules.qlinear.qlinear_marlin import QuantLinear
elif bits == 4 and not disable_exllamav2 and EXLLAMAV2_KERNELS_AVAILABLE:
from ..nn_modules.qlinear.qlinear_exllamav2 import QuantLinear
elif bits == 4 and not disable_exllama and EXLLAMA_KERNELS_AVAILABLE:
from ..nn_modules.qlinear.qlinear_exllama import QuantLinear
elif not desc_act or group_size == -1:
from ..nn_modules.qlinear.qlinear_cuda_old import QuantLinear
else:
from ..nn_modules.qlinear.qlinear_cuda import QuantLinear
return QuantLinear
def compare_transformers_version(version: str = "v4.28.0", op: str = "eq"):
assert op in ["eq", "lt", "le", "gt", "ge"]
from transformers import __version__
return getattr(parse_version(__version__), f"__{op}__")(parse_version(version))
def compare_pytorch_version(version: str = "v2.0.0", op: str = "eq"):
assert op in ["eq", "lt", "le", "gt", "ge"]
from torch import __version__
return getattr(parse_version(__version__), f"__{op}__")(parse_version(version))
import gc
from logging import getLogger
from typing import Tuple
import torch
from accelerate.utils import find_tied_parameters
from safetensors.torch import save_file as safe_save
from tqdm import tqdm
from ..nn_modules.qlinear.qlinear_marlin import QuantLinear as MarlinQuantLinear
from ..nn_modules.qlinear.qlinear_marlin import _get_perms, unpack_qzeros
from ..quantization import CHECKPOINT_FORMAT, QUANT_METHOD, BaseQuantizeConfig
from .accelerate_utils import load_checkpoint_in_model
from .import_utils import MARLIN_AVAILABLE, MARLIN_EXCEPTION
from .modeling_utils import recurse_getattr, recurse_setattr
if MARLIN_AVAILABLE:
import autogptq_marlin_cuda
logger = getLogger(__name__)
def prepare_model_for_marlin_load(
model,
quantize_config: BaseQuantizeConfig,
quant_linear_class,
torch_dtype,
current_model_save_name,
device_map,
):
# The model (e.g. model.safetensors) is already serialized in the Marlin format, load it directly.
if quantize_config.checkpoint_format == CHECKPOINT_FORMAT.MARLIN:
model_save_name = current_model_save_name
logger.info(f"Loading a GPTQ model, detected Marlin serialized format at {model_save_name}.")
model = convert_to_marlin(model, quant_linear_class, quantize_config, repack=False)
else:
model_save_name, is_cached = quantize_config.get_cache_file_path(quant_method=QUANT_METHOD.GPTQ,
checkpoint_format=CHECKPOINT_FORMAT.MARLIN)
# If GPTQ model has Marlin version cached locally, load from the cached version (no repacking needed).
if is_cached:
logger.info(
f"Loading a GPTQ model, detected a cached repacked weight for Marlin kernel at {model_save_name}."
)
model = convert_to_marlin(model, quant_linear_class, quantize_config, repack=False)
# Otherwise, convert the model to Marlin format first and cache locally.
else:
# Loading the GPTQ checkpoint to do the conversion.
# TODO: Avoid loading the model with wrong QuantLinear, and directly use
# Marlin ones. The repacking can be done directly on the safetensors, just
# as for AWQ checkpoints.
load_checkpoint_in_model(
model,
dtype=torch_dtype, # This is very hacky but works due to https://github.com/huggingface/accelerate/blob/bd72a5f1a80d5146554458823f8aeda0a9db5297/src/accelerate/utils/modeling.py#L292
checkpoint=current_model_save_name,
device_map=device_map,
offload_state_dict=True,
offload_buffers=True,
)
# Convert model to marlin, repacking weights into Marlin format.
model = convert_to_marlin(model, quant_linear_class, quantize_config, repack=True)
# Safetensors is unable to save tied weights, so we untie them here. Reference: https://github.com/huggingface/safetensors/issues/202
tied_params = find_tied_parameters(model)
for weight_group in tied_params:
for param_name in weight_group:
if isinstance(recurse_getattr(model, param_name), torch.nn.Parameter):
recurse_setattr(
model,
param_name,
torch.nn.Parameter(recurse_getattr(model, param_name).clone()),
)
else:
recurse_setattr(
model,
param_name,
recurse_getattr(model, param_name).clone(),
)
# Cache the converted model.
safe_save(model.state_dict(), model_save_name)
return model, model_save_name
# Validate marlin support
def _validate_marlin_device_support() -> bool:
"""
Validates if the current device is compatible for Marlin.
ref: https://github.com/IST-DASLab/marlin?tab=readme-ov-file#requirements
Returns:
bool: indicates if CUDA device is compatible for Marlin
"""
return torch.cuda.get_device_capability()[0] >= 8
# Adapted from https://github.com/rib-2/marlin/tree/conversion
def _validate_marlin_compatibility(cfg: BaseQuantizeConfig):
if not MARLIN_AVAILABLE:
return f"AutoGPTQ is not compiled with the Marlin kernel, with the following error: {MARLIN_EXCEPTION}"
if cfg.bits != 4:
return f"The quantized model uses a bitwidth different than 4 (found {cfg.bits})"
if cfg.group_size != 128 and cfg.group_size != -1:
return "The quantized model uses a group size that is not 128 or -1 (found quantization_config.group_size)"
if not cfg.sym:
return "The quantized model uses asymmetric quantization"
if cfg.desc_act:
return "The quantized model uses act-order (also called desc-act) scheme"
if cfg.quant_method == QUANT_METHOD.AWQ:
return "awq_gemm format is currently not compatible with marlin"
return None
@torch.no_grad()
def convert_to_marlin(model, model_quantlinear, quantization_config: BaseQuantizeConfig, repack: bool, strict: bool = False):
"""
Converts GPTQ-packed weights to the Marlin format. This assumes that the model already meets Marlin kernel constraints.
Arguments:
repack (`bool`):
Whether to repack the qweights from `model` into the Marlin's QuantLinear layers.
"""
if repack:
message = "Repacking weights to be compatible with Marlin kernel..."
else:
# TODO: load directly Marlin QuantLinear.
message = "Overriding QuantLinear layers to use Marlin's QuantLinear..."
for name, module in tqdm(model.named_modules(), desc=message, total=len(list(model.named_modules()))):
if not isinstance(module, model_quantlinear):
continue
parent_name = ".".join(name.split(".")[:-1])
layer_name = name[len(parent_name) + 1 :]
# We could use `torch.count_nonzero(module.bias) > 0` here to discard zero bias, but this has issues when
# loading weights from checkpoints holding zero bias.
with torch.device("meta"):
new_module = MarlinQuantLinear(
bits=4,
group_size=module.group_size,
infeatures=module.infeatures,
outfeatures=module.outfeatures,
bias=module.bias is not None,
trainable=False,
)
# workspace is never in the state_dict, thus we need to allocate it manually.
new_module.workspace = torch.zeros(module.outfeatures // 128 * 16, dtype=torch.int, device=module.device)
# Dequantize the weight.
if repack:
marlin_repacked_weight = autogptq_marlin_cuda.gptq_repack(module.qweight)
if strict:
dequantized_qzeros = unpack_qzeros(module.qzeros)
if not torch.all(dequantized_qzeros == 8):
raise ValueError(
"Marlin kernel is compatible only with checkpoints using symmetric quantization."
"Found non-symmetric quantization for the weight {name}."
)
_, _scale_perm, _scale_perm_single = _get_perms()
s = module.scales.data.clone()
if module.group_size != module.infeatures:
s = s.reshape((1, -1))
s = s.reshape((-1, len(_scale_perm)))[:, _scale_perm]
else:
s = s.reshape((-1, len(_scale_perm_single)))[:, _scale_perm_single]
s = s.reshape((-1, module.outfeatures)).contiguous()
new_module.B = marlin_repacked_weight
new_module.s = s
new_module.bias = module.bias
new_module = new_module.to(module.device)
# Save to parent.
parent_module = model.get_submodule(parent_name)
setattr(parent_module, layer_name, new_module)
# Free cuda memory.
del module
if repack:
del marlin_repacked_weight
gc.collect()
# Set quantization config to be Marlin.
quantization_config.checkpoint_format = CHECKPOINT_FORMAT.MARLIN
return model
import functools
def recurse_getattr(obj, attr: str):
"""
Recursive `getattr`.
Args:
obj:
A class instance holding the attribute.
attr (`str`):
The attribute that is to be retrieved, e.g. 'attribute1.attribute2'.
"""
def _getattr(obj, attr):
return getattr(obj, attr)
return functools.reduce(_getattr, [obj] + attr.split("."))
def recurse_setattr(module, name, value):
"""A function to recursively set attributes to a module."""
if "." not in name:
setattr(module, name, value)
else:
name, rest = name.split(".", 1)
recurse_setattr(getattr(module, name), rest, value)
import warnings
from contextlib import contextmanager
from typing import List, Optional, Tuple, Union
import torch
from peft import PeftConfig, PeftModel, PeftType, get_peft_model
from peft.mapping import PEFT_TYPE_TO_CONFIG_MAPPING
from peft.peft_model import PEFT_TYPE_TO_MODEL_MAPPING
from peft.tuners.adalora import AdaLoraConfig, AdaLoraLayer, AdaLoraModel
from peft.tuners.lora import LoraConfig, LoraLayer, LoraModel
from ..modeling._base import BaseGPTQForCausalLM
from ..nn_modules.qlinear import GeneralQuantLinear
from ..nn_modules.qlinear.qlinear_cuda import QuantLinear as QuantLinearCuda
from ..nn_modules.qlinear.qlinear_cuda_old import QuantLinear as QuantLinearCudaOld
from ..nn_modules.qlinear.qlinear_hpu import QuantLinear as QuantLinearHpu
from ..nn_modules.qlinear.qlinear_exllama import QuantLinear as QuantLinearExllama
from ..nn_modules.qlinear.qlinear_exllama import QuantLinear as QuantLinearExllamaV2
from ..nn_modules.qlinear.qlinear_qigen import QuantLinear as QuantLinearQigen
from ..nn_modules.qlinear.qlinear_triton import QuantLinear as QuantLinearTriton
LinearLayer = Union[
torch.nn.Linear,
GeneralQuantLinear,
QuantLinearCuda,
QuantLinearCudaOld,
QuantLinearHpu,
QuantLinearExllama,
QuantLinearExllamaV2,
QuantLinearQigen,
QuantLinearTriton,
]
class GPTQLoraConfig(LoraConfig):
injected_fused_attention: bool = False
injected_fused_mlp: bool = False
def _get_linear_feature_count(linear_layer: LinearLayer) -> Tuple[int, int]:
in_features = getattr(linear_layer, "in_features", getattr(linear_layer, "infeatures"))
out_features = getattr(linear_layer, "out_features", getattr(linear_layer, "outfeatures"))
return in_features, out_features
def _get_weight(linear_layer: LinearLayer) -> torch.Tensor:
return getattr(linear_layer, "weight", getattr(linear_layer, "qweight"))
class GPTQLoraLinear(torch.nn.Linear, LoraLayer):
def __init__(
self,
adapter_name: str,
linear_module: LinearLayer,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
**kwargs,
):
init_lora_weights = kwargs.pop("init_lora_weights", True)
in_features, out_features = _get_linear_feature_count(linear_module)
torch.nn.Linear.__init__(self, in_features, out_features)
LoraLayer.__init__(self, in_features, out_features)
self.linear_module = linear_module
delattr(self, "weight")
self.weight = _get_weight(linear_module)
delattr(self, "bias")
self.fan_in_fan_out = fan_in_fan_out
if fan_in_fan_out:
assert hasattr(linear_module, "weight")
linear_module.weight.data = linear_module.weight.data.T
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
self.active_adapter = adapter_name
def reset_lora_parameters(self, adapter_name):
if adapter_name in self.lora_A.keys():
torch.nn.init.xavier_uniform_(self.lora_A[adapter_name].weight)
torch.nn.init.zeros_(self.lora_B[adapter_name].weight)
def merge(self):
raise NotImplementedError("gptq model not support merge lora adapter")
def unmerge(self):
raise NotImplementedError("gptq model not support unmerge lora adapter")
def forward(self, x: torch.Tensor):
previous_dtype = x.dtype
if self.active_adapter not in self.lora_A.keys():
return self.linear_module(x)
if self.disable_adapters:
if self.r[self.active_adapter] > 0 and self.merged:
self.unmerge()
result = self.linear_module(x)
elif self.r[self.active_adapter] > 0 and not self.merged:
result = self.linear_module(x)
lora_B = self.lora_B[self.active_adapter]
lora_A = self.lora_A[self.active_adapter]
lora_dropout = self.lora_dropout[self.active_adapter]
scale = self.scaling[self.active_adapter]
x = x.type_as(lora_A.weight.data)
adapter_result = (lora_B(lora_A(lora_dropout(x))) * scale).type_as(result)
result += adapter_result
else:
result = self.linear_module(x)
result = result.to(previous_dtype)
return result
class GPTQLoraModel(LoraModel):
def _replace_module(self, parent_module, child_name, new_module, old_module):
setattr(parent_module, child_name, new_module)
if not isinstance(new_module, GPTQLoraLinear):
new_module.weight = old_module.weight
if hasattr(old_module, "bias"):
if old_module.bias is not None:
new_module.bias = old_module.bias
if getattr(old_module, "state", None) is not None:
new_module.state = old_module.state
new_module.to(old_module.weight.device)
# dispatch to correct device
for name, module in new_module.named_modules():
if "lora_" in name:
device = (list(old_module.parameters()) + list(old_module.buffers()))[0].device
module.to(device)
@staticmethod
def _create_new_module(
lora_config: GPTQLoraConfig,
adapter_name: str,
target: torch.nn.Linear,
**kwargs,
):
gptq_quantlinears = {
GeneralQuantLinear,
QuantLinearCuda,
QuantLinearCudaOld,
QuantLinearHpu,
QuantLinearExllama,
QuantLinearExllamaV2,
QuantLinearQigen,
QuantLinearTriton,
}
is_gptq_layer = any(isinstance(target, cls) for cls in gptq_quantlinears)
if is_gptq_layer:
return GPTQLoraLinear(
adapter_name,
target,
r=lora_config.r,
lora_alpha=lora_config.lora_alpha,
lora_dropout=lora_config.lora_dropout,
fan_in_fan_out=lora_config.fan_in_fan_out,
)
else:
return LoraModel._create_new_module(lora_config, adapter_name, target, **kwargs)
def merge_adapter(self):
raise NotImplementedError("gptq model not support merge ada lora adapter")
def unmerge_adapter(self):
raise NotImplementedError("gptq model not support unmerge ada lora adapter")
def merge_and_unload(self):
raise NotImplementedError("gptq model not support merge and unload")
class GPTQAdaLoraConfig(AdaLoraConfig):
injected_fused_attention: bool = False
injected_fused_mlp: bool = False
class GPTQSVDLinear(torch.nn.Linear, AdaLoraLayer):
def __init__(
self,
adapter_name: str,
linear_module: LinearLayer,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
**kwargs,
):
init_lora_weights = kwargs.pop("init_lora_weights", True)
in_features, out_features = _get_linear_feature_count(linear_module)
torch.nn.Linear.__init__(self, in_features, out_features)
AdaLoraLayer.__init__(self, in_features, out_features)
self.linear_module = linear_module
delattr(self, "weight")
self.weight = _get_weight(linear_module)
delattr(self, "bias")
self.fan_in_fan_out = fan_in_fan_out
if fan_in_fan_out:
assert hasattr(linear_module, "weight")
linear_module.weight.data = linear_module.weight.data.T
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
self.active_adapter = adapter_name
def merge(self):
raise NotImplementedError("gptq model not support merge lora adapter")
def unmerge(self):
raise NotImplementedError("gptq model not support unmerge lora adapter")
def forward(self, x: torch.Tensor):
if self.active_adapter not in self.lora_A.keys():
return self.linear_module(x)
if self.disable_adapters:
if self.r[self.active_adapter] > 0 and self.merged:
self.unmerge()
result = self.linear_module(x)
elif self.r[self.active_adapter] > 0 and not self.merged:
result = self.linear_module(x)
result += (
(
self.lora_dropout[self.active_adapter](x)
@ (self.lora_A[self.active_adapter] * self.lora_E[self.active_adapter]).T
@ self.lora_B[self.active_adapter].T
)
* self.scaling[self.active_adapter]
/ (self.ranknum[self.active_adapter] + 1e-5)
)
else:
result = self.linear_module(x)
return result
def reset_lora_parameters(self, adapter_name):
if adapter_name in self.lora_A.keys():
# Peft standard values seems too high
# Still not ideal, just not causing NaNs with fp16 anymore
torch.nn.init.normal_(self.lora_E[adapter_name], mean=0.0, std=0.005)
torch.clamp_(self.lora_E[adapter_name].data, -0.1, 0.1)
torch.nn.init.normal_(self.lora_A[adapter_name], mean=0.0, std=0.005)
torch.clamp_(self.lora_A[adapter_name].data, -0.1, 0.1)
torch.nn.init.normal_(self.lora_B[adapter_name], mean=0.0, std=0.005)
torch.clamp_(self.lora_B[adapter_name].data, -0.1, 0.1)
class GPTQAdaLoraModel(AdaLoraModel):
def _replace_module(self, parent_module, child_name, new_module, old_module):
setattr(parent_module, child_name, new_module)
# dispatch to correct device
for name, module in new_module.named_modules():
if "lora_" in name:
device = (list(old_module.parameters()) + list(old_module.buffers()))[0].device
module.to(device)
@staticmethod
def _create_new_module(
lora_config: GPTQLoraConfig,
adapter_name: str,
target: torch.nn.Linear,
**kwargs,
):
gptq_quantlinears = {
GeneralQuantLinear,
QuantLinearCuda,
QuantLinearCudaOld,
QuantLinearHpu,
QuantLinearExllama,
QuantLinearExllamaV2,
QuantLinearQigen,
QuantLinearTriton,
}
is_gptq_layer = any(isinstance(target, cls) for cls in gptq_quantlinears)
if is_gptq_layer:
return GPTQSVDLinear(
adapter_name,
target,
r=lora_config.r,
lora_alpha=lora_config.lora_alpha,
lora_dropout=lora_config.lora_dropout,
fan_in_fan_out=lora_config.fan_in_fan_out,
)
else:
return LoraModel._create_new_module(lora_config, adapter_name, target, **kwargs)
def merge_adapter(self):
raise NotImplementedError("gptq model not support merge ada lora adapter")
def unmerge_adapter(self):
raise NotImplementedError("gptq model not support unmerge ada lora adapter")
def merge_and_unload(self):
raise NotImplementedError("gptq model not support merge and unload")
def find_all_linear_names(
model: BaseGPTQForCausalLM,
ignore: Optional[List[str]] = None,
ignore_lm_head: bool = True,
):
if not ignore:
ignore = []
lm_head_name = model.lm_head_name
if ignore_lm_head and lm_head_name not in ignore:
ignore.append(lm_head_name)
results = set()
for n, m in model.named_modules():
if isinstance(m, torch.nn.Linear):
res = n.split(".")[-1]
if res not in ignore:
results.add(res)
return list(results)
@contextmanager
def hijack_peft_mappings():
PEFT_TYPE_TO_CONFIG_MAPPING[PeftType.LORA] = GPTQLoraConfig
PEFT_TYPE_TO_MODEL_MAPPING[PeftType.LORA] = GPTQLoraModel
PEFT_TYPE_TO_CONFIG_MAPPING[PeftType.ADALORA] = GPTQAdaLoraConfig
PEFT_TYPE_TO_MODEL_MAPPING[PeftType.ADALORA] = GPTQAdaLoraModel
try:
yield
except:
PEFT_TYPE_TO_CONFIG_MAPPING[PeftType.LORA] = GPTQLoraConfig
PEFT_TYPE_TO_MODEL_MAPPING[PeftType.LORA] = GPTQLoraModel
PEFT_TYPE_TO_CONFIG_MAPPING[PeftType.ADALORA] = GPTQAdaLoraConfig
PEFT_TYPE_TO_MODEL_MAPPING[PeftType.ADALORA] = GPTQAdaLoraModel
raise
finally:
PEFT_TYPE_TO_CONFIG_MAPPING[PeftType.LORA] = GPTQLoraConfig
PEFT_TYPE_TO_MODEL_MAPPING[PeftType.LORA] = GPTQLoraModel
PEFT_TYPE_TO_CONFIG_MAPPING[PeftType.ADALORA] = GPTQAdaLoraConfig
PEFT_TYPE_TO_MODEL_MAPPING[PeftType.ADALORA] = GPTQAdaLoraModel
def get_gptq_peft_model(
model: BaseGPTQForCausalLM,
peft_config: PeftConfig = None,
model_id: str = None,
adapter_name: str = "default",
auto_find_all_linears: bool = True,
train_mode: bool = False,
):
if train_mode and not model.trainable:
model.enable_trainable_mode()
if train_mode and not peft_config:
raise ValueError("peft_config not specified when in train mode.")
if not train_mode and not model_id:
raise ValueError("model_id(where to load adapters) not specified when in inference mode.")
if model.fused_attn_module_type is not None and not model.injected_fused_attention:
peft_types = [PeftType.LORA.value, PeftType.ADALORA.value]
warnings.warn(
f"You can just ignore this warning if the peft type you use isn't in {peft_types}.\n"
f"{model.__class__.__name__} supports injecting fused attention but not enables this time. "
"If you are training adapters, you must also disable fused attention injection when loading quantized "
"base model at inference time, otherwise adapters may not be added to base model properly. "
"If you are loading adapters to do inference, you can reference to adapter's config file to check "
"whether the adapters are trained using base model that not enable fused attention injection."
)
if model.injected_fused_mlp:
raise NotImplementedError(
"GPTQ model that enables fused mlp injection is not supported to integrate with peft."
)
if train_mode:
peft_type = peft_config.peft_type
if not isinstance(peft_type, str):
peft_type = peft_type.value
if peft_type in [PeftType.LORA.value, PeftType.ADALORA.value]:
if auto_find_all_linears:
peft_config.target_modules = find_all_linear_names(model, ignore_lm_head=True)
if peft_type == PeftType.LORA.value and not isinstance(peft_config, GPTQLoraConfig):
peft_config = GPTQLoraConfig(**peft_config.to_dict())
if peft_type == PeftType.ADALORA.value and not isinstance(peft_config, GPTQAdaLoraConfig):
peft_config = GPTQAdaLoraConfig(**peft_config.to_dict())
peft_config.injected_fused_attention = model.injected_fused_attention
peft_config.injected_fused_mlp = model.injected_fused_mlp
if peft_type == PeftType.ADAPTION_PROMPT.value:
if peft_config.adapter_layers > model.config.num_hidden_layers:
warnings.warn(
f"model has only {model.config.num_hidden_layers} layers "
f"but adapter_layers is set to {peft_config.adapter_layers}, "
f"will reset value to {model.config.num_hidden_layers}."
)
peft_config.adapter_layers = model.config.num_hidden_layers
if model.injected_fused_attention:
raise NotImplementedError(
"model with fused attention injected isn't supported to use ADAPTION_PROMPT peft type yet."
)
with hijack_peft_mappings():
try:
if train_mode:
peft_model = get_peft_model(model.model, peft_config, adapter_name=adapter_name)
else:
peft_model = PeftModel.from_pretrained(model.model, model_id, adapter_name)
except:
raise
raise NotImplementedError(
f"{model.__class__.__name__} not support {peft_config.peft_type.value} peft type yet."
)
return peft_model
__all__ = [
"GPTQLoraConfig",
"GPTQLoraModel",
"GPTQAdaLoraConfig",
"GPTQAdaLoraModel",
"find_all_linear_names",
"get_gptq_peft_model",
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment