entry.py 4.99 KB
Newer Older
Ji Lin's avatar
Ji Lin committed
1
import os
Casper Hansen's avatar
Casper Hansen committed
2
import torch
3
4
import argparse
from lm_eval import evaluator
5
6
from transformers import AutoTokenizer
from awq.models.auto import AutoAWQForCausalLM
Casper Hansen's avatar
Casper Hansen committed
7
8
from awq.quantize.auto_clip import apply_clip
from awq.quantize.auto_scale import apply_scale
9
from awq.utils.lm_eval_adaptor import LMEvalAdaptor
Ji Lin's avatar
Ji Lin committed
10

11

Casper Hansen's avatar
Casper Hansen committed
12
13
def load_search_result_into_memory(model, search_path):
    awq_results = torch.load(search_path, map_location="cpu")
Casper Hansen's avatar
Casper Hansen committed
14
            
Casper Hansen's avatar
Casper Hansen committed
15
16
    apply_scale(model, awq_results["scale"])
    apply_clip(model, awq_results["clip"])
Ji Lin's avatar
Ji Lin committed
17

18
def run_search(model_path, dump_path, quant_config):
19
20
21
22
    """
    Step 1/2: Search the pile for an optimal scaling factor.
    """
    # Load model
23
24
    model = AutoAWQForCausalLM.from_pretrained(model_path)
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
Ji Lin's avatar
Ji Lin committed
25

26
    # Quantize
27
    model.quantize(tokenizer, quant_config=quant_config, run_search=True, run_quant=False)
28
29
30

    # Save search results
    model.save_quantized(dump_path)
Ji Lin's avatar
Ji Lin committed
31

32
33
34
    # Save tokenizer
    tokenizer.save_pretrained(dump_path)

35
def run_quant(model_path, search_path, dump_path, quant_config):
36
37
38
39
    """
    Step 2/2: Use the search results to quantize model weights
    """
    # Load model and search results
40
41
    model = AutoAWQForCausalLM.from_pretrained(model_path)
    load_search_result_into_memory(model.model, search_path)
Ji Lin's avatar
Ji Lin committed
42

43
    # Run actual weight quantization
44
    model.quantize(quant_config=quant_config, run_search=False, run_quant=True)
45
46
47

    # Save quantized model
    model.save_quantized(dump_path)
Ji Lin's avatar
Ji Lin committed
48

49
def run_eval(model_path, quant_file, device, tasks, task_batch_size, task_n_shot, task_use_pretrained):
50
51
52
53
    """
    Post quantization: Evaluate perplexity on wikitext with EleutherAI Evaluation Harness
    """
    # Load model
54
55
56
57
58
59
    if task_use_pretrained:
        model = AutoAWQForCausalLM.from_pretrained(model_path)
    else:
        model = AutoAWQForCausalLM.from_quantized(model_path, quant_file)
    
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
60

61
    # Load adapter
62
    lm_eval_model = LMEvalAdaptor(model_path, model, tokenizer, device, batch_size=task_batch_size)
63
64

    # Evaluate perplexity of quantized model
65
66
    results = evaluator.simple_evaluate(
        model=lm_eval_model,
67
68
        tasks=tasks.split(','),
        batch_size=task_batch_size,
69
        no_cache=True,
70
        num_fewshot=task_n_shot,
71
72
73
74
75
    )

    print(evaluator.make_table(results))

if __name__ == '__main__':
76
    """
77
78
79
80
81
82
83
    - Run AWQ search and save result:
    python -m awq.entry --entry_type search --model_path lmsys/vicuna-7b-v1.5 --search_path vicuna-7b-v1.5-awq

    - Run AWQ to save the real quantized weights at the quant_path:
    python -m awq.entry --entry_type quant --model_path lmsys/vicuna-7b-v1.5 --search_path vicuna-7b-v1.5-awq/awq_model_search_result.pt --quant_path vicuna-7b-v1.5-awq

    - Run perplexity of quantized model:
84
    python -m awq.entry --entry_type eval --model_path vicuna-7b-v1.5-awq --quant_file awq_model_w4_g128.pt
85
86
87

    - Run perplexity unquantized FP16 model:
    python -m awq.entry --entry_type eval --model_path lmsys/vicuna-7b-v1.5 --task_use_pretrained
88
    """
89
    parser = argparse.ArgumentParser()
90
    parser.add_argument('--entry_type', type=str, help='The type of task to run (search|quant|eval)')
91
92
    parser.add_argument('--model_path', type=str, help='Path to hf model')
    parser.add_argument('--search_path', type=str, help='Path to save/load AWQ search results')
93
    parser.add_argument('--quant_path', type=str, help='Path to save AWQ model to directory')
94
95
    parser.add_argument('--quant_file', type=str, help='Path to quantized AWQ model file')
    parser.add_argument('--device', type=str, default='cuda:0', help='Device to load model to')
96
97
    parser.add_argument('--w_bit', type=int, default=4)
    parser.add_argument('--q_group_size', type=int, default=128)
98
99
100
101
102
103
104
    parser.add_argument('--tasks', type=str, default='wikitext', help='Tasks to evaluate. '
                        'Separate tasks by comma for multiple tasks.'
                        'https://github.com/EleutherAI/lm-evaluation-harness/blob/master/docs/task_table.md')
    parser.add_argument("--task_use_pretrained", default=False, action=argparse.BooleanOptionalAction,
                        help="Pass '--task_use_pretrained' to use a pretrained model running FP16")
    parser.add_argument('--task_batch_size', type=int, default=1)
    parser.add_argument('--task_n_shot', type=int, default=0)
105
106
    args = parser.parse_args()

107
    quant_config = { "zero_point": True, "q_group_size": args.q_group_size, "w_bit": args.w_bit }
108
109
    
    if args.entry_type == 'search':
110
        run_search(args.model_path, args.search_path, quant_config)
111
    elif args.entry_type == 'quant':
112
        run_quant(args.model_path, args.search_path, args.quant_path, quant_config)
113
114
115
    elif args.entry_type == 'eval':
        run_eval(args.model_path, args.quant_file, args.device, 
                       args.tasks, args.task_batch_size, args.task_n_shot, args.task_use_pretrained)
116
    else:
117
        raise Exception('--entry_type must be one of (search|quant|eval)')