Commit 35ac58c7 authored by Casper Hansen's avatar Casper Hansen
Browse files

Cleanup code

parent e09dc751
...@@ -18,7 +18,7 @@ class BaseAWQForCausalLM: ...@@ -18,7 +18,7 @@ class BaseAWQForCausalLM:
@torch.no_grad() @torch.no_grad()
def quantize(self, model, tokenizer=None, w_bit=4, q_config={}, n_samples=128, seqlen=512, def quantize(self, model, tokenizer=None, w_bit=4, q_config={}, n_samples=128, seqlen=512,
auto_scale=True, mse_range=True, run_search=False, run_quant=True, auto_scale=True, mse_range=True, run_search=False, run_quant=True,
calib_data="pileval", init_only=False): calib_data="pileval"):
search_result = None search_result = None
if run_search: if run_search:
...@@ -26,12 +26,12 @@ class BaseAWQForCausalLM: ...@@ -26,12 +26,12 @@ class BaseAWQForCausalLM:
auto_scale=auto_scale, mse_range=mse_range, calib_data=calib_data) auto_scale=auto_scale, mse_range=mse_range, calib_data=calib_data)
if run_quant: if run_quant:
self._awq_quant(model, w_bit, q_config, init_only) self._awq_quant(model, w_bit, q_config)
return search_result return search_result
def _awq_quant(self, model, w_bit, q_config, init_only): def _awq_quant(self, model, w_bit, q_config):
assert q_config["zero_point"], "We only support zero_point quantization now." assert q_config["zero_point"], "We only support zero_point quantization now."
layers = self.get_model_layers(model) layers = self.get_model_layers(model)
...@@ -39,36 +39,20 @@ class BaseAWQForCausalLM: ...@@ -39,36 +39,20 @@ class BaseAWQForCausalLM:
for i in tqdm(range(len(layers)), desc="AWQ Quantization"): for i in tqdm(range(len(layers)), desc="AWQ Quantization"):
layer = layers[i] layer = layers[i]
named_linears = get_named_linears(layer) named_linears = get_named_linears(layer)
self._scale_activations(layer)
if not isinstance(layer.ffn.act, ScaledActivation):
param = next(layer.parameters())
# get activation scale
scale_dict = self.get_act_for_scaling(layer)
scale_like = torch.ones(scale_dict['scale_shape'], dtype=param.dtype, device=param.device)
# scale activation
scaled_act = ScaledActivation(scale_dict['scale_layer'], scale_like)
set_op_by_name(layer, scale_dict['scale_name'], scaled_act)
for name, module in named_linears.items(): for name, module in named_linears.items():
if init_only: module.cuda()
q_linear = WQLinear.from_linear( module.weight.data, scales, zeros = pseudo_quantize_tensor(module.weight.data, n_bit=w_bit, get_scale_zp=True, **q_config)
module, w_bit, q_config['q_group_size'], True) scales = scales.t().contiguous()
q_linear.to(next(layer.parameters()).device) zeros = zeros.t().contiguous()
set_op_by_name(layer, name, q_linear) q_linear = WQLinear.from_linear(
else: module, w_bit, q_config['q_group_size'], False, scales, zeros)
module.cuda() module.cpu()
module.weight.data, scales, zeros = pseudo_quantize_tensor(module.weight.data, n_bit=w_bit, get_scale_zp=True, **q_config) q_linear.to(next(layer.parameters()).device)
scales = scales.t().contiguous() set_op_by_name(layer, name, q_linear)
zeros = zeros.t().contiguous() torch.cuda.empty_cache()
q_linear = WQLinear.from_linear( gc.collect()
module, w_bit, q_config['q_group_size'], False, scales, zeros)
module.cpu()
q_linear.to(next(layer.parameters()).device)
set_op_by_name(layer, name, q_linear)
torch.cuda.empty_cache()
gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
......
import torch import torch
from tqdm import tqdm
EMBEDDING_KEYWORDS = ["embed"]
LM_HEAD_KEYWORDS = ["lm_head", "embed_out", "output"]
# core quantization method (simulated quantization) # core quantization method (simulated quantization)
def pseudo_quantize_tensor(w, n_bit=8, def pseudo_quantize_tensor(w, n_bit=8,
...@@ -47,24 +44,3 @@ def pseudo_quantize_tensor(w, n_bit=8, ...@@ -47,24 +44,3 @@ def pseudo_quantize_tensor(w, n_bit=8,
return w, scales.view(w.shape[0], -1), zeros.view(w.shape[0], -1) return w, scales.view(w.shape[0], -1), zeros.view(w.shape[0], -1)
else: else:
return w return w
@torch.no_grad()
def pseudo_quantize_model_weight(
model, w_bit, q_config,
):
from .pre_quant import get_blocks, get_named_linears
layers = get_blocks(model)
for i in tqdm(range(len(layers)), desc="pseudo weight quantization..."):
named_linears = get_named_linears(layers[i])
for n, m in named_linears.items():
m.cuda()
m.weight.data = pseudo_quantize_tensor(m.weight.data, n_bit=w_bit, **q_config)
m.cpu()
@torch.no_grad()
def real_quantize_model_weight(model, awq_model):
layers = awq_model.get_model_layers(model)
for i in tqdm(range(len(layers)), desc="real weight quantization..."):
layer = layers[i]
del layer
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment