entry.py 3.75 KB
Newer Older
Ji Lin's avatar
Ji Lin committed
1
import os
Casper Hansen's avatar
Casper Hansen committed
2
import torch
3
4
import argparse
from lm_eval import evaluator
Casper Hansen's avatar
Casper Hansen committed
5
6
from awq.quantize.auto_clip import apply_clip
from awq.quantize.auto_scale import apply_scale
7
from awq.utils.lm_eval_adaptor import LMEvalAdaptor
Casper Hansen's avatar
Casper Hansen committed
8
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
Ji Lin's avatar
Ji Lin committed
9

Casper Hansen's avatar
Casper Hansen committed
10
11
12
13
14
15
16
def get_awq_model(model):
    from awq.models import MptAWQForCausalLM

    if "mpt" in str(model.__class__).lower():
        return MptAWQForCausalLM()
    else:
        raise NotImplementedError(type(model))
Ji Lin's avatar
Ji Lin committed
17

Casper Hansen's avatar
Casper Hansen committed
18
def load_unquantized(model_path):
19
    config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
Casper Hansen's avatar
Casper Hansen committed
20
    tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name, trust_remote_code=True)
Ji Lin's avatar
Ji Lin committed
21

Casper Hansen's avatar
Casper Hansen committed
22
23
24
    kwargs = {"torch_dtype": torch.float16, "low_cpu_mem_usage": True}
    model = AutoModelForCausalLM.from_pretrained(
        model_path, config=config, trust_remote_code=True, **kwargs)
Abhinav Kulkarni's avatar
Abhinav Kulkarni committed
25

Casper Hansen's avatar
Casper Hansen committed
26
    model.eval()
27

Casper Hansen's avatar
Casper Hansen committed
28
    return model, tokenizer
Abhinav Kulkarni's avatar
Abhinav Kulkarni committed
29

30
31
32
def load_quantized(model_path):
    awq_model = get_awq_model(model)

Casper Hansen's avatar
Casper Hansen committed
33
34
def load_search_result_into_memory(model, search_path):
    awq_results = torch.load(search_path, map_location="cpu")
Casper Hansen's avatar
Casper Hansen committed
35
            
Casper Hansen's avatar
Casper Hansen committed
36
37
    apply_scale(model, awq_results["scale"])
    apply_clip(model, awq_results["clip"])
Ji Lin's avatar
Ji Lin committed
38

39
def run_search(model_path, dump_path, w_bit, q_config):
Casper Hansen's avatar
Casper Hansen committed
40
41
    model, tokenizer = load_unquantized(model_path)
    awq_model = get_awq_model(model)
42
    awq_results = awq_model.quantize(model, tokenizer, w_bit=w_bit, q_config=q_config, run_search=True, run_quant=False)
Ji Lin's avatar
Ji Lin committed
43

Casper Hansen's avatar
Casper Hansen committed
44
45
46
    dirpath = os.path.dirname(dump_path)
    os.makedirs(dirpath, exist_ok=True)
    torch.save(awq_results, dump_path)
Ji Lin's avatar
Ji Lin committed
47

48
49
def run_quant(model_path, search_path, dump_path, w_bit, q_config, device):
    model, tokenizer = load_unquantized(model_path, device)
Casper Hansen's avatar
Casper Hansen committed
50
    load_search_result_into_memory(model, search_path)
Ji Lin's avatar
Ji Lin committed
51

Casper Hansen's avatar
Casper Hansen committed
52
    awq_model = get_awq_model(model)
53
    awq_model.quantize(model, w_bit=w_bit, q_config=q_config, run_search=False, run_quant=True)
Ji Lin's avatar
Ji Lin committed
54

Casper Hansen's avatar
Casper Hansen committed
55
56
57
    dirpath = os.path.dirname(dump_path)
    os.makedirs(dirpath, exist_ok=True)
    torch.save(model.cpu().state_dict(), dump_path)
Ji Lin's avatar
Ji Lin committed
58

59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def run_perplexity(model_path, device):
    model, tokenizer = load_unquantized(model_path)

    lm_eval_model = LMEvalAdaptor(model_path, model, tokenizer, device, batch_size=1)
    results = evaluator.simple_evaluate(
        model=lm_eval_model,
        tasks=['wikitext'],
        batch_size=1,
        no_cache=True,
        num_fewshot=0,
    )

    print(evaluator.make_table(results))

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--entry_type', type=str, help='The type of task to run (search|quant|perplexity)')
    parser.add_argument('--model_path', type=str, help='Path to hf model')
    parser.add_argument('--search_path', type=str, help='Path to save/load AWQ search results')
    parser.add_argument('--quant_path', type=str, help='Path to save/load AWQ quant model')
    parser.add_argument('--device', type=str, default='cuda:0', help='Device to load model to')
    parser.add_argument('--w_bit', type=int, default=4)
    parser.add_argument('--q_group_size', type=int, default=128)
    args = parser.parse_args()

    args.model_path = "./mpt-7b-8k-chat"
    args.search_path = "./mpt-7b-8k-chat/mpt-7b-8k-chat-awq-search.pt"
    args.quant_path = "./mpt-7b-8k-chat/mpt-7b-8k-chat-w4-g128.pt"
    q_config = { "zero_point": True, "q_group_size": args.q_group_size }
    
    if args.entry_type == 'search':
        run_search(args.model_path, args.search_path, args.w_bit, q_config)
    elif args.entry_type == 'quant':
        run_quant(args.model_path, args.search_path, args.quant_path, args.w_bit, q_config)
    elif args.entry_type == 'perplexity':
        run_perplexity(args.model_path, args.device)
    else:
        raise Exception('--entry_type must be one of (search|quant|perplexity)')