Commit d68441d0 authored by Casper Hansen's avatar Casper Hansen
Browse files

Load quantized model with saved quant_config

parent 84d23089
...@@ -46,12 +46,12 @@ def run_quant(model_path, search_path, dump_path, quant_config): ...@@ -46,12 +46,12 @@ def run_quant(model_path, search_path, dump_path, quant_config):
# Save quantized model # Save quantized model
model.save_quantized(dump_path) model.save_quantized(dump_path)
def run_perplexity(quant_path, quant_file, quant_config, device): def run_perplexity(quant_path, quant_file, device):
""" """
Post quantization: Evaluate perplexity on wikitext with EleutherAI Evaluation Harness Post quantization: Evaluate perplexity on wikitext with EleutherAI Evaluation Harness
""" """
# Load model # Load model
model = AutoAWQForCausalLM.from_quantized(quant_path, quant_file, quant_config) model = AutoAWQForCausalLM.from_quantized(quant_path, quant_file)
tokenizer = AutoTokenizer.from_pretrained(quant_path, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(quant_path, trust_remote_code=True)
# Load adapter # Load adapter
...@@ -92,6 +92,6 @@ if __name__ == '__main__': ...@@ -92,6 +92,6 @@ if __name__ == '__main__':
elif args.entry_type == 'quant': elif args.entry_type == 'quant':
run_quant(args.model_path, args.search_path, args.quant_path, quant_config) run_quant(args.model_path, args.search_path, args.quant_path, quant_config)
elif args.entry_type == 'perplexity': elif args.entry_type == 'perplexity':
run_perplexity(args.quant_path, args.quant_file, args.w_bit, quant_config, args.device) run_perplexity(args.quant_path, args.quant_file, args.device)
else: else:
raise Exception('--entry_type must be one of (search|quant|perplexity)') raise Exception('--entry_type must be one of (search|quant|perplexity)')
\ No newline at end of file
...@@ -16,8 +16,6 @@ def check_and_get_model_type(model_dir, trust_remote_code=True): ...@@ -16,8 +16,6 @@ def check_and_get_model_type(model_dir, trust_remote_code=True):
return model_type return model_type
class AutoAWQForCausalLM: class AutoAWQForCausalLM:
default_quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4}
def __init__(self): def __init__(self):
raise EnvironmentError('You must instantiate AutoAWQForCausalLM with\n' raise EnvironmentError('You must instantiate AutoAWQForCausalLM with\n'
'AutoAWQForCausalLM.from_quantized or AutoAWQForCausalLM.from_pretrained') 'AutoAWQForCausalLM.from_quantized or AutoAWQForCausalLM.from_pretrained')
...@@ -31,11 +29,10 @@ class AutoAWQForCausalLM: ...@@ -31,11 +29,10 @@ class AutoAWQForCausalLM:
) )
@classmethod @classmethod
def from_quantized(self, quant_path, quant_filename, quant_config={}, def from_quantized(self, quant_path, quant_filename,
device='balanced', trust_remote_code=True) -> BaseAWQForCausalLM: device='balanced', trust_remote_code=True) -> BaseAWQForCausalLM:
model_type = check_and_get_model_type(quant_path, trust_remote_code) model_type = check_and_get_model_type(quant_path, trust_remote_code)
quant_config = quant_config if quant_config else self.default_quant_config
return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized( return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized(
quant_path, model_type, quant_filename, quant_config, device, trust_remote_code=trust_remote_code quant_path, model_type, quant_filename, device, trust_remote_code=trust_remote_code
) )
\ No newline at end of file
...@@ -3,7 +3,6 @@ import gc ...@@ -3,7 +3,6 @@ import gc
import json import json
import torch import torch
import functools import functools
import accelerate
import torch.nn as nn import torch.nn as nn
from tqdm import tqdm from tqdm import tqdm
from collections import defaultdict from collections import defaultdict
...@@ -44,11 +43,11 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -44,11 +43,11 @@ class BaseAWQForCausalLM(nn.Module):
auto_scale=auto_scale, mse_range=mse_range, calib_data=calib_data) auto_scale=auto_scale, mse_range=mse_range, calib_data=calib_data)
if run_quant: if run_quant:
self._awq_quant(quant_config) self._awq_quant()
def _awq_quant(self, quant_config): def _awq_quant(self):
assert quant_config["zero_point"], "We only support zero_point quantization now." assert self.quant_config["zero_point"], "We only support zero_point quantization now."
layers = self.get_model_layers(self.model) layers = self.get_model_layers(self.model)
# Run AWQ quantization # Run AWQ quantization
...@@ -59,11 +58,25 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -59,11 +58,25 @@ class BaseAWQForCausalLM(nn.Module):
for name, module in named_linears.items(): for name, module in named_linears.items():
module.cuda() module.cuda()
module.weight.data, scales, zeros = pseudo_quantize_tensor(module.weight.data, w_bit=quant_config['w_bit'], get_scale_zp=True, **quant_config)
module.weight.data, scales, zeros = pseudo_quantize_tensor(
module.weight.data,
get_scale_zp=True,
**self.quant_config
)
scales = scales.t().contiguous() scales = scales.t().contiguous()
zeros = zeros.t().contiguous() zeros = zeros.t().contiguous()
q_linear = WQLinear.from_linear( q_linear = WQLinear.from_linear(
module, quant_config['w_bit'], quant_config['q_group_size'], False, scales, zeros) module,
self.quant_config['w_bit'],
self.quant_config['q_group_size'],
False,
scales,
zeros
)
module.cpu() module.cpu()
q_linear.to(next(layer.parameters()).device) q_linear.to(next(layer.parameters()).device)
set_op_by_name(layer, name, q_linear) set_op_by_name(layer, name, q_linear)
...@@ -228,7 +241,7 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -228,7 +241,7 @@ class BaseAWQForCausalLM(nn.Module):
) )
@classmethod @classmethod
def from_quantized(self, model_path, model_type, model_filename, w_bit=4, quant_config={}, def from_quantized(self, model_path, model_type, model_filename,
device='balanced', torch_dtype=torch.float16, trust_remote_code=True, device='balanced', torch_dtype=torch.float16, trust_remote_code=True,
safetensors=False, is_quantized=True): safetensors=False, is_quantized=True):
# Download model if path is not a directory # Download model if path is not a directory
...@@ -245,6 +258,14 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -245,6 +258,14 @@ class BaseAWQForCausalLM(nn.Module):
model_filename = model_path + f'/{model_filename}' model_filename = model_path + f'/{model_filename}'
# Load config # Load config
quant_config_path = f'{model_path}/quant_config.json'
if os.path.exists(quant_config_path):
with open(quant_config_path, 'r') as file:
quant_config = json.loads(file.read())
else:
# Default config that works for most models
quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4}
config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code) config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)
# Load empty weights # Load empty weights
...@@ -254,7 +275,7 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -254,7 +275,7 @@ class BaseAWQForCausalLM(nn.Module):
# Only need to replace layers if a model is AWQ quantized # Only need to replace layers if a model is AWQ quantized
if is_quantized: if is_quantized:
# Prepare WQLinear layers, replace nn.Linear # Prepare WQLinear layers, replace nn.Linear
self._load_quantized_modules(self, model, w_bit, quant_config) self._load_quantized_modules(self, model, quant_config)
model.tie_weights() model.tie_weights()
...@@ -281,7 +302,7 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -281,7 +302,7 @@ class BaseAWQForCausalLM(nn.Module):
return self(model, model_type, is_quantized=is_quantized, quant_config=quant_config) return self(model, model_type, is_quantized=is_quantized, quant_config=quant_config)
def _load_quantized_modules(self, model, w_bit, quant_config): def _load_quantized_modules(self, model, quant_config):
# Real quantization of weights # Real quantization of weights
assert quant_config["zero_point"], "We only support zero_point quantization now." assert quant_config["zero_point"], "We only support zero_point quantization now."
...@@ -300,7 +321,7 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -300,7 +321,7 @@ class BaseAWQForCausalLM(nn.Module):
# Replace nn.Linear with WQLinear # Replace nn.Linear with WQLinear
for name, module in named_linears.items(): for name, module in named_linears.items():
q_linear = WQLinear.from_linear( q_linear = WQLinear.from_linear(
module, w_bit, quant_config['q_group_size'], True) module, quant_config['w_bit'], quant_config['q_group_size'], True)
q_linear.to(next(layer.parameters()).device) q_linear.to(next(layer.parameters()).device)
set_op_by_name(layer, name, q_linear) set_op_by_name(layer, name, q_linear)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment