base.py 15.2 KB
Newer Older
1
import os
Casper Hansen's avatar
Casper Hansen committed
2
import gc
3
import json
Casper Hansen's avatar
Casper Hansen committed
4
import torch
Casper Hansen's avatar
Casper Hansen committed
5
import logging
Casper Hansen's avatar
Casper Hansen committed
6
7
import functools
import torch.nn as nn
Casper Hansen's avatar
Casper Hansen committed
8
from tqdm import tqdm
9
from typing import List, Union
Casper Hansen's avatar
Casper Hansen committed
10
11
from collections import defaultdict

12
from awq.modules.act import ScaledActivation
13
from huggingface_hub import snapshot_download
Casper Hansen's avatar
Casper Hansen committed
14
from awq.utils.calib_data import get_calib_dataset
15
from awq.quantize.quantizer import pseudo_quantize_tensor
16
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
Casper Hansen's avatar
Casper Hansen committed
17
18
from awq.quantize.auto_clip import auto_clip_block, apply_clip
from awq.quantize.auto_scale import auto_scale_block, apply_scale
19
20
from transformers import AutoModelForCausalLM, AutoConfig, PreTrainedModel
from accelerate import init_empty_weights, load_checkpoint_and_dispatch, infer_auto_device_map
Casper Hansen's avatar
Casper Hansen committed
21
from awq.utils.module import append_str_prefix, get_op_name, get_named_linears, set_op_by_name
Casper Hansen's avatar
Casper Hansen committed
22

23
class BaseAWQForCausalLM(nn.Module):
24
    def __init__(self, model, model_type, is_quantized, quant_config):
25
        super().__init__()
26
27
28
29
        self.model:PreTrainedModel = model
        self.model_type:str = model_type
        self.is_quantized:bool = is_quantized
        self.search_result = None
30
        self.quant_config:dict = quant_config
31
32
33
34
35
36
    
    def to(self, device: str):
        return self.model.to(device)
    
    def forward(self, *args, **kwargs):
        return self.model(*args, **kwargs)
Casper Hansen's avatar
Casper Hansen committed
37
38
39
40
    
    def generate(self, *args, **kwargs):
        with torch.inference_mode():
            return self.model.generate(*args, **kwargs)
41

Casper Hansen's avatar
Casper Hansen committed
42
    @torch.no_grad()
43
    def quantize(self, tokenizer=None, quant_config={}, n_samples=128, seqlen=512,
44
                       auto_scale=True, mse_range=True, run_search=True, run_quant=True,
45
46
                       calib_data: Union[str, List[str]]="pileval", split="train",
                       text_column="text"):
47
        self.quant_config = quant_config
48
        quant_config["version"] = "GEMM" if 'version' not in quant_config.keys() else quant_config["version"]
49

Casper Hansen's avatar
Casper Hansen committed
50
        if run_search:
51
52
53
54
55
            self.search_result = self._awq_search(
                tokenizer, quant_config, n_samples=n_samples, seqlen=seqlen,
                auto_scale=auto_scale, mse_range=mse_range, calib_data=calib_data,
                split=split, text_column=text_column
            )
Casper Hansen's avatar
Casper Hansen committed
56
57
        
        if run_quant:
58
            self._awq_quant()
Casper Hansen's avatar
Casper Hansen committed
59
            self.is_quantized = True
Casper Hansen's avatar
Casper Hansen committed
60
    
qwopqwop200's avatar
qwopqwop200 committed
61
    @staticmethod
62
    def fuse_layers(model, quant_config):
qwopqwop200's avatar
qwopqwop200 committed
63
64
        pass
        
65
66
    def _awq_quant(self):
        assert self.quant_config["zero_point"], "We only support zero_point quantization now."
67
        layers = self.get_model_layers(self.model)
Casper's avatar
Casper committed
68

Casper Hansen's avatar
Casper Hansen committed
69
70
71
72
        # Run AWQ quantization
        for i in tqdm(range(len(layers)), desc="AWQ Quantization"):
            layer = layers[i]
            named_linears = get_named_linears(layer)
73
            self._scale_activations(self, layer)
Casper Hansen's avatar
Casper Hansen committed
74
75

            for name, module in named_linears.items():
Casper Hansen's avatar
Casper Hansen committed
76
                module.cuda()
77
78
79
80

                module.weight.data, scales, zeros = pseudo_quantize_tensor(
                    module.weight.data, 
                    get_scale_zp=True, 
81
82
                    w_bit=self.quant_config["w_bit"], 
                    q_group_size=self.quant_config["q_group_size"]
83
84
                )

85
                if self.quant_config["version"] == 'GEMM':
86
87
                    scales = scales.t().contiguous()
                    zeros = zeros.t().contiguous()
88
89
90
91
92
93
94
95
96
97
                    q_linear_module = WQLinear_GEMM
                elif self.quant_config["version"] == 'GEMV':
                    q_linear_module = WQLinear_GEMV
                
                q_linear = q_linear_module.from_linear(
                    module,
                    self.quant_config['w_bit'],
                    self.quant_config['q_group_size'],
                    False,
                    scales,
98
99
100
                    zeros
                )

Casper Hansen's avatar
Casper Hansen committed
101
102
103
104
105
                module.cpu()
                q_linear.to(next(layer.parameters()).device)
                set_op_by_name(layer, name, q_linear)
                torch.cuda.empty_cache()
                gc.collect()
Casper Hansen's avatar
Casper Hansen committed
106
107
108
109
            
            torch.cuda.empty_cache()
            gc.collect()
    
110
    def _awq_search(self, tokenizer, quant_config, n_samples=128, seqlen=512,
111
                       auto_scale=True, mse_range=True, calib_data:Union[str, List[str]]="pileval"):
112
        layers = self.get_model_layers(self.model)
Casper Hansen's avatar
Casper Hansen committed
113
114
115
116
117
118
119
120
121

        samples = get_calib_dataset(
            data=calib_data, tokenizer=tokenizer, n_samples=n_samples, block_size=seqlen)
        samples = torch.cat(samples, dim=0)

        inps = []
        layer_kwargs = {}

        layers[0] = layers[0].cuda()
122
        self.move_embed(self.model, "cuda")
Casper Hansen's avatar
Casper Hansen committed
123
124
125
126
127
128
129
130
131
        
        # get input and kwargs to layer 0
        # with_kwargs is only supported in PyTorch 2.0
        # use this Catcher hack for now
        class Catcher(nn.Module):
            def __init__(self, module):
                super().__init__()
                self.module = module

Casper's avatar
Casper committed
132
133
            def forward(self, hijacked_inputs, **kwargs):
                inps.append(hijacked_inputs)
Casper Hansen's avatar
Casper Hansen committed
134
135
136
137
138
139
                layer_kwargs.update(kwargs)
                raise ValueError  # early exit to break later inference

        # patch layer 0 to catch input and kwargs
        layers[0] = Catcher(layers[0])
        try:
140
            self.model(samples.to(next(self.model.parameters()).device))
Casper Hansen's avatar
Casper Hansen committed
141
142
143
144
145
146
147
        except ValueError:  # work with early exit
            pass
        del samples
        layers[0] = layers[0].module  # restore
        inps = inps[0]

        layers[0] = layers[0].cpu()
148
        self.move_embed(self.model, "cpu")
Casper Hansen's avatar
Casper Hansen committed
149
150
151
152
153
154
155
156
        
        gc.collect()
        torch.cuda.empty_cache()
        awq_results = {
            "scale": [],
            "clip": [],
        }

Casper Hansen's avatar
Casper Hansen committed
157
        # Run AWQ search layer by layer
158
        for i in tqdm(range(len(layers)), desc="AWQ Search"):
Casper Hansen's avatar
Casper Hansen committed
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
            layer = layers[i]
            layer = layer.cuda()
            named_linears = get_named_linears(layer)

            # firstly, get input features of all linear layers
            def cache_input_hook(m, x, y, name, feat_dict):
                x = x[0]
                x = x.detach().cpu()
                feat_dict[name].append(x)

            input_feat = defaultdict(list)
            handles = []
            for name in named_linears:
                handles.append(named_linears[name].register_forward_hook(
                    functools.partial(cache_input_hook, name=name,
                                    feat_dict=input_feat)))
            inps = inps.to(next(layer.parameters()).device)  # in case multi-gpu
            # get output as next layer's input
            inps = layer(inps, **layer_kwargs)[0]
            for h in handles:
                h.remove()
            # now solve for scaling and clipping
            input_feat = {k: torch.cat(v, dim=0) for k, v in input_feat.items()}

            # Clear GPU memory
            torch.cuda.empty_cache()

            if auto_scale:  # if it applies, we should also modify the input_feat with scales
                scales_list = auto_scale_block(
                    self,
189
190
191
                    layer,
                    layer_kwargs,
                    quant_config=quant_config,
Casper Hansen's avatar
Casper Hansen committed
192
193
                    input_feat=input_feat,
                )
194

Casper Hansen's avatar
Casper Hansen committed
195
                apply_scale(layers[i], scales_list, input_feat_dict=input_feat)
196

Casper Hansen's avatar
Casper Hansen committed
197
                # append prefix to make names global
198
                awq_results["scale"] += append_str_prefix(scales_list, get_op_name(self.model, layer) + ".")
Casper Hansen's avatar
Casper Hansen committed
199
200
201
202
203

            # Clear GPU memory
            torch.cuda.empty_cache()
            
            if mse_range:
204
205
206
207
208
209
                clip_list = auto_clip_block(
                    layer,
                    quant_config=quant_config,
                    input_feat=input_feat
                )

Casper Hansen's avatar
Casper Hansen committed
210
211
                apply_clip(layer, clip_list)
                # append prefix to make names global
212
                awq_results["clip"] += append_str_prefix(clip_list, get_op_name(self.model, layer) + ".")
Casper Hansen's avatar
Casper Hansen committed
213
214
215
216
217
218

            layer = layer.cpu()
            # Haotian: check activation replacement
            del input_feat
            gc.collect()
            torch.cuda.empty_cache()
Casper Hansen's avatar
Casper Hansen committed
219
        
Casper Hansen's avatar
Casper Hansen committed
220
        return awq_results
Casper's avatar
Casper committed
221

222
    def save_quantized(self, save_dir):
223
224
225
226
227
228
229
230
231
232
233
234
235
236
        def _save_files(save_dir, model_name, model):
            class EmptyModule(nn.Module):
                def __init__(self): super(EmptyModule, self).__init__()
                def forward(self, x): return x

            # Save model fiels without search results
            self.model.save_pretrained(save_dir, state_dict=EmptyModule().state_dict())

            # Remove empty module
            os.remove(f'{save_dir}/pytorch_model.bin')

            # Save search results
            torch.save(model, f'{save_dir}/{model_name}')

237
238
239
240
            # Save config
            with open(f'{save_dir}/quant_config.json', 'w+') as file:
                file.write(json.dumps(self.quant_config, indent=4))

241
242
243
        save_dir = save_dir[:-1] if save_dir[-1] == '/' else save_dir

        # Save model
Casper Hansen's avatar
Casper Hansen committed
244
        if self.search_result is None or self.is_quantized:
Casper Hansen's avatar
Casper Hansen committed
245
            model_name = f'awq_model_w{self.quant_config["w_bit"]}_g{self.quant_config["q_group_size"]}.pt'
246
            _save_files(save_dir, model_name, self.model.state_dict())
247
248
        else:
            model_name = 'awq_model_search_result.pt'
249
250
            _save_files(save_dir, model_name, self.search_result)
        
251
252
    @classmethod
    def from_pretrained(self, model_path, model_type, torch_dtype: torch.dtype = torch.float16, 
Casper Hansen's avatar
Casper Hansen committed
253
                        trust_remote_code=True, safetensors=False):
254
255
256
        return self.from_quantized(
            model_path, 
            model_type, 
257
            model_filename='', 
258
            max_new_tokens=None,
259
260
261
            device='balanced', 
            torch_dtype=torch_dtype, 
            trust_remote_code=trust_remote_code, 
Casper Hansen's avatar
Casper Hansen committed
262
            safetensors=safetensors,
263
264
            is_quantized=False
        )
Casper's avatar
Casper committed
265

266
    @classmethod
267
    def from_quantized(self, model_path, model_type, model_filename, max_new_tokens=None,
268
                       device='balanced', torch_dtype=torch.float16, trust_remote_code=True, 
Casper Hansen's avatar
Casper Hansen committed
269
                       safetensors=False, is_quantized=True, fuse_layers=False, version='GEMM'):
270
        # [STEP 1] Download model if path is not a directory
271
        if not os.path.isdir(model_path):
272
273
274
275
276
277
278
            ignore_patterns = ["*msgpack*", "*h5*"]
            if safetensors:
                ignore_patterns.extend(["*.pt", "*.bin"])
            else:
                ignore_patterns.append("*safetensors*")

            model_path = snapshot_download(model_path, ignore_patterns=ignore_patterns)
279
280
281
        
        # TODO: Better naming, model_filename becomes a directory
        model_filename = model_path + f'/{model_filename}'
282

283
        # [STEP 2] Load config and set sequence length
284
        # TODO: Create BaseAWQConfig class
285
286
287
288
        quant_config_path = f'{model_path}/quant_config.json'
        if os.path.exists(quant_config_path):
            with open(quant_config_path, 'r') as file:
                quant_config = json.loads(file.read())
289
290
291
            
            if "version" not in quant_config.keys():
                quant_config["version"] = version
292
293
        else:
            # Default config that works for most models
294
            quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4, "version": version}
295
        
296
297
298
299
300
301
302
303
304
        # Load model config and set max generation length
        if max_new_tokens is None and hasattr(self, 'max_new_tokens_key'):
            config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)
            config.max_new_tokens = getattr(config, self.max_new_tokens_key)
        else:
            max_new_tokens = 2048 if max_new_tokens is None else max_new_tokens
            config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)
            config.max_new_tokens = max_new_tokens
        
305
        # [STEP 3] Load model
306
        with init_empty_weights():
307
308
            model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code)
        
309
        # Only need to replace layers if a model is AWQ quantized
310
311
        if is_quantized:
            # Prepare WQLinear layers, replace nn.Linear
312
            self._load_quantized_modules(self, model, quant_config, quant_config["version"])
313
314
        
        model.tie_weights()
315

316
317
318
319
320
321
        device_map = infer_auto_device_map(
            model,
            no_split_module_classes=[self.layer_type], 
            dtype=torch_dtype
        )

322
        # Load model weights
323
        if is_quantized:
324
325
326
327
328
329
            model = load_checkpoint_and_dispatch(
                model, 
                model_filename, 
                device_map=device_map, 
                no_split_module_classes=[self.layer_type]
            )
330

331
            if fuse_layers:
332
                self.fuse_layers(model, quant_config)
333

334
335
        else:
            # If not quantized, must load with AutoModelForCausalLM
336
337
338
339
            del model
            
            # Load model weights
            model = AutoModelForCausalLM.from_pretrained(
340
341
342
343
344
345
346
                model_filename, 
                device_map=device_map, 
                trust_remote_code=trust_remote_code, 
                offload_folder="offload", 
                offload_state_dict=True, 
                torch_dtype=torch_dtype, 
                use_safetensors=safetensors
347
348
            )
            model.eval()
349

350
        return self(model, model_type, is_quantized=is_quantized, quant_config=quant_config)
Casper's avatar
Casper committed
351

Casper Hansen's avatar
Casper Hansen committed
352
    def _load_quantized_modules(self, model, quant_config, version):
353
        # Real quantization of weights
354
        assert quant_config["zero_point"], "We only support zero_point quantization now."
355
356
        
        # Get blocks of model
357
        layers = self.get_model_layers(model)
358

359
360
        for i in tqdm(range(len(layers)), desc="Replacing layers..."):
            layer = layers[i]
361
362

            # Get every linear layer in a block
363
            named_linears = get_named_linears(layer)
364
365

            # Replace activation functions
366
            self._scale_activations(self, layer)
367

368
            # Replace nn.Linear with WQLinear
369
            for name, module in named_linears.items():
Casper Hansen's avatar
Casper Hansen committed
370
371
372
373
374
375
                if version == 'GEMM':
                    q_linear_module = WQLinear_GEMM
                elif version == 'GEMV':
                    q_linear_module = WQLinear_GEMV
                
                q_linear = q_linear_module.from_linear(
376
377
378
                    module,
                    quant_config['w_bit'],
                    quant_config['q_group_size'],
Casper Hansen's avatar
Casper Hansen committed
379
380
                    True
                )
381
382
383
384
385
386
                q_linear.to(next(layer.parameters()).device)
                set_op_by_name(layer, name, q_linear)
            
            torch.cuda.empty_cache()
            gc.collect()
    
387
    @staticmethod
388
    def _scale_activations(self, layer):
389
        scale_dict = self.get_act_for_scaling(layer)
390

391
392
393
        if scale_dict['is_scalable']:
            if not isinstance(scale_dict['scale_layer'], ScaledActivation):
                param = next(layer.parameters())
394

395
396
                # get activation scale
                scale_like = torch.ones(scale_dict['scale_shape'], dtype=param.dtype, device=param.device)
397

398
399
                # scale activation
                scaled_act = ScaledActivation(scale_dict['scale_layer'], scale_like)
Casper's avatar
Casper committed
400
                set_op_by_name(layer, scale_dict['scale_name'], scaled_act)