base.py 13 KB
Newer Older
1
import os
Casper Hansen's avatar
Casper Hansen committed
2
import gc
3
import json
Casper Hansen's avatar
Casper Hansen committed
4
5
6
import torch
import functools
import torch.nn as nn
Casper Hansen's avatar
Casper Hansen committed
7
from tqdm import tqdm
Casper Hansen's avatar
Casper Hansen committed
8
9
from collections import defaultdict

10
from huggingface_hub import snapshot_download
Casper Hansen's avatar
Casper Hansen committed
11
from awq.utils.calib_data import get_calib_dataset
12
13
from awq.quantize.quantizer import pseudo_quantize_tensor
from awq.quantize.qmodule import WQLinear, ScaledActivation
Casper Hansen's avatar
Casper Hansen committed
14
15
from awq.quantize.auto_clip import auto_clip_block, apply_clip
from awq.quantize.auto_scale import auto_scale_block, apply_scale
16
17
from transformers import AutoModelForCausalLM, AutoConfig, PreTrainedModel
from accelerate import init_empty_weights, load_checkpoint_and_dispatch, infer_auto_device_map
Casper Hansen's avatar
Casper Hansen committed
18
from awq.utils.module import append_str_prefix, get_op_name, get_named_linears, set_op_by_name
Casper Hansen's avatar
Casper Hansen committed
19

20
class BaseAWQForCausalLM(nn.Module):
21
    def __init__(self, model, model_type, is_quantized, quant_config):
22
        super().__init__()
23
24
25
26
        self.model:PreTrainedModel = model
        self.model_type:str = model_type
        self.is_quantized:bool = is_quantized
        self.search_result = None
27
        self.quant_config:dict = quant_config
28
29
30
31
32
33
    
    def to(self, device: str):
        return self.model.to(device)
    
    def forward(self, *args, **kwargs):
        return self.model(*args, **kwargs)
Casper Hansen's avatar
Casper Hansen committed
34
35
36
37
    
    def generate(self, *args, **kwargs):
        with torch.inference_mode():
            return self.model.generate(*args, **kwargs)
38

Casper Hansen's avatar
Casper Hansen committed
39
    @torch.no_grad()
40
    def quantize(self, tokenizer=None, quant_config={}, n_samples=128, seqlen=512,
Casper Hansen's avatar
Casper Hansen committed
41
                       auto_scale=True, mse_range=True, run_search=False, run_quant=True,
Casper Hansen's avatar
Casper Hansen committed
42
                       calib_data="pileval"):
43
        self.quant_config = quant_config
44

Casper Hansen's avatar
Casper Hansen committed
45
        if run_search:
46
            self.search_result = self._awq_search(tokenizer, quant_config, n_samples=n_samples, seqlen=seqlen,
Casper Hansen's avatar
Casper Hansen committed
47
48
49
                       auto_scale=auto_scale, mse_range=mse_range, calib_data=calib_data)
        
        if run_quant:
50
            self._awq_quant()
Casper Hansen's avatar
Casper Hansen committed
51
52
    
    
53
54
    def _awq_quant(self):
        assert self.quant_config["zero_point"], "We only support zero_point quantization now."
55
        layers = self.get_model_layers(self.model)
Casper's avatar
Casper committed
56

Casper Hansen's avatar
Casper Hansen committed
57
58
59
60
        # Run AWQ quantization
        for i in tqdm(range(len(layers)), desc="AWQ Quantization"):
            layer = layers[i]
            named_linears = get_named_linears(layer)
61
            self._scale_activations(self, layer)
Casper Hansen's avatar
Casper Hansen committed
62
63

            for name, module in named_linears.items():
Casper Hansen's avatar
Casper Hansen committed
64
                module.cuda()
65
66
67
68
69
70
71

                module.weight.data, scales, zeros = pseudo_quantize_tensor(
                    module.weight.data, 
                    get_scale_zp=True, 
                    **self.quant_config
                )

Casper Hansen's avatar
Casper Hansen committed
72
73
                scales = scales.t().contiguous()
                zeros = zeros.t().contiguous()
74

Casper Hansen's avatar
Casper Hansen committed
75
                q_linear = WQLinear.from_linear(
76
77
78
79
80
81
82
83
                    module, 
                    self.quant_config['w_bit'], 
                    self.quant_config['q_group_size'], 
                    False, 
                    scales, 
                    zeros
                )

Casper Hansen's avatar
Casper Hansen committed
84
85
86
87
88
                module.cpu()
                q_linear.to(next(layer.parameters()).device)
                set_op_by_name(layer, name, q_linear)
                torch.cuda.empty_cache()
                gc.collect()
Casper Hansen's avatar
Casper Hansen committed
89
90
91
92
            
            torch.cuda.empty_cache()
            gc.collect()
    
93
    def _awq_search(self, tokenizer, quant_config, n_samples=128, seqlen=512,
Casper Hansen's avatar
Casper Hansen committed
94
                       auto_scale=True, mse_range=True, calib_data="pileval"):
95
        layers = self.get_model_layers(self.model)
Casper Hansen's avatar
Casper Hansen committed
96
97
98
99
100
101
102
103
104

        samples = get_calib_dataset(
            data=calib_data, tokenizer=tokenizer, n_samples=n_samples, block_size=seqlen)
        samples = torch.cat(samples, dim=0)

        inps = []
        layer_kwargs = {}

        layers[0] = layers[0].cuda()
105
        self.move_embed(self.model, "cuda")
Casper Hansen's avatar
Casper Hansen committed
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
        
        # get input and kwargs to layer 0
        # with_kwargs is only supported in PyTorch 2.0
        # use this Catcher hack for now
        class Catcher(nn.Module):
            def __init__(self, module):
                super().__init__()
                self.module = module

            def forward(self, inp, **kwargs):
                inps.append(inp)
                layer_kwargs.update(kwargs)
                raise ValueError  # early exit to break later inference

        # patch layer 0 to catch input and kwargs
        layers[0] = Catcher(layers[0])
        try:
123
            self.model(samples.to(next(self.model.parameters()).device))
Casper Hansen's avatar
Casper Hansen committed
124
125
126
127
128
129
130
        except ValueError:  # work with early exit
            pass
        del samples
        layers[0] = layers[0].module  # restore
        inps = inps[0]

        layers[0] = layers[0].cpu()
131
        self.move_embed(self.model, "cpu")
Casper Hansen's avatar
Casper Hansen committed
132
133
134
135
136
137
138
139
        
        gc.collect()
        torch.cuda.empty_cache()
        awq_results = {
            "scale": [],
            "clip": [],
        }

Casper Hansen's avatar
Casper Hansen committed
140
        # Run AWQ search layer by layer
141
        for i in tqdm(range(len(layers)), desc="AWQ Search"):
Casper Hansen's avatar
Casper Hansen committed
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
            layer = layers[i]
            layer = layer.cuda()
            named_linears = get_named_linears(layer)

            # firstly, get input features of all linear layers
            def cache_input_hook(m, x, y, name, feat_dict):
                x = x[0]
                x = x.detach().cpu()
                feat_dict[name].append(x)

            input_feat = defaultdict(list)
            handles = []
            for name in named_linears:
                handles.append(named_linears[name].register_forward_hook(
                    functools.partial(cache_input_hook, name=name,
                                    feat_dict=input_feat)))
            inps = inps.to(next(layer.parameters()).device)  # in case multi-gpu
            # get output as next layer's input
            inps = layer(inps, **layer_kwargs)[0]
            for h in handles:
                h.remove()
            # now solve for scaling and clipping
            input_feat = {k: torch.cat(v, dim=0) for k, v in input_feat.items()}

            # Clear GPU memory
            torch.cuda.empty_cache()

            if auto_scale:  # if it applies, we should also modify the input_feat with scales
                scales_list = auto_scale_block(
                    self,
172
173
174
                    layer,
                    layer_kwargs,
                    quant_config=quant_config,
Casper Hansen's avatar
Casper Hansen committed
175
176
                    input_feat=input_feat,
                )
177

Casper Hansen's avatar
Casper Hansen committed
178
                apply_scale(layers[i], scales_list, input_feat_dict=input_feat)
179

Casper Hansen's avatar
Casper Hansen committed
180
                # append prefix to make names global
181
                awq_results["scale"] += append_str_prefix(scales_list, get_op_name(self.model, layer) + ".")
Casper Hansen's avatar
Casper Hansen committed
182
183
184
185
186

            # Clear GPU memory
            torch.cuda.empty_cache()
            
            if mse_range:
187
188
189
190
191
192
                clip_list = auto_clip_block(
                    layer,
                    quant_config=quant_config,
                    input_feat=input_feat
                )

Casper Hansen's avatar
Casper Hansen committed
193
194
                apply_clip(layer, clip_list)
                # append prefix to make names global
195
                awq_results["clip"] += append_str_prefix(clip_list, get_op_name(self.model, layer) + ".")
Casper Hansen's avatar
Casper Hansen committed
196
197
198
199
200
201

            layer = layer.cpu()
            # Haotian: check activation replacement
            del input_feat
            gc.collect()
            torch.cuda.empty_cache()
Casper Hansen's avatar
Casper Hansen committed
202
        
Casper Hansen's avatar
Casper Hansen committed
203
        return awq_results
Casper's avatar
Casper committed
204

205
    def save_quantized(self, save_dir):
206
207
208
209
210
211
212
213
214
215
216
217
218
219
        def _save_files(save_dir, model_name, model):
            class EmptyModule(nn.Module):
                def __init__(self): super(EmptyModule, self).__init__()
                def forward(self, x): return x

            # Save model fiels without search results
            self.model.save_pretrained(save_dir, state_dict=EmptyModule().state_dict())

            # Remove empty module
            os.remove(f'{save_dir}/pytorch_model.bin')

            # Save search results
            torch.save(model, f'{save_dir}/{model_name}')

220
221
222
223
            # Save config
            with open(f'{save_dir}/quant_config.json', 'w+') as file:
                file.write(json.dumps(self.quant_config, indent=4))

224
225
226
227
        save_dir = save_dir[:-1] if save_dir[-1] == '/' else save_dir

        # Save model
        if self.search_result is None:
228
            model_name = f'awq_model_w{self.quant_config.w_bit}_g{self.quant_config.q_group_size}.pt'
229
            _save_files(save_dir, model_name, self.model.state_dict())
230
231
        else:
            model_name = 'awq_model_search_result.pt'
232
233
            _save_files(save_dir, model_name, self.search_result)
        
234
235
236
237
238
239
    @classmethod
    def from_pretrained(self, model_path, model_type, torch_dtype: torch.dtype = torch.float16, 
                        trust_remote_code=True):
        return self.from_quantized(
            model_path, 
            model_type, 
240
            model_filename='', 
241
242
243
244
245
            device='balanced', 
            torch_dtype=torch_dtype, 
            trust_remote_code=trust_remote_code, 
            is_quantized=False
        )
Casper's avatar
Casper committed
246

247
    @classmethod
248
    def from_quantized(self, model_path, model_type, model_filename,
249
250
                       device='balanced', torch_dtype=torch.float16, trust_remote_code=True, 
                       safetensors=False, is_quantized=True):
251
252
        # Download model if path is not a directory
        if not os.path.isdir(model_path):
253
254
255
256
257
258
259
            ignore_patterns = ["*msgpack*", "*h5*"]
            if safetensors:
                ignore_patterns.extend(["*.pt", "*.bin"])
            else:
                ignore_patterns.append("*safetensors*")

            model_path = snapshot_download(model_path, ignore_patterns=ignore_patterns)
260
261
262
        
        # TODO: Better naming, model_filename becomes a directory
        model_filename = model_path + f'/{model_filename}'
263

264
        # Load config
265
266
267
268
269
270
271
272
        quant_config_path = f'{model_path}/quant_config.json'
        if os.path.exists(quant_config_path):
            with open(quant_config_path, 'r') as file:
                quant_config = json.loads(file.read())
        else:
            # Default config that works for most models
            quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4}
        
273
274
275
276
        config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)

        # Load empty weights
        with init_empty_weights():
277
278
            model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code)
        
279
        # Only need to replace layers if a model is AWQ quantized
280
281
        if is_quantized:
            # Prepare WQLinear layers, replace nn.Linear
282
            self._load_quantized_modules(self, model, quant_config)
283
284
        
        model.tie_weights()
285

286
        # Load model weights
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
        try:
            model = load_checkpoint_and_dispatch(model, model_filename, device_map=device, no_split_module_classes=[self.layer_type])
        except Exception as ex:
            # Fallback to auto model if load_checkpoint_and_dispatch is not working
            print(f'{ex} - falling back to AutoModelForCausalLM.from_pretrained')

            device_map = infer_auto_device_map(
                model,
                no_split_module_classes=[self.layer_type], 
                dtype=torch_dtype
            )
            
            del model
            
            # Load model weights
            model = AutoModelForCausalLM.from_pretrained(
                model_filename, device_map=device_map, offload_folder="offload", offload_state_dict=True, torch_dtype=torch_dtype
            )
            model.eval()
306

307
        return self(model, model_type, is_quantized=is_quantized, quant_config=quant_config)
Casper's avatar
Casper committed
308

309
    def _load_quantized_modules(self, model, quant_config):
310
        # Real quantization of weights
311
        assert quant_config["zero_point"], "We only support zero_point quantization now."
312
313
        
        # Get blocks of model
314
        layers = self.get_model_layers(model)
315

316
317
        for i in tqdm(range(len(layers)), desc="Replacing layers..."):
            layer = layers[i]
318
319

            # Get every linear layer in a block
320
            named_linears = get_named_linears(layer)
321
322

            # Replace activation functions
323
            self._scale_activations(self, layer)
324

325
            # Replace nn.Linear with WQLinear
326
327
            for name, module in named_linears.items():
                q_linear = WQLinear.from_linear(
328
                    module, quant_config['w_bit'], quant_config['q_group_size'], True)
329
330
331
332
333
334
                q_linear.to(next(layer.parameters()).device)
                set_op_by_name(layer, name, q_linear)
            
            torch.cuda.empty_cache()
            gc.collect()
    
335
    @staticmethod
336
    def _scale_activations(self, layer):
337
        scale_dict = self.get_act_for_scaling(layer)
338

339
340
341
        if scale_dict['is_scalable']:
            if not isinstance(scale_dict['scale_layer'], ScaledActivation):
                param = next(layer.parameters())
342

343
344
                # get activation scale
                scale_like = torch.ones(scale_dict['scale_shape'], dtype=param.dtype, device=param.device)
345

346
347
348
                # scale activation
                scaled_act = ScaledActivation(scale_dict['scale_layer'], scale_like)
                set_op_by_name(layer, scale_dict['scale_name'], scaled_act)