base.py 16.3 KB
Newer Older
1
import os
Casper Hansen's avatar
Casper Hansen committed
2
import gc
3
import json
Casper Hansen's avatar
Casper Hansen committed
4
import torch
Casper Hansen's avatar
Casper Hansen committed
5
import logging
Casper Hansen's avatar
Casper Hansen committed
6
7
import functools
import torch.nn as nn
Casper Hansen's avatar
Casper Hansen committed
8
from tqdm import tqdm
Casper Hansen's avatar
Casper Hansen committed
9
from collections import defaultdict
10
from safetensors.torch import save_file
Casper Hansen's avatar
Casper Hansen committed
11

12
from awq.modules.act import ScaledActivation
13
from huggingface_hub import snapshot_download
Casper Hansen's avatar
Casper Hansen committed
14
from awq.utils.utils import simple_dispatch_model
Casper Hansen's avatar
Casper Hansen committed
15
from awq.utils.calib_data import get_calib_dataset
16
from transformers.modeling_utils import shard_checkpoint
17
from awq.quantize.quantizer import pseudo_quantize_tensor
18
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
Casper Hansen's avatar
Casper Hansen committed
19
20
from awq.quantize.auto_clip import auto_clip_block, apply_clip
from awq.quantize.auto_scale import auto_scale_block, apply_scale
21
from transformers import AutoModelForCausalLM, AutoConfig, PreTrainedModel
Casper Hansen's avatar
Casper Hansen committed
22
from accelerate import init_empty_weights, load_checkpoint_in_model, infer_auto_device_map
Casper Hansen's avatar
Casper Hansen committed
23
from awq.utils.module import append_str_prefix, get_op_name, get_named_linears, set_op_by_name
Casper Hansen's avatar
Casper Hansen committed
24

25
class BaseAWQForCausalLM(nn.Module):
26
    def __init__(self, model, model_type, is_quantized, quant_config):
27
        super().__init__()
28
29
30
31
        self.model:PreTrainedModel = model
        self.model_type:str = model_type
        self.is_quantized:bool = is_quantized
        self.search_result = None
32
        self.quant_config:dict = quant_config
33
34
35
36
37
38
    
    def to(self, device: str):
        return self.model.to(device)
    
    def forward(self, *args, **kwargs):
        return self.model(*args, **kwargs)
Casper Hansen's avatar
Casper Hansen committed
39
40
41
42
    
    def generate(self, *args, **kwargs):
        with torch.inference_mode():
            return self.model.generate(*args, **kwargs)
43

Casper Hansen's avatar
Casper Hansen committed
44
    @torch.no_grad()
45
    def quantize(self, tokenizer=None, quant_config={}, n_samples=128, seqlen=512,
46
                       auto_scale=True, mse_range=True, run_search=True, run_quant=True,
Casper Hansen's avatar
Casper Hansen committed
47
                       calib_data="pileval"):
48
        self.quant_config = quant_config
49
        quant_config["version"] = "GEMM" if 'version' not in quant_config.keys() else quant_config["version"]
50

Casper Hansen's avatar
Casper Hansen committed
51
        if run_search:
52
            self.search_result = self._awq_search(tokenizer, quant_config, n_samples=n_samples, seqlen=seqlen,
Casper Hansen's avatar
Casper Hansen committed
53
54
55
                       auto_scale=auto_scale, mse_range=mse_range, calib_data=calib_data)
        
        if run_quant:
56
            self._awq_quant()
Casper Hansen's avatar
Casper Hansen committed
57
            self.is_quantized = True
Casper Hansen's avatar
Casper Hansen committed
58
    
qwopqwop200's avatar
qwopqwop200 committed
59
    @staticmethod
60
    def fuse_layers(model, quant_config):
qwopqwop200's avatar
qwopqwop200 committed
61
62
        pass
        
63
64
    def _awq_quant(self):
        assert self.quant_config["zero_point"], "We only support zero_point quantization now."
65
        layers = self.get_model_layers(self.model)
Casper's avatar
Casper committed
66

Casper Hansen's avatar
Casper Hansen committed
67
68
69
70
        # Run AWQ quantization
        for i in tqdm(range(len(layers)), desc="AWQ Quantization"):
            layer = layers[i]
            named_linears = get_named_linears(layer)
71
            self._scale_activations(self, layer)
Casper Hansen's avatar
Casper Hansen committed
72
73

            for name, module in named_linears.items():
Casper Hansen's avatar
Casper Hansen committed
74
                module.cuda()
75
76
77
78

                module.weight.data, scales, zeros = pseudo_quantize_tensor(
                    module.weight.data, 
                    get_scale_zp=True, 
79
80
                    w_bit=self.quant_config["w_bit"], 
                    q_group_size=self.quant_config["q_group_size"]
81
82
                )

83
                if self.quant_config["version"] == 'GEMM':
84
85
                    scales = scales.t().contiguous()
                    zeros = zeros.t().contiguous()
86
87
88
89
90
91
92
93
94
95
                    q_linear_module = WQLinear_GEMM
                elif self.quant_config["version"] == 'GEMV':
                    q_linear_module = WQLinear_GEMV
                
                q_linear = q_linear_module.from_linear(
                    module,
                    self.quant_config['w_bit'],
                    self.quant_config['q_group_size'],
                    False,
                    scales,
96
97
98
                    zeros
                )

Casper Hansen's avatar
Casper Hansen committed
99
100
101
102
103
                module.cpu()
                q_linear.to(next(layer.parameters()).device)
                set_op_by_name(layer, name, q_linear)
                torch.cuda.empty_cache()
                gc.collect()
Casper Hansen's avatar
Casper Hansen committed
104
105
106
107
            
            torch.cuda.empty_cache()
            gc.collect()
    
108
    def _awq_search(self, tokenizer, quant_config, n_samples=128, seqlen=512,
Casper Hansen's avatar
Casper Hansen committed
109
                       auto_scale=True, mse_range=True, calib_data="pileval"):
110
        layers = self.get_model_layers(self.model)
Casper Hansen's avatar
Casper Hansen committed
111
112
113
114
115
116
117
118
119

        samples = get_calib_dataset(
            data=calib_data, tokenizer=tokenizer, n_samples=n_samples, block_size=seqlen)
        samples = torch.cat(samples, dim=0)

        inps = []
        layer_kwargs = {}

        layers[0] = layers[0].cuda()
120
        self.move_embed(self.model, "cuda")
Casper Hansen's avatar
Casper Hansen committed
121
122
123
124
125
126
127
128
129
        
        # get input and kwargs to layer 0
        # with_kwargs is only supported in PyTorch 2.0
        # use this Catcher hack for now
        class Catcher(nn.Module):
            def __init__(self, module):
                super().__init__()
                self.module = module

Casper's avatar
Casper committed
130
131
            def forward(self, hijacked_inputs, **kwargs):
                inps.append(hijacked_inputs)
Casper Hansen's avatar
Casper Hansen committed
132
133
134
135
136
137
                layer_kwargs.update(kwargs)
                raise ValueError  # early exit to break later inference

        # patch layer 0 to catch input and kwargs
        layers[0] = Catcher(layers[0])
        try:
138
            self.model(samples.to(next(self.model.parameters()).device))
Casper Hansen's avatar
Casper Hansen committed
139
140
141
142
143
144
145
        except ValueError:  # work with early exit
            pass
        del samples
        layers[0] = layers[0].module  # restore
        inps = inps[0]

        layers[0] = layers[0].cpu()
146
        self.move_embed(self.model, "cpu")
Casper Hansen's avatar
Casper Hansen committed
147
148
149
150
151
152
153
154
        
        gc.collect()
        torch.cuda.empty_cache()
        awq_results = {
            "scale": [],
            "clip": [],
        }

Casper Hansen's avatar
Casper Hansen committed
155
        # Run AWQ search layer by layer
156
        for i in tqdm(range(len(layers)), desc="AWQ Search"):
Casper Hansen's avatar
Casper Hansen committed
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
            layer = layers[i]
            layer = layer.cuda()
            named_linears = get_named_linears(layer)

            # firstly, get input features of all linear layers
            def cache_input_hook(m, x, y, name, feat_dict):
                x = x[0]
                x = x.detach().cpu()
                feat_dict[name].append(x)

            input_feat = defaultdict(list)
            handles = []
            for name in named_linears:
                handles.append(named_linears[name].register_forward_hook(
                    functools.partial(cache_input_hook, name=name,
                                    feat_dict=input_feat)))
            inps = inps.to(next(layer.parameters()).device)  # in case multi-gpu
            # get output as next layer's input
            inps = layer(inps, **layer_kwargs)[0]
            for h in handles:
                h.remove()
            # now solve for scaling and clipping
            input_feat = {k: torch.cat(v, dim=0) for k, v in input_feat.items()}

            # Clear GPU memory
            torch.cuda.empty_cache()

            if auto_scale:  # if it applies, we should also modify the input_feat with scales
                scales_list = auto_scale_block(
                    self,
187
188
189
                    layer,
                    layer_kwargs,
                    quant_config=quant_config,
Casper Hansen's avatar
Casper Hansen committed
190
191
                    input_feat=input_feat,
                )
192

Casper Hansen's avatar
Casper Hansen committed
193
                apply_scale(layers[i], scales_list, input_feat_dict=input_feat)
194

Casper Hansen's avatar
Casper Hansen committed
195
                # append prefix to make names global
196
                awq_results["scale"] += append_str_prefix(scales_list, get_op_name(self.model, layer) + ".")
Casper Hansen's avatar
Casper Hansen committed
197
198
199
200
201

            # Clear GPU memory
            torch.cuda.empty_cache()
            
            if mse_range:
202
203
204
205
206
207
                clip_list = auto_clip_block(
                    layer,
                    quant_config=quant_config,
                    input_feat=input_feat
                )

Casper Hansen's avatar
Casper Hansen committed
208
209
                apply_clip(layer, clip_list)
                # append prefix to make names global
210
                awq_results["clip"] += append_str_prefix(clip_list, get_op_name(self.model, layer) + ".")
Casper Hansen's avatar
Casper Hansen committed
211
212
213
214
215
216

            layer = layer.cpu()
            # Haotian: check activation replacement
            del input_feat
            gc.collect()
            torch.cuda.empty_cache()
Casper Hansen's avatar
Casper Hansen committed
217
        
Casper Hansen's avatar
Casper Hansen committed
218
        return awq_results
Casper's avatar
Casper committed
219

220
    def save_quantized(self, save_dir, safetensors=False, shard_size="10GB"):
Casper Hansen's avatar
Casper Hansen committed
221
        def _save_files(save_dir, model_name='', search_result=None):
222
223
224
225
            class EmptyModule(nn.Module):
                def __init__(self): super(EmptyModule, self).__init__()
                def forward(self, x): return x

226
            # Save model files with empty state dict
227
228
            self.model.save_pretrained(save_dir, state_dict=EmptyModule().state_dict())

229
            # Remove empty state dict
230
231
            os.remove(f'{save_dir}/pytorch_model.bin')

232
233
234
235
            if search_result is not None:
                torch.save(search_result, f'{save_dir}/{model_name}')
            else:
                # model_name has no extension, add it when saving state_dict
236
                model_name = 'model.safetensors' if safetensors else 'pytorch_model.bin'
237
238
239
240
241
242
243
244
245

                # shard checkpoint into chunks (10GB default)
                shards, index = shard_checkpoint(
                    self.model.state_dict(), 
                    max_shard_size=shard_size, 
                    weights_name=model_name
                )

                for shard_file, shard in shards.items():
246
                    if safetensors:
247
248
249
250
251
252
253
254
255
256
                        # safetensors must be in the same memory, so we duplicate and use contiguous memory
                        shard = {k: v.clone().contiguous() for k, v in shard.items()}
                        save_file(shard, os.path.join(save_dir, shard_file), metadata={"format": "pt"})
                    else:
                        torch.save(shard, os.path.join(save_dir, shard_file))

                # save shard index
                if index is not None:
                    with open(f'{save_dir}/{model_name}.index.json', 'w+') as file:
                        file.write(json.dumps(index, indent=4))
257

258
259
260
261
            # Save config
            with open(f'{save_dir}/quant_config.json', 'w+') as file:
                file.write(json.dumps(self.quant_config, indent=4))

262
263
264
        save_dir = save_dir[:-1] if save_dir[-1] == '/' else save_dir

        # Save model
Casper Hansen's avatar
Casper Hansen committed
265
        if self.search_result is None or self.is_quantized:
Casper Hansen's avatar
Casper Hansen committed
266
            _save_files(save_dir, '', search_result=None)
267
268
        else:
            model_name = 'awq_model_search_result.pt'
269
270
            _save_files(save_dir, model_name, self.search_result)
        
271
272
    @classmethod
    def from_pretrained(self, model_path, model_type, torch_dtype: torch.dtype = torch.float16, 
Casper Hansen's avatar
Casper Hansen committed
273
                        trust_remote_code=True, safetensors=False):
274
275
276
        return self.from_quantized(
            model_path, 
            model_type, 
277
            model_filename='', 
278
            max_new_tokens=None,
279
280
281
            device='balanced', 
            torch_dtype=torch_dtype, 
            trust_remote_code=trust_remote_code, 
Casper Hansen's avatar
Casper Hansen committed
282
            safetensors=safetensors,
283
284
            is_quantized=False
        )
Casper's avatar
Casper committed
285

286
    @classmethod
287
    def from_quantized(self, model_path, model_type, model_filename='', 
Casper Hansen's avatar
Casper Hansen committed
288
289
290
                             max_new_tokens=None, device='balanced', torch_dtype=torch.float16, 
                             trust_remote_code=True, safetensors=False, is_quantized=True, 
                             fuse_layers=False, version='GEMM'):
291
        # [STEP 1] Download model if path is not a directory
292
        if not os.path.isdir(model_path):
293
294
            ignore_patterns = ["*msgpack*", "*h5*"]
            if safetensors:
Casper Hansen's avatar
Casper Hansen committed
295
                ignore_patterns.extend(["*.pt*", "*.bin*"])
296
            else:
Casper Hansen's avatar
Casper Hansen committed
297
298
                ignore_patterns.append("*.safetensors*")
            
299
            model_path = snapshot_download(model_path, ignore_patterns=ignore_patterns)
300
        
301
302
303
304
        if model_filename != '':
            model_weights_path = model_path + f'/{model_filename}'
        else:
            model_weights_path = model_path
305

306
        # [STEP 2] Load config and set sequence length
307
        # TODO: Create BaseAWQConfig class
308
309
310
311
        quant_config_path = f'{model_path}/quant_config.json'
        if os.path.exists(quant_config_path):
            with open(quant_config_path, 'r') as file:
                quant_config = json.loads(file.read())
312
313
314
            
            if "version" not in quant_config.keys():
                quant_config["version"] = version
315
316
        else:
            # Default config that works for most models
317
            quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4, "version": version}
318
        
319
320
321
322
323
324
325
326
327
        # Load model config and set max generation length
        if max_new_tokens is None and hasattr(self, 'max_new_tokens_key'):
            config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)
            config.max_new_tokens = getattr(config, self.max_new_tokens_key)
        else:
            max_new_tokens = 2048 if max_new_tokens is None else max_new_tokens
            config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)
            config.max_new_tokens = max_new_tokens
        
328
        # [STEP 3] Load model
329
        with init_empty_weights():
330
331
            model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code)
        
332
        # Only need to replace layers if a model is AWQ quantized
333
334
        if is_quantized:
            # Prepare WQLinear layers, replace nn.Linear
335
            self._load_quantized_modules(self, model, quant_config, quant_config["version"])
336
337
        
        model.tie_weights()
338

339
340
341
342
343
344
        device_map = infer_auto_device_map(
            model,
            no_split_module_classes=[self.layer_type], 
            dtype=torch_dtype
        )

345
        # Load model weights
346
        if is_quantized:
Casper Hansen's avatar
Casper Hansen committed
347
348
            load_checkpoint_in_model(
                model,
349
                checkpoint=model_weights_path,
Casper Hansen's avatar
Casper Hansen committed
350
                device_map=device_map
351
            )
Casper Hansen's avatar
Casper Hansen committed
352
353
354
            
            model = simple_dispatch_model(model, device_map)
            
355
            if fuse_layers:
356
                self.fuse_layers(model, quant_config)
357

358
359
        else:
            # If not quantized, must load with AutoModelForCausalLM
360
361
362
363
            del model
            
            # Load model weights
            model = AutoModelForCausalLM.from_pretrained(
Casper Hansen's avatar
Casper Hansen committed
364
                model_weights_path, 
365
366
367
368
369
370
                device_map=device_map, 
                trust_remote_code=trust_remote_code, 
                offload_folder="offload", 
                offload_state_dict=True, 
                torch_dtype=torch_dtype, 
                use_safetensors=safetensors
371
372
            )
            model.eval()
373

374
        return self(model, model_type, is_quantized=is_quantized, quant_config=quant_config)
Casper's avatar
Casper committed
375

Casper Hansen's avatar
Casper Hansen committed
376
    def _load_quantized_modules(self, model, quant_config, version):
377
        # Real quantization of weights
378
        assert quant_config["zero_point"], "We only support zero_point quantization now."
379
380
        
        # Get blocks of model
381
        layers = self.get_model_layers(model)
382

383
384
        for i in tqdm(range(len(layers)), desc="Replacing layers..."):
            layer = layers[i]
385
386

            # Get every linear layer in a block
387
            named_linears = get_named_linears(layer)
388
389

            # Replace activation functions
390
            self._scale_activations(self, layer)
391

392
            # Replace nn.Linear with WQLinear
393
            for name, module in named_linears.items():
Casper Hansen's avatar
Casper Hansen committed
394
395
396
397
398
399
                if version == 'GEMM':
                    q_linear_module = WQLinear_GEMM
                elif version == 'GEMV':
                    q_linear_module = WQLinear_GEMV
                
                q_linear = q_linear_module.from_linear(
400
401
402
                    module,
                    quant_config['w_bit'],
                    quant_config['q_group_size'],
Casper Hansen's avatar
Casper Hansen committed
403
404
                    True
                )
405
406
407
408
409
410
                q_linear.to(next(layer.parameters()).device)
                set_op_by_name(layer, name, q_linear)
            
            torch.cuda.empty_cache()
            gc.collect()
    
411
    @staticmethod
412
    def _scale_activations(self, layer):
413
        scale_dict = self.get_act_for_scaling(layer)
414

415
416
417
        if scale_dict['is_scalable']:
            if not isinstance(scale_dict['scale_layer'], ScaledActivation):
                param = next(layer.parameters())
418

419
420
                # get activation scale
                scale_like = torch.ones(scale_dict['scale_shape'], dtype=param.dtype, device=param.device)
421

422
423
                # scale activation
                scaled_act = ScaledActivation(scale_dict['scale_layer'], scale_like)
Casper's avatar
Casper committed
424
                set_op_by_name(layer, scale_dict['scale_name'], scaled_act)