Unverified Commit 34085edc authored by Ilyas Moutawwakil's avatar Ilyas Moutawwakil Committed by GitHub
Browse files

Marlin symmetric quantization and inference (#320)


Co-authored-by: default avatarCasper <casperbh.96@gmail.com>
parent f018d2b7
......@@ -13,6 +13,7 @@ from transformers.modeling_utils import shard_checkpoint
from awq.modules.linear.gemm import WQLinear_GEMM
from awq.modules.linear.gemv import WQLinear_GEMV
from awq.modules.linear.marlin import WQLinear_Marlin, marlin_post_init
from awq.modules.linear.exllama import WQLinear_Exllama, exllama_post_init
from awq.modules.linear.exllamav2 import WQLinear_ExllamaV2, exllamav2_post_init
from awq.utils.module import (
......@@ -103,6 +104,7 @@ class BaseAWQForCausalLM(nn.Module):
tokenizer,
self.quant_config.w_bit,
self.quant_config.q_group_size,
self.quant_config.zero_point,
self.quant_config.version,
calib_data,
split,
......@@ -149,6 +151,7 @@ class BaseAWQForCausalLM(nn.Module):
# Save model and config files with empty state dict
self.model.config.quantization_config = self.quant_config.to_transformers_dict()
self.model.generation_config.do_sample = True
self.model.save_pretrained(save_dir, state_dict=EmptyModule().state_dict())
self.quant_config.save_pretrained(save_dir)
......@@ -302,7 +305,10 @@ class BaseAWQForCausalLM(nn.Module):
if fuse_layers:
self.fuse_layers(model)
if use_exllama:
if quant_config.version == "Marlin":
model = marlin_post_init(model)
elif use_exllama:
# creates q4 handle
model = exllama_post_init(model)
elif use_exllama_v2:
......@@ -375,7 +381,6 @@ class BaseAWQForCausalLM(nn.Module):
self, model, quant_config, version, use_exllama, use_exllama_v2
):
# Real quantization of weights
assert quant_config.zero_point, "We only support zero_point quantization now."
assert not (
version == "GEMV" and (use_exllama or use_exllama_v2)
), "Exllama kernels only support GEMM version."
......@@ -399,7 +404,9 @@ class BaseAWQForCausalLM(nn.Module):
# Replace nn.Linear with WQLinear
for name, module in named_linears.items():
if use_exllama:
if version == "Marlin":
q_linear_module = WQLinear_Marlin
elif use_exllama:
q_linear_module = WQLinear_Exllama
elif use_exllama_v2:
q_linear_module = WQLinear_ExllamaV2
......
......@@ -2,3 +2,4 @@ from .exllama import WQLinear_Exllama
from .exllamav2 import WQLinear_ExllamaV2
from .gemm import WQLinear_GEMM
from .gemv import WQLinear_GEMV
from .marlin import WQLinear_Marlin
\ No newline at end of file
......@@ -109,6 +109,11 @@ class WQLinear_Exllama(nn.Module):
"Please install them from https://github.com/casper-hansen/AutoAWQ_kernels"
)
assert EXL_INSTALLED, (
"ExllamaV2 kernels are not installed. "
"Please install AWQ compatible ExllamaV2 kernels from AutoAWQ_kernels."
)
input_dtype = x.dtype
out_shape = x.shape[:-1] + (self.out_features,)
......
......@@ -10,7 +10,6 @@ try:
except:
EXLV2_INSTALLED = False
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
none_tensor = torch.empty((1, 1), device="meta")
......@@ -23,7 +22,6 @@ class WQLinear_ExllamaV2(nn.Module):
raise NotImplementedError("Only 4-bit are supported for now.")
self.q_handle = None
self.q_tensors = None
self.w_bit = w_bit
self.in_features = in_features
......@@ -134,8 +132,8 @@ class WQLinear_ExllamaV2(nn.Module):
"Use exllamav2_post_init() on the whole model."
)
assert EXLV2_INSTALLED, (
"Exllama kernels could not be loaded. "
"Please install them from https://github.com/casper-hansen/AutoAWQ_kernels"
"ExllamaV2 kernels are not installed. "
"Please install AWQ compatible ExllamaV2 kernels from AutoAWQ_kernels."
)
input_dtype = x.dtype
......
import torch
import torch.nn as nn
import numpy as np
try:
import marlin_cuda # with CUDA kernels (AutoAWQ_kernels)
MARLIN_INSTALLED = True
except:
MARLIN_INSTALLED = False
def _get_perms():
perm = []
for i in range(32):
perm1 = []
col = i // 4
for block in [0, 1]:
for row in [
2 * (i % 4),
2 * (i % 4) + 1,
2 * (i % 4 + 4),
2 * (i % 4 + 4) + 1,
]:
perm1.append(16 * row + col + 8 * block)
for j in range(4):
perm.extend([p + 256 * j for p in perm1])
perm = np.array(perm)
interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
perm = perm.reshape((-1, 8))[:, interleave].ravel()
perm = torch.from_numpy(perm)
scale_perm = []
for i in range(8):
scale_perm.extend([i + 8 * j for j in range(8)])
scale_perm_single = []
for i in range(4):
scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
return perm, scale_perm, scale_perm_single
_perm, _scale_perm, _scale_perm_single = _get_perms()
class WQLinear_Marlin(nn.Module):
def __init__(self, w_bit, group_size, in_features, out_features, bias, dev):
super().__init__()
if w_bit not in [4]:
raise NotImplementedError("Only 4-bit are supported for now.")
self.w_bit = w_bit
self.in_features = in_features
self.out_features = out_features
self.group_size = group_size if group_size != -1 else in_features
self.max_par = 8 # partitioning for large inputs
# quick sanity check (make sure aligment)
assert self.in_features % self.group_size == 0
assert out_features % (32 // self.w_bit) == 0
######################################################
## These shapes are only specific for Marlin models ##
self.register_buffer(
"qweight",
torch.zeros(
(in_features // 16, out_features * 16 // 8),
dtype=torch.int32,
device=dev,
),
)
self.register_buffer(
"scales",
torch.zeros(
(in_features // group_size, out_features),
dtype=torch.float16,
device=dev,
),
)
######################################################
if bias:
self.register_buffer(
"bias",
torch.zeros(
(out_features),
dtype=torch.float16,
device=dev,
),
)
else:
self.bias = None
@classmethod
def from_linear(
cls,
linear,
w_bit,
group_size,
init_only=False,
scales=None,
zeros=None,
):
awq_linear = cls(
w_bit,
group_size,
linear.in_features,
linear.out_features,
linear.bias is not None,
linear.weight.device,
)
if init_only: # just prepare for loading sd
return awq_linear
assert zeros is None and scales is not None
tile = 16
maxq = 2**4 - 1
s = scales.t()
w = linear.weight.data.t()
if awq_linear.group_size != awq_linear.in_features:
w = w.reshape((-1, awq_linear.group_size, awq_linear.out_features))
w = w.permute(1, 0, 2)
w = w.reshape((awq_linear.group_size, -1))
s = s.reshape((1, -1))
w = torch.round(w / s).int()
w += (maxq + 1) // 2
w = torch.clamp(w, 0, maxq)
if awq_linear.group_size != awq_linear.in_features:
w = w.reshape((awq_linear.group_size, -1, awq_linear.out_features))
w = w.permute(1, 0, 2)
w = w.reshape(
(awq_linear.in_features, awq_linear.out_features)
).contiguous()
s = s.reshape((-1, len(_scale_perm)))[:, _scale_perm]
else:
s = s.reshape((-1, len(_scale_perm_single)))[:, _scale_perm_single]
s = s.reshape((-1, awq_linear.out_features)).contiguous()
w = w.reshape(
(
awq_linear.in_features // tile,
tile,
awq_linear.out_features // tile,
tile,
)
)
w = w.permute((0, 2, 1, 3))
w = w.reshape((awq_linear.in_features // tile, awq_linear.out_features * tile))
res = w
res = res.reshape((-1, _perm.numel()))[:, _perm].reshape(res.shape)
q = np.zeros((res.shape[0], res.shape[1] // 8), dtype=np.uint32)
res = res.cpu().numpy().astype(np.uint32)
for i in range(8):
q |= res[:, i::8] << 4 * i
q = torch.from_numpy(q.astype(np.int32)).to(w.device)
awq_linear.qweight[:] = q.to(awq_linear.qweight.device)
awq_linear.scales[:] = s.to(awq_linear.qweight.device)
if awq_linear.bias is not None:
awq_linear.bias[:] = linear.bias.data.to(awq_linear.bias.device)
return awq_linear
def post_init(self):
self.register_buffer(
"workspace",
torch.zeros(
self.out_features // 128 * self.max_par,
dtype=torch.int32,
device=self.qweight.device,
),
persistent=False,
)
@torch.no_grad()
def forward(self, x):
assert hasattr(self, "workspace"), (
"module.post_init() must be called before module.forward(). "
"Use marlin_post_init() on the whole model."
)
assert MARLIN_INSTALLED, (
"Marlin kernels are not installed. "
"Please install AWQ compatible Marlin kernels from AutoAWQ_kernels."
)
out_shape = x.shape[:-1] + (self.out_features,)
input_dtype = x.dtype
if input_dtype != torch.float16:
x = x.half()
x = x.view(-1, x.shape[-1])
out = torch.empty(
(x.shape[0], self.out_features),
dtype=torch.float16,
device=x.device,
)
marlin_cuda.mul(
x,
self.qweight,
out,
self.scales,
self.workspace,
-1, # thread_k
-1, # thread_n
-1, # sms
self.max_par,
)
if input_dtype != torch.float16:
out = out.to(dtype=input_dtype)
if self.bias is not None:
out.add_(self.bias)
return out.view(out_shape)
def extra_repr(self) -> str:
return (
"in_features={}, out_features={}, bias={}, w_bit={}, group_size={}".format(
self.in_features,
self.out_features,
self.bias is not None,
self.w_bit,
self.group_size,
)
)
def marlin_post_init(model):
for _, submodule in model.named_modules():
if isinstance(submodule, WQLinear_Marlin):
submodule.post_init()
return model
......@@ -4,82 +4,110 @@ import logging
import functools
import torch.nn as nn
from tqdm import tqdm
from typing import Dict, List
from typing import Dict, List, Optional
from collections import defaultdict
from awq.utils.calib_data import get_calib_dataset
from awq.quantize.scale import apply_scale, apply_clip
from awq.utils.utils import clear_memory, get_best_device
from awq.modules.linear.gemm import WQLinear_GEMM
from awq.modules.linear.gemv import WQLinear_GEMV
from awq.modules.linear.marlin import WQLinear_Marlin
from awq.utils.module import (
append_str_prefix,
get_op_name,
get_named_linears,
set_op_by_name,
exclude_layers_to_not_quantize
exclude_layers_to_not_quantize,
)
class AwqQuantizer:
def __init__(self, awq_model, model, tokenizer, w_bit, group_size, version,
calib_data, split, text_column, duo_scaling, modules_to_not_convert=None,
export_compatible=False) -> None:
def __init__(
self,
awq_model,
model,
tokenizer,
w_bit,
group_size,
zero_point,
version,
calib_data,
split,
text_column,
duo_scaling,
modules_to_not_convert=None,
export_compatible=False,
) -> None:
self.awq_model = awq_model
self.model = model
self.tokenizer = tokenizer
self.w_bit = w_bit
self.group_size = group_size
self.zero_point = zero_point
self.version = version
self.calib_data = calib_data
self.split = split
self.text_column = text_column
self.duo_scaling = duo_scaling
self.export_compatible = export_compatible
self.modules_to_not_convert = modules_to_not_convert if modules_to_not_convert is not None else []
self.modules_to_not_convert = (
modules_to_not_convert if modules_to_not_convert is not None else []
)
self.modules, self.module_kwargs, self.inps = self.init_quant()
def pseudo_quantize_tensor(self, w: torch.Tensor, get_scale_zp=False):
def pseudo_quantize_tensor(self, w: torch.Tensor):
org_w_shape = w.shape
if self.group_size > 0:
assert org_w_shape[-1] % self.group_size == 0
w = w.reshape(-1, self.group_size)
assert w.dim() == 2
assert torch.isnan(w).sum() == 0
# zero point quantization
max_val = w.amax(dim=1, keepdim=True)
min_val = w.amin(dim=1, keepdim=True)
max_int = 2 ** self.w_bit - 1
min_int = 0
scales = (max_val - min_val).clamp(min=1e-5) / max_int
zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)
if self.zero_point:
max_val = w.amax(dim=1, keepdim=True)
min_val = w.amin(dim=1, keepdim=True)
max_int = 2**self.w_bit - 1
min_int = 0
scales = (max_val - min_val).clamp(min=1e-5) / max_int
zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)
w = (
torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - zeros
) * scales
zeros = zeros.view(org_w_shape[0], -1)
else:
max_val = w.abs().amax(dim=1, keepdim=True)
max_val = max_val.clamp(min=1e-5)
max_int = 2 ** (self.w_bit - 1) - 1
min_int = -(2 ** (self.w_bit - 1))
scales = max_val / max_int
zeros = None
w = torch.clamp(torch.round(w / scales), min_int, max_int) * scales
assert torch.isnan(scales).sum() == 0
assert torch.isnan(w).sum() == 0
w = (torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - zeros) * scales
assert torch.isnan(w).sum() == 0
scales = scales.view(org_w_shape[0], -1)
w = w.reshape(org_w_shape)
if get_scale_zp:
return w, scales.view(w.shape[0], -1), zeros.view(w.shape[0], -1)
else:
return w
def pseudo_dequantize_tensor(self, w: nn.Linear, scales: torch.Tensor, zeros: torch.Tensor):
# get repeated count
repeat_count = w.weight.data.shape[-1] // zeros.shape[-1]
return w, scales, zeros
# get zeros and scales in correct shape
zeros = zeros.repeat(1, repeat_count).reshape(w.weight.data.shape)
def pseudo_dequantize_tensor(
self, w: nn.Linear, scales: torch.Tensor, zeros: Optional[torch.Tensor] = None
):
# get repeated count
repeat_count = w.weight.data.shape[-1] // scales.shape[-1]
scales = scales.repeat(1, repeat_count).reshape(w.weight.data.shape)
# dequantize
w = (w.weight.data - zeros) * scales
if self.zero_point:
zeros = zeros.repeat(1, repeat_count).reshape(w.weight.data.shape)
w = (w.weight.data - zeros) * scales
else:
w = w.weight.data * scales
return w
def quantize(self):
for i in tqdm(range(len(self.modules)), desc="AWQ"):
# Move module and inputs to correct device
......@@ -94,10 +122,14 @@ class AwqQuantizer:
common_device = next(self.modules[i].parameters()).device
if self.module_kwargs.get("position_ids") is not None:
self.module_kwargs["position_ids"] = self.module_kwargs["position_ids"].to(common_device)
self.module_kwargs["position_ids"] = self.module_kwargs[
"position_ids"
].to(common_device)
if self.module_kwargs.get("attention_mask") is not None:
self.module_kwargs["attention_mask"] = self.module_kwargs["attention_mask"].to(common_device)
self.module_kwargs["attention_mask"] = self.module_kwargs[
"attention_mask"
].to(common_device)
self.inps = self.inps.to(common_device)
......@@ -105,7 +137,9 @@ class AwqQuantizer:
named_linears = get_named_linears(self.modules[i])
# Filter out the linear layers we don't want to exclude
named_linears = exclude_layers_to_not_quantize(named_linears, self.modules_to_not_convert)
named_linears = exclude_layers_to_not_quantize(
named_linears, self.modules_to_not_convert
)
input_feat = self._get_input_feat(self.modules[i], named_linears)
clear_memory()
......@@ -114,53 +148,69 @@ class AwqQuantizer:
module_config: List[Dict] = self.awq_model.get_layers_for_scaling(
self.modules[i], input_feat, self.module_kwargs
)
scales_list = [self._search_best_scale(self.modules[i], **layer) for layer in module_config]
scales_list = [
self._search_best_scale(self.modules[i], **layer)
for layer in module_config
]
apply_scale(self.modules[i], scales_list, input_feat_dict=input_feat)
scales_list = append_str_prefix(scales_list, get_op_name(self.model, self.modules[i]) + ".")
scales_list = append_str_prefix(
scales_list, get_op_name(self.model, self.modules[i]) + "."
)
# [STEP 3]: Compute and apply clipping list
clip_list = self._search_best_clip(self.modules[i], named_linears, input_feat)
clip_list = self._search_best_clip(
self.modules[i], named_linears, input_feat
)
apply_clip(self.modules[i], clip_list)
clip_list = append_str_prefix(clip_list, get_op_name(self.model, self.modules[i]) + ".")
clip_list = append_str_prefix(
clip_list, get_op_name(self.model, self.modules[i]) + "."
)
# [STEP 4]: Quantize weights
if not self.export_compatible:
self._apply_quant(self.modules[i], named_linears)
clear_memory()
def pack(self):
for i in tqdm(range(len(self.modules)), desc="Packing"):
named_linears = get_named_linears(self.modules[i])
named_linears = exclude_layers_to_not_quantize(named_linears, self.modules_to_not_convert)
named_linears = exclude_layers_to_not_quantize(
named_linears, self.modules_to_not_convert
)
self._apply_quant(self.modules[i], named_linears)
clear_memory()
def _apply_quant(self, module, named_linears: Dict[str, nn.Linear]):
for name, linear_layer in named_linears.items():
# NOTE: small regression in perplexity if linear layer uses .cpu().float()
linear_layer = linear_layer.to(get_best_device()).half()
linear_layer.weight.data, scales, zeros = self.pseudo_quantize_tensor(
linear_layer.weight.data,
get_scale_zp=True
linear_layer.weight.data
)
if self.version == 'GEMM':
if self.version == "GEMM":
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
q_linear_module = WQLinear_GEMM
elif self.version == 'GEMV':
elif self.version == "GEMV":
q_linear_module = WQLinear_GEMV
elif self.version == "Marlin":
q_linear_module = WQLinear_Marlin
else:
raise ValueError(f"Unknown version {self.version}")
q_linear = q_linear_module.from_linear(
linear=linear_layer,
w_bit=self.w_bit,
group_size=self.group_size,
init_only=False,
scales=scales,
zeros=zeros
zeros=zeros,
)
linear_layer.cpu()
......@@ -169,14 +219,22 @@ class AwqQuantizer:
clear_memory()
@torch.no_grad()
def _search_best_scale(self, module, prev_op, layers: List[nn.Linear], inp: torch.Tensor, module2inspect=None, kwargs={}):
def _search_best_scale(
self,
module,
prev_op,
layers: List[nn.Linear],
inp: torch.Tensor,
module2inspect=None,
kwargs={},
):
if module2inspect is None:
assert len(layers) == 1
module2inspect = layers[0]
if "use_cache" in kwargs:
kwargs.pop("use_cache")
# Put x on the right device
inp = inp.to(next(module2inspect.parameters()).device)
......@@ -199,17 +257,28 @@ class AwqQuantizer:
fp16_output = module2inspect(inp, **module_kwargs)
if isinstance(fp16_output, tuple):
fp16_output = fp16_output[0]
# [STEP 4]: Compute loss
best_scales = self._compute_best_scale(
inp, w_max, x_max, module2inspect,
layers, fp16_output, module_kwargs
inp, w_max, x_max, module2inspect, layers, fp16_output, module_kwargs
)
return (get_op_name(module, prev_op), tuple([get_op_name(module, m) for m in layers]), best_scales)
def _compute_best_scale(self, x, w_max, x_max, module2inspect, linears2scale: List[nn.Linear],
fp16_output, kwargs={}):
return (
get_op_name(module, prev_op),
tuple([get_op_name(module, m) for m in layers]),
best_scales,
)
def _compute_best_scale(
self,
x,
w_max,
x_max,
module2inspect,
linears2scale: List[nn.Linear],
fp16_output,
kwargs={},
):
"""
Compute loss and select best scales
......@@ -223,21 +292,21 @@ class AwqQuantizer:
history = []
best_ratio = -1
best_scales = None
best_error = float('inf')
best_error = float("inf")
org_sd = {k: v.cpu() for k, v in module2inspect.state_dict().items()}
device = x.device
x_max = x_max.view(-1).to(device)
w_max = w_max.view(-1).to(device)
for ratio in range(n_grid):
# create new scales
ratio = ratio / n_grid
# NOTE: s^-1 * x is fused here, according to paper
if self.duo_scaling:
scales = (x_max.pow(ratio) / w_max.pow(1-ratio)).clamp(min=1e-4)
scales = (x_max.pow(ratio) / w_max.pow(1 - ratio)).clamp(min=1e-4)
else:
scales = x_max.pow(ratio).clamp(min=1e-4).view(-1)
scales = scales / (scales.max() * scales.min()).sqrt()
......@@ -246,15 +315,19 @@ class AwqQuantizer:
# Q(W * s)
for fc in linears2scale:
fc.weight.mul_(scales_view)
fc.weight.data = self.pseudo_quantize_tensor(fc.weight.data) / scales_view
fc.weight.data = (
self.pseudo_quantize_tensor(fc.weight.data)[0] / scales_view
)
# W * X
int_w_output = module2inspect(x, **kwargs)
if isinstance(int_w_output, tuple):
int_w_output = int_w_output[0]
# compute mean squared error (L2 norm)
loss = (fp16_output - int_w_output).float().pow(2).mean().item() # NOTE: float prevents overflow
loss = (
(fp16_output - int_w_output).float().pow(2).mean().item()
) # NOTE: float prevents overflow
history.append(loss)
if loss < best_error:
......@@ -284,30 +357,36 @@ class AwqQuantizer:
named_linears[name].to(get_best_device())
max_val = self._compute_best_clip(named_linears[name].weight, input_feat[name])
clip_list.append((name, max_val))
named_linears[name].cpu()
return clip_list
@torch.no_grad()
def _compute_best_clip(self, w: torch.Tensor, input_feat: torch.Tensor, n_grid=20, max_shrink=0.5, n_sample_token=512):
def _compute_best_clip(
self,
w: torch.Tensor,
input_feat: torch.Tensor,
n_grid=20,
max_shrink=0.5,
n_sample_token=512,
):
assert w.dim() == 2
org_w_shape = w.shape
# w [co, ci] -> [co, 1, n_group, group size]
# input_feat [n_token, ci] -> [1, n_token, n_group, group size]
group_size = self.group_size if self.group_size > 0 else w.shape[1]
group_size = self.group_size if self.group_size > 0 else org_w_shape[1]
input_feat = input_feat.view(-1, input_feat.shape[-1])
input_feat = input_feat.reshape(1, input_feat.shape[0], -1, group_size)
input_feat = input_feat[:, 0::input_feat.shape[1] // n_sample_token]
w = w.reshape(w.shape[0], 1, -1, group_size)
input_feat = input_feat[:, 0 :: input_feat.shape[1] // n_sample_token]
w = w.reshape(org_w_shape[0], 1, -1, group_size)
oc_batch_size = 256 if w.shape[0] % 256 == 0 else 64 # prevent OOM
assert w.shape[0] % oc_batch_size == 0
oc_batch_size = 256 if org_w_shape[0] % 256 == 0 else 64 # prevent OOM
assert org_w_shape[0] % oc_batch_size == 0
w_all = w
best_max_val_all = []
for i_b in range(w.shape[0] // oc_batch_size):
w = w_all[i_b * oc_batch_size: (i_b + 1) * oc_batch_size]
for i_b in range(org_w_shape[0] // oc_batch_size):
w = w_all[i_b * oc_batch_size : (i_b + 1) * oc_batch_size]
org_max_val = w.abs().amax(dim=-1, keepdim=True) # co, 1, n_group, 1
......@@ -318,9 +397,9 @@ class AwqQuantizer:
for i_s in range(int(max_shrink * n_grid)):
max_val = org_max_val * (1 - i_s / n_grid)
min_val = - max_val
min_val = -max_val
cur_w = torch.clamp(w, min_val, max_val)
q_w = self.pseudo_quantize_tensor(cur_w)
q_w = self.pseudo_quantize_tensor(cur_w)[0]
cur_out = (input_feat * q_w).sum(dim=-1)
# co, 1, n_group, 1
......@@ -339,11 +418,15 @@ class AwqQuantizer:
return best_max_val.squeeze(1)
def init_quant(self, n_samples=128, seqlen=512):
def init_quant(self, n_samples=2, seqlen=512):
modules = self.awq_model.get_model_layers(self.model)
samples = get_calib_dataset(
data=self.calib_data, tokenizer=self.tokenizer, n_samples=n_samples, block_size=seqlen,
split=self.split, text_column=self.text_column
data=self.calib_data,
tokenizer=self.tokenizer,
n_samples=n_samples,
block_size=seqlen,
split=self.split,
text_column=self.text_column,
)
samples = torch.cat(samples, dim=0)
......@@ -353,7 +436,7 @@ class AwqQuantizer:
best_device = get_best_device()
modules[0] = modules[0].to(best_device)
self.awq_model.move_embed(self.model, best_device)
# get input and kwargs to layer 0
# with_kwargs is only supported in PyTorch 2.0
# use this Catcher hack for now
......@@ -381,7 +464,7 @@ class AwqQuantizer:
self.model(samples.to(next(self.model.parameters()).device))
except ValueError: # work with early exit
pass
# Update the layer kwargs with `prepare_inputs_for_generation` method
# that takes care of everything to avoid unexpected errors.
layer_kwargs = self.model.prepare_inputs_for_generation(samples, **layer_kwargs)
......@@ -394,14 +477,14 @@ class AwqQuantizer:
modules[0] = modules[0].cpu()
self.awq_model.move_embed(self.model, "cpu")
clear_memory()
if layer_kwargs.get("attention_mask") is not None:
layer_kwargs["attention_mask"] = layer_kwargs["attention_mask"].to(best_device)
return modules, layer_kwargs, inps
def _get_input_feat(self, layer, named_linears):
# firstly, get input features of all linear layers
def cache_input_hook(m, x, y, name, feat_dict):
......@@ -414,15 +497,20 @@ class AwqQuantizer:
# FIXME: Workaround for Mixtral to use block_sparse_moe input features
if self.awq_model.model_type == "mixtral":
named_linears = {**named_linears, "block_sparse_moe": layer.block_sparse_moe}
named_linears = {
**named_linears,
"block_sparse_moe": layer.block_sparse_moe,
}
for name in named_linears:
handles.append(named_linears[name].register_forward_hook(
functools.partial(cache_input_hook, name=name,
feat_dict=input_feat)))
handles.append(
named_linears[name].register_forward_hook(
functools.partial(cache_input_hook, name=name, feat_dict=input_feat)
)
)
self.inps = self.inps.to(next(layer.parameters()).device) # in case multi-gpu
# get output as next layer's input
# Sanitize the kwargs in case we use transformers version that contains
# kwargs that are not handled by the module.
# Useful for trust_remote_code models.
......@@ -433,15 +521,14 @@ class AwqQuantizer:
h.remove()
# now solve for scaling and clipping
input_feat = {k: torch.cat(v, dim=0) for k, v in input_feat.items()}
return input_feat
return input_feat
def _sanitize_kwargs(self, inputs_kwargs, module):
"""
Remove the arguments that are not supported in the module's
forward pass to avoid breaking behaviour between different versions
of transformers.
of transformers.
Args:
inputs_kwargs (`dict`):
......@@ -451,7 +538,7 @@ class AwqQuantizer:
"""
module_signature = inspect.signature(module.forward).parameters
sanitized_kwargs = {}
for k, v in inputs_kwargs.items():
for k, v in inputs_kwargs.items():
if k in module_signature:
sanitized_kwargs[k] = v
return sanitized_kwargs
\ No newline at end of file
return sanitized_kwargs
import torch
from awq.modules.linear.gemm import WQLinear_GEMM
from awq.modules.linear.gemv import WQLinear_GEMV
from awq.modules.linear.marlin import WQLinear_Marlin
from awq.modules.linear.exllama import WQLinear_Exllama
from awq.modules.linear.exllamav2 import WQLinear_ExllamaV2
......@@ -58,8 +60,10 @@ def fuse_qkv(module, q_proj, k_proj, v_proj):
q_linear = WQLinear_GEMM
elif isinstance(q_proj, WQLinear_Exllama):
q_linear = WQLinear_Exllama
else:
elif isinstance(q_proj, WQLinear_ExllamaV2):
q_linear = WQLinear_ExllamaV2
elif isinstance(q_proj, WQLinear_Marlin):
q_linear = WQLinear_Marlin
qkv_layer = q_linear(
q_proj.w_bit,
......@@ -87,7 +91,11 @@ def fuse_qkv(module, q_proj, k_proj, v_proj):
qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
elif isinstance(q_proj, WQLinear_Marlin):
qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
# workspace is created in post_init
qkv_layer.bias = bias
return qkv_layer
......
import torch
from typing import List
Q_BITS = 4
STORAGE_BITS = 32
PACK_NUM = STORAGE_BITS // Q_BITS
ORDINAL_PACK_ORDER = [0, 1, 2, 3, 4, 5, 6, 7]
AWQ_PACK_ORDER = [0, 2, 4, 6, 1, 3, 5, 7]
REVERSE_AWQ_PACK_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
def pack(imatrix: torch.Tensor, direction: str = "column"):
"""
Packs a 4-bit integer matrix into a packed 32-bit integer matrix.
Args:
imatrix (torch.Tensor): matrix of integers
direction (str): direction of packing, either "column" or "row"
Returns:
qmatrix (torch.Tensor): packed matrix of integers
"""
shifts = torch.arange(0, STORAGE_BITS, Q_BITS, device=imatrix.device)
imatrix = imatrix.to(torch.int8)
imatrix = torch.bitwise_and(imatrix, 0x0F) # eventually correct overflow
if direction == "column":
imatrix = imatrix.view(-1, imatrix.shape[1] // PACK_NUM, PACK_NUM)
qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, None, :]).sum(dim=-1)
elif direction == "row":
imatrix = imatrix.view(imatrix.shape[0] // PACK_NUM, PACK_NUM, -1)
qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, :, None]).sum(dim=1)
qmatrix = qmatrix.to(torch.int32)
return qmatrix
def unpack(qmatrix: torch.Tensor, direction: str = "column"):
"""
Unpacks a 32-bit packed integer matrix into a 4-bit integer matrix.
Args:
qmatrix (torch.Tensor): matrix of packed integers
direction (str): direction of unpacking, either "column" or "row"
Returns:
imatrix (torch.Tensor): matrix of integers
"""
shifts = torch.arange(0, STORAGE_BITS, Q_BITS, device=qmatrix.device)
if direction == "column":
imatrix = torch.bitwise_right_shift(
qmatrix[:, :, None], shifts[None, None, :]
).view(qmatrix.shape[0], -1)
elif direction == "row":
imatrix = torch.bitwise_right_shift(
qmatrix[:, None, :], shifts[None, :, None]
).view(-1, qmatrix.shape[-1])
imatrix = imatrix.to(torch.int8) & 0x0F # eventually correct overflow
return imatrix
def quantize(fmatrix, scales, zeros, group_size):
"""
Quantizes a matrix of 16-bit floats into a matrix of 4-bit integers.
Args:
fmatrix (torch.Tensor): matrix of 16-bit floats
scales (torch.Tensor): matrix of 16-bit floats
zeros (torch.Tensor): matrix of 4-bit integers
group_size (int): group size
Returns:
imatrix (torch.Tensor): matrix of 4-bit integers
"""
zeros = zeros.to(torch.int8) & 0x0F
imatrix = torch.round(
(
fmatrix / scales.repeat_interleave(group_size, dim=0)
+ zeros.repeat_interleave(group_size, dim=0)
)
)
imatrix = imatrix.to(torch.int8) & 0x0F
return imatrix
def dequantize(imatrix, scales, zeros, group_size):
"""
Dequantizes a 4-bit integer matrix into a float matrix.
Args:
imatrix (torch.Tensor): matrix of 4-bit integers
scales (torch.Tensor): matrix of 16-bit floats
zeros (torch.Tensor): matrix of 4-bit integers
group_size (int): group size
Returns:
fmatrix (torch.Tensor): matrix of 16-bit floats
"""
zeros = zeros.to(torch.int8) & 0x0F
imatrix = imatrix.to(torch.int8) & 0x0F
fmatrix = (
imatrix - zeros.repeat_interleave(group_size, dim=0)
) * scales.repeat_interleave(group_size, dim=0)
fmatrix = fmatrix.to(torch.float16)
return fmatrix
def apply_order(
imatrix: torch.Tensor,
direction: str = "column",
order: List[int] = ORDINAL_PACK_ORDER,
):
"""
Applies the order to a 4-bit integer matrix.
Args:
imatrix (torch.Tensor): matrix of integers
direction (str): direction of applying order, either "column" or "row"
order (List[int]): order to apply, default is ordinal packing order
Returns:
imatrix (torch.Tensor): matrix of integers
"""
if direction == "column":
imatrix = imatrix.view(-1, PACK_NUM)[:, order].view(imatrix.shape)
elif direction == "row":
imatrix = imatrix.view(PACK_NUM, -1)[order, :].view(imatrix.shape)
return imatrix
def awq_to_exllama(qweight, qzeros):
# awq uses column packing for both weights and zeros
izeros = unpack(qzeros, direction="column")
iweights = unpack(qweight, direction="column")
# Reverse the order of the iweight and izeros tensors
izeros = apply_order(izeros, direction="column", order=REVERSE_AWQ_PACK_ORDER)
iweights = apply_order(iweights, direction="column", order=REVERSE_AWQ_PACK_ORDER)
# Subtract 1 from the izeros tensor (exllama adds 1 during inference)
izeros = izeros - 1
# exllama uses row packing for weights and column packing for zeros
qzeros = pack(izeros, direction="column")
qweight = pack(iweights, direction="row")
return qweight, qzeros
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer, TextStreamer
quant_path = "TheBloke/Mistral-7B-Instruct-v0.2-AWQ"
# Load model
......
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer, TextStreamer
quant_path = "IlyasMoutawwakil/vicuna-7b-v1.5-awq-marlin"
# Load model
model = AutoAWQForCausalLM.from_quantized(quant_path, fuse_layers=False)
tokenizer = AutoTokenizer.from_pretrained(quant_path, trust_remote_code=True)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# Convert prompt to tokens
prompt_template = """\
<|system|>
</s>
<|user|>
{prompt}</s>
<|assistant|>"""
prompt = "You're standing on the surface of the Earth. "\
"You walk one mile south, one mile west and one mile north. "\
"You end up exactly where you started. Where are you?"
tokens = tokenizer(
prompt_template.format(prompt=prompt),
return_tensors='pt'
).input_ids.cuda()
# Generate output
generation_output = model.generate(
tokens,
streamer=streamer,
max_new_tokens=512
)
\ No newline at end of file
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer
model_path = 'lmsys/vicuna-7b-v1.5'
quant_path = 'vicuna-7b-v1.5-awq-marlin'
quant_config = { "zero_point": False, "q_group_size": 128, "w_bit": 4, "version": "Marlin" }
# Load model
# NOTE: pass safetensors=True to load safetensors
model = AutoAWQForCausalLM.from_pretrained(
model_path, **{"low_cpu_mem_usage": True, "use_cache": False}
)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# Quantize
model.quantize(tokenizer, quant_config=quant_config)
# Save quantized model
model.save_quantized(quant_path)
tokenizer.save_pretrained(quant_path)
print(f'Model is quantized and saved at "{quant_path}"')
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment