quantizer.py 18.9 KB
Newer Older
Ji Lin's avatar
Ji Lin committed
1
import torch
2
import inspect
Casper's avatar
Casper committed
3
4
5
6
import logging
import functools
import torch.nn as nn
from tqdm import tqdm
7
from typing import Dict, List, Optional
Casper's avatar
Casper committed
8
9
from collections import defaultdict
from awq.utils.calib_data import get_calib_dataset
Casper Hansen's avatar
Casper Hansen committed
10
from awq.quantize.scale import apply_scale, apply_clip
11
12
13
from awq.utils.utils import clear_memory, get_best_device
from awq.modules.linear.gemm import WQLinear_GEMM
from awq.modules.linear.gemv import WQLinear_GEMV
14
from awq.modules.linear.marlin import WQLinear_Marlin
15
16
17
18
19
from awq.utils.module import (
    append_str_prefix,
    get_op_name,
    get_named_linears,
    set_op_by_name,
20
    exclude_layers_to_not_quantize,
21
)
Casper's avatar
Casper committed
22
23
24


class AwqQuantizer:
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
    def __init__(
        self,
        awq_model,
        model,
        tokenizer,
        w_bit,
        group_size,
        zero_point,
        version,
        calib_data,
        split,
        text_column,
        duo_scaling,
        modules_to_not_convert=None,
        export_compatible=False,
    ) -> None:
Casper Hansen's avatar
Casper Hansen committed
41
        self.awq_model = awq_model
Casper's avatar
Casper committed
42
43
44
45
        self.model = model
        self.tokenizer = tokenizer
        self.w_bit = w_bit
        self.group_size = group_size
46
        self.zero_point = zero_point
Casper's avatar
Casper committed
47
48
49
50
        self.version = version
        self.calib_data = calib_data
        self.split = split
        self.text_column = text_column
51
        self.duo_scaling = duo_scaling
52
        self.export_compatible = export_compatible
53
54
55
        self.modules_to_not_convert = (
            modules_to_not_convert if modules_to_not_convert is not None else []
        )
Casper Hansen's avatar
Casper Hansen committed
56
        self.modules, self.module_kwargs, self.inps = self.init_quant()
57
58

    def pseudo_quantize_tensor(self, w: torch.Tensor):
Casper's avatar
Casper committed
59
60
61
62
63
        org_w_shape = w.shape
        if self.group_size > 0:
            assert org_w_shape[-1] % self.group_size == 0
            w = w.reshape(-1, self.group_size)
        assert w.dim() == 2
64
        assert torch.isnan(w).sum() == 0
Casper's avatar
Casper committed
65
66

        # zero point quantization
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
        if self.zero_point:
            max_val = w.amax(dim=1, keepdim=True)
            min_val = w.amin(dim=1, keepdim=True)
            max_int = 2**self.w_bit - 1
            min_int = 0
            scales = (max_val - min_val).clamp(min=1e-5) / max_int
            zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)
            w = (
                torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - zeros
            ) * scales
            zeros = zeros.view(org_w_shape[0], -1)
        else:
            max_val = w.abs().amax(dim=1, keepdim=True)
            max_val = max_val.clamp(min=1e-5)
            max_int = 2 ** (self.w_bit - 1) - 1
            min_int = -(2 ** (self.w_bit - 1))
            scales = max_val / max_int
            zeros = None
            w = torch.clamp(torch.round(w / scales), min_int, max_int) * scales
Casper's avatar
Casper committed
86
87
88
89

        assert torch.isnan(scales).sum() == 0
        assert torch.isnan(w).sum() == 0

90
        scales = scales.view(org_w_shape[0], -1)
Casper's avatar
Casper committed
91
92
        w = w.reshape(org_w_shape)

93
        return w, scales, zeros
Casper's avatar
Casper committed
94

95
96
97
98
99
    def pseudo_dequantize_tensor(
        self, w: nn.Linear, scales: torch.Tensor, zeros: Optional[torch.Tensor] = None
    ):
        # get repeated count
        repeat_count = w.weight.data.shape[-1] // scales.shape[-1]
Casper's avatar
Casper committed
100
101
102
        scales = scales.repeat(1, repeat_count).reshape(w.weight.data.shape)

        # dequantize
103
104
105
106
107
        if self.zero_point:
            zeros = zeros.repeat(1, repeat_count).reshape(w.weight.data.shape)
            w = (w.weight.data - zeros) * scales
        else:
            w = w.weight.data * scales
Casper's avatar
Casper committed
108
109

        return w
110

Casper Hansen's avatar
Casper Hansen committed
111
112
    def quantize(self):
        for i in tqdm(range(len(self.modules)), desc="AWQ"):
Casper's avatar
Casper committed
113
114
115
            # Move module and inputs to correct device
            common_device = next(self.modules[i].parameters()).device
            if common_device is None or str(common_device) == "cpu":
116
117
118
119
120
121
                if torch.cuda.is_available():
                    best_device = "cuda:" + str(i % torch.cuda.device_count())
                else:
                    best_device = get_best_device()
                
                self.modules[i] = self.modules[i].to(best_device)
Casper's avatar
Casper committed
122
                common_device = next(self.modules[i].parameters()).device
123
124

            if self.module_kwargs.get("position_ids") is not None:
125
126
127
                self.module_kwargs["position_ids"] = self.module_kwargs[
                    "position_ids"
                ].to(common_device)
128
129

            if self.module_kwargs.get("attention_mask") is not None:
130
131
132
                self.module_kwargs["attention_mask"] = self.module_kwargs[
                    "attention_mask"
                ].to(common_device)
133

Casper's avatar
Casper committed
134
135
            self.inps = self.inps.to(common_device)

Casper's avatar
Casper committed
136
137
            # [STEP 1]: Get layer, extract linear modules, extract input features
            named_linears = get_named_linears(self.modules[i])
138
139

            # Filter out the linear layers we don't want to exclude
140
141
142
            named_linears = exclude_layers_to_not_quantize(
                named_linears, self.modules_to_not_convert
            )
143

Casper's avatar
Casper committed
144
145
146
147
            input_feat = self._get_input_feat(self.modules[i], named_linears)
            clear_memory()

            # [STEP 2]: Compute and apply scale list
Vik Paruchuri's avatar
Vik Paruchuri committed
148
            module_config: List[Dict] = self.awq_model.get_layers_for_scaling(
Casper's avatar
Casper committed
149
150
                self.modules[i], input_feat, self.module_kwargs
            )
151
152
153
154
            scales_list = [
                self._search_best_scale(self.modules[i], **layer)
                for layer in module_config
            ]
Casper's avatar
Casper committed
155
            apply_scale(self.modules[i], scales_list, input_feat_dict=input_feat)
156
157
158
            scales_list = append_str_prefix(
                scales_list, get_op_name(self.model, self.modules[i]) + "."
            )
Casper's avatar
Casper committed
159
160

            # [STEP 3]: Compute and apply clipping list
161
162
163
            clip_list = self._search_best_clip(
                self.modules[i], named_linears, input_feat
            )
Casper Hansen's avatar
Casper Hansen committed
164
            apply_clip(self.modules[i], clip_list)
165
166
167
            clip_list = append_str_prefix(
                clip_list, get_op_name(self.model, self.modules[i]) + "."
            )
Casper's avatar
Casper committed
168
169

            # [STEP 4]: Quantize weights
170
171
            if not self.export_compatible:
                self._apply_quant(self.modules[i], named_linears)
172

173
            clear_memory()
174

175
176
177
    def pack(self):
        for i in tqdm(range(len(self.modules)), desc="Packing"):
            named_linears = get_named_linears(self.modules[i])
178
179
180
            named_linears = exclude_layers_to_not_quantize(
                named_linears, self.modules_to_not_convert
            )
181
182
            self._apply_quant(self.modules[i], named_linears)
            clear_memory()
183

Vik Paruchuri's avatar
Vik Paruchuri committed
184
    def _apply_quant(self, module, named_linears: Dict[str, nn.Linear]):
185
186
        for name, linear_layer in named_linears.items():
            # NOTE: small regression in perplexity if linear layer uses .cpu().float()
187
            linear_layer = linear_layer.to(get_best_device()).half()
188
189

            linear_layer.weight.data, scales, zeros = self.pseudo_quantize_tensor(
190
                linear_layer.weight.data
191
192
            )

193
            if self.version == "GEMM":
194
195
196
197
                scales = scales.t().contiguous()
                zeros = zeros.t().contiguous()
                q_linear_module = WQLinear_GEMM

198
            elif self.version == "GEMV":
199
                q_linear_module = WQLinear_GEMV
200
201
202
203
204
205
206

            elif self.version == "Marlin":
                q_linear_module = WQLinear_Marlin

            else:
                raise ValueError(f"Unknown version {self.version}")

207
208
209
210
211
212
            q_linear = q_linear_module.from_linear(
                linear=linear_layer,
                w_bit=self.w_bit,
                group_size=self.group_size,
                init_only=False,
                scales=scales,
213
                zeros=zeros,
214
215
216
217
218
            )

            linear_layer.cpu()
            q_linear.to(next(module.parameters()).device)
            set_op_by_name(module, name, q_linear)
Casper's avatar
Casper committed
219
220
221
            clear_memory()

    @torch.no_grad()
222
223
224
225
226
227
228
229
230
    def _search_best_scale(
        self,
        module,
        prev_op,
        layers: List[nn.Linear],
        inp: torch.Tensor,
        module2inspect=None,
        kwargs={},
    ):
Casper Hansen's avatar
Casper Hansen committed
231
232
233
        if module2inspect is None:
            assert len(layers) == 1
            module2inspect = layers[0]
234

Casper Hansen's avatar
Casper Hansen committed
235
236
        if "use_cache" in kwargs:
            kwargs.pop("use_cache")
237

Casper's avatar
Casper committed
238
        # Put x on the right device
Casper Hansen's avatar
Casper Hansen committed
239
        inp = inp.to(next(module2inspect.parameters()).device)
Casper's avatar
Casper committed
240
241

        # [STEP 1]: Compute maximum of weight
Casper Hansen's avatar
Casper Hansen committed
242
243
        weight = torch.cat([_m.weight for _m in layers], dim=0)
        org_shape = weight.shape
Casper's avatar
Casper committed
244
        weight = weight.view(-1, self.group_size)
Casper Hansen's avatar
Casper Hansen committed
245
246
247
        w_scale = weight.abs() / weight.abs().amax(dim=1, keepdim=True)
        w_scale = w_scale.view(org_shape)
        w_max = w_scale.mean(0)
Casper's avatar
Casper committed
248
249
250
        clear_memory(weight)

        # [STEP 2]: Compute maximum of x
Casper Hansen's avatar
Casper Hansen committed
251
        x_max = inp.abs().view(-1, inp.shape[-1]).mean(0)
Casper's avatar
Casper committed
252

Casper Hansen's avatar
Casper Hansen committed
253
        # [STEP 3]: Compute output of module
Casper's avatar
Casper committed
254
        with torch.no_grad():
255
256
257
            module_kwargs = self._sanitize_kwargs(kwargs, module2inspect)

            fp16_output = module2inspect(inp, **module_kwargs)
258
259
            if isinstance(fp16_output, tuple):
                fp16_output = fp16_output[0]
260

Casper's avatar
Casper committed
261
262
        # [STEP 4]: Compute loss
        best_scales = self._compute_best_scale(
263
            inp, w_max, x_max, module2inspect, layers, fp16_output, module_kwargs
Casper's avatar
Casper committed
264
265
        )

266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
        return (
            get_op_name(module, prev_op),
            tuple([get_op_name(module, m) for m in layers]),
            best_scales,
        )

    def _compute_best_scale(
        self,
        x,
        w_max,
        x_max,
        module2inspect,
        linears2scale: List[nn.Linear],
        fp16_output,
        kwargs={},
    ):
Casper's avatar
Casper committed
282
283
284
        """
        Compute loss and select best scales

Casper's avatar
Casper committed
285
        L(s) = || Q(W * s) (s^-1 * X) - W * X ||
Casper's avatar
Casper committed
286
287
288
289
290
291
292
293
294
        Q: weight quantization function | pseudo_quantize_tensor(W * s)
        X: inputs from calib dataset    | X
        W: original weights in FP16     | layer
        s: per channel scaling factor   | s^-1 * X
        """
        n_grid = 20
        history = []
        best_ratio = -1
        best_scales = None
295
        best_error = float("inf")
Casper's avatar
Casper committed
296

Casper Hansen's avatar
Casper Hansen committed
297
        org_sd = {k: v.cpu() for k, v in module2inspect.state_dict().items()}
298

Casper's avatar
Casper committed
299
300
301
        device = x.device
        x_max = x_max.view(-1).to(device)
        w_max = w_max.view(-1).to(device)
302

Casper's avatar
Casper committed
303
304
        for ratio in range(n_grid):
            # create new scales
Casper's avatar
Casper committed
305
            ratio = ratio / n_grid
306

Casper Hansen's avatar
Casper Hansen committed
307
            # NOTE: s^-1 * x is fused here, according to paper
308
            if self.duo_scaling:
309
                scales = (x_max.pow(ratio) / w_max.pow(1 - ratio)).clamp(min=1e-4)
310
311
            else:
                scales = x_max.pow(ratio).clamp(min=1e-4).view(-1)
Casper's avatar
Casper committed
312
            scales = scales / (scales.max() * scales.min()).sqrt()
Casper's avatar
Casper committed
313
            scales_view = scales.view(1, -1).to(device)
314

Casper Hansen's avatar
Casper Hansen committed
315
            # Q(W * s)
Casper's avatar
Casper committed
316
            for fc in linears2scale:
Casper's avatar
Casper committed
317
                fc.weight.mul_(scales_view)
318
319
320
                fc.weight.data = (
                    self.pseudo_quantize_tensor(fc.weight.data)[0] / scales_view
                )
Casper's avatar
Casper committed
321

322
323
324
325
            # W * X
            int_w_output = module2inspect(x, **kwargs)
            if isinstance(int_w_output, tuple):
                int_w_output = int_w_output[0]
326

Casper Hansen's avatar
Casper Hansen committed
327
            # compute mean squared error (L2 norm)
328
329
330
            loss = (
                (fp16_output - int_w_output).float().pow(2).mean().item()
            )  # NOTE: float prevents overflow
Casper's avatar
Casper committed
331
332

            history.append(loss)
Casper's avatar
Casper committed
333
            if loss < best_error:
Casper's avatar
Casper committed
334
335
                best_error = loss
                best_ratio = ratio
Casper's avatar
Casper committed
336
                best_scales = scales.clone()
Casper Hansen's avatar
Casper Hansen committed
337
            module2inspect.load_state_dict(org_sd)
Casper's avatar
Casper committed
338

Casper's avatar
Casper committed
339
340
341
342
343
344
        if best_ratio == -1:
            logging.debug(history)
            raise Exception

        assert torch.isnan(best_scales).sum() == 0, best_scales

Casper Hansen's avatar
Casper Hansen committed
345
        return best_scales.detach().cpu()
Casper's avatar
Casper committed
346

Casper Hansen's avatar
Casper Hansen committed
347
348
349
350
    @torch.no_grad()
    def _search_best_clip(self, layer, named_linears, input_feat):
        clip_list = []
        avoid_clipping = ["q_", "k_", "query", "key", "Wqkv"]
Casper's avatar
Casper committed
351

Casper Hansen's avatar
Casper Hansen committed
352
353
354
355
356
        for name in named_linears:
            # due to qk bmm, it is hard to clip precisely
            if any([_ in name for _ in avoid_clipping]):
                continue

357
            named_linears[name].to(get_best_device())
Casper Hansen's avatar
Casper Hansen committed
358
359
360
            max_val = self._compute_best_clip(named_linears[name].weight, input_feat[name])
            clip_list.append((name, max_val))
            named_linears[name].cpu()
361

Casper Hansen's avatar
Casper Hansen committed
362
        return clip_list
Casper Hansen's avatar
Casper Hansen committed
363
364

    @torch.no_grad()
365
366
367
368
369
370
371
372
    def _compute_best_clip(
        self,
        w: torch.Tensor,
        input_feat: torch.Tensor,
        n_grid=20,
        max_shrink=0.5,
        n_sample_token=512,
    ):
Casper Hansen's avatar
Casper Hansen committed
373
374
375
376
        assert w.dim() == 2
        org_w_shape = w.shape
        # w           [co, ci]      -> [co, 1, n_group, group size]
        # input_feat  [n_token, ci] -> [1, n_token, n_group, group size]
377
        group_size = self.group_size if self.group_size > 0 else org_w_shape[1]
Casper Hansen's avatar
Casper Hansen committed
378
379
        input_feat = input_feat.view(-1, input_feat.shape[-1])
        input_feat = input_feat.reshape(1, input_feat.shape[0], -1, group_size)
380
381
        input_feat = input_feat[:, 0 :: input_feat.shape[1] // n_sample_token]
        w = w.reshape(org_w_shape[0], 1, -1, group_size)
Casper Hansen's avatar
Casper Hansen committed
382

383
384
        oc_batch_size = 256 if org_w_shape[0] % 256 == 0 else 64  # prevent OOM
        assert org_w_shape[0] % oc_batch_size == 0
Casper Hansen's avatar
Casper Hansen committed
385
386
        w_all = w
        best_max_val_all = []
Casper's avatar
Casper committed
387

388
389
        for i_b in range(org_w_shape[0] // oc_batch_size):
            w = w_all[i_b * oc_batch_size : (i_b + 1) * oc_batch_size]
Casper Hansen's avatar
Casper Hansen committed
390
391
392
393
394
395
396
397
398
399

            org_max_val = w.abs().amax(dim=-1, keepdim=True)  # co, 1, n_group, 1

            best_max_val = org_max_val.clone()
            min_errs = torch.ones_like(org_max_val) * 1e9
            input_feat = input_feat.to(w.device)
            org_out = (input_feat * w).sum(dim=-1)  # co, n_token, n_group

            for i_s in range(int(max_shrink * n_grid)):
                max_val = org_max_val * (1 - i_s / n_grid)
400
                min_val = -max_val
Casper Hansen's avatar
Casper Hansen committed
401
                cur_w = torch.clamp(w, min_val, max_val)
402
                q_w = self.pseudo_quantize_tensor(cur_w)[0]
Casper Hansen's avatar
Casper Hansen committed
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
                cur_out = (input_feat * q_w).sum(dim=-1)

                # co, 1, n_group, 1
                err = (cur_out - org_out).pow(2).mean(dim=1).view(min_errs.shape)
                del cur_w
                del cur_out
                cur_best_idx = err < min_errs
                min_errs[cur_best_idx] = err[cur_best_idx]
                best_max_val[cur_best_idx] = max_val[cur_best_idx]
            best_max_val_all.append(best_max_val)

        best_max_val = torch.cat(best_max_val_all, dim=0)

        clear_memory(input_feat)
        clear_memory(org_out)

        return best_max_val.squeeze(1)

421
    def init_quant(self, n_samples=2, seqlen=512):
Casper Hansen's avatar
Casper Hansen committed
422
        modules = self.awq_model.get_model_layers(self.model)
Casper's avatar
Casper committed
423
        samples = get_calib_dataset(
424
425
426
427
428
429
            data=self.calib_data,
            tokenizer=self.tokenizer,
            n_samples=n_samples,
            block_size=seqlen,
            split=self.split,
            text_column=self.text_column,
Casper's avatar
Casper committed
430
431
432
433
434
435
        )
        samples = torch.cat(samples, dim=0)

        inps = []
        layer_kwargs = {}

436
437
438
        best_device = get_best_device()
        modules[0] = modules[0].to(best_device)
        self.awq_model.move_embed(self.model, best_device)
439

Casper's avatar
Casper committed
440
441
442
443
444
445
446
447
        # get input and kwargs to layer 0
        # with_kwargs is only supported in PyTorch 2.0
        # use this Catcher hack for now
        class Catcher(nn.Module):
            def __init__(self, module):
                super().__init__()
                self.module = module

448
449
450
451
452
453
454
455
456
457
            def forward(self, *args, **kwargs):
                # assume first input to forward is hidden states
                if len(args) > 0:
                    hidden_states = args[0]
                    del args
                else:
                    first_key = list(kwargs.keys())[0]
                    hidden_states = kwargs.pop(first_key)

                inps.append(hidden_states)
Casper's avatar
Casper committed
458
459
460
461
                layer_kwargs.update(kwargs)
                raise ValueError  # early exit to break later inference

        # patch layer 0 to catch input and kwargs
Casper Hansen's avatar
Casper Hansen committed
462
        modules[0] = Catcher(modules[0])
Casper's avatar
Casper committed
463
464
465
466
        try:
            self.model(samples.to(next(self.model.parameters()).device))
        except ValueError:  # work with early exit
            pass
467

468
469
470
471
472
473
        # Update the layer kwargs with `prepare_inputs_for_generation` method
        # that takes care of everything to avoid unexpected errors.
        layer_kwargs = self.model.prepare_inputs_for_generation(samples, **layer_kwargs)
        # Pop the input_ids as they are not needed at all.
        layer_kwargs.pop("input_ids")

Casper's avatar
Casper committed
474
        del samples
Casper Hansen's avatar
Casper Hansen committed
475
        modules[0] = modules[0].module  # restore
Casper's avatar
Casper committed
476
477
        inps = inps[0]

Casper Hansen's avatar
Casper Hansen committed
478
479
        modules[0] = modules[0].cpu()
        self.awq_model.move_embed(self.model, "cpu")
480

Casper's avatar
Casper committed
481
        clear_memory()
482

Casper's avatar
Casper committed
483
        if layer_kwargs.get("attention_mask") is not None:
484
            layer_kwargs["attention_mask"] = layer_kwargs["attention_mask"].to(best_device)
Casper's avatar
Casper committed
485

Casper Hansen's avatar
Casper Hansen committed
486
        return modules, layer_kwargs, inps
487

Casper's avatar
Casper committed
488
489
490
491
492
493
494
495
496
    def _get_input_feat(self, layer, named_linears):
        # firstly, get input features of all linear layers
        def cache_input_hook(m, x, y, name, feat_dict):
            x = x[0]
            x = x.detach().cpu()
            feat_dict[name].append(x)

        input_feat = defaultdict(list)
        handles = []
497
498
499

        # FIXME: Workaround for Mixtral to use block_sparse_moe input features
        if self.awq_model.model_type == "mixtral":
500
501
502
503
            named_linears = {
                **named_linears,
                "block_sparse_moe": layer.block_sparse_moe,
            }
504

Casper's avatar
Casper committed
505
        for name in named_linears:
506
507
508
509
510
            handles.append(
                named_linears[name].register_forward_hook(
                    functools.partial(cache_input_hook, name=name, feat_dict=input_feat)
                )
            )
Casper Hansen's avatar
Casper Hansen committed
511
        self.inps = self.inps.to(next(layer.parameters()).device)  # in case multi-gpu
Casper's avatar
Casper committed
512
        # get output as next layer's input
513

514
515
516
517
518
519
        # Sanitize the kwargs in case we use transformers version that contains
        # kwargs that are not handled by the module.
        # Useful for trust_remote_code models.
        module_kwargs = self._sanitize_kwargs(self.module_kwargs, layer)

        self.inps = layer(self.inps, **module_kwargs)[0]
Casper's avatar
Casper committed
520
521
522
523
        for h in handles:
            h.remove()
        # now solve for scaling and clipping
        input_feat = {k: torch.cat(v, dim=0) for k, v in input_feat.items()}
524

525
        return input_feat
526
527
528
529
530

    def _sanitize_kwargs(self, inputs_kwargs, module):
        """
        Remove the arguments that are not supported in the module's
        forward pass to avoid breaking behaviour between different versions
531
        of transformers.
532
533
534
535
536
537
538
539
540

        Args:
            inputs_kwargs (`dict`):
                The input dictionary to pass to the model layer
            module (`torch.nn.Module`):
                Target module to quantize.
        """
        module_signature = inspect.signature(module.forward).parameters
        sanitized_kwargs = {}
541
        for k, v in inputs_kwargs.items():
542
543
            if k in module_signature:
                sanitized_kwargs[k] = v
544
        return sanitized_kwargs