quantizer.py 19 KB
Newer Older
Ji Lin's avatar
Ji Lin committed
1
import torch
2
import inspect
Casper's avatar
Casper committed
3
4
5
6
import logging
import functools
import torch.nn as nn
from tqdm import tqdm
7
from typing import Dict, List, Optional
Casper's avatar
Casper committed
8
9
from collections import defaultdict
from awq.utils.calib_data import get_calib_dataset
Casper Hansen's avatar
Casper Hansen committed
10
from awq.quantize.scale import apply_scale, apply_clip
11
from awq.utils.utils import clear_memory, get_best_device
Casper's avatar
Casper committed
12
13
14
15
16
17
from awq.modules.linear import (
    WQLinear_GEMM,
    WQLinear_GEMV,
    WQLinear_Marlin,
    WQLinear_GEMVFast,
)
18
19
20
21
22
from awq.utils.module import (
    append_str_prefix,
    get_op_name,
    get_named_linears,
    set_op_by_name,
23
    exclude_layers_to_not_quantize,
24
)
Casper's avatar
Casper committed
25
26
27


class AwqQuantizer:
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
    def __init__(
        self,
        awq_model,
        model,
        tokenizer,
        w_bit,
        group_size,
        zero_point,
        version,
        calib_data,
        split,
        text_column,
        duo_scaling,
        modules_to_not_convert=None,
        export_compatible=False,
    ) -> None:
Casper Hansen's avatar
Casper Hansen committed
44
        self.awq_model = awq_model
Casper's avatar
Casper committed
45
46
47
48
        self.model = model
        self.tokenizer = tokenizer
        self.w_bit = w_bit
        self.group_size = group_size
49
        self.zero_point = zero_point
Casper's avatar
Casper committed
50
51
52
53
        self.version = version
        self.calib_data = calib_data
        self.split = split
        self.text_column = text_column
54
        self.duo_scaling = duo_scaling
55
        self.export_compatible = export_compatible
56
57
58
        self.modules_to_not_convert = (
            modules_to_not_convert if modules_to_not_convert is not None else []
        )
Casper Hansen's avatar
Casper Hansen committed
59
        self.modules, self.module_kwargs, self.inps = self.init_quant()
60
61

    def pseudo_quantize_tensor(self, w: torch.Tensor):
Casper's avatar
Casper committed
62
63
64
65
66
        org_w_shape = w.shape
        if self.group_size > 0:
            assert org_w_shape[-1] % self.group_size == 0
            w = w.reshape(-1, self.group_size)
        assert w.dim() == 2
67
        assert torch.isnan(w).sum() == 0
Casper's avatar
Casper committed
68
69

        # zero point quantization
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
        if self.zero_point:
            max_val = w.amax(dim=1, keepdim=True)
            min_val = w.amin(dim=1, keepdim=True)
            max_int = 2**self.w_bit - 1
            min_int = 0
            scales = (max_val - min_val).clamp(min=1e-5) / max_int
            zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)
            w = (
                torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - zeros
            ) * scales
            zeros = zeros.view(org_w_shape[0], -1)
        else:
            max_val = w.abs().amax(dim=1, keepdim=True)
            max_val = max_val.clamp(min=1e-5)
            max_int = 2 ** (self.w_bit - 1) - 1
            min_int = -(2 ** (self.w_bit - 1))
            scales = max_val / max_int
            zeros = None
            w = torch.clamp(torch.round(w / scales), min_int, max_int) * scales
Casper's avatar
Casper committed
89
90
91
92

        assert torch.isnan(scales).sum() == 0
        assert torch.isnan(w).sum() == 0

93
        scales = scales.view(org_w_shape[0], -1)
Casper's avatar
Casper committed
94
95
        w = w.reshape(org_w_shape)

96
        return w, scales, zeros
Casper's avatar
Casper committed
97

98
99
100
101
102
    def pseudo_dequantize_tensor(
        self, w: nn.Linear, scales: torch.Tensor, zeros: Optional[torch.Tensor] = None
    ):
        # get repeated count
        repeat_count = w.weight.data.shape[-1] // scales.shape[-1]
Casper's avatar
Casper committed
103
104
105
        scales = scales.repeat(1, repeat_count).reshape(w.weight.data.shape)

        # dequantize
106
107
108
109
110
        if self.zero_point:
            zeros = zeros.repeat(1, repeat_count).reshape(w.weight.data.shape)
            w = (w.weight.data - zeros) * scales
        else:
            w = w.weight.data * scales
Casper's avatar
Casper committed
111
112

        return w
113

Casper Hansen's avatar
Casper Hansen committed
114
115
    def quantize(self):
        for i in tqdm(range(len(self.modules)), desc="AWQ"):
Casper's avatar
Casper committed
116
117
118
            # Move module and inputs to correct device
            common_device = next(self.modules[i].parameters()).device
            if common_device is None or str(common_device) == "cpu":
119
120
121
122
                if torch.cuda.is_available():
                    best_device = "cuda:" + str(i % torch.cuda.device_count())
                else:
                    best_device = get_best_device()
Casper's avatar
Casper committed
123

124
                self.modules[i] = self.modules[i].to(best_device)
Casper's avatar
Casper committed
125
                common_device = next(self.modules[i].parameters()).device
126
127

            if self.module_kwargs.get("position_ids") is not None:
128
129
130
                self.module_kwargs["position_ids"] = self.module_kwargs[
                    "position_ids"
                ].to(common_device)
131
132

            if self.module_kwargs.get("attention_mask") is not None:
133
134
135
                self.module_kwargs["attention_mask"] = self.module_kwargs[
                    "attention_mask"
                ].to(common_device)
136

Casper's avatar
Casper committed
137
138
            self.inps = self.inps.to(common_device)

Casper's avatar
Casper committed
139
140
            # [STEP 1]: Get layer, extract linear modules, extract input features
            named_linears = get_named_linears(self.modules[i])
141
142

            # Filter out the linear layers we don't want to exclude
143
144
145
            named_linears = exclude_layers_to_not_quantize(
                named_linears, self.modules_to_not_convert
            )
146

Casper's avatar
Casper committed
147
148
149
150
            input_feat = self._get_input_feat(self.modules[i], named_linears)
            clear_memory()

            # [STEP 2]: Compute and apply scale list
Vik Paruchuri's avatar
Vik Paruchuri committed
151
            module_config: List[Dict] = self.awq_model.get_layers_for_scaling(
Casper's avatar
Casper committed
152
153
                self.modules[i], input_feat, self.module_kwargs
            )
154
155
156
157
            scales_list = [
                self._search_best_scale(self.modules[i], **layer)
                for layer in module_config
            ]
Casper's avatar
Casper committed
158
            apply_scale(self.modules[i], scales_list, input_feat_dict=input_feat)
159
160
161
            scales_list = append_str_prefix(
                scales_list, get_op_name(self.model, self.modules[i]) + "."
            )
Casper's avatar
Casper committed
162
163

            # [STEP 3]: Compute and apply clipping list
164
165
166
            clip_list = self._search_best_clip(
                self.modules[i], named_linears, input_feat
            )
Casper Hansen's avatar
Casper Hansen committed
167
            apply_clip(self.modules[i], clip_list)
168
169
170
            clip_list = append_str_prefix(
                clip_list, get_op_name(self.model, self.modules[i]) + "."
            )
Casper's avatar
Casper committed
171
172

            # [STEP 4]: Quantize weights
173
174
            if not self.export_compatible:
                self._apply_quant(self.modules[i], named_linears)
175

176
            clear_memory()
177

178
179
180
    def pack(self):
        for i in tqdm(range(len(self.modules)), desc="Packing"):
            named_linears = get_named_linears(self.modules[i])
181
182
183
            named_linears = exclude_layers_to_not_quantize(
                named_linears, self.modules_to_not_convert
            )
184
185
            self._apply_quant(self.modules[i], named_linears)
            clear_memory()
186

Vik Paruchuri's avatar
Vik Paruchuri committed
187
    def _apply_quant(self, module, named_linears: Dict[str, nn.Linear]):
188
189
        for name, linear_layer in named_linears.items():
            # NOTE: small regression in perplexity if linear layer uses .cpu().float()
190
            linear_layer = linear_layer.to(get_best_device()).half()
191
192

            linear_layer.weight.data, scales, zeros = self.pseudo_quantize_tensor(
193
                linear_layer.weight.data
194
195
            )

Casper's avatar
Casper committed
196
            if self.version == "gemm":
197
198
199
200
                scales = scales.t().contiguous()
                zeros = zeros.t().contiguous()
                q_linear_module = WQLinear_GEMM

Casper's avatar
Casper committed
201
            elif self.version == "gemv":
202
                q_linear_module = WQLinear_GEMV
203

Casper's avatar
Casper committed
204
            elif self.version == "marlin":
205
                q_linear_module = WQLinear_Marlin
Casper's avatar
Casper committed
206
207
208
            
            elif self.version == "gemv_fast":
                q_linear_module = WQLinear_GEMVFast
209
210
211
212

            else:
                raise ValueError(f"Unknown version {self.version}")

213
214
215
216
217
218
            q_linear = q_linear_module.from_linear(
                linear=linear_layer,
                w_bit=self.w_bit,
                group_size=self.group_size,
                init_only=False,
                scales=scales,
219
                zeros=zeros,
220
221
222
223
224
            )

            linear_layer.cpu()
            q_linear.to(next(module.parameters()).device)
            set_op_by_name(module, name, q_linear)
Casper's avatar
Casper committed
225
226
227
            clear_memory()

    @torch.no_grad()
228
229
230
231
232
233
234
235
236
    def _search_best_scale(
        self,
        module,
        prev_op,
        layers: List[nn.Linear],
        inp: torch.Tensor,
        module2inspect=None,
        kwargs={},
    ):
Casper Hansen's avatar
Casper Hansen committed
237
238
239
        if module2inspect is None:
            assert len(layers) == 1
            module2inspect = layers[0]
240

Casper Hansen's avatar
Casper Hansen committed
241
242
        if "use_cache" in kwargs:
            kwargs.pop("use_cache")
243

Casper's avatar
Casper committed
244
        # Put x on the right device
Casper Hansen's avatar
Casper Hansen committed
245
        inp = inp.to(next(module2inspect.parameters()).device)
Casper's avatar
Casper committed
246
247

        # [STEP 1]: Compute maximum of weight
Casper Hansen's avatar
Casper Hansen committed
248
249
        weight = torch.cat([_m.weight for _m in layers], dim=0)
        org_shape = weight.shape
Casper's avatar
Casper committed
250
        weight = weight.view(-1, self.group_size)
Casper Hansen's avatar
Casper Hansen committed
251
252
253
        w_scale = weight.abs() / weight.abs().amax(dim=1, keepdim=True)
        w_scale = w_scale.view(org_shape)
        w_max = w_scale.mean(0)
Casper's avatar
Casper committed
254
255
256
        clear_memory(weight)

        # [STEP 2]: Compute maximum of x
Casper Hansen's avatar
Casper Hansen committed
257
        x_max = inp.abs().view(-1, inp.shape[-1]).mean(0)
Casper's avatar
Casper committed
258

Casper Hansen's avatar
Casper Hansen committed
259
        # [STEP 3]: Compute output of module
Casper's avatar
Casper committed
260
        with torch.no_grad():
261
262
263
            module_kwargs = self._sanitize_kwargs(kwargs, module2inspect)

            fp16_output = module2inspect(inp, **module_kwargs)
264
265
            if isinstance(fp16_output, tuple):
                fp16_output = fp16_output[0]
266

Casper's avatar
Casper committed
267
268
        # [STEP 4]: Compute loss
        best_scales = self._compute_best_scale(
269
            inp, w_max, x_max, module2inspect, layers, fp16_output, module_kwargs
Casper's avatar
Casper committed
270
271
        )

272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
        return (
            get_op_name(module, prev_op),
            tuple([get_op_name(module, m) for m in layers]),
            best_scales,
        )

    def _compute_best_scale(
        self,
        x,
        w_max,
        x_max,
        module2inspect,
        linears2scale: List[nn.Linear],
        fp16_output,
        kwargs={},
    ):
Casper's avatar
Casper committed
288
289
290
        """
        Compute loss and select best scales

Casper's avatar
Casper committed
291
        L(s) = || Q(W * s) (s^-1 * X) - W * X ||
Casper's avatar
Casper committed
292
293
294
295
296
297
298
299
300
        Q: weight quantization function | pseudo_quantize_tensor(W * s)
        X: inputs from calib dataset    | X
        W: original weights in FP16     | layer
        s: per channel scaling factor   | s^-1 * X
        """
        n_grid = 20
        history = []
        best_ratio = -1
        best_scales = None
301
        best_error = float("inf")
Casper's avatar
Casper committed
302

Casper Hansen's avatar
Casper Hansen committed
303
        org_sd = {k: v.cpu() for k, v in module2inspect.state_dict().items()}
304

Casper's avatar
Casper committed
305
306
307
        device = x.device
        x_max = x_max.view(-1).to(device)
        w_max = w_max.view(-1).to(device)
308

Casper's avatar
Casper committed
309
310
        for ratio in range(n_grid):
            # create new scales
Casper's avatar
Casper committed
311
            ratio = ratio / n_grid
312

Casper Hansen's avatar
Casper Hansen committed
313
            # NOTE: s^-1 * x is fused here, according to paper
314
            if self.duo_scaling:
315
                scales = (x_max.pow(ratio) / w_max.pow(1 - ratio)).clamp(min=1e-4)
316
317
            else:
                scales = x_max.pow(ratio).clamp(min=1e-4).view(-1)
Casper's avatar
Casper committed
318
            scales = scales / (scales.max() * scales.min()).sqrt()
Casper's avatar
Casper committed
319
            scales_view = scales.view(1, -1).to(device)
320

Casper Hansen's avatar
Casper Hansen committed
321
            # Q(W * s)
Casper's avatar
Casper committed
322
            for fc in linears2scale:
Casper's avatar
Casper committed
323
                fc.weight.mul_(scales_view)
324
325
326
                fc.weight.data = (
                    self.pseudo_quantize_tensor(fc.weight.data)[0] / scales_view
                )
Casper's avatar
Casper committed
327

328
329
330
331
            # W * X
            int_w_output = module2inspect(x, **kwargs)
            if isinstance(int_w_output, tuple):
                int_w_output = int_w_output[0]
332

Casper Hansen's avatar
Casper Hansen committed
333
            # compute mean squared error (L2 norm)
334
335
336
            loss = (
                (fp16_output - int_w_output).float().pow(2).mean().item()
            )  # NOTE: float prevents overflow
Casper's avatar
Casper committed
337
338

            history.append(loss)
Casper's avatar
Casper committed
339
            if loss < best_error:
Casper's avatar
Casper committed
340
341
                best_error = loss
                best_ratio = ratio
Casper's avatar
Casper committed
342
                best_scales = scales.clone()
Casper Hansen's avatar
Casper Hansen committed
343
            module2inspect.load_state_dict(org_sd)
Casper's avatar
Casper committed
344

Casper's avatar
Casper committed
345
346
347
348
349
350
        if best_ratio == -1:
            logging.debug(history)
            raise Exception

        assert torch.isnan(best_scales).sum() == 0, best_scales

Casper Hansen's avatar
Casper Hansen committed
351
        return best_scales.detach().cpu()
Casper's avatar
Casper committed
352

Casper Hansen's avatar
Casper Hansen committed
353
354
355
356
    @torch.no_grad()
    def _search_best_clip(self, layer, named_linears, input_feat):
        clip_list = []
        avoid_clipping = ["q_", "k_", "query", "key", "Wqkv"]
Casper's avatar
Casper committed
357

Casper Hansen's avatar
Casper Hansen committed
358
359
360
361
362
        for name in named_linears:
            # due to qk bmm, it is hard to clip precisely
            if any([_ in name for _ in avoid_clipping]):
                continue

363
            named_linears[name].to(get_best_device())
Casper's avatar
Casper committed
364
365
366
            max_val = self._compute_best_clip(
                named_linears[name].weight, input_feat[name]
            )
Casper Hansen's avatar
Casper Hansen committed
367
368
            clip_list.append((name, max_val))
            named_linears[name].cpu()
369

Casper Hansen's avatar
Casper Hansen committed
370
        return clip_list
Casper Hansen's avatar
Casper Hansen committed
371
372

    @torch.no_grad()
373
374
375
376
377
378
379
380
    def _compute_best_clip(
        self,
        w: torch.Tensor,
        input_feat: torch.Tensor,
        n_grid=20,
        max_shrink=0.5,
        n_sample_token=512,
    ):
Casper Hansen's avatar
Casper Hansen committed
381
382
383
384
        assert w.dim() == 2
        org_w_shape = w.shape
        # w           [co, ci]      -> [co, 1, n_group, group size]
        # input_feat  [n_token, ci] -> [1, n_token, n_group, group size]
385
        group_size = self.group_size if self.group_size > 0 else org_w_shape[1]
Casper Hansen's avatar
Casper Hansen committed
386
387
        input_feat = input_feat.view(-1, input_feat.shape[-1])
        input_feat = input_feat.reshape(1, input_feat.shape[0], -1, group_size)
388
389
        input_feat = input_feat[:, 0 :: input_feat.shape[1] // n_sample_token]
        w = w.reshape(org_w_shape[0], 1, -1, group_size)
Casper Hansen's avatar
Casper Hansen committed
390

391
392
        oc_batch_size = 256 if org_w_shape[0] % 256 == 0 else 64  # prevent OOM
        assert org_w_shape[0] % oc_batch_size == 0
Casper Hansen's avatar
Casper Hansen committed
393
394
        w_all = w
        best_max_val_all = []
Casper's avatar
Casper committed
395

396
397
        for i_b in range(org_w_shape[0] // oc_batch_size):
            w = w_all[i_b * oc_batch_size : (i_b + 1) * oc_batch_size]
Casper Hansen's avatar
Casper Hansen committed
398
399
400
401
402
403
404
405
406
407

            org_max_val = w.abs().amax(dim=-1, keepdim=True)  # co, 1, n_group, 1

            best_max_val = org_max_val.clone()
            min_errs = torch.ones_like(org_max_val) * 1e9
            input_feat = input_feat.to(w.device)
            org_out = (input_feat * w).sum(dim=-1)  # co, n_token, n_group

            for i_s in range(int(max_shrink * n_grid)):
                max_val = org_max_val * (1 - i_s / n_grid)
408
                min_val = -max_val
Casper Hansen's avatar
Casper Hansen committed
409
                cur_w = torch.clamp(w, min_val, max_val)
410
                q_w = self.pseudo_quantize_tensor(cur_w)[0]
Casper Hansen's avatar
Casper Hansen committed
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
                cur_out = (input_feat * q_w).sum(dim=-1)

                # co, 1, n_group, 1
                err = (cur_out - org_out).pow(2).mean(dim=1).view(min_errs.shape)
                del cur_w
                del cur_out
                cur_best_idx = err < min_errs
                min_errs[cur_best_idx] = err[cur_best_idx]
                best_max_val[cur_best_idx] = max_val[cur_best_idx]
            best_max_val_all.append(best_max_val)

        best_max_val = torch.cat(best_max_val_all, dim=0)

        clear_memory(input_feat)
        clear_memory(org_out)

        return best_max_val.squeeze(1)

Casper's avatar
Casper committed
429
    def init_quant(self, n_samples=128, seqlen=512):
Casper Hansen's avatar
Casper Hansen committed
430
        modules = self.awq_model.get_model_layers(self.model)
Casper's avatar
Casper committed
431
        samples = get_calib_dataset(
432
433
434
435
436
437
            data=self.calib_data,
            tokenizer=self.tokenizer,
            n_samples=n_samples,
            block_size=seqlen,
            split=self.split,
            text_column=self.text_column,
Casper's avatar
Casper committed
438
439
440
441
442
443
        )
        samples = torch.cat(samples, dim=0)

        inps = []
        layer_kwargs = {}

444
445
446
        best_device = get_best_device()
        modules[0] = modules[0].to(best_device)
        self.awq_model.move_embed(self.model, best_device)
447

Casper's avatar
Casper committed
448
449
450
451
452
453
454
455
        # get input and kwargs to layer 0
        # with_kwargs is only supported in PyTorch 2.0
        # use this Catcher hack for now
        class Catcher(nn.Module):
            def __init__(self, module):
                super().__init__()
                self.module = module

456
457
458
459
460
461
462
463
464
465
            def forward(self, *args, **kwargs):
                # assume first input to forward is hidden states
                if len(args) > 0:
                    hidden_states = args[0]
                    del args
                else:
                    first_key = list(kwargs.keys())[0]
                    hidden_states = kwargs.pop(first_key)

                inps.append(hidden_states)
Casper's avatar
Casper committed
466
467
468
469
                layer_kwargs.update(kwargs)
                raise ValueError  # early exit to break later inference

        # patch layer 0 to catch input and kwargs
Casper Hansen's avatar
Casper Hansen committed
470
        modules[0] = Catcher(modules[0])
Casper's avatar
Casper committed
471
472
473
474
        try:
            self.model(samples.to(next(self.model.parameters()).device))
        except ValueError:  # work with early exit
            pass
Casper's avatar
Casper committed
475
        modules[0] = modules[0].module  # restore
476

477
478
479
480
481
482
        # Update the layer kwargs with `prepare_inputs_for_generation` method
        # that takes care of everything to avoid unexpected errors.
        layer_kwargs = self.model.prepare_inputs_for_generation(samples, **layer_kwargs)
        # Pop the input_ids as they are not needed at all.
        layer_kwargs.pop("input_ids")

Casper's avatar
Casper committed
483
484
485
        del samples
        inps = inps[0]

Casper Hansen's avatar
Casper Hansen committed
486
487
        modules[0] = modules[0].cpu()
        self.awq_model.move_embed(self.model, "cpu")
488

Casper's avatar
Casper committed
489
        clear_memory()
490

Casper's avatar
Casper committed
491
        if layer_kwargs.get("attention_mask") is not None:
Casper's avatar
Casper committed
492
493
494
            layer_kwargs["attention_mask"] = layer_kwargs["attention_mask"].to(
                best_device
            )
Casper's avatar
Casper committed
495

Casper Hansen's avatar
Casper Hansen committed
496
        return modules, layer_kwargs, inps
497

Casper's avatar
Casper committed
498
499
500
501
502
503
504
505
506
    def _get_input_feat(self, layer, named_linears):
        # firstly, get input features of all linear layers
        def cache_input_hook(m, x, y, name, feat_dict):
            x = x[0]
            x = x.detach().cpu()
            feat_dict[name].append(x)

        input_feat = defaultdict(list)
        handles = []
507
508
509

        # FIXME: Workaround for Mixtral to use block_sparse_moe input features
        if self.awq_model.model_type == "mixtral":
510
511
512
513
            named_linears = {
                **named_linears,
                "block_sparse_moe": layer.block_sparse_moe,
            }
514

Casper's avatar
Casper committed
515
        for name in named_linears:
516
517
518
519
520
            handles.append(
                named_linears[name].register_forward_hook(
                    functools.partial(cache_input_hook, name=name, feat_dict=input_feat)
                )
            )
Casper Hansen's avatar
Casper Hansen committed
521
        self.inps = self.inps.to(next(layer.parameters()).device)  # in case multi-gpu
Casper's avatar
Casper committed
522
        # get output as next layer's input
523

524
525
526
527
528
529
        # Sanitize the kwargs in case we use transformers version that contains
        # kwargs that are not handled by the module.
        # Useful for trust_remote_code models.
        module_kwargs = self._sanitize_kwargs(self.module_kwargs, layer)

        self.inps = layer(self.inps, **module_kwargs)[0]
Casper's avatar
Casper committed
530
531
532
533
        for h in handles:
            h.remove()
        # now solve for scaling and clipping
        input_feat = {k: torch.cat(v, dim=0) for k, v in input_feat.items()}
534

535
        return input_feat
536
537
538
539
540

    def _sanitize_kwargs(self, inputs_kwargs, module):
        """
        Remove the arguments that are not supported in the module's
        forward pass to avoid breaking behaviour between different versions
541
        of transformers.
542
543
544
545
546
547
548
549
550

        Args:
            inputs_kwargs (`dict`):
                The input dictionary to pass to the model layer
            module (`torch.nn.Module`):
                Target module to quantize.
        """
        module_signature = inspect.signature(module.forward).parameters
        sanitized_kwargs = {}
551
        for k, v in inputs_kwargs.items():
552
553
            if k in module_signature:
                sanitized_kwargs[k] = v
554
        return sanitized_kwargs