base.py 16.7 KB
Newer Older
1
import os
Casper Hansen's avatar
Casper Hansen committed
2
import gc
3
import json
Casper Hansen's avatar
Casper Hansen committed
4
import torch
Casper Hansen's avatar
Casper Hansen committed
5
import logging
Casper Hansen's avatar
Casper Hansen committed
6
7
import functools
import torch.nn as nn
Casper Hansen's avatar
Casper Hansen committed
8
from tqdm import tqdm
9
from typing import List, Union
Casper Hansen's avatar
Casper Hansen committed
10
from collections import defaultdict
11
from safetensors.torch import save_file
Casper Hansen's avatar
Casper Hansen committed
12

13
from awq.modules.act import ScaledActivation
14
from huggingface_hub import snapshot_download
Casper Hansen's avatar
Casper Hansen committed
15
from awq.utils.utils import simple_dispatch_model
Casper Hansen's avatar
Casper Hansen committed
16
from awq.utils.calib_data import get_calib_dataset
17
from transformers.modeling_utils import shard_checkpoint
18
from awq.quantize.quantizer import pseudo_quantize_tensor
19
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
Casper Hansen's avatar
Casper Hansen committed
20
21
from awq.quantize.auto_clip import auto_clip_block, apply_clip
from awq.quantize.auto_scale import auto_scale_block, apply_scale
22
from transformers import AutoModelForCausalLM, AutoConfig, PreTrainedModel
Casper Hansen's avatar
Casper Hansen committed
23
from accelerate import init_empty_weights, load_checkpoint_in_model, infer_auto_device_map
Casper Hansen's avatar
Casper Hansen committed
24
from awq.utils.module import append_str_prefix, get_op_name, get_named_linears, set_op_by_name
Casper Hansen's avatar
Casper Hansen committed
25

26
class BaseAWQForCausalLM(nn.Module):
27
    def __init__(self, model, model_type, is_quantized, quant_config):
28
        super().__init__()
29
30
31
32
        self.model:PreTrainedModel = model
        self.model_type:str = model_type
        self.is_quantized:bool = is_quantized
        self.search_result = None
33
        self.quant_config:dict = quant_config
34
35
36
37
38
39
    
    def to(self, device: str):
        return self.model.to(device)
    
    def forward(self, *args, **kwargs):
        return self.model(*args, **kwargs)
Casper Hansen's avatar
Casper Hansen committed
40
41
42
43
    
    def generate(self, *args, **kwargs):
        with torch.inference_mode():
            return self.model.generate(*args, **kwargs)
44

Casper Hansen's avatar
Casper Hansen committed
45
    @torch.no_grad()
46
    def quantize(self, tokenizer=None, quant_config={}, n_samples=128, seqlen=512,
47
                       auto_scale=True, mse_range=True, run_search=True, run_quant=True,
48
49
                       calib_data: Union[str, List[str]]="pileval", split="train",
                       text_column="text"):
50
        self.quant_config = quant_config
51
        quant_config["version"] = "GEMM" if 'version' not in quant_config.keys() else quant_config["version"]
52

Casper Hansen's avatar
Casper Hansen committed
53
        if run_search:
54
55
56
57
58
            self.search_result = self._awq_search(
                tokenizer, quant_config, n_samples=n_samples, seqlen=seqlen,
                auto_scale=auto_scale, mse_range=mse_range, calib_data=calib_data,
                split=split, text_column=text_column
            )
Casper Hansen's avatar
Casper Hansen committed
59
60
        
        if run_quant:
61
            self._awq_quant()
Casper Hansen's avatar
Casper Hansen committed
62
            self.is_quantized = True
Casper Hansen's avatar
Casper Hansen committed
63
    
qwopqwop200's avatar
qwopqwop200 committed
64
    @staticmethod
65
    def fuse_layers(model, quant_config):
qwopqwop200's avatar
qwopqwop200 committed
66
67
        pass
        
68
69
    def _awq_quant(self):
        assert self.quant_config["zero_point"], "We only support zero_point quantization now."
70
        layers = self.get_model_layers(self.model)
Casper's avatar
Casper committed
71

Casper Hansen's avatar
Casper Hansen committed
72
73
74
75
        # Run AWQ quantization
        for i in tqdm(range(len(layers)), desc="AWQ Quantization"):
            layer = layers[i]
            named_linears = get_named_linears(layer)
76
            self._scale_activations(self, layer)
Casper Hansen's avatar
Casper Hansen committed
77
78

            for name, module in named_linears.items():
Casper Hansen's avatar
Casper Hansen committed
79
                module.cuda()
80
81
82
83

                module.weight.data, scales, zeros = pseudo_quantize_tensor(
                    module.weight.data, 
                    get_scale_zp=True, 
84
85
                    w_bit=self.quant_config["w_bit"], 
                    q_group_size=self.quant_config["q_group_size"]
86
87
                )

88
                if self.quant_config["version"] == 'GEMM':
89
90
                    scales = scales.t().contiguous()
                    zeros = zeros.t().contiguous()
91
92
93
94
95
96
97
98
99
100
                    q_linear_module = WQLinear_GEMM
                elif self.quant_config["version"] == 'GEMV':
                    q_linear_module = WQLinear_GEMV
                
                q_linear = q_linear_module.from_linear(
                    module,
                    self.quant_config['w_bit'],
                    self.quant_config['q_group_size'],
                    False,
                    scales,
101
102
103
                    zeros
                )

Casper Hansen's avatar
Casper Hansen committed
104
105
106
107
108
                module.cpu()
                q_linear.to(next(layer.parameters()).device)
                set_op_by_name(layer, name, q_linear)
                torch.cuda.empty_cache()
                gc.collect()
Casper Hansen's avatar
Casper Hansen committed
109
110
111
112
            
            torch.cuda.empty_cache()
            gc.collect()
    
113
    def _awq_search(self, tokenizer, quant_config, n_samples=128, seqlen=512,
Casper's avatar
Casper committed
114
115
                       auto_scale=True, mse_range=True, calib_data:Union[str, List[str]]="pileval",
                       split="train", text_column="text"):
116
        layers = self.get_model_layers(self.model)
Casper Hansen's avatar
Casper Hansen committed
117
118

        samples = get_calib_dataset(
Casper's avatar
Casper committed
119
120
121
            data=calib_data, tokenizer=tokenizer, n_samples=n_samples, block_size=seqlen,
            split=split, text_column=text_column
        )
Casper Hansen's avatar
Casper Hansen committed
122
123
124
125
126
127
        samples = torch.cat(samples, dim=0)

        inps = []
        layer_kwargs = {}

        layers[0] = layers[0].cuda()
128
        self.move_embed(self.model, "cuda")
Casper Hansen's avatar
Casper Hansen committed
129
130
131
132
133
134
135
136
137
        
        # get input and kwargs to layer 0
        # with_kwargs is only supported in PyTorch 2.0
        # use this Catcher hack for now
        class Catcher(nn.Module):
            def __init__(self, module):
                super().__init__()
                self.module = module

Casper's avatar
Casper committed
138
139
            def forward(self, hijacked_inputs, **kwargs):
                inps.append(hijacked_inputs)
Casper Hansen's avatar
Casper Hansen committed
140
141
142
143
144
145
                layer_kwargs.update(kwargs)
                raise ValueError  # early exit to break later inference

        # patch layer 0 to catch input and kwargs
        layers[0] = Catcher(layers[0])
        try:
146
            self.model(samples.to(next(self.model.parameters()).device))
Casper Hansen's avatar
Casper Hansen committed
147
148
149
150
151
152
153
        except ValueError:  # work with early exit
            pass
        del samples
        layers[0] = layers[0].module  # restore
        inps = inps[0]

        layers[0] = layers[0].cpu()
154
        self.move_embed(self.model, "cpu")
Casper Hansen's avatar
Casper Hansen committed
155
156
157
158
159
160
161
162
        
        gc.collect()
        torch.cuda.empty_cache()
        awq_results = {
            "scale": [],
            "clip": [],
        }

Casper Hansen's avatar
Casper Hansen committed
163
        # Run AWQ search layer by layer
164
        for i in tqdm(range(len(layers)), desc="AWQ Search"):
Casper Hansen's avatar
Casper Hansen committed
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
            layer = layers[i]
            layer = layer.cuda()
            named_linears = get_named_linears(layer)

            # firstly, get input features of all linear layers
            def cache_input_hook(m, x, y, name, feat_dict):
                x = x[0]
                x = x.detach().cpu()
                feat_dict[name].append(x)

            input_feat = defaultdict(list)
            handles = []
            for name in named_linears:
                handles.append(named_linears[name].register_forward_hook(
                    functools.partial(cache_input_hook, name=name,
                                    feat_dict=input_feat)))
            inps = inps.to(next(layer.parameters()).device)  # in case multi-gpu
            # get output as next layer's input
            inps = layer(inps, **layer_kwargs)[0]
            for h in handles:
                h.remove()
            # now solve for scaling and clipping
            input_feat = {k: torch.cat(v, dim=0) for k, v in input_feat.items()}

            # Clear GPU memory
            torch.cuda.empty_cache()

            if auto_scale:  # if it applies, we should also modify the input_feat with scales
                scales_list = auto_scale_block(
                    self,
195
196
197
                    layer,
                    layer_kwargs,
                    quant_config=quant_config,
Casper Hansen's avatar
Casper Hansen committed
198
199
                    input_feat=input_feat,
                )
200

Casper Hansen's avatar
Casper Hansen committed
201
                apply_scale(layers[i], scales_list, input_feat_dict=input_feat)
202

Casper Hansen's avatar
Casper Hansen committed
203
                # append prefix to make names global
204
                awq_results["scale"] += append_str_prefix(scales_list, get_op_name(self.model, layer) + ".")
Casper Hansen's avatar
Casper Hansen committed
205
206
207
208
209

            # Clear GPU memory
            torch.cuda.empty_cache()
            
            if mse_range:
210
211
212
213
214
215
                clip_list = auto_clip_block(
                    layer,
                    quant_config=quant_config,
                    input_feat=input_feat
                )

Casper Hansen's avatar
Casper Hansen committed
216
217
                apply_clip(layer, clip_list)
                # append prefix to make names global
218
                awq_results["clip"] += append_str_prefix(clip_list, get_op_name(self.model, layer) + ".")
Casper Hansen's avatar
Casper Hansen committed
219
220
221
222
223
224

            layer = layer.cpu()
            # Haotian: check activation replacement
            del input_feat
            gc.collect()
            torch.cuda.empty_cache()
Casper Hansen's avatar
Casper Hansen committed
225
        
Casper Hansen's avatar
Casper Hansen committed
226
        return awq_results
Casper's avatar
Casper committed
227

228
    def save_quantized(self, save_dir, safetensors=False, shard_size="10GB"):
Casper Hansen's avatar
Casper Hansen committed
229
        def _save_files(save_dir, model_name='', search_result=None):
230
231
232
233
            class EmptyModule(nn.Module):
                def __init__(self): super(EmptyModule, self).__init__()
                def forward(self, x): return x

234
            # Save model files with empty state dict
235
236
            self.model.save_pretrained(save_dir, state_dict=EmptyModule().state_dict())

237
            # Remove empty state dict
238
239
            os.remove(f'{save_dir}/pytorch_model.bin')

240
241
242
243
            if search_result is not None:
                torch.save(search_result, f'{save_dir}/{model_name}')
            else:
                # model_name has no extension, add it when saving state_dict
244
                model_name = 'model.safetensors' if safetensors else 'pytorch_model.bin'
245
246
247
248
249
250
251
252
253

                # shard checkpoint into chunks (10GB default)
                shards, index = shard_checkpoint(
                    self.model.state_dict(), 
                    max_shard_size=shard_size, 
                    weights_name=model_name
                )

                for shard_file, shard in shards.items():
254
                    if safetensors:
255
256
257
258
259
260
261
262
263
264
                        # safetensors must be in the same memory, so we duplicate and use contiguous memory
                        shard = {k: v.clone().contiguous() for k, v in shard.items()}
                        save_file(shard, os.path.join(save_dir, shard_file), metadata={"format": "pt"})
                    else:
                        torch.save(shard, os.path.join(save_dir, shard_file))

                # save shard index
                if index is not None:
                    with open(f'{save_dir}/{model_name}.index.json', 'w+') as file:
                        file.write(json.dumps(index, indent=4))
265

266
267
268
269
            # Save config
            with open(f'{save_dir}/quant_config.json', 'w+') as file:
                file.write(json.dumps(self.quant_config, indent=4))

270
271
272
        save_dir = save_dir[:-1] if save_dir[-1] == '/' else save_dir

        # Save model
Casper Hansen's avatar
Casper Hansen committed
273
        if self.search_result is None or self.is_quantized:
Casper Hansen's avatar
Casper Hansen committed
274
            _save_files(save_dir, '', search_result=None)
275
276
        else:
            model_name = 'awq_model_search_result.pt'
277
278
            _save_files(save_dir, model_name, self.search_result)
        
279
280
    @classmethod
    def from_pretrained(self, model_path, model_type, torch_dtype: torch.dtype = torch.float16, 
Casper Hansen's avatar
Casper Hansen committed
281
                        trust_remote_code=True, safetensors=False):
282
283
284
        return self.from_quantized(
            model_path, 
            model_type, 
285
            model_filename='', 
286
            max_new_tokens=None,
287
288
289
            device='balanced', 
            torch_dtype=torch_dtype, 
            trust_remote_code=trust_remote_code, 
Casper Hansen's avatar
Casper Hansen committed
290
            safetensors=safetensors,
291
292
            is_quantized=False
        )
Casper's avatar
Casper committed
293

294
    @classmethod
295
    def from_quantized(self, model_path, model_type, model_filename='', 
Casper Hansen's avatar
Casper Hansen committed
296
297
298
                             max_new_tokens=None, device='balanced', torch_dtype=torch.float16, 
                             trust_remote_code=True, safetensors=False, is_quantized=True, 
                             fuse_layers=False, version='GEMM'):
299
        # [STEP 1] Download model if path is not a directory
300
        if not os.path.isdir(model_path):
301
302
            ignore_patterns = ["*msgpack*", "*h5*"]
            if safetensors:
Casper Hansen's avatar
Casper Hansen committed
303
                ignore_patterns.extend(["*.pt*", "*.bin*"])
304
            else:
Casper Hansen's avatar
Casper Hansen committed
305
306
                ignore_patterns.append("*.safetensors*")
            
307
            model_path = snapshot_download(model_path, ignore_patterns=ignore_patterns)
308
        
309
310
311
312
        if model_filename != '':
            model_weights_path = model_path + f'/{model_filename}'
        else:
            model_weights_path = model_path
313

314
        # [STEP 2] Load config and set sequence length
315
        # TODO: Create BaseAWQConfig class
316
317
318
319
        quant_config_path = f'{model_path}/quant_config.json'
        if os.path.exists(quant_config_path):
            with open(quant_config_path, 'r') as file:
                quant_config = json.loads(file.read())
320
321
322
            
            if "version" not in quant_config.keys():
                quant_config["version"] = version
323
324
        else:
            # Default config that works for most models
325
            quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4, "version": version}
326
        
327
328
329
330
331
332
333
334
335
        # Load model config and set max generation length
        if max_new_tokens is None and hasattr(self, 'max_new_tokens_key'):
            config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)
            config.max_new_tokens = getattr(config, self.max_new_tokens_key)
        else:
            max_new_tokens = 2048 if max_new_tokens is None else max_new_tokens
            config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)
            config.max_new_tokens = max_new_tokens
        
