entry.py 7.97 KB
Newer Older
Ji Lin's avatar
Ji Lin committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from lm_eval import evaluator, tasks
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, AutoModelForSeq2SeqLM
import torch
import argparse
import os
import json
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
from awq.utils.parallel import auto_parallel
from awq.quantize.pre_quant import run_awq, apply_awq
from awq.quantize.quantizer import pseudo_quantize_model_weight, real_quantize_model_weight
from awq.utils.lm_eval_adaptor import LMEvalAdaptor


parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=str, help='path of the hf model')
parser.add_argument('--batch_size', type=int, default=1, help='batch size')
parser.add_argument("--tasks", default=None, type=str)
parser.add_argument("--output_path", default=None, type=str)
parser.add_argument('--num_fewshot', type=int, default=0)
# model config
parser.add_argument('--parallel', action='store_true',
                    help="enable model parallelism")
parser.add_argument('--auto_parallel', action='store_true',
                    help="automatically set parallel and batch_size")
# quantization config
parser.add_argument('--w_bit', type=int, default=None)
parser.add_argument('--q_group_size', type=int, default=-1)
parser.add_argument('--no_zero_point', action='store_true',
                    help="disable zero_point")
parser.add_argument('--q_backend', type=str,
                    default="fake", choices=["fake", "real"])
# save/load real quantized weights
parser.add_argument('--dump_quant', type=str, default=None,
                    help='save quantized model')
parser.add_argument('--load_quant', type=str, default=None,
                    help='load quantized model')
# apply/save/load awq
parser.add_argument('--run_awq', action='store_true',
                    help="perform awq search process")
parser.add_argument('--dump_awq', type=str, default=None,
                    help="save the awq search results")
parser.add_argument('--load_awq', type=str, default=None,
                    help="load the awq search results")
args = parser.parse_args()

if args.auto_parallel:
    gpu_list = auto_parallel(args)

# get quantization config (apart from w_bit)
q_config = {
    "zero_point": not args.no_zero_point,  # by default True
    "q_group_size": args.q_group_size,  # whether to use group quantization

}
print("Quantization config:", q_config)

# build model and tokenizer

def build_model_and_enc(model_path):
    if not os.path.exists(model_path):  # look into ssd
        raise FileNotFoundError(f"{model_path} not found!")
    print(f"* Building model {model_path}")

    # all hf model
65
66
67
68
69
    config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
    if "mpt" in config.__class__.__name__.lower():
        enc = AutoTokenizer.from_pretrained(config.tokenizer_name)
    else:
        enc = AutoTokenizer.from_pretrained(model_path, use_fast=False)
Ji Lin's avatar
Ji Lin committed
70
71
72
73
74
75

    if args.load_quant:  # directly load quantized weights
        # no need to really load the fp16 weights... just to get the model structure
        print("Loading pre-computed quantized weights...")
        with init_empty_weights():
            model = AutoModelForCausalLM.from_pretrained(model_path, config=config,
76
                                                         torch_dtype=torch.float16, trust_remote_code=True)
Ji Lin's avatar
Ji Lin committed
77
78
79
80
81
82
        real_quantize_model_weight(
            model, w_bit=args.w_bit, q_config=q_config, init_only=True)
        model = load_checkpoint_and_dispatch(
            model, args.load_quant, device_map="balanced",
            # TODO: can we remove this?
            no_split_module_classes=[
83
                "OPTDecoderLayer", "LlamaDecoderLayer", "BloomBlock", "MPTBlock", "DecoderLayer"]
Ji Lin's avatar
Ji Lin committed
84
85
        )
    else:  # fp16 to quantized
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
        args.run_awq &= not args.load_awq  # if load_awq, no need to run awq
            
        if args.run_awq:
            assert args.dump_awq, "Please save the awq results with --dump_awq"
            
            # Init model on CPU
            def skip(*args, **kwargs):
                pass
            
            torch.nn.init.kaiming_normal_ = skip
            torch.nn.init.kaiming_uniform_ = skip
            torch.nn.init.uniform_ = skip
            torch.nn.init.normal_ = skip
            model = AutoModelForCausalLM.from_pretrained(
                model_path, config=config, trust_remote_code=True, torch_dtype=torch.float16)
