base.py 12 KB
Newer Older
1
import os
Casper Hansen's avatar
Casper Hansen committed
2
3
4
import gc
import torch
import functools
5
import accelerate
Casper Hansen's avatar
Casper Hansen committed
6
import torch.nn as nn
Casper Hansen's avatar
Casper Hansen committed
7
from tqdm import tqdm
Casper Hansen's avatar
Casper Hansen committed
8
9
from collections import defaultdict

10
from huggingface_hub import snapshot_download
Casper Hansen's avatar
Casper Hansen committed
11
from awq.utils.calib_data import get_calib_dataset
12
13
from awq.quantize.quantizer import pseudo_quantize_tensor
from awq.quantize.qmodule import WQLinear, ScaledActivation
Casper Hansen's avatar
Casper Hansen committed
14
15
from awq.quantize.auto_clip import auto_clip_block, apply_clip
from awq.quantize.auto_scale import auto_scale_block, apply_scale
16
17
from transformers import AutoModelForCausalLM, AutoConfig, PreTrainedModel
from accelerate import init_empty_weights, load_checkpoint_and_dispatch, infer_auto_device_map
Casper Hansen's avatar
Casper Hansen committed
18
from awq.utils.module import append_str_prefix, get_op_name, get_named_linears, set_op_by_name
Casper Hansen's avatar
Casper Hansen committed
19

Casper's avatar
Casper committed
20
class BaseAWQForCausalLM:
21
    def __init__(self, model, model_type, is_quantized):
22
23
24
25
        self.model:PreTrainedModel = model
        self.model_type:str = model_type
        self.is_quantized:bool = is_quantized
        self.search_result = None
26
27
28
29
30
31
    
    def to(self, device: str):
        return self.model.to(device)
    
    def forward(self, *args, **kwargs):
        return self.model(*args, **kwargs)
32

Casper Hansen's avatar
Casper Hansen committed
33
    @torch.no_grad()
34
    def quantize(self, tokenizer=None, w_bit=4, q_config={}, n_samples=128, seqlen=512,
Casper Hansen's avatar
Casper Hansen committed
35
                       auto_scale=True, mse_range=True, run_search=False, run_quant=True,
Casper Hansen's avatar
Casper Hansen committed
36
                       calib_data="pileval"):
37

Casper Hansen's avatar
Casper Hansen committed
38
        if run_search:
39
            self.search_result = self._awq_search(tokenizer, w_bit, q_config, n_samples=n_samples, seqlen=seqlen,
Casper Hansen's avatar
Casper Hansen committed
40
41
42
                       auto_scale=auto_scale, mse_range=mse_range, calib_data=calib_data)
        
        if run_quant:
43
            self._awq_quant(w_bit, q_config)
Casper Hansen's avatar
Casper Hansen committed
44
45
    
    
46
    def _awq_quant(self, w_bit, q_config):
Casper Hansen's avatar
Casper Hansen committed
47
        assert q_config["zero_point"], "We only support zero_point quantization now."
48
        layers = self.get_model_layers(self.model)
Casper's avatar
Casper committed
49

Casper Hansen's avatar
Casper Hansen committed
50
51
52
53
        # Run AWQ quantization
        for i in tqdm(range(len(layers)), desc="AWQ Quantization"):
            layer = layers[i]
            named_linears = get_named_linears(layer)
54
            self._scale_activations(self, layer)
Casper Hansen's avatar
Casper Hansen committed
55
56

            for name, module in named_linears.items():
Casper Hansen's avatar
Casper Hansen committed
57
58
59
60
61
62
63
64
65
66
67
                module.cuda()
                module.weight.data, scales, zeros = pseudo_quantize_tensor(module.weight.data, n_bit=w_bit, get_scale_zp=True, **q_config)
                scales = scales.t().contiguous()
                zeros = zeros.t().contiguous()
                q_linear = WQLinear.from_linear(
                    module, w_bit, q_config['q_group_size'], False, scales, zeros)
                module.cpu()
                q_linear.to(next(layer.parameters()).device)
                set_op_by_name(layer, name, q_linear)
                torch.cuda.empty_cache()
                gc.collect()
Casper Hansen's avatar
Casper Hansen committed
68
69
70
71
            
            torch.cuda.empty_cache()
            gc.collect()
    
72
    def _awq_search(self, tokenizer, w_bit, q_config, n_samples=128, seqlen=512,
Casper Hansen's avatar
Casper Hansen committed
73
                       auto_scale=True, mse_range=True, calib_data="pileval"):
74
        layers = self.get_model_layers(self.model)
Casper Hansen's avatar
Casper Hansen committed
75
76
77
78
79
80
81
82
83

        samples = get_calib_dataset(
            data=calib_data, tokenizer=tokenizer, n_samples=n_samples, block_size=seqlen)
        samples = torch.cat(samples, dim=0)

        inps = []
        layer_kwargs = {}

        layers[0] = layers[0].cuda()
84
        self.move_embed(self.model, "cuda")
Casper Hansen's avatar
Casper Hansen committed
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
        
        # get input and kwargs to layer 0
        # with_kwargs is only supported in PyTorch 2.0
        # use this Catcher hack for now
        class Catcher(nn.Module):
            def __init__(self, module):
                super().__init__()
                self.module = module

            def forward(self, inp, **kwargs):
                inps.append(inp)
                layer_kwargs.update(kwargs)
                raise ValueError  # early exit to break later inference

        # patch layer 0 to catch input and kwargs
        layers[0] = Catcher(layers[0])
        try:
102
            self.model(samples.to(next(self.model.parameters()).device))
Casper Hansen's avatar
Casper Hansen committed
103
104
105
106
107
108
109
        except ValueError:  # work with early exit
            pass
        del samples
        layers[0] = layers[0].module  # restore
        inps = inps[0]

        layers[0] = layers[0].cpu()
110
        self.move_embed(self.model, "cpu")
Casper Hansen's avatar
Casper Hansen committed
111
112
113
114
115
116
117
118
        
        gc.collect()
        torch.cuda.empty_cache()
        awq_results = {
            "scale": [],
            "clip": [],
        }

Casper Hansen's avatar
Casper Hansen committed
119
        # Run AWQ search layer by layer
120
        for i in tqdm(range(len(layers)), desc="AWQ Search"):
Casper Hansen's avatar
Casper Hansen committed
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
            layer = layers[i]
            layer = layer.cuda()
            named_linears = get_named_linears(layer)

            # firstly, get input features of all linear layers
            def cache_input_hook(m, x, y, name, feat_dict):
                x = x[0]
                x = x.detach().cpu()
                feat_dict[name].append(x)

            input_feat = defaultdict(list)
            handles = []
            for name in named_linears:
                handles.append(named_linears[name].register_forward_hook(
                    functools.partial(cache_input_hook, name=name,
                                    feat_dict=input_feat)))
            inps = inps.to(next(layer.parameters()).device)  # in case multi-gpu
            # get output as next layer's input
            inps = layer(inps, **layer_kwargs)[0]
            for h in handles:
                h.remove()
            # now solve for scaling and clipping
            input_feat = {k: torch.cat(v, dim=0) for k, v in input_feat.items()}

            # Clear GPU memory
            torch.cuda.empty_cache()

            if auto_scale:  # if it applies, we should also modify the input_feat with scales
                scales_list = auto_scale_block(
                    self,
                    layer, layer_kwargs,
                    w_bit=w_bit, q_config=q_config,
                    input_feat=input_feat,
                )
                # apply_scale(layer, scales_list, input_feat_dict=input_feat)
                apply_scale(layers[i], scales_list, input_feat_dict=input_feat)
                # append prefix to make names global
158
                awq_results["scale"] += append_str_prefix(scales_list, get_op_name(self.model, layer) + ".")
Casper Hansen's avatar
Casper Hansen committed
159
160
161
162
163
164
165
166
167
168

            # Clear GPU memory
            torch.cuda.empty_cache()
            
            if mse_range:
                clip_list = auto_clip_block(layer,
                                w_bit=w_bit, q_config=q_config,
                                input_feat=input_feat,)
                apply_clip(layer, clip_list)
                # append prefix to make names global
169
                awq_results["clip"] += append_str_prefix(clip_list, get_op_name(self.model, layer) + ".")
Casper Hansen's avatar
Casper Hansen committed
170
171
172
173
174
175

            layer = layer.cpu()
            # Haotian: check activation replacement
            del input_feat
            gc.collect()
            torch.cuda.empty_cache()
Casper Hansen's avatar
Casper Hansen committed
176
        
Casper Hansen's avatar
Casper Hansen committed
177
        return awq_results
Casper's avatar
Casper committed
178

179
    def save_quantized(self, save_dir):
180
181
182
183
184
185
186
187
188
189
190
191
192
193
        def _save_files(save_dir, model_name, model):
            class EmptyModule(nn.Module):
                def __init__(self): super(EmptyModule, self).__init__()
                def forward(self, x): return x

            # Save model fiels without search results
            self.model.save_pretrained(save_dir, state_dict=EmptyModule().state_dict())

            # Remove empty module
            os.remove(f'{save_dir}/pytorch_model.bin')

            # Save search results
            torch.save(model, f'{save_dir}/{model_name}')

194
195
196
197
        save_dir = save_dir[:-1] if save_dir[-1] == '/' else save_dir

        # Save model
        if self.search_result is None:
198
199
            model_name = 'awq_model_w4_g128.pt'
            _save_files(save_dir, model_name, self.model.state_dict())
200
201
        else:
            model_name = 'awq_model_search_result.pt'
202
203
            _save_files(save_dir, model_name, self.search_result)
        
204
205
206
207
208
209
    @classmethod
    def from_pretrained(self, model_path, model_type, torch_dtype: torch.dtype = torch.float16, 
                        trust_remote_code=True):
        return self.from_quantized(
            model_path, 
            model_type, 
210
            model_filename='', 
211
212
213
214
215
            device='balanced', 
            torch_dtype=torch_dtype, 
            trust_remote_code=trust_remote_code, 
            is_quantized=False
        )
Casper's avatar
Casper committed
216

217
    @classmethod
218
    def from_quantized(self, model_path, model_type, model_filename, w_bit=4, q_config={}, 
219
220
                       device='balanced', torch_dtype=torch.float16, trust_remote_code=True, 
                       safetensors=False, is_quantized=True):
221
222
        # Download model if path is not a directory
        if not os.path.isdir(model_path):
223
224
225
226
227
228
229
            ignore_patterns = ["*msgpack*", "*h5*"]
            if safetensors:
                ignore_patterns.extend(["*.pt", "*.bin"])
            else:
                ignore_patterns.append("*safetensors*")

            model_path = snapshot_download(model_path, ignore_patterns=ignore_patterns)
230
231
232
        
        # TODO: Better naming, model_filename becomes a directory
        model_filename = model_path + f'/{model_filename}'
233

234
235
236
237
238
        # Load config
        config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)

        # Load empty weights
        with init_empty_weights():
239
240
            model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code)
        
241
        # Only need to replace layers if a model is AWQ quantized
242
243
244
245
246
        if is_quantized:
            # Prepare WQLinear layers, replace nn.Linear
            self._load_quantized_modules(self, model, w_bit, q_config)
        
        model.tie_weights()
247

248
        # Load model weights
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
        try:
            model = load_checkpoint_and_dispatch(model, model_filename, device_map=device, no_split_module_classes=[self.layer_type])
        except Exception as ex:
            # Fallback to auto model if load_checkpoint_and_dispatch is not working
            print(f'{ex} - falling back to AutoModelForCausalLM.from_pretrained')

            device_map = infer_auto_device_map(
                model,
                no_split_module_classes=[self.layer_type], 
                dtype=torch_dtype
            )
            
            del model
            
            # Load model weights
            model = AutoModelForCausalLM.from_pretrained(
                model_filename, device_map=device_map, offload_folder="offload", offload_state_dict=True, torch_dtype=torch_dtype
            )
            model.eval()
268

269
        return self(model, model_type, is_quantized=is_quantized)
Casper's avatar
Casper committed
270

271
272
    def _load_quantized_modules(self, model, w_bit, q_config):
        # Real quantization of weights
273
        assert q_config["zero_point"], "We only support zero_point quantization now."
274
275
        
        # Get blocks of model
276
        layers = self.get_model_layers(model)
277

278
279
        for i in tqdm(range(len(layers)), desc="Replacing layers..."):
            layer = layers[i]
280
281

            # Get every linear layer in a block
282
            named_linears = get_named_linears(layer)
283
284

            # Replace activation functions
285
            self._scale_activations(self, layer)
286

287
            # Replace nn.Linear with WQLinear
288
289
290
291
292
293
294
295
296
            for name, module in named_linears.items():
                q_linear = WQLinear.from_linear(
                    module, w_bit, q_config['q_group_size'], True)
                q_linear.to(next(layer.parameters()).device)
                set_op_by_name(layer, name, q_linear)
            
            torch.cuda.empty_cache()
            gc.collect()
    
297
    @staticmethod
298
    def _scale_activations(self, layer):
299
        scale_dict = self.get_act_for_scaling(layer)
300

301
302
303
        if scale_dict['is_scalable']:
            if not isinstance(scale_dict['scale_layer'], ScaledActivation):
                param = next(layer.parameters())
304

305
306
                # get activation scale
                scale_like = torch.ones(scale_dict['scale_shape'], dtype=param.dtype, device=param.device)
307

308
309
310
                # scale activation
                scaled_act = ScaledActivation(scale_dict['scale_layer'], scale_like)
                set_op_by_name(layer, scale_dict['scale_name'], scaled_act)