entry.py 8.04 KB
Newer Older
Ji Lin's avatar
Ji Lin committed
1
from lm_eval import evaluator, tasks
2
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
Ji Lin's avatar
Ji Lin committed
3
4
5
6
import torch
import argparse
import os
import json
7
from accelerate import init_empty_weights, infer_auto_device_map, dispatch_model, load_checkpoint_and_dispatch
Ji Lin's avatar
Ji Lin committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from awq.utils.parallel import auto_parallel
from awq.quantize.pre_quant import run_awq, apply_awq
from awq.quantize.quantizer import pseudo_quantize_model_weight, real_quantize_model_weight
from awq.utils.lm_eval_adaptor import LMEvalAdaptor


parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=str, help='path of the hf model')
parser.add_argument('--batch_size', type=int, default=1, help='batch size')
parser.add_argument("--tasks", default=None, type=str)
parser.add_argument("--output_path", default=None, type=str)
parser.add_argument('--num_fewshot', type=int, default=0)
# model config
parser.add_argument('--parallel', action='store_true',
                    help="enable model parallelism")
23
24
25
# max memory to offload larger models to CPU
parser.add_argument('--max_memory', type=str, nargs='*',
                    help="List of device_id:max_memory pairs to be parsed into a dictionary; " \
26
                        + "Example: 0:10GiB 1:10GiB cpu:30GiB; " \
27
28
                        + "mode details here: " \
                        + "https://huggingface.co/docs/accelerate/usage_guides/big_modeling")
Ji Lin's avatar
Ji Lin committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
parser.add_argument('--auto_parallel', action='store_true',
                    help="automatically set parallel and batch_size")
# quantization config
parser.add_argument('--w_bit', type=int, default=None)
parser.add_argument('--q_group_size', type=int, default=-1)
parser.add_argument('--no_zero_point', action='store_true',
                    help="disable zero_point")
parser.add_argument('--q_backend', type=str,
                    default="fake", choices=["fake", "real"])
# save/load real quantized weights
parser.add_argument('--dump_quant', type=str, default=None,
                    help='save quantized model')
parser.add_argument('--load_quant', type=str, default=None,
                    help='load quantized model')
# apply/save/load awq
parser.add_argument('--run_awq', action='store_true',
                    help="perform awq search process")
parser.add_argument('--dump_awq', type=str, default=None,
                    help="save the awq search results")
parser.add_argument('--load_awq', type=str, default=None,
                    help="load the awq search results")
args = parser.parse_args()

52
max_memory = [v.split(':') for v in (args.max_memory or [])]
53
54
max_memory = {(int(k) if k.isdigit() else k):v for k,v in max_memory}

Ji Lin's avatar
Ji Lin committed
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
if args.auto_parallel:
    gpu_list = auto_parallel(args)

# get quantization config (apart from w_bit)
q_config = {
    "zero_point": not args.no_zero_point,  # by default True
    "q_group_size": args.q_group_size,  # whether to use group quantization

}
print("Quantization config:", q_config)

# build model and tokenizer

def build_model_and_enc(model_path):
    if not os.path.exists(model_path):  # look into ssd
        raise FileNotFoundError(f"{model_path} not found!")
    print(f"* Building model {model_path}")

    # all hf model
74
75
    config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
    if "mpt" in config.__class__.__name__.lower():
76
        enc = AutoTokenizer.from_pretrained(config.tokenizer_name, trust_remote_code=True)
77
    else:
78
        enc = AutoTokenizer.from_pretrained(model_path, use_fast=False, trust_remote_code=True)
Ji Lin's avatar
Ji Lin committed
79
80
81
82
83

    if args.load_quant:  # directly load quantized weights
        print("Loading pre-computed quantized weights...")
        with init_empty_weights():
            model = AutoModelForCausalLM.from_pretrained(model_path, config=config,
84
                                                         torch_dtype=torch.float16, trust_remote_code=True)
Ji Lin's avatar
Ji Lin committed
85
86
87
        real_quantize_model_weight(
            model, w_bit=args.w_bit, q_config=q_config, init_only=True)
        model = load_checkpoint_and_dispatch(
88
            model, args.load_quant, device_map="balanced",
Ji Lin's avatar
Ji Lin committed
89
90
            # TODO: can we remove this?
            no_split_module_classes=[
91
                "OPTDecoderLayer", "LlamaDecoderLayer", "BloomBlock", "MPTBlock", "DecoderLayer"]
Ji Lin's avatar
Ji Lin committed
92
93
        )
    else:  # fp16 to quantized
94
        args.run_awq &= not args.load_awq  # if load_awq, no need to run awq
95
        # Init model on CPU:
96
        kwargs = {"torch_dtype": torch.float16, "low_cpu_mem_usage": True}
97
98
        model = AutoModelForCausalLM.from_pretrained(
            model_path, config=config, trust_remote_code=True, **kwargs)
99

100
101
        if args.run_awq:
            assert args.dump_awq, "Please save the awq results with --dump_awq"
102
                        
Ji Lin's avatar
Ji Lin committed
103
104
105
106
107
108
            awq_results = run_awq(
                model, enc,
                w_bit=args.w_bit, q_config=q_config,
                n_samples=128, seqlen=512,
            )
            if args.dump_awq:
109
110
111
                dirpath = os.path.dirname(args.dump_awq)
                os.makedirs(dirpath, exist_ok=True)
                
Ji Lin's avatar
Ji Lin committed
112
113
                torch.save(awq_results, args.dump_awq)
                print("AWQ results saved at", args.dump_awq)
114
115
116
117
118
119
120
                
            exit(0)
                
        if args.load_awq:
            print("Loading pre-computed AWQ results from", args.load_awq)
            awq_results = torch.load(args.load_awq, map_location="cpu")
            apply_awq(model, awq_results)
Ji Lin's avatar
Ji Lin committed
121
122
123
124
125
126
127
128
129
130
131
132
133
134

        # weight quantization
        if args.w_bit is not None:
            if args.q_backend == "fake":
                assert args.dump_quant is None, \
                    "Need to use real quantization to dump quantized weights"
                pseudo_quantize_model_weight(
                    model, w_bit=args.w_bit, q_config=q_config
                )
            elif args.q_backend == "real":  # real quantization
                real_quantize_model_weight(
                    model, w_bit=args.w_bit, q_config=q_config
                )
                if args.dump_quant:
135
136
137
                    dirpath = os.path.dirname(args.dump_quant)
                    os.makedirs(dirpath, exist_ok=True)
                    
Ji Lin's avatar
Ji Lin committed
138
139
140
141
142
143
                    print(
                        f"Saving the quantized model at {args.dump_quant}...")
                    torch.save(model.cpu().state_dict(), args.dump_quant)
                    exit(0)
            else:
                raise NotImplementedError
144
            
145
146
147
148
149
150
151
152
153
154
        # Move the model to GPU (as much as possible) for LM evaluation
        kwargs = {"max_memory": max_memory} if len(max_memory) else {}
        device_map = infer_auto_device_map(
            model,
            # TODO: can we remove this?
            no_split_module_classes=[
                "OPTDecoderLayer", "LlamaDecoderLayer", "BloomBlock", "MPTBlock", "DecoderLayer"],
            **kwargs
        )
        model = dispatch_model(model, device_map=device_map)
Ji Lin's avatar
Ji Lin committed
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174

    return model, enc


def main():
    if args.output_path is not None and os.path.exists(args.output_path):
        # print(f"Results {args.output_path} already generated. Exit.")
        print(f"Results {args.output_path} already generated. Overwrite.")
        # exit()

    if args.dump_awq and os.path.exists(args.dump_awq):
        print(f"Found existing AWQ results {args.dump_awq}, exit.")
        exit()

    # a hack here to auto set model group
    model, enc = build_model_and_enc(args.model_path)

    if args.tasks is not None:
        task_names = args.tasks.split(",")

175
        lm_eval_model = LMEvalAdaptor(args.model_path, model, enc, args.batch_size)
Ji Lin's avatar
Ji Lin committed
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
        results = evaluator.simple_evaluate(
            model=lm_eval_model,
            tasks=task_names,
            batch_size=args.batch_size,
            no_cache=True,
            num_fewshot=args.num_fewshot,
        )

        print(evaluator.make_table(results))

        if args.output_path is not None:
            os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
            # otherwise cannot save
            results["config"]["model"] = args.model_path
            with open(args.output_path, "w") as f:
                json.dump(results, f, indent=2)


if __name__ == '__main__':
    main()