base.py 12.3 KB
Newer Older
1
import os
Casper Hansen's avatar
Casper Hansen committed
2
import gc
3
import json
Casper Hansen's avatar
Casper Hansen committed
4
5
import torch
import functools
6
import accelerate
Casper Hansen's avatar
Casper Hansen committed
7
import torch.nn as nn
Casper Hansen's avatar
Casper Hansen committed
8
from tqdm import tqdm
Casper Hansen's avatar
Casper Hansen committed
9
10
from collections import defaultdict

11
from huggingface_hub import snapshot_download
Casper Hansen's avatar
Casper Hansen committed
12
from awq.utils.calib_data import get_calib_dataset
13
14
from awq.quantize.quantizer import pseudo_quantize_tensor
from awq.quantize.qmodule import WQLinear, ScaledActivation
Casper Hansen's avatar
Casper Hansen committed
15
16
from awq.quantize.auto_clip import auto_clip_block, apply_clip
from awq.quantize.auto_scale import auto_scale_block, apply_scale
17
18
from transformers import AutoModelForCausalLM, AutoConfig, PreTrainedModel
from accelerate import init_empty_weights, load_checkpoint_and_dispatch, infer_auto_device_map
Casper Hansen's avatar
Casper Hansen committed
19
from awq.utils.module import append_str_prefix, get_op_name, get_named_linears, set_op_by_name
Casper Hansen's avatar
Casper Hansen committed
20

21
class BaseAWQForCausalLM(nn.Module):
22
    def __init__(self, model, model_type, is_quantized, quant_config):
23
        super().__init__()
24
25
26
27
        self.model:PreTrainedModel = model
        self.model_type:str = model_type
        self.is_quantized:bool = is_quantized
        self.search_result = None
28
        self.quant_config:dict = quant_config
29
30
31
32
33
34
    
    def to(self, device: str):
        return self.model.to(device)
    
    def forward(self, *args, **kwargs):
        return self.model(*args, **kwargs)
35

Casper Hansen's avatar
Casper Hansen committed
36
    @torch.no_grad()
37
    def quantize(self, tokenizer=None, quant_config={}, n_samples=128, seqlen=512,
Casper Hansen's avatar
Casper Hansen committed
38
                       auto_scale=True, mse_range=True, run_search=False, run_quant=True,
Casper Hansen's avatar
Casper Hansen committed
39
                       calib_data="pileval"):
40
        self.quant_config = quant_config
41

Casper Hansen's avatar
Casper Hansen committed
42
        if run_search:
43
            self.search_result = self._awq_search(tokenizer, quant_config, n_samples=n_samples, seqlen=seqlen,
Casper Hansen's avatar
Casper Hansen committed
44
45
46
                       auto_scale=auto_scale, mse_range=mse_range, calib_data=calib_data)
        
        if run_quant:
47
            self._awq_quant(quant_config)
Casper Hansen's avatar
Casper Hansen committed
48
49
    
    
50
51
    def _awq_quant(self, quant_config):
        assert quant_config["zero_point"], "We only support zero_point quantization now."
52
        layers = self.get_model_layers(self.model)
Casper's avatar
Casper committed
53

Casper Hansen's avatar
Casper Hansen committed
54
55
56
57
        # Run AWQ quantization
        for i in tqdm(range(len(layers)), desc="AWQ Quantization"):
            layer = layers[i]
            named_linears = get_named_linears(layer)
58
            self._scale_activations(self, layer)
Casper Hansen's avatar
Casper Hansen committed
59
60

            for name, module in named_linears.items():
Casper Hansen's avatar
Casper Hansen committed
61
                module.cuda()
62
                module.weight.data, scales, zeros = pseudo_quantize_tensor(module.weight.data, w_bit=quant_config['w_bit'], get_scale_zp=True, **quant_config)
Casper Hansen's avatar
Casper Hansen committed
63
64
65
                scales = scales.t().contiguous()
                zeros = zeros.t().contiguous()
                q_linear = WQLinear.from_linear(
66
                    module, quant_config['w_bit'], quant_config['q_group_size'], False, scales, zeros)
Casper Hansen's avatar
Casper Hansen committed
67
68
69
70
71
                module.cpu()
                q_linear.to(next(layer.parameters()).device)
                set_op_by_name(layer, name, q_linear)
                torch.cuda.empty_cache()
                gc.collect()
Casper Hansen's avatar
Casper Hansen committed
72
73
74
75
            
            torch.cuda.empty_cache()
            gc.collect()
    
76
    def _awq_search(self, tokenizer, quant_config, n_samples=128, seqlen=512,
Casper Hansen's avatar
Casper Hansen committed
77
                       auto_scale=True, mse_range=True, calib_data="pileval"):
78
        layers = self.get_model_layers(self.model)
Casper Hansen's avatar
Casper Hansen committed
79
80
81
82
83
84
85
86
87

        samples = get_calib_dataset(
            data=calib_data, tokenizer=tokenizer, n_samples=n_samples, block_size=seqlen)
        samples = torch.cat(samples, dim=0)

        inps = []
        layer_kwargs = {}

        layers[0] = layers[0].cuda()
88
        self.move_embed(self.model, "cuda")
Casper Hansen's avatar
Casper Hansen committed
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
        
        # get input and kwargs to layer 0
        # with_kwargs is only supported in PyTorch 2.0
        # use this Catcher hack for now
        class Catcher(nn.Module):
            def __init__(self, module):
                super().__init__()
                self.module = module

            def forward(self, inp, **kwargs):
                inps.append(inp)
                layer_kwargs.update(kwargs)
                raise ValueError  # early exit to break later inference

        # patch layer 0 to catch input and kwargs
        layers[0] = Catcher(layers[0])
        try:
106
            self.model(samples.to(next(self.model.parameters()).device))
Casper Hansen's avatar
Casper Hansen committed
107
108
109
110
111
112
113
        except ValueError:  # work with early exit
            pass
        del samples
        layers[0] = layers[0].module  # restore
        inps = inps[0]

        layers[0] = layers[0].cpu()
114
        self.move_embed(self.model, "cpu")
Casper Hansen's avatar
Casper Hansen committed
115
116
117
118
119
120
121
122
        
        gc.collect()
        torch.cuda.empty_cache()
        awq_results = {
            "scale": [],
            "clip": [],
        }

Casper Hansen's avatar
Casper Hansen committed
123
        # Run AWQ search layer by layer
124
        for i in tqdm(range(len(layers)), desc="AWQ Search"):
Casper Hansen's avatar
Casper Hansen committed
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
            layer = layers[i]
            layer = layer.cuda()
            named_linears = get_named_linears(layer)

            # firstly, get input features of all linear layers
            def cache_input_hook(m, x, y, name, feat_dict):
                x = x[0]
                x = x.detach().cpu()
                feat_dict[name].append(x)

            input_feat = defaultdict(list)
            handles = []
            for name in named_linears:
                handles.append(named_linears[name].register_forward_hook(
                    functools.partial(cache_input_hook, name=name,
                                    feat_dict=input_feat)))
            inps = inps.to(next(layer.parameters()).device)  # in case multi-gpu
            # get output as next layer's input
            inps = layer(inps, **layer_kwargs)[0]
            for h in handles:
                h.remove()
            # now solve for scaling and clipping
            input_feat = {k: torch.cat(v, dim=0) for k, v in input_feat.items()}

            # Clear GPU memory
            torch.cuda.empty_cache()

            if auto_scale:  # if it applies, we should also modify the input_feat with scales
                scales_list = auto_scale_block(
                    self,
155
156
157
                    layer,
                    layer_kwargs,
                    quant_config=quant_config,
Casper Hansen's avatar
Casper Hansen committed
158
159
                    input_feat=input_feat,
                )
160

Casper Hansen's avatar
Casper Hansen committed
161
                apply_scale(layers[i], scales_list, input_feat_dict=input_feat)
162

Casper Hansen's avatar
Casper Hansen committed
163
                # append prefix to make names global
164
                awq_results["scale"] += append_str_prefix(scales_list, get_op_name(self.model, layer) + ".")
Casper Hansen's avatar
Casper Hansen committed
165
166
167
168
169

            # Clear GPU memory
            torch.cuda.empty_cache()
            
            if mse_range:
170
171
172
173
174
175
                clip_list = auto_clip_block(
                    layer,
                    quant_config=quant_config,
                    input_feat=input_feat
                )

Casper Hansen's avatar
Casper Hansen committed
176
177
                apply_clip(layer, clip_list)
                # append prefix to make names global
178
                awq_results["clip"] += append_str_prefix(clip_list, get_op_name(self.model, layer) + ".")
Casper Hansen's avatar
Casper Hansen committed
179
180
181
182
183
184

            layer = layer.cpu()
            # Haotian: check activation replacement
            del input_feat
            gc.collect()
            torch.cuda.empty_cache()
Casper Hansen's avatar
Casper Hansen committed
185
        
Casper Hansen's avatar
Casper Hansen committed
186
        return awq_results
Casper's avatar
Casper committed
187

188
    def save_quantized(self, save_dir):
189
190
191
192
193
194
195
196
197
198
199
200
201
202
        def _save_files(save_dir, model_name, model):
            class EmptyModule(nn.Module):
                def __init__(self): super(EmptyModule, self).__init__()
                def forward(self, x): return x

            # Save model fiels without search results
            self.model.save_pretrained(save_dir, state_dict=EmptyModule().state_dict())

            # Remove empty module
            os.remove(f'{save_dir}/pytorch_model.bin')

            # Save search results
            torch.save(model, f'{save_dir}/{model_name}')

203
204
205
206
            # Save config
            with open(f'{save_dir}/quant_config.json', 'w+') as file:
                file.write(json.dumps(self.quant_config, indent=4))

207
208
209
210
        save_dir = save_dir[:-1] if save_dir[-1] == '/' else save_dir

        # Save model
        if self.search_result is None:
211
212
            model_name = 'awq_model_w4_g128.pt'
            _save_files(save_dir, model_name, self.model.state_dict())
213
214
        else:
            model_name = 'awq_model_search_result.pt'
215
216
            _save_files(save_dir, model_name, self.search_result)
        
217
218
219
220
221
222
    @classmethod
    def from_pretrained(self, model_path, model_type, torch_dtype: torch.dtype = torch.float16, 
                        trust_remote_code=True):
        return self.from_quantized(
            model_path, 
            model_type, 
223
            model_filename='', 
224
225
226
227
228
            device='balanced', 
            torch_dtype=torch_dtype, 
            trust_remote_code=trust_remote_code, 
            is_quantized=False
        )
Casper's avatar
Casper committed
229

230
    @classmethod
231
    def from_quantized(self, model_path, model_type, model_filename, w_bit=4, quant_config={}, 
232
233
                       device='balanced', torch_dtype=torch.float16, trust_remote_code=True, 
                       safetensors=False, is_quantized=True):
234
235
        # Download model if path is not a directory
        if not os.path.isdir(model_path):
236
237
238
239
240
241
242
            ignore_patterns = ["*msgpack*", "*h5*"]
            if safetensors:
                ignore_patterns.extend(["*.pt", "*.bin"])
            else:
                ignore_patterns.append("*safetensors*")

            model_path = snapshot_download(model_path, ignore_patterns=ignore_patterns)
243
244
245
        
        # TODO: Better naming, model_filename becomes a directory
        model_filename = model_path + f'/{model_filename}'
246

247
248
249
250
251
        # Load config
        config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)

        # Load empty weights
        with init_empty_weights():
252
253
            model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code)
        
254
        # Only need to replace layers if a model is AWQ quantized
255
256
        if is_quantized:
            # Prepare WQLinear layers, replace nn.Linear
257
            self._load_quantized_modules(self, model, w_bit, quant_config)
258
259
        
        model.tie_weights()
260

261
        # Load model weights
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
        try:
            model = load_checkpoint_and_dispatch(model, model_filename, device_map=device, no_split_module_classes=[self.layer_type])
        except Exception as ex:
            # Fallback to auto model if load_checkpoint_and_dispatch is not working
            print(f'{ex} - falling back to AutoModelForCausalLM.from_pretrained')

            device_map = infer_auto_device_map(
                model,
                no_split_module_classes=[self.layer_type], 
                dtype=torch_dtype
            )
            
            del model
            
            # Load model weights
            model = AutoModelForCausalLM.from_pretrained(
                model_filename, device_map=device_map, offload_folder="offload", offload_state_dict=True, torch_dtype=torch_dtype
            )
            model.eval()
281

282
        return self(model, model_type, is_quantized=is_quantized, quant_config=quant_config)
Casper's avatar
Casper committed
283

284
    def _load_quantized_modules(self, model, w_bit, quant_config):
285
        # Real quantization of weights
286
        assert quant_config["zero_point"], "We only support zero_point quantization now."
287
288
        
        # Get blocks of model
289
        layers = self.get_model_layers(model)
290

291
292
        for i in tqdm(range(len(layers)), desc="Replacing layers..."):
            layer = layers[i]
293
294

            # Get every linear layer in a block
295
            named_linears = get_named_linears(layer)
296
297

            # Replace activation functions
298
            self._scale_activations(self, layer)
299

300
            # Replace nn.Linear with WQLinear
301
302
            for name, module in named_linears.items():
                q_linear = WQLinear.from_linear(
303
                    module, w_bit, quant_config['q_group_size'], True)
304
305
306
307
308
309
                q_linear.to(next(layer.parameters()).device)
                set_op_by_name(layer, name, q_linear)
            
            torch.cuda.empty_cache()
            gc.collect()
    
310
    @staticmethod
311
    def _scale_activations(self, layer):
312
        scale_dict = self.get_act_for_scaling(layer)
313

314
315
316
        if scale_dict['is_scalable']:
            if not isinstance(scale_dict['scale_layer'], ScaledActivation):
                param = next(layer.parameters())
317

318
319
                # get activation scale
                scale_like = torch.ones(scale_dict['scale_shape'], dtype=param.dtype, device=param.device)
320

321
322
323
                # scale activation
                scaled_act = ScaledActivation(scale_dict['scale_layer'], scale_like)
                set_op_by_name(layer, scale_dict['scale_name'], scaled_act)