entry.py 8.14 KB
Newer Older
Ji Lin's avatar
Ji Lin committed
1
import os
Casper Hansen's avatar
Casper Hansen committed
2
import time
Casper Hansen's avatar
Casper Hansen committed
3
import torch
4
5
import argparse
from lm_eval import evaluator
Casper's avatar
Casper committed
6
from awq import AutoAWQForCausalLM
Casper Hansen's avatar
Casper Hansen committed
7
8
from awq.quantize.auto_clip import apply_clip
from awq.quantize.auto_scale import apply_scale
9
from awq.utils.lm_eval_adaptor import LMEvalAdaptor
Casper Hansen's avatar
Casper Hansen committed
10
from transformers import AutoTokenizer, GenerationConfig
Ji Lin's avatar
Ji Lin committed
11

12

Casper Hansen's avatar
Casper Hansen committed
13
14
def load_search_result_into_memory(model, search_path):
    awq_results = torch.load(search_path, map_location="cpu")
15

Casper Hansen's avatar
Casper Hansen committed
16
17
    apply_scale(model, awq_results["scale"])
    apply_clip(model, awq_results["clip"])
Ji Lin's avatar
Ji Lin committed
18

19
def run_search(model_path, dump_path, quant_config):
20
21
22
23
    """
    Step 1/2: Search the pile for an optimal scaling factor.
    """
    # Load model
24
25
    model = AutoAWQForCausalLM.from_pretrained(model_path)
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
Ji Lin's avatar
Ji Lin committed
26

27
    # Quantize
28
    model.quantize(tokenizer, quant_config=quant_config, run_search=True, run_quant=False)
29
30
31

    # Save search results
    model.save_quantized(dump_path)
Ji Lin's avatar
Ji Lin committed
32

33
34
35
    # Save tokenizer
    tokenizer.save_pretrained(dump_path)

36
def run_quant(model_path, search_path, dump_path, quant_config):
37
38
39
40
    """
    Step 2/2: Use the search results to quantize model weights
    """
    # Load model and search results
41
42
    model = AutoAWQForCausalLM.from_pretrained(model_path)
    load_search_result_into_memory(model.model, search_path)
Ji Lin's avatar
Ji Lin committed
43

44
    # Run actual weight quantization
45
    model.quantize(quant_config=quant_config, run_search=False, run_quant=True)
46
47
48

    # Save quantized model
    model.save_quantized(dump_path)
Ji Lin's avatar
Ji Lin committed
49

50
def run_eval(model_path, quant_file, device, tasks, task_batch_size, task_n_shot, task_use_pretrained):
51
52
53
54
    """
    Post quantization: Evaluate perplexity on wikitext with EleutherAI Evaluation Harness
    """
    # Load model
55
56
57
58
    if task_use_pretrained:
        model = AutoAWQForCausalLM.from_pretrained(model_path)
    else:
        model = AutoAWQForCausalLM.from_quantized(model_path, quant_file)
59

60
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
61

62
    # Load adapter
63
    lm_eval_model = LMEvalAdaptor(model_path, model, tokenizer, device, batch_size=task_batch_size)
64
65

    # Evaluate perplexity of quantized model
66
67
    results = evaluator.simple_evaluate(
        model=lm_eval_model,
68
69
        tasks=tasks.split(','),
        batch_size=task_batch_size,
70
        no_cache=True,
71
        num_fewshot=task_n_shot,
72
73
74
75
    )

    print(evaluator.make_table(results))

Casper Hansen's avatar
Casper Hansen committed
76
@torch.inference_mode()
Casper Hansen's avatar
Casper Hansen committed
77
def run_speed(model_path, quant_file, device, n_generate=128, n_context=256, batch_size=1, disable_fused_layers=False):
Casper Hansen's avatar
Casper Hansen committed
78
79
80
81
    def _timer(func):
        start = time.time()
        out = func()
        return out, time.time() - start
82

Casper Hansen's avatar
Casper Hansen committed
83
84
85
    def _warmup(device:str):
        warm_up = torch.randn((4096,4096)).to(device)
        torch.mm(warm_up,warm_up)
86
87

    if quant_file:
Casper Hansen's avatar
Casper Hansen committed
88
89
        fuse_layers = False if disable_fused_layers else True
        model, load_time = _timer(lambda: AutoAWQForCausalLM.from_quantized(model_path, quant_file, fuse_layers=fuse_layers))
90
91
92
    else:
        model, load_time = _timer(lambda: AutoAWQForCausalLM.from_pretrained(model_path))

Casper Hansen's avatar
Casper Hansen committed
93
94
95
96
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    _warmup(device)

    # Generate random inputs
Casper Hansen's avatar
Casper Hansen committed
97
    n_context = n_context - n_generate
Casper Hansen's avatar
Casper Hansen committed
98
    ids = torch.randint(0, tokenizer.vocab_size, (batch_size, n_context)).cuda()
Casper Hansen's avatar
Casper Hansen committed
99
100

    # Context stage
Casper Hansen's avatar
Casper Hansen committed
101
102
103
104
105
106
107
108
    _, context_time = _timer(lambda: model.generate(
        ids, 
        generation_config=GenerationConfig(
            max_new_tokens=0,
            min_new_tokens=0,
            use_cache=True
        )
    ))
Casper Hansen's avatar
Casper Hansen committed
109
110

    # Generation stage
Casper Hansen's avatar
Casper Hansen committed
111
112
113
114
115
116
117
118
119
120
121
    _, generation_time = _timer(lambda: model.generate(
        ids, 
        generation_config=GenerationConfig(
            max_new_tokens=n_context,
            min_new_tokens=n_context,
            forced_eos_token_id=-100,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=-100,
            use_cache=True
        )
    ))
122

Casper Hansen's avatar
Casper Hansen committed
123
124
    # Prints
    memory_used = torch.cuda.max_memory_allocated(device) / (1024 ** 2)
125
    context_tokens_per_second = n_context / context_time * batch_size
Casper Hansen's avatar
Casper Hansen committed
126
    context_ms_per_token = (context_time*1000) / n_context / batch_size
127
    inference_tokens_per_second = n_generate / generation_time * batch_size
Casper Hansen's avatar
Casper Hansen committed
128
    inference_ms_per_token = (generation_time*1000) / n_generate / batch_size
Casper Hansen's avatar
Casper Hansen committed
129

Casper Hansen's avatar
Casper Hansen committed
130
    print(f"[=] Model summary: {model_path} [=]")
Casper Hansen's avatar
Casper Hansen committed
131
132
133
134
135
    print(f"[*] Load time: {load_time:.2f} seconds")
    print(f"[*] Context speed: {context_tokens_per_second:.2f} tokens/second ({context_ms_per_token:.2f} ms/token)")
    print(f"[*] Generation speed: {inference_tokens_per_second:.2f} tokens/second ({inference_ms_per_token:.2f} ms/token)")
    print(f"[*] VRAM: {memory_used:.2f} MB")

136
if __name__ == '__main__':
137
    """
138
139
140
141
142
143
144
    - Run AWQ search and save result:
    python -m awq.entry --entry_type search --model_path lmsys/vicuna-7b-v1.5 --search_path vicuna-7b-v1.5-awq

    - Run AWQ to save the real quantized weights at the quant_path:
    python -m awq.entry --entry_type quant --model_path lmsys/vicuna-7b-v1.5 --search_path vicuna-7b-v1.5-awq/awq_model_search_result.pt --quant_path vicuna-7b-v1.5-awq

    - Run perplexity of quantized model:
145
    python -m awq.entry --entry_type eval --model_path vicuna-7b-v1.5-awq --quant_file awq_model_w4_g128.pt
146
147
148

    - Run perplexity unquantized FP16 model:
    python -m awq.entry --entry_type eval --model_path lmsys/vicuna-7b-v1.5 --task_use_pretrained
Casper Hansen's avatar
Casper Hansen committed
149
150

    - Run a speedtest to benchmark the quantized model:
Casper Hansen's avatar
Casper Hansen committed
151
    python -m awq.entry --entry_type speed --model_path vicuna-7b-v1.5-awq --quant_file awq_model_w4_g128.pt --n_generate 128 --n_context 256
152
153

    - Run a speedtest to benchmark the unquantized FP16 model:
Casper Hansen's avatar
Casper Hansen committed
154
    python -m awq.entry --entry_type speed --model_path lmsys/vicuna-7b-v1.5 --n_generate 128 --n_context 256
155
    """
156
    parser = argparse.ArgumentParser()
Casper Hansen's avatar
Casper Hansen committed
157
    parser.add_argument('--entry_type', type=str, help='The type of task to run (search|quant|eval|speed)')
158
159
    parser.add_argument('--model_path', type=str, help='Path to hf model')
    parser.add_argument('--search_path', type=str, help='Path to save/load AWQ search results')
160
    parser.add_argument('--quant_path', type=str, help='Path to save AWQ model to directory')
161
162
    parser.add_argument('--quant_file', type=str, help='Path to quantized AWQ model file')
    parser.add_argument('--device', type=str, default='cuda:0', help='Device to load model to')
163
164
    parser.add_argument('--w_bit', type=int, default=4)
    parser.add_argument('--q_group_size', type=int, default=128)
165
166
167
    parser.add_argument('--tasks', type=str, default='wikitext', help='Tasks to evaluate. '
                        'Separate tasks by comma for multiple tasks.'
                        'https://github.com/EleutherAI/lm-evaluation-harness/blob/master/docs/task_table.md')
Casper Hansen's avatar
Casper Hansen committed
168
    parser.add_argument("--task_use_pretrained", default=False, action='store_true',
169
170
171
                        help="Pass '--task_use_pretrained' to use a pretrained model running FP16")
    parser.add_argument('--task_batch_size', type=int, default=1)
    parser.add_argument('--task_n_shot', type=int, default=0)
Casper Hansen's avatar
Casper Hansen committed
172
    parser.add_argument('--n_generate', type=int, default=128)
Casper Hansen's avatar
Casper Hansen committed
173
    parser.add_argument('--n_context', type=int, default=256)
Casper Hansen's avatar
Casper Hansen committed
174
175
176
    parser.add_argument('--batch_size', type=int, default=1)
    parser.add_argument("--disable_fused_layers", default=False, action='store_true',
                        help="Pass '--disable_fused_layers' to disable fused layers")
177
178
    args = parser.parse_args()

179
    quant_config = { "zero_point": True, "q_group_size": args.q_group_size, "w_bit": args.w_bit }
180

181
    if args.entry_type == 'search':
182
        run_search(args.model_path, args.search_path, quant_config)
183
    elif args.entry_type == 'quant':
184
        run_quant(args.model_path, args.search_path, args.quant_path, quant_config)
185
    elif args.entry_type == 'eval':
186
        run_eval(args.model_path, args.quant_file, args.device,
187
                       args.tasks, args.task_batch_size, args.task_n_shot, args.task_use_pretrained)
Casper Hansen's avatar
Casper Hansen committed
188
    elif args.entry_type == 'speed':
Casper Hansen's avatar
Casper Hansen committed
189
        run_speed(args.model_path, args.quant_file, args.device, args.n_generate, args.n_context, args.batch_size, args.disable_fused_layers)
190
    else:
191
        raise Exception('--entry_type must be one of (search|quant|eval|speed)')