quantizer.py 19.7 KB
Newer Older
Ji Lin's avatar
Ji Lin committed
1
import torch
2
import inspect
Casper's avatar
Casper committed
3
4
5
6
import logging
import functools
import torch.nn as nn
from tqdm import tqdm
7
from typing import Dict, List, Optional
Casper's avatar
Casper committed
8
9
from collections import defaultdict
from awq.utils.calib_data import get_calib_dataset
Casper Hansen's avatar
Casper Hansen committed
10
from awq.quantize.scale import apply_scale, apply_clip
11
from awq.utils.utils import clear_memory, get_best_device
Casper's avatar
Casper committed
12
13
14
15
16
17
from awq.modules.linear import (
    WQLinear_GEMM,
    WQLinear_GEMV,
    WQLinear_Marlin,
    WQLinear_GEMVFast,
)
18
19
20
21
22
from awq.utils.module import (
    append_str_prefix,
    get_op_name,
    get_named_linears,
    set_op_by_name,
23
    exclude_layers_to_not_quantize,
24
)
Casper's avatar
Casper committed
25
26
27


class AwqQuantizer:
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
    def __init__(
        self,
        awq_model,
        model,
        tokenizer,
        w_bit,
        group_size,
        zero_point,
        version,
        calib_data,
        split,
        text_column,
        duo_scaling,
        modules_to_not_convert=None,
        export_compatible=False,
43
        apply_clip=True,
44
    ) -> None:
Casper Hansen's avatar
Casper Hansen committed
45
        self.awq_model = awq_model
Casper's avatar
Casper committed
46
47
48
49
        self.model = model
        self.tokenizer = tokenizer
        self.w_bit = w_bit
        self.group_size = group_size
50
        self.zero_point = zero_point
Casper's avatar
Casper committed
51
52
53
54
        self.version = version
        self.calib_data = calib_data
        self.split = split
        self.text_column = text_column
55
        self.duo_scaling = duo_scaling
56
        self.export_compatible = export_compatible
57
        self.apply_clip = apply_clip
58
59
60
        self.modules_to_not_convert = (
            modules_to_not_convert if modules_to_not_convert is not None else []
        )
Casper Hansen's avatar
Casper Hansen committed
61
        self.modules, self.module_kwargs, self.inps = self.init_quant()
62
63

    def pseudo_quantize_tensor(self, w: torch.Tensor):
Casper's avatar
Casper committed
64
65
66
67
68
        org_w_shape = w.shape
        if self.group_size > 0:
            assert org_w_shape[-1] % self.group_size == 0
            w = w.reshape(-1, self.group_size)
        assert w.dim() == 2
69
        assert torch.isnan(w).sum() == 0
Casper's avatar
Casper committed
70
71

        # zero point quantization
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
        if self.zero_point:
            max_val = w.amax(dim=1, keepdim=True)
            min_val = w.amin(dim=1, keepdim=True)
            max_int = 2**self.w_bit - 1
            min_int = 0
            scales = (max_val - min_val).clamp(min=1e-5) / max_int
            zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)
            w = (
                torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - zeros
            ) * scales
            zeros = zeros.view(org_w_shape[0], -1)
        else:
            max_val = w.abs().amax(dim=1, keepdim=True)
            max_val = max_val.clamp(min=1e-5)
            max_int = 2 ** (self.w_bit - 1) - 1
            min_int = -(2 ** (self.w_bit - 1))
            scales = max_val / max_int
            zeros = None
            w = torch.clamp(torch.round(w / scales), min_int, max_int) * scales
Casper's avatar
Casper committed
91
92
93
94

        assert torch.isnan(scales).sum() == 0
        assert torch.isnan(w).sum() == 0

95
        scales = scales.view(org_w_shape[0], -1)
Casper's avatar
Casper committed
96
97
        w = w.reshape(org_w_shape)

98
        return w, scales, zeros
Casper's avatar
Casper committed
99

100
101
102
103
104
    def pseudo_dequantize_tensor(
        self, w: nn.Linear, scales: torch.Tensor, zeros: Optional[torch.Tensor] = None
    ):
        # get repeated count
        repeat_count = w.weight.data.shape[-1] // scales.shape[-1]
Casper's avatar
Casper committed
105
106
107
        scales = scales.repeat(1, repeat_count).reshape(w.weight.data.shape)

        # dequantize
108
109
110
111
112
        if self.zero_point:
            zeros = zeros.repeat(1, repeat_count).reshape(w.weight.data.shape)
            w = (w.weight.data - zeros) * scales
        else:
            w = w.weight.data * scales
Casper's avatar
Casper committed
113
114

        return w
115

Casper Hansen's avatar
Casper Hansen committed
116
117
    def quantize(self):
        for i in tqdm(range(len(self.modules)), desc="AWQ"):
Casper's avatar
Casper committed
118
119
120
            # Move module and inputs to correct device
            common_device = next(self.modules[i].parameters()).device
            if common_device is None or str(common_device) == "cpu":
121
122
123
124
                if torch.cuda.is_available():
                    best_device = "cuda:" + str(i % torch.cuda.device_count())
                else:
                    best_device = get_best_device()
Casper's avatar
Casper committed
125

126
                self.modules[i] = self.modules[i].to(best_device)
Casper's avatar
Casper committed
127
                common_device = next(self.modules[i].parameters()).device
128
129

            if self.module_kwargs.get("position_ids") is not None:
130
131
132
                self.module_kwargs["position_ids"] = self.module_kwargs[
                    "position_ids"
                ].to(common_device)
133
134

            if self.module_kwargs.get("attention_mask") is not None:
135
136
137
                self.module_kwargs["attention_mask"] = self.module_kwargs[
                    "attention_mask"
                ].to(common_device)
138

Casper's avatar
Casper committed
139
140
            self.inps = self.inps.to(common_device)

Casper's avatar
Casper committed
141
142
            # [STEP 1]: Get layer, extract linear modules, extract input features
            named_linears = get_named_linears(self.modules[i])
143
144

            # Filter out the linear layers we don't want to exclude
145
146
147
            named_linears = exclude_layers_to_not_quantize(
                named_linears, self.modules_to_not_convert
            )
148

Casper's avatar
Casper committed
149
150
151
152
            input_feat = self._get_input_feat(self.modules[i], named_linears)
            clear_memory()

            # [STEP 2]: Compute and apply scale list
Vik Paruchuri's avatar
Vik Paruchuri committed
153
            module_config: List[Dict] = self.awq_model.get_layers_for_scaling(
Casper's avatar
Casper committed
154
155
                self.modules[i], input_feat, self.module_kwargs
            )
156
157
158
159
            scales_list = [
                self._search_best_scale(self.modules[i], **layer)
                for layer in module_config
            ]
Casper's avatar
Casper committed
160
            apply_scale(self.modules[i], scales_list, input_feat_dict=input_feat)
161
162
163
            scales_list = append_str_prefix(
                scales_list, get_op_name(self.model, self.modules[i]) + "."
            )
Casper's avatar
Casper committed
164
165

            # [STEP 3]: Compute and apply clipping list
166
167
168
169
170
171
172
173
            if self.apply_clip:
                clip_list = self._search_best_clip(
                    self.modules[i], named_linears, input_feat
                )
                apply_clip(self.modules[i], clip_list)
                clip_list = append_str_prefix(
                    clip_list, get_op_name(self.model, self.modules[i]) + "."
                )
Casper's avatar
Casper committed
174
175

            # [STEP 4]: Quantize weights
176
177
            if not self.export_compatible:
                self._apply_quant(self.modules[i], named_linears)
178

179
            clear_memory()
180

181
182
183
    def pack(self):
        for i in tqdm(range(len(self.modules)), desc="Packing"):
            named_linears = get_named_linears(self.modules[i])
184
185
186
            named_linears = exclude_layers_to_not_quantize(
                named_linears, self.modules_to_not_convert
            )
187
188
            self._apply_quant(self.modules[i], named_linears)
            clear_memory()
189

Vik Paruchuri's avatar
Vik Paruchuri committed
190
    def _apply_quant(self, module, named_linears: Dict[str, nn.Linear]):
191
192
        for name, linear_layer in named_linears.items():
            # NOTE: small regression in perplexity if linear layer uses .cpu().float()
193
            linear_layer = linear_layer.to(get_best_device()).half()
194
195

            linear_layer.weight.data, scales, zeros = self.pseudo_quantize_tensor(
196
                linear_layer.weight.data
197
198
            )

Casper's avatar
Casper committed
199
            if self.version == "gemm":
200
201
202
203
                scales = scales.t().contiguous()
                zeros = zeros.t().contiguous()
                q_linear_module = WQLinear_GEMM

Casper's avatar
Casper committed
204
            elif self.version == "gemv":
205
                q_linear_module = WQLinear_GEMV
206

Casper's avatar
Casper committed
207
            elif self.version == "marlin":
208
                q_linear_module = WQLinear_Marlin
Casper's avatar
Casper committed
209
210
211
            
            elif self.version == "gemv_fast":
                q_linear_module = WQLinear_GEMVFast
212
213
214
215

            else:
                raise ValueError(f"Unknown version {self.version}")

216
217
218
219
220
221
            q_linear = q_linear_module.from_linear(
                linear=linear_layer,
                w_bit=self.w_bit,
                group_size=self.group_size,
                init_only=False,
                scales=scales,
222
                zeros=zeros,
223
224
225
226
227
            )

            linear_layer.cpu()
            q_linear.to(next(module.parameters()).device)
            set_op_by_name(module, name, q_linear)
Casper's avatar
Casper committed
228
229
230
            clear_memory()

    @torch.no_grad()
231
232
233
234
235
236
237
238
239
    def _search_best_scale(
        self,
        module,
        prev_op,
        layers: List[nn.Linear],
        inp: torch.Tensor,
        module2inspect=None,
        kwargs={},
    ):
Casper Hansen's avatar
Casper Hansen committed
240
241
242
        if module2inspect is None:
            assert len(layers) == 1
            module2inspect = layers[0]
243

Casper Hansen's avatar
Casper Hansen committed
244
245
        if "use_cache" in kwargs:
            kwargs.pop("use_cache")
246

Casper's avatar
Casper committed
247
        # Put x on the right device
Casper Hansen's avatar
Casper Hansen committed
248
        inp = inp.to(next(module2inspect.parameters()).device)
Casper's avatar
Casper committed
249

250
251
        # [STEP 1]: Compute per-channel mean of normalised weights
        # All layer weights are concatted together
Casper Hansen's avatar
Casper Hansen committed
252
253
        weight = torch.cat([_m.weight for _m in layers], dim=0)
        org_shape = weight.shape
254
        # The weights are reshaped to be organised by quantization group
Casper's avatar
Casper committed
255
        weight = weight.view(-1, self.group_size)
256
257
        # Calculates the relative magnitude of the weights within each of the quantization groups, 
        # and rescales each group individually so that each group has weights on a 0-1 scale.
Casper Hansen's avatar
Casper Hansen committed
258
        w_scale = weight.abs() / weight.abs().amax(dim=1, keepdim=True)
259
        # Resizes the rescaled weight matrix back up to its original dimensions
Casper Hansen's avatar
Casper Hansen committed
260
        w_scale = w_scale.view(org_shape)
261
262
        # Gets the average rescaled magnitude for each output channel
        w_mean = w_scale.mean(0)
Casper's avatar
Casper committed
263
264
        clear_memory(weight)

265
266
        # [STEP 2]: Compute per-channel mean of the input activation
        x_mean = inp.abs().view(-1, inp.shape[-1]).mean(0)
Casper's avatar
Casper committed
267

Casper Hansen's avatar
Casper Hansen committed
268
        # [STEP 3]: Compute output of module
Casper's avatar
Casper committed
269
        with torch.no_grad():
270
271
272
            module_kwargs = self._sanitize_kwargs(kwargs, module2inspect)

            fp16_output = module2inspect(inp, **module_kwargs)
273
274
            if isinstance(fp16_output, tuple):
                fp16_output = fp16_output[0]
275

Casper's avatar
Casper committed
276
277
        # [STEP 4]: Compute loss
        best_scales = self._compute_best_scale(
278
            inp, w_mean, x_mean, module2inspect, layers, fp16_output, module_kwargs
Casper's avatar
Casper committed
279
280
        )

281
282
283
284
285
286
287
288
289
        return (
            get_op_name(module, prev_op),
            tuple([get_op_name(module, m) for m in layers]),
            best_scales,
        )

    def _compute_best_scale(
        self,
        x,
290
291
        w_mean,
        x_mean,
292
293
294
295
296
        module2inspect,
        linears2scale: List[nn.Linear],
        fp16_output,
        kwargs={},
    ):
Casper's avatar
Casper committed
297
298
299
        """
        Compute loss and select best scales

Casper's avatar
Casper committed
300
        L(s) = || Q(W * s) (s^-1 * X) - W * X ||
Casper's avatar
Casper committed
301
302
303
304
305
306
307
308
309
        Q: weight quantization function | pseudo_quantize_tensor(W * s)
        X: inputs from calib dataset    | X
        W: original weights in FP16     | layer
        s: per channel scaling factor   | s^-1 * X
        """
        n_grid = 20
        history = []
        best_ratio = -1
        best_scales = None
310
        best_error = float("inf")
Casper's avatar
Casper committed
311

Casper Hansen's avatar
Casper Hansen committed
312
        org_sd = {k: v.cpu() for k, v in module2inspect.state_dict().items()}
313

Casper's avatar
Casper committed
314
        device = x.device
315
316
        x_mean = x_mean.view(-1).to(device)
        w_mean = w_mean.view(-1).to(device)
317

Casper's avatar
Casper committed
318
319
        for ratio in range(n_grid):
            # create new scales
Casper's avatar
Casper committed
320
            ratio = ratio / n_grid
321

Casper Hansen's avatar
Casper Hansen committed
322
            # NOTE: s^-1 * x is fused here, according to paper
323
            if self.duo_scaling:
324
                scales = (x_mean.pow(ratio) / w_mean.pow(1 - ratio)).clamp(min=1e-4)
325
            else:
326
                scales = x_mean.pow(ratio).clamp(min=1e-4).view(-1)
Casper's avatar
Casper committed
327
            scales = scales / (scales.max() * scales.min()).sqrt()
Casper's avatar
Casper committed
328
            scales_view = scales.view(1, -1).to(device)
329

Casper Hansen's avatar
Casper Hansen committed
330
            # Q(W * s)
Casper's avatar
Casper committed
331
            for fc in linears2scale:
Casper's avatar
Casper committed
332
                fc.weight.mul_(scales_view)
333
334
335
                fc.weight.data = (
                    self.pseudo_quantize_tensor(fc.weight.data)[0] / scales_view
                )
Casper's avatar
Casper committed
336

337
338
339
340
            # W * X
            int_w_output = module2inspect(x, **kwargs)
            if isinstance(int_w_output, tuple):
                int_w_output = int_w_output[0]
341

Casper Hansen's avatar
Casper Hansen committed
342
            # compute mean squared error (L2 norm)
343
344
345
            loss = (
                (fp16_output - int_w_output).float().pow(2).mean().item()
            )  # NOTE: float prevents overflow
Casper's avatar
Casper committed
346
347

            history.append(loss)
Casper's avatar
Casper committed
348
            if loss < best_error:
Casper's avatar
Casper committed
349
350
                best_error = loss
                best_ratio = ratio
Casper's avatar
Casper committed
351
                best_scales = scales.clone()
Casper Hansen's avatar
Casper Hansen committed
352
            module2inspect.load_state_dict(org_sd)
Casper's avatar
Casper committed
353

Casper's avatar
Casper committed
354
355
356
357
358
359
        if best_ratio == -1:
            logging.debug(history)
            raise Exception

        assert torch.isnan(best_scales).sum() == 0, best_scales

Casper Hansen's avatar
Casper Hansen committed
360
        return best_scales.detach().cpu()
Casper's avatar
Casper committed
361

Casper Hansen's avatar
Casper Hansen committed
362
363
364
365
    @torch.no_grad()
    def _search_best_clip(self, layer, named_linears, input_feat):
        clip_list = []
        avoid_clipping = ["q_", "k_", "query", "key", "Wqkv"]
Casper's avatar
Casper committed
366

Casper Hansen's avatar
Casper Hansen committed
367
368
369
370
371
        for name in named_linears:
            # due to qk bmm, it is hard to clip precisely
            if any([_ in name for _ in avoid_clipping]):
                continue

372
            named_linears[name].to(get_best_device())
Casper's avatar
Casper committed
373
374
375
            max_val = self._compute_best_clip(
                named_linears[name].weight, input_feat[name]
            )
Casper Hansen's avatar
Casper Hansen committed
376
377
            clip_list.append((name, max_val))
            named_linears[name].cpu()
378

Casper Hansen's avatar
Casper Hansen committed
379
        return clip_list
Casper Hansen's avatar
Casper Hansen committed
380
381

    @torch.no_grad()
382
383
384
385
386
387
388
389
    def _compute_best_clip(
        self,
        w: torch.Tensor,
        input_feat: torch.Tensor,
        n_grid=20,
        max_shrink=0.5,
        n_sample_token=512,
    ):
Casper Hansen's avatar
Casper Hansen committed
390
391
392
393
        assert w.dim() == 2
        org_w_shape = w.shape
        # w           [co, ci]      -> [co, 1, n_group, group size]
        # input_feat  [n_token, ci] -> [1, n_token, n_group, group size]
394
        group_size = self.group_size if self.group_size > 0 else org_w_shape[1]
Casper Hansen's avatar
Casper Hansen committed
395
396
        input_feat = input_feat.view(-1, input_feat.shape[-1])
        input_feat = input_feat.reshape(1, input_feat.shape[0], -1, group_size)
397
398
        input_feat = input_feat[:, 0 :: input_feat.shape[1] // n_sample_token]
        w = w.reshape(org_w_shape[0], 1, -1, group_size)
Casper Hansen's avatar
Casper Hansen committed
399

400
401
        oc_batch_size = 256 if org_w_shape[0] % 256 == 0 else 64  # prevent OOM
        assert org_w_shape[0] % oc_batch_size == 0
Casper Hansen's avatar
Casper Hansen committed
402
403
        w_all = w
        best_max_val_all = []
Casper's avatar
Casper committed
404

405
406
        for i_b in range(org_w_shape[0] // oc_batch_size):
            w = w_all[i_b * oc_batch_size : (i_b + 1) * oc_batch_size]
Casper Hansen's avatar
Casper Hansen committed
407
408
409
410
411
412
413
414
415
416

            org_max_val = w.abs().amax(dim=-1, keepdim=True)  # co, 1, n_group, 1

            best_max_val = org_max_val.clone()
            min_errs = torch.ones_like(org_max_val) * 1e9
            input_feat = input_feat.to(w.device)
            org_out = (input_feat * w).sum(dim=-1)  # co, n_token, n_group

            for i_s in range(int(max_shrink * n_grid)):
                max_val = org_max_val * (1 - i_s / n_grid)
417
                min_val = -max_val
Casper Hansen's avatar
Casper Hansen committed
418
                cur_w = torch.clamp(w, min_val, max_val)
419
                q_w = self.pseudo_quantize_tensor(cur_w)[0]
Casper Hansen's avatar
Casper Hansen committed
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
                cur_out = (input_feat * q_w).sum(dim=-1)

                # co, 1, n_group, 1
                err = (cur_out - org_out).pow(2).mean(dim=1).view(min_errs.shape)
                del cur_w
                del cur_out
                cur_best_idx = err < min_errs
                min_errs[cur_best_idx] = err[cur_best_idx]
                best_max_val[cur_best_idx] = max_val[cur_best_idx]
            best_max_val_all.append(best_max_val)

        best_max_val = torch.cat(best_max_val_all, dim=0)

        clear_memory(input_feat)
        clear_memory(org_out)

        return best_max_val.squeeze(1)

Casper's avatar
Casper committed
438
    def init_quant(self, n_samples=128, seqlen=512):
Casper Hansen's avatar
Casper Hansen committed
439
        modules = self.awq_model.get_model_layers(self.model)
Casper's avatar
Casper committed
440
        samples = get_calib_dataset(
441
442
443
444
445
446
            data=self.calib_data,
            tokenizer=self.tokenizer,
            n_samples=n_samples,
            block_size=seqlen,
            split=self.split,
            text_column=self.text_column,
Casper's avatar
Casper committed
447
448
449
450
451
452
        )
        samples = torch.cat(samples, dim=0)

        inps = []
        layer_kwargs = {}

453
454
455
        best_device = get_best_device()
        modules[0] = modules[0].to(best_device)
        self.awq_model.move_embed(self.model, best_device)
456

Casper's avatar
Casper committed
457
458
459
460
461
462
463
464
        # get input and kwargs to layer 0
        # with_kwargs is only supported in PyTorch 2.0
        # use this Catcher hack for now
        class Catcher(nn.Module):
            def __init__(self, module):
                super().__init__()
                self.module = module

465
466
467
468
469
470
471
472
473
474
            def forward(self, *args, **kwargs):
                # assume first input to forward is hidden states
                if len(args) > 0:
                    hidden_states = args[0]
                    del args
                else:
                    first_key = list(kwargs.keys())[0]
                    hidden_states = kwargs.pop(first_key)

                inps.append(hidden_states)
Casper's avatar
Casper committed
475
476
477
478
                layer_kwargs.update(kwargs)
                raise ValueError  # early exit to break later inference

        # patch layer 0 to catch input and kwargs
Casper Hansen's avatar
Casper Hansen committed
479
        modules[0] = Catcher(modules[0])
Casper's avatar
Casper committed
480
481
482
483
        try:
            self.model(samples.to(next(self.model.parameters()).device))
        except ValueError:  # work with early exit
            pass
Casper's avatar
Casper committed
484
        modules[0] = modules[0].module  # restore
485

486
487
488
489
490
491
        # Update the layer kwargs with `prepare_inputs_for_generation` method
        # that takes care of everything to avoid unexpected errors.
        layer_kwargs = self.model.prepare_inputs_for_generation(samples, **layer_kwargs)
        # Pop the input_ids as they are not needed at all.
        layer_kwargs.pop("input_ids")

Casper's avatar
Casper committed
492
493
494
        del samples
        inps = inps[0]

Casper Hansen's avatar
Casper Hansen committed
495
496
        modules[0] = modules[0].cpu()
        self.awq_model.move_embed(self.model, "cpu")
497

Casper's avatar
Casper committed
498
        clear_memory()
499

Casper's avatar
Casper committed
500
        if layer_kwargs.get("attention_mask") is not None:
Casper's avatar
Casper committed
501
502
503
            layer_kwargs["attention_mask"] = layer_kwargs["attention_mask"].to(
                best_device
            )
Casper's avatar
Casper committed
504

Casper Hansen's avatar
Casper Hansen committed
505
        return modules, layer_kwargs, inps
506

Casper's avatar
Casper committed
507
508
509
510
511
512
513
514
515
    def _get_input_feat(self, layer, named_linears):
        # firstly, get input features of all linear layers
        def cache_input_hook(m, x, y, name, feat_dict):
            x = x[0]
            x = x.detach().cpu()
            feat_dict[name].append(x)

        input_feat = defaultdict(list)
        handles = []
516
517
518

        # FIXME: Workaround for Mixtral to use block_sparse_moe input features
        if self.awq_model.model_type == "mixtral":
519
520
521
522
            named_linears = {
                **named_linears,
                "block_sparse_moe": layer.block_sparse_moe,
            }
523

Casper's avatar
Casper committed
524
        for name in named_linears:
525
526
527
528
529
            handles.append(
                named_linears[name].register_forward_hook(
                    functools.partial(cache_input_hook, name=name, feat_dict=input_feat)
                )
            )
Casper Hansen's avatar
Casper Hansen committed
530
        self.inps = self.inps.to(next(layer.parameters()).device)  # in case multi-gpu
Casper's avatar
Casper committed
531
        # get output as next layer's input
532

533
534
535
536
537
538
        # Sanitize the kwargs in case we use transformers version that contains
        # kwargs that are not handled by the module.
        # Useful for trust_remote_code models.
        module_kwargs = self._sanitize_kwargs(self.module_kwargs, layer)

        self.inps = layer(self.inps, **module_kwargs)[0]
Casper's avatar
Casper committed
539
540
541
542
        for h in handles:
            h.remove()
        # now solve for scaling and clipping
        input_feat = {k: torch.cat(v, dim=0) for k, v in input_feat.items()}
543

544
        return input_feat
545
546
547
548
549

    def _sanitize_kwargs(self, inputs_kwargs, module):
        """
        Remove the arguments that are not supported in the module's
        forward pass to avoid breaking behaviour between different versions
550
        of transformers.
551
552
553
554
555
556
557
558
559

        Args:
            inputs_kwargs (`dict`):
                The input dictionary to pass to the model layer
            module (`torch.nn.Module`):
                Target module to quantize.
        """
        module_signature = inspect.signature(module.forward).parameters
        sanitized_kwargs = {}
560
        for k, v in inputs_kwargs.items():
561
562
            if k in module_signature:
                sanitized_kwargs[k] = v
563
        return sanitized_kwargs