Unverified Commit 34085edc authored by Ilyas Moutawwakil's avatar Ilyas Moutawwakil Committed by GitHub
Browse files

Marlin symmetric quantization and inference (#320)


Co-authored-by: default avatarCasper <casperbh.96@gmail.com>
parent f018d2b7
...@@ -13,6 +13,7 @@ from transformers.modeling_utils import shard_checkpoint ...@@ -13,6 +13,7 @@ from transformers.modeling_utils import shard_checkpoint
from awq.modules.linear.gemm import WQLinear_GEMM from awq.modules.linear.gemm import WQLinear_GEMM
from awq.modules.linear.gemv import WQLinear_GEMV from awq.modules.linear.gemv import WQLinear_GEMV
from awq.modules.linear.marlin import WQLinear_Marlin, marlin_post_init
from awq.modules.linear.exllama import WQLinear_Exllama, exllama_post_init from awq.modules.linear.exllama import WQLinear_Exllama, exllama_post_init
from awq.modules.linear.exllamav2 import WQLinear_ExllamaV2, exllamav2_post_init from awq.modules.linear.exllamav2 import WQLinear_ExllamaV2, exllamav2_post_init
from awq.utils.module import ( from awq.utils.module import (
...@@ -103,6 +104,7 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -103,6 +104,7 @@ class BaseAWQForCausalLM(nn.Module):
tokenizer, tokenizer,
self.quant_config.w_bit, self.quant_config.w_bit,
self.quant_config.q_group_size, self.quant_config.q_group_size,
self.quant_config.zero_point,
self.quant_config.version, self.quant_config.version,
calib_data, calib_data,
split, split,
...@@ -149,6 +151,7 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -149,6 +151,7 @@ class BaseAWQForCausalLM(nn.Module):
# Save model and config files with empty state dict # Save model and config files with empty state dict
self.model.config.quantization_config = self.quant_config.to_transformers_dict() self.model.config.quantization_config = self.quant_config.to_transformers_dict()
self.model.generation_config.do_sample = True
self.model.save_pretrained(save_dir, state_dict=EmptyModule().state_dict()) self.model.save_pretrained(save_dir, state_dict=EmptyModule().state_dict())
self.quant_config.save_pretrained(save_dir) self.quant_config.save_pretrained(save_dir)
...@@ -302,7 +305,10 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -302,7 +305,10 @@ class BaseAWQForCausalLM(nn.Module):
if fuse_layers: if fuse_layers:
self.fuse_layers(model) self.fuse_layers(model)
if use_exllama: if quant_config.version == "Marlin":
model = marlin_post_init(model)
elif use_exllama:
# creates q4 handle # creates q4 handle
model = exllama_post_init(model) model = exllama_post_init(model)
elif use_exllama_v2: elif use_exllama_v2:
...@@ -375,7 +381,6 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -375,7 +381,6 @@ class BaseAWQForCausalLM(nn.Module):
self, model, quant_config, version, use_exllama, use_exllama_v2 self, model, quant_config, version, use_exllama, use_exllama_v2
): ):
# Real quantization of weights # Real quantization of weights
assert quant_config.zero_point, "We only support zero_point quantization now."
assert not ( assert not (
version == "GEMV" and (use_exllama or use_exllama_v2) version == "GEMV" and (use_exllama or use_exllama_v2)
), "Exllama kernels only support GEMM version." ), "Exllama kernels only support GEMM version."
...@@ -399,7 +404,9 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -399,7 +404,9 @@ class BaseAWQForCausalLM(nn.Module):
# Replace nn.Linear with WQLinear # Replace nn.Linear with WQLinear
for name, module in named_linears.items(): for name, module in named_linears.items():
if use_exllama: if version == "Marlin":
q_linear_module = WQLinear_Marlin
elif use_exllama:
q_linear_module = WQLinear_Exllama q_linear_module = WQLinear_Exllama
elif use_exllama_v2: elif use_exllama_v2:
q_linear_module = WQLinear_ExllamaV2 q_linear_module = WQLinear_ExllamaV2
......
...@@ -2,3 +2,4 @@ from .exllama import WQLinear_Exllama ...@@ -2,3 +2,4 @@ from .exllama import WQLinear_Exllama
from .exllamav2 import WQLinear_ExllamaV2 from .exllamav2 import WQLinear_ExllamaV2
from .gemm import WQLinear_GEMM from .gemm import WQLinear_GEMM
from .gemv import WQLinear_GEMV from .gemv import WQLinear_GEMV
from .marlin import WQLinear_Marlin
\ No newline at end of file
...@@ -109,6 +109,11 @@ class WQLinear_Exllama(nn.Module): ...@@ -109,6 +109,11 @@ class WQLinear_Exllama(nn.Module):
"Please install them from https://github.com/casper-hansen/AutoAWQ_kernels" "Please install them from https://github.com/casper-hansen/AutoAWQ_kernels"
) )
assert EXL_INSTALLED, (
"ExllamaV2 kernels are not installed. "
"Please install AWQ compatible ExllamaV2 kernels from AutoAWQ_kernels."
)
input_dtype = x.dtype input_dtype = x.dtype
out_shape = x.shape[:-1] + (self.out_features,) out_shape = x.shape[:-1] + (self.out_features,)
......
...@@ -10,7 +10,6 @@ try: ...@@ -10,7 +10,6 @@ try:
except: except:
EXLV2_INSTALLED = False EXLV2_INSTALLED = False
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension # Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
none_tensor = torch.empty((1, 1), device="meta") none_tensor = torch.empty((1, 1), device="meta")
...@@ -23,7 +22,6 @@ class WQLinear_ExllamaV2(nn.Module): ...@@ -23,7 +22,6 @@ class WQLinear_ExllamaV2(nn.Module):
raise NotImplementedError("Only 4-bit are supported for now.") raise NotImplementedError("Only 4-bit are supported for now.")
self.q_handle = None self.q_handle = None
self.q_tensors = None
self.w_bit = w_bit self.w_bit = w_bit
self.in_features = in_features self.in_features = in_features
...@@ -134,8 +132,8 @@ class WQLinear_ExllamaV2(nn.Module): ...@@ -134,8 +132,8 @@ class WQLinear_ExllamaV2(nn.Module):
"Use exllamav2_post_init() on the whole model." "Use exllamav2_post_init() on the whole model."
) )
assert EXLV2_INSTALLED, ( assert EXLV2_INSTALLED, (
"Exllama kernels could not be loaded. " "ExllamaV2 kernels are not installed. "
"Please install them from https://github.com/casper-hansen/AutoAWQ_kernels" "Please install AWQ compatible ExllamaV2 kernels from AutoAWQ_kernels."
) )
input_dtype = x.dtype input_dtype = x.dtype
......
import torch
import torch.nn as nn
import numpy as np
try:
import marlin_cuda # with CUDA kernels (AutoAWQ_kernels)
MARLIN_INSTALLED = True
except:
MARLIN_INSTALLED = False
def _get_perms():
perm = []
for i in range(32):
perm1 = []
col = i // 4
for block in [0, 1]:
for row in [
2 * (i % 4),
2 * (i % 4) + 1,
2 * (i % 4 + 4),
2 * (i % 4 + 4) + 1,
]:
perm1.append(16 * row + col + 8 * block)
for j in range(4):
perm.extend([p + 256 * j for p in perm1])
perm = np.array(perm)
interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
perm = perm.reshape((-1, 8))[:, interleave].ravel()
perm = torch.from_numpy(perm)
scale_perm = []
for i in range(8):
scale_perm.extend([i + 8 * j for j in range(8)])
scale_perm_single = []
for i in range(4):
scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
return perm, scale_perm, scale_perm_single
_perm, _scale_perm, _scale_perm_single = _get_perms()
class WQLinear_Marlin(nn.Module):
def __init__(self, w_bit, group_size, in_features, out_features, bias, dev):
super().__init__()
if w_bit not in [4]:
raise NotImplementedError("Only 4-bit are supported for now.")
self.w_bit = w_bit
self.in_features = in_features
self.out_features = out_features
self.group_size = group_size if group_size != -1 else in_features
self.max_par = 8 # partitioning for large inputs
# quick sanity check (make sure aligment)
assert self.in_features % self.group_size == 0
assert out_features % (32 // self.w_bit) == 0
######################################################
## These shapes are only specific for Marlin models ##
self.register_buffer(
"qweight",
torch.zeros(
(in_features // 16, out_features * 16 // 8),
dtype=torch.int32,
device=dev,
),
)
self.register_buffer(
"scales",
torch.zeros(
(in_features // group_size, out_features),
dtype=torch.float16,
device=dev,
),
)
######################################################
if bias:
self.register_buffer(
"bias",
torch.zeros(
(out_features),
dtype=torch.float16,
device=dev,
),
)
else:
self.bias = None
@classmethod
def from_linear(
cls,
linear,
w_bit,
group_size,
init_only=False,
scales=None,
zeros=None,
):
awq_linear = cls(
w_bit,
group_size,
linear.in_features,
linear.out_features,
linear.bias is not None,
linear.weight.device,
)
if init_only: # just prepare for loading sd
return awq_linear
assert zeros is None and scales is not None
tile = 16
maxq = 2**4 - 1
s = scales.t()
w = linear.weight.data.t()
if awq_linear.group_size != awq_linear.in_features:
w = w.reshape((-1, awq_linear.group_size, awq_linear.out_features))
w = w.permute(1, 0, 2)
w = w.reshape((awq_linear.group_size, -1))
s = s.reshape((1, -1))
w = torch.round(w / s).int()
w += (maxq + 1) // 2
w = torch.clamp(w, 0, maxq)
if awq_linear.group_size != awq_linear.in_features:
w = w.reshape((awq_linear.group_size, -1, awq_linear.out_features))
w = w.permute(1, 0, 2)
w = w.reshape(
(awq_linear.in_features, awq_linear.out_features)
).contiguous()
s = s.reshape((-1, len(_scale_perm)))[:, _scale_perm]
else:
s = s.reshape((-1, len(_scale_perm_single)))[:, _scale_perm_single]
s = s.reshape((-1, awq_linear.out_features)).contiguous()
w = w.reshape(
(
awq_linear.in_features // tile,
tile,
awq_linear.out_features // tile,
tile,
)
)
w = w.permute((0, 2, 1, 3))
w = w.reshape((awq_linear.in_features // tile, awq_linear.out_features * tile))
res = w
res = res.reshape((-1, _perm.numel()))[:, _perm].reshape(res.shape)
q = np.zeros((res.shape[0], res.shape[1] // 8), dtype=np.uint32)
res = res.cpu().numpy().astype(np.uint32)
for i in range(8):
q |= res[:, i::8] << 4 * i
q = torch.from_numpy(q.astype(np.int32)).to(w.device)
awq_linear.qweight[:] = q.to(awq_linear.qweight.device)
awq_linear.scales[:] = s.to(awq_linear.qweight.device)
if awq_linear.bias is not None:
awq_linear.bias[:] = linear.bias.data.to(awq_linear.bias.device)
return awq_linear
def post_init(self):
self.register_buffer(
"workspace",
torch.zeros(
self.out_features // 128 * self.max_par,
dtype=torch.int32,
device=self.qweight.device,
),
persistent=False,
)
@torch.no_grad()
def forward(self, x):
assert hasattr(self, "workspace"), (
"module.post_init() must be called before module.forward(). "
"Use marlin_post_init() on the whole model."
)
assert MARLIN_INSTALLED, (
"Marlin kernels are not installed. "
"Please install AWQ compatible Marlin kernels from AutoAWQ_kernels."
)
out_shape = x.shape[:-1] + (self.out_features,)
input_dtype = x.dtype
if input_dtype != torch.float16:
x = x.half()
x = x.view(-1, x.shape[-1])
out = torch.empty(
(x.shape[0], self.out_features),
dtype=torch.float16,
device=x.device,
)
marlin_cuda.mul(
x,
self.qweight,
out,
self.scales,
self.workspace,
-1, # thread_k
-1, # thread_n
-1, # sms
self.max_par,
)
if input_dtype != torch.float16:
out = out.to(dtype=input_dtype)
if self.bias is not None:
out.add_(self.bias)
return out.view(out_shape)
def extra_repr(self) -> str:
return (
"in_features={}, out_features={}, bias={}, w_bit={}, group_size={}".format(
self.in_features,
self.out_features,
self.bias is not None,
self.w_bit,
self.group_size,
)
)
def marlin_post_init(model):
for _, submodule in model.named_modules():
if isinstance(submodule, WQLinear_Marlin):
submodule.post_init()
return model
...@@ -4,79 +4,107 @@ import logging ...@@ -4,79 +4,107 @@ import logging
import functools import functools
import torch.nn as nn import torch.nn as nn
from tqdm import tqdm from tqdm import tqdm
from typing import Dict, List from typing import Dict, List, Optional
from collections import defaultdict from collections import defaultdict
from awq.utils.calib_data import get_calib_dataset from awq.utils.calib_data import get_calib_dataset
from awq.quantize.scale import apply_scale, apply_clip from awq.quantize.scale import apply_scale, apply_clip
from awq.utils.utils import clear_memory, get_best_device from awq.utils.utils import clear_memory, get_best_device
from awq.modules.linear.gemm import WQLinear_GEMM from awq.modules.linear.gemm import WQLinear_GEMM
from awq.modules.linear.gemv import WQLinear_GEMV from awq.modules.linear.gemv import WQLinear_GEMV
from awq.modules.linear.marlin import WQLinear_Marlin
from awq.utils.module import ( from awq.utils.module import (
append_str_prefix, append_str_prefix,
get_op_name, get_op_name,
get_named_linears, get_named_linears,
set_op_by_name, set_op_by_name,
exclude_layers_to_not_quantize exclude_layers_to_not_quantize,
) )
class AwqQuantizer: class AwqQuantizer:
def __init__(self, awq_model, model, tokenizer, w_bit, group_size, version, def __init__(
calib_data, split, text_column, duo_scaling, modules_to_not_convert=None, self,
export_compatible=False) -> None: awq_model,
model,
tokenizer,
w_bit,
group_size,
zero_point,
version,
calib_data,
split,
text_column,
duo_scaling,
modules_to_not_convert=None,
export_compatible=False,
) -> None:
self.awq_model = awq_model self.awq_model = awq_model
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.w_bit = w_bit self.w_bit = w_bit
self.group_size = group_size self.group_size = group_size
self.zero_point = zero_point
self.version = version self.version = version
self.calib_data = calib_data self.calib_data = calib_data
self.split = split self.split = split
self.text_column = text_column self.text_column = text_column
self.duo_scaling = duo_scaling self.duo_scaling = duo_scaling
self.export_compatible = export_compatible self.export_compatible = export_compatible
self.modules_to_not_convert = modules_to_not_convert if modules_to_not_convert is not None else [] self.modules_to_not_convert = (
modules_to_not_convert if modules_to_not_convert is not None else []
)
self.modules, self.module_kwargs, self.inps = self.init_quant() self.modules, self.module_kwargs, self.inps = self.init_quant()
def pseudo_quantize_tensor(self, w: torch.Tensor, get_scale_zp=False): def pseudo_quantize_tensor(self, w: torch.Tensor):
org_w_shape = w.shape org_w_shape = w.shape
if self.group_size > 0: if self.group_size > 0:
assert org_w_shape[-1] % self.group_size == 0 assert org_w_shape[-1] % self.group_size == 0
w = w.reshape(-1, self.group_size) w = w.reshape(-1, self.group_size)
assert w.dim() == 2 assert w.dim() == 2
assert torch.isnan(w).sum() == 0
# zero point quantization # zero point quantization
if self.zero_point:
max_val = w.amax(dim=1, keepdim=True) max_val = w.amax(dim=1, keepdim=True)
min_val = w.amin(dim=1, keepdim=True) min_val = w.amin(dim=1, keepdim=True)
max_int = 2 ** self.w_bit - 1 max_int = 2**self.w_bit - 1
min_int = 0 min_int = 0
scales = (max_val - min_val).clamp(min=1e-5) / max_int scales = (max_val - min_val).clamp(min=1e-5) / max_int
zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int) zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)
w = (
torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - zeros
) * scales
zeros = zeros.view(org_w_shape[0], -1)
else:
max_val = w.abs().amax(dim=1, keepdim=True)
max_val = max_val.clamp(min=1e-5)
max_int = 2 ** (self.w_bit - 1) - 1
min_int = -(2 ** (self.w_bit - 1))
scales = max_val / max_int
zeros = None
w = torch.clamp(torch.round(w / scales), min_int, max_int) * scales
assert torch.isnan(scales).sum() == 0 assert torch.isnan(scales).sum() == 0
assert torch.isnan(w).sum() == 0 assert torch.isnan(w).sum() == 0
w = (torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - zeros) * scales scales = scales.view(org_w_shape[0], -1)
assert torch.isnan(w).sum() == 0
w = w.reshape(org_w_shape) w = w.reshape(org_w_shape)
if get_scale_zp: return w, scales, zeros
return w, scales.view(w.shape[0], -1), zeros.view(w.shape[0], -1)
else:
return w
def pseudo_dequantize_tensor(self, w: nn.Linear, scales: torch.Tensor, zeros: torch.Tensor): def pseudo_dequantize_tensor(
self, w: nn.Linear, scales: torch.Tensor, zeros: Optional[torch.Tensor] = None
):
# get repeated count # get repeated count
repeat_count = w.weight.data.shape[-1] // zeros.shape[-1] repeat_count = w.weight.data.shape[-1] // scales.shape[-1]
# get zeros and scales in correct shape
zeros = zeros.repeat(1, repeat_count).reshape(w.weight.data.shape)
scales = scales.repeat(1, repeat_count).reshape(w.weight.data.shape) scales = scales.repeat(1, repeat_count).reshape(w.weight.data.shape)
# dequantize # dequantize
if self.zero_point:
zeros = zeros.repeat(1, repeat_count).reshape(w.weight.data.shape)
w = (w.weight.data - zeros) * scales w = (w.weight.data - zeros) * scales
else:
w = w.weight.data * scales
return w return w
...@@ -94,10 +122,14 @@ class AwqQuantizer: ...@@ -94,10 +122,14 @@ class AwqQuantizer:
common_device = next(self.modules[i].parameters()).device common_device = next(self.modules[i].parameters()).device
if self.module_kwargs.get("position_ids") is not None: if self.module_kwargs.get("position_ids") is not None:
self.module_kwargs["position_ids"] = self.module_kwargs["position_ids"].to(common_device) self.module_kwargs["position_ids"] = self.module_kwargs[
"position_ids"
].to(common_device)
if self.module_kwargs.get("attention_mask") is not None: if self.module_kwargs.get("attention_mask") is not None:
self.module_kwargs["attention_mask"] = self.module_kwargs["attention_mask"].to(common_device) self.module_kwargs["attention_mask"] = self.module_kwargs[
"attention_mask"
].to(common_device)
self.inps = self.inps.to(common_device) self.inps = self.inps.to(common_device)
...@@ -105,7 +137,9 @@ class AwqQuantizer: ...@@ -105,7 +137,9 @@ class AwqQuantizer:
named_linears = get_named_linears(self.modules[i]) named_linears = get_named_linears(self.modules[i])
# Filter out the linear layers we don't want to exclude # Filter out the linear layers we don't want to exclude
named_linears = exclude_layers_to_not_quantize(named_linears, self.modules_to_not_convert) named_linears = exclude_layers_to_not_quantize(
named_linears, self.modules_to_not_convert
)
input_feat = self._get_input_feat(self.modules[i], named_linears) input_feat = self._get_input_feat(self.modules[i], named_linears)
clear_memory() clear_memory()
...@@ -114,14 +148,23 @@ class AwqQuantizer: ...@@ -114,14 +148,23 @@ class AwqQuantizer:
module_config: List[Dict] = self.awq_model.get_layers_for_scaling( module_config: List[Dict] = self.awq_model.get_layers_for_scaling(
self.modules[i], input_feat, self.module_kwargs self.modules[i], input_feat, self.module_kwargs
) )
scales_list = [self._search_best_scale(self.modules[i], **layer) for layer in module_config] scales_list = [
self._search_best_scale(self.modules[i], **layer)
for layer in module_config
]
apply_scale(self.modules[i], scales_list, input_feat_dict=input_feat) apply_scale(self.modules[i], scales_list, input_feat_dict=input_feat)
scales_list = append_str_prefix(scales_list, get_op_name(self.model, self.modules[i]) + ".") scales_list = append_str_prefix(
scales_list, get_op_name(self.model, self.modules[i]) + "."
)
# [STEP 3]: Compute and apply clipping list # [STEP 3]: Compute and apply clipping list
clip_list = self._search_best_clip(self.modules[i], named_linears, input_feat) clip_list = self._search_best_clip(
self.modules[i], named_linears, input_feat
)
apply_clip(self.modules[i], clip_list) apply_clip(self.modules[i], clip_list)
clip_list = append_str_prefix(clip_list, get_op_name(self.model, self.modules[i]) + ".") clip_list = append_str_prefix(
clip_list, get_op_name(self.model, self.modules[i]) + "."
)
# [STEP 4]: Quantize weights # [STEP 4]: Quantize weights
if not self.export_compatible: if not self.export_compatible:
...@@ -132,7 +175,9 @@ class AwqQuantizer: ...@@ -132,7 +175,9 @@ class AwqQuantizer:
def pack(self): def pack(self):
for i in tqdm(range(len(self.modules)), desc="Packing"): for i in tqdm(range(len(self.modules)), desc="Packing"):
named_linears = get_named_linears(self.modules[i]) named_linears = get_named_linears(self.modules[i])
named_linears = exclude_layers_to_not_quantize(named_linears, self.modules_to_not_convert) named_linears = exclude_layers_to_not_quantize(
named_linears, self.modules_to_not_convert
)
self._apply_quant(self.modules[i], named_linears) self._apply_quant(self.modules[i], named_linears)
clear_memory() clear_memory()
...@@ -142,25 +187,30 @@ class AwqQuantizer: ...@@ -142,25 +187,30 @@ class AwqQuantizer:
linear_layer = linear_layer.to(get_best_device()).half() linear_layer = linear_layer.to(get_best_device()).half()
linear_layer.weight.data, scales, zeros = self.pseudo_quantize_tensor( linear_layer.weight.data, scales, zeros = self.pseudo_quantize_tensor(
linear_layer.weight.data, linear_layer.weight.data
get_scale_zp=True
) )
if self.version == 'GEMM': if self.version == "GEMM":
scales = scales.t().contiguous() scales = scales.t().contiguous()
zeros = zeros.t().contiguous() zeros = zeros.t().contiguous()
q_linear_module = WQLinear_GEMM q_linear_module = WQLinear_GEMM
elif self.version == 'GEMV': elif self.version == "GEMV":
q_linear_module = WQLinear_GEMV q_linear_module = WQLinear_GEMV
elif self.version == "Marlin":
q_linear_module = WQLinear_Marlin
else:
raise ValueError(f"Unknown version {self.version}")
q_linear = q_linear_module.from_linear( q_linear = q_linear_module.from_linear(
linear=linear_layer, linear=linear_layer,
w_bit=self.w_bit, w_bit=self.w_bit,
group_size=self.group_size, group_size=self.group_size,
init_only=False, init_only=False,
scales=scales, scales=scales,
zeros=zeros zeros=zeros,
) )
linear_layer.cpu() linear_layer.cpu()
...@@ -169,7 +219,15 @@ class AwqQuantizer: ...@@ -169,7 +219,15 @@ class AwqQuantizer:
clear_memory() clear_memory()
@torch.no_grad() @torch.no_grad()
def _search_best_scale(self, module, prev_op, layers: List[nn.Linear], inp: torch.Tensor, module2inspect=None, kwargs={}): def _search_best_scale(
self,
module,
prev_op,
layers: List[nn.Linear],
inp: torch.Tensor,
module2inspect=None,
kwargs={},
):
if module2inspect is None: if module2inspect is None:
assert len(layers) == 1 assert len(layers) == 1
module2inspect = layers[0] module2inspect = layers[0]
...@@ -202,14 +260,25 @@ class AwqQuantizer: ...@@ -202,14 +260,25 @@ class AwqQuantizer:
# [STEP 4]: Compute loss # [STEP 4]: Compute loss
best_scales = self._compute_best_scale( best_scales = self._compute_best_scale(
inp, w_max, x_max, module2inspect, inp, w_max, x_max, module2inspect, layers, fp16_output, module_kwargs
layers, fp16_output, module_kwargs
) )
return (get_op_name(module, prev_op), tuple([get_op_name(module, m) for m in layers]), best_scales) return (
get_op_name(module, prev_op),
tuple([get_op_name(module, m) for m in layers]),
best_scales,
)
def _compute_best_scale(self, x, w_max, x_max, module2inspect, linears2scale: List[nn.Linear], def _compute_best_scale(
fp16_output, kwargs={}): self,
x,
w_max,
x_max,
module2inspect,
linears2scale: List[nn.Linear],
fp16_output,
kwargs={},
):
""" """
Compute loss and select best scales Compute loss and select best scales
...@@ -223,7 +292,7 @@ class AwqQuantizer: ...@@ -223,7 +292,7 @@ class AwqQuantizer:
history = [] history = []
best_ratio = -1 best_ratio = -1
best_scales = None best_scales = None
best_error = float('inf') best_error = float("inf")
org_sd = {k: v.cpu() for k, v in module2inspect.state_dict().items()} org_sd = {k: v.cpu() for k, v in module2inspect.state_dict().items()}
...@@ -237,7 +306,7 @@ class AwqQuantizer: ...@@ -237,7 +306,7 @@ class AwqQuantizer:
# NOTE: s^-1 * x is fused here, according to paper # NOTE: s^-1 * x is fused here, according to paper
if self.duo_scaling: if self.duo_scaling:
scales = (x_max.pow(ratio) / w_max.pow(1-ratio)).clamp(min=1e-4) scales = (x_max.pow(ratio) / w_max.pow(1 - ratio)).clamp(min=1e-4)
else: else:
scales = x_max.pow(ratio).clamp(min=1e-4).view(-1) scales = x_max.pow(ratio).clamp(min=1e-4).view(-1)
scales = scales / (scales.max() * scales.min()).sqrt() scales = scales / (scales.max() * scales.min()).sqrt()
...@@ -246,7 +315,9 @@ class AwqQuantizer: ...@@ -246,7 +315,9 @@ class AwqQuantizer:
# Q(W * s) # Q(W * s)
for fc in linears2scale: for fc in linears2scale:
fc.weight.mul_(scales_view) fc.weight.mul_(scales_view)
fc.weight.data = self.pseudo_quantize_tensor(fc.weight.data) / scales_view fc.weight.data = (
self.pseudo_quantize_tensor(fc.weight.data)[0] / scales_view
)
# W * X # W * X
int_w_output = module2inspect(x, **kwargs) int_w_output = module2inspect(x, **kwargs)
...@@ -254,7 +325,9 @@ class AwqQuantizer: ...@@ -254,7 +325,9 @@ class AwqQuantizer:
int_w_output = int_w_output[0] int_w_output = int_w_output[0]
# compute mean squared error (L2 norm) # compute mean squared error (L2 norm)
loss = (fp16_output - int_w_output).float().pow(2).mean().item() # NOTE: float prevents overflow loss = (
(fp16_output - int_w_output).float().pow(2).mean().item()
) # NOTE: float prevents overflow
history.append(loss) history.append(loss)
if loss < best_error: if loss < best_error:
...@@ -284,30 +357,36 @@ class AwqQuantizer: ...@@ -284,30 +357,36 @@ class AwqQuantizer:
named_linears[name].to(get_best_device()) named_linears[name].to(get_best_device())
max_val = self._compute_best_clip(named_linears[name].weight, input_feat[name]) max_val = self._compute_best_clip(named_linears[name].weight, input_feat[name])
clip_list.append((name, max_val)) clip_list.append((name, max_val))
named_linears[name].cpu() named_linears[name].cpu()
return clip_list return clip_list
@torch.no_grad() @torch.no_grad()
def _compute_best_clip(self, w: torch.Tensor, input_feat: torch.Tensor, n_grid=20, max_shrink=0.5, n_sample_token=512): def _compute_best_clip(
self,
w: torch.Tensor,
input_feat: torch.Tensor,
n_grid=20,
max_shrink=0.5,
n_sample_token=512,
):
assert w.dim() == 2 assert w.dim() == 2
org_w_shape = w.shape org_w_shape = w.shape
# w [co, ci] -> [co, 1, n_group, group size] # w [co, ci] -> [co, 1, n_group, group size]
# input_feat [n_token, ci] -> [1, n_token, n_group, group size] # input_feat [n_token, ci] -> [1, n_token, n_group, group size]
group_size = self.group_size if self.group_size > 0 else w.shape[1] group_size = self.group_size if self.group_size > 0 else org_w_shape[1]
input_feat = input_feat.view(-1, input_feat.shape[-1]) input_feat = input_feat.view(-1, input_feat.shape[-1])
input_feat = input_feat.reshape(1, input_feat.shape[0], -1, group_size) input_feat = input_feat.reshape(1, input_feat.shape[0], -1, group_size)
input_feat = input_feat[:, 0::input_feat.shape[1] // n_sample_token] input_feat = input_feat[:, 0 :: input_feat.shape[1] // n_sample_token]
w = w.reshape(w.shape[0], 1, -1, group_size) w = w.reshape(org_w_shape[0], 1, -1, group_size)
oc_batch_size = 256 if w.shape[0] % 256 == 0 else 64 # prevent OOM oc_batch_size = 256 if org_w_shape[0] % 256 == 0 else 64 # prevent OOM
assert w.shape[0] % oc_batch_size == 0 assert org_w_shape[0] % oc_batch_size == 0
w_all = w w_all = w
best_max_val_all = [] best_max_val_all = []
for i_b in range(w.shape[0] // oc_batch_size): for i_b in range(org_w_shape[0] // oc_batch_size):
w = w_all[i_b * oc_batch_size: (i_b + 1) * oc_batch_size] w = w_all[i_b * oc_batch_size : (i_b + 1) * oc_batch_size]
org_max_val = w.abs().amax(dim=-1, keepdim=True) # co, 1, n_group, 1 org_max_val = w.abs().amax(dim=-1, keepdim=True) # co, 1, n_group, 1
...@@ -318,9 +397,9 @@ class AwqQuantizer: ...@@ -318,9 +397,9 @@ class AwqQuantizer:
for i_s in range(int(max_shrink * n_grid)): for i_s in range(int(max_shrink * n_grid)):
max_val = org_max_val * (1 - i_s / n_grid) max_val = org_max_val * (1 - i_s / n_grid)
min_val = - max_val min_val = -max_val
cur_w = torch.clamp(w, min_val, max_val) cur_w = torch.clamp(w, min_val, max_val)
q_w = self.pseudo_quantize_tensor(cur_w) q_w = self.pseudo_quantize_tensor(cur_w)[0]
cur_out = (input_feat * q_w).sum(dim=-1) cur_out = (input_feat * q_w).sum(dim=-1)
# co, 1, n_group, 1 # co, 1, n_group, 1
...@@ -339,11 +418,15 @@ class AwqQuantizer: ...@@ -339,11 +418,15 @@ class AwqQuantizer:
return best_max_val.squeeze(1) return best_max_val.squeeze(1)
def init_quant(self, n_samples=128, seqlen=512): def init_quant(self, n_samples=2, seqlen=512):
modules = self.awq_model.get_model_layers(self.model) modules = self.awq_model.get_model_layers(self.model)
samples = get_calib_dataset( samples = get_calib_dataset(
data=self.calib_data, tokenizer=self.tokenizer, n_samples=n_samples, block_size=seqlen, data=self.calib_data,
split=self.split, text_column=self.text_column tokenizer=self.tokenizer,
n_samples=n_samples,
block_size=seqlen,
split=self.split,
text_column=self.text_column,
) )
samples = torch.cat(samples, dim=0) samples = torch.cat(samples, dim=0)
...@@ -414,12 +497,17 @@ class AwqQuantizer: ...@@ -414,12 +497,17 @@ class AwqQuantizer:
# FIXME: Workaround for Mixtral to use block_sparse_moe input features # FIXME: Workaround for Mixtral to use block_sparse_moe input features
if self.awq_model.model_type == "mixtral": if self.awq_model.model_type == "mixtral":
named_linears = {**named_linears, "block_sparse_moe": layer.block_sparse_moe} named_linears = {
**named_linears,
"block_sparse_moe": layer.block_sparse_moe,
}
for name in named_linears: for name in named_linears:
handles.append(named_linears[name].register_forward_hook( handles.append(
functools.partial(cache_input_hook, name=name, named_linears[name].register_forward_hook(
feat_dict=input_feat))) functools.partial(cache_input_hook, name=name, feat_dict=input_feat)
)
)
self.inps = self.inps.to(next(layer.parameters()).device) # in case multi-gpu self.inps = self.inps.to(next(layer.parameters()).device) # in case multi-gpu
# get output as next layer's input # get output as next layer's input
...@@ -436,7 +524,6 @@ class AwqQuantizer: ...@@ -436,7 +524,6 @@ class AwqQuantizer:
return input_feat return input_feat
def _sanitize_kwargs(self, inputs_kwargs, module): def _sanitize_kwargs(self, inputs_kwargs, module):
""" """
Remove the arguments that are not supported in the module's Remove the arguments that are not supported in the module's
......
import torch import torch
from awq.modules.linear.gemm import WQLinear_GEMM from awq.modules.linear.gemm import WQLinear_GEMM
from awq.modules.linear.gemv import WQLinear_GEMV from awq.modules.linear.gemv import WQLinear_GEMV
from awq.modules.linear.marlin import WQLinear_Marlin
from awq.modules.linear.exllama import WQLinear_Exllama from awq.modules.linear.exllama import WQLinear_Exllama
from awq.modules.linear.exllamav2 import WQLinear_ExllamaV2 from awq.modules.linear.exllamav2 import WQLinear_ExllamaV2
...@@ -58,8 +60,10 @@ def fuse_qkv(module, q_proj, k_proj, v_proj): ...@@ -58,8 +60,10 @@ def fuse_qkv(module, q_proj, k_proj, v_proj):
q_linear = WQLinear_GEMM q_linear = WQLinear_GEMM
elif isinstance(q_proj, WQLinear_Exllama): elif isinstance(q_proj, WQLinear_Exllama):
q_linear = WQLinear_Exllama q_linear = WQLinear_Exllama
else: elif isinstance(q_proj, WQLinear_ExllamaV2):
q_linear = WQLinear_ExllamaV2 q_linear = WQLinear_ExllamaV2
elif isinstance(q_proj, WQLinear_Marlin):
q_linear = WQLinear_Marlin
qkv_layer = q_linear( qkv_layer = q_linear(
q_proj.w_bit, q_proj.w_bit,
...@@ -87,6 +91,10 @@ def fuse_qkv(module, q_proj, k_proj, v_proj): ...@@ -87,6 +91,10 @@ def fuse_qkv(module, q_proj, k_proj, v_proj):
qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1) qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1) qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1) qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
elif isinstance(q_proj, WQLinear_Marlin):
qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
# workspace is created in post_init
qkv_layer.bias = bias qkv_layer.bias = bias
......
import torch
from typing import List
Q_BITS = 4
STORAGE_BITS = 32
PACK_NUM = STORAGE_BITS // Q_BITS
ORDINAL_PACK_ORDER = [0, 1, 2, 3, 4, 5, 6, 7]
AWQ_PACK_ORDER = [0, 2, 4, 6, 1, 3, 5, 7]
REVERSE_AWQ_PACK_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
def pack(imatrix: torch.Tensor, direction: str = "column"):
"""
Packs a 4-bit integer matrix into a packed 32-bit integer matrix.
Args:
imatrix (torch.Tensor): matrix of integers
direction (str): direction of packing, either "column" or "row"
Returns:
qmatrix (torch.Tensor): packed matrix of integers
"""
shifts = torch.arange(0, STORAGE_BITS, Q_BITS, device=imatrix.device)
imatrix = imatrix.to(torch.int8)
imatrix = torch.bitwise_and(imatrix, 0x0F) # eventually correct overflow
if direction == "column":
imatrix = imatrix.view(-1, imatrix.shape[1] // PACK_NUM, PACK_NUM)
qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, None, :]).sum(dim=-1)
elif direction == "row":
imatrix = imatrix.view(imatrix.shape[0] // PACK_NUM, PACK_NUM, -1)
qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, :, None]).sum(dim=1)
qmatrix = qmatrix.to(torch.int32)
return qmatrix
def unpack(qmatrix: torch.Tensor, direction: str = "column"):
"""
Unpacks a 32-bit packed integer matrix into a 4-bit integer matrix.
Args:
qmatrix (torch.Tensor): matrix of packed integers
direction (str): direction of unpacking, either "column" or "row"
Returns:
imatrix (torch.Tensor): matrix of integers
"""
shifts = torch.arange(0, STORAGE_BITS, Q_BITS, device=qmatrix.device)
if direction == "column":
imatrix = torch.bitwise_right_shift(
qmatrix[:, :, None], shifts[None, None, :]
).view(qmatrix.shape[0], -1)
elif direction == "row":
imatrix = torch.bitwise_right_shift(
qmatrix[:, None, :], shifts[None, :, None]
).view(-1, qmatrix.shape[-1])
imatrix = imatrix.to(torch.int8) & 0x0F # eventually correct overflow
return imatrix
def quantize(fmatrix, scales, zeros, group_size):
"""
Quantizes a matrix of 16-bit floats into a matrix of 4-bit integers.
Args:
fmatrix (torch.Tensor): matrix of 16-bit floats
scales (torch.Tensor): matrix of 16-bit floats
zeros (torch.Tensor): matrix of 4-bit integers
group_size (int): group size
Returns:
imatrix (torch.Tensor): matrix of 4-bit integers
"""
zeros = zeros.to(torch.int8) & 0x0F
imatrix = torch.round(
(
fmatrix / scales.repeat_interleave(group_size, dim=0)
+ zeros.repeat_interleave(group_size, dim=0)
)
)
imatrix = imatrix.to(torch.int8) & 0x0F
return imatrix
def dequantize(imatrix, scales, zeros, group_size):
"""
Dequantizes a 4-bit integer matrix into a float matrix.
Args:
imatrix (torch.Tensor): matrix of 4-bit integers
scales (torch.Tensor): matrix of 16-bit floats
zeros (torch.Tensor): matrix of 4-bit integers
group_size (int): group size
Returns:
fmatrix (torch.Tensor): matrix of 16-bit floats
"""
zeros = zeros.to(torch.int8) & 0x0F
imatrix = imatrix.to(torch.int8) & 0x0F
fmatrix = (
imatrix - zeros.repeat_interleave(group_size, dim=0)
) * scales.repeat_interleave(group_size, dim=0)
fmatrix = fmatrix.to(torch.float16)
return fmatrix
def apply_order(
imatrix: torch.Tensor,
direction: str = "column",
order: List[int] = ORDINAL_PACK_ORDER,
):
"""
Applies the order to a 4-bit integer matrix.
Args:
imatrix (torch.Tensor): matrix of integers
direction (str): direction of applying order, either "column" or "row"
order (List[int]): order to apply, default is ordinal packing order
Returns:
imatrix (torch.Tensor): matrix of integers
"""
if direction == "column":
imatrix = imatrix.view(-1, PACK_NUM)[:, order].view(imatrix.shape)
elif direction == "row":
imatrix = imatrix.view(PACK_NUM, -1)[order, :].view(imatrix.shape)
return imatrix
def awq_to_exllama(qweight, qzeros):
# awq uses column packing for both weights and zeros
izeros = unpack(qzeros, direction="column")
iweights = unpack(qweight, direction="column")
# Reverse the order of the iweight and izeros tensors
izeros = apply_order(izeros, direction="column", order=REVERSE_AWQ_PACK_ORDER)
iweights = apply_order(iweights, direction="column", order=REVERSE_AWQ_PACK_ORDER)
# Subtract 1 from the izeros tensor (exllama adds 1 during inference)
izeros = izeros - 1
# exllama uses row packing for weights and column packing for zeros
qzeros = pack(izeros, direction="column")
qweight = pack(iweights, direction="row")
return qweight, qzeros
from awq import AutoAWQForCausalLM from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer, TextStreamer from transformers import AutoTokenizer, TextStreamer
quant_path = "TheBloke/Mistral-7B-Instruct-v0.2-AWQ" quant_path = "TheBloke/Mistral-7B-Instruct-v0.2-AWQ"
# Load model # Load model
......
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer, TextStreamer
quant_path = "IlyasMoutawwakil/vicuna-7b-v1.5-awq-marlin"
# Load model
model = AutoAWQForCausalLM.from_quantized(quant_path, fuse_layers=False)
tokenizer = AutoTokenizer.from_pretrained(quant_path, trust_remote_code=True)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# Convert prompt to tokens
prompt_template = """\
<|system|>
</s>
<|user|>
{prompt}</s>
<|assistant|>"""
prompt = "You're standing on the surface of the Earth. "\
"You walk one mile south, one mile west and one mile north. "\
"You end up exactly where you started. Where are you?"
tokens = tokenizer(
prompt_template.format(prompt=prompt),
return_tensors='pt'
).input_ids.cuda()
# Generate output
generation_output = model.generate(
tokens,
streamer=streamer,
max_new_tokens=512
)
\ No newline at end of file
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer
model_path = 'lmsys/vicuna-7b-v1.5'
quant_path = 'vicuna-7b-v1.5-awq-marlin'
quant_config = { "zero_point": False, "q_group_size": 128, "w_bit": 4, "version": "Marlin" }
# Load model
# NOTE: pass safetensors=True to load safetensors
model = AutoAWQForCausalLM.from_pretrained(
model_path, **{"low_cpu_mem_usage": True, "use_cache": False}
)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# Quantize
model.quantize(tokenizer, quant_config=quant_config)
# Save quantized model
model.save_quantized(quant_path)
tokenizer.save_pretrained(quant_path)
print(f'Model is quantized and saved at "{quant_path}"')
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment