entry.py 8.51 KB
Newer Older
Ji Lin's avatar
Ji Lin committed
1
import os
Casper Hansen's avatar
Casper Hansen committed
2
import time
Casper Hansen's avatar
Casper Hansen committed
3
import torch
4
5
import argparse
from lm_eval import evaluator
6
from transformers import AutoTokenizer
Casper's avatar
Casper committed
7
from awq import AutoAWQForCausalLM
Casper Hansen's avatar
Casper Hansen committed
8
9
from awq.quantize.auto_clip import apply_clip
from awq.quantize.auto_scale import apply_scale
10
from awq.utils.lm_eval_adaptor import LMEvalAdaptor
Ji Lin's avatar
Ji Lin committed
11

12

Casper Hansen's avatar
Casper Hansen committed
13
14
def load_search_result_into_memory(model, search_path):
    awq_results = torch.load(search_path, map_location="cpu")
15

Casper Hansen's avatar
Casper Hansen committed
16
17
    apply_scale(model, awq_results["scale"])
    apply_clip(model, awq_results["clip"])
Ji Lin's avatar
Ji Lin committed
18

19
def run_search(model_path, dump_path, quant_config):
20
21
22
23
    """
    Step 1/2: Search the pile for an optimal scaling factor.
    """
    # Load model
24
25
    model = AutoAWQForCausalLM.from_pretrained(model_path)
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
Ji Lin's avatar
Ji Lin committed
26

27
    # Quantize
28
    model.quantize(tokenizer, quant_config=quant_config, run_search=True, run_quant=False)
29
30
31

    # Save search results
    model.save_quantized(dump_path)
Ji Lin's avatar
Ji Lin committed
32

33
34
35
    # Save tokenizer
    tokenizer.save_pretrained(dump_path)

36
def run_quant(model_path, search_path, dump_path, quant_config):
37
38
39
40
    """
    Step 2/2: Use the search results to quantize model weights
    """
    # Load model and search results
41
42
    model = AutoAWQForCausalLM.from_pretrained(model_path)
    load_search_result_into_memory(model.model, search_path)
Ji Lin's avatar
Ji Lin committed
43

44
    # Run actual weight quantization
45
    model.quantize(quant_config=quant_config, run_search=False, run_quant=True)
46
47
48

    # Save quantized model
    model.save_quantized(dump_path)
Ji Lin's avatar
Ji Lin committed
49

50
def run_eval(model_path, quant_file, device, tasks, task_batch_size, task_n_shot, task_use_pretrained):
51
52
53
54
    """
    Post quantization: Evaluate perplexity on wikitext with EleutherAI Evaluation Harness
    """
    # Load model
55
56
57
58
    if task_use_pretrained:
        model = AutoAWQForCausalLM.from_pretrained(model_path)
    else:
        model = AutoAWQForCausalLM.from_quantized(model_path, quant_file)
59

60
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
61

62
    # Load adapter
63
    lm_eval_model = LMEvalAdaptor(model_path, model, tokenizer, device, batch_size=task_batch_size)
64
65

    # Evaluate perplexity of quantized model
66
67
    results = evaluator.simple_evaluate(
        model=lm_eval_model,
68
69
        tasks=tasks.split(','),
        batch_size=task_batch_size,
70
        no_cache=True,
71
        num_fewshot=task_n_shot,
72
73
74
75
    )

    print(evaluator.make_table(results))

Casper Hansen's avatar
Casper Hansen committed
76
@torch.inference_mode()
Casper Hansen's avatar
Casper Hansen committed
77
def run_speed(model_path, quant_file, device, n_generate=128, n_context=256, batch_size=1, disable_fused_layers=False):
Casper Hansen's avatar
Casper Hansen committed
78
79
80
81
    def _timer(func):
        start = time.time()
        out = func()
        return out, time.time() - start
82

Casper Hansen's avatar
Casper Hansen committed
83
    def _generate(model, model_out, n_generate, batch_size):
Casper Hansen's avatar
Casper Hansen committed
84
85
86
        past_key_values = model_out.past_key_values

        for i in range(n_generate):
Casper Hansen's avatar
Casper Hansen committed
87
88
            logits = model_out.logits[:, -1, :]
            new_tokens = []
Casper Hansen's avatar
Casper Hansen committed
89

Casper Hansen's avatar
Casper Hansen committed
90
91
92
93
94
95
96
97
            for batch_index in range(batch_size):
                probs = torch.softmax(logits[batch_index], dim=-1)
                token = torch.multinomial(probs, num_samples=1)
                new_tokens.append(token)
            
            tokens = torch.as_tensor(new_tokens, device=device).unsqueeze(-1)

            model_out = model(tokens, use_cache=True, past_key_values=past_key_values)
98

Casper Hansen's avatar
Casper Hansen committed
99
100
101
    def _warmup(device:str):
        warm_up = torch.randn((4096,4096)).to(device)
        torch.mm(warm_up,warm_up)
102
103

    if quant_file:
Casper Hansen's avatar
Casper Hansen committed
104
105
        fuse_layers = False if disable_fused_layers else True
        model, load_time = _timer(lambda: AutoAWQForCausalLM.from_quantized(model_path, quant_file, fuse_layers=fuse_layers))
106
107
108
    else:
        model, load_time = _timer(lambda: AutoAWQForCausalLM.from_pretrained(model_path))

Casper Hansen's avatar
Casper Hansen committed
109
110
111
112
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    _warmup(device)

    # Generate random inputs
Casper Hansen's avatar
Casper Hansen committed
113
    n_context = n_context - n_generate
Casper Hansen's avatar
Casper Hansen committed
114
    ids = torch.randint(0, tokenizer.vocab_size, (batch_size, n_context)).cuda()
Casper Hansen's avatar
Casper Hansen committed
115
116
117
118
119

    # Context stage
    model_out, context_time = _timer(lambda: model(ids, use_cache=True))

    # Generation stage
Casper Hansen's avatar
Casper Hansen committed
120
    _, generation_time = _timer(lambda: _generate(model, model_out, n_generate, batch_size))
121

Casper Hansen's avatar
Casper Hansen committed
122
123
124
125
126
127
128
129
130
131
132
133
134
    # Prints
    memory_used = torch.cuda.max_memory_allocated(device) / (1024 ** 2)
    context_tokens_per_second = n_context / context_time
    context_ms_per_token = (context_time*1000) / n_context
    inference_tokens_per_second = n_generate / generation_time
    inference_ms_per_token = (generation_time*1000) / n_generate

    print(f"[======] Model summary: {model_path} [======]")
    print(f"[*] Load time: {load_time:.2f} seconds")
    print(f"[*] Context speed: {context_tokens_per_second:.2f} tokens/second ({context_ms_per_token:.2f} ms/token)")
    print(f"[*] Generation speed: {inference_tokens_per_second:.2f} tokens/second ({inference_ms_per_token:.2f} ms/token)")
    print(f"[*] VRAM: {memory_used:.2f} MB")

135
if __name__ == '__main__':
136
    """
137
138
139
140
141
142
143
    - Run AWQ search and save result:
    python -m awq.entry --entry_type search --model_path lmsys/vicuna-7b-v1.5 --search_path vicuna-7b-v1.5-awq

    - Run AWQ to save the real quantized weights at the quant_path:
    python -m awq.entry --entry_type quant --model_path lmsys/vicuna-7b-v1.5 --search_path vicuna-7b-v1.5-awq/awq_model_search_result.pt --quant_path vicuna-7b-v1.5-awq

    - Run perplexity of quantized model:
144
    python -m awq.entry --entry_type eval --model_path vicuna-7b-v1.5-awq --quant_file awq_model_w4_g128.pt
145
146
147

    - Run perplexity unquantized FP16 model:
    python -m awq.entry --entry_type eval --model_path lmsys/vicuna-7b-v1.5 --task_use_pretrained
Casper Hansen's avatar
Casper Hansen committed
148
149

    - Run a speedtest to benchmark the quantized model:
Casper Hansen's avatar
Casper Hansen committed
150
    python -m awq.entry --entry_type speed --model_path vicuna-7b-v1.5-awq --quant_file awq_model_w4_g128.pt --n_generate 128 --n_context 256
151
152

    - Run a speedtest to benchmark the unquantized FP16 model:
Casper Hansen's avatar
Casper Hansen committed
153
    python -m awq.entry --entry_type speed --model_path lmsys/vicuna-7b-v1.5 --n_generate 128 --n_context 256
154
    """
155
    parser = argparse.ArgumentParser()
Casper Hansen's avatar
Casper Hansen committed
156
    parser.add_argument('--entry_type', type=str, help='The type of task to run (search|quant|eval|speed)')
157
158
    parser.add_argument('--model_path', type=str, help='Path to hf model')
    parser.add_argument('--search_path', type=str, help='Path to save/load AWQ search results')
159
    parser.add_argument('--quant_path', type=str, help='Path to save AWQ model to directory')
160
161
    parser.add_argument('--quant_file', type=str, help='Path to quantized AWQ model file')
    parser.add_argument('--device', type=str, default='cuda:0', help='Device to load model to')
162
163
    parser.add_argument('--w_bit', type=int, default=4)
    parser.add_argument('--q_group_size', type=int, default=128)
164
165
166
    parser.add_argument('--tasks', type=str, default='wikitext', help='Tasks to evaluate. '
                        'Separate tasks by comma for multiple tasks.'
                        'https://github.com/EleutherAI/lm-evaluation-harness/blob/master/docs/task_table.md')
Casper Hansen's avatar
Casper Hansen committed
167
    parser.add_argument("--task_use_pretrained", default=False, action='store_true',
168
169
170
                        help="Pass '--task_use_pretrained' to use a pretrained model running FP16")
    parser.add_argument('--task_batch_size', type=int, default=1)
    parser.add_argument('--task_n_shot', type=int, default=0)
Casper Hansen's avatar
Casper Hansen committed
171
    parser.add_argument('--n_generate', type=int, default=128)
Casper Hansen's avatar
Casper Hansen committed
172
    parser.add_argument('--n_context', type=int, default=256)
Casper Hansen's avatar
Casper Hansen committed
173
174
175
    parser.add_argument('--batch_size', type=int, default=1)
    parser.add_argument("--disable_fused_layers", default=False, action='store_true',
                        help="Pass '--disable_fused_layers' to disable fused layers")
176
177
    args = parser.parse_args()

178
    quant_config = { "zero_point": True, "q_group_size": args.q_group_size, "w_bit": args.w_bit }
179

180
    if args.entry_type == 'search':
181
        run_search(args.model_path, args.search_path, quant_config)
182
    elif args.entry_type == 'quant':
183
        run_quant(args.model_path, args.search_path, args.quant_path, quant_config)
184
    elif args.entry_type == 'eval':
185
        run_eval(args.model_path, args.quant_file, args.device,
186
                       args.tasks, args.task_batch_size, args.task_n_shot, args.task_use_pretrained)
Casper Hansen's avatar
Casper Hansen committed
187
    elif args.entry_type == 'speed':
Casper Hansen's avatar
Casper Hansen committed
188
189
190
191
        if args.batch_size > 1 and not args.disable_fused_layers:
            raise Exception('Fused layers only support batch_size=1. Pass --disable_fused_layers to run batch_size>1 (much slower).')
        
        run_speed(args.model_path, args.quant_file, args.device, args.n_generate, args.n_context, args.batch_size, args.disable_fused_layers)
192
    else:
193
        raise Exception('--entry_type must be one of (search|quant|eval|speed)')