"vscode:/vscode.git/clone" did not exist on "1edd4e07d6ad52f4f63e7f6beaa5987c1e1cf621"
entry.py 4.1 KB
Newer Older
Ji Lin's avatar
Ji Lin committed
1
import os
Casper Hansen's avatar
Casper Hansen committed
2
import torch
3
4
import argparse
from lm_eval import evaluator
Casper Hansen's avatar
Casper Hansen committed
5
6
from awq.quantize.auto_clip import apply_clip
from awq.quantize.auto_scale import apply_scale
7
from awq.utils.lm_eval_adaptor import LMEvalAdaptor
Casper Hansen's avatar
Casper Hansen committed
8
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
Ji Lin's avatar
Ji Lin committed
9

Casper Hansen's avatar
Casper Hansen committed
10
11
12
13
14
15
16
def get_awq_model(model):
    from awq.models import MptAWQForCausalLM

    if "mpt" in str(model.__class__).lower():
        return MptAWQForCausalLM()
    else:
        raise NotImplementedError(type(model))
Ji Lin's avatar
Ji Lin committed
17

Casper Hansen's avatar
Casper Hansen committed
18
def load_unquantized(model_path):
19
    config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
Casper Hansen's avatar
Casper Hansen committed
20
    tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name, trust_remote_code=True)
Ji Lin's avatar
Ji Lin committed
21

Casper Hansen's avatar
Casper Hansen committed
22
23
24
    kwargs = {"torch_dtype": torch.float16, "low_cpu_mem_usage": True}
    model = AutoModelForCausalLM.from_pretrained(
        model_path, config=config, trust_remote_code=True, **kwargs)
Abhinav Kulkarni's avatar
Abhinav Kulkarni committed
25

Casper Hansen's avatar
Casper Hansen committed
26
    model.eval()
27

Casper Hansen's avatar
Casper Hansen committed
28
    return model, tokenizer
Abhinav Kulkarni's avatar
Abhinav Kulkarni committed
29

30
31
32
33
34
35
def load_quantized(model_path, quant_path, w_bit, q_config, device):
    from awq.models.auto import AutoAWQForCausalLM
    model = AutoAWQForCausalLM.from_quantized(model_path, quant_path, w_bit, q_config, device)
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

    return model, tokenizer
36

Casper Hansen's avatar
Casper Hansen committed
37
38
def load_search_result_into_memory(model, search_path):
    awq_results = torch.load(search_path, map_location="cpu")
Casper Hansen's avatar
Casper Hansen committed
39
            
Casper Hansen's avatar
Casper Hansen committed
40
41
    apply_scale(model, awq_results["scale"])
    apply_clip(model, awq_results["clip"])
Ji Lin's avatar
Ji Lin committed
42

43
def run_search(model_path, dump_path, w_bit, q_config):
Casper Hansen's avatar
Casper Hansen committed
44
45
    model, tokenizer = load_unquantized(model_path)
    awq_model = get_awq_model(model)
46
    awq_results = awq_model.quantize(model, tokenizer, w_bit=w_bit, q_config=q_config, run_search=True, run_quant=False)
Ji Lin's avatar
Ji Lin committed
47

Casper Hansen's avatar
Casper Hansen committed
48
49
50
    dirpath = os.path.dirname(dump_path)
    os.makedirs(dirpath, exist_ok=True)
    torch.save(awq_results, dump_path)
Ji Lin's avatar
Ji Lin committed
51

52
53
def run_quant(model_path, search_path, dump_path, w_bit, q_config, device):
    model, tokenizer = load_unquantized(model_path, device)
Casper Hansen's avatar
Casper Hansen committed
54
    load_search_result_into_memory(model, search_path)
Ji Lin's avatar
Ji Lin committed
55

Casper Hansen's avatar
Casper Hansen committed
56
    awq_model = get_awq_model(model)
57
    awq_model.quantize(model, w_bit=w_bit, q_config=q_config, run_search=False, run_quant=True)
Ji Lin's avatar
Ji Lin committed
58

Casper Hansen's avatar
Casper Hansen committed
59
60
61
    dirpath = os.path.dirname(dump_path)
    os.makedirs(dirpath, exist_ok=True)
    torch.save(model.cpu().state_dict(), dump_path)
Ji Lin's avatar
Ji Lin committed
62

63
64
def run_perplexity(model_path, quant_path, w_bit, q_config, device):
    model, tokenizer = load_quantized(model_path, quant_path, w_bit, q_config, device)
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97

    lm_eval_model = LMEvalAdaptor(model_path, model, tokenizer, device, batch_size=1)
    results = evaluator.simple_evaluate(
        model=lm_eval_model,
        tasks=['wikitext'],
        batch_size=1,
        no_cache=True,
        num_fewshot=0,
    )

    print(evaluator.make_table(results))

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--entry_type', type=str, help='The type of task to run (search|quant|perplexity)')
    parser.add_argument('--model_path', type=str, help='Path to hf model')
    parser.add_argument('--search_path', type=str, help='Path to save/load AWQ search results')
    parser.add_argument('--quant_path', type=str, help='Path to save/load AWQ quant model')
    parser.add_argument('--device', type=str, default='cuda:0', help='Device to load model to')
    parser.add_argument('--w_bit', type=int, default=4)
    parser.add_argument('--q_group_size', type=int, default=128)
    args = parser.parse_args()

    args.model_path = "./mpt-7b-8k-chat"
    args.search_path = "./mpt-7b-8k-chat/mpt-7b-8k-chat-awq-search.pt"
    args.quant_path = "./mpt-7b-8k-chat/mpt-7b-8k-chat-w4-g128.pt"
    q_config = { "zero_point": True, "q_group_size": args.q_group_size }
    
    if args.entry_type == 'search':
        run_search(args.model_path, args.search_path, args.w_bit, q_config)
    elif args.entry_type == 'quant':
        run_quant(args.model_path, args.search_path, args.quant_path, args.w_bit, q_config)
    elif args.entry_type == 'perplexity':
98
        run_perplexity(args.model_path, args.quant_path, args.w_bit, q_config, args.device)
99
100
    else:
        raise Exception('--entry_type must be one of (search|quant|perplexity)')