quantizer.py 19.6 KB
Newer Older
Ji Lin's avatar
Ji Lin committed
1
import torch
2
import inspect
Casper's avatar
Casper committed
3
4
5
6
import logging
import functools
import torch.nn as nn
from tqdm import tqdm
7
from typing import Dict, List, Optional
Casper's avatar
Casper committed
8
9
from collections import defaultdict
from awq.utils.calib_data import get_calib_dataset
Casper Hansen's avatar
Casper Hansen committed
10
from awq.quantize.scale import apply_scale, apply_clip
11
from awq.utils.utils import clear_memory, get_best_device
Casper's avatar
Casper committed
12
13
14
15
16
17
from awq.modules.linear import (
    WQLinear_GEMM,
    WQLinear_GEMV,
    WQLinear_Marlin,
    WQLinear_GEMVFast,
)
18
19
20
21
22
from awq.utils.module import (
    append_str_prefix,
    get_op_name,
    get_named_linears,
    set_op_by_name,
23
    exclude_layers_to_not_quantize,
24
)
Casper's avatar
Casper committed
25
26
27


class AwqQuantizer:
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
    def __init__(
        self,
        awq_model,
        model,
        tokenizer,
        w_bit,
        group_size,
        zero_point,
        version,
        calib_data,
        split,
        text_column,
        duo_scaling,
        modules_to_not_convert=None,
        export_compatible=False,
    ) -> None:
Casper Hansen's avatar
Casper Hansen committed
44
        self.awq_model = awq_model
Casper's avatar
Casper committed
45
46
47
48
        self.model = model
        self.tokenizer = tokenizer
        self.w_bit = w_bit
        self.group_size = group_size
49
        self.zero_point = zero_point
Casper's avatar
Casper committed
50
51
52
53
        self.version = version
        self.calib_data = calib_data
        self.split = split
        self.text_column = text_column
54
        self.duo_scaling = duo_scaling
55
        self.export_compatible = export_compatible
56
57
58
        self.modules_to_not_convert = (
            modules_to_not_convert if modules_to_not_convert is not None else []
        )
Casper Hansen's avatar
Casper Hansen committed
59
        self.modules, self.module_kwargs, self.inps = self.init_quant()
60
61

    def pseudo_quantize_tensor(self, w: torch.Tensor):
Casper's avatar
Casper committed
62
63
64
65
66
        org_w_shape = w.shape
        if self.group_size > 0:
            assert org_w_shape[-1] % self.group_size == 0
            w = w.reshape(-1, self.group_size)
        assert w.dim() == 2
67
        assert torch.isnan(w).sum() == 0
Casper's avatar
Casper committed
68
69

        # zero point quantization
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
        if self.zero_point:
            max_val = w.amax(dim=1, keepdim=True)
            min_val = w.amin(dim=1, keepdim=True)
            max_int = 2**self.w_bit - 1
            min_int = 0
            scales = (max_val - min_val).clamp(min=1e-5) / max_int
            zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)
            w = (
                torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - zeros
            ) * scales
            zeros = zeros.view(org_w_shape[0], -1)
        else:
            max_val = w.abs().amax(dim=1, keepdim=True)
            max_val = max_val.clamp(min=1e-5)
            max_int = 2 ** (self.w_bit - 1) - 1
            min_int = -(2 ** (self.w_bit - 1))
            scales = max_val / max_int
            zeros = None
            w = torch.clamp(torch.round(w / scales), min_int, max_int) * scales
Casper's avatar
Casper committed
89
90
91
92

        assert torch.isnan(scales).sum() == 0
        assert torch.isnan(w).sum() == 0

93
        scales = scales.view(org_w_shape[0], -1)
Casper's avatar
Casper committed
94
95
        w = w.reshape(org_w_shape)

96
        return w, scales, zeros
Casper's avatar
Casper committed
97

98
99
100
101
102
    def pseudo_dequantize_tensor(
        self, w: nn.Linear, scales: torch.Tensor, zeros: Optional[torch.Tensor] = None
    ):
        # get repeated count
        repeat_count = w.weight.data.shape[-1] // scales.shape[-1]
Casper's avatar
Casper committed
103
104
105
        scales = scales.repeat(1, repeat_count).reshape(w.weight.data.shape)

        # dequantize
106
107
108
109
110
        if self.zero_point:
            zeros = zeros.repeat(1, repeat_count).reshape(w.weight.data.shape)
            w = (w.weight.data - zeros) * scales
        else:
            w = w.weight.data * scales
Casper's avatar
Casper committed
111
112

        return w
113

Casper Hansen's avatar
Casper Hansen committed
114
115
    def quantize(self):
        for i in tqdm(range(len(self.modules)), desc="AWQ"):
Casper's avatar
Casper committed
116
117
118
            # Move module and inputs to correct device
            common_device = next(self.modules[i].parameters()).device
            if common_device is None or str(common_device) == "cpu":
119
120
121
122
                if torch.cuda.is_available():
                    best_device = "cuda:" + str(i % torch.cuda.device_count())
                else:
                    best_device = get_best_device()
Casper's avatar
Casper committed
123

124
                self.modules[i] = self.modules[i].to(best_device)
Casper's avatar
Casper committed
125
                common_device = next(self.modules[i].parameters()).device
126
127

            if self.module_kwargs.get("position_ids") is not None:
128
129
130
                self.module_kwargs["position_ids"] = self.module_kwargs[
                    "position_ids"
                ].to(common_device)
131
132

            if self.module_kwargs.get("attention_mask") is not None:
133
134
135
                self.module_kwargs["attention_mask"] = self.module_kwargs[
                    "attention_mask"
                ].to(common_device)
136

Casper's avatar
Casper committed
137
138
            self.inps = self.inps.to(common_device)

Casper's avatar
Casper committed
139
140
            # [STEP 1]: Get layer, extract linear modules, extract input features
            named_linears = get_named_linears(self.modules[i])
141
142

            # Filter out the linear layers we don't want to exclude
143
144
145
            named_linears = exclude_layers_to_not_quantize(
                named_linears, self.modules_to_not_convert
            )
146

Casper's avatar
Casper committed
147
148
149
150
            input_feat = self._get_input_feat(self.modules[i], named_linears)
            clear_memory()

            # [STEP 2]: Compute and apply scale list
Vik Paruchuri's avatar
Vik Paruchuri committed
151
            module_config: List[Dict] = self.awq_model.get_layers_for_scaling(
Casper's avatar
Casper committed
152
153
                self.modules[i], input_feat, self.module_kwargs
            )
154
155
156
157
            scales_list = [
                self._search_best_scale(self.modules[i], **layer)
                for layer in module_config
            ]
Casper's avatar
Casper committed
158
            apply_scale(self.modules[i], scales_list, input_feat_dict=input_feat)
159
160
161
            scales_list = append_str_prefix(
                scales_list, get_op_name(self.model, self.modules[i]) + "."
            )
Casper's avatar
Casper committed
162
163

            # [STEP 3]: Compute and apply clipping list
164
165
166
            clip_list = self._search_best_clip(
                self.modules[i], named_linears, input_feat
            )
Casper Hansen's avatar
Casper Hansen committed
167
            apply_clip(self.modules[i], clip_list)
168
169
170
            clip_list = append_str_prefix(
                clip_list, get_op_name(self.model, self.modules[i]) + "."
            )
Casper's avatar
Casper committed
171
172

            # [STEP 4]: Quantize weights
173
174
            if not self.export_compatible:
                self._apply_quant(self.modules[i], named_linears)
175

176
            clear_memory()
177

178
179
180
    def pack(self):
        for i in tqdm(range(len(self.modules)), desc="Packing"):
            named_linears = get_named_linears(self.modules[i])
181
182
183
            named_linears = exclude_layers_to_not_quantize(
                named_linears, self.modules_to_not_convert
            )
184
185
            self._apply_quant(self.modules[i], named_linears)
            clear_memory()
186

Vik Paruchuri's avatar
Vik Paruchuri committed
187
    def _apply_quant(self, module, named_linears: Dict[str, nn.Linear]):
188
189
        for name, linear_layer in named_linears.items():
            # NOTE: small regression in perplexity if linear layer uses .cpu().float()
190
            linear_layer = linear_layer.to(get_best_device()).half()
191
192

            linear_layer.weight.data, scales, zeros = self.pseudo_quantize_tensor(
193
                linear_layer.weight.data
194
195
            )

Casper's avatar
Casper committed
196
            if self.version == "gemm":
197
198
199
200
                scales = scales.t().contiguous()
                zeros = zeros.t().contiguous()
                q_linear_module = WQLinear_GEMM

Casper's avatar
Casper committed
201
            elif self.version == "gemv":
202
                q_linear_module = WQLinear_GEMV
203

Casper's avatar
Casper committed
204
            elif self.version == "marlin":
205
                q_linear_module = WQLinear_Marlin
Casper's avatar
Casper committed
206
207
208
            
            elif self.version == "gemv_fast":
                q_linear_module = WQLinear_GEMVFast
209
210
211
212

            else:
                raise ValueError(f"Unknown version {self.version}")

213
214
215
216
217
218
            q_linear = q_linear_module.from_linear(
                linear=linear_layer,
                w_bit=self.w_bit,
                group_size=self.group_size,
                init_only=False,
                scales=scales,
219
                zeros=zeros,
220
221
222
223
224
            )

            linear_layer.cpu()
            q_linear.to(next(module.parameters()).device)
            set_op_by_name(module, name, q_linear)
Casper's avatar
Casper committed
225
226
227
            clear_memory()

    @torch.no_grad()
228
229
230
231
232
233
234
235
236
    def _search_best_scale(
        self,
        module,
        prev_op,
        layers: List[nn.Linear],
        inp: torch.Tensor,
        module2inspect=None,
        kwargs={},
    ):
Casper Hansen's avatar
Casper Hansen committed
237
238
239
        if module2inspect is None:
            assert len(layers) == 1
            module2inspect = layers[0]
240

Casper Hansen's avatar
Casper Hansen committed
241
242
        if "use_cache" in kwargs:
            kwargs.pop("use_cache")
243

Casper's avatar
Casper committed
244
        # Put x on the right device
Casper Hansen's avatar
Casper Hansen committed
245
        inp = inp.to(next(module2inspect.parameters()).device)
Casper's avatar
Casper committed
246

247
248
        # [STEP 1]: Compute per-channel mean of normalised weights
        # All layer weights are concatted together
Casper Hansen's avatar
Casper Hansen committed
249
250
        weight = torch.cat([_m.weight for _m in layers], dim=0)
        org_shape = weight.shape
251
        # The weights are reshaped to be organised by quantization group
Casper's avatar
Casper committed
252
        weight = weight.view(-1, self.group_size)
253
254
        # Calculates the relative magnitude of the weights within each of the quantization groups, 
        # and rescales each group individually so that each group has weights on a 0-1 scale.
Casper Hansen's avatar
Casper Hansen committed
255
        w_scale = weight.abs() / weight.abs().amax(dim=1, keepdim=True)
256
        # Resizes the rescaled weight matrix back up to its original dimensions
Casper Hansen's avatar
Casper Hansen committed
257
        w_scale = w_scale.view(org_shape)
258
259
        # Gets the average rescaled magnitude for each output channel
        w_mean = w_scale.mean(0)
Casper's avatar
Casper committed
260
261
        clear_memory(weight)

262
263
        # [STEP 2]: Compute per-channel mean of the input activation
        x_mean = inp.abs().view(-1, inp.shape[-1]).mean(0)
Casper's avatar
Casper committed
264

Casper Hansen's avatar
Casper Hansen committed
265
        # [STEP 3]: Compute output of module
Casper's avatar
Casper committed
266
        with torch.no_grad():
267
268
269
            module_kwargs = self._sanitize_kwargs(kwargs, module2inspect)

            fp16_output = module2inspect(inp, **module_kwargs)
270
271
            if isinstance(fp16_output, tuple):
                fp16_output = fp16_output[0]
272

Casper's avatar
Casper committed
273
274
        # [STEP 4]: Compute loss
        best_scales = self._compute_best_scale(
275
            inp, w_mean, x_mean, module2inspect, layers, fp16_output, module_kwargs
Casper's avatar
Casper committed
276
277
        )

278
279
280
281
282
283
284
285
286
        return (
            get_op_name(module, prev_op),
            tuple([get_op_name(module, m) for m in layers]),
            best_scales,
        )

    def _compute_best_scale(
        self,
        x,
287
288
        w_mean,
        x_mean,
289
290
291
292
293
        module2inspect,
        linears2scale: List[nn.Linear],
        fp16_output,
        kwargs={},
    ):
Casper's avatar
Casper committed
294
295
296
        """
        Compute loss and select best scales

Casper's avatar
Casper committed
297
        L(s) = || Q(W * s) (s^-1 * X) - W * X ||
Casper's avatar
Casper committed
298
299
300
301
302
303
304
305
306
        Q: weight quantization function | pseudo_quantize_tensor(W * s)
        X: inputs from calib dataset    | X
        W: original weights in FP16     | layer
        s: per channel scaling factor   | s^-1 * X
        """
        n_grid = 20
        history = []
        best_ratio = -1
        best_scales = None
307
        best_error = float("inf")
Casper's avatar
Casper committed
308

Casper Hansen's avatar
Casper Hansen committed
309
        org_sd = {k: v.cpu() for k, v in module2inspect.state_dict().items()}
310

Casper's avatar
Casper committed
311
        device = x.device
312
313
        x_mean = x_mean.view(-1).to(device)
        w_mean = w_mean.view(-1).to(device)
314

Casper's avatar
Casper committed
315
316
        for ratio in range(n_grid):
            # create new scales
Casper's avatar
Casper committed
317
            ratio = ratio / n_grid
318

Casper Hansen's avatar
Casper Hansen committed
319
            # NOTE: s^-1 * x is fused here, according to paper
320
            if self.duo_scaling:
321
                scales = (x_mean.pow(ratio) / w_mean.pow(1 - ratio)).clamp(min=1e-4)
322
            else:
323
                scales = x_mean.pow(ratio).clamp(min=1e-4).view(-1)
Casper's avatar
Casper committed
324
            scales = scales / (scales.max() * scales.min()).sqrt()
Casper's avatar
Casper committed
325
            scales_view = scales.view(1, -1).to(device)
326

Casper Hansen's avatar
Casper Hansen committed
327
            # Q(W * s)
Casper's avatar
Casper committed
328
            for fc in linears2scale:
Casper's avatar
Casper committed
329
                fc.weight.mul_(scales_view)
330
331
332
                fc.weight.data = (
                    self.pseudo_quantize_tensor(fc.weight.data)[0] / scales_view
                )
Casper's avatar
Casper committed
333

334
335
336
337
            # W * X
            int_w_output = module2inspect(x, **kwargs)
            if isinstance(int_w_output, tuple):
                int_w_output = int_w_output[0]
338

Casper Hansen's avatar
Casper Hansen committed
339
            # compute mean squared error (L2 norm)
340
341
342
            loss = (
                (fp16_output - int_w_output).float().pow(2).mean().item()
            )  # NOTE: float prevents overflow
Casper's avatar
Casper committed
343
344

            history.append(loss)
Casper's avatar
Casper committed
345
            if loss < best_error:
Casper's avatar
Casper committed
346
347
                best_error = loss
                best_ratio = ratio
Casper's avatar
Casper committed
348
                best_scales = scales.clone()
Casper Hansen's avatar
Casper Hansen committed
349
            module2inspect.load_state_dict(org_sd)
Casper's avatar
Casper committed
350

Casper's avatar
Casper committed
351
352
353
354
355
356
        if best_ratio == -1:
            logging.debug(history)
            raise Exception

        assert torch.isnan(best_scales).sum() == 0, best_scales

Casper Hansen's avatar
Casper Hansen committed
357
        return best_scales.detach().cpu()
Casper's avatar
Casper committed
358

Casper Hansen's avatar
Casper Hansen committed
359
360
361
362
    @torch.no_grad()
    def _search_best_clip(self, layer, named_linears, input_feat):
        clip_list = []
        avoid_clipping = ["q_", "k_", "query", "key", "Wqkv"]
Casper's avatar
Casper committed
363

Casper Hansen's avatar
Casper Hansen committed
364
365
366
367
368
        for name in named_linears:
            # due to qk bmm, it is hard to clip precisely
            if any([_ in name for _ in avoid_clipping]):
                continue

369
            named_linears[name].to(get_best_device())
Casper's avatar
Casper committed
370
371
372
            max_val = self._compute_best_clip(
                named_linears[name].weight, input_feat[name]
            )
Casper Hansen's avatar
Casper Hansen committed
373
374
            clip_list.append((name, max_val))
            named_linears[name].cpu()
375

Casper Hansen's avatar
Casper Hansen committed
376
        return clip_list
Casper Hansen's avatar
Casper Hansen committed
377
378

    @torch.no_grad()
379
380
381
382
383
384
385
386
    def _compute_best_clip(
        self,
        w: torch.Tensor,
        input_feat: torch.Tensor,
        n_grid=20,
        max_shrink=0.5,
        n_sample_token=512,
    ):
Casper Hansen's avatar
Casper Hansen committed
387
388
389
390
        assert w.dim() == 2
        org_w_shape = w.shape
        # w           [co, ci]      -> [co, 1, n_group, group size]
        # input_feat  [n_token, ci] -> [1, n_token, n_group, group size]
391
        group_size = self.group_size if self.group_size > 0 else org_w_shape[1]
Casper Hansen's avatar
Casper Hansen committed
392
393
        input_feat = input_feat.view(-1, input_feat.shape[-1])
        input_feat = input_feat.reshape(1, input_feat.shape[0], -1, group_size)
394
395
        input_feat = input_feat[:, 0 :: input_feat.shape[1] // n_sample_token]
        w = w.reshape(org_w_shape[0], 1, -1, group_size)
Casper Hansen's avatar
Casper Hansen committed
396

397
398
        oc_batch_size = 256 if org_w_shape[0] % 256 == 0 else 64  # prevent OOM
        assert org_w_shape[0] % oc_batch_size == 0
Casper Hansen's avatar
Casper Hansen committed
399
400
        w_all = w
        best_max_val_all = []
Casper's avatar
Casper committed
401

402
403
        for i_b in range(org_w_shape[0] // oc_batch_size):
            w = w_all[i_b * oc_batch_size : (i_b + 1) * oc_batch_size]
Casper Hansen's avatar
Casper Hansen committed
404
405
406
407
408
409
410
411
412
413

            org_max_val = w.abs().amax(dim=-1, keepdim=True)  # co, 1, n_group, 1

            best_max_val = org_max_val.clone()
            min_errs = torch.ones_like(org_max_val) * 1e9
            input_feat = input_feat.to(w.device)
            org_out = (input_feat * w).sum(dim=-1)  # co, n_token, n_group

            for i_s in range(int(max_shrink * n_grid)):
                max_val = org_max_val * (1 - i_s / n_grid)
414
                min_val = -max_val
Casper Hansen's avatar
Casper Hansen committed
415
                cur_w = torch.clamp(w, min_val, max_val)
416
                q_w = self.pseudo_quantize_tensor(cur_w)[0]
Casper Hansen's avatar
Casper Hansen committed
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
                cur_out = (input_feat * q_w).sum(dim=-1)

                # co, 1, n_group, 1
                err = (cur_out - org_out).pow(2).mean(dim=1).view(min_errs.shape)
                del cur_w
                del cur_out
                cur_best_idx = err < min_errs
                min_errs[cur_best_idx] = err[cur_best_idx]
                best_max_val[cur_best_idx] = max_val[cur_best_idx]
            best_max_val_all.append(best_max_val)

        best_max_val = torch.cat(best_max_val_all, dim=0)

        clear_memory(input_feat)
        clear_memory(org_out)

        return best_max_val.squeeze(1)

Casper's avatar
Casper committed
435
    def init_quant(self, n_samples=128, seqlen=512):
Casper Hansen's avatar
Casper Hansen committed
436
        modules = self.awq_model.get_model_layers(self.model)
Casper's avatar
Casper committed
437
        samples = get_calib_dataset(
438
439
440
441
442
443
            data=self.calib_data,
            tokenizer=self.tokenizer,
            n_samples=n_samples,
            block_size=seqlen,
            split=self.split,
            text_column=self.text_column,
Casper's avatar
Casper committed
444
445
446
447
448
449
        )
        samples = torch.cat(samples, dim=0)

        inps = []
        layer_kwargs = {}

450
451
452
        best_device = get_best_device()
        modules[0] = modules[0].to(best_device)
        self.awq_model.move_embed(self.model, best_device)
453

Casper's avatar
Casper committed
454
455
456
457
458
459
460
461
        # get input and kwargs to layer 0
        # with_kwargs is only supported in PyTorch 2.0
        # use this Catcher hack for now
        class Catcher(nn.Module):
            def __init__(self, module):
                super().__init__()
                self.module = module

462
463
464
465
466
467
468
469
470
471
            def forward(self, *args, **kwargs):
                # assume first input to forward is hidden states
                if len(args) > 0:
                    hidden_states = args[0]
                    del args
                else:
                    first_key = list(kwargs.keys())[0]
                    hidden_states = kwargs.pop(first_key)

                inps.append(hidden_states)
Casper's avatar
Casper committed
472
473
474
475
                layer_kwargs.update(kwargs)
                raise ValueError  # early exit to break later inference

        # patch layer 0 to catch input and kwargs
Casper Hansen's avatar
Casper Hansen committed
476
        modules[0] = Catcher(modules[0])
Casper's avatar
Casper committed
477
478
479
480
        try:
            self.model(samples.to(next(self.model.parameters()).device))
        except ValueError:  # work with early exit
            pass
Casper's avatar
Casper committed
481
        modules[0] = modules[0].module  # restore
482

483
484
485
486
487
488
        # Update the layer kwargs with `prepare_inputs_for_generation` method
        # that takes care of everything to avoid unexpected errors.
        layer_kwargs = self.model.prepare_inputs_for_generation(samples, **layer_kwargs)
        # Pop the input_ids as they are not needed at all.
        layer_kwargs.pop("input_ids")

Casper's avatar
Casper committed
489
490
491
        del samples
        inps = inps[0]

Casper Hansen's avatar
Casper Hansen committed
492
493
        modules[0] = modules[0].cpu()
        self.awq_model.move_embed(self.model, "cpu")
494

Casper's avatar
Casper committed
495
        clear_memory()
496

Casper's avatar
Casper committed
497
        if layer_kwargs.get("attention_mask") is not None:
Casper's avatar
Casper committed
498
499
500
            layer_kwargs["attention_mask"] = layer_kwargs["attention_mask"].to(
                best_device
            )
Casper's avatar
Casper committed
501

Casper Hansen's avatar
Casper Hansen committed
502
        return modules, layer_kwargs, inps
503

Casper's avatar
Casper committed
504
505
506
507
508
509
510
511
512
    def _get_input_feat(self, layer, named_linears):
        # firstly, get input features of all linear layers
        def cache_input_hook(m, x, y, name, feat_dict):
            x = x[0]
            x = x.detach().cpu()
            feat_dict[name].append(x)

        input_feat = defaultdict(list)
        handles = []
513
514
515

        # FIXME: Workaround for Mixtral to use block_sparse_moe input features
        if self.awq_model.model_type == "mixtral":
516
517
518
519
            named_linears = {
                **named_linears,
                "block_sparse_moe": layer.block_sparse_moe,
            }
520

Casper's avatar
Casper committed
521
        for name in named_linears:
522
523
524
525
526
            handles.append(
                named_linears[name].register_forward_hook(
                    functools.partial(cache_input_hook, name=name, feat_dict=input_feat)
                )
            )
Casper Hansen's avatar
Casper Hansen committed
527
        self.inps = self.inps.to(next(layer.parameters()).device)  # in case multi-gpu
Casper's avatar
Casper committed
528
        # get output as next layer's input
529

530
531
532
533
534
535
        # Sanitize the kwargs in case we use transformers version that contains
        # kwargs that are not handled by the module.
        # Useful for trust_remote_code models.
        module_kwargs = self._sanitize_kwargs(self.module_kwargs, layer)

        self.inps = layer(self.inps, **module_kwargs)[0]
Casper's avatar
Casper committed
536
537
538
539
        for h in handles:
            h.remove()
        # now solve for scaling and clipping
        input_feat = {k: torch.cat(v, dim=0) for k, v in input_feat.items()}
540

541
        return input_feat
542
543
544
545
546

    def _sanitize_kwargs(self, inputs_kwargs, module):
        """
        Remove the arguments that are not supported in the module's
        forward pass to avoid breaking behaviour between different versions
547
        of transformers.
548
549
550
551
552
553
554
555
556

        Args:
            inputs_kwargs (`dict`):
                The input dictionary to pass to the model layer
            module (`torch.nn.Module`):
                Target module to quantize.
        """
        module_signature = inspect.signature(module.forward).parameters
        sanitized_kwargs = {}
557
        for k, v in inputs_kwargs.items():
558
559
            if k in module_signature:
                sanitized_kwargs[k] = v
560
        return sanitized_kwargs