base.py 13.8 KB
Newer Older
1
import os
Casper Hansen's avatar
Casper Hansen committed
2
import gc
3
import json
Casper Hansen's avatar
Casper Hansen committed
4
5
6
import torch
import functools
import torch.nn as nn
Casper Hansen's avatar
Casper Hansen committed
7
from tqdm import tqdm
Casper Hansen's avatar
Casper Hansen committed
8
9
from collections import defaultdict

10
from huggingface_hub import snapshot_download
Casper Hansen's avatar
Casper Hansen committed
11
from awq.utils.calib_data import get_calib_dataset
12
13
from awq.quantize.quantizer import pseudo_quantize_tensor
from awq.quantize.qmodule import WQLinear, ScaledActivation
Casper Hansen's avatar
Casper Hansen committed
14
15
from awq.quantize.auto_clip import auto_clip_block, apply_clip
from awq.quantize.auto_scale import auto_scale_block, apply_scale
16
17
from transformers import AutoModelForCausalLM, AutoConfig, PreTrainedModel
from accelerate import init_empty_weights, load_checkpoint_and_dispatch, infer_auto_device_map
Casper Hansen's avatar
Casper Hansen committed
18
from awq.utils.module import append_str_prefix, get_op_name, get_named_linears, set_op_by_name
Casper Hansen's avatar
Casper Hansen committed
19

20
class BaseAWQForCausalLM(nn.Module):
21
    def __init__(self, model, model_type, is_quantized, quant_config):
22
        super().__init__()
23
24
25
26
        self.model:PreTrainedModel = model
        self.model_type:str = model_type
        self.is_quantized:bool = is_quantized
        self.search_result = None
27
        self.quant_config:dict = quant_config
28
29
30
31
32
33
    
    def to(self, device: str):
        return self.model.to(device)
    
    def forward(self, *args, **kwargs):
        return self.model(*args, **kwargs)
Casper Hansen's avatar
Casper Hansen committed
34
35
36
37
    
    def generate(self, *args, **kwargs):
        with torch.inference_mode():
            return self.model.generate(*args, **kwargs)
38

Casper Hansen's avatar
Casper Hansen committed
39
    @torch.no_grad()
40
    def quantize(self, tokenizer=None, quant_config={}, n_samples=128, seqlen=512,
41
                       auto_scale=True, mse_range=True, run_search=True, run_quant=True,
Casper Hansen's avatar
Casper Hansen committed
42
                       calib_data="pileval"):
43
        self.quant_config = quant_config
44

Casper Hansen's avatar
Casper Hansen committed
45
        if run_search:
46
            self.search_result = self._awq_search(tokenizer, quant_config, n_samples=n_samples, seqlen=seqlen,
Casper Hansen's avatar
Casper Hansen committed
47
48
49
                       auto_scale=auto_scale, mse_range=mse_range, calib_data=calib_data)
        
        if run_quant:
50
            self._awq_quant()
Casper Hansen's avatar
Casper Hansen committed
51
            self.is_quantized = True
Casper Hansen's avatar
Casper Hansen committed
52
53
    
    
54
55
    def _awq_quant(self):
        assert self.quant_config["zero_point"], "We only support zero_point quantization now."
56
        layers = self.get_model_layers(self.model)
Casper's avatar
Casper committed
57

Casper Hansen's avatar
Casper Hansen committed
58
59
60
61
        # Run AWQ quantization
        for i in tqdm(range(len(layers)), desc="AWQ Quantization"):
            layer = layers[i]
            named_linears = get_named_linears(layer)
62
            self._scale_activations(self, layer)
Casper Hansen's avatar
Casper Hansen committed
63
64

            for name, module in named_linears.items():
Casper Hansen's avatar
Casper Hansen committed
65
                module.cuda()
66
67
68
69
70
71
72

                module.weight.data, scales, zeros = pseudo_quantize_tensor(
                    module.weight.data, 
                    get_scale_zp=True, 
                    **self.quant_config
                )

Casper Hansen's avatar
Casper Hansen committed
73
74
                scales = scales.t().contiguous()
                zeros = zeros.t().contiguous()
75

Casper Hansen's avatar
Casper Hansen committed
76
                q_linear = WQLinear.from_linear(
77
78
79
80
81
82
83
84
                    module, 
                    self.quant_config['w_bit'], 
                    self.quant_config['q_group_size'], 
                    False, 
                    scales, 
                    zeros
                )

Casper Hansen's avatar
Casper Hansen committed
85
86
87
88
89
                module.cpu()
                q_linear.to(next(layer.parameters()).device)
                set_op_by_name(layer, name, q_linear)
                torch.cuda.empty_cache()
                gc.collect()
Casper Hansen's avatar
Casper Hansen committed
90
91
92
93
            
            torch.cuda.empty_cache()
            gc.collect()
    
94
    def _awq_search(self, tokenizer, quant_config, n_samples=128, seqlen=512,
Casper Hansen's avatar
Casper Hansen committed
95
                       auto_scale=True, mse_range=True, calib_data="pileval"):
96
        layers = self.get_model_layers(self.model)
Casper Hansen's avatar
Casper Hansen committed
97
98
99
100
101
102
103
104
105

        samples = get_calib_dataset(
            data=calib_data, tokenizer=tokenizer, n_samples=n_samples, block_size=seqlen)
        samples = torch.cat(samples, dim=0)

        inps = []
        layer_kwargs = {}

        layers[0] = layers[0].cuda()
106
        self.move_embed(self.model, "cuda")
Casper Hansen's avatar
Casper Hansen committed
107
108
109
110
111
112
113
114
115
        
        # get input and kwargs to layer 0
        # with_kwargs is only supported in PyTorch 2.0
        # use this Catcher hack for now
        class Catcher(nn.Module):
            def __init__(self, module):
                super().__init__()
                self.module = module

Casper's avatar
Casper committed
116
117
            def forward(self, hijacked_inputs, **kwargs):
                inps.append(hijacked_inputs)
Casper Hansen's avatar
Casper Hansen committed
118
119
120
121
122
123
                layer_kwargs.update(kwargs)
                raise ValueError  # early exit to break later inference

        # patch layer 0 to catch input and kwargs
        layers[0] = Catcher(layers[0])
        try:
124
            self.model(samples.to(next(self.model.parameters()).device))
Casper Hansen's avatar
Casper Hansen committed
125
126
127
128
129
130
131
        except ValueError:  # work with early exit
            pass
        del samples
        layers[0] = layers[0].module  # restore
        inps = inps[0]

        layers[0] = layers[0].cpu()
132
        self.move_embed(self.model, "cpu")
Casper Hansen's avatar
Casper Hansen committed
133
134
135
136
137
138
139
140
        
        gc.collect()
        torch.cuda.empty_cache()
        awq_results = {
            "scale": [],
            "clip": [],
        }

Casper Hansen's avatar
Casper Hansen committed
141
        # Run AWQ search layer by layer
142
        for i in tqdm(range(len(layers)), desc="AWQ Search"):
Casper Hansen's avatar
Casper Hansen committed
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
            layer = layers[i]
            layer = layer.cuda()
            named_linears = get_named_linears(layer)

            # firstly, get input features of all linear layers
            def cache_input_hook(m, x, y, name, feat_dict):
                x = x[0]
                x = x.detach().cpu()
                feat_dict[name].append(x)

            input_feat = defaultdict(list)
            handles = []
            for name in named_linears:
                handles.append(named_linears[name].register_forward_hook(
                    functools.partial(cache_input_hook, name=name,
                                    feat_dict=input_feat)))
            inps = inps.to(next(layer.parameters()).device)  # in case multi-gpu
            # get output as next layer's input
            inps = layer(inps, **layer_kwargs)[0]
            for h in handles:
                h.remove()
            # now solve for scaling and clipping
            input_feat = {k: torch.cat(v, dim=0) for k, v in input_feat.items()}

            # Clear GPU memory
            torch.cuda.empty_cache()

            if auto_scale:  # if it applies, we should also modify the input_feat with scales
                scales_list = auto_scale_block(
                    self,
173
174
175
                    layer,
                    layer_kwargs,
                    quant_config=quant_config,
Casper Hansen's avatar
Casper Hansen committed
176
177
                    input_feat=input_feat,
                )
178

Casper Hansen's avatar
Casper Hansen committed
179
                apply_scale(layers[i], scales_list, input_feat_dict=input_feat)
180

Casper Hansen's avatar
Casper Hansen committed
181
                # append prefix to make names global
182
                awq_results["scale"] += append_str_prefix(scales_list, get_op_name(self.model, layer) + ".")
Casper Hansen's avatar
Casper Hansen committed
183
184
185
186
187

            # Clear GPU memory
            torch.cuda.empty_cache()
            
            if mse_range:
188
189
190
191
192
193
                clip_list = auto_clip_block(
                    layer,
                    quant_config=quant_config,
                    input_feat=input_feat
                )

Casper Hansen's avatar
Casper Hansen committed
194
195
                apply_clip(layer, clip_list)
                # append prefix to make names global
196
                awq_results["clip"] += append_str_prefix(clip_list, get_op_name(self.model, layer) + ".")
Casper Hansen's avatar
Casper Hansen committed
197
198
199
200
201
202

            layer = layer.cpu()
            # Haotian: check activation replacement
            del input_feat
            gc.collect()
            torch.cuda.empty_cache()
Casper Hansen's avatar
Casper Hansen committed
203
        
Casper Hansen's avatar
Casper Hansen committed
204
        return awq_results
Casper's avatar
Casper committed
205

206
    def save_quantized(self, save_dir):
207
208
209
210
211
212
213
214
215
216
217
218
219
220
        def _save_files(save_dir, model_name, model):
            class EmptyModule(nn.Module):
                def __init__(self): super(EmptyModule, self).__init__()
                def forward(self, x): return x

            # Save model fiels without search results
            self.model.save_pretrained(save_dir, state_dict=EmptyModule().state_dict())

            # Remove empty module
            os.remove(f'{save_dir}/pytorch_model.bin')

            # Save search results
            torch.save(model, f'{save_dir}/{model_name}')

221
222
223
224
            # Save config
            with open(f'{save_dir}/quant_config.json', 'w+') as file:
                file.write(json.dumps(self.quant_config, indent=4))

225
226
227
        save_dir = save_dir[:-1] if save_dir[-1] == '/' else save_dir

        # Save model
Casper Hansen's avatar
Casper Hansen committed
228
        if self.search_result is None or self.is_quantized:
Casper Hansen's avatar
Casper Hansen committed
229
            model_name = f'awq_model_w{self.quant_config["w_bit"]}_g{self.quant_config["q_group_size"]}.pt'
230
            _save_files(save_dir, model_name, self.model.state_dict())
231
232
        else:
            model_name = 'awq_model_search_result.pt'
233
234
            _save_files(save_dir, model_name, self.search_result)
        
235
236
    @classmethod
    def from_pretrained(self, model_path, model_type, torch_dtype: torch.dtype = torch.float16, 
Casper Hansen's avatar
Casper Hansen committed
237
                        trust_remote_code=True, safetensors=False):
238
239
240
        return self.from_quantized(
            model_path, 
            model_type, 
241
            model_filename='', 
242
            max_new_tokens=None,
243
244
245
            device='balanced', 
            torch_dtype=torch_dtype, 
            trust_remote_code=trust_remote_code, 
Casper Hansen's avatar
Casper Hansen committed
246
            safetensors=safetensors,
247
248
            is_quantized=False
        )
Casper's avatar
Casper committed
249

250
    @classmethod
251
    def from_quantized(self, model_path, model_type, model_filename, max_new_tokens=None,
252
                       device='balanced', torch_dtype=torch.float16, trust_remote_code=True, 
253
                       safetensors=False, is_quantized=True, fuse_layers=False):
254
        # [STEP 1] Download model if path is not a directory
255
        if not os.path.isdir(model_path):
256
257
258
259
260
261
262
            ignore_patterns = ["*msgpack*", "*h5*"]
            if safetensors:
                ignore_patterns.extend(["*.pt", "*.bin"])
            else:
                ignore_patterns.append("*safetensors*")

            model_path = snapshot_download(model_path, ignore_patterns=ignore_patterns)
263
264
265
        
        # TODO: Better naming, model_filename becomes a directory
        model_filename = model_path + f'/{model_filename}'
266

267
        # [STEP 2] Load config and set sequence length
268
        # TODO: Create BaseAWQConfig class
269
270
271
272
273
274
275
276
        quant_config_path = f'{model_path}/quant_config.json'
        if os.path.exists(quant_config_path):
            with open(quant_config_path, 'r') as file:
                quant_config = json.loads(file.read())
        else:
            # Default config that works for most models
            quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4}
        
277
278
279
280
281
282
283
284
285
        # Load model config and set max generation length
        if max_new_tokens is None and hasattr(self, 'max_new_tokens_key'):
            config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)
            config.max_new_tokens = getattr(config, self.max_new_tokens_key)
        else:
            max_new_tokens = 2048 if max_new_tokens is None else max_new_tokens
            config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)
            config.max_new_tokens = max_new_tokens
        
286
        # [STEP 3] Load model
287
        with init_empty_weights():
288
289
            model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code)
        
290
        # Only need to replace layers if a model is AWQ quantized
291
292
        if is_quantized:
            # Prepare WQLinear layers, replace nn.Linear
293
            self._load_quantized_modules(self, model, quant_config)
294
295
        
        model.tie_weights()
296

297
        # Load model weights
298
        if is_quantized:
299
300
            model = load_checkpoint_and_dispatch(model, model_filename, device_map=device, no_split_module_classes=[self.layer_type])

301
302
303
            if fuse_layers:
                self.fuse_layers(model)

304
305
        else:
            # If not quantized, must load with AutoModelForCausalLM
306
307
308
309
310
311
312
313
314
315
            device_map = infer_auto_device_map(
                model,
                no_split_module_classes=[self.layer_type], 
                dtype=torch_dtype
            )
            
            del model
            
            # Load model weights
            model = AutoModelForCausalLM.from_pretrained(
Casper Hansen's avatar
Casper Hansen committed
316
                model_filename, device_map=device_map, offload_folder="offload", offload_state_dict=True, torch_dtype=torch_dtype, use_safetensors=safetensors
317
318
            )
            model.eval()
319

320
        return self(model, model_type, is_quantized=is_quantized, quant_config=quant_config)
Casper's avatar
Casper committed
321

322
    def _load_quantized_modules(self, model, quant_config):
323
        # Real quantization of weights
324
        assert quant_config["zero_point"], "We only support zero_point quantization now."
325
326
        
        # Get blocks of model
327
        layers = self.get_model_layers(model)
328

329
330
        for i in tqdm(range(len(layers)), desc="Replacing layers..."):
            layer = layers[i]
331
332

            # Get every linear layer in a block
333
            named_linears = get_named_linears(layer)
334
335

            # Replace activation functions
336
            self._scale_activations(self, layer)
337

338
            # Replace nn.Linear with WQLinear
339
340
            for name, module in named_linears.items():
                q_linear = WQLinear.from_linear(
341
                    module, quant_config['w_bit'], quant_config['q_group_size'], True)
342
343
344
345
346
347
                q_linear.to(next(layer.parameters()).device)
                set_op_by_name(layer, name, q_linear)
            
            torch.cuda.empty_cache()
            gc.collect()
    
348
    @staticmethod
349
    def _scale_activations(self, layer):
350
        scale_dict = self.get_act_for_scaling(layer)
351

352
353
354
        if scale_dict['is_scalable']:
            if not isinstance(scale_dict['scale_layer'], ScaledActivation):
                param = next(layer.parameters())
355

356
357
                # get activation scale
                scale_like = torch.ones(scale_dict['scale_shape'], dtype=param.dtype, device=param.device)
358

359
360
                # scale activation
                scaled_act = ScaledActivation(scale_dict['scale_layer'], scale_like)
Casper's avatar
Casper committed
361
                set_op_by_name(layer, scale_dict['scale_name'], scaled_act)