101
            
Ji Lin's avatar
Ji Lin committed
102
103
104
105
106
107
            awq_results = run_awq(
                model, enc,
                w_bit=args.w_bit, q_config=q_config,
                n_samples=128, seqlen=512,
            )
            if args.dump_awq:
108
109
110
                dirpath = os.path.dirname(args.dump_awq)
                os.makedirs(dirpath, exist_ok=True)
                
Ji Lin's avatar
Ji Lin committed
111
112
                torch.save(awq_results, args.dump_awq)
                print("AWQ results saved at", args.dump_awq)
113
114
115
116
                
            exit(0)
        else:
            # Inference with fake quant
117
118
            # Init model on CPU:
            kwargs = {"torch_dtype": torch.float16}
119
120
121
122
123
124
125
            model = AutoModelForCausalLM.from_pretrained(
                model_path, config=config, trust_remote_code=True, **kwargs)
                
        if args.load_awq:
            print("Loading pre-computed AWQ results from", args.load_awq)
            awq_results = torch.load(args.load_awq, map_location="cpu")
            apply_awq(model, awq_results)
Ji Lin's avatar
Ji Lin committed
126
127
128
129
130
131
132
133
134
135
136
137
138
139

        # weight quantization
        if args.w_bit is not None:
            if args.q_backend == "fake":
                assert args.dump_quant is None, \
                    "Need to use real quantization to dump quantized weights"
                pseudo_quantize_model_weight(
                    model, w_bit=args.w_bit, q_config=q_config
                )
            elif args.q_backend == "real":  # real quantization
                real_quantize_model_weight(
                    model, w_bit=args.w_bit, q_config=q_config
                )
                if args.dump_quant:
140
141
142
                    dirpath = os.path.dirname(args.dump_quant)
                    os.makedirs(dirpath, exist_ok=True)
                    
Ji Lin's avatar
Ji Lin committed
143
144
145
146
147
148
                    print(
                        f"Saving the quantized model at {args.dump_quant}...")
                    torch.save(model.cpu().state_dict(), args.dump_quant)
                    exit(0)
            else:
                raise NotImplementedError
149
150
151
152
153
154
155
156
157
            
            # Move the model to GPU (as much as possible) for LM evaluation
            kwargs = {
                "torch_dtype": torch.float16, 
                "device_map": "auto", 
                "max_memory": {0: "8GiB", "cpu": "99GiB"}
            }
            model = AutoModelForCausalLM.from_pretrained(
                model_path, config=config, state_dict=model.state_dict(), trust_remote_code=True, **kwargs)
Ji Lin's avatar
Ji Lin committed
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177

    return model, enc


def main():
    if args.output_path is not None and os.path.exists(args.output_path):
        # print(f"Results {args.output_path} already generated. Exit.")
        print(f"Results {args.output_path} already generated. Overwrite.")
        # exit()

    if args.dump_awq and os.path.exists(args.dump_awq):
        print(f"Found existing AWQ results {args.dump_awq}, exit.")
        exit()

    # a hack here to auto set model group
    model, enc = build_model_and_enc(args.model_path)

    if args.tasks is not None:
        task_names = args.tasks.split(",")

178
        lm_eval_model = LMEvalAdaptor(args.model_path, model, enc, args.batch_size)
Ji Lin's avatar
Ji Lin committed
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
        results = evaluator.simple_evaluate(
            model=lm_eval_model,
            tasks=task_names,
            batch_size=args.batch_size,
            no_cache=True,
            num_fewshot=args.num_fewshot,
        )

        print(evaluator.make_table(results))

        if args.output_path is not None:
            os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
            # otherwise cannot save
            results["config"]["model"] = args.model_path
            with open(args.output_path, "w") as f:
                json.dump(results, f, indent=2)


if __name__ == '__main__':
    main()