Commit 724bda58 authored by Casper Hansen's avatar Casper Hansen
Browse files

Working for OPT

parent 356cbc92
...@@ -15,7 +15,6 @@ from huggingface_hub import snapshot_download ...@@ -15,7 +15,6 @@ from huggingface_hub import snapshot_download
from awq.utils.utils import simple_dispatch_model from awq.utils.utils import simple_dispatch_model
from awq.utils.calib_data import get_calib_dataset from awq.utils.calib_data import get_calib_dataset
from transformers.modeling_utils import shard_checkpoint from transformers.modeling_utils import shard_checkpoint
from awq.quantize.quantizer import pseudo_quantize_tensor
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from awq.quantize.auto_clip import auto_clip_block, apply_clip from awq.quantize.auto_clip import auto_clip_block, apply_clip
from awq.quantize.auto_scale import auto_scale_block, apply_scale from awq.quantize.auto_scale import auto_scale_block, apply_scale
...@@ -43,23 +42,32 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -43,23 +42,32 @@ class BaseAWQForCausalLM(nn.Module):
return self.model.generate(*args, **kwargs) return self.model.generate(*args, **kwargs)
@torch.no_grad() @torch.no_grad()
def quantize(self, tokenizer=None, quant_config={}, n_samples=128, seqlen=512, def quantize(self, tokenizer=None, quant_config={},
auto_scale=True, mse_range=True, run_search=True, run_quant=True, calib_data: Union[str, List[str]]="pileval",
calib_data: Union[str, List[str]]="pileval", split="train", split="train", text_column="text"):
text_column="text"):
self.quant_config = quant_config self.quant_config = quant_config
quant_config["version"] = "GEMM" if 'version' not in quant_config.keys() else quant_config["version"] quant_config["version"] = "GEMM" if 'version' not in quant_config.keys() else quant_config["version"]
if run_search: from awq.quantize.quantizer import AwqQuantizer
self.search_result = self._awq_search(
tokenizer, quant_config, n_samples=n_samples, seqlen=seqlen, quantizer = AwqQuantizer(
auto_scale=auto_scale, mse_range=mse_range, calib_data=calib_data, self, self.model, tokenizer, quant_config["w_bit"], quant_config["q_group_size"],
split=split, text_column=text_column quant_config["version"], calib_data, split, text_column
) )
quantizer.quantize()
self.is_quantized = True
# if run_search:
# self.search_result = self._awq_search(
# tokenizer, quant_config, n_samples=n_samples, seqlen=seqlen,
# auto_scale=auto_scale, mse_range=mse_range, calib_data=calib_data,
# split=split, text_column=text_column
# )
if run_quant: # if run_quant:
self._awq_quant() # self._awq_quant()
self.is_quantized = True # self.is_quantized = True
@staticmethod @staticmethod
def fuse_layers(model, quant_config): def fuse_layers(model, quant_config):
......
...@@ -27,10 +27,12 @@ class OptAWQForCausalLM(BaseAWQForCausalLM): ...@@ -27,10 +27,12 @@ class OptAWQForCausalLM(BaseAWQForCausalLM):
# attention input # attention input
layers.append(dict( layers.append(dict(
prev_op=module.self_attn_layer_norm, prev_op=module.self_attn_layer_norm,
layers=[module.self_attn.q_proj, layers=[
module.self_attn.k_proj, module.self_attn.v_proj], module.self_attn.q_proj,
module.self_attn.k_proj, module.self_attn.v_proj],
inp=input_feat['self_attn.q_proj'], inp=input_feat['self_attn.q_proj'],
module2inspect=module.self_attn, kwargs=module_kwargs, module2inspect=module.self_attn,
kwargs=module_kwargs,
)) ))
# attention out # attention out
......
import torch import torch
import torch.nn as nn import torch.nn as nn
from .quantizer import pseudo_quantize_tensor
import gc import gc
__all__ = ["auto_clip_block"] __all__ = ["auto_clip_block"]
......
...@@ -87,7 +87,51 @@ def scale_gelu_fc(gelu, fc, scales): ...@@ -87,7 +87,51 @@ def scale_gelu_fc(gelu, fc, scales):
for p in fc.parameters(): for p in fc.parameters():
assert torch.isnan(p).sum() == 0 assert torch.isnan(p).sum() == 0
def pseudo_quantize_tensor(w, w_bit=4,
zero_point=True,
q_group_size=-1,
inplace=False,
get_scale_zp=False
):
org_w_shape = w.shape
if q_group_size > 0:
assert org_w_shape[-1] % q_group_size == 0
w = w.reshape(-1, q_group_size)
assert w.dim() == 2
if zero_point:
max_val = w.amax(dim=1, keepdim=True)
min_val = w.amin(dim=1, keepdim=True)
max_int = 2 ** w_bit - 1
min_int = 0
scales = (max_val - min_val).clamp(min=1e-5) / max_int
zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)
else: # we actually never used this
assert min_val is None
max_val = w.abs().amax(dim=1, keepdim=True)
max_val = max_val.clamp(min=1e-5)
max_int = 2 ** (w_bit - 1) - 1
min_int = - 2 ** (w_bit - 1)
scales = max_val / max_int
zeros = 0
assert torch.isnan(scales).sum() == 0
assert torch.isnan(w).sum() == 0
if inplace:
((w.div_(scales).round_().add_(zeros)).clamp_(
min_int, max_int).sub_(zeros)).mul_(scales)
else:
w = (torch.clamp(torch.round(w / scales) +
zeros, min_int, max_int) - zeros) * scales
assert torch.isnan(w).sum() == 0
w = w.reshape(org_w_shape)
if get_scale_zp:
return w, scales.view(w.shape[0], -1), zeros.view(w.shape[0], -1)
else:
return w
@torch.no_grad() @torch.no_grad()
def auto_scale_block(awq_model, def auto_scale_block(awq_model,
...@@ -95,7 +139,7 @@ def auto_scale_block(awq_model, ...@@ -95,7 +139,7 @@ def auto_scale_block(awq_model,
module_kwargs, module_kwargs,
quant_config, quant_config,
input_feat): input_feat):
from .quantizer import pseudo_quantize_tensor # from .quantizer import pseudo_quantize_tensor
# firstly, get the weight quantize function # firstly, get the weight quantize function
if quant_config['w_bit'] is not None: if quant_config['w_bit'] is not None:
def w_quantize_func(p): return pseudo_quantize_tensor(p, w_bit=quant_config["w_bit"], q_group_size=quant_config["q_group_size"]).detach() def w_quantize_func(p): return pseudo_quantize_tensor(p, w_bit=quant_config["w_bit"], q_group_size=quant_config["q_group_size"]).detach()
...@@ -106,24 +150,32 @@ def auto_scale_block(awq_model, ...@@ -106,24 +150,32 @@ def auto_scale_block(awq_model,
module_kwargs.pop("use_cache") module_kwargs.pop("use_cache")
# find the best scale ratio # find the best scale ratio
def _search_module_scale(block, linears2scale: list, x, kwargs={}): def _search_module_scale(module2inspect, layers: list, inp, kwargs={}):
if module2inspect is None:
assert len(layers) == 1
module2inspect = layers[0]
# w: co, ci # w: co, ci
# x: n, ci # x: n, ci
weight = torch.cat([_m.weight for _m in linears2scale], dim=0) weight = torch.cat([_m.weight for _m in layers], dim=0)
w_max = get_weight_scale( org_shape = weight.shape
weight, q_group_size=quant_config.get("q_group_size", -1)) weight = weight.view(-1, quant_config.get("q_group_size"))
w_scale = weight.abs() / weight.abs().amax(dim=1, keepdim=True)
w_scale = w_scale.view(org_shape)
w_max = w_scale.mean(0)
# Clear GPU memory # Clear GPU memory
del weight del weight
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
x = x.to(next(block.parameters()).device) inp = inp.to(next(module2inspect.parameters()).device)
with torch.no_grad(): with torch.no_grad():
org_out = block(x, **kwargs) org_out = module2inspect(inp, **kwargs)
if isinstance(org_out, tuple): if isinstance(org_out, tuple):
org_out = org_out[0] org_out = org_out[0]
x_max = get_act_scale(x) x_max = get_act_scale(inp)
best_error = float('inf') best_error = float('inf')
best_ratio = -1 best_ratio = -1
...@@ -132,17 +184,17 @@ def auto_scale_block(awq_model, ...@@ -132,17 +184,17 @@ def auto_scale_block(awq_model,
n_grid = 20 n_grid = 20
history = [] history = []
org_sd = {k: v.cpu() for k, v in block.state_dict().items()} org_sd = {k: v.cpu() for k, v in module2inspect.state_dict().items()}
for ratio in range(n_grid): for ratio in range(n_grid):
ratio = ratio * 1 / n_grid ratio = ratio * 1 / n_grid
scales = (x_max.pow(ratio) / w_max.pow(1-ratio) scales = (x_max.pow(ratio) / w_max.pow(1-ratio)
).clamp(min=1e-4).view(-1) ).clamp(min=1e-4).view(-1)
scales = scales / (scales.max() * scales.min()).sqrt() scales = scales / (scales.max() * scales.min()).sqrt()
for fc in linears2scale: for fc in layers:
fc.weight.mul_(scales.view(1, -1).to(fc.weight.device)) fc.weight.mul_(scales.view(1, -1).to(fc.weight.device))
fc.weight.data = w_quantize_func( fc.weight.data = w_quantize_func(
fc.weight.data) / (scales.view(1, -1)) fc.weight.data) / (scales.view(1, -1))
out = block(x, **kwargs) out = module2inspect(inp, **kwargs)
if isinstance(out, tuple): if isinstance(out, tuple):
out = out[0] out = out[0]
...@@ -153,7 +205,7 @@ def auto_scale_block(awq_model, ...@@ -153,7 +205,7 @@ def auto_scale_block(awq_model,
best_error = loss best_error = loss
best_ratio = ratio best_ratio = ratio
best_scales = scales best_scales = scales
block.load_state_dict(org_sd) module2inspect.load_state_dict(org_sd)
if best_ratio == -1: if best_ratio == -1:
logging.debug(history) logging.debug(history)
raise Exception raise Exception
...@@ -163,13 +215,9 @@ def auto_scale_block(awq_model, ...@@ -163,13 +215,9 @@ def auto_scale_block(awq_model,
return best_scales.detach() return best_scales.detach()
def _auto_get_scale(prev_op, layers, inp, module2inspect=None, kwargs={}): def _auto_get_scale(prev_op, layers, inp, module2inspect=None, kwargs={}):
# module2inspect: if given, we will check the output diff of this module instead of layers
if module2inspect is None:
assert len(layers) == 1
module2inspect = layers[0]
scales = _search_module_scale(module2inspect, layers, inp, kwargs) scales = _search_module_scale(module2inspect, layers, inp, kwargs)
scales = scales.detach().cpu() scales = scales.detach().cpu()
# prev_op_name, [layer_name], scale # prev_op_name, [layer_name], scale
return (get_op_name(module, prev_op), tuple([get_op_name(module, m) for m in layers]), scales) return (get_op_name(module, prev_op), tuple([get_op_name(module, m) for m in layers]), scales)
......
...@@ -12,7 +12,8 @@ from awq.utils.module import append_str_prefix, get_op_name, get_named_linears, ...@@ -12,7 +12,8 @@ from awq.utils.module import append_str_prefix, get_op_name, get_named_linears,
class AwqQuantizer: class AwqQuantizer:
def __init__(self, model, tokenizer, w_bit, group_size, version, calib_data, split, text_column) -> None: def __init__(self, awq_model, model, tokenizer, w_bit, group_size, version, calib_data, split, text_column) -> None:
self.awq_model = awq_model
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.w_bit = w_bit self.w_bit = w_bit
...@@ -21,7 +22,7 @@ class AwqQuantizer: ...@@ -21,7 +22,7 @@ class AwqQuantizer:
self.calib_data = calib_data self.calib_data = calib_data
self.split = split self.split = split
self.text_column = text_column self.text_column = text_column
self.modules, self.module_kwargs = self.init_quant() self.modules, self.module_kwargs, self.inps = self.init_quant()
def pseudo_quantize_tensor(self, w: torch.Tensor, get_scale_zp=False): def pseudo_quantize_tensor(self, w: torch.Tensor, get_scale_zp=False):
org_w_shape = w.shape org_w_shape = w.shape
...@@ -51,8 +52,8 @@ class AwqQuantizer: ...@@ -51,8 +52,8 @@ class AwqQuantizer:
else: else:
return w return w
def quantize(self, get_layers_for_scaling: function): def quantize(self):
for i in tqdm(range(len(self.modules)), desc="QUANTIZING"): for i in tqdm(range(len(self.modules)), desc="AWQ"):
# [STEP 1]: Get layer, extract linear modules, extract input features # [STEP 1]: Get layer, extract linear modules, extract input features
self.modules[i] = self.modules[i].cuda() self.modules[i] = self.modules[i].cuda()
named_linears = get_named_linears(self.modules[i]) named_linears = get_named_linears(self.modules[i])
...@@ -60,22 +61,22 @@ class AwqQuantizer: ...@@ -60,22 +61,22 @@ class AwqQuantizer:
clear_memory() clear_memory()
# [STEP 2]: Compute and apply scale list # [STEP 2]: Compute and apply scale list
module_config: list[dict] = get_layers_for_scaling( module_config: list[dict] = self.awq_model.get_layers_for_scaling(
self.modules[i], input_feat, self.module_kwargs self.modules[i], input_feat, self.module_kwargs
) )
scales_list = [self._search_best_scale(**layer) for layer in module_config] scales_list = [self._search_best_scale(self.modules[i], **layer) for layer in module_config]
apply_scale(self.modules[i], scales_list, input_feat_dict=input_feat) apply_scale(self.modules[i], scales_list, input_feat_dict=input_feat)
scales_list = append_str_prefix(scales_list, get_op_name(self.model, self.modules[i]) + ".") scales_list = append_str_prefix(scales_list, get_op_name(self.model, self.modules[i]) + ".")
# [STEP 3]: Compute and apply clipping list # [STEP 3]: Compute and apply clipping list
clip_list = self._search_best_clip(named_linears, input_feat) # clip_list = self._search_best_clip(self.modules[i], named_linears, input_feat)
apply_clip(self.modules[i], clip_list) # apply_clip(self.modules[i], clip_list)
clip_list = append_str_prefix(clip_list, get_op_name(self.model, self.modules[i]) + ".") # clip_list = append_str_prefix(clip_list, get_op_name(self.model, self.modules[i]) + ".")
# [STEP 4]: Quantize weights # [STEP 4]: Quantize weights
for name, linear_layer in named_linears.items(): for name, linear_layer in named_linears.items():
linear_layer.weight.data, scales, zeros = self.pseudo_quantize_tensor( linear_layer.weight.data, scales, zeros = self.pseudo_quantize_tensor(
linear_layer.weight.data, linear_layer.weight.data.float(),
get_scale_zp=True get_scale_zp=True
) )
...@@ -102,107 +103,46 @@ class AwqQuantizer: ...@@ -102,107 +103,46 @@ class AwqQuantizer:
clear_memory() clear_memory()
clear_memory() clear_memory()
return self.model
@torch.no_grad()
def _search_best_clip(self, layer, named_linears, input_feat):
clip_list = []
avoid_clipping = ["q_", "k_", "query", "key", "Wqkv"]
for name in named_linears:
# due to qk bmm, it is hard to clip precisely
if any([_ in name for _ in avoid_clipping]):
continue
named_linears[name].cuda()
max_val = self._compute_best_clip(named_linears[name].weight, input_feat[name])
clip_list.append((name, max_val))
named_linears[name].cpu()
@torch.no_grad()
def _compute_best_clip(self, w: torch.Tensor, input_feat: torch.Tensor, n_grid=20, max_shrink=0.5, n_sample_token=512):
assert w.dim() == 2
org_w_shape = w.shape
# w [co, ci] -> [co, 1, n_group, group size]
# input_feat [n_token, ci] -> [1, n_token, n_group, group size]
group_size = self.group_size if self.group_size > 0 else w.shape[1]
input_feat = input_feat.view(-1, input_feat.shape[-1])
input_feat = input_feat.reshape(1, input_feat.shape[0], -1, group_size)
input_feat = input_feat[:, 0::input_feat.shape[1] // n_sample_token]
w = w.reshape(w.shape[0], 1, -1, group_size)
oc_batch_size = 256 if w.shape[0] % 256 == 0 else 64 # prevent OOM
assert w.shape[0] % oc_batch_size == 0
w_all = w
best_max_val_all = []
for i_b in range(w.shape[0] // oc_batch_size):
w = w_all[i_b * oc_batch_size: (i_b + 1) * oc_batch_size]
org_max_val = w.abs().amax(dim=-1, keepdim=True) # co, 1, n_group, 1
best_max_val = org_max_val.clone()
min_errs = torch.ones_like(org_max_val) * 1e9
input_feat = input_feat.to(w.device)
org_out = (input_feat * w).sum(dim=-1) # co, n_token, n_group
for i_s in range(int(max_shrink * n_grid)):
max_val = org_max_val * (1 - i_s / n_grid)
min_val = - max_val
cur_w = torch.clamp(w, min_val, max_val)
q_w = self.pseudo_quantize_tensor(cur_w)
cur_out = (input_feat * q_w).sum(dim=-1)
# co, 1, n_group, 1
err = (cur_out - org_out).pow(2).mean(dim=1).view(min_errs.shape)
del cur_w
del cur_out
cur_best_idx = err < min_errs
min_errs[cur_best_idx] = err[cur_best_idx]
best_max_val[cur_best_idx] = max_val[cur_best_idx]
best_max_val_all.append(best_max_val)
best_max_val = torch.cat(best_max_val_all, dim=0)
clear_memory(input_feat)
clear_memory(org_out)
return best_max_val.squeeze(1)
@torch.no_grad() @torch.no_grad()
def _search_best_scale(self, previous_layer, linears2scale: list[nn.Linear], x: torch.Tensor, kwargs={}): def _search_best_scale(self, module, prev_op, layers: list[nn.Linear], inp: torch.Tensor, module2inspect=None, kwargs={}):
if module2inspect is None:
assert len(layers) == 1
module2inspect = layers[0]
if "use_cache" in kwargs:
kwargs.pop("use_cache")
# Put x on the right device # Put x on the right device
x = x.to(next(previous_layer.parameters()).device) inp = inp.to(next(module2inspect.parameters()).device)
# [STEP 1]: Compute maximum of weight # [STEP 1]: Compute maximum of weight
weight = torch.cat([_m.weight for _m in linears2scale], dim=0) weight = torch.cat([_m.weight for _m in layers], dim=0)
org_shape = weight.shape
weight = weight.view(-1, self.group_size) weight = weight.view(-1, self.group_size)
w_max = weight.abs() / weight.abs().amax(dim=1, keepdim=True) w_scale = weight.abs() / weight.abs().amax(dim=1, keepdim=True)
w_max = w_max.view(weight.shape) w_scale = w_scale.view(org_shape)
w_max = w_max.mean(0) w_max = w_scale.mean(0)
clear_memory(weight) clear_memory(weight)
# [STEP 2]: Compute maximum of x # [STEP 2]: Compute maximum of x
x_max = x.abs().view(-1, x.shape[-1]).mean(0) x_max = inp.abs().view(-1, inp.shape[-1]).mean(0)
# [STEP 3]: Compute output of previous layer # [STEP 3]: Compute output of module
with torch.no_grad(): with torch.no_grad():
org_out = previous_layer(x, **kwargs) org_out = module2inspect(inp, **kwargs)
if isinstance(org_out, tuple): if isinstance(org_out, tuple):
org_out = org_out[0] org_out = org_out[0]
# [STEP 4]: Compute loss # [STEP 4]: Compute loss
best_scales = self._compute_best_scale( best_scales = self._compute_best_scale(
x, w_max, x_max, previous_layer, inp, w_max, x_max, module2inspect,
linears2scale, org_out, kwargs layers, org_out, kwargs
) )
return best_scales return (get_op_name(module, prev_op), tuple([get_op_name(module, m) for m in layers]), best_scales)
def _compute_best_scale(self, x, w_max, x_max, previous_layer, linears2scale: list[nn.Linear], org_out, kwargs={}): def _compute_best_scale(self, x, w_max, x_max, module2inspect, linears2scale: list[nn.Linear], org_out, kwargs={}):
""" """
Compute loss and select best scales Compute loss and select best scales
...@@ -218,7 +158,7 @@ class AwqQuantizer: ...@@ -218,7 +158,7 @@ class AwqQuantizer:
best_scales = None best_scales = None
best_error = float('inf') best_error = float('inf')
org_sd = {k: v.cpu() for k, v in previous_layer.state_dict().items()} org_sd = {k: v.cpu() for k, v in module2inspect.state_dict().items()}
device = x.device device = x.device
x_max = x_max.view(-1).to(device) x_max = x_max.view(-1).to(device)
...@@ -235,7 +175,7 @@ class AwqQuantizer: ...@@ -235,7 +175,7 @@ class AwqQuantizer:
fc.weight.mul_(scales_view) fc.weight.mul_(scales_view)
fc.weight.data = self.pseudo_quantize_tensor(fc.weight.data) / scales_view fc.weight.data = self.pseudo_quantize_tensor(fc.weight.data) / scales_view
out = previous_layer(x, **kwargs) out = module2inspect(x, **kwargs)
if isinstance(out, tuple): if isinstance(out, tuple):
out = out[0] out = out[0]
...@@ -246,7 +186,7 @@ class AwqQuantizer: ...@@ -246,7 +186,7 @@ class AwqQuantizer:
best_error = loss best_error = loss
best_ratio = ratio best_ratio = ratio
best_scales = scales.clone() best_scales = scales.clone()
previous_layer.load_state_dict(org_sd) module2inspect.load_state_dict(org_sd)
if best_ratio == -1: if best_ratio == -1:
logging.debug(history) logging.debug(history)
...@@ -254,12 +194,76 @@ class AwqQuantizer: ...@@ -254,12 +194,76 @@ class AwqQuantizer:
assert torch.isnan(best_scales).sum() == 0, best_scales assert torch.isnan(best_scales).sum() == 0, best_scales
return best_scales.detach() return best_scales.detach().cpu()
@torch.no_grad()
def _search_best_clip(self, layer, named_linears, input_feat):
clip_list = []
avoid_clipping = ["q_", "k_", "query", "key", "Wqkv"]
def init_quant(self, n_samples=128, seqlen=512): for name in named_linears:
layers = self.get_model_layers(self.model) # due to qk bmm, it is hard to clip precisely
if any([_ in name for _ in avoid_clipping]):
continue
named_linears[name].cuda()
max_val = self._compute_best_clip(named_linears[name].weight, input_feat[name])
clip_list.append((name, max_val))
named_linears[name].cpu()
@torch.no_grad()
def _compute_best_clip(self, w: torch.Tensor, input_feat: torch.Tensor, n_grid=20, max_shrink=0.5, n_sample_token=512):
assert w.dim() == 2
org_w_shape = w.shape
# w [co, ci] -> [co, 1, n_group, group size]
# input_feat [n_token, ci] -> [1, n_token, n_group, group size]
group_size = self.group_size if self.group_size > 0 else w.shape[1]
input_feat = input_feat.view(-1, input_feat.shape[-1])
input_feat = input_feat.reshape(1, input_feat.shape[0], -1, group_size)
input_feat = input_feat[:, 0::input_feat.shape[1] // n_sample_token]
w = w.reshape(w.shape[0], 1, -1, group_size)
oc_batch_size = 256 if w.shape[0] % 256 == 0 else 64 # prevent OOM
assert w.shape[0] % oc_batch_size == 0
w_all = w
best_max_val_all = []
for i_b in range(w.shape[0] // oc_batch_size):
w = w_all[i_b * oc_batch_size: (i_b + 1) * oc_batch_size]
org_max_val = w.abs().amax(dim=-1, keepdim=True) # co, 1, n_group, 1
best_max_val = org_max_val.clone()
min_errs = torch.ones_like(org_max_val) * 1e9
input_feat = input_feat.to(w.device)
org_out = (input_feat * w).sum(dim=-1) # co, n_token, n_group
for i_s in range(int(max_shrink * n_grid)):
max_val = org_max_val * (1 - i_s / n_grid)
min_val = - max_val
cur_w = torch.clamp(w, min_val, max_val)
q_w = self.pseudo_quantize_tensor(cur_w)
cur_out = (input_feat * q_w).sum(dim=-1)
# co, 1, n_group, 1
err = (cur_out - org_out).pow(2).mean(dim=1).view(min_errs.shape)
del cur_w
del cur_out
cur_best_idx = err < min_errs
min_errs[cur_best_idx] = err[cur_best_idx]
best_max_val[cur_best_idx] = max_val[cur_best_idx]
best_max_val_all.append(best_max_val)
best_max_val = torch.cat(best_max_val_all, dim=0)
clear_memory(input_feat)
clear_memory(org_out)
return best_max_val.squeeze(1)
def init_quant(self, n_samples=128, seqlen=512):
modules = self.awq_model.get_model_layers(self.model)
samples = get_calib_dataset( samples = get_calib_dataset(
data=self.calib_data, tokenizer=self.tokenizer, n_samples=n_samples, block_size=seqlen, data=self.calib_data, tokenizer=self.tokenizer, n_samples=n_samples, block_size=seqlen,
split=self.split, text_column=self.text_column split=self.split, text_column=self.text_column
...@@ -269,8 +273,8 @@ class AwqQuantizer: ...@@ -269,8 +273,8 @@ class AwqQuantizer:
inps = [] inps = []
layer_kwargs = {} layer_kwargs = {}
layers[0] = layers[0].cuda() modules[0] = modules[0].cuda()
self.move_embed(self.model, "cuda") self.awq_model.move_embed(self.model, "cuda")
# get input and kwargs to layer 0 # get input and kwargs to layer 0
# with_kwargs is only supported in PyTorch 2.0 # with_kwargs is only supported in PyTorch 2.0
...@@ -286,21 +290,21 @@ class AwqQuantizer: ...@@ -286,21 +290,21 @@ class AwqQuantizer:
raise ValueError # early exit to break later inference raise ValueError # early exit to break later inference
# patch layer 0 to catch input and kwargs # patch layer 0 to catch input and kwargs
layers[0] = Catcher(layers[0]) modules[0] = Catcher(modules[0])
try: try:
self.model(samples.to(next(self.model.parameters()).device)) self.model(samples.to(next(self.model.parameters()).device))
except ValueError: # work with early exit except ValueError: # work with early exit
pass pass
del samples del samples
layers[0] = layers[0].module # restore modules[0] = modules[0].module # restore
inps = inps[0] inps = inps[0]
layers[0] = layers[0].cpu() modules[0] = modules[0].cpu()
self.move_embed(self.model, "cpu") self.awq_model.move_embed(self.model, "cpu")
clear_memory() clear_memory()
return layers, layer_kwargs return modules, layer_kwargs, inps
def _get_input_feat(self, layer, named_linears): def _get_input_feat(self, layer, named_linears):
# firstly, get input features of all linear layers # firstly, get input features of all linear layers
...@@ -315,9 +319,9 @@ class AwqQuantizer: ...@@ -315,9 +319,9 @@ class AwqQuantizer:
handles.append(named_linears[name].register_forward_hook( handles.append(named_linears[name].register_forward_hook(
functools.partial(cache_input_hook, name=name, functools.partial(cache_input_hook, name=name,
feat_dict=input_feat))) feat_dict=input_feat)))
inps = inps.to(next(layer.parameters()).device) # in case multi-gpu self.inps = self.inps.to(next(layer.parameters()).device) # in case multi-gpu
# get output as next layer's input # get output as next layer's input
inps = layer(inps, **self.module_kwargs)[0] self.inps = layer(self.inps, **self.module_kwargs)[0]
for h in handles: for h in handles:
h.remove() h.remove()
# now solve for scaling and clipping # now solve for scaling and clipping
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment