quantizer.py 19 KB
Newer Older
Ji Lin's avatar
Ji Lin committed
1
import torch
2
import inspect
Casper's avatar
Casper committed
3
4
5
6
import logging
import functools
import torch.nn as nn
from tqdm import tqdm
7
from typing import Dict, List, Optional
Casper's avatar
Casper committed
8
9
from collections import defaultdict
from awq.utils.calib_data import get_calib_dataset
Casper Hansen's avatar
Casper Hansen committed
10
from awq.quantize.scale import apply_scale, apply_clip
11
12
13
from awq.utils.utils import clear_memory, get_best_device
from awq.modules.linear.gemm import WQLinear_GEMM
from awq.modules.linear.gemv import WQLinear_GEMV
14
from awq.modules.linear.marlin import WQLinear_Marlin
15
16
17
18
19
from awq.utils.module import (
    append_str_prefix,
    get_op_name,
    get_named_linears,
    set_op_by_name,
20
    exclude_layers_to_not_quantize,
21
)
Casper's avatar
Casper committed
22
23
24


class AwqQuantizer:
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
    def __init__(
        self,
        awq_model,
        model,
        tokenizer,
        w_bit,
        group_size,
        zero_point,
        version,
        calib_data,
        split,
        text_column,
        duo_scaling,
        modules_to_not_convert=None,
        export_compatible=False,
    ) -> None:
Casper Hansen's avatar
Casper Hansen committed
41
        self.awq_model = awq_model
Casper's avatar
Casper committed
42
43
44
45
        self.model = model
        self.tokenizer = tokenizer
        self.w_bit = w_bit
        self.group_size = group_size
46
        self.zero_point = zero_point
Casper's avatar
Casper committed
47
48
49
50
        self.version = version
        self.calib_data = calib_data
        self.split = split
        self.text_column = text_column
51
        self.duo_scaling = duo_scaling
52
        self.export_compatible = export_compatible
53
54
55
        self.modules_to_not_convert = (
            modules_to_not_convert if modules_to_not_convert is not None else []
        )
Casper Hansen's avatar
Casper Hansen committed
56
        self.modules, self.module_kwargs, self.inps = self.init_quant()
57
58

    def pseudo_quantize_tensor(self, w: torch.Tensor):
Casper's avatar
Casper committed
59
60
61
62
63
        org_w_shape = w.shape
        if self.group_size > 0:
            assert org_w_shape[-1] % self.group_size == 0
            w = w.reshape(-1, self.group_size)
        assert w.dim() == 2
64
        assert torch.isnan(w).sum() == 0
Casper's avatar
Casper committed
65
66

        # zero point quantization
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
        if self.zero_point:
            max_val = w.amax(dim=1, keepdim=True)
            min_val = w.amin(dim=1, keepdim=True)
            max_int = 2**self.w_bit - 1
            min_int = 0
            scales = (max_val - min_val).clamp(min=1e-5) / max_int
            zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)
            w = (
                torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - zeros
            ) * scales
            zeros = zeros.view(org_w_shape[0], -1)
        else:
            max_val = w.abs().amax(dim=1, keepdim=True)
            max_val = max_val.clamp(min=1e-5)
            max_int = 2 ** (self.w_bit - 1) - 1
            min_int = -(2 ** (self.w_bit - 1))
            scales = max_val / max_int
            zeros = None
            w = torch.clamp(torch.round(w / scales), min_int, max_int) * scales
Casper's avatar
Casper committed
86
87
88
89

        assert torch.isnan(scales).sum() == 0
        assert torch.isnan(w).sum() == 0

90
        scales = scales.view(org_w_shape[0], -1)
Casper's avatar
Casper committed
91
92
        w = w.reshape(org_w_shape)

93
        return w, scales, zeros
Casper's avatar
Casper committed
94

95
96
97
98
99
    def pseudo_dequantize_tensor(
        self, w: nn.Linear, scales: torch.Tensor, zeros: Optional[torch.Tensor] = None
    ):
        # get repeated count
        repeat_count = w.weight.data.shape[-1] // scales.shape[-1]
Casper's avatar
Casper committed
100
101
102
        scales = scales.repeat(1, repeat_count).reshape(w.weight.data.shape)

        # dequantize
103
104
105
106
107
        if self.zero_point:
            zeros = zeros.repeat(1, repeat_count).reshape(w.weight.data.shape)
            w = (w.weight.data - zeros) * scales
        else:
            w = w.weight.data * scales
Casper's avatar
Casper committed
108
109

        return w
110

Casper Hansen's avatar
Casper Hansen committed
111
112
    def quantize(self):
        for i in tqdm(range(len(self.modules)), desc="AWQ"):
Casper's avatar
Casper committed
113
114
115
            # Move module and inputs to correct device
            common_device = next(self.modules[i].parameters()).device
            if common_device is None or str(common_device) == "cpu":
116
117
118
119
                if torch.cuda.is_available():
                    best_device = "cuda:" + str(i % torch.cuda.device_count())
                else:
                    best_device = get_best_device()
Casper's avatar
Casper committed
120

121
                self.modules[i] = self.modules[i].to(best_device)
Casper's avatar
Casper committed
122
                common_device = next(self.modules[i].parameters()).device
123
124

            if self.module_kwargs.get("position_ids") is not None:
125
126
127
                self.module_kwargs["position_ids"] = self.module_kwargs[
                    "position_ids"
                ].to(common_device)
128
129

            if self.module_kwargs.get("attention_mask") is not None:
130
131
132
                self.module_kwargs["attention_mask"] = self.module_kwargs[
                    "attention_mask"
                ].to(common_device)
133

Casper's avatar
Casper committed
134
135
            self.inps = self.inps.to(common_device)

Casper's avatar
Casper committed
136
137
            # [STEP 1]: Get layer, extract linear modules, extract input features
            named_linears = get_named_linears(self.modules[i])
138
139

            # Filter out the linear layers we don't want to exclude
140
141
142
            named_linears = exclude_layers_to_not_quantize(
                named_linears, self.modules_to_not_convert
            )
143

Casper's avatar
Casper committed
144
145
146
147
            input_feat = self._get_input_feat(self.modules[i], named_linears)
            clear_memory()

            # [STEP 2]: Compute and apply scale list
Vik Paruchuri's avatar
Vik Paruchuri committed
148
            module_config: List[Dict] = self.awq_model.get_layers_for_scaling(
Casper's avatar
Casper committed
149
150
                self.modules[i], input_feat, self.module_kwargs
            )
151
152
153
154
            scales_list = [
                self._search_best_scale(self.modules[i], **layer)
                for layer in module_config
            ]
Casper's avatar
Casper committed
155
            apply_scale(self.modules[i], scales_list, input_feat_dict=input_feat)
156
157
158
            scales_list = append_str_prefix(
                scales_list, get_op_name(self.model, self.modules[i]) + "."
            )
Casper's avatar
Casper committed
159
160

            # [STEP 3]: Compute and apply clipping list
161
162
163
            clip_list = self._search_best_clip(
                self.modules[i], named_linears, input_feat
            )
Casper Hansen's avatar
Casper Hansen committed
164
            apply_clip(self.modules[i], clip_list)
165
166
167
            clip_list = append_str_prefix(
                clip_list, get_op_name(self.model, self.modules[i]) + "."
            )
Casper's avatar
Casper committed
168
169

            # [STEP 4]: Quantize weights
170
171
            if not self.export_compatible:
                self._apply_quant(self.modules[i], named_linears)
172

173
            clear_memory()
174

175
176
177
    def pack(self):
        for i in tqdm(range(len(self.modules)), desc="Packing"):
            named_linears = get_named_linears(self.modules[i])
178
179
180
            named_linears = exclude_layers_to_not_quantize(
                named_linears, self.modules_to_not_convert
            )
181
182
            self._apply_quant(self.modules[i], named_linears)
            clear_memory()
183

Vik Paruchuri's avatar
Vik Paruchuri committed
184
    def _apply_quant(self, module, named_linears: Dict[str, nn.Linear]):
185
186
        for name, linear_layer in named_linears.items():
            # NOTE: small regression in perplexity if linear layer uses .cpu().float()
187
            linear_layer = linear_layer.to(get_best_device()).half()
188
189

            linear_layer.weight.data, scales, zeros = self.pseudo_quantize_tensor(
190
                linear_layer.weight.data
191
192
            )

Casper's avatar
Casper committed
193
            if self.version == "gemm":
194
195
196
197
                scales = scales.t().contiguous()
                zeros = zeros.t().contiguous()
                q_linear_module = WQLinear_GEMM

Casper's avatar
Casper committed
198
            elif self.version == "gemv":
199
                q_linear_module = WQLinear_GEMV
200

Casper's avatar
Casper committed
201
            elif self.version == "marlin":
202
203
204
205
206
                q_linear_module = WQLinear_Marlin

            else:
                raise ValueError(f"Unknown version {self.version}")

207
208
209
210
211
212
            q_linear = q_linear_module.from_linear(
                linear=linear_layer,
                w_bit=self.w_bit,
                group_size=self.group_size,
                init_only=False,
                scales=scales,
213
                zeros=zeros,
214
215
216
217
218
            )

            linear_layer.cpu()
            q_linear.to(next(module.parameters()).device)
            set_op_by_name(module, name, q_linear)
Casper's avatar
Casper committed
219
220
221
            clear_memory()

    @torch.no_grad()
222
223
224
225
226
227
228
229
230
    def _search_best_scale(
        self,
        module,
        prev_op,
        layers: List[nn.Linear],
        inp: torch.Tensor,
        module2inspect=None,
        kwargs={},
    ):
Casper Hansen's avatar
Casper Hansen committed
231
232
233
        if module2inspect is None:
            assert len(layers) == 1
            module2inspect = layers[0]
234

Casper Hansen's avatar
Casper Hansen committed
235
236
        if "use_cache" in kwargs:
            kwargs.pop("use_cache")
237

Casper's avatar
Casper committed
238
        # Put x on the right device
Casper Hansen's avatar
Casper Hansen committed
239
        inp = inp.to(next(module2inspect.parameters()).device)
Casper's avatar
Casper committed
240
241

        # [STEP 1]: Compute maximum of weight
Casper Hansen's avatar
Casper Hansen committed
242
243
        weight = torch.cat([_m.weight for _m in layers], dim=0)
        org_shape = weight.shape
Casper's avatar
Casper committed
244
        weight = weight.view(-1, self.group_size)
Casper Hansen's avatar
Casper Hansen committed
245
246
247
        w_scale = weight.abs() / weight.abs().amax(dim=1, keepdim=True)
        w_scale = w_scale.view(org_shape)
        w_max = w_scale.mean(0)
Casper's avatar
Casper committed
248
249
250
        clear_memory(weight)

        # [STEP 2]: Compute maximum of x
Casper Hansen's avatar
Casper Hansen committed
251
        x_max = inp.abs().view(-1, inp.shape[-1]).mean(0)
Casper's avatar
Casper committed
252

Casper Hansen's avatar
Casper Hansen committed
253
        # [STEP 3]: Compute output of module
Casper's avatar
Casper committed
254
        with torch.no_grad():
255
256
257
            module_kwargs = self._sanitize_kwargs(kwargs, module2inspect)

            fp16_output = module2inspect(inp, **module_kwargs)
258
259
            if isinstance(fp16_output, tuple):
                fp16_output = fp16_output[0]
260

Casper's avatar
Casper committed
261
262
        # [STEP 4]: Compute loss
        best_scales = self._compute_best_scale(
263
            inp, w_max, x_max, module2inspect, layers, fp16_output, module_kwargs
Casper's avatar
Casper committed
264
265
        )

266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
        return (
            get_op_name(module, prev_op),
            tuple([get_op_name(module, m) for m in layers]),
            best_scales,
        )

    def _compute_best_scale(
        self,
        x,
        w_max,
        x_max,
        module2inspect,
        linears2scale: List[nn.Linear],
        fp16_output,
        kwargs={},
    ):
Casper's avatar
Casper committed
282
283
284
        """
        Compute loss and select best scales

Casper's avatar
Casper committed
285
        L(s) = || Q(W * s) (s^-1 * X) - W * X ||
Casper's avatar
Casper committed
286
287
288
289
290
291
292
293
294
        Q: weight quantization function | pseudo_quantize_tensor(W * s)
        X: inputs from calib dataset    | X
        W: original weights in FP16     | layer
        s: per channel scaling factor   | s^-1 * X
        """
        n_grid = 20
        history = []
        best_ratio = -1
        best_scales = None
295
        best_error = float("inf")
Casper's avatar
Casper committed
296

Casper Hansen's avatar
Casper Hansen committed
297
        org_sd = {k: v.cpu() for k, v in module2inspect.state_dict().items()}
298

Casper's avatar
Casper committed
299
300
301
        device = x.device
        x_max = x_max.view(-1).to(device)
        w_max = w_max.view(-1).to(device)
302

Casper's avatar
Casper committed
303
304
        for ratio in range(n_grid):
            # create new scales
Casper's avatar
Casper committed
305
            ratio = ratio / n_grid
306

Casper Hansen's avatar
Casper Hansen committed
307
            # NOTE: s^-1 * x is fused here, according to paper
308
            if self.duo_scaling:
309
                scales = (x_max.pow(ratio) / w_max.pow(1 - ratio)).clamp(min=1e-4)
310
311
            else:
                scales = x_max.pow(ratio).clamp(min=1e-4).view(-1)
Casper's avatar
Casper committed
312
            scales = scales / (scales.max() * scales.min()).sqrt()
Casper's avatar
Casper committed
313
            scales_view = scales.view(1, -1).to(device)
314

Casper Hansen's avatar
Casper Hansen committed
315
            # Q(W * s)
Casper's avatar
Casper committed
316
            for fc in linears2scale:
Casper's avatar
Casper committed
317
                fc.weight.mul_(scales_view)
318
319
320
                fc.weight.data = (
                    self.pseudo_quantize_tensor(fc.weight.data)[0] / scales_view
                )
Casper's avatar
Casper committed
321

322
323
324
325
            # W * X
            int_w_output = module2inspect(x, **kwargs)
            if isinstance(int_w_output, tuple):
                int_w_output = int_w_output[0]
326

Casper Hansen's avatar
Casper Hansen committed
327
            # compute mean squared error (L2 norm)
328
329
330
            loss = (
                (fp16_output - int_w_output).float().pow(2).mean().item()
            )  # NOTE: float prevents overflow
Casper's avatar
Casper committed
331
332

            history.append(loss)
Casper's avatar
Casper committed
333
            if loss < best_error:
Casper's avatar
Casper committed
334
335
                best_error = loss
                best_ratio = ratio
Casper's avatar
Casper committed
336
                best_scales = scales.clone()
Casper Hansen's avatar
Casper Hansen committed
337
            module2inspect.load_state_dict(org_sd)
Casper's avatar
Casper committed
338

Casper's avatar
Casper committed
339
340
341
342
343
344
        if best_ratio == -1:
            logging.debug(history)
            raise Exception

        assert torch.isnan(best_scales).sum() == 0, best_scales

Casper Hansen's avatar
Casper Hansen committed
345
        return best_scales.detach().cpu()
Casper's avatar
Casper committed
346

Casper Hansen's avatar
Casper Hansen committed
347
348
349
350
    @torch.no_grad()
    def _search_best_clip(self, layer, named_linears, input_feat):
        clip_list = []
        avoid_clipping = ["q_", "k_", "query", "key", "Wqkv"]
Casper's avatar
Casper committed
351

Casper Hansen's avatar
Casper Hansen committed
352
353
354
355
356
        for name in named_linears:
            # due to qk bmm, it is hard to clip precisely
            if any([_ in name for _ in avoid_clipping]):
                continue

357
            named_linears[name].to(get_best_device())
Casper's avatar
Casper committed
358
359
360
            max_val = self._compute_best_clip(
                named_linears[name].weight, input_feat[name]
            )
Casper Hansen's avatar
Casper Hansen committed
361
362
            clip_list.append((name, max_val))
            named_linears[name].cpu()
363

Casper Hansen's avatar
Casper Hansen committed
364
        return clip_list
Casper Hansen's avatar
Casper Hansen committed
365
366

    @torch.no_grad()
367
368
369
370
371
372
373
374
    def _compute_best_clip(
        self,
        w: torch.Tensor,
        input_feat: torch.Tensor,
        n_grid=20,
        max_shrink=0.5,
        n_sample_token=512,
    ):
Casper Hansen's avatar
Casper Hansen committed
375
376
377
378
        assert w.dim() == 2
        org_w_shape = w.shape
        # w           [co, ci]      -> [co, 1, n_group, group size]
        # input_feat  [n_token, ci] -> [1, n_token, n_group, group size]
379
        group_size = self.group_size if self.group_size > 0 else org_w_shape[1]
Casper Hansen's avatar
Casper Hansen committed
380
381
        input_feat = input_feat.view(-1, input_feat.shape[-1])
        input_feat = input_feat.reshape(1, input_feat.shape[0], -1, group_size)
382
383
        input_feat = input_feat[:, 0 :: input_feat.shape[1] // n_sample_token]
        w = w.reshape(org_w_shape[0], 1, -1, group_size)
Casper Hansen's avatar
Casper Hansen committed
384

385
386
        oc_batch_size = 256 if org_w_shape[0] % 256 == 0 else 64  # prevent OOM
        assert org_w_shape[0] % oc_batch_size == 0
Casper Hansen's avatar
Casper Hansen committed
387
388
        w_all = w
        best_max_val_all = []
Casper's avatar
Casper committed
389

390
391
        for i_b in range(org_w_shape[0] // oc_batch_size):
            w = w_all[i_b * oc_batch_size : (i_b + 1) * oc_batch_size]
Casper Hansen's avatar
Casper Hansen committed
392
393
394
395
396
397
398
399
400
401

            org_max_val = w.abs().amax(dim=-1, keepdim=True)  # co, 1, n_group, 1

            best_max_val = org_max_val.clone()
            min_errs = torch.ones_like(org_max_val) * 1e9
            input_feat = input_feat.to(w.device)
            org_out = (input_feat * w).sum(dim=-1)  # co, n_token, n_group

            for i_s in range(int(max_shrink * n_grid)):
                max_val = org_max_val * (1 - i_s / n_grid)
402
                min_val = -max_val
Casper Hansen's avatar
Casper Hansen committed
403
                cur_w = torch.clamp(w, min_val, max_val)
404
                q_w = self.pseudo_quantize_tensor(cur_w)[0]
Casper Hansen's avatar
Casper Hansen committed
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
                cur_out = (input_feat * q_w).sum(dim=-1)

                # co, 1, n_group, 1
                err = (cur_out - org_out).pow(2).mean(dim=1).view(min_errs.shape)
                del cur_w
                del cur_out
                cur_best_idx = err < min_errs
                min_errs[cur_best_idx] = err[cur_best_idx]
                best_max_val[cur_best_idx] = max_val[cur_best_idx]
            best_max_val_all.append(best_max_val)

        best_max_val = torch.cat(best_max_val_all, dim=0)

        clear_memory(input_feat)
        clear_memory(org_out)

        return best_max_val.squeeze(1)

Casper's avatar
Casper committed
423
    def init_quant(self, n_samples=128, seqlen=512):
Casper Hansen's avatar
Casper Hansen committed
424
        modules = self.awq_model.get_model_layers(self.model)
Casper's avatar
Casper committed
425
        samples = get_calib_dataset(
426
427
428
429
430
431
            data=self.calib_data,
            tokenizer=self.tokenizer,
            n_samples=n_samples,
            block_size=seqlen,
            split=self.split,
            text_column=self.text_column,
Casper's avatar
Casper committed
432
433
434
435
436
437
        )
        samples = torch.cat(samples, dim=0)

        inps = []
        layer_kwargs = {}

438
439
440
        best_device = get_best_device()
        modules[0] = modules[0].to(best_device)
        self.awq_model.move_embed(self.model, best_device)
441

Casper's avatar
Casper committed
442
443
444
445
446
447
448
449
        # get input and kwargs to layer 0
        # with_kwargs is only supported in PyTorch 2.0
        # use this Catcher hack for now
        class Catcher(nn.Module):
            def __init__(self, module):
                super().__init__()
                self.module = module

450
451
452
453
454
455
456
457
458
459
            def forward(self, *args, **kwargs):
                # assume first input to forward is hidden states
                if len(args) > 0:
                    hidden_states = args[0]
                    del args
                else:
                    first_key = list(kwargs.keys())[0]
                    hidden_states = kwargs.pop(first_key)

                inps.append(hidden_states)
Casper's avatar
Casper committed
460
461
462
463
                layer_kwargs.update(kwargs)
                raise ValueError  # early exit to break later inference

        # patch layer 0 to catch input and kwargs
Casper Hansen's avatar
Casper Hansen committed
464
        modules[0] = Catcher(modules[0])
Casper's avatar
Casper committed
465
466
467
468
        try:
            self.model(samples.to(next(self.model.parameters()).device))
        except ValueError:  # work with early exit
            pass
469

470
471
472
473
474
475
        # Update the layer kwargs with `prepare_inputs_for_generation` method
        # that takes care of everything to avoid unexpected errors.
        layer_kwargs = self.model.prepare_inputs_for_generation(samples, **layer_kwargs)
        # Pop the input_ids as they are not needed at all.
        layer_kwargs.pop("input_ids")

Casper's avatar
Casper committed
476
        del samples
Casper Hansen's avatar
Casper Hansen committed
477
        modules[0] = modules[0].module  # restore
Casper's avatar
Casper committed
478
479
        inps = inps[0]

Casper Hansen's avatar
Casper Hansen committed
480
481
        modules[0] = modules[0].cpu()
        self.awq_model.move_embed(self.model, "cpu")
482

Casper's avatar
Casper committed
483
        clear_memory()
484

Casper's avatar
Casper committed
485
        if layer_kwargs.get("attention_mask") is not None:
Casper's avatar
Casper committed
486
487
488
            layer_kwargs["attention_mask"] = layer_kwargs["attention_mask"].to(
                best_device
            )
Casper's avatar
Casper committed
489

Casper Hansen's avatar
Casper Hansen committed
490
        return modules, layer_kwargs, inps
491

Casper's avatar
Casper committed
492
493
494
495
496
497
498
499
500
    def _get_input_feat(self, layer, named_linears):
        # firstly, get input features of all linear layers
        def cache_input_hook(m, x, y, name, feat_dict):
            x = x[0]
            x = x.detach().cpu()
            feat_dict[name].append(x)

        input_feat = defaultdict(list)
        handles = []
501
502
503

        # FIXME: Workaround for Mixtral to use block_sparse_moe input features
        if self.awq_model.model_type == "mixtral":
504
505
506
507
            named_linears = {
                **named_linears,
                "block_sparse_moe": layer.block_sparse_moe,
            }
508

Casper's avatar
Casper committed
509
        for name in named_linears:
510
511
512
513
514
            handles.append(
                named_linears[name].register_forward_hook(
                    functools.partial(cache_input_hook, name=name, feat_dict=input_feat)
                )
            )
Casper Hansen's avatar
Casper Hansen committed
515
        self.inps = self.inps.to(next(layer.parameters()).device)  # in case multi-gpu
Casper's avatar
Casper committed
516
        # get output as next layer's input
517

518
519
520
521
522
523
        # Sanitize the kwargs in case we use transformers version that contains
        # kwargs that are not handled by the module.
        # Useful for trust_remote_code models.
        module_kwargs = self._sanitize_kwargs(self.module_kwargs, layer)

        self.inps = layer(self.inps, **module_kwargs)[0]
Casper's avatar
Casper committed
524
525
526
527
        for h in handles:
            h.remove()
        # now solve for scaling and clipping
        input_feat = {k: torch.cat(v, dim=0) for k, v in input_feat.items()}
528

529
        return input_feat
530
531
532
533
534

    def _sanitize_kwargs(self, inputs_kwargs, module):
        """
        Remove the arguments that are not supported in the module's
        forward pass to avoid breaking behaviour between different versions
535
        of transformers.
536
537
538
539
540
541
542
543
544

        Args:
            inputs_kwargs (`dict`):
                The input dictionary to pass to the model layer
            module (`torch.nn.Module`):
                Target module to quantize.
        """
        module_signature = inspect.signature(module.forward).parameters
        sanitized_kwargs = {}
545
        for k, v in inputs_kwargs.items():
546
547
            if k in module_signature:
                sanitized_kwargs[k] = v
548
        return sanitized_kwargs