336
        # [STEP 3] Load model
337
        with init_empty_weights():
338
339
            model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code)
        
340
        # Only need to replace layers if a model is AWQ quantized
341
342
        if is_quantized:
            # Prepare WQLinear layers, replace nn.Linear
343
            self._load_quantized_modules(self, model, quant_config, quant_config["version"])
344
345
        
        model.tie_weights()
346

347
348
349
350
351
352
        device_map = infer_auto_device_map(
            model,
            no_split_module_classes=[self.layer_type], 
            dtype=torch_dtype
        )

353
        # Load model weights
354
        if is_quantized:
Casper Hansen's avatar
Casper Hansen committed
355
356
            load_checkpoint_in_model(
                model,
357
                checkpoint=model_weights_path,
Casper Hansen's avatar
Casper Hansen committed
358
                device_map=device_map
359
            )
Casper Hansen's avatar
Casper Hansen committed
360
361
362
            
            model = simple_dispatch_model(model, device_map)
            
363
            if fuse_layers:
364
                self.fuse_layers(model, quant_config)
365

366
367
        else:
            # If not quantized, must load with AutoModelForCausalLM
368
369
370
371
            del model
            
            # Load model weights
            model = AutoModelForCausalLM.from_pretrained(
Casper Hansen's avatar
Casper Hansen committed
372
                model_weights_path, 
373
374
375
376
377
378
                device_map=device_map, 
                trust_remote_code=trust_remote_code, 
                offload_folder="offload", 
                offload_state_dict=True, 
                torch_dtype=torch_dtype, 
                use_safetensors=safetensors
379
380
            )
            model.eval()
381

382
        return self(model, model_type, is_quantized=is_quantized, quant_config=quant_config)
Casper's avatar
Casper committed
383

Casper Hansen's avatar
Casper Hansen committed
384
    def _load_quantized_modules(self, model, quant_config, version):
385
        # Real quantization of weights
386
        assert quant_config["zero_point"], "We only support zero_point quantization now."
387
388
        
        # Get blocks of model
389
        layers = self.get_model_layers(model)
390

391
392
        for i in tqdm(range(len(layers)), desc="Replacing layers..."):
            layer = layers[i]
393
394

            # Get every linear layer in a block
395
            named_linears = get_named_linears(layer)
396
397

            # Replace activation functions
398
            self._scale_activations(self, layer)
399

400
            # Replace nn.Linear with WQLinear
401
            for name, module in named_linears.items():
Casper Hansen's avatar
Casper Hansen committed
402
403
404
405
406
407
                if version == 'GEMM':
                    q_linear_module = WQLinear_GEMM
                elif version == 'GEMV':
                    q_linear_module = WQLinear_GEMV
                
                q_linear = q_linear_module.from_linear(
408
409
410
                    module,
                    quant_config['w_bit'],
                    quant_config['q_group_size'],
Casper Hansen's avatar
Casper Hansen committed
411
412
                    True
                )
413
414
415
416
417
418
                q_linear.to(next(layer.parameters()).device)
                set_op_by_name(layer, name, q_linear)
            
            torch.cuda.empty_cache()
            gc.collect()
    
419
    @staticmethod
420
    def _scale_activations(self, layer):
421
        scale_dict = self.get_act_for_scaling(layer)
422

423
424
425
        if scale_dict['is_scalable']:
            if not isinstance(scale_dict['scale_layer'], ScaledActivation):
                param = next(layer.parameters())
426

427
428
                # get activation scale
                scale_like = torch.ones(scale_dict['scale_shape'], dtype=param.dtype, device=param.device)
429

430
431
                # scale activation
                scaled_act = ScaledActivation(scale_dict['scale_layer'], scale_like)
Casper's avatar
Casper committed
432
                set_op_by_name(layer, scale_dict['scale_name'], scaled_act)