base.py 16.3 KB
Newer Older
1
import os
Casper Hansen's avatar
Casper Hansen committed
2
import gc
3
import json
Casper Hansen's avatar
Casper Hansen committed
4
import torch
Casper Hansen's avatar
Casper Hansen committed
5
import logging
Casper Hansen's avatar
Casper Hansen committed
6
7
import functools
import torch.nn as nn
Casper Hansen's avatar
Casper Hansen committed
8
from tqdm import tqdm
Casper Hansen's avatar
Casper Hansen committed
9
from collections import defaultdict
10
from safetensors.torch import save_file
Casper Hansen's avatar
Casper Hansen committed
11

12
from awq.modules.act import ScaledActivation
13
from huggingface_hub import snapshot_download
Casper Hansen's avatar
Casper Hansen committed
14
from awq.utils.utils import simple_dispatch_model
Casper Hansen's avatar
Casper Hansen committed
15
from awq.utils.calib_data import get_calib_dataset
16
from transformers.modeling_utils import shard_checkpoint
17
from awq.quantize.quantizer import pseudo_quantize_tensor
18
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
Casper Hansen's avatar
Casper Hansen committed
19
20
from awq.quantize.auto_clip import auto_clip_block, apply_clip
from awq.quantize.auto_scale import auto_scale_block, apply_scale
21
from transformers import AutoModelForCausalLM, AutoConfig, PreTrainedModel
Casper Hansen's avatar
Casper Hansen committed
22
from accelerate import init_empty_weights, load_checkpoint_in_model, infer_auto_device_map
Casper Hansen's avatar
Casper Hansen committed
23
from awq.utils.module import append_str_prefix, get_op_name, get_named_linears, set_op_by_name
Casper Hansen's avatar
Casper Hansen committed
24

25
class BaseAWQForCausalLM(nn.Module):
26
    def __init__(self, model, model_type, is_quantized, quant_config):
27
        super().__init__()
28
29
30
31
        self.model:PreTrainedModel = model
        self.model_type:str = model_type
        self.is_quantized:bool = is_quantized
        self.search_result = None
32
        self.quant_config:dict = quant_config
33
34
35
36
37
38
    
    def to(self, device: str):
        return self.model.to(device)
    
    def forward(self, *args, **kwargs):
        return self.model(*args, **kwargs)
Casper Hansen's avatar
Casper Hansen committed
39
40
41
42
    
    def generate(self, *args, **kwargs):
        with torch.inference_mode():
            return self.model.generate(*args, **kwargs)
43

Casper Hansen's avatar
Casper Hansen committed
44
    @torch.no_grad()
45
    def quantize(self, tokenizer=None, quant_config={}, n_samples=128, seqlen=512,
46
                       auto_scale=True, mse_range=True, run_search=True, run_quant=True,
Casper Hansen's avatar
Casper Hansen committed
47
                       calib_data="pileval"):
48
        self.quant_config = quant_config
49
        quant_config["version"] = "GEMM" if 'version' not in quant_config.keys() else quant_config["version"]
50

Casper Hansen's avatar
Casper Hansen committed
51
        if run_search:
52
            self.search_result = self._awq_search(tokenizer, quant_config, n_samples=n_samples, seqlen=seqlen,
Casper Hansen's avatar
Casper Hansen committed
53
54
55
                       auto_scale=auto_scale, mse_range=mse_range, calib_data=calib_data)
        
        if run_quant:
56
            self._awq_quant()
Casper Hansen's avatar
Casper Hansen committed
57
            self.is_quantized = True
Casper Hansen's avatar
Casper Hansen committed
58
    
qwopqwop200's avatar
qwopqwop200 committed
59
    @staticmethod
60
    def fuse_layers(model, quant_config):
qwopqwop200's avatar
qwopqwop200 committed
61
62
        pass
        
63
64
    def _awq_quant(self):
        assert self.quant_config["zero_point"], "We only support zero_point quantization now."
65
        layers = self.get_model_layers(self.model)
Casper's avatar
Casper committed
66

Casper Hansen's avatar
Casper Hansen committed
67
68
69
70
        # Run AWQ quantization
        for i in tqdm(range(len(layers)), desc="AWQ Quantization"):
            layer = layers[i]
            named_linears = get_named_linears(layer)
71
            self._scale_activations(self, layer)
Casper Hansen's avatar
Casper Hansen committed
72
73

            for name, module in named_linears.items():
Casper Hansen's avatar
Casper Hansen committed
74
                module.cuda()
75
76
77
78

                module.weight.data, scales, zeros = pseudo_quantize_tensor(
                    module.weight.data, 
                    get_scale_zp=True, 
79
80
                    w_bit=self.quant_config["w_bit"], 
                    q_group_size=self.quant_config["q_group_size"]
81
82
                )

83
                if self.quant_config["version"] == 'GEMM':
84
85
                    scales = scales.t().contiguous()
                    zeros = zeros.t().contiguous()
86
87
88
89
90
91
92
93
94
95
                    q_linear_module = WQLinear_GEMM
                elif self.quant_config["version"] == 'GEMV':
                    q_linear_module = WQLinear_GEMV
                
                q_linear = q_linear_module.from_linear(
                    module,
                    self.quant_config['w_bit'],
                    self.quant_config['q_group_size'],
                    False,
                    scales,
96
97
98
                    zeros
                )

Casper Hansen's avatar
Casper Hansen committed
99
100
101
102
103
                module.cpu()
                q_linear.to(next(layer.parameters()).device)
                set_op_by_name(layer, name, q_linear)
                torch.cuda.empty_cache()
                gc.collect()
Casper Hansen's avatar
Casper Hansen committed
104
105
106
107
            
            torch.cuda.empty_cache()
            gc.collect()
    
108
    def _awq_search(self, tokenizer, quant_config, n_samples=128, seqlen=512,
Casper Hansen's avatar
Casper Hansen committed
109
                       auto_scale=True, mse_range=True, calib_data="pileval"):
110
        layers = self.get_model_layers(self.model)
Casper Hansen's avatar
Casper Hansen committed
111
112
113
114
115
116
117
118
119

        samples = get_calib_dataset(
            data=calib_data, tokenizer=tokenizer, n_samples=n_samples, block_size=seqlen)
        samples = torch.cat(samples, dim=0)

        inps = []
        layer_kwargs = {}

        layers[0] = layers[0].cuda()
120
        self.move_embed(self.model, "cuda")
Casper Hansen's avatar
Casper Hansen committed
121
122
123
124
125
126
127
128
129
        
        # get input and kwargs to layer 0
        # with_kwargs is only supported in PyTorch 2.0
        # use this Catcher hack for now
        class Catcher(nn.Module):
            def __init__(self, module):
                super().__init__()
                self.module = module

Casper's avatar
Casper committed
130
131
            def forward(self, hijacked_inputs, **kwargs):
                inps.append(hijacked_inputs)
Casper Hansen's avatar
Casper Hansen committed
132
133
134
135
136
137
                layer_kwargs.update(kwargs)
                raise ValueError  # early exit to break later inference

        # patch layer 0 to catch input and kwargs
        layers[0] = Catcher(layers[0])
        try:
138
            self.model(samples.to(next(self.model.parameters()).device))
Casper Hansen's avatar
Casper Hansen committed
139
140
141
142
143
144
145
        except ValueError:  # work with early exit
            pass
        del samples
        layers[0] = layers[0].module  # restore
        inps = inps[0]

        layers[0] = layers[0].cpu()
146
        self.move_embed(self.model, "cpu")
Casper Hansen's avatar
Casper Hansen committed
147
148
149
150
151
152
153
154
        
        gc.collect()
        torch.cuda.empty_cache()
        awq_results = {
            "scale": [],
            "clip": [],
        }

Casper Hansen's avatar
Casper Hansen committed
155
        # Run AWQ search layer by layer
156
        for i in tqdm(range(len(layers)), desc="AWQ Search"):
Casper Hansen's avatar
Casper Hansen committed
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
            layer = layers[i]
            layer = layer.cuda()
            named_linears = get_named_linears(layer)

            # firstly, get input features of all linear layers
            def cache_input_hook(m, x, y, name, feat_dict):
                x = x[0]
                x = x.detach().cpu()
                feat_dict[name].append(x)

            input_feat = defaultdict(list)
            handles = []
            for name in named_linears:
                handles.append(named_linears[name].register_forward_hook(
                    functools.partial(cache_input_hook, name=name,
                                    feat_dict=input_feat)))
            inps = inps.to(next(layer.parameters()).device)  # in case multi-gpu
            # get output as next layer's input
            inps = layer(inps, **layer_kwargs)[0]
            for h in handles:
                h.remove()
            # now solve for scaling and clipping
            input_feat = {k: torch.cat(v, dim=0) for k, v in input_feat.items()}

            # Clear GPU memory
            torch.cuda.empty_cache()

            if auto_scale:  # if it applies, we should also modify the input_feat with scales
                scales_list = auto_scale_block(
                    self,
187
188
189
                    layer,
                    layer_kwargs,
                    quant_config=quant_config,
Casper Hansen's avatar
Casper Hansen committed
190
191
                    input_feat=input_feat,
                )
192

Casper Hansen's avatar
Casper Hansen committed
193
                apply_scale(layers[i], scales_list, input_feat_dict=input_feat)
194

Casper Hansen's avatar
Casper Hansen committed
195
                # append prefix to make names global
196
                awq_results["scale"] += append_str_prefix(scales_list, get_op_name(self.model, layer) + ".")
Casper Hansen's avatar
Casper Hansen committed
197
198
199
200
201

            # Clear GPU memory
            torch.cuda.empty_cache()
            
            if mse_range:
202
203
204
205
206
207
                clip_list = auto_clip_block(
                    layer,
                    quant_config=quant_config,
                    input_feat=input_feat
                )

Casper Hansen's avatar
Casper Hansen committed
208
209
                apply_clip(layer, clip_list)
                # append prefix to make names global
210
                awq_results["clip"] += append_str_prefix(clip_list, get_op_name(self.model, layer) + ".")
Casper Hansen's avatar
Casper Hansen committed
211
212
213
214
215
216

            layer = layer.cpu()
            # Haotian: check activation replacement
            del input_feat
            gc.collect()
            torch.cuda.empty_cache()
Casper Hansen's avatar
Casper Hansen committed
217
        
Casper Hansen's avatar
Casper Hansen committed
218
        return awq_results
Casper's avatar
Casper committed
219

220
    def save_quantized(self, save_dir, safetensors=False, shard_size="10GB"):
Casper Hansen's avatar
Casper Hansen committed
221
        def _save_files(save_dir, model_name='', search_result=None):
222
223
224
225
            class EmptyModule(nn.Module):
                def __init__(self): super(EmptyModule, self).__init__()
                def forward(self, x): return x

226
            # Save model files with empty state dict
227
228
            self.model.save_pretrained(save_dir, state_dict=EmptyModule().state_dict())

229
            # Remove empty state dict
230
231
            os.remove(f'{save_dir}/pytorch_model.bin')

232
233
234
235
            if search_result is not None:
                torch.save(search_result, f'{save_dir}/{model_name}')
            else:
                # model_name has no extension, add it when saving state_dict
236
                model_name = 'model.safetensors' if safetensors else 'pytorch_model.bin'
237
238
239
240
241
242
243
244
245

                # shard checkpoint into chunks (10GB default)
                shards, index = shard_checkpoint(
                    self.model.state_dict(), 
                    max_shard_size=shard_size, 
                    weights_name=model_name
                )

                for shard_file, shard in shards.items():
246
                    if safetensors:
247
248
249
250
251
252
253
254
255
256
                        # safetensors must be in the same memory, so we duplicate and use contiguous memory
                        shard = {k: v.clone().contiguous() for k, v in shard.items()}
                        save_file(shard, os.path.join(save_dir, shard_file), metadata={"format": "pt"})
                    else:
                        torch.save(shard, os.path.join(save_dir, shard_file))

                # save shard index
                if index is not None:
                    with open(f'{save_dir}/{model_name}.index.json', 'w+') as file:
                        file.write(json.dumps(index, indent=4))
257

258
259
260
261
            # Save config
            with open(f'{save_dir}/quant_config.json', 'w+') as file:
                file.write(json.dumps(self.quant_config, indent=4))

262
263
264
        save_dir = save_dir[:-1] if save_dir[-1] == '/' else save_dir

        # Save model
Casper Hansen's avatar
Casper Hansen committed
265
        if self.search_result is None or self.is_quantized:
Casper Hansen's avatar
Casper Hansen committed
266
            _save_files(save_dir, '', search_result=None)
267
268
        else:
            model_name = 'awq_model_search_result.pt'
269
270
            _save_files(save_dir, model_name, self.search_result)
        
271
272
    @classmethod
    def from_pretrained(self, model_path, model_type, torch_dtype: torch.dtype = torch.float16, 
Casper Hansen's avatar
Casper Hansen committed
273
                        trust_remote_code=True, safetensors=False):
274
275
276
        return self.from_quantized(
            model_path, 
            model_type, 
277
            model_filename='', 
278
            max_new_tokens=None,
279
280
281
            device='balanced', 
            torch_dtype=torch_dtype, 
            trust_remote_code=trust_remote_code, 
Casper Hansen's avatar
Casper Hansen committed
282
            safetensors=safetensors,
283
284
            is_quantized=False
        )
Casper's avatar
Casper committed
285

286
    @classmethod
Casper Hansen's avatar
Casper Hansen committed
287
288
289
290
    def from_quantized(self, model_path, model_type, model_filename='pytorch_model.bin', 
                             max_new_tokens=None, device='balanced', torch_dtype=torch.float16, 
                             trust_remote_code=True, safetensors=False, is_quantized=True, 
                             fuse_layers=False, version='GEMM'):
291
        # [STEP 1] Download model if path is not a directory
292
        if not os.path.isdir(model_path):
293
294
            ignore_patterns = ["*msgpack*", "*h5*"]
            if safetensors:
Casper Hansen's avatar
Casper Hansen committed
295
                ignore_patterns.extend(["*.pt*", "*.bin*"])
296
            else:
Casper Hansen's avatar
Casper Hansen committed
297
298
                ignore_patterns.append("*.safetensors*")
            
299
            model_path = snapshot_download(model_path, ignore_patterns=ignore_patterns)
300
        
Casper Hansen's avatar
Casper Hansen committed
301
        model_weights_path = model_path + f'/{model_filename}'
302

303
        # [STEP 2] Load config and set sequence length
304
        # TODO: Create BaseAWQConfig class
305
306
307
308
        quant_config_path = f'{model_path}/quant_config.json'
        if os.path.exists(quant_config_path):
            with open(quant_config_path, 'r') as file:
                quant_config = json.loads(file.read())
309
310
311
            
            if "version" not in quant_config.keys():
                quant_config["version"] = version
312
313
        else:
            # Default config that works for most models
314
            quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4, "version": version}
315
        
316
317
318
319
320
321
322
323
324
        # Load model config and set max generation length
        if max_new_tokens is None and hasattr(self, 'max_new_tokens_key'):
            config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)
            config.max_new_tokens = getattr(config, self.max_new_tokens_key)
        else:
            max_new_tokens = 2048 if max_new_tokens is None else max_new_tokens
            config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)
            config.max_new_tokens = max_new_tokens
        
325
        # [STEP 3] Load model
326
        with init_empty_weights():
327
328
            model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code)
        
329
        # Only need to replace layers if a model is AWQ quantized
330
331
        if is_quantized:
            # Prepare WQLinear layers, replace nn.Linear
332
            self._load_quantized_modules(self, model, quant_config, quant_config["version"])
333
334
        
        model.tie_weights()
335

336
337
338
339
340
341
        device_map = infer_auto_device_map(
            model,
            no_split_module_classes=[self.layer_type], 
            dtype=torch_dtype
        )

342
        # Load model weights
343
        if is_quantized:
Casper Hansen's avatar
Casper Hansen committed
344
345
346
347
            load_checkpoint_in_model(
                model,
                checkpoint=model_path if safetensors else model_weights_path,
                device_map=device_map
348
            )
Casper Hansen's avatar
Casper Hansen committed
349
350
351
            
            model = simple_dispatch_model(model, device_map)
            
352
            if fuse_layers:
353
                self.fuse_layers(model, quant_config)
354

355
356
        else:
            # If not quantized, must load with AutoModelForCausalLM
357
358
359
360
            del model
            
            # Load model weights
            model = AutoModelForCausalLM.from_pretrained(
Casper Hansen's avatar
Casper Hansen committed
361
                model_weights_path, 
362
363
364
365
366
367
                device_map=device_map, 
                trust_remote_code=trust_remote_code, 
                offload_folder="offload", 
                offload_state_dict=True, 
                torch_dtype=torch_dtype, 
                use_safetensors=safetensors
368
369
            )
            model.eval()
370

371
        return self(model, model_type, is_quantized=is_quantized, quant_config=quant_config)
Casper's avatar
Casper committed
372

Casper Hansen's avatar
Casper Hansen committed
373
    def _load_quantized_modules(self, model, quant_config, version):
374
        # Real quantization of weights
375
        assert quant_config["zero_point"], "We only support zero_point quantization now."
376
377
        
        # Get blocks of model
378
        layers = self.get_model_layers(model)
379

380
381
        for i in tqdm(range(len(layers)), desc="Replacing layers..."):
            layer = layers[i]
382
383

            # Get every linear layer in a block
384
            named_linears = get_named_linears(layer)
385
386

            # Replace activation functions
387
            self._scale_activations(self, layer)
388

389
            # Replace nn.Linear with WQLinear
390
            for name, module in named_linears.items():
Casper Hansen's avatar
Casper Hansen committed
391
392
393
394
395
396
                if version == 'GEMM':
                    q_linear_module = WQLinear_GEMM
                elif version == 'GEMV':
                    q_linear_module = WQLinear_GEMV
                
                q_linear = q_linear_module.from_linear(
397
398
399
                    module,
                    quant_config['w_bit'],
                    quant_config['q_group_size'],
Casper Hansen's avatar
Casper Hansen committed
400
401
                    True
                )
402
403
404
405
406
407
                q_linear.to(next(layer.parameters()).device)
                set_op_by_name(layer, name, q_linear)
            
            torch.cuda.empty_cache()
            gc.collect()
    
408
    @staticmethod
409
    def _scale_activations(self, layer):
410
        scale_dict = self.get_act_for_scaling(layer)
411

412
413
414
        if scale_dict['is_scalable']:
            if not isinstance(scale_dict['scale_layer'], ScaledActivation):
                param = next(layer.parameters())
415

416
417
                # get activation scale
                scale_like = torch.ones(scale_dict['scale_shape'], dtype=param.dtype, device=param.device)
418

419
420
                # scale activation
                scaled_act = ScaledActivation(scale_dict['scale_layer'], scale_like)
Casper's avatar
Casper committed
421
                set_op_by_name(layer, scale_dict['scale_name'], scaled_act)