Commit d68441d0 authored by Casper Hansen's avatar Casper Hansen
Browse files

Load quantized model with saved quant_config

parent 84d23089
......@@ -46,12 +46,12 @@ def run_quant(model_path, search_path, dump_path, quant_config):
# Save quantized model
model.save_quantized(dump_path)
def run_perplexity(quant_path, quant_file, quant_config, device):
def run_perplexity(quant_path, quant_file, device):
"""
Post quantization: Evaluate perplexity on wikitext with EleutherAI Evaluation Harness
"""
# Load model
model = AutoAWQForCausalLM.from_quantized(quant_path, quant_file, quant_config)
model = AutoAWQForCausalLM.from_quantized(quant_path, quant_file)
tokenizer = AutoTokenizer.from_pretrained(quant_path, trust_remote_code=True)
# Load adapter
......@@ -92,6 +92,6 @@ if __name__ == '__main__':
elif args.entry_type == 'quant':
run_quant(args.model_path, args.search_path, args.quant_path, quant_config)
elif args.entry_type == 'perplexity':
run_perplexity(args.quant_path, args.quant_file, args.w_bit, quant_config, args.device)
run_perplexity(args.quant_path, args.quant_file, args.device)
else:
raise Exception('--entry_type must be one of (search|quant|perplexity)')
\ No newline at end of file
......@@ -16,8 +16,6 @@ def check_and_get_model_type(model_dir, trust_remote_code=True):
return model_type
class AutoAWQForCausalLM:
default_quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4}
def __init__(self):
raise EnvironmentError('You must instantiate AutoAWQForCausalLM with\n'
'AutoAWQForCausalLM.from_quantized or AutoAWQForCausalLM.from_pretrained')
......@@ -31,11 +29,10 @@ class AutoAWQForCausalLM:
)
@classmethod
def from_quantized(self, quant_path, quant_filename, quant_config={},
def from_quantized(self, quant_path, quant_filename,
device='balanced', trust_remote_code=True) -> BaseAWQForCausalLM:
model_type = check_and_get_model_type(quant_path, trust_remote_code)
quant_config = quant_config if quant_config else self.default_quant_config
return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized(
quant_path, model_type, quant_filename, quant_config, device, trust_remote_code=trust_remote_code
quant_path, model_type, quant_filename, device, trust_remote_code=trust_remote_code
)
\ No newline at end of file
......@@ -3,7 +3,6 @@ import gc
import json
import torch
import functools
import accelerate
import torch.nn as nn
from tqdm import tqdm
from collections import defaultdict
......@@ -44,11 +43,11 @@ class BaseAWQForCausalLM(nn.Module):
auto_scale=auto_scale, mse_range=mse_range, calib_data=calib_data)
if run_quant:
self._awq_quant(quant_config)
self._awq_quant()
def _awq_quant(self, quant_config):
assert quant_config["zero_point"], "We only support zero_point quantization now."
def _awq_quant(self):
assert self.quant_config["zero_point"], "We only support zero_point quantization now."
layers = self.get_model_layers(self.model)
# Run AWQ quantization
......@@ -59,11 +58,25 @@ class BaseAWQForCausalLM(nn.Module):
for name, module in named_linears.items():
module.cuda()
module.weight.data, scales, zeros = pseudo_quantize_tensor(module.weight.data, w_bit=quant_config['w_bit'], get_scale_zp=True, **quant_config)
module.weight.data, scales, zeros = pseudo_quantize_tensor(
module.weight.data,
get_scale_zp=True,
**self.quant_config
)
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
q_linear = WQLinear.from_linear(
module, quant_config['w_bit'], quant_config['q_group_size'], False, scales, zeros)
module,
self.quant_config['w_bit'],
self.quant_config['q_group_size'],
False,
scales,
zeros
)
module.cpu()
q_linear.to(next(layer.parameters()).device)
set_op_by_name(layer, name, q_linear)
......@@ -228,7 +241,7 @@ class BaseAWQForCausalLM(nn.Module):
)
@classmethod
def from_quantized(self, model_path, model_type, model_filename, w_bit=4, quant_config={},
def from_quantized(self, model_path, model_type, model_filename,
device='balanced', torch_dtype=torch.float16, trust_remote_code=True,
safetensors=False, is_quantized=True):
# Download model if path is not a directory
......@@ -245,6 +258,14 @@ class BaseAWQForCausalLM(nn.Module):
model_filename = model_path + f'/{model_filename}'
# Load config
quant_config_path = f'{model_path}/quant_config.json'
if os.path.exists(quant_config_path):
with open(quant_config_path, 'r') as file:
quant_config = json.loads(file.read())
else:
# Default config that works for most models
quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4}
config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)
# Load empty weights
......@@ -254,7 +275,7 @@ class BaseAWQForCausalLM(nn.Module):
# Only need to replace layers if a model is AWQ quantized
if is_quantized:
# Prepare WQLinear layers, replace nn.Linear
self._load_quantized_modules(self, model, w_bit, quant_config)
self._load_quantized_modules(self, model, quant_config)
model.tie_weights()
......@@ -281,7 +302,7 @@ class BaseAWQForCausalLM(nn.Module):
return self(model, model_type, is_quantized=is_quantized, quant_config=quant_config)
def _load_quantized_modules(self, model, w_bit, quant_config):
def _load_quantized_modules(self, model, quant_config):
# Real quantization of weights
assert quant_config["zero_point"], "We only support zero_point quantization now."
......@@ -300,7 +321,7 @@ class BaseAWQForCausalLM(nn.Module):
# Replace nn.Linear with WQLinear
for name, module in named_linears.items():
q_linear = WQLinear.from_linear(
module, w_bit, quant_config['q_group_size'], True)
module, quant_config['w_bit'], quant_config['q_group_size'], True)
q_linear.to(next(layer.parameters()).device)
set_op_by_name(layer, name, q_linear)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment