"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "6fd04ca922e5da7ef8c52d86118fc58b798a7e4a"
Commit d3550fec authored by Casper Hansen's avatar Casper Hansen
Browse files

Remove old quantization code

parent 724bda58
...@@ -2,25 +2,18 @@ import os ...@@ -2,25 +2,18 @@ import os
import gc import gc
import json import json
import torch import torch
import logging
import functools
import torch.nn as nn import torch.nn as nn
from tqdm import tqdm from tqdm import tqdm
from typing import List, Union from typing import List, Union
from collections import defaultdict
from safetensors.torch import save_file from safetensors.torch import save_file
from awq.modules.act import ScaledActivation from awq.modules.act import ScaledActivation
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from awq.utils.utils import simple_dispatch_model from awq.utils.utils import simple_dispatch_model
from awq.utils.calib_data import get_calib_dataset
from transformers.modeling_utils import shard_checkpoint from transformers.modeling_utils import shard_checkpoint
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from awq.quantize.auto_clip import auto_clip_block, apply_clip from awq.utils.module import get_named_linears, set_op_by_name
from awq.quantize.auto_scale import auto_scale_block, apply_scale
from transformers import AutoModelForCausalLM, AutoConfig, PreTrainedModel from transformers import AutoModelForCausalLM, AutoConfig, PreTrainedModel
from accelerate import init_empty_weights, load_checkpoint_in_model, infer_auto_device_map from accelerate import init_empty_weights, load_checkpoint_in_model, infer_auto_device_map
from awq.utils.module import append_str_prefix, get_op_name, get_named_linears, set_op_by_name
class BaseAWQForCausalLM(nn.Module): class BaseAWQForCausalLM(nn.Module):
def __init__(self, model, model_type, is_quantized, quant_config): def __init__(self, model, model_type, is_quantized, quant_config):
...@@ -55,183 +48,11 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -55,183 +48,11 @@ class BaseAWQForCausalLM(nn.Module):
quant_config["version"], calib_data, split, text_column quant_config["version"], calib_data, split, text_column
) )
quantizer.quantize() quantizer.quantize()
self.is_quantized = True self.is_quantized = True
# if run_search:
# self.search_result = self._awq_search(
# tokenizer, quant_config, n_samples=n_samples, seqlen=seqlen,
# auto_scale=auto_scale, mse_range=mse_range, calib_data=calib_data,
# split=split, text_column=text_column
# )
# if run_quant:
# self._awq_quant()
# self.is_quantized = True
@staticmethod @staticmethod
def fuse_layers(model, quant_config): def fuse_layers(model, quant_config):
pass pass
def _awq_quant(self):
assert self.quant_config["zero_point"], "We only support zero_point quantization now."
layers = self.get_model_layers(self.model)
# Run AWQ quantization
for i in tqdm(range(len(layers)), desc="AWQ Quantization"):
layer = layers[i]
named_linears = get_named_linears(layer)
self._scale_activations(self, layer)
for name, module in named_linears.items():
module.cuda()
module.weight.data, scales, zeros = pseudo_quantize_tensor(
module.weight.data,
get_scale_zp=True,
w_bit=self.quant_config["w_bit"],
q_group_size=self.quant_config["q_group_size"]
)
if self.quant_config["version"] == 'GEMM':
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
q_linear_module = WQLinear_GEMM
elif self.quant_config["version"] == 'GEMV':
q_linear_module = WQLinear_GEMV
q_linear = q_linear_module.from_linear(
module,
self.quant_config['w_bit'],
self.quant_config['q_group_size'],
False,
scales,
zeros
)
module.cpu()
q_linear.to(next(layer.parameters()).device)
set_op_by_name(layer, name, q_linear)
torch.cuda.empty_cache()
gc.collect()
torch.cuda.empty_cache()
gc.collect()
def _awq_search(self, tokenizer, quant_config, n_samples=128, seqlen=512,
auto_scale=True, mse_range=True, calib_data:Union[str, List[str]]="pileval",
split="train", text_column="text"):
layers = self.get_model_layers(self.model)
samples = get_calib_dataset(
data=calib_data, tokenizer=tokenizer, n_samples=n_samples, block_size=seqlen,
split=split, text_column=text_column
)
samples = torch.cat(samples, dim=0)
inps = []
layer_kwargs = {}
layers[0] = layers[0].cuda()
self.move_embed(self.model, "cuda")
# get input and kwargs to layer 0
# with_kwargs is only supported in PyTorch 2.0
# use this Catcher hack for now
class Catcher(nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
def forward(self, hijacked_inputs, **kwargs):
inps.append(hijacked_inputs)
layer_kwargs.update(kwargs)
raise ValueError # early exit to break later inference
# patch layer 0 to catch input and kwargs
layers[0] = Catcher(layers[0])
try:
self.model(samples.to(next(self.model.parameters()).device))
except ValueError: # work with early exit
pass
del samples
layers[0] = layers[0].module # restore
inps = inps[0]
layers[0] = layers[0].cpu()
self.move_embed(self.model, "cpu")
gc.collect()
torch.cuda.empty_cache()
awq_results = {
"scale": [],
"clip": [],
}
# Run AWQ search layer by layer
for i in tqdm(range(len(layers)), desc="AWQ Search"):
layer = layers[i]
layer = layer.cuda()
named_linears = get_named_linears(layer)
# firstly, get input features of all linear layers
def cache_input_hook(m, x, y, name, feat_dict):
x = x[0]
x = x.detach().cpu()
feat_dict[name].append(x)
input_feat = defaultdict(list)
handles = []
for name in named_linears:
handles.append(named_linears[name].register_forward_hook(
functools.partial(cache_input_hook, name=name,
feat_dict=input_feat)))
inps = inps.to(next(layer.parameters()).device) # in case multi-gpu
# get output as next layer's input
inps = layer(inps, **layer_kwargs)[0]
for h in handles:
h.remove()
# now solve for scaling and clipping
input_feat = {k: torch.cat(v, dim=0) for k, v in input_feat.items()}
# Clear GPU memory
torch.cuda.empty_cache()
if auto_scale: # if it applies, we should also modify the input_feat with scales
scales_list = auto_scale_block(
self,
layer,
layer_kwargs,
quant_config=quant_config,
input_feat=input_feat,
)
apply_scale(layers[i], scales_list, input_feat_dict=input_feat)
# append prefix to make names global
awq_results["scale"] += append_str_prefix(scales_list, get_op_name(self.model, layer) + ".")
# Clear GPU memory
torch.cuda.empty_cache()
if mse_range:
clip_list = auto_clip_block(
layer,
quant_config=quant_config,
input_feat=input_feat
)
apply_clip(layer, clip_list)
# append prefix to make names global
awq_results["clip"] += append_str_prefix(clip_list, get_op_name(self.model, layer) + ".")
layer = layer.cpu()
# Haotian: check activation replacement
del input_feat
gc.collect()
torch.cuda.empty_cache()
return awq_results
def save_quantized(self, save_dir, safetensors=False, shard_size="10GB"): def save_quantized(self, save_dir, safetensors=False, shard_size="10GB"):
def _save_files(save_dir, model_name='', search_result=None): def _save_files(save_dir, model_name='', search_result=None):
......
import torch
import torch.nn as nn
import gc
__all__ = ["auto_clip_block"]
# weight quantization
@torch.no_grad()
def auto_clip_layer(w,
input_feat,
quant_config,
n_grid=20,
max_shrink=0.5,
n_sample_token=512):
assert w.dim() == 2
org_w_shape = w.shape
# w [co, ci] -> [co, 1, n_group, group size]
# input_feat [n_token, ci] -> [1, n_token, n_group, group size]
group_size = quant_config["q_group_size"] if quant_config["q_group_size"] > 0 else w.shape[1]
input_feat = input_feat.view(-1, input_feat.shape[-1])
input_feat = input_feat.reshape(1, input_feat.shape[0], -1, group_size)
input_feat = input_feat[:, 0::input_feat.shape[1] // n_sample_token]
w = w.reshape(w.shape[0], 1, -1, group_size)
oc_batch_size = 256 if w.shape[0] % 256 == 0 else 64 # prevent OOM
assert w.shape[0] % oc_batch_size == 0
w_all = w
best_max_val_all = []
for i_b in range(w.shape[0] // oc_batch_size):
w = w_all[i_b * oc_batch_size: (i_b + 1) * oc_batch_size]
org_max_val = w.abs().amax(dim=-1, keepdim=True) # co, 1, n_group, 1
best_max_val = org_max_val.clone()
min_errs = torch.ones_like(org_max_val) * 1e9
input_feat = input_feat.to(w.device)
org_out = (input_feat * w).sum(dim=-1) # co, n_token, n_group
for i_s in range(int(max_shrink * n_grid)):
max_val = org_max_val * (1 - i_s / n_grid)
min_val = - max_val
cur_w = torch.clamp(w, min_val, max_val)
q_w = pseudo_quantize_tensor(cur_w, w_bit=quant_config["w_bit"], q_group_size=quant_config["q_group_size"])
cur_out = (input_feat * q_w).sum(dim=-1)
# co, 1, n_group, 1
err = (cur_out - org_out).pow(2).mean(dim=1).view(min_errs.shape)
del cur_w
del cur_out
cur_best_idx = err < min_errs
min_errs[cur_best_idx] = err[cur_best_idx]
best_max_val[cur_best_idx] = max_val[cur_best_idx]
best_max_val_all.append(best_max_val)
best_max_val = torch.cat(best_max_val_all, dim=0)
del input_feat
del org_out
gc.collect()
torch.cuda.empty_cache()
return best_max_val.squeeze(1)
@torch.no_grad()
def auto_clip_block(module,
quant_config,
input_feat):
named_linears = {name: m for name,
m in module.named_modules() if isinstance(m, nn.Linear)}
clip_list = []
for name in named_linears:
# due to qk bmm, it is hard to clip precisely
if any([_ in name for _ in ["q_", "k_", "query", "key", "Wqkv"]]):
continue
named_linears[name].cuda()
max_val = auto_clip_layer(
named_linears[name].weight, input_feat[name], quant_config=quant_config)
clip_list.append((name, max_val))
named_linears[name].cpu()
return clip_list
@torch.no_grad()
def apply_clip(module, clip_list):
from ..utils.module import get_op_by_name
for name, max_val in clip_list:
layer = get_op_by_name(module, name)
layer.cuda()
max_val = max_val.to(layer.weight.device)
org_shape = layer.weight.shape
layer.weight.data = layer.weight.data.reshape(*max_val.shape[:2], -1)
layer.weight.data = torch.clamp(layer.weight.data, -max_val, max_val)
layer.weight.data = layer.weight.data.reshape(org_shape)
layer.cpu()
import gc
import torch
import torch.nn as nn
import logging
from transformers.models.bloom.modeling_bloom import BloomBlock, BloomGelu
from transformers.models.opt.modeling_opt import OPTDecoderLayer
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMSNorm
from transformers.activations import NewGELUActivation
from awq.modules.act import ScaledActivation
from awq.utils.module import get_op_by_name, get_op_name, set_op_by_name
__all__ = ["auto_scale_block", "apply_scale"]
@torch.no_grad()
def get_weight_scale(weight, q_group_size=-1):
org_shape = weight.shape
if q_group_size > 0:
weight = weight.view(-1, q_group_size)
scale = weight.abs() / weight.abs().amax(dim=1, keepdim=True)
scale = scale.view(org_shape)
scale = scale.mean(0)
return scale
@torch.no_grad()
def get_act_scale(x):
return x.abs().view(-1, x.shape[-1]).mean(0)
@torch.no_grad()
def scale_ln_fcs(ln, fcs, scales):
if not isinstance(fcs, list):
fcs = [fcs]
scales = scales.to(ln.weight.device)
# debugging start even scales = 1 does not work?
"""
scales = scales * 0
scales = scales + 1
"""
# debugging end
ln.weight.div_(scales)
if hasattr(ln, 'bias') and ln.bias is not None:
ln.bias.div_(scales)
for fc in fcs:
fc.weight.mul_(scales.view(1, -1))
for p in ln.parameters():
assert torch.isnan(p).sum() == 0
for fc in fcs:
for p in fc.parameters():
assert torch.isnan(p).sum() == 0
@torch.no_grad()
def scale_fc_fc(fc1, fc2, scales):
assert isinstance(fc1, nn.Linear)
assert isinstance(fc2, nn.Linear)
# assert fc1.out_features == fc2.in_features
scales = scales.to(fc1.weight.device)
# fc1.weight.div_(scales.view(-1, 1))
fc1.weight[-scales.size(0):].div_(scales.view(-1, 1))
if fc1.bias is not None:
fc1.bias.div_(scales.view(-1))
fc2.weight.mul_(scales.view(1, -1))
for p in fc1.parameters():
assert torch.isnan(p).sum() == 0
for p in fc2.parameters():
assert torch.isnan(p).sum() == 0
@torch.no_grad()
def scale_gelu_fc(gelu, fc, scales):
assert any(isinstance(gelu,t) for t in [nn.GELU, BloomGelu, NewGELUActivation])
assert isinstance(fc, nn.Linear)
fc.weight.mul_(scales.view(1, -1).to(fc.weight.device))
for p in fc.parameters():
assert torch.isnan(p).sum() == 0
def pseudo_quantize_tensor(w, w_bit=4,
zero_point=True,
q_group_size=-1,
inplace=False,
get_scale_zp=False
):
org_w_shape = w.shape
if q_group_size > 0:
assert org_w_shape[-1] % q_group_size == 0
w = w.reshape(-1, q_group_size)
assert w.dim() == 2
if zero_point:
max_val = w.amax(dim=1, keepdim=True)
min_val = w.amin(dim=1, keepdim=True)
max_int = 2 ** w_bit - 1
min_int = 0
scales = (max_val - min_val).clamp(min=1e-5) / max_int
zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)
else: # we actually never used this
assert min_val is None
max_val = w.abs().amax(dim=1, keepdim=True)
max_val = max_val.clamp(min=1e-5)
max_int = 2 ** (w_bit - 1) - 1
min_int = - 2 ** (w_bit - 1)
scales = max_val / max_int
zeros = 0
assert torch.isnan(scales).sum() == 0
assert torch.isnan(w).sum() == 0
if inplace:
((w.div_(scales).round_().add_(zeros)).clamp_(
min_int, max_int).sub_(zeros)).mul_(scales)
else:
w = (torch.clamp(torch.round(w / scales) +
zeros, min_int, max_int) - zeros) * scales
assert torch.isnan(w).sum() == 0
w = w.reshape(org_w_shape)
if get_scale_zp:
return w, scales.view(w.shape[0], -1), zeros.view(w.shape[0], -1)
else:
return w
@torch.no_grad()
def auto_scale_block(awq_model,
module,
module_kwargs,
quant_config,
input_feat):
# from .quantizer import pseudo_quantize_tensor
# firstly, get the weight quantize function
if quant_config['w_bit'] is not None:
def w_quantize_func(p): return pseudo_quantize_tensor(p, w_bit=quant_config["w_bit"], q_group_size=quant_config["q_group_size"]).detach()
else:
def w_quantize_func(p): return p
if "use_cache" in module_kwargs:
module_kwargs.pop("use_cache")
# find the best scale ratio
def _search_module_scale(module2inspect, layers: list, inp, kwargs={}):
if module2inspect is None:
assert len(layers) == 1
module2inspect = layers[0]
# w: co, ci
# x: n, ci
weight = torch.cat([_m.weight for _m in layers], dim=0)
org_shape = weight.shape
weight = weight.view(-1, quant_config.get("q_group_size"))
w_scale = weight.abs() / weight.abs().amax(dim=1, keepdim=True)
w_scale = w_scale.view(org_shape)
w_max = w_scale.mean(0)
# Clear GPU memory
del weight
gc.collect()
torch.cuda.empty_cache()
inp = inp.to(next(module2inspect.parameters()).device)
with torch.no_grad():
org_out = module2inspect(inp, **kwargs)
if isinstance(org_out, tuple):
org_out = org_out[0]
x_max = get_act_scale(inp)
best_error = float('inf')
best_ratio = -1
best_scales = None
n_grid = 20
history = []
org_sd = {k: v.cpu() for k, v in module2inspect.state_dict().items()}
for ratio in range(n_grid):
ratio = ratio * 1 / n_grid
scales = (x_max.pow(ratio) / w_max.pow(1-ratio)
).clamp(min=1e-4).view(-1)
scales = scales / (scales.max() * scales.min()).sqrt()
for fc in layers:
fc.weight.mul_(scales.view(1, -1).to(fc.weight.device))
fc.weight.data = w_quantize_func(
fc.weight.data) / (scales.view(1, -1))
out = module2inspect(inp, **kwargs)
if isinstance(out, tuple):
out = out[0]
loss = (org_out - out).float().pow(2).mean().item() # float prevents overflow
history.append(loss)
is_best = loss < best_error
if is_best:
best_error = loss
best_ratio = ratio
best_scales = scales
module2inspect.load_state_dict(org_sd)
if best_ratio == -1:
logging.debug(history)
raise Exception
best_scales = best_scales.view(-1)
assert torch.isnan(best_scales).sum() == 0, best_scales
return best_scales.detach()
def _auto_get_scale(prev_op, layers, inp, module2inspect=None, kwargs={}):
scales = _search_module_scale(module2inspect, layers, inp, kwargs)
scales = scales.detach().cpu()
# prev_op_name, [layer_name], scale
return (get_op_name(module, prev_op), tuple([get_op_name(module, m) for m in layers]), scales)
layers: list[dict] = awq_model.get_layers_for_scaling(
module, input_feat, module_kwargs
)
scales_list = [_auto_get_scale(**layer) for layer in layers]
return scales_list
def apply_scale(module, scales_list, input_feat_dict=None):
for prev_op_name, layer_names, scales in scales_list:
prev_op = get_op_by_name(module, prev_op_name)
layers = [get_op_by_name(module, name) for name in layer_names]
prev_op.cuda()
for layer in layers:
layer.cuda()
scales.cuda()
if isinstance(prev_op, nn.Linear):
assert len(layers) == 1
scale_fc_fc(prev_op, layers[0], scales)
elif any(isinstance(prev_op,t) for t in [nn.LayerNorm, LlamaRMSNorm]) \
or 'rmsnorm' in str(prev_op.__class__).lower():
scale_ln_fcs(prev_op, layers, scales)
elif any(isinstance(prev_op,t) for t in [nn.GELU, BloomGelu, NewGELUActivation]):
new_module = ScaledActivation(prev_op, scales)
set_op_by_name(module, prev_op_name, new_module)
scale_gelu_fc(prev_op, layers[0], scales)
else:
raise NotImplementedError(
f"prev_op {type(prev_op)} not supported yet!")
# apply the scaling to input feat if given; prepare it for clipping
if input_feat_dict is not None:
for layer_name in layer_names:
inp = input_feat_dict[layer_name]
inp.div_(scales.view(1, -1).to(inp.device))
prev_op.cpu()
for layer in layers:
layer.cpu()
scales.cpu()
...@@ -69,9 +69,9 @@ class AwqQuantizer: ...@@ -69,9 +69,9 @@ class AwqQuantizer:
scales_list = append_str_prefix(scales_list, get_op_name(self.model, self.modules[i]) + ".") scales_list = append_str_prefix(scales_list, get_op_name(self.model, self.modules[i]) + ".")
# [STEP 3]: Compute and apply clipping list # [STEP 3]: Compute and apply clipping list
# clip_list = self._search_best_clip(self.modules[i], named_linears, input_feat) clip_list = self._search_best_clip(self.modules[i], named_linears, input_feat)
# apply_clip(self.modules[i], clip_list) apply_clip(self.modules[i], clip_list)
# clip_list = append_str_prefix(clip_list, get_op_name(self.model, self.modules[i]) + ".") clip_list = append_str_prefix(clip_list, get_op_name(self.model, self.modules[i]) + ".")
# [STEP 4]: Quantize weights # [STEP 4]: Quantize weights
for name, linear_layer in named_linears.items(): for name, linear_layer in named_linears.items():
...@@ -211,6 +211,8 @@ class AwqQuantizer: ...@@ -211,6 +211,8 @@ class AwqQuantizer:
clip_list.append((name, max_val)) clip_list.append((name, max_val))
named_linears[name].cpu() named_linears[name].cpu()
return clip_list
@torch.no_grad() @torch.no_grad()
def _compute_best_clip(self, w: torch.Tensor, input_feat: torch.Tensor, n_grid=20, max_shrink=0.5, n_sample_token=512): def _compute_best_clip(self, w: torch.Tensor, input_feat: torch.Tensor, n_grid=20, max_shrink=0.5, n_sample_token=512):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment