base.py 13.9 KB
Newer Older
1
import os
Casper Hansen's avatar
Casper Hansen committed
2
import gc
3
import json
4
from typing import List, Union
Casper Hansen's avatar
Casper Hansen committed
5
6
7
import torch
import functools
import torch.nn as nn
Casper Hansen's avatar
Casper Hansen committed
8
from tqdm import tqdm
Casper Hansen's avatar
Casper Hansen committed
9
10
from collections import defaultdict

11
from huggingface_hub import snapshot_download
Casper Hansen's avatar
Casper Hansen committed
12
from awq.utils.calib_data import get_calib_dataset
13
14
from awq.quantize.quantizer import pseudo_quantize_tensor
from awq.quantize.qmodule import WQLinear, ScaledActivation
Casper Hansen's avatar
Casper Hansen committed
15
16
from awq.quantize.auto_clip import auto_clip_block, apply_clip
from awq.quantize.auto_scale import auto_scale_block, apply_scale
17
18
from transformers import AutoModelForCausalLM, AutoConfig, PreTrainedModel
from accelerate import init_empty_weights, load_checkpoint_and_dispatch, infer_auto_device_map
Casper Hansen's avatar
Casper Hansen committed
19
from awq.utils.module import append_str_prefix, get_op_name, get_named_linears, set_op_by_name
Casper Hansen's avatar
Casper Hansen committed
20

21
class BaseAWQForCausalLM(nn.Module):
22
    def __init__(self, model, model_type, is_quantized, quant_config):
23
        super().__init__()
24
25
26
27
        self.model:PreTrainedModel = model
        self.model_type:str = model_type
        self.is_quantized:bool = is_quantized
        self.search_result = None
28
        self.quant_config:dict = quant_config
29
30
31
32
33
34
    
    def to(self, device: str):
        return self.model.to(device)
    
    def forward(self, *args, **kwargs):
        return self.model(*args, **kwargs)
Casper Hansen's avatar
Casper Hansen committed
35
36
37
38
    
    def generate(self, *args, **kwargs):
        with torch.inference_mode():
            return self.model.generate(*args, **kwargs)
39

Casper Hansen's avatar
Casper Hansen committed
40
    @torch.no_grad()
41
    def quantize(self, tokenizer=None, quant_config={}, n_samples=128, seqlen=512,
42
                       auto_scale=True, mse_range=True, run_search=True, run_quant=True,
43
                       calib_data: Union[str, List[str]]="pileval"):
44
        self.quant_config = quant_config
45

Casper Hansen's avatar
Casper Hansen committed
46
        if run_search:
47
            self.search_result = self._awq_search(tokenizer, quant_config, n_samples=n_samples, seqlen=seqlen,
Casper Hansen's avatar
Casper Hansen committed
48
49
50
                       auto_scale=auto_scale, mse_range=mse_range, calib_data=calib_data)
        
        if run_quant:
51
            self._awq_quant()
Casper Hansen's avatar
Casper Hansen committed
52
            self.is_quantized = True
Casper Hansen's avatar
Casper Hansen committed
53
    
qwopqwop200's avatar
qwopqwop200 committed
54
55
56
57
    @staticmethod
    def fuse_layers(model):
        pass
        
58
59
    def _awq_quant(self):
        assert self.quant_config["zero_point"], "We only support zero_point quantization now."
60
        layers = self.get_model_layers(self.model)
Casper's avatar
Casper committed
61

Casper Hansen's avatar
Casper Hansen committed
62
63
64
65
        # Run AWQ quantization
        for i in tqdm(range(len(layers)), desc="AWQ Quantization"):
            layer = layers[i]
            named_linears = get_named_linears(layer)
66
            self._scale_activations(self, layer)
Casper Hansen's avatar
Casper Hansen committed
67
68

            for name, module in named_linears.items():
Casper Hansen's avatar
Casper Hansen committed
69
                module.cuda()
70
71
72
73
74
75
76

                module.weight.data, scales, zeros = pseudo_quantize_tensor(
                    module.weight.data, 
                    get_scale_zp=True, 
                    **self.quant_config
                )

Casper Hansen's avatar
Casper Hansen committed
77
78
                scales = scales.t().contiguous()
                zeros = zeros.t().contiguous()
79

Casper Hansen's avatar
Casper Hansen committed
80
                q_linear = WQLinear.from_linear(
81
82
83
84
85
86
87
88
                    module, 
                    self.quant_config['w_bit'], 
                    self.quant_config['q_group_size'], 
                    False, 
                    scales, 
                    zeros
                )

Casper Hansen's avatar
Casper Hansen committed
89
90
91
92
93
                module.cpu()
                q_linear.to(next(layer.parameters()).device)
                set_op_by_name(layer, name, q_linear)
                torch.cuda.empty_cache()
                gc.collect()
Casper Hansen's avatar
Casper Hansen committed
94
95
96
97
            
            torch.cuda.empty_cache()
            gc.collect()
    
98
    def _awq_search(self, tokenizer, quant_config, n_samples=128, seqlen=512,
99
                       auto_scale=True, mse_range=True, calib_data:Union[str, List[str]]="pileval"):
100
        layers = self.get_model_layers(self.model)
Casper Hansen's avatar
Casper Hansen committed
101
102
103
104
105
106
107
108
109

        samples = get_calib_dataset(
            data=calib_data, tokenizer=tokenizer, n_samples=n_samples, block_size=seqlen)
        samples = torch.cat(samples, dim=0)

        inps = []
        layer_kwargs = {}

        layers[0] = layers[0].cuda()
110
        self.move_embed(self.model, "cuda")
Casper Hansen's avatar
Casper Hansen committed
111
112
113
114
115
116
117
118
119
        
        # get input and kwargs to layer 0
        # with_kwargs is only supported in PyTorch 2.0
        # use this Catcher hack for now
        class Catcher(nn.Module):
            def __init__(self, module):
                super().__init__()
                self.module = module

Casper's avatar
Casper committed
120
121
            def forward(self, hijacked_inputs, **kwargs):
                inps.append(hijacked_inputs)
Casper Hansen's avatar
Casper Hansen committed
122
123
124
125
126
127
                layer_kwargs.update(kwargs)
                raise ValueError  # early exit to break later inference

        # patch layer 0 to catch input and kwargs
        layers[0] = Catcher(layers[0])
        try:
128
            self.model(samples.to(next(self.model.parameters()).device))
Casper Hansen's avatar
Casper Hansen committed
129
130
131
132
133
134
135
        except ValueError:  # work with early exit
            pass
        del samples
        layers[0] = layers[0].module  # restore
        inps = inps[0]

        layers[0] = layers[0].cpu()
136
        self.move_embed(self.model, "cpu")
Casper Hansen's avatar
Casper Hansen committed
137
138
139
140
141
142
143
144
        
        gc.collect()
        torch.cuda.empty_cache()
        awq_results = {
            "scale": [],
            "clip": [],
        }

Casper Hansen's avatar
Casper Hansen committed
145
        # Run AWQ search layer by layer
146
        for i in tqdm(range(len(layers)), desc="AWQ Search"):
Casper Hansen's avatar
Casper Hansen committed
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
            layer = layers[i]
            layer = layer.cuda()
            named_linears = get_named_linears(layer)

            # firstly, get input features of all linear layers
            def cache_input_hook(m, x, y, name, feat_dict):
                x = x[0]
                x = x.detach().cpu()
                feat_dict[name].append(x)

            input_feat = defaultdict(list)
            handles = []
            for name in named_linears:
                handles.append(named_linears[name].register_forward_hook(
                    functools.partial(cache_input_hook, name=name,
                                    feat_dict=input_feat)))
            inps = inps.to(next(layer.parameters()).device)  # in case multi-gpu
            # get output as next layer's input
            inps = layer(inps, **layer_kwargs)[0]
            for h in handles:
                h.remove()
            # now solve for scaling and clipping
            input_feat = {k: torch.cat(v, dim=0) for k, v in input_feat.items()}

            # Clear GPU memory
            torch.cuda.empty_cache()

            if auto_scale:  # if it applies, we should also modify the input_feat with scales
                scales_list = auto_scale_block(
                    self,
177
178
179
                    layer,
                    layer_kwargs,
                    quant_config=quant_config,
Casper Hansen's avatar
Casper Hansen committed
180
181
                    input_feat=input_feat,
                )
182

Casper Hansen's avatar
Casper Hansen committed
183
                apply_scale(layers[i], scales_list, input_feat_dict=input_feat)
184

Casper Hansen's avatar
Casper Hansen committed
185
                # append prefix to make names global
186
                awq_results["scale"] += append_str_prefix(scales_list, get_op_name(self.model, layer) + ".")
Casper Hansen's avatar
Casper Hansen committed
187
188
189
190
191

            # Clear GPU memory
            torch.cuda.empty_cache()
            
            if mse_range:
192
193
194
195
196
197
                clip_list = auto_clip_block(
                    layer,
                    quant_config=quant_config,
                    input_feat=input_feat
                )

Casper Hansen's avatar
Casper Hansen committed
198
199
                apply_clip(layer, clip_list)
                # append prefix to make names global
200
                awq_results["clip"] += append_str_prefix(clip_list, get_op_name(self.model, layer) + ".")
Casper Hansen's avatar
Casper Hansen committed
201
202
203
204
205
206

            layer = layer.cpu()
            # Haotian: check activation replacement
            del input_feat
            gc.collect()
            torch.cuda.empty_cache()
Casper Hansen's avatar
Casper Hansen committed
207
        
Casper Hansen's avatar
Casper Hansen committed
208
        return awq_results
Casper's avatar
Casper committed
209

210
    def save_quantized(self, save_dir):
211
212
213
214
215
216
217
218
219
220
221
222
223
224
        def _save_files(save_dir, model_name, model):
            class EmptyModule(nn.Module):
                def __init__(self): super(EmptyModule, self).__init__()
                def forward(self, x): return x

            # Save model fiels without search results
            self.model.save_pretrained(save_dir, state_dict=EmptyModule().state_dict())

            # Remove empty module
            os.remove(f'{save_dir}/pytorch_model.bin')

            # Save search results
            torch.save(model, f'{save_dir}/{model_name}')

225
226
227
228
            # Save config
            with open(f'{save_dir}/quant_config.json', 'w+') as file:
                file.write(json.dumps(self.quant_config, indent=4))

229
230
231
        save_dir = save_dir[:-1] if save_dir[-1] == '/' else save_dir

        # Save model
Casper Hansen's avatar
Casper Hansen committed
232
        if self.search_result is None or self.is_quantized:
Casper Hansen's avatar
Casper Hansen committed
233
            model_name = f'awq_model_w{self.quant_config["w_bit"]}_g{self.quant_config["q_group_size"]}.pt'
234
            _save_files(save_dir, model_name, self.model.state_dict())
235
236
        else:
            model_name = 'awq_model_search_result.pt'
237
238
            _save_files(save_dir, model_name, self.search_result)
        
239
240
    @classmethod
    def from_pretrained(self, model_path, model_type, torch_dtype: torch.dtype = torch.float16, 
Casper Hansen's avatar
Casper Hansen committed
241
                        trust_remote_code=True, safetensors=False):
242
243
244
        return self.from_quantized(
            model_path, 
            model_type, 
245
            model_filename='', 
246
            max_new_tokens=None,
247
248
249
            device='balanced', 
            torch_dtype=torch_dtype, 
            trust_remote_code=trust_remote_code, 
Casper Hansen's avatar
Casper Hansen committed
250
            safetensors=safetensors,
251
252
            is_quantized=False
        )
Casper's avatar
Casper committed
253

254
    @classmethod
255
    def from_quantized(self, model_path, model_type, model_filename, max_new_tokens=None,
256
                       device='balanced', torch_dtype=torch.float16, trust_remote_code=True, 
257
                       safetensors=False, is_quantized=True, fuse_layers=False):
258
        # [STEP 1] Download model if path is not a directory
259
        if not os.path.isdir(model_path):
260
261
262
263
264
265
266
            ignore_patterns = ["*msgpack*", "*h5*"]
            if safetensors:
                ignore_patterns.extend(["*.pt", "*.bin"])
            else:
                ignore_patterns.append("*safetensors*")

            model_path = snapshot_download(model_path, ignore_patterns=ignore_patterns)
267
268
269
        
        # TODO: Better naming, model_filename becomes a directory
        model_filename = model_path + f'/{model_filename}'
270

271
        # [STEP 2] Load config and set sequence length
272
        # TODO: Create BaseAWQConfig class
273
274
275
276
277
278
279
280
        quant_config_path = f'{model_path}/quant_config.json'
        if os.path.exists(quant_config_path):
            with open(quant_config_path, 'r') as file:
                quant_config = json.loads(file.read())
        else:
            # Default config that works for most models
            quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4}
        
281
282
283
284
285
286
287
288
289
        # Load model config and set max generation length
        if max_new_tokens is None and hasattr(self, 'max_new_tokens_key'):
            config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)
            config.max_new_tokens = getattr(config, self.max_new_tokens_key)
        else:
            max_new_tokens = 2048 if max_new_tokens is None else max_new_tokens
            config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)
            config.max_new_tokens = max_new_tokens
        
290
        # [STEP 3] Load model
291
        with init_empty_weights():
292
293
            model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code)
        
294
        # Only need to replace layers if a model is AWQ quantized
295
296
        if is_quantized:
            # Prepare WQLinear layers, replace nn.Linear
297
            self._load_quantized_modules(self, model, quant_config)
298
299
        
        model.tie_weights()
300

301
        # Load model weights
302
        if is_quantized:
303
304
            model = load_checkpoint_and_dispatch(model, model_filename, device_map=device, no_split_module_classes=[self.layer_type])

305
306
307
            if fuse_layers:
                self.fuse_layers(model)

308
309
        else:
            # If not quantized, must load with AutoModelForCausalLM
310
311
312
313
314
315
316
317
318
319
            device_map = infer_auto_device_map(
                model,
                no_split_module_classes=[self.layer_type], 
                dtype=torch_dtype
            )
            
            del model
            
            # Load model weights
            model = AutoModelForCausalLM.from_pretrained(
Casper Hansen's avatar
Casper Hansen committed
320
                model_filename, device_map=device_map, offload_folder="offload", offload_state_dict=True, torch_dtype=torch_dtype, use_safetensors=safetensors
321
322
            )
            model.eval()
323

324
        return self(model, model_type, is_quantized=is_quantized, quant_config=quant_config)
Casper's avatar
Casper committed
325

326
    def _load_quantized_modules(self, model, quant_config):
327
        # Real quantization of weights
328
        assert quant_config["zero_point"], "We only support zero_point quantization now."
329
330
        
        # Get blocks of model
331
        layers = self.get_model_layers(model)
332

333
334
        for i in tqdm(range(len(layers)), desc="Replacing layers..."):
            layer = layers[i]
335
336

            # Get every linear layer in a block
337
            named_linears = get_named_linears(layer)
338
339

            # Replace activation functions
340
            self._scale_activations(self, layer)
341

342
            # Replace nn.Linear with WQLinear
343
344
            for name, module in named_linears.items():
                q_linear = WQLinear.from_linear(
345
                    module, quant_config['w_bit'], quant_config['q_group_size'], True)
346
347
348
349
350
351
                q_linear.to(next(layer.parameters()).device)
                set_op_by_name(layer, name, q_linear)
            
            torch.cuda.empty_cache()
            gc.collect()
    
352
    @staticmethod
353
    def _scale_activations(self, layer):
354
        scale_dict = self.get_act_for_scaling(layer)
355

356
357
358
        if scale_dict['is_scalable']:
            if not isinstance(scale_dict['scale_layer'], ScaledActivation):
                param = next(layer.parameters())
359

360
361
                # get activation scale
                scale_like = torch.ones(scale_dict['scale_shape'], dtype=param.dtype, device=param.device)
362

363
364
                # scale activation
                scaled_act = ScaledActivation(scale_dict['scale_layer'], scale_like)
Casper's avatar
Casper committed
365
                set_op_by_name(layer, scale_dict['scale_name'], scaled_act)