Commit ed347704 authored by Casper Hansen's avatar Casper Hansen
Browse files

Implemented save_quantized. Generalize from_quantized.Add comments.

parent 14d198c6
......@@ -16,29 +16,45 @@ def load_search_result_into_memory(model, search_path):
apply_clip(model, awq_results["clip"])
def run_search(model_path, dump_path, w_bit, q_config):
"""
Step 1/2: Search the pile for an optimal scaling factor.
"""
# Load model
model = AutoAWQForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
awq_results = model.quantize(model.model, tokenizer, w_bit=w_bit, q_config=q_config, run_search=True, run_quant=False)
dirpath = os.path.dirname(dump_path)
os.makedirs(dirpath, exist_ok=True)
torch.save(awq_results, dump_path)
# Quantize
model.quantize(tokenizer, w_bit=w_bit, q_config=q_config, run_search=True, run_quant=False)
# Save search results
model.save_quantized(dump_path)
def run_quant(model_path, search_path, dump_path, w_bit, q_config):
"""
Step 2/2: Use the search results to quantize model weights
"""
# Load model and search results
model = AutoAWQForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
load_search_result_into_memory(model.model, search_path)
model.quantize(model.model, w_bit=w_bit, q_config=q_config, run_search=False, run_quant=True)
dirpath = os.path.dirname(dump_path)
os.makedirs(dirpath, exist_ok=True)
torch.save(model.model.cpu().state_dict(), dump_path)
# Run actual weight quantization
model.quantize(w_bit=w_bit, q_config=q_config, run_search=False, run_quant=True)
# Save quantized model
model.save_quantized(dump_path)
def run_perplexity(model_path, quant_path, w_bit, q_config, device):
"""
Post quantization: Evaluate perplexity on wikitext with EleutherAI Evaluation Harness
"""
# Load model
model = AutoAWQForCausalLM.from_quantized(model_path, quant_path, w_bit, q_config, device)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# Load adapter
lm_eval_model = LMEvalAdaptor(model_path, model, tokenizer, device, batch_size=1)
# Evaluate perplexity of quantized model
results = evaluator.simple_evaluate(
model=lm_eval_model,
tasks=['wikitext'],
......@@ -50,19 +66,21 @@ def run_perplexity(model_path, quant_path, w_bit, q_config, device):
print(evaluator.make_table(results))
if __name__ == '__main__':
"""
python -m awq.entry --entry_type search --model_path mosaicml/mpt-7b-8k-chat --search_path mpt-7b-8k-chat-awq
python -m awq.entry --entry_type quant --model_path mosaicml/mpt-7b-8k-chat --search_path mpt-7b-8k-chat-awq/pytorch_model.bin --quant_path mpt-7b-8k-chat-awq
python -m awq.entry --entry_type perplexity --model_path mosaicml/mpt-7b-8k-chat --quant_path mpt-7b-8k-chat-awq
"""
parser = argparse.ArgumentParser()
parser.add_argument('--entry_type', type=str, help='The type of task to run (search|quant|perplexity)')
parser.add_argument('--model_path', type=str, help='Path to hf model')
parser.add_argument('--search_path', type=str, help='Path to save/load AWQ search results')
parser.add_argument('--quant_path', type=str, help='Path to save/load AWQ quant model')
parser.add_argument('--device', type=str, default='cuda:0', help='Device to load model to')
parser.add_argument('--device', type=str, default='balanced', help='Device to load model to')
parser.add_argument('--w_bit', type=int, default=4)
parser.add_argument('--q_group_size', type=int, default=128)
args = parser.parse_args()
args.model_path = "./mpt-7b-8k-chat"
args.search_path = "./mpt-7b-8k-chat/mpt-7b-8k-chat-awq-search.pt"
args.quant_path = "./mpt-7b-8k-chat/mpt-7b-8k-chat-w4-g128.pt"
q_config = { "zero_point": True, "q_group_size": args.q_group_size }
if args.entry_type == 'search':
......
from transformers import AutoConfig
from awq.models import MptAWQForCausalLM
from awq.models.base import BaseAWQForCausalLM
AWQ_CAUSAL_LM_MODEL_MAP = {
"mpt": MptAWQForCausalLM,
......@@ -13,12 +14,14 @@ def check_and_get_model_type(model_dir, trust_remote_code=True):
return model_type
class AutoAWQForCausalLM:
default_q_config = {"zero_point": True, "q_group_size": 128}
def __init__(self):
raise EnvironmentError('You must instantiate AutoAWQForCausalLM with\n'
'AutoAWQForCausalLM.from_quantized or AutoAWQForCausalLM.from_pretrained')
@classmethod
def from_pretrained(self, model_path, trust_remote_code=True):
def from_pretrained(self, model_path, trust_remote_code=True) -> BaseAWQForCausalLM:
model_type = check_and_get_model_type(model_path, trust_remote_code)
return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_pretrained(
......@@ -26,9 +29,11 @@ class AutoAWQForCausalLM:
)
@classmethod
def from_quantized(self, model_path, quant_path, w_bit, q_config, device, trust_remote_code=True):
def from_quantized(self, model_path, quant_file, w_bit=4, q_config={},
device='balanced', trust_remote_code=True) -> BaseAWQForCausalLM:
model_type = check_and_get_model_type(model_path, trust_remote_code)
q_config = q_config if q_config else self.default_q_config
return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized(
model_path, quant_path, w_bit, q_config, device, trust_remote_code
model_path, model_type, quant_file, w_bit, q_config, device, trust_remote_code=trust_remote_code
)
\ No newline at end of file
import os
import gc
import torch
import functools
......@@ -5,8 +6,9 @@ import torch.nn as nn
from tqdm import tqdm
from collections import defaultdict
from huggingface_hub import snapshot_download
from awq.utils.calib_data import get_calib_dataset
from transformers import AutoModelForCausalLM, AutoConfig
from transformers import AutoModelForCausalLM, AutoConfig, PreTrainedModel
from awq.quantize.quantizer import pseudo_quantize_tensor
from awq.quantize.qmodule import WQLinear, ScaledActivation
from awq.quantize.auto_clip import auto_clip_block, apply_clip
......@@ -16,29 +18,27 @@ from awq.utils.module import append_str_prefix, get_op_name, get_named_linears,
class BaseAWQForCausalLM:
def __init__(self, model, model_type, is_quantized):
self.model = model
self.model_type = model_type
self.is_quantized = is_quantized
self.model:PreTrainedModel = model
self.model_type:str = model_type
self.is_quantized:bool = is_quantized
self.search_result = None
@torch.no_grad()
def quantize(self, model, tokenizer=None, w_bit=4, q_config={}, n_samples=128, seqlen=512,
def quantize(self, tokenizer=None, w_bit=4, q_config={}, n_samples=128, seqlen=512,
auto_scale=True, mse_range=True, run_search=False, run_quant=True,
calib_data="pileval"):
search_result = None
if run_search:
search_result = self._awq_search(model, tokenizer, w_bit, q_config, n_samples=n_samples, seqlen=seqlen,
self.search_result = self._awq_search(tokenizer, w_bit, q_config, n_samples=n_samples, seqlen=seqlen,
auto_scale=auto_scale, mse_range=mse_range, calib_data=calib_data)
if run_quant:
self._awq_quant(model, w_bit, q_config)
return search_result
self._awq_quant(w_bit, q_config)
def _awq_quant(self, model, w_bit, q_config):
def _awq_quant(self, w_bit, q_config):
assert q_config["zero_point"], "We only support zero_point quantization now."
layers = self.get_model_layers(model)
layers = self.get_model_layers(self.model)
# Run AWQ quantization
for i in tqdm(range(len(layers)), desc="AWQ Quantization"):
......@@ -62,9 +62,9 @@ class BaseAWQForCausalLM:
torch.cuda.empty_cache()
gc.collect()
def _awq_search(self, model, tokenizer, w_bit, q_config, n_samples=128, seqlen=512,
def _awq_search(self, tokenizer, w_bit, q_config, n_samples=128, seqlen=512,
auto_scale=True, mse_range=True, calib_data="pileval"):
layers = self.get_model_layers(model)
layers = self.get_model_layers(self.model)
samples = get_calib_dataset(
data=calib_data, tokenizer=tokenizer, n_samples=n_samples, block_size=seqlen)
......@@ -74,7 +74,7 @@ class BaseAWQForCausalLM:
layer_kwargs = {}
layers[0] = layers[0].cuda()
self.move_embed(model, "cuda")
self.move_embed(self.model, "cuda")
# get input and kwargs to layer 0
# with_kwargs is only supported in PyTorch 2.0
......@@ -92,7 +92,7 @@ class BaseAWQForCausalLM:
# patch layer 0 to catch input and kwargs
layers[0] = Catcher(layers[0])
try:
model(samples.to(next(model.parameters()).device))
self.model(samples.to(next(self.model.parameters()).device))
except ValueError: # work with early exit
pass
del samples
......@@ -100,7 +100,7 @@ class BaseAWQForCausalLM:
inps = inps[0]
layers[0] = layers[0].cpu()
self.move_embed(model, "cpu")
self.move_embed(self.model, "cpu")
gc.collect()
torch.cuda.empty_cache()
......@@ -148,7 +148,7 @@ class BaseAWQForCausalLM:
# apply_scale(layer, scales_list, input_feat_dict=input_feat)
apply_scale(layers[i], scales_list, input_feat_dict=input_feat)
# append prefix to make names global
awq_results["scale"] += append_str_prefix(scales_list, get_op_name(model, layer) + ".")
awq_results["scale"] += append_str_prefix(scales_list, get_op_name(self.model, layer) + ".")
# Clear GPU memory
torch.cuda.empty_cache()
......@@ -159,7 +159,7 @@ class BaseAWQForCausalLM:
input_feat=input_feat,)
apply_clip(layer, clip_list)
# append prefix to make names global
awq_results["clip"] += append_str_prefix(clip_list, get_op_name(model, layer) + ".")
awq_results["clip"] += append_str_prefix(clip_list, get_op_name(self.model, layer) + ".")
layer = layer.cpu()
# Haotian: check activation replacement
......@@ -169,39 +169,77 @@ class BaseAWQForCausalLM:
return awq_results
def save_quantized():
pass
def save_quantized(self, save_dir):
save_dir = save_dir[:-1] if save_dir[-1] == '/' else save_dir
# Save model
if self.search_result is None:
self.model.save_pretrained(save_dir, state_dict=self.model.state_dict())
else:
self.model.save_pretrained(save_dir, state_dict=self.search_result)
# TODO: Rename model name & save quant_config
if self.search_result is not None:
model_name = 'awq_model_search_result.pt'
else:
model_name = 'awq_model_w4_g128.pt'
@classmethod
def from_pretrained(self, model_path, model_type, torch_dtype: torch.dtype = torch.float16,
trust_remote_code=True):
return self.from_quantized(
model_path,
model_type,
quant_file='',
device='balanced',
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code,
is_quantized=False
)
@classmethod
def from_pretrained(self, model_path, model_type, torch_dtype: torch.dtype = torch.float16, trust_remote_code=True):
def from_quantized(self, model_path, model_type, quant_file, w_bit=4, q_config={},
device='balanced', torch_dtype=torch.float16, trust_remote_code=True, is_quantized=True):
# Download model
model_path = snapshot_download(model_path)
quant_path = model_path + f'/{quant_file}' if is_quantized else model_path
# Load config
config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)
# Load empty weights
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch.float16, trust_remote_code=True)
model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code)
# Only need to replace layers if a model is AWQ quantized
if is_quantized:
# Prepare WQLinear layers, replace nn.Linear
self._load_quantized_modules(self, model, w_bit, q_config)
model.tie_weights()
# Load model weights
model = load_checkpoint_and_dispatch(model, model_path, device_map="balanced", no_split_module_classes=[self.layer_type])
model = load_checkpoint_and_dispatch(model, quant_path, device_map=device, no_split_module_classes=[self.layer_type])
return self(model, model_type, is_quantized=False)
return self(model, model_type, is_quantized=is_quantized)
@classmethod
def from_quantized(self, model_path, quant_path, w_bit, q_config, device, trust_remote_code=True):
# Load config
config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch.float16, trust_remote_code=True)
# Initialize layers
def _load_quantized_modules(self, model, w_bit, q_config):
# Real quantization of weights
assert q_config["zero_point"], "We only support zero_point quantization now."
# Get blocks of model
layers = self.get_model_layers(model)
for i in tqdm(range(len(layers)), desc="Replacing layers..."):
layer = layers[i]
# Get every linear layer in a block
named_linears = get_named_linears(layer)
# Replace activation functions
self._scale_activations(self, layer)
# Replace nn.Linear with WQLinear
for name, module in named_linears.items():
q_linear = WQLinear.from_linear(
module, w_bit, q_config['q_group_size'], True)
......@@ -210,12 +248,6 @@ class BaseAWQForCausalLM:
torch.cuda.empty_cache()
gc.collect()
model.tie_weights()
model = load_checkpoint_and_dispatch(model, quant_path, device_map="balanced", no_split_module_classes=[self.layer_type])
return model
@staticmethod
def _scale_activations(self, layer):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment