entry.py 8.58 KB
Newer Older
Ji Lin's avatar
Ji Lin committed
1
from lm_eval import evaluator, tasks
2
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
Ji Lin's avatar
Ji Lin committed
3
4
5
6
import torch
import argparse
import os
import json
Abhinav Kulkarni's avatar
Abhinav Kulkarni committed
7
from accelerate import init_empty_weights, infer_auto_device_map, dispatch_model, load_checkpoint_in_model
8
from accelerate.utils.modeling import get_balanced_memory
Ji Lin's avatar
Ji Lin committed
9
10
11
12
from awq.utils.parallel import auto_parallel
from awq.quantize.pre_quant import run_awq, apply_awq
from awq.quantize.quantizer import pseudo_quantize_model_weight, real_quantize_model_weight
from awq.utils.lm_eval_adaptor import LMEvalAdaptor
Abhinav Kulkarni's avatar
Abhinav Kulkarni committed
13
from awq.utils.utils import simple_dispatch_model
Ji Lin's avatar
Ji Lin committed
14
15
16
17
18
19
20
21
22
23
24


parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=str, help='path of the hf model')
parser.add_argument('--batch_size', type=int, default=1, help='batch size')
parser.add_argument("--tasks", default=None, type=str)
parser.add_argument("--output_path", default=None, type=str)
parser.add_argument('--num_fewshot', type=int, default=0)
# model config
parser.add_argument('--parallel', action='store_true',
                    help="enable model parallelism")
25
26
27
# max memory to offload larger models to CPU
parser.add_argument('--max_memory', type=str, nargs='*',
                    help="List of device_id:max_memory pairs to be parsed into a dictionary; " \
28
                        + "Example: 0:10GiB 1:10GiB cpu:30GiB; " \
29
30
                        + "mode details here: " \
                        + "https://huggingface.co/docs/accelerate/usage_guides/big_modeling")
Ji Lin's avatar
Ji Lin committed
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
parser.add_argument('--auto_parallel', action='store_true',
                    help="automatically set parallel and batch_size")
# quantization config
parser.add_argument('--w_bit', type=int, default=None)
parser.add_argument('--q_group_size', type=int, default=-1)
parser.add_argument('--no_zero_point', action='store_true',
                    help="disable zero_point")
parser.add_argument('--q_backend', type=str,
                    default="fake", choices=["fake", "real"])
# save/load real quantized weights
parser.add_argument('--dump_quant', type=str, default=None,
                    help='save quantized model')
parser.add_argument('--load_quant', type=str, default=None,
                    help='load quantized model')
# apply/save/load awq
parser.add_argument('--run_awq', action='store_true',
                    help="perform awq search process")
parser.add_argument('--dump_awq', type=str, default=None,
                    help="save the awq search results")
parser.add_argument('--load_awq', type=str, default=None,
                    help="load the awq search results")
args = parser.parse_args()

54
max_memory = [v.split(':') for v in (args.max_memory or [])]
55
56
max_memory = {(int(k) if k.isdigit() else k):v for k,v in max_memory}

Ji Lin's avatar
Ji Lin committed
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
if args.auto_parallel:
    gpu_list = auto_parallel(args)

# get quantization config (apart from w_bit)
q_config = {
    "zero_point": not args.no_zero_point,  # by default True
    "q_group_size": args.q_group_size,  # whether to use group quantization

}
print("Quantization config:", q_config)

# build model and tokenizer

def build_model_and_enc(model_path):
    if not os.path.exists(model_path):  # look into ssd
        raise FileNotFoundError(f"{model_path} not found!")
    print(f"* Building model {model_path}")

    # all hf model
76
77
    config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
    if "mpt" in config.__class__.__name__.lower():
78
        enc = AutoTokenizer.from_pretrained(config.tokenizer_name, trust_remote_code=True)
79
    else:
80
        enc = AutoTokenizer.from_pretrained(model_path, use_fast=False, trust_remote_code=True)
Ji Lin's avatar
Ji Lin committed
81
82
83
84

    if args.load_quant:  # directly load quantized weights
        print("Loading pre-computed quantized weights...")
        with init_empty_weights():
Abhinav Kulkarni's avatar
Abhinav Kulkarni committed
85
86
            model = AutoModelForCausalLM.from_config(config=config,
                                                     torch_dtype=torch.float16, trust_remote_code=True)
Ji Lin's avatar
Ji Lin committed
87
88
        real_quantize_model_weight(
            model, w_bit=args.w_bit, q_config=q_config, init_only=True)
Abhinav Kulkarni's avatar
Abhinav Kulkarni committed
89
90
91
92
93
94
95
        
        model.tie_weights()
        
        # Infer device map
        kwargs = {"max_memory": max_memory} if len(max_memory) else {}
        device_map = infer_auto_device_map(
            model,
Ji Lin's avatar
Ji Lin committed
96
            no_split_module_classes=[
Abhinav Kulkarni's avatar
Abhinav Kulkarni committed
97
98
                "OPTDecoderLayer", "LlamaDecoderLayer", "BloomBlock", "MPTBlock", "DecoderLayer"],
            **kwargs
Ji Lin's avatar
Ji Lin committed
99
        )
Abhinav Kulkarni's avatar
Abhinav Kulkarni committed
100
101
102
103
104
105
106
107
108
109
110
        # Load checkpoint in the model
        load_checkpoint_in_model(
            model,
            checkpoint=args.load_quant,
            device_map=device_map,
            offload_state_dict=True,
        )
        # Dispatch model
        model = simple_dispatch_model(model, device_map=device_map)

        model.eval()
Ji Lin's avatar
Ji Lin committed
111
    else:  # fp16 to quantized
112
        args.run_awq &= not args.load_awq  # if load_awq, no need to run awq
113
        # Init model on CPU:
114
        kwargs = {"torch_dtype": torch.float16, "low_cpu_mem_usage": True}
115
116
        model = AutoModelForCausalLM.from_pretrained(
            model_path, config=config, trust_remote_code=True, **kwargs)
117

Abhinav Kulkarni's avatar
Abhinav Kulkarni committed
118
119
        model.eval()

120
121
        if args.run_awq:
            assert args.dump_awq, "Please save the awq results with --dump_awq"
122
                        
Ji Lin's avatar
Ji Lin committed
123
124
125
126
127
128
            awq_results = run_awq(
                model, enc,
                w_bit=args.w_bit, q_config=q_config,
                n_samples=128, seqlen=512,
            )
            if args.dump_awq:
129
130
131
                dirpath = os.path.dirname(args.dump_awq)
                os.makedirs(dirpath, exist_ok=True)
                
Ji Lin's avatar
Ji Lin committed
132
133
                torch.save(awq_results, args.dump_awq)
                print("AWQ results saved at", args.dump_awq)
134
135
136
137
138
139
140
                
            exit(0)
                
        if args.load_awq:
            print("Loading pre-computed AWQ results from", args.load_awq)
            awq_results = torch.load(args.load_awq, map_location="cpu")
            apply_awq(model, awq_results)
Ji Lin's avatar
Ji Lin committed
141
142
143
144
145
146
147
148
149
150
151
152
153
154

        # weight quantization
        if args.w_bit is not None:
            if args.q_backend == "fake":
                assert args.dump_quant is None, \
                    "Need to use real quantization to dump quantized weights"
                pseudo_quantize_model_weight(
                    model, w_bit=args.w_bit, q_config=q_config
                )
            elif args.q_backend == "real":  # real quantization
                real_quantize_model_weight(
                    model, w_bit=args.w_bit, q_config=q_config
                )
                if args.dump_quant:
155
156
157
                    dirpath = os.path.dirname(args.dump_quant)
                    os.makedirs(dirpath, exist_ok=True)
                    
Ji Lin's avatar
Ji Lin committed
158
159
160
161
162
163
                    print(
                        f"Saving the quantized model at {args.dump_quant}...")
                    torch.save(model.cpu().state_dict(), args.dump_quant)
                    exit(0)
            else:
                raise NotImplementedError
164
            
165
        # Move the model to GPU (as much as possible) for LM evaluation
166
        kwargs = {"max_memory": get_balanced_memory(model, max_memory if len(max_memory) > 0 else None)}
167
168
169
170
171
172
173
174
        device_map = infer_auto_device_map(
            model,
            # TODO: can we remove this?
            no_split_module_classes=[
                "OPTDecoderLayer", "LlamaDecoderLayer", "BloomBlock", "MPTBlock", "DecoderLayer"],
            **kwargs
        )
        model = dispatch_model(model, device_map=device_map)
Ji Lin's avatar
Ji Lin committed
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194

    return model, enc


def main():
    if args.output_path is not None and os.path.exists(args.output_path):
        # print(f"Results {args.output_path} already generated. Exit.")
        print(f"Results {args.output_path} already generated. Overwrite.")
        # exit()

    if args.dump_awq and os.path.exists(args.dump_awq):
        print(f"Found existing AWQ results {args.dump_awq}, exit.")
        exit()

    # a hack here to auto set model group
    model, enc = build_model_and_enc(args.model_path)

    if args.tasks is not None:
        task_names = args.tasks.split(",")

195
        lm_eval_model = LMEvalAdaptor(args.model_path, model, enc, args.batch_size)
Ji Lin's avatar
Ji Lin committed
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
        results = evaluator.simple_evaluate(
            model=lm_eval_model,
            tasks=task_names,
            batch_size=args.batch_size,
            no_cache=True,
            num_fewshot=args.num_fewshot,
        )

        print(evaluator.make_table(results))

        if args.output_path is not None:
            os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
            # otherwise cannot save
            results["config"]["model"] = args.model_path
            with open(args.output_path, "w") as f:
                json.dump(results, f, indent=2)


if __name__ == '__main__':
    main()