base.py 16.3 KB
Newer Older
1
import os
Casper Hansen's avatar
Casper Hansen committed
2
import gc
3
import json
Casper Hansen's avatar
Casper Hansen committed
4
import torch
Casper Hansen's avatar
Casper Hansen committed
5
import logging
Casper Hansen's avatar
Casper Hansen committed
6
7
import functools
import torch.nn as nn
Casper Hansen's avatar
Casper Hansen committed
8
from tqdm import tqdm
Casper Hansen's avatar
Casper Hansen committed
9
from collections import defaultdict
10
from safetensors.torch import save_file
Casper Hansen's avatar
Casper Hansen committed
11

12
from awq.modules.act import ScaledActivation
13
from huggingface_hub import snapshot_download
Casper Hansen's avatar
Casper Hansen committed
14
from awq.utils.calib_data import get_calib_dataset
15
from transformers.modeling_utils import shard_checkpoint
16
from awq.quantize.quantizer import pseudo_quantize_tensor
17
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
Casper Hansen's avatar
Casper Hansen committed
18
19
from awq.quantize.auto_clip import auto_clip_block, apply_clip
from awq.quantize.auto_scale import auto_scale_block, apply_scale
20
21
from transformers import AutoModelForCausalLM, AutoConfig, PreTrainedModel
from accelerate import init_empty_weights, load_checkpoint_and_dispatch, infer_auto_device_map
Casper Hansen's avatar
Casper Hansen committed
22
from awq.utils.module import append_str_prefix, get_op_name, get_named_linears, set_op_by_name
Casper Hansen's avatar
Casper Hansen committed
23

24
class BaseAWQForCausalLM(nn.Module):
25
    def __init__(self, model, model_type, is_quantized, quant_config):
26
        super().__init__()
27
28
29
30
        self.model:PreTrainedModel = model
        self.model_type:str = model_type
        self.is_quantized:bool = is_quantized
        self.search_result = None
31
        self.quant_config:dict = quant_config
32
33
34
35
36
37
    
    def to(self, device: str):
        return self.model.to(device)
    
    def forward(self, *args, **kwargs):
        return self.model(*args, **kwargs)
Casper Hansen's avatar
Casper Hansen committed
38
39
40
41
    
    def generate(self, *args, **kwargs):
        with torch.inference_mode():
            return self.model.generate(*args, **kwargs)
42

Casper Hansen's avatar
Casper Hansen committed
43
    @torch.no_grad()
44
    def quantize(self, tokenizer=None, quant_config={}, n_samples=128, seqlen=512,
45
                       auto_scale=True, mse_range=True, run_search=True, run_quant=True,
Casper Hansen's avatar
Casper Hansen committed
46
                       calib_data="pileval"):
47
        self.quant_config = quant_config
48
        quant_config["version"] = "GEMM" if 'version' not in quant_config.keys() else quant_config["version"]
49

Casper Hansen's avatar
Casper Hansen committed
50
        if run_search:
51
            self.search_result = self._awq_search(tokenizer, quant_config, n_samples=n_samples, seqlen=seqlen,
Casper Hansen's avatar
Casper Hansen committed
52
53
54
                       auto_scale=auto_scale, mse_range=mse_range, calib_data=calib_data)
        
        if run_quant:
55
            self._awq_quant()
Casper Hansen's avatar
Casper Hansen committed
56
            self.is_quantized = True
Casper Hansen's avatar
Casper Hansen committed
57
    
qwopqwop200's avatar
qwopqwop200 committed
58
    @staticmethod
59
    def fuse_layers(model, quant_config):
qwopqwop200's avatar
qwopqwop200 committed
60
61
        pass
        
62
63
    def _awq_quant(self):
        assert self.quant_config["zero_point"], "We only support zero_point quantization now."
64
        layers = self.get_model_layers(self.model)
Casper's avatar
Casper committed
65

Casper Hansen's avatar
Casper Hansen committed
66
67
68
69
        # Run AWQ quantization
        for i in tqdm(range(len(layers)), desc="AWQ Quantization"):
            layer = layers[i]
            named_linears = get_named_linears(layer)
70
            self._scale_activations(self, layer)
Casper Hansen's avatar
Casper Hansen committed
71
72

            for name, module in named_linears.items():
Casper Hansen's avatar
Casper Hansen committed
73
                module.cuda()
74
75
76
77

                module.weight.data, scales, zeros = pseudo_quantize_tensor(
                    module.weight.data, 
                    get_scale_zp=True, 
78
79
                    w_bit=self.quant_config["w_bit"], 
                    q_group_size=self.quant_config["q_group_size"]
80
81
                )

82
                if self.quant_config["version"] == 'GEMM':
83
84
                    scales = scales.t().contiguous()
                    zeros = zeros.t().contiguous()
85
86
87
88
89
90
91
92
93
94
                    q_linear_module = WQLinear_GEMM
                elif self.quant_config["version"] == 'GEMV':
                    q_linear_module = WQLinear_GEMV
                
                q_linear = q_linear_module.from_linear(
                    module,
                    self.quant_config['w_bit'],
                    self.quant_config['q_group_size'],
                    False,
                    scales,
95
96
97
                    zeros
                )

Casper Hansen's avatar
Casper Hansen committed
98
99
100
101
102
                module.cpu()
                q_linear.to(next(layer.parameters()).device)
                set_op_by_name(layer, name, q_linear)
                torch.cuda.empty_cache()
                gc.collect()
Casper Hansen's avatar
Casper Hansen committed
103
104
105
106
            
            torch.cuda.empty_cache()
            gc.collect()
    
107
    def _awq_search(self, tokenizer, quant_config, n_samples=128, seqlen=512,
Casper Hansen's avatar
Casper Hansen committed
108
                       auto_scale=True, mse_range=True, calib_data="pileval"):
109
        layers = self.get_model_layers(self.model)
Casper Hansen's avatar
Casper Hansen committed
110
111
112
113
114
115
116
117
118

        samples = get_calib_dataset(
            data=calib_data, tokenizer=tokenizer, n_samples=n_samples, block_size=seqlen)
        samples = torch.cat(samples, dim=0)

        inps = []
        layer_kwargs = {}

        layers[0] = layers[0].cuda()
119
        self.move_embed(self.model, "cuda")
Casper Hansen's avatar
Casper Hansen committed
120
121
122
123
124
125
126
127
128
        
        # get input and kwargs to layer 0
        # with_kwargs is only supported in PyTorch 2.0
        # use this Catcher hack for now
        class Catcher(nn.Module):
            def __init__(self, module):
                super().__init__()
                self.module = module

Casper's avatar
Casper committed
129
130
            def forward(self, hijacked_inputs, **kwargs):
                inps.append(hijacked_inputs)
Casper Hansen's avatar
Casper Hansen committed
131
132
133
134
135
136
                layer_kwargs.update(kwargs)
                raise ValueError  # early exit to break later inference

        # patch layer 0 to catch input and kwargs
        layers[0] = Catcher(layers[0])
        try:
137
            self.model(samples.to(next(self.model.parameters()).device))
Casper Hansen's avatar
Casper Hansen committed
138
139
140
141
142
143
144
        except ValueError:  # work with early exit
            pass
        del samples
        layers[0] = layers[0].module  # restore
        inps = inps[0]

        layers[0] = layers[0].cpu()
145
        self.move_embed(self.model, "cpu")
Casper Hansen's avatar
Casper Hansen committed
146
147
148
149
150
151
152
153
        
        gc.collect()
        torch.cuda.empty_cache()
        awq_results = {
            "scale": [],
            "clip": [],
        }

Casper Hansen's avatar
Casper Hansen committed
154
        # Run AWQ search layer by layer
155
        for i in tqdm(range(len(layers)), desc="AWQ Search"):
Casper Hansen's avatar
Casper Hansen committed
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
            layer = layers[i]
            layer = layer.cuda()
            named_linears = get_named_linears(layer)

            # firstly, get input features of all linear layers
            def cache_input_hook(m, x, y, name, feat_dict):
                x = x[0]
                x = x.detach().cpu()
                feat_dict[name].append(x)

            input_feat = defaultdict(list)
            handles = []
            for name in named_linears:
                handles.append(named_linears[name].register_forward_hook(
                    functools.partial(cache_input_hook, name=name,
                                    feat_dict=input_feat)))
            inps = inps.to(next(layer.parameters()).device)  # in case multi-gpu
            # get output as next layer's input
            inps = layer(inps, **layer_kwargs)[0]
            for h in handles:
                h.remove()
            # now solve for scaling and clipping
            input_feat = {k: torch.cat(v, dim=0) for k, v in input_feat.items()}

            # Clear GPU memory
            torch.cuda.empty_cache()

            if auto_scale:  # if it applies, we should also modify the input_feat with scales
                scales_list = auto_scale_block(
                    self,
186
187
188
                    layer,
                    layer_kwargs,
                    quant_config=quant_config,
Casper Hansen's avatar
Casper Hansen committed
189
190
                    input_feat=input_feat,
                )
191

Casper Hansen's avatar
Casper Hansen committed
192
                apply_scale(layers[i], scales_list, input_feat_dict=input_feat)
193

Casper Hansen's avatar
Casper Hansen committed
194
                # append prefix to make names global
195
                awq_results["scale"] += append_str_prefix(scales_list, get_op_name(self.model, layer) + ".")
Casper Hansen's avatar
Casper Hansen committed
196
197
198
199
200

            # Clear GPU memory
            torch.cuda.empty_cache()
            
            if mse_range:
201
202
203
204
205
206
                clip_list = auto_clip_block(
                    layer,
                    quant_config=quant_config,
                    input_feat=input_feat
                )

Casper Hansen's avatar
Casper Hansen committed
207
208
                apply_clip(layer, clip_list)
                # append prefix to make names global
209
                awq_results["clip"] += append_str_prefix(clip_list, get_op_name(self.model, layer) + ".")
Casper Hansen's avatar
Casper Hansen committed
210
211
212
213
214
215

            layer = layer.cpu()
            # Haotian: check activation replacement
            del input_feat
            gc.collect()
            torch.cuda.empty_cache()
Casper Hansen's avatar
Casper Hansen committed
216
        
Casper Hansen's avatar
Casper Hansen committed
217
        return awq_results
Casper's avatar
Casper committed
218

219
220
    def save_quantized(self, save_dir, use_safetensors=False, shard_size="10GB"):
        def _save_files(save_dir, model_name, search_result=None):
221
222
223
224
            class EmptyModule(nn.Module):
                def __init__(self): super(EmptyModule, self).__init__()
                def forward(self, x): return x

225
            # Save model files with empty state dict
226
227
            self.model.save_pretrained(save_dir, state_dict=EmptyModule().state_dict())

228
            # Remove empty state dict
229
230
            os.remove(f'{save_dir}/pytorch_model.bin')

231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
            if search_result is not None:
                torch.save(search_result, f'{save_dir}/{model_name}')
            else:
                # model_name has no extension, add it when saving state_dict
                model_name += '.safetensors' if use_safetensors else '.bin'

                # shard checkpoint into chunks (10GB default)
                shards, index = shard_checkpoint(
                    self.model.state_dict(), 
                    max_shard_size=shard_size, 
                    weights_name=model_name
                )

                for shard_file, shard in shards.items():
                    if use_safetensors:
                        # safetensors must be in the same memory, so we duplicate and use contiguous memory
                        shard = {k: v.clone().contiguous() for k, v in shard.items()}
                        save_file(shard, os.path.join(save_dir, shard_file), metadata={"format": "pt"})
                    else:
                        torch.save(shard, os.path.join(save_dir, shard_file))

                # save shard index
                if index is not None:
                    with open(f'{save_dir}/{model_name}.index.json', 'w+') as file:
                        file.write(json.dumps(index, indent=4))
256

257
258
259
260
            # Save config
            with open(f'{save_dir}/quant_config.json', 'w+') as file:
                file.write(json.dumps(self.quant_config, indent=4))

261
262
263
        save_dir = save_dir[:-1] if save_dir[-1] == '/' else save_dir

        # Save model
Casper Hansen's avatar
Casper Hansen committed
264
        if self.search_result is None or self.is_quantized:
265
266
            model_name = f'awq_model_w{self.quant_config["w_bit"]}_g{self.quant_config["q_group_size"]}'
            _save_files(save_dir, model_name, search_result=None)
267
268
        else:
            model_name = 'awq_model_search_result.pt'
269
270
            _save_files(save_dir, model_name, self.search_result)
        
271
272
    @classmethod
    def from_pretrained(self, model_path, model_type, torch_dtype: torch.dtype = torch.float16, 
Casper Hansen's avatar
Casper Hansen committed
273
                        trust_remote_code=True, safetensors=False):
274
275
276
        return self.from_quantized(
            model_path, 
            model_type, 
277
            model_filename='', 
278
            max_new_tokens=None,
279
280
281
            device='balanced', 
            torch_dtype=torch_dtype, 
            trust_remote_code=trust_remote_code, 
Casper Hansen's avatar
Casper Hansen committed
282
            safetensors=safetensors,
283
284
            is_quantized=False
        )
Casper's avatar
Casper committed
285

286
    @classmethod
287
    def from_quantized(self, model_path, model_type, model_filename, max_new_tokens=None,
288
                       device='balanced', torch_dtype=torch.float16, trust_remote_code=True, 
Casper Hansen's avatar
Casper Hansen committed
289
                       safetensors=False, is_quantized=True, fuse_layers=False, version='GEMM'):
290
        # [STEP 1] Download model if path is not a directory
291
        if not os.path.isdir(model_path):
292
293
294
295
296
297
298
            ignore_patterns = ["*msgpack*", "*h5*"]
            if safetensors:
                ignore_patterns.extend(["*.pt", "*.bin"])
            else:
                ignore_patterns.append("*safetensors*")

            model_path = snapshot_download(model_path, ignore_patterns=ignore_patterns)
299
300
301
        
        # TODO: Better naming, model_filename becomes a directory
        model_filename = model_path + f'/{model_filename}'
302

303
        # [STEP 2] Load config and set sequence length
304
        # TODO: Create BaseAWQConfig class
305
306
307
308
        quant_config_path = f'{model_path}/quant_config.json'
        if os.path.exists(quant_config_path):
            with open(quant_config_path, 'r') as file:
                quant_config = json.loads(file.read())
309
310
311
            
            if "version" not in quant_config.keys():
                quant_config["version"] = version
312
313
        else:
            # Default config that works for most models
314
            quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4, "version": version}
315
        
316
317
318
319
320
321
322
323
324
        # Load model config and set max generation length
        if max_new_tokens is None and hasattr(self, 'max_new_tokens_key'):
            config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)
            config.max_new_tokens = getattr(config, self.max_new_tokens_key)
        else:
            max_new_tokens = 2048 if max_new_tokens is None else max_new_tokens
            config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)
            config.max_new_tokens = max_new_tokens
        
325
        # [STEP 3] Load model
326
        with init_empty_weights():
327
328
            model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code)
        
329
        # Only need to replace layers if a model is AWQ quantized
330
331
        if is_quantized:
            # Prepare WQLinear layers, replace nn.Linear
332
            self._load_quantized_modules(self, model, quant_config, quant_config["version"])
333
334
        
        model.tie_weights()
335

336
337
338
339
340
341
        device_map = infer_auto_device_map(
            model,
            no_split_module_classes=[self.layer_type], 
            dtype=torch_dtype
        )

342
        # Load model weights
343
        if is_quantized:
344
345
346
347
348
349
            model = load_checkpoint_and_dispatch(
                model, 
                model_filename, 
                device_map=device_map, 
                no_split_module_classes=[self.layer_type]
            )
350

351
            if fuse_layers:
352
                self.fuse_layers(model, quant_config)
353

354
355
        else:
            # If not quantized, must load with AutoModelForCausalLM
356
357
358
359
            del model
            
            # Load model weights
            model = AutoModelForCausalLM.from_pretrained(
360
361
362
363
364
365
366
                model_filename, 
                device_map=device_map, 
                trust_remote_code=trust_remote_code, 
                offload_folder="offload", 
                offload_state_dict=True, 
                torch_dtype=torch_dtype, 
                use_safetensors=safetensors
367
368
            )
            model.eval()
369

370
        return self(model, model_type, is_quantized=is_quantized, quant_config=quant_config)
Casper's avatar
Casper committed
371

Casper Hansen's avatar
Casper Hansen committed
372
    def _load_quantized_modules(self, model, quant_config, version):
373
        # Real quantization of weights
374
        assert quant_config["zero_point"], "We only support zero_point quantization now."
375
376
        
        # Get blocks of model
377
        layers = self.get_model_layers(model)
378

379
380
        for i in tqdm(range(len(layers)), desc="Replacing layers..."):
            layer = layers[i]
381
382

            # Get every linear layer in a block
383
            named_linears = get_named_linears(layer)
384
385

            # Replace activation functions
386
            self._scale_activations(self, layer)
387

388
            # Replace nn.Linear with WQLinear
389
            for name, module in named_linears.items():
Casper Hansen's avatar
Casper Hansen committed
390
391
392
393
394
395
                if version == 'GEMM':
                    q_linear_module = WQLinear_GEMM
                elif version == 'GEMV':
                    q_linear_module = WQLinear_GEMV
                
                q_linear = q_linear_module.from_linear(
396
397
398
                    module,
                    quant_config['w_bit'],
                    quant_config['q_group_size'],
Casper Hansen's avatar
Casper Hansen committed
399
400
                    True
                )
401
402
403
404
405
406
                q_linear.to(next(layer.parameters()).device)
                set_op_by_name(layer, name, q_linear)
            
            torch.cuda.empty_cache()
            gc.collect()
    
407
    @staticmethod
408
    def _scale_activations(self, layer):
409
        scale_dict = self.get_act_for_scaling(layer)
410

411
412
413
        if scale_dict['is_scalable']:
            if not isinstance(scale_dict['scale_layer'], ScaledActivation):
                param = next(layer.parameters())
414

415
416
                # get activation scale
                scale_like = torch.ones(scale_dict['scale_shape'], dtype=param.dtype, device=param.device)
417

418
419
                # scale activation
                scaled_act = ScaledActivation(scale_dict['scale_layer'], scale_like)
Casper's avatar
Casper committed
420
                set_op_by_name(layer, scale_dict['scale_name'], scaled_act)