Commit 934ad336 authored by Casper Hansen's avatar Casper Hansen
Browse files

Implement argparse and perplexity

parent 5d4ab5dc
import os
import torch
import argparse
from lm_eval import evaluator
from awq.quantize.auto_clip import apply_clip
from awq.quantize.auto_scale import apply_scale
from awq.utils.lm_eval_adaptor import LMEvalAdaptor
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
max_memory = [v.split(':') for v in (None or [])]
max_memory = {(int(k) if k.isdigit() else k):v for k,v in max_memory}
def get_awq_model(model):
from awq.models import MptAWQForCausalLM
......@@ -27,33 +27,70 @@ def load_unquantized(model_path):
return model, tokenizer
def load_quantized(model_path):
awq_model = get_awq_model(model)
def load_search_result_into_memory(model, search_path):
awq_results = torch.load(search_path, map_location="cpu")
apply_scale(model, awq_results["scale"])
apply_clip(model, awq_results["clip"])
def run_search(model, dump_path):
def run_search(model_path, dump_path, w_bit, q_config):
model, tokenizer = load_unquantized(model_path)
awq_model = get_awq_model(model)
awq_results = awq_model.quantize(model, tokenizer, w_bit=4, q_config=q_config, run_search=True, run_quant=False)
awq_results = awq_model.quantize(model, tokenizer, w_bit=w_bit, q_config=q_config, run_search=True, run_quant=False)
dirpath = os.path.dirname(dump_path)
os.makedirs(dirpath, exist_ok=True)
torch.save(awq_results, dump_path)
def run_quant(model, search_path, dump_path):
model, tokenizer = load_unquantized(model_path)
def run_quant(model_path, search_path, dump_path, w_bit, q_config, device):
model, tokenizer = load_unquantized(model_path, device)
load_search_result_into_memory(model, search_path)
awq_model = get_awq_model(model)
awq_model.quantize(model, w_bit=4, q_config=q_config, run_search=False, run_quant=True)
awq_model.quantize(model, w_bit=w_bit, q_config=q_config, run_search=False, run_quant=True)
dirpath = os.path.dirname(dump_path)
os.makedirs(dirpath, exist_ok=True)
torch.save(model.cpu().state_dict(), dump_path)
model_path = "./mpt-7b-8k-chat"
search_path = "./mpt-7b-8k-chat/mpt-7b-8k-chat-awq-search.pt"
quant_path = "./mpt-7b-8k-chat/mpt-7b-8k-chat-w4-g128.pt"
q_config = { "zero_point": True, "q_group_size": 128 }
def run_perplexity(model_path, device):
model, tokenizer = load_unquantized(model_path)
lm_eval_model = LMEvalAdaptor(model_path, model, tokenizer, device, batch_size=1)
results = evaluator.simple_evaluate(
model=lm_eval_model,
tasks=['wikitext'],
batch_size=1,
no_cache=True,
num_fewshot=0,
)
print(evaluator.make_table(results))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--entry_type', type=str, help='The type of task to run (search|quant|perplexity)')
parser.add_argument('--model_path', type=str, help='Path to hf model')
parser.add_argument('--search_path', type=str, help='Path to save/load AWQ search results')
parser.add_argument('--quant_path', type=str, help='Path to save/load AWQ quant model')
parser.add_argument('--device', type=str, default='cuda:0', help='Device to load model to')
parser.add_argument('--w_bit', type=int, default=4)
parser.add_argument('--q_group_size', type=int, default=128)
args = parser.parse_args()
args.model_path = "./mpt-7b-8k-chat"
args.search_path = "./mpt-7b-8k-chat/mpt-7b-8k-chat-awq-search.pt"
args.quant_path = "./mpt-7b-8k-chat/mpt-7b-8k-chat-w4-g128.pt"
q_config = { "zero_point": True, "q_group_size": args.q_group_size }
if args.entry_type == 'search':
run_search(args.model_path, args.search_path, args.w_bit, q_config)
elif args.entry_type == 'quant':
run_quant(args.model_path, args.search_path, args.quant_path, args.w_bit, q_config)
elif args.entry_type == 'perplexity':
run_perplexity(args.model_path, args.device)
else:
raise Exception('--entry_type must be one of (search|quant|perplexity)')
\ No newline at end of file
......@@ -19,13 +19,16 @@ class BaseAWQForCausalLM:
def quantize(self, model, tokenizer=None, w_bit=4, q_config={}, n_samples=128, seqlen=512,
auto_scale=True, mse_range=True, run_search=False, run_quant=True,
calib_data="pileval", init_only=False):
search_result = None
if run_search:
self._awq_search(model, tokenizer, w_bit, q_config, n_samples=n_samples, seqlen=seqlen,
search_result = self._awq_search(model, tokenizer, w_bit, q_config, n_samples=n_samples, seqlen=seqlen,
auto_scale=auto_scale, mse_range=mse_range, calib_data=calib_data)
if run_quant:
self._awq_quant(model, w_bit, q_config, init_only)
return search_result
def _awq_quant(self, model, w_bit, q_config, init_only):
......@@ -118,7 +121,7 @@ class BaseAWQForCausalLM:
}
# Run AWQ search layer by layer
for i in tqdm(range(len(layers)), desc="AWQ Search:"):
for i in tqdm(range(len(layers)), desc="AWQ Search"):
layer = layers[i]
layer = layer.cuda()
named_linears = get_named_linears(layer)
......
......@@ -6,13 +6,13 @@ import fnmatch
class LMEvalAdaptor(BaseLM):
def __init__(self, model_name, model, tokenizer, batch_size=1, max_length=-1):
def __init__(self, model_name, model, tokenizer, device, batch_size=1, max_length=-1):
super().__init__()
assert isinstance(batch_size, int)
self.model_name = model_name
self.model = model
self.model = model.to(device)
self.model.eval()
self.tokenizer = tokenizer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment