base.py 14.7 KB
Newer Older
1
import os
Casper Hansen's avatar
Casper Hansen committed
2
import gc
3
import json
Casper Hansen's avatar
Casper Hansen committed
4
import torch
Casper Hansen's avatar
Casper Hansen committed
5
import logging
Casper Hansen's avatar
Casper Hansen committed
6
7
import functools
import torch.nn as nn
Casper Hansen's avatar
Casper Hansen committed
8
from tqdm import tqdm
Casper Hansen's avatar
Casper Hansen committed
9
10
from collections import defaultdict

Casper Hansen's avatar
Casper Hansen committed
11
from awq.modules.qlinear import WQLinear_GEMM, WQLinear_GEMV
12
from awq.modules.act import ScaledActivation
13
from huggingface_hub import snapshot_download
Casper Hansen's avatar
Casper Hansen committed
14
from awq.utils.calib_data import get_calib_dataset
15
from awq.quantize.quantizer import pseudo_quantize_tensor
Casper Hansen's avatar
Casper Hansen committed
16
17
from awq.quantize.auto_clip import auto_clip_block, apply_clip
from awq.quantize.auto_scale import auto_scale_block, apply_scale
18
19
from transformers import AutoModelForCausalLM, AutoConfig, PreTrainedModel
from accelerate import init_empty_weights, load_checkpoint_and_dispatch, infer_auto_device_map
Casper Hansen's avatar
Casper Hansen committed
20
from awq.utils.module import append_str_prefix, get_op_name, get_named_linears, set_op_by_name
Casper Hansen's avatar
Casper Hansen committed
21

22
class BaseAWQForCausalLM(nn.Module):
23
    def __init__(self, model, model_type, is_quantized, quant_config):
24
        super().__init__()
25
26
27
28
        self.model:PreTrainedModel = model
        self.model_type:str = model_type
        self.is_quantized:bool = is_quantized
        self.search_result = None
29
        self.quant_config:dict = quant_config
30
31
32
33
34
35
    
    def to(self, device: str):
        return self.model.to(device)
    
    def forward(self, *args, **kwargs):
        return self.model(*args, **kwargs)
Casper Hansen's avatar
Casper Hansen committed
36
37
38
39
    
    def generate(self, *args, **kwargs):
        with torch.inference_mode():
            return self.model.generate(*args, **kwargs)
40

Casper Hansen's avatar
Casper Hansen committed
41
    @torch.no_grad()
42
    def quantize(self, tokenizer=None, quant_config={}, n_samples=128, seqlen=512,
43
                       auto_scale=True, mse_range=True, run_search=True, run_quant=True,
Casper Hansen's avatar
Casper Hansen committed
44
                       calib_data="pileval"):
45
        self.quant_config = quant_config
46

Casper Hansen's avatar
Casper Hansen committed
47
        if run_search:
48
            self.search_result = self._awq_search(tokenizer, quant_config, n_samples=n_samples, seqlen=seqlen,
Casper Hansen's avatar
Casper Hansen committed
49
50
51
                       auto_scale=auto_scale, mse_range=mse_range, calib_data=calib_data)
        
        if run_quant:
52
            self._awq_quant()
Casper Hansen's avatar
Casper Hansen committed
53
            self.is_quantized = True
Casper Hansen's avatar
Casper Hansen committed
54
    
qwopqwop200's avatar
qwopqwop200 committed
55
56
57
58
    @staticmethod
    def fuse_layers(model):
        pass
        
59
60
    def _awq_quant(self):
        assert self.quant_config["zero_point"], "We only support zero_point quantization now."
61
        layers = self.get_model_layers(self.model)
Casper's avatar
Casper committed
62

Casper Hansen's avatar
Casper Hansen committed
63
64
65
66
        # Run AWQ quantization
        for i in tqdm(range(len(layers)), desc="AWQ Quantization"):
            layer = layers[i]
            named_linears = get_named_linears(layer)
67
            self._scale_activations(self, layer)
Casper Hansen's avatar
Casper Hansen committed
68
69

            for name, module in named_linears.items():
Casper Hansen's avatar
Casper Hansen committed
70
                module.cuda()
71
72
73
74
75
76
77

                module.weight.data, scales, zeros = pseudo_quantize_tensor(
                    module.weight.data, 
                    get_scale_zp=True, 
                    **self.quant_config
                )

Casper Hansen's avatar
Casper Hansen committed
78
79
                scales = scales.t().contiguous()
                zeros = zeros.t().contiguous()
80

81
                q_linear = WQLinear_GEMM.from_linear(
82
83
84
85
86
87
88
89
                    module, 
                    self.quant_config['w_bit'], 
                    self.quant_config['q_group_size'], 
                    False, 
                    scales, 
                    zeros
                )

Casper Hansen's avatar
Casper Hansen committed
90
91
92
93
94
                module.cpu()
                q_linear.to(next(layer.parameters()).device)
                set_op_by_name(layer, name, q_linear)
                torch.cuda.empty_cache()
                gc.collect()
Casper Hansen's avatar
Casper Hansen committed
95
96
97
98
            
            torch.cuda.empty_cache()
            gc.collect()
    
99
    def _awq_search(self, tokenizer, quant_config, n_samples=128, seqlen=512,
Casper Hansen's avatar
Casper Hansen committed
100
                       auto_scale=True, mse_range=True, calib_data="pileval"):
101
        layers = self.get_model_layers(self.model)
Casper Hansen's avatar
Casper Hansen committed
102
103
104
105
106
107
108
109
110

        samples = get_calib_dataset(
            data=calib_data, tokenizer=tokenizer, n_samples=n_samples, block_size=seqlen)
        samples = torch.cat(samples, dim=0)

        inps = []
        layer_kwargs = {}

        layers[0] = layers[0].cuda()
111
        self.move_embed(self.model, "cuda")
Casper Hansen's avatar
Casper Hansen committed
112
113
114
115
116
117
118
119
120
        
        # get input and kwargs to layer 0
        # with_kwargs is only supported in PyTorch 2.0
        # use this Catcher hack for now
        class Catcher(nn.Module):
            def __init__(self, module):
                super().__init__()
                self.module = module

Casper's avatar
Casper committed
121
122
            def forward(self, hijacked_inputs, **kwargs):
                inps.append(hijacked_inputs)
Casper Hansen's avatar
Casper Hansen committed
123
124
125
126
127
128
                layer_kwargs.update(kwargs)
                raise ValueError  # early exit to break later inference

        # patch layer 0 to catch input and kwargs
        layers[0] = Catcher(layers[0])
        try:
129
            self.model(samples.to(next(self.model.parameters()).device))
Casper Hansen's avatar
Casper Hansen committed
130
131
132
133
134
135
136
        except ValueError:  # work with early exit
            pass
        del samples
        layers[0] = layers[0].module  # restore
        inps = inps[0]

        layers[0] = layers[0].cpu()
137
        self.move_embed(self.model, "cpu")
Casper Hansen's avatar
Casper Hansen committed
138
139
140
141
142
143
144
145
        
        gc.collect()
        torch.cuda.empty_cache()
        awq_results = {
            "scale": [],
            "clip": [],
        }

Casper Hansen's avatar
Casper Hansen committed
146
        # Run AWQ search layer by layer
147
        for i in tqdm(range(len(layers)), desc="AWQ Search"):
Casper Hansen's avatar
Casper Hansen committed
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
            layer = layers[i]
            layer = layer.cuda()
            named_linears = get_named_linears(layer)

            # firstly, get input features of all linear layers
            def cache_input_hook(m, x, y, name, feat_dict):
                x = x[0]
                x = x.detach().cpu()
                feat_dict[name].append(x)

            input_feat = defaultdict(list)
            handles = []
            for name in named_linears:
                handles.append(named_linears[name].register_forward_hook(
                    functools.partial(cache_input_hook, name=name,
                                    feat_dict=input_feat)))
            inps = inps.to(next(layer.parameters()).device)  # in case multi-gpu
            # get output as next layer's input
            inps = layer(inps, **layer_kwargs)[0]
            for h in handles:
                h.remove()
            # now solve for scaling and clipping
            input_feat = {k: torch.cat(v, dim=0) for k, v in input_feat.items()}

            # Clear GPU memory
            torch.cuda.empty_cache()

            if auto_scale:  # if it applies, we should also modify the input_feat with scales
                scales_list = auto_scale_block(
                    self,
178
179
180
                    layer,
                    layer_kwargs,
                    quant_config=quant_config,
Casper Hansen's avatar
Casper Hansen committed
181
182
                    input_feat=input_feat,
                )
183

Casper Hansen's avatar
Casper Hansen committed
184
                apply_scale(layers[i], scales_list, input_feat_dict=input_feat)
185

Casper Hansen's avatar
Casper Hansen committed
186
                # append prefix to make names global
187
                awq_results["scale"] += append_str_prefix(scales_list, get_op_name(self.model, layer) + ".")
Casper Hansen's avatar
Casper Hansen committed
188
189
190
191
192

            # Clear GPU memory
            torch.cuda.empty_cache()
            
            if mse_range:
193
194
195
196
197
198
                clip_list = auto_clip_block(
                    layer,
                    quant_config=quant_config,
                    input_feat=input_feat
                )

Casper Hansen's avatar
Casper Hansen committed
199
200
                apply_clip(layer, clip_list)
                # append prefix to make names global
201
                awq_results["clip"] += append_str_prefix(clip_list, get_op_name(self.model, layer) + ".")
Casper Hansen's avatar
Casper Hansen committed
202
203
204
205
206
207

            layer = layer.cpu()
            # Haotian: check activation replacement
            del input_feat
            gc.collect()
            torch.cuda.empty_cache()
Casper Hansen's avatar
Casper Hansen committed
208
        
Casper Hansen's avatar
Casper Hansen committed
209
        return awq_results
Casper's avatar
Casper committed
210

211
    def save_quantized(self, save_dir):
212
213
214
215
216
217
218
219
220
221
222
223
224
225
        def _save_files(save_dir, model_name, model):
            class EmptyModule(nn.Module):
                def __init__(self): super(EmptyModule, self).__init__()
                def forward(self, x): return x

            # Save model fiels without search results
            self.model.save_pretrained(save_dir, state_dict=EmptyModule().state_dict())

            # Remove empty module
            os.remove(f'{save_dir}/pytorch_model.bin')

            # Save search results
            torch.save(model, f'{save_dir}/{model_name}')

226
227
228
229
            # Save config
            with open(f'{save_dir}/quant_config.json', 'w+') as file:
                file.write(json.dumps(self.quant_config, indent=4))

230
231
232
        save_dir = save_dir[:-1] if save_dir[-1] == '/' else save_dir

        # Save model
Casper Hansen's avatar
Casper Hansen committed
233
        if self.search_result is None or self.is_quantized:
Casper Hansen's avatar
Casper Hansen committed
234
            model_name = f'awq_model_w{self.quant_config["w_bit"]}_g{self.quant_config["q_group_size"]}.pt'
235
            _save_files(save_dir, model_name, self.model.state_dict())
236
237
        else:
            model_name = 'awq_model_search_result.pt'
238
239
            _save_files(save_dir, model_name, self.search_result)
        
240
241
    @classmethod
    def from_pretrained(self, model_path, model_type, torch_dtype: torch.dtype = torch.float16, 
Casper Hansen's avatar
Casper Hansen committed
242
                        trust_remote_code=True, safetensors=False):
243
244
245
        return self.from_quantized(
            model_path, 
            model_type, 
246
            model_filename='', 
247
            max_new_tokens=None,
248
249
250
            device='balanced', 
            torch_dtype=torch_dtype, 
            trust_remote_code=trust_remote_code, 
Casper Hansen's avatar
Casper Hansen committed
251
            safetensors=safetensors,
252
253
            is_quantized=False
        )
Casper's avatar
Casper committed
254

255
    @classmethod
256
    def from_quantized(self, model_path, model_type, model_filename, max_new_tokens=None,
257
                       device='balanced', torch_dtype=torch.float16, trust_remote_code=True, 
Casper Hansen's avatar
Casper Hansen committed
258
                       safetensors=False, is_quantized=True, fuse_layers=False, version='GEMM'):
259
        # [STEP 1] Download model if path is not a directory
260
        if not os.path.isdir(model_path):
261
262
263
264
265
266
267
            ignore_patterns = ["*msgpack*", "*h5*"]
            if safetensors:
                ignore_patterns.extend(["*.pt", "*.bin"])
            else:
                ignore_patterns.append("*safetensors*")

            model_path = snapshot_download(model_path, ignore_patterns=ignore_patterns)
268
269
270
        
        # TODO: Better naming, model_filename becomes a directory
        model_filename = model_path + f'/{model_filename}'
271

272
        # [STEP 2] Load config and set sequence length
273
        # TODO: Create BaseAWQConfig class
274
275
276
277
278
279
        quant_config_path = f'{model_path}/quant_config.json'
        if os.path.exists(quant_config_path):
            with open(quant_config_path, 'r') as file:
                quant_config = json.loads(file.read())
        else:
            # Default config that works for most models
Casper Hansen's avatar
Casper Hansen committed
280
            quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM"}
281
        
282
283
284
285
286
287
288
289
290
        # Load model config and set max generation length
        if max_new_tokens is None and hasattr(self, 'max_new_tokens_key'):
            config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)
            config.max_new_tokens = getattr(config, self.max_new_tokens_key)
        else:
            max_new_tokens = 2048 if max_new_tokens is None else max_new_tokens
            config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)
            config.max_new_tokens = max_new_tokens
        
291
        # [STEP 3] Load model
292
        with init_empty_weights():
293
294
            model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code)
        
295
        # Only need to replace layers if a model is AWQ quantized
296
297
        if is_quantized:
            # Prepare WQLinear layers, replace nn.Linear
Casper Hansen's avatar
Casper Hansen committed
298
            self._load_quantized_modules(self, model, quant_config, version)
299
300
        
        model.tie_weights()
301

302
303
304
305
306
307
        device_map = infer_auto_device_map(
            model,
            no_split_module_classes=[self.layer_type], 
            dtype=torch_dtype
        )

308
        # Load model weights
309
        if is_quantized:
310
311
312
313
314
315
            model = load_checkpoint_and_dispatch(
                model, 
                model_filename, 
                device_map=device_map, 
                no_split_module_classes=[self.layer_type]
            )
316

317
318
319
            if fuse_layers:
                self.fuse_layers(model)

320
321
        else:
            # If not quantized, must load with AutoModelForCausalLM
322
323
324
325
            del model
            
            # Load model weights
            model = AutoModelForCausalLM.from_pretrained(
326
327
328
329
330
331
332
                model_filename, 
                device_map=device_map, 
                trust_remote_code=trust_remote_code, 
                offload_folder="offload", 
                offload_state_dict=True, 
                torch_dtype=torch_dtype, 
                use_safetensors=safetensors
333
334
            )
            model.eval()
335

336
        return self(model, model_type, is_quantized=is_quantized, quant_config=quant_config)
Casper's avatar
Casper committed
337

Casper Hansen's avatar
Casper Hansen committed
338
    def _load_quantized_modules(self, model, quant_config, version):
339
        # Real quantization of weights
340
        assert quant_config["zero_point"], "We only support zero_point quantization now."
Casper Hansen's avatar
Casper Hansen committed
341
342
343
344
345

        if version == 'GEMM':
            logging.warning('Deprecated model weight format. Re-quantize '
                            'your weights again with version="GEMV" for a speedup. '
                            'In the next AutoAWQ version, GEMM will be deprecated.')
346
347
        
        # Get blocks of model
348
        layers = self.get_model_layers(model)
349

350
351
        for i in tqdm(range(len(layers)), desc="Replacing layers..."):
            layer = layers[i]
352
353

            # Get every linear layer in a block
354
            named_linears = get_named_linears(layer)
355
356

            # Replace activation functions
357
            self._scale_activations(self, layer)
358

359
            # Replace nn.Linear with WQLinear
360
            for name, module in named_linears.items():
Casper Hansen's avatar
Casper Hansen committed
361
362
363
364
365
366
367
368
369
370
371
                if version == 'GEMM':
                    q_linear_module = WQLinear_GEMM
                elif version == 'GEMV':
                    q_linear_module = WQLinear_GEMV
                
                q_linear = q_linear_module.from_linear(
                    module, 
                    quant_config['w_bit'], 
                    quant_config['q_group_size'], 
                    True
                )
372
373
374
375
376
377
                q_linear.to(next(layer.parameters()).device)
                set_op_by_name(layer, name, q_linear)
            
            torch.cuda.empty_cache()
            gc.collect()
    
378
    @staticmethod
379
    def _scale_activations(self, layer):
380
        scale_dict = self.get_act_for_scaling(layer)
381

382
383
384
        if scale_dict['is_scalable']:
            if not isinstance(scale_dict['scale_layer'], ScaledActivation):
                param = next(layer.parameters())
385

386
387
                # get activation scale
                scale_like = torch.ones(scale_dict['scale_shape'], dtype=param.dtype, device=param.device)
388

389
390
                # scale activation
                scaled_act = ScaledActivation(scale_dict['scale_layer'], scale_like)
Casper's avatar
Casper committed
391
                set_op_by_name(layer, scale_dict['scale_name'], scaled_act)