base.py 12.3 KB
Newer Older
1
import os
Casper Hansen's avatar
Casper Hansen committed
2
import gc
3
import json
Casper Hansen's avatar
Casper Hansen committed
4
5
import torch
import functools
6
import accelerate
Casper Hansen's avatar
Casper Hansen committed
7
import torch.nn as nn
Casper Hansen's avatar
Casper Hansen committed
8
from tqdm import tqdm
Casper Hansen's avatar
Casper Hansen committed
9
10
from collections import defaultdict

11
from huggingface_hub import snapshot_download
Casper Hansen's avatar
Casper Hansen committed
12
from awq.utils.calib_data import get_calib_dataset
13
14
from awq.quantize.quantizer import pseudo_quantize_tensor
from awq.quantize.qmodule import WQLinear, ScaledActivation
Casper Hansen's avatar
Casper Hansen committed
15
16
from awq.quantize.auto_clip import auto_clip_block, apply_clip
from awq.quantize.auto_scale import auto_scale_block, apply_scale
17
18
from transformers import AutoModelForCausalLM, AutoConfig, PreTrainedModel
from accelerate import init_empty_weights, load_checkpoint_and_dispatch, infer_auto_device_map
Casper Hansen's avatar
Casper Hansen committed
19
from awq.utils.module import append_str_prefix, get_op_name, get_named_linears, set_op_by_name
Casper Hansen's avatar
Casper Hansen committed
20

Casper's avatar
Casper committed
21
class BaseAWQForCausalLM:
22
    def __init__(self, model, model_type, is_quantized, quant_config):
23
24
25
26
        self.model:PreTrainedModel = model
        self.model_type:str = model_type
        self.is_quantized:bool = is_quantized
        self.search_result = None
27
        self.quant_config:dict = quant_config
28
29
30
31
32
33
    
    def to(self, device: str):
        return self.model.to(device)
    
    def forward(self, *args, **kwargs):
        return self.model(*args, **kwargs)
34

Casper Hansen's avatar
Casper Hansen committed
35
    @torch.no_grad()
36
    def quantize(self, tokenizer=None, quant_config={}, n_samples=128, seqlen=512,
Casper Hansen's avatar
Casper Hansen committed
37
                       auto_scale=True, mse_range=True, run_search=False, run_quant=True,
Casper Hansen's avatar
Casper Hansen committed
38
                       calib_data="pileval"):
39
        self.quant_config = quant_config
40

Casper Hansen's avatar
Casper Hansen committed
41
        if run_search:
42
            self.search_result = self._awq_search(tokenizer, quant_config, n_samples=n_samples, seqlen=seqlen,
Casper Hansen's avatar
Casper Hansen committed
43
44
45
                       auto_scale=auto_scale, mse_range=mse_range, calib_data=calib_data)
        
        if run_quant:
46
            self._awq_quant(quant_config)
Casper Hansen's avatar
Casper Hansen committed
47
48
    
    
49
50
    def _awq_quant(self, quant_config):
        assert quant_config["zero_point"], "We only support zero_point quantization now."
51
        layers = self.get_model_layers(self.model)
Casper's avatar
Casper committed
52

Casper Hansen's avatar
Casper Hansen committed
53
54
55
56
        # Run AWQ quantization
        for i in tqdm(range(len(layers)), desc="AWQ Quantization"):
            layer = layers[i]
            named_linears = get_named_linears(layer)
57
            self._scale_activations(self, layer)
Casper Hansen's avatar
Casper Hansen committed
58
59

            for name, module in named_linears.items():
Casper Hansen's avatar
Casper Hansen committed
60
                module.cuda()
61
                module.weight.data, scales, zeros = pseudo_quantize_tensor(module.weight.data, w_bit=quant_config['w_bit'], get_scale_zp=True, **quant_config)
Casper Hansen's avatar
Casper Hansen committed
62
63
64
                scales = scales.t().contiguous()
                zeros = zeros.t().contiguous()
                q_linear = WQLinear.from_linear(
65
                    module, quant_config['w_bit'], quant_config['q_group_size'], False, scales, zeros)
Casper Hansen's avatar
Casper Hansen committed
66
67
68
69
70
                module.cpu()
                q_linear.to(next(layer.parameters()).device)
                set_op_by_name(layer, name, q_linear)
                torch.cuda.empty_cache()
                gc.collect()
Casper Hansen's avatar
Casper Hansen committed
71
72
73
74
            
            torch.cuda.empty_cache()
            gc.collect()
    
75
    def _awq_search(self, tokenizer, quant_config, n_samples=128, seqlen=512,
Casper Hansen's avatar
Casper Hansen committed
76
                       auto_scale=True, mse_range=True, calib_data="pileval"):
77
        layers = self.get_model_layers(self.model)
Casper Hansen's avatar
Casper Hansen committed
78
79
80
81
82
83
84
85
86

        samples = get_calib_dataset(
            data=calib_data, tokenizer=tokenizer, n_samples=n_samples, block_size=seqlen)
        samples = torch.cat(samples, dim=0)

        inps = []
        layer_kwargs = {}

        layers[0] = layers[0].cuda()
87
        self.move_embed(self.model, "cuda")
Casper Hansen's avatar
Casper Hansen committed
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
        
        # get input and kwargs to layer 0
        # with_kwargs is only supported in PyTorch 2.0
        # use this Catcher hack for now
        class Catcher(nn.Module):
            def __init__(self, module):
                super().__init__()
                self.module = module

            def forward(self, inp, **kwargs):
                inps.append(inp)
                layer_kwargs.update(kwargs)
                raise ValueError  # early exit to break later inference

        # patch layer 0 to catch input and kwargs
        layers[0] = Catcher(layers[0])
        try:
105
            self.model(samples.to(next(self.model.parameters()).device))
Casper Hansen's avatar
Casper Hansen committed
106
107
108
109
110
111
112
        except ValueError:  # work with early exit
            pass
        del samples
        layers[0] = layers[0].module  # restore
        inps = inps[0]

        layers[0] = layers[0].cpu()
113
        self.move_embed(self.model, "cpu")
Casper Hansen's avatar
Casper Hansen committed
114
115
116
117
118
119
120
121
        
        gc.collect()
        torch.cuda.empty_cache()
        awq_results = {
            "scale": [],
            "clip": [],
        }

Casper Hansen's avatar
Casper Hansen committed
122
        # Run AWQ search layer by layer
123
        for i in tqdm(range(len(layers)), desc="AWQ Search"):
Casper Hansen's avatar
Casper Hansen committed
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
            layer = layers[i]
            layer = layer.cuda()
            named_linears = get_named_linears(layer)

            # firstly, get input features of all linear layers
            def cache_input_hook(m, x, y, name, feat_dict):
                x = x[0]
                x = x.detach().cpu()
                feat_dict[name].append(x)

            input_feat = defaultdict(list)
            handles = []
            for name in named_linears:
                handles.append(named_linears[name].register_forward_hook(
                    functools.partial(cache_input_hook, name=name,
                                    feat_dict=input_feat)))
            inps = inps.to(next(layer.parameters()).device)  # in case multi-gpu
            # get output as next layer's input
            inps = layer(inps, **layer_kwargs)[0]
            for h in handles:
                h.remove()
            # now solve for scaling and clipping
            input_feat = {k: torch.cat(v, dim=0) for k, v in input_feat.items()}

            # Clear GPU memory
            torch.cuda.empty_cache()

            if auto_scale:  # if it applies, we should also modify the input_feat with scales
                scales_list = auto_scale_block(
                    self,
154
155
156
                    layer,
                    layer_kwargs,
                    quant_config=quant_config,
Casper Hansen's avatar
Casper Hansen committed
157
158
                    input_feat=input_feat,
                )
159

Casper Hansen's avatar
Casper Hansen committed
160
                apply_scale(layers[i], scales_list, input_feat_dict=input_feat)
161

Casper Hansen's avatar
Casper Hansen committed
162
                # append prefix to make names global
163
                awq_results["scale"] += append_str_prefix(scales_list, get_op_name(self.model, layer) + ".")
Casper Hansen's avatar
Casper Hansen committed
164
165
166
167
168

            # Clear GPU memory
            torch.cuda.empty_cache()
            
            if mse_range:
169
170
171
172
173
174
                clip_list = auto_clip_block(
                    layer,
                    quant_config=quant_config,
                    input_feat=input_feat
                )

Casper Hansen's avatar
Casper Hansen committed
175
176
                apply_clip(layer, clip_list)
                # append prefix to make names global
177
                awq_results["clip"] += append_str_prefix(clip_list, get_op_name(self.model, layer) + ".")
Casper Hansen's avatar
Casper Hansen committed
178
179
180
181
182
183

            layer = layer.cpu()
            # Haotian: check activation replacement
            del input_feat
            gc.collect()
            torch.cuda.empty_cache()
Casper Hansen's avatar
Casper Hansen committed
184
        
Casper Hansen's avatar
Casper Hansen committed
185
        return awq_results
Casper's avatar
Casper committed
186

187
    def save_quantized(self, save_dir):
188
189
190
191
192
193
194
195
196
197
198
199
200
201
        def _save_files(save_dir, model_name, model):
            class EmptyModule(nn.Module):
                def __init__(self): super(EmptyModule, self).__init__()
                def forward(self, x): return x

            # Save model fiels without search results
            self.model.save_pretrained(save_dir, state_dict=EmptyModule().state_dict())

            # Remove empty module
            os.remove(f'{save_dir}/pytorch_model.bin')

            # Save search results
            torch.save(model, f'{save_dir}/{model_name}')

202
203
204
205
            # Save config
            with open(f'{save_dir}/quant_config.json', 'w+') as file:
                file.write(json.dumps(self.quant_config, indent=4))

206
207
208
209
        save_dir = save_dir[:-1] if save_dir[-1] == '/' else save_dir

        # Save model
        if self.search_result is None:
210
211
            model_name = 'awq_model_w4_g128.pt'
            _save_files(save_dir, model_name, self.model.state_dict())
212
213
        else:
            model_name = 'awq_model_search_result.pt'
214
215
            _save_files(save_dir, model_name, self.search_result)
        
216
217
218
219
220
221
    @classmethod
    def from_pretrained(self, model_path, model_type, torch_dtype: torch.dtype = torch.float16, 
                        trust_remote_code=True):
        return self.from_quantized(
            model_path, 
            model_type, 
222
            model_filename='', 
223
224
225
226
227
            device='balanced', 
            torch_dtype=torch_dtype, 
            trust_remote_code=trust_remote_code, 
            is_quantized=False
        )
Casper's avatar
Casper committed
228

229
    @classmethod
230
    def from_quantized(self, model_path, model_type, model_filename, w_bit=4, quant_config={}, 
231
232
                       device='balanced', torch_dtype=torch.float16, trust_remote_code=True, 
                       safetensors=False, is_quantized=True):
233
234
        # Download model if path is not a directory
        if not os.path.isdir(model_path):
235
236
237
238
239
240
241
            ignore_patterns = ["*msgpack*", "*h5*"]
            if safetensors:
                ignore_patterns.extend(["*.pt", "*.bin"])
            else:
                ignore_patterns.append("*safetensors*")

            model_path = snapshot_download(model_path, ignore_patterns=ignore_patterns)
242
243
244
        
        # TODO: Better naming, model_filename becomes a directory
        model_filename = model_path + f'/{model_filename}'
245

246
247
248
249
250
        # Load config
        config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)

        # Load empty weights
        with init_empty_weights():
251
252
            model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code)
        
253
        # Only need to replace layers if a model is AWQ quantized
254
255
        if is_quantized:
            # Prepare WQLinear layers, replace nn.Linear
256
            self._load_quantized_modules(self, model, w_bit, quant_config)
257
258
        
        model.tie_weights()
259

260
        # Load model weights
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
        try:
            model = load_checkpoint_and_dispatch(model, model_filename, device_map=device, no_split_module_classes=[self.layer_type])
        except Exception as ex:
            # Fallback to auto model if load_checkpoint_and_dispatch is not working
            print(f'{ex} - falling back to AutoModelForCausalLM.from_pretrained')

            device_map = infer_auto_device_map(
                model,
                no_split_module_classes=[self.layer_type], 
                dtype=torch_dtype
            )
            
            del model
            
            # Load model weights
            model = AutoModelForCausalLM.from_pretrained(
                model_filename, device_map=device_map, offload_folder="offload", offload_state_dict=True, torch_dtype=torch_dtype
            )
            model.eval()
280

281
        return self(model, model_type, is_quantized=is_quantized, quant_config=quant_config)
Casper's avatar
Casper committed
282

283
    def _load_quantized_modules(self, model, w_bit, quant_config):
284
        # Real quantization of weights
285
        assert quant_config["zero_point"], "We only support zero_point quantization now."
286
287
        
        # Get blocks of model
288
        layers = self.get_model_layers(model)
289

290
291
        for i in tqdm(range(len(layers)), desc="Replacing layers..."):
            layer = layers[i]
292
293

            # Get every linear layer in a block
294
            named_linears = get_named_linears(layer)
295
296

            # Replace activation functions
297
            self._scale_activations(self, layer)
298

299
            # Replace nn.Linear with WQLinear
300
301
            for name, module in named_linears.items():
                q_linear = WQLinear.from_linear(
302
                    module, w_bit, quant_config['q_group_size'], True)
303
304
305
306
307
308
                q_linear.to(next(layer.parameters()).device)
                set_op_by_name(layer, name, q_linear)
            
            torch.cuda.empty_cache()
            gc.collect()
    
309
    @staticmethod
310
    def _scale_activations(self, layer):
311
        scale_dict = self.get_act_for_scaling(layer)
312

313
314
315
        if scale_dict['is_scalable']:
            if not isinstance(scale_dict['scale_layer'], ScaledActivation):
                param = next(layer.parameters())
316

317
318
                # get activation scale
                scale_like = torch.ones(scale_dict['scale_shape'], dtype=param.dtype, device=param.device)
319

320
321
322
                # scale activation
                scaled_act = ScaledActivation(scale_dict['scale_layer'], scale_like)
                set_op_by_name(layer, scale_dict['scale_name'], scaled_act)