Commit d35ade75 authored by Casper Hansen's avatar Casper Hansen
Browse files

Rename q_config -> quant_config. Include w_bit in quant_config. Save quant_config.json.

parent 6f30f051
......@@ -15,7 +15,7 @@ def load_search_result_into_memory(model, search_path):
apply_scale(model, awq_results["scale"])
apply_clip(model, awq_results["clip"])
def run_search(model_path, dump_path, w_bit, q_config):
def run_search(model_path, dump_path, quant_config):
"""
Step 1/2: Search the pile for an optimal scaling factor.
"""
......@@ -24,7 +24,7 @@ def run_search(model_path, dump_path, w_bit, q_config):
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# Quantize
model.quantize(tokenizer, w_bit=w_bit, q_config=q_config, run_search=True, run_quant=False)
model.quantize(tokenizer, quant_config=quant_config, run_search=True, run_quant=False)
# Save search results
model.save_quantized(dump_path)
......@@ -32,7 +32,7 @@ def run_search(model_path, dump_path, w_bit, q_config):
# Save tokenizer
tokenizer.save_pretrained(dump_path)
def run_quant(model_path, search_path, dump_path, w_bit, q_config):
def run_quant(model_path, search_path, dump_path, quant_config):
"""
Step 2/2: Use the search results to quantize model weights
"""
......@@ -41,17 +41,17 @@ def run_quant(model_path, search_path, dump_path, w_bit, q_config):
load_search_result_into_memory(model.model, search_path)
# Run actual weight quantization
model.quantize(w_bit=w_bit, q_config=q_config, run_search=False, run_quant=True)
model.quantize(quant_config=quant_config, run_search=False, run_quant=True)
# Save quantized model
model.save_quantized(dump_path)
def run_perplexity(quant_path, quant_file, w_bit, q_config, device):
def run_perplexity(quant_path, quant_file, quant_config, device):
"""
Post quantization: Evaluate perplexity on wikitext with EleutherAI Evaluation Harness
"""
# Load model
model = AutoAWQForCausalLM.from_quantized(quant_path, quant_file, w_bit, q_config)
model = AutoAWQForCausalLM.from_quantized(quant_path, quant_file, quant_config)
tokenizer = AutoTokenizer.from_pretrained(quant_path, trust_remote_code=True)
# Load adapter
......@@ -85,13 +85,13 @@ if __name__ == '__main__':
parser.add_argument('--q_group_size', type=int, default=128)
args = parser.parse_args()
q_config = { "zero_point": True, "q_group_size": args.q_group_size }
quant_config = { "zero_point": True, "q_group_size": args.q_group_size, "w_bit": args.w_bit }
if args.entry_type == 'search':
run_search(args.model_path, args.search_path, args.w_bit, q_config)
run_search(args.model_path, args.search_path, quant_config)
elif args.entry_type == 'quant':
run_quant(args.model_path, args.search_path, args.quant_path, args.w_bit, q_config)
run_quant(args.model_path, args.search_path, args.quant_path, quant_config)
elif args.entry_type == 'perplexity':
run_perplexity(args.quant_path, args.quant_file, args.w_bit, q_config, args.device)
run_perplexity(args.quant_path, args.quant_file, args.w_bit, quant_config, args.device)
else:
raise Exception('--entry_type must be one of (search|quant|perplexity)')
\ No newline at end of file
......@@ -16,7 +16,7 @@ def check_and_get_model_type(model_dir, trust_remote_code=True):
return model_type
class AutoAWQForCausalLM:
default_q_config = {"zero_point": True, "q_group_size": 128}
default_quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4}
def __init__(self):
raise EnvironmentError('You must instantiate AutoAWQForCausalLM with\n'
......@@ -31,11 +31,11 @@ class AutoAWQForCausalLM:
)
@classmethod
def from_quantized(self, quant_path, quant_filename, w_bit=4, q_config={},
def from_quantized(self, quant_path, quant_filename, quant_config={},
device='balanced', trust_remote_code=True) -> BaseAWQForCausalLM:
model_type = check_and_get_model_type(quant_path, trust_remote_code)
q_config = q_config if q_config else self.default_q_config
quant_config = quant_config if quant_config else self.default_quant_config
return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized(
quant_path, model_type, quant_filename, w_bit, q_config, device, trust_remote_code=trust_remote_code
quant_path, model_type, quant_filename, quant_config, device, trust_remote_code=trust_remote_code
)
\ No newline at end of file
import os
import gc
import json
import torch
import functools
import accelerate
......@@ -18,11 +19,12 @@ from accelerate import init_empty_weights, load_checkpoint_and_dispatch, infer_a
from awq.utils.module import append_str_prefix, get_op_name, get_named_linears, set_op_by_name
class BaseAWQForCausalLM:
def __init__(self, model, model_type, is_quantized):
def __init__(self, model, model_type, is_quantized, quant_config):
self.model:PreTrainedModel = model
self.model_type:str = model_type
self.is_quantized:bool = is_quantized
self.search_result = None
self.quant_config:dict = quant_config
def to(self, device: str):
return self.model.to(device)
......@@ -31,20 +33,21 @@ class BaseAWQForCausalLM:
return self.model(*args, **kwargs)
@torch.no_grad()
def quantize(self, tokenizer=None, w_bit=4, q_config={}, n_samples=128, seqlen=512,
def quantize(self, tokenizer=None, quant_config={}, n_samples=128, seqlen=512,
auto_scale=True, mse_range=True, run_search=False, run_quant=True,
calib_data="pileval"):
self.quant_config = quant_config
if run_search:
self.search_result = self._awq_search(tokenizer, w_bit, q_config, n_samples=n_samples, seqlen=seqlen,
self.search_result = self._awq_search(tokenizer, quant_config, n_samples=n_samples, seqlen=seqlen,
auto_scale=auto_scale, mse_range=mse_range, calib_data=calib_data)
if run_quant:
self._awq_quant(w_bit, q_config)
self._awq_quant(quant_config)
def _awq_quant(self, w_bit, q_config):
assert q_config["zero_point"], "We only support zero_point quantization now."
def _awq_quant(self, quant_config):
assert quant_config["zero_point"], "We only support zero_point quantization now."
layers = self.get_model_layers(self.model)
# Run AWQ quantization
......@@ -55,11 +58,11 @@ class BaseAWQForCausalLM:
for name, module in named_linears.items():
module.cuda()
module.weight.data, scales, zeros = pseudo_quantize_tensor(module.weight.data, n_bit=w_bit, get_scale_zp=True, **q_config)
module.weight.data, scales, zeros = pseudo_quantize_tensor(module.weight.data, w_bit=quant_config['w_bit'], get_scale_zp=True, **quant_config)
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
q_linear = WQLinear.from_linear(
module, w_bit, q_config['q_group_size'], False, scales, zeros)
module, quant_config['w_bit'], quant_config['q_group_size'], False, scales, zeros)
module.cpu()
q_linear.to(next(layer.parameters()).device)
set_op_by_name(layer, name, q_linear)
......@@ -69,7 +72,7 @@ class BaseAWQForCausalLM:
torch.cuda.empty_cache()
gc.collect()
def _awq_search(self, tokenizer, w_bit, q_config, n_samples=128, seqlen=512,
def _awq_search(self, tokenizer, quant_config, n_samples=128, seqlen=512,
auto_scale=True, mse_range=True, calib_data="pileval"):
layers = self.get_model_layers(self.model)
......@@ -148,12 +151,14 @@ class BaseAWQForCausalLM:
if auto_scale: # if it applies, we should also modify the input_feat with scales
scales_list = auto_scale_block(
self,
layer, layer_kwargs,
w_bit=w_bit, q_config=q_config,
layer,
layer_kwargs,
quant_config=quant_config,
input_feat=input_feat,
)
# apply_scale(layer, scales_list, input_feat_dict=input_feat)
apply_scale(layers[i], scales_list, input_feat_dict=input_feat)
# append prefix to make names global
awq_results["scale"] += append_str_prefix(scales_list, get_op_name(self.model, layer) + ".")
......@@ -161,9 +166,12 @@ class BaseAWQForCausalLM:
torch.cuda.empty_cache()
if mse_range:
clip_list = auto_clip_block(layer,
w_bit=w_bit, q_config=q_config,
input_feat=input_feat,)
clip_list = auto_clip_block(
layer,
quant_config=quant_config,
input_feat=input_feat
)
apply_clip(layer, clip_list)
# append prefix to make names global
awq_results["clip"] += append_str_prefix(clip_list, get_op_name(self.model, layer) + ".")
......@@ -191,6 +199,10 @@ class BaseAWQForCausalLM:
# Save search results
torch.save(model, f'{save_dir}/{model_name}')
# Save config
with open(f'{save_dir}/quant_config.json', 'w+') as file:
file.write(json.dumps(self.quant_config, indent=4))
save_dir = save_dir[:-1] if save_dir[-1] == '/' else save_dir
# Save model
......@@ -215,7 +227,7 @@ class BaseAWQForCausalLM:
)
@classmethod
def from_quantized(self, model_path, model_type, model_filename, w_bit=4, q_config={},
def from_quantized(self, model_path, model_type, model_filename, w_bit=4, quant_config={},
device='balanced', torch_dtype=torch.float16, trust_remote_code=True,
safetensors=False, is_quantized=True):
# Download model if path is not a directory
......@@ -241,7 +253,7 @@ class BaseAWQForCausalLM:
# Only need to replace layers if a model is AWQ quantized
if is_quantized:
# Prepare WQLinear layers, replace nn.Linear
self._load_quantized_modules(self, model, w_bit, q_config)
self._load_quantized_modules(self, model, w_bit, quant_config)
model.tie_weights()
......@@ -266,11 +278,11 @@ class BaseAWQForCausalLM:
)
model.eval()
return self(model, model_type, is_quantized=is_quantized)
return self(model, model_type, is_quantized=is_quantized, quant_config=quant_config)
def _load_quantized_modules(self, model, w_bit, q_config):
def _load_quantized_modules(self, model, w_bit, quant_config):
# Real quantization of weights
assert q_config["zero_point"], "We only support zero_point quantization now."
assert quant_config["zero_point"], "We only support zero_point quantization now."
# Get blocks of model
layers = self.get_model_layers(model)
......@@ -287,7 +299,7 @@ class BaseAWQForCausalLM:
# Replace nn.Linear with WQLinear
for name, module in named_linears.items():
q_linear = WQLinear.from_linear(
module, w_bit, q_config['q_group_size'], True)
module, w_bit, quant_config['q_group_size'], True)
q_linear.to(next(layer.parameters()).device)
set_op_by_name(layer, name, q_linear)
......
......@@ -8,7 +8,9 @@ __all__ = ["auto_clip_block"]
# weight quantization
@torch.no_grad()
def auto_clip_layer(w, input_feat, n_bit, q_config,
def auto_clip_layer(w,
input_feat,
quant_config,
n_grid=20,
max_shrink=0.5,
n_sample_token=512):
......@@ -16,7 +18,7 @@ def auto_clip_layer(w, input_feat, n_bit, q_config,
org_w_shape = w.shape
# w [co, ci] -> [co, 1, n_group, group size]
# input_feat [n_token, ci] -> [1, n_token, n_group, group size]
group_size = q_config["q_group_size"] if q_config["q_group_size"] > 0 else w.shape[1]
group_size = quant_config["q_group_size"] if quant_config["q_group_size"] > 0 else w.shape[1]
input_feat = input_feat.view(-1, input_feat.shape[-1])
input_feat = input_feat.reshape(1, input_feat.shape[0], -1, group_size)
input_feat = input_feat[:, 0::input_feat.shape[1] // n_sample_token]
......@@ -41,7 +43,7 @@ def auto_clip_layer(w, input_feat, n_bit, q_config,
max_val = org_max_val * (1 - i_s / n_grid)
min_val = - max_val
cur_w = torch.clamp(w, min_val, max_val)
q_w = pseudo_quantize_tensor(cur_w, n_bit=n_bit, **q_config)
q_w = pseudo_quantize_tensor(cur_w, **quant_config)
cur_out = (input_feat * q_w).sum(dim=-1)
# co, 1, n_group, 1
......@@ -64,7 +66,7 @@ def auto_clip_layer(w, input_feat, n_bit, q_config,
@torch.no_grad()
def auto_clip_block(module,
w_bit, q_config,
quant_config,
input_feat):
named_linears = {name: m for name,
......@@ -77,7 +79,7 @@ def auto_clip_block(module,
continue
named_linears[name].cuda()
max_val = auto_clip_layer(
named_linears[name].weight, input_feat[name], n_bit=w_bit, q_config=q_config)
named_linears[name].weight, input_feat[name], quant_config=quant_config)
clip_list.append((name, max_val))
named_linears[name].cpu()
return clip_list
......
......@@ -90,15 +90,14 @@ def scale_gelu_fc(gelu, fc, scales):
@torch.no_grad()
def auto_scale_block(awq_model,
module, module_kwargs,
w_bit, q_config,
module,
module_kwargs,
quant_config,
input_feat):
from .quantizer import pseudo_quantize_tensor
# firstly, get the weight quantize function
if w_bit is not None:
def w_quantize_func(p): return pseudo_quantize_tensor(
p, n_bit=w_bit, **q_config,
).detach()
if quant_config['w_bit'] is not None:
def w_quantize_func(p): return pseudo_quantize_tensor(p, **quant_config).detach()
else:
def w_quantize_func(p): return p
......@@ -111,7 +110,7 @@ def auto_scale_block(awq_model,
# x: n, ci
weight = torch.cat([_m.weight for _m in linears2scale], dim=0)
w_max = get_weight_scale(
weight, q_group_size=q_config.get("q_group_size", -1))
weight, q_group_size=quant_config.get("q_group_size", -1))
# Clear GPU memory
del weight
gc.collect()
......
import torch
# core quantization method (simulated quantization)
def pseudo_quantize_tensor(w, n_bit=8,
zero_point=True, q_group_size=-1,
def pseudo_quantize_tensor(w, w_bit=4,
zero_point=True,
q_group_size=-1,
inplace=False,
get_scale_zp=False
):
......@@ -14,7 +15,7 @@ def pseudo_quantize_tensor(w, n_bit=8,
if zero_point:
max_val = w.amax(dim=1, keepdim=True)
min_val = w.amin(dim=1, keepdim=True)
max_int = 2 ** n_bit - 1
max_int = 2 ** w_bit - 1
min_int = 0
scales = (max_val - min_val).clamp(min=1e-5) / max_int
zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)
......@@ -22,8 +23,8 @@ def pseudo_quantize_tensor(w, n_bit=8,
assert min_val is None
max_val = w.abs().amax(dim=1, keepdim=True)
max_val = max_val.clamp(min=1e-5)
max_int = 2 ** (n_bit - 1) - 1
min_int = - 2 ** (n_bit - 1)
max_int = 2 ** (w_bit - 1) - 1
min_int = - 2 ** (w_bit - 1)
scales = max_val / max_int
zeros = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment