Commit 45c22ee5 authored by Casper's avatar Casper
Browse files

Initial quantization refactoring

parent a5e8b048
import torch
import torch.nn as nn
from typing import Tuple
from awq.modules.act import ScaledActivation
from transformers.activations import NewGELUActivation
from transformers.models.bloom.modeling_bloom import BloomGelu
from transformers.models.llama.modeling_llama import LlamaRMSNorm
from awq.utils.module import get_op_by_name, get_op_name, set_op_by_name
allowed_norms = [nn.LayerNorm, LlamaRMSNorm]
allowed_act_fns = [nn.GELU, BloomGelu, NewGELUActivation]
@torch.no_grad()
def apply_clip(module, clip_list: Tuple[str, torch.Tensor]):
for name, max_val in clip_list:
layer: nn.Linear = get_op_by_name(module, name)
layer.cuda()
max_val = max_val.to(layer.weight.device)
org_shape = layer.weight.shape
layer.weight.data = layer.weight.data.reshape(*max_val.shape[:2], -1)
layer.weight.data = torch.clamp(layer.weight.data, -max_val, max_val)
layer.weight.data = layer.weight.data.reshape(org_shape)
layer.cpu()
def apply_scale(module, scales_list, input_feat_dict=None):
for prev_op_name, layer_names, scales in scales_list:
prev_op = get_op_by_name(module, prev_op_name)
layers = [get_op_by_name(module, name) for name in layer_names]
prev_op.cuda()
for layer in layers:
layer.cuda()
scales.cuda()
if isinstance(prev_op, nn.Linear):
assert len(layers) == 1
scale_fc_fc(prev_op, layers[0], scales)
elif any(isinstance(prev_op,t) for t in allowed_norms) \
or 'rmsnorm' in str(prev_op.__class__).lower():
scale_ln_fcs(prev_op, layers, scales)
elif any(isinstance(prev_op,t) for t in allowed_act_fns):
new_module = ScaledActivation(prev_op, scales)
set_op_by_name(module, prev_op_name, new_module)
scale_gelu_fc(prev_op, layers[0], scales)
else:
raise NotImplementedError(
f"prev_op {type(prev_op)} not supported yet!")
# apply the scaling to input feat if given; prepare it for clipping
if input_feat_dict is not None:
for layer_name in layer_names:
inp = input_feat_dict[layer_name]
inp.div_(scales.view(1, -1).to(inp.device))
prev_op.cpu()
for layer in layers:
layer.cpu()
scales.cpu()
@torch.no_grad()
def scale_ln_fcs(ln, fcs, scales):
if not isinstance(fcs, list):
fcs = [fcs]
scales = scales.to(ln.weight.device)
ln.weight.div_(scales)
if hasattr(ln, 'bias') and ln.bias is not None:
ln.bias.div_(scales)
for fc in fcs:
fc.weight.mul_(scales.view(1, -1))
for p in ln.parameters():
assert torch.isnan(p).sum() == 0
for fc in fcs:
for p in fc.parameters():
assert torch.isnan(p).sum() == 0
@torch.no_grad()
def scale_fc_fc(fc1, fc2, scales):
assert isinstance(fc1, nn.Linear)
assert isinstance(fc2, nn.Linear)
scales = scales.to(fc1.weight.device)
fc1.weight[-scales.size(0):].div_(scales.view(-1, 1))
if fc1.bias is not None:
fc1.bias.div_(scales.view(-1))
fc2.weight.mul_(scales.view(1, -1))
for p in fc1.parameters():
assert torch.isnan(p).sum() == 0
for p in fc2.parameters():
assert torch.isnan(p).sum() == 0
@torch.no_grad()
def scale_gelu_fc(gelu, fc, scales):
assert any(isinstance(gelu,t) for t in allowed_act_fns)
assert isinstance(fc, nn.Linear)
fc.weight.mul_(scales.view(1, -1).to(fc.weight.device))
for p in fc.parameters():
assert torch.isnan(p).sum() == 0
\ No newline at end of file
import torch import torch
import logging
import functools
import torch.nn as nn
from tqdm import tqdm
from collections import defaultdict
from awq.utils.utils import clear_memory
from awq.utils.calib_data import get_calib_dataset
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from awq.quantize.apply_quantized import apply_scale, apply_clip
from awq.utils.module import append_str_prefix, get_op_name, get_named_linears, set_op_by_name
class AwqQuantizer:
def __init__(self, model, tokenizer, w_bit, group_size, version, calib_data, split, text_column) -> None:
self.model = model
self.tokenizer = tokenizer
self.w_bit = w_bit
self.group_size = group_size
self.version = version
self.calib_data = calib_data
self.split = split
self.text_column = text_column
self.modules, self.module_kwargs = self.init_quant()
# core quantization method (simulated quantization) def pseudo_quantize_tensor(self, w: torch.Tensor, get_scale_zp=False):
def pseudo_quantize_tensor(w, w_bit=4, org_w_shape = w.shape
zero_point=True, if self.group_size > 0:
q_group_size=-1, assert org_w_shape[-1] % self.group_size == 0
inplace=False, w = w.reshape(-1, self.group_size)
get_scale_zp=False assert w.dim() == 2
):
org_w_shape = w.shape # zero point quantization
if q_group_size > 0:
assert org_w_shape[-1] % q_group_size == 0
w = w.reshape(-1, q_group_size)
assert w.dim() == 2
if zero_point:
max_val = w.amax(dim=1, keepdim=True) max_val = w.amax(dim=1, keepdim=True)
min_val = w.amin(dim=1, keepdim=True) min_val = w.amin(dim=1, keepdim=True)
max_int = 2 ** w_bit - 1 max_int = 2 ** self.w_bit - 1
min_int = 0 min_int = 0
scales = (max_val - min_val).clamp(min=1e-5) / max_int scales = (max_val - min_val).clamp(min=1e-5) / max_int
zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int) zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)
else: # we actually never used this
assert min_val is None assert torch.isnan(scales).sum() == 0
max_val = w.abs().amax(dim=1, keepdim=True) assert torch.isnan(w).sum() == 0
max_val = max_val.clamp(min=1e-5)
max_int = 2 ** (w_bit - 1) - 1 w = (torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - zeros) * scales
min_int = - 2 ** (w_bit - 1) assert torch.isnan(w).sum() == 0
scales = max_val / max_int
zeros = 0 w = w.reshape(org_w_shape)
assert torch.isnan(scales).sum() == 0 if get_scale_zp:
assert torch.isnan(w).sum() == 0 return w, scales.view(w.shape[0], -1), zeros.view(w.shape[0], -1)
else:
if inplace: return w
((w.div_(scales).round_().add_(zeros)).clamp_(
min_int, max_int).sub_(zeros)).mul_(scales) def quantize(self, get_layers_for_scaling: function):
else: for i in tqdm(range(len(self.modules)), desc=""):
w = (torch.clamp(torch.round(w / scales) + # [STEP 1]: Get layer, extract linear modules, extract input features
zeros, min_int, max_int) - zeros) * scales self.modules[i] = self.modules[i].cuda()
assert torch.isnan(w).sum() == 0 named_linears = get_named_linears(self.modules[i])
input_feat = self._get_input_feat(self.modules[i], named_linears)
w = w.reshape(org_w_shape) clear_memory()
if get_scale_zp: # [STEP 2]: Compute and apply scale list
return w, scales.view(w.shape[0], -1), zeros.view(w.shape[0], -1) module_config: list[dict] = get_layers_for_scaling(
else: self.modules[i], input_feat, self.module_kwargs
return w )
scales_list = [self._search_best_scale(**layer) for layer in module_config]
apply_scale(self.modules[i], scales_list, input_feat_dict=input_feat)
scales_list = append_str_prefix(scales_list, get_op_name(self.model, self.modules[i]) + ".")
# [STEP 3]: Compute and apply clipping list
clip_list = self._search_best_clip(named_linears, input_feat)
apply_clip(self.modules[i], clip_list)
clip_list = append_str_prefix(clip_list, get_op_name(self.model, self.modules[i]) + ".")
# [STEP 4]: Quantize weights
for name, linear_layer in named_linears.items():
linear_layer.weight.data, scales, zeros = self.pseudo_quantize_tensor(
linear_layer.weight.data,
get_scale_zp=True
)
if self.version == 'GEMM':
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
q_linear_module = WQLinear_GEMM
elif self.version == 'GEMV':
q_linear_module = WQLinear_GEMV
q_linear = q_linear_module.from_linear(
linear=linear_layer,
w_bit=self.w_bit,
group_size=self.group_size,
init_only=False,
scales=scales,
zeros=zeros
)
linear_layer.cpu()
q_linear.to(next(self.modules[i].parameters()).device)
set_op_by_name(self.modules[i], name, q_linear)
clear_memory()
clear_memory()
return self.model
@torch.no_grad()
def _search_best_clip(self, layer, named_linears, input_feat):
clip_list = []
avoid_clipping = ["q_", "k_", "query", "key", "Wqkv"]
for name in named_linears:
# due to qk bmm, it is hard to clip precisely
if any([_ in name for _ in avoid_clipping]):
continue
named_linears[name].cuda()
max_val = self._compute_best_clip(named_linears[name].weight, input_feat[name])
clip_list.append((name, max_val))
named_linears[name].cpu()
@torch.no_grad()
def _compute_best_clip(self, w: torch.Tensor, input_feat: torch.Tensor, n_grid=20, max_shrink=0.5, n_sample_token=512):
assert w.dim() == 2
org_w_shape = w.shape
# w [co, ci] -> [co, 1, n_group, group size]
# input_feat [n_token, ci] -> [1, n_token, n_group, group size]
group_size = self.group_size if self.group_size > 0 else w.shape[1]
input_feat = input_feat.view(-1, input_feat.shape[-1])
input_feat = input_feat.reshape(1, input_feat.shape[0], -1, group_size)
input_feat = input_feat[:, 0::input_feat.shape[1] // n_sample_token]
w = w.reshape(w.shape[0], 1, -1, group_size)
oc_batch_size = 256 if w.shape[0] % 256 == 0 else 64 # prevent OOM
assert w.shape[0] % oc_batch_size == 0
w_all = w
best_max_val_all = []
for i_b in range(w.shape[0] // oc_batch_size):
w = w_all[i_b * oc_batch_size: (i_b + 1) * oc_batch_size]
org_max_val = w.abs().amax(dim=-1, keepdim=True) # co, 1, n_group, 1
best_max_val = org_max_val.clone()
min_errs = torch.ones_like(org_max_val) * 1e9
input_feat = input_feat.to(w.device)
org_out = (input_feat * w).sum(dim=-1) # co, n_token, n_group
for i_s in range(int(max_shrink * n_grid)):
max_val = org_max_val * (1 - i_s / n_grid)
min_val = - max_val
cur_w = torch.clamp(w, min_val, max_val)
q_w = self.pseudo_quantize_tensor(cur_w)
cur_out = (input_feat * q_w).sum(dim=-1)
# co, 1, n_group, 1
err = (cur_out - org_out).pow(2).mean(dim=1).view(min_errs.shape)
del cur_w
del cur_out
cur_best_idx = err < min_errs
min_errs[cur_best_idx] = err[cur_best_idx]
best_max_val[cur_best_idx] = max_val[cur_best_idx]
best_max_val_all.append(best_max_val)
best_max_val = torch.cat(best_max_val_all, dim=0)
clear_memory(input_feat)
clear_memory(org_out)
return best_max_val.squeeze(1)
@torch.no_grad()
def _search_best_scale(self, previous_layer, linears2scale: list[nn.Linear], x: torch.Tensor, kwargs={}):
# Put x on the right device
x = x.to(next(previous_layer.parameters()).device)
# [STEP 1]: Compute maximum of weight
weight = torch.cat([_m.weight for _m in linears2scale], dim=0)
weight = weight.view(-1, self.group_size)
w_max = weight.abs() / weight.abs().amax(dim=1, keepdim=True)
w_max = w_max.view(weight.shape)
w_max = w_max.mean(0)
clear_memory(weight)
# [STEP 2]: Compute maximum of x
x_max = x.abs().view(-1, x.shape[-1]).mean(0)
# [STEP 3]: Compute output of previous layer
with torch.no_grad():
org_out = previous_layer(x, **kwargs)
if isinstance(org_out, tuple):
org_out = org_out[0]
# [STEP 4]: Compute loss
best_scales = self._compute_best_scale(
x, w_max, x_max, previous_layer,
linears2scale, org_out, kwargs
)
return best_scales
def _compute_best_scale(self, x, w_max, x_max, previous_layer, linears2scale: list[nn.Linear], org_out, kwargs={}):
"""
Compute loss and select best scales
L(s) = ||Q(W \cdot s) (s^{-1} \cdot X) - W \cdot X||
Q: weight quantization function | pseudo_quantize_tensor(W * s)
X: inputs from calib dataset | X
W: original weights in FP16 | layer
s: per channel scaling factor | s^-1 * X
"""
n_grid = 20
history = []
best_ratio = -1
best_scales = None
best_error = float('inf')
org_sd = {k: v.cpu() for k, v in previous_layer.state_dict().items()}
for ratio in range(n_grid):
# create new scales
ratio = ratio * 1 / n_grid
scales = (x_max.pow(ratio) / w_max.pow(1-ratio)).clamp(min=1e-4).view(-1)
scales = scales / (scales.max() * scales.min()).sqrt()
# multiply scale and quantize
for fc in linears2scale:
fc.weight.mul_(scales.view(1, -1).to(fc.weight.device))
fc.weight.data = self.pseudo_quantize_tensor(fc.weight.data) / (scales.view(1, -1))
out = previous_layer(x, **kwargs)
if isinstance(out, tuple):
out = out[0]
# measure loss and check if better than best
loss = (org_out - out).float().pow(2).mean().item() # NOTE: float prevents overflow
history.append(loss)
is_best = loss < best_error
if is_best:
best_error = loss
best_ratio = ratio
best_scales = scales
previous_layer.load_state_dict(org_sd)
if best_ratio == -1:
logging.debug(history)
raise Exception
best_scales = best_scales.view(-1)
assert torch.isnan(best_scales).sum() == 0, best_scales
return best_scales.detach()
def init_quant(self, n_samples=128, seqlen=512):
layers = self.get_model_layers(self.model)
samples = get_calib_dataset(
data=self.calib_data, tokenizer=self.tokenizer, n_samples=n_samples, block_size=seqlen,
split=self.split, text_column=self.text_column
)
samples = torch.cat(samples, dim=0)
inps = []
layer_kwargs = {}
layers[0] = layers[0].cuda()
self.move_embed(self.model, "cuda")
# get input and kwargs to layer 0
# with_kwargs is only supported in PyTorch 2.0
# use this Catcher hack for now
class Catcher(nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
def forward(self, hijacked_inputs, **kwargs):
inps.append(hijacked_inputs)
layer_kwargs.update(kwargs)
raise ValueError # early exit to break later inference
# patch layer 0 to catch input and kwargs
layers[0] = Catcher(layers[0])
try:
self.model(samples.to(next(self.model.parameters()).device))
except ValueError: # work with early exit
pass
del samples
layers[0] = layers[0].module # restore
inps = inps[0]
layers[0] = layers[0].cpu()
self.move_embed(self.model, "cpu")
clear_memory()
return layers, layer_kwargs
def _get_input_feat(self, layer, named_linears):
# firstly, get input features of all linear layers
def cache_input_hook(m, x, y, name, feat_dict):
x = x[0]
x = x.detach().cpu()
feat_dict[name].append(x)
input_feat = defaultdict(list)
handles = []
for name in named_linears:
handles.append(named_linears[name].register_forward_hook(
functools.partial(cache_input_hook, name=name,
feat_dict=input_feat)))
inps = inps.to(next(layer.parameters()).device) # in case multi-gpu
# get output as next layer's input
inps = layer(inps, **self.module_kwargs)[0]
for h in handles:
h.remove()
# now solve for scaling and clipping
input_feat = {k: torch.cat(v, dim=0) for k, v in input_feat.items()}
return input_feat
import gc
import torch import torch
import accelerate import accelerate
...@@ -53,3 +54,9 @@ def set_module_name(model, name, value): ...@@ -53,3 +54,9 @@ def set_module_name(model, name, value):
child_name = name child_name = name
setattr(parent, child_name, value) setattr(parent, child_name, value)
def clear_memory(weight=None):
if weight is not None:
del weight
gc.collect()
torch.cuda.empty_cache()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment