quantizer.py 16.4 KB
Newer Older
Ji Lin's avatar
Ji Lin committed
1
import torch
2
import inspect
Casper's avatar
Casper committed
3
4
5
6
import logging
import functools
import torch.nn as nn
from tqdm import tqdm
Vik Paruchuri's avatar
Vik Paruchuri committed
7
from typing import Dict, List
Casper's avatar
Casper committed
8
9
10
from collections import defaultdict
from awq.utils.utils import clear_memory
from awq.utils.calib_data import get_calib_dataset
Casper Hansen's avatar
Casper Hansen committed
11
from awq.quantize.scale import apply_scale, apply_clip
Casper's avatar
Casper committed
12
13
14
15
16
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from awq.utils.module import append_str_prefix, get_op_name, get_named_linears, set_op_by_name


class AwqQuantizer:
17
    def __init__(self, awq_model, model, tokenizer, w_bit, group_size, version, 
18
                       calib_data, split, text_column, duo_scaling, modules_to_not_convert=None) -> None:
Casper Hansen's avatar
Casper Hansen committed
19
        self.awq_model = awq_model
Casper's avatar
Casper committed
20
21
22
23
24
25
26
27
        self.model = model
        self.tokenizer = tokenizer
        self.w_bit = w_bit
        self.group_size = group_size
        self.version = version
        self.calib_data = calib_data
        self.split = split
        self.text_column = text_column
28
        self.duo_scaling = duo_scaling
29
        self.modules_to_not_convert = modules_to_not_convert if modules_to_not_convert is not None else []
Casper Hansen's avatar
Casper Hansen committed
30
        self.modules, self.module_kwargs, self.inps = self.init_quant()
31
    
Casper's avatar
Casper committed
32
33
34
35
36
37
38
39
    def pseudo_quantize_tensor(self, w: torch.Tensor, get_scale_zp=False):
        org_w_shape = w.shape
        if self.group_size > 0:
            assert org_w_shape[-1] % self.group_size == 0
            w = w.reshape(-1, self.group_size)
        assert w.dim() == 2

        # zero point quantization
Ji Lin's avatar
Ji Lin committed
40
41
        max_val = w.amax(dim=1, keepdim=True)
        min_val = w.amin(dim=1, keepdim=True)
Casper's avatar
Casper committed
42
        max_int = 2 ** self.w_bit - 1
Ji Lin's avatar
Ji Lin committed
43
44
45
        min_int = 0
        scales = (max_val - min_val).clamp(min=1e-5) / max_int
        zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)
Casper's avatar
Casper committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59

        assert torch.isnan(scales).sum() == 0
        assert torch.isnan(w).sum() == 0

        w = (torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - zeros) * scales
        assert torch.isnan(w).sum() == 0

        w = w.reshape(org_w_shape)

        if get_scale_zp:
            return w, scales.view(w.shape[0], -1), zeros.view(w.shape[0], -1)
        else:
            return w
    
Casper's avatar
Casper committed
60
61
62
63
64
65
66
67
68
69
70
71
72
    def pseudo_dequantize_tensor(self, w: nn.Linear, scales: torch.Tensor, zeros: torch.Tensor):
        # get repeated count
        repeat_count = w.weight.data.shape[-1] // zeros.shape[-1]

        # get zeros and scales in correct shape
        zeros = zeros.repeat(1, repeat_count).reshape(w.weight.data.shape)
        scales = scales.repeat(1, repeat_count).reshape(w.weight.data.shape)

        # dequantize
        w = (w.weight.data - zeros) * scales

        return w
    
73
74
75
76
77
78
79
    def _exclude_layers_to_not_quantize(self, linear_layers):
        filtered_layers = {}
        for name, linear_layer in linear_layers.items():
            if not any(key in name for key in self.modules_to_not_convert):
                filtered_layers[name] = linear_layer
        return filtered_layers
    
Casper Hansen's avatar
Casper Hansen committed
80
81
    def quantize(self):
        for i in tqdm(range(len(self.modules)), desc="AWQ"):
Casper's avatar
Casper committed
82
83
84
85
86
87
88
89
            # Move module and inputs to correct device
            common_device = next(self.modules[i].parameters()).device
            if common_device is None or str(common_device) == "cpu":
                self.modules[i] = self.modules[i].cuda()
                common_device = next(self.modules[i].parameters()).device
            
            self.inps = self.inps.to(common_device)

Casper's avatar
Casper committed
90
91
            # [STEP 1]: Get layer, extract linear modules, extract input features
            named_linears = get_named_linears(self.modules[i])
92
93
94
95

            # Filter out the linear layers we don't want to exclude
            named_linears = self._exclude_layers_to_not_quantize(named_linears)

Casper's avatar
Casper committed
96
97
98
99
            input_feat = self._get_input_feat(self.modules[i], named_linears)
            clear_memory()

            # [STEP 2]: Compute and apply scale list
Vik Paruchuri's avatar
Vik Paruchuri committed
100
            module_config: List[Dict] = self.awq_model.get_layers_for_scaling(
Casper's avatar
Casper committed
101
102
                self.modules[i], input_feat, self.module_kwargs
            )
Casper Hansen's avatar
Casper Hansen committed
103
            scales_list = [self._search_best_scale(self.modules[i], **layer) for layer in module_config]
Casper's avatar
Casper committed
104
105
106
107
            apply_scale(self.modules[i], scales_list, input_feat_dict=input_feat)
            scales_list = append_str_prefix(scales_list, get_op_name(self.model, self.modules[i]) + ".")

            # [STEP 3]: Compute and apply clipping list
Casper Hansen's avatar
Casper Hansen committed
108
109
110
            clip_list = self._search_best_clip(self.modules[i], named_linears, input_feat)
            apply_clip(self.modules[i], clip_list)
            clip_list = append_str_prefix(clip_list, get_op_name(self.model, self.modules[i]) + ".")
Casper's avatar
Casper committed
111
112

            # [STEP 4]: Quantize weights
113
114
115
            self._apply_quant(self.modules[i], named_linears)
            clear_memory()
    
Vik Paruchuri's avatar
Vik Paruchuri committed
116
    def _apply_quant(self, module, named_linears: Dict[str, nn.Linear]):
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
        for name, linear_layer in named_linears.items():
            # NOTE: small regression in perplexity if linear layer uses .cpu().float()
            linear_layer = linear_layer.cuda().half()

            linear_layer.weight.data, scales, zeros = self.pseudo_quantize_tensor(
                linear_layer.weight.data, 
                get_scale_zp=True
            )

            if self.version == 'GEMM':
                scales = scales.t().contiguous()
                zeros = zeros.t().contiguous()
                q_linear_module = WQLinear_GEMM

            elif self.version  == 'GEMV':
                q_linear_module = WQLinear_GEMV
Casper's avatar
Casper committed
133
            
134
135
136
137
138
139
140
141
142
143
144
145
            q_linear = q_linear_module.from_linear(
                linear=linear_layer,
                w_bit=self.w_bit,
                group_size=self.group_size,
                init_only=False,
                scales=scales,
                zeros=zeros
            )

            linear_layer.cpu()
            q_linear.to(next(module.parameters()).device)
            set_op_by_name(module, name, q_linear)
Casper's avatar
Casper committed
146
147
148
            clear_memory()

    @torch.no_grad()
Vik Paruchuri's avatar
Vik Paruchuri committed
149
    def _search_best_scale(self, module, prev_op, layers: List[nn.Linear], inp: torch.Tensor, module2inspect=None, kwargs={}):
Casper Hansen's avatar
Casper Hansen committed
150
151
152
153
154
155
156
        if module2inspect is None:
            assert len(layers) == 1
            module2inspect = layers[0]
        
        if "use_cache" in kwargs:
            kwargs.pop("use_cache")
        
Casper's avatar
Casper committed
157
        # Put x on the right device
Casper Hansen's avatar
Casper Hansen committed
158
        inp = inp.to(next(module2inspect.parameters()).device)
Casper's avatar
Casper committed
159
160

        # [STEP 1]: Compute maximum of weight
Casper Hansen's avatar
Casper Hansen committed
161
162
        weight = torch.cat([_m.weight for _m in layers], dim=0)
        org_shape = weight.shape
Casper's avatar
Casper committed
163
        weight = weight.view(-1, self.group_size)
Casper Hansen's avatar
Casper Hansen committed
164
165
166
        w_scale = weight.abs() / weight.abs().amax(dim=1, keepdim=True)
        w_scale = w_scale.view(org_shape)
        w_max = w_scale.mean(0)
Casper's avatar
Casper committed
167
168
169
        clear_memory(weight)

        # [STEP 2]: Compute maximum of x
Casper Hansen's avatar
Casper Hansen committed
170
        x_max = inp.abs().view(-1, inp.shape[-1]).mean(0)
Casper's avatar
Casper committed
171

Casper Hansen's avatar
Casper Hansen committed
172
        # [STEP 3]: Compute output of module
Casper's avatar
Casper committed
173
        with torch.no_grad():
174
175
176
            module_kwargs = self._sanitize_kwargs(kwargs, module2inspect)

            fp16_output = module2inspect(inp, **module_kwargs)
177
178
            if isinstance(fp16_output, tuple):
                fp16_output = fp16_output[0]
Casper's avatar
Casper committed
179
180
181
        
        # [STEP 4]: Compute loss
        best_scales = self._compute_best_scale(
Casper Hansen's avatar
Casper Hansen committed
182
            inp, w_max, x_max, module2inspect,
183
            layers, fp16_output, module_kwargs
Casper's avatar
Casper committed
184
185
        )
        
Casper Hansen's avatar
Casper Hansen committed
186
        return (get_op_name(module, prev_op), tuple([get_op_name(module, m) for m in layers]), best_scales)
Casper's avatar
Casper committed
187

Vik Paruchuri's avatar
Vik Paruchuri committed
188
    def _compute_best_scale(self, x, w_max, x_max, module2inspect, linears2scale: List[nn.Linear],
189
                                  fp16_output, kwargs={}):
Casper's avatar
Casper committed
190
191
192
        """
        Compute loss and select best scales

Casper's avatar
Casper committed
193
        L(s) = || Q(W * s) (s^-1 * X) - W * X ||
Casper's avatar
Casper committed
194
195
196
197
198
199
200
201
202
203
204
        Q: weight quantization function | pseudo_quantize_tensor(W * s)
        X: inputs from calib dataset    | X
        W: original weights in FP16     | layer
        s: per channel scaling factor   | s^-1 * X
        """
        n_grid = 20
        history = []
        best_ratio = -1
        best_scales = None
        best_error = float('inf')

Casper Hansen's avatar
Casper Hansen committed
205
        org_sd = {k: v.cpu() for k, v in module2inspect.state_dict().items()}
Casper's avatar
Casper committed
206
207
208
209
210
        
        device = x.device
        x_max = x_max.view(-1).to(device)
        w_max = w_max.view(-1).to(device)
        
Casper's avatar
Casper committed
211
212
        for ratio in range(n_grid):
            # create new scales
Casper's avatar
Casper committed
213
            ratio = ratio / n_grid
214

Casper Hansen's avatar
Casper Hansen committed
215
            # NOTE: s^-1 * x is fused here, according to paper
216
217
218
219
            if self.duo_scaling:
                scales = (x_max.pow(ratio) / w_max.pow(1-ratio)).clamp(min=1e-4)
            else:
                scales = x_max.pow(ratio).clamp(min=1e-4).view(-1)
Casper's avatar
Casper committed
220
            scales = scales / (scales.max() * scales.min()).sqrt()
Casper's avatar
Casper committed
221
            scales_view = scales.view(1, -1).to(device)
222

Casper Hansen's avatar
Casper Hansen committed
223
            # Q(W * s)
Casper's avatar
Casper committed
224
            for fc in linears2scale:
Casper's avatar
Casper committed
225
226
                fc.weight.mul_(scales_view)
                fc.weight.data = self.pseudo_quantize_tensor(fc.weight.data) / scales_view
Casper's avatar
Casper committed
227

228
229
230
231
232
            # W * X
            int_w_output = module2inspect(x, **kwargs)
            if isinstance(int_w_output, tuple):
                int_w_output = int_w_output[0]
            
Casper Hansen's avatar
Casper Hansen committed
233
234
            # compute mean squared error (L2 norm)
            loss = (fp16_output - int_w_output).float().pow(2).mean().item() # NOTE: float prevents overflow
Casper's avatar
Casper committed
235
236

            history.append(loss)
Casper's avatar
Casper committed
237
            if loss < best_error:
Casper's avatar
Casper committed
238
239
                best_error = loss
                best_ratio = ratio
Casper's avatar
Casper committed
240
                best_scales = scales.clone()
Casper Hansen's avatar
Casper Hansen committed
241
            module2inspect.load_state_dict(org_sd)
Casper's avatar
Casper committed
242

Casper's avatar
Casper committed
243
244
245
246
247
248
        if best_ratio == -1:
            logging.debug(history)
            raise Exception

        assert torch.isnan(best_scales).sum() == 0, best_scales

Casper Hansen's avatar
Casper Hansen committed
249
        return best_scales.detach().cpu()
Casper's avatar
Casper committed
250

Casper Hansen's avatar
Casper Hansen committed
251
252
253
254
    @torch.no_grad()
    def _search_best_clip(self, layer, named_linears, input_feat):
        clip_list = []
        avoid_clipping = ["q_", "k_", "query", "key", "Wqkv"]
Casper's avatar
Casper committed
255

Casper Hansen's avatar
Casper Hansen committed
256
257
258
259
260
261
262
263
264
265
        for name in named_linears:
            # due to qk bmm, it is hard to clip precisely
            if any([_ in name for _ in avoid_clipping]):
                continue

            named_linears[name].cuda()
            max_val = self._compute_best_clip(named_linears[name].weight, input_feat[name])
            clip_list.append((name, max_val))

            named_linears[name].cpu()
Casper Hansen's avatar
Casper Hansen committed
266
267
        
        return clip_list
Casper Hansen's avatar
Casper Hansen committed
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284

    @torch.no_grad()
    def _compute_best_clip(self, w: torch.Tensor, input_feat: torch.Tensor, n_grid=20, max_shrink=0.5, n_sample_token=512):
        assert w.dim() == 2
        org_w_shape = w.shape
        # w           [co, ci]      -> [co, 1, n_group, group size]
        # input_feat  [n_token, ci] -> [1, n_token, n_group, group size]
        group_size = self.group_size if self.group_size > 0 else w.shape[1]
        input_feat = input_feat.view(-1, input_feat.shape[-1])
        input_feat = input_feat.reshape(1, input_feat.shape[0], -1, group_size)
        input_feat = input_feat[:, 0::input_feat.shape[1] // n_sample_token]
        w = w.reshape(w.shape[0], 1, -1, group_size)

        oc_batch_size = 256 if w.shape[0] % 256 == 0 else 64  # prevent OOM
        assert w.shape[0] % oc_batch_size == 0
        w_all = w
        best_max_val_all = []
Casper's avatar
Casper committed
285

Casper Hansen's avatar
Casper Hansen committed
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
        for i_b in range(w.shape[0] // oc_batch_size):
            w = w_all[i_b * oc_batch_size: (i_b + 1) * oc_batch_size]

            org_max_val = w.abs().amax(dim=-1, keepdim=True)  # co, 1, n_group, 1

            best_max_val = org_max_val.clone()
            min_errs = torch.ones_like(org_max_val) * 1e9
            input_feat = input_feat.to(w.device)
            org_out = (input_feat * w).sum(dim=-1)  # co, n_token, n_group

            for i_s in range(int(max_shrink * n_grid)):
                max_val = org_max_val * (1 - i_s / n_grid)
                min_val = - max_val
                cur_w = torch.clamp(w, min_val, max_val)
                q_w = self.pseudo_quantize_tensor(cur_w)
                cur_out = (input_feat * q_w).sum(dim=-1)

                # co, 1, n_group, 1
                err = (cur_out - org_out).pow(2).mean(dim=1).view(min_errs.shape)
                del cur_w
                del cur_out
                cur_best_idx = err < min_errs
                min_errs[cur_best_idx] = err[cur_best_idx]
                best_max_val[cur_best_idx] = max_val[cur_best_idx]
            best_max_val_all.append(best_max_val)

        best_max_val = torch.cat(best_max_val_all, dim=0)

        clear_memory(input_feat)
        clear_memory(org_out)

        return best_max_val.squeeze(1)

    def init_quant(self, n_samples=128, seqlen=512):
        modules = self.awq_model.get_model_layers(self.model)
Casper's avatar
Casper committed
321
322
323
324
325
326
327
328
329
        samples = get_calib_dataset(
            data=self.calib_data, tokenizer=self.tokenizer, n_samples=n_samples, block_size=seqlen,
            split=self.split, text_column=self.text_column
        )
        samples = torch.cat(samples, dim=0)

        inps = []
        layer_kwargs = {}

Casper Hansen's avatar
Casper Hansen committed
330
331
        modules[0] = modules[0].cuda()
        self.awq_model.move_embed(self.model, "cuda")
Casper's avatar
Casper committed
332
333
334
335
336
337
338
339
340
        
        # get input and kwargs to layer 0
        # with_kwargs is only supported in PyTorch 2.0
        # use this Catcher hack for now
        class Catcher(nn.Module):
            def __init__(self, module):
                super().__init__()
                self.module = module

341
342
343
344
345
346
347
348
349
350
            def forward(self, *args, **kwargs):
                # assume first input to forward is hidden states
                if len(args) > 0:
                    hidden_states = args[0]
                    del args
                else:
                    first_key = list(kwargs.keys())[0]
                    hidden_states = kwargs.pop(first_key)

                inps.append(hidden_states)
Casper's avatar
Casper committed
351
352
353
354
                layer_kwargs.update(kwargs)
                raise ValueError  # early exit to break later inference

        # patch layer 0 to catch input and kwargs
Casper Hansen's avatar
Casper Hansen committed
355
        modules[0] = Catcher(modules[0])
Casper's avatar
Casper committed
356
357
358
359
        try:
            self.model(samples.to(next(self.model.parameters()).device))
        except ValueError:  # work with early exit
            pass
360
361
362
363
364
365
366
        
        # Update the layer kwargs with `prepare_inputs_for_generation` method
        # that takes care of everything to avoid unexpected errors.
        layer_kwargs = self.model.prepare_inputs_for_generation(samples, **layer_kwargs)
        # Pop the input_ids as they are not needed at all.
        layer_kwargs.pop("input_ids")

Casper's avatar
Casper committed
367
        del samples
Casper Hansen's avatar
Casper Hansen committed
368
        modules[0] = modules[0].module  # restore
Casper's avatar
Casper committed
369
370
        inps = inps[0]

Casper Hansen's avatar
Casper Hansen committed
371
372
        modules[0] = modules[0].cpu()
        self.awq_model.move_embed(self.model, "cpu")
Casper's avatar
Casper committed
373
374
        
        clear_memory()
375
        
Casper's avatar
Casper committed
376
        if layer_kwargs.get("attention_mask") is not None:
377
            layer_kwargs["attention_mask"] = layer_kwargs["attention_mask"].to("cuda")
Casper's avatar
Casper committed
378

Casper Hansen's avatar
Casper Hansen committed
379
        return modules, layer_kwargs, inps
Casper's avatar
Casper committed
380
381
382
383
384
385
386
387
388
389
390
391
392
393
    
    def _get_input_feat(self, layer, named_linears):
        # firstly, get input features of all linear layers
        def cache_input_hook(m, x, y, name, feat_dict):
            x = x[0]
            x = x.detach().cpu()
            feat_dict[name].append(x)

        input_feat = defaultdict(list)
        handles = []
        for name in named_linears:
            handles.append(named_linears[name].register_forward_hook(
                functools.partial(cache_input_hook, name=name,
                                feat_dict=input_feat)))
Casper Hansen's avatar
Casper Hansen committed
394
        self.inps = self.inps.to(next(layer.parameters()).device)  # in case multi-gpu
Casper's avatar
Casper committed
395
        # get output as next layer's input
396
397
398
399
400
401
402
        
        # Sanitize the kwargs in case we use transformers version that contains
        # kwargs that are not handled by the module.
        # Useful for trust_remote_code models.
        module_kwargs = self._sanitize_kwargs(self.module_kwargs, layer)

        self.inps = layer(self.inps, **module_kwargs)[0]
Casper's avatar
Casper committed
403
404
405
406
407
408
        for h in handles:
            h.remove()
        # now solve for scaling and clipping
        input_feat = {k: torch.cat(v, dim=0) for k, v in input_feat.items()}
        
        return input_feat
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428


    def _sanitize_kwargs(self, inputs_kwargs, module):
        """
        Remove the arguments that are not supported in the module's
        forward pass to avoid breaking behaviour between different versions
        of transformers. 

        Args:
            inputs_kwargs (`dict`):
                The input dictionary to pass to the model layer
            module (`torch.nn.Module`):
                Target module to quantize.
        """
        module_signature = inspect.signature(module.forward).parameters
        sanitized_kwargs = {}
        for k, v in  inputs_kwargs.items():
            if k in module_signature:
                sanitized_kwargs[k] = v
        return sanitized_kwargs