"tests/vscode:/vscode.git/clone" did not exist on "620e8924d9c6b2a0b1d49ac0dcf2588fffcbe390"
utils.py 17.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
bnellnm's avatar
bnellnm committed
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5

import torch

bnellnm's avatar
bnellnm committed
6
import vllm._custom_ops as ops
7
from tests.kernels.quant_utils import per_block_cast_to_int8
8
from tests.kernels.quantization.nvfp4_utils import FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX
9
from vllm.model_executor.layers.activation import SiluAndMul
10
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
11
12
13
14
15
16
from vllm.model_executor.layers.fused_moe.config import (
    FusedMoEConfig,
    FusedMoEParallelConfig,
    FusedMoEQuantConfig,
    RoutingMethodType,
)
bnellnm's avatar
bnellnm committed
17
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
18
19
20
21
    BatchedPrepareAndFinalize,
    BatchedTritonExperts,
    NaiveBatchedExperts,
)
22
23
24
25
from vllm.model_executor.layers.fused_moe.fused_moe import (
    TritonExperts,
    fused_experts,
)
26
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
27
28
29
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
    MoEPrepareAndFinalizeNoEP,
)
30
from vllm.model_executor.layers.fused_moe.router.fused_topk_router import fused_topk
31
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
32
from vllm.utils.deep_gemm import per_block_cast_to_fp8
33
from vllm.utils.math_utils import round_up
bnellnm's avatar
bnellnm committed
34
35


36
37
38
39
40
41
42
43
44
45
def shuffle_weight(w: torch.Tensor) -> torch.Tensor:
    """Fold weights to adjacent locations for Triton MoE / SwiGLU kernel layout."""
    shape = w.shape
    n = shape[-1]
    first = w[..., : n // 2]
    second = w[..., n // 2 :]
    stacked = torch.stack((first, second), dim=-1)
    return stacked.reshape(shape)


46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def make_dummy_moe_config(
    num_experts: int = 1,
    experts_per_token: int = 1,
    hidden_dim: int = 1,
    intermediate_size_per_partition: int = 1,
    in_dtype: torch.dtype = torch.bfloat16,
) -> FusedMoEConfig:
    """
    This is a dummy config for the mk constructor interface
    as most kernels like DeepGEMM, CUTLASSFp4, Triton, MARLIN
    do not actually use this config.

    CUTLASSFp8 needs to set some params for workshapes.
    """
    return FusedMoEConfig(
        num_experts=num_experts,
        experts_per_token=experts_per_token,
        hidden_dim=hidden_dim,
        intermediate_size_per_partition=intermediate_size_per_partition,
        num_local_experts=num_experts,
66
        num_logical_experts=num_experts,
67
        moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
68
        activation=MoEActivation.SILU,
69
70
71
72
73
74
        in_dtype=in_dtype,
        device="cuda",
        routing_method=RoutingMethodType.TopK,
    )


bnellnm's avatar
bnellnm committed
75
76
77
78
79
80
def triton_moe(
    a: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weight: torch.Tensor,
    topk_ids: torch.Tensor,
81
82
83
84
85
    w1_scale: torch.Tensor | None = None,
    w2_scale: torch.Tensor | None = None,
    a1_scale: torch.Tensor | None = None,
    a2_scale: torch.Tensor | None = None,
    quant_dtype: torch.dtype | None = None,
bnellnm's avatar
bnellnm committed
86
    per_act_token_quant=False,
87
    block_shape: list[int] | None = None,
bnellnm's avatar
bnellnm committed
88
) -> torch.Tensor:
89
90
91
92
93
94
95
96
97
98
    quant_config = FusedMoEQuantConfig.make(
        quant_dtype,
        per_act_token_quant=per_act_token_quant,
        block_shape=block_shape,
        w1_scale=w1_scale,
        w2_scale=w2_scale,
        a1_scale=a1_scale,
        a2_scale=a2_scale,
    )

99
    return fused_experts(a, w1, w2, topk_weight, topk_ids, quant_config=quant_config)
bnellnm's avatar
bnellnm committed
100
101
102
103
104
105
106
107


def batched_moe(
    a: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weight: torch.Tensor,
    topk_ids: torch.Tensor,
108
109
110
111
112
    w1_scale: torch.Tensor | None = None,
    w2_scale: torch.Tensor | None = None,
    a1_scale: torch.Tensor | None = None,
    a2_scale: torch.Tensor | None = None,
    quant_dtype: torch.dtype | None = None,
bnellnm's avatar
bnellnm committed
113
    per_act_token_quant: bool = False,
114
    block_shape: list[int] | None = None,
bnellnm's avatar
bnellnm committed
115
116
117
) -> torch.Tensor:
    max_num_tokens = round_up(a.shape[0], 64)

118
119
120
121
122
123
124
125
126
127
    quant_config = FusedMoEQuantConfig.make(
        quant_dtype,
        per_act_token_quant=per_act_token_quant,
        block_shape=block_shape,
        w1_scale=w1_scale,
        w2_scale=w2_scale,
        a1_scale=a1_scale,
        a2_scale=a2_scale,
    )

bnellnm's avatar
bnellnm committed
128
    fused_experts = FusedMoEModularKernel(
129
130
131
        BatchedPrepareAndFinalize(
            max_num_tokens, num_dispatchers=1, num_local_experts=w1.shape[0], rank=0
        ),
bnellnm's avatar
bnellnm committed
132
133
        BatchedTritonExperts(
            max_num_tokens=max_num_tokens,
134
            num_dispatchers=1,
135
            quant_config=quant_config,
136
            moe_config=make_dummy_moe_config(),
bnellnm's avatar
bnellnm committed
137
        ),
138
        inplace=False,
139
140
    )

141
    return fused_experts(a, w1, w2, topk_weight, topk_ids)
bnellnm's avatar
bnellnm committed
142
143
144
145
146
147
148
149


def naive_batched_moe(
    a: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weight: torch.Tensor,
    topk_ids: torch.Tensor,
150
151
152
153
154
    w1_scale: torch.Tensor | None = None,
    w2_scale: torch.Tensor | None = None,
    a1_scale: torch.Tensor | None = None,
    a2_scale: torch.Tensor | None = None,
    quant_dtype: torch.dtype | None = None,
bnellnm's avatar
bnellnm committed
155
    per_act_token_quant: bool = False,
156
    block_shape: list[int] | None = None,
bnellnm's avatar
bnellnm committed
157
158
159
) -> torch.Tensor:
    max_num_tokens = round_up(a.shape[0], 64)

160
161
162
163
164
165
166
167
168
169
    quant_config = FusedMoEQuantConfig.make(
        quant_dtype,
        per_act_token_quant=per_act_token_quant,
        block_shape=block_shape,
        w1_scale=w1_scale,
        w2_scale=w2_scale,
        a1_scale=a1_scale,
        a2_scale=a2_scale,
    )

bnellnm's avatar
bnellnm committed
170
    fused_experts = FusedMoEModularKernel(
171
172
173
        BatchedPrepareAndFinalize(
            max_num_tokens, num_dispatchers=1, num_local_experts=w1.shape[0], rank=0
        ),
bnellnm's avatar
bnellnm committed
174
175
        NaiveBatchedExperts(
            max_num_tokens=max_num_tokens,
176
            num_dispatchers=1,
177
            quant_config=quant_config,
178
            moe_config=make_dummy_moe_config(),
bnellnm's avatar
bnellnm committed
179
        ),
180
        inplace=False,
bnellnm's avatar
bnellnm committed
181
    )
182

183
    return fused_experts(a, w1, w2, topk_weight, topk_ids)
bnellnm's avatar
bnellnm committed
184
185


186
def chunk_scales(
187
188
    scales: torch.Tensor | None, start: int, end: int
) -> torch.Tensor | None:
bnellnm's avatar
bnellnm committed
189
190
191
192
193
194
195
196
197
198
199
200
201
    if scales is not None:
        if scales.numel() == 1:
            return scales
        else:
            return scales[start:end]
    return None


def make_quantized_test_activations(
    E: int,
    m: int,
    k: int,
    in_dtype: torch.dtype,
202
203
    quant_dtype: torch.dtype | None = None,
    block_shape: list[int] | None = None,
bnellnm's avatar
bnellnm committed
204
    per_act_token_quant: bool = False,
205
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
bnellnm's avatar
bnellnm committed
206
207
208
209
210
    a = torch.randn((E, m, k), device="cuda", dtype=in_dtype) / 10
    a_q = a
    a_scale = None

    if quant_dtype is not None:
211
212
213
        assert quant_dtype == torch.float8_e4m3fn or quant_dtype == torch.int8, (
            "only fp8/int8 supported"
        )
bnellnm's avatar
bnellnm committed
214
215
216
217
        a_q = torch.zeros_like(a, dtype=quant_dtype)
        a_scale_l = [None] * E
        for e in range(E):
            a_q[e], a_scale_l[e] = moe_kernel_quantize_input(
218
219
                a[e], None, quant_dtype, per_act_token_quant, block_shape
            )
bnellnm's avatar
bnellnm committed
220
221
222
223
224
225
226
227
228
229
        a_scale = torch.stack(a_scale_l)

        if not per_act_token_quant and block_shape is None:
            a_scale = a_scale.view(E, 1, 1)

    return a, a_q, a_scale


def moe_quantize_weights(
    w: torch.Tensor,
230
231
    w_s: torch.Tensor | None,
    quant_dtype: torch.dtype | str | None,
bnellnm's avatar
bnellnm committed
232
    per_token_quant: bool,
233
234
    block_shape: list[int] | None,
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
235
236
237
238
239
    assert (
        quant_dtype == torch.float8_e4m3fn
        or quant_dtype == torch.int8
        or quant_dtype == "nvfp4"
    ), "only fp8/int8/nvfp4 supported"
240
241

    w_gs = None
bnellnm's avatar
bnellnm committed
242
243
244
245
246

    if block_shape is not None:
        assert not per_token_quant
        if quant_dtype == torch.int8:
            w, w_s = per_block_cast_to_int8(w, block_shape)
247
        elif quant_dtype == torch.float8_e4m3fn:
bnellnm's avatar
bnellnm committed
248
            w, w_s = per_block_cast_to_fp8(w, block_shape)
249
250
251
252
        elif quant_dtype == "nvfp4":
            raise RuntimeError("blocked quantization not supported for nvfp4")
        else:
            raise RuntimeError(f"Unsupported quant type {quant_dtype}")
bnellnm's avatar
bnellnm committed
253
254
255
    else:
        if quant_dtype == torch.int8:
            w, w_s = ops.scaled_int8_quant(
256
257
                w, w_s, use_per_token_if_dynamic=per_token_quant
            )
258
        elif quant_dtype == torch.float8_e4m3fn:
bnellnm's avatar
bnellnm committed
259
            w, w_s = ops.scaled_fp8_quant(
260
261
                w, w_s, use_per_token_if_dynamic=per_token_quant
            )
262
263
264
265
266
267
268
        elif quant_dtype == "nvfp4":
            assert not per_token_quant
            w_amax = torch.abs(w).max().to(torch.float32)
            w_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w_amax
            w, w_s = ops.scaled_fp4_quant(w, w_gs)
        else:
            raise RuntimeError(f"Unsupported quant type {quant_dtype}")
bnellnm's avatar
bnellnm committed
269

270
    return w, w_s, w_gs
bnellnm's avatar
bnellnm committed
271
272
273
274
275
276
277


def make_test_weight(
    e: int,
    rows: int,
    cols: int,
    in_dtype: torch.dtype = torch.bfloat16,
278
279
    quant_dtype: torch.dtype | str | None = None,
    block_shape: list[int] | None = None,
280
    per_out_ch_quant: bool = False,
281
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
bnellnm's avatar
bnellnm committed
282
    w_16 = torch.randn((e, rows, cols), device="cuda", dtype=in_dtype) / 15
283
    w_gs = None
bnellnm's avatar
bnellnm committed
284
285
286
287

    if quant_dtype is not None:
        w_l = [None] * e
        w_s_l = [None] * e
288
        w_gs_l = [None] * e
bnellnm's avatar
bnellnm committed
289
        for idx in range(e):
290
            w_l[idx], w_s_l[idx], w_gs_l[idx] = moe_quantize_weights(
291
292
                w_16[idx], None, quant_dtype, per_out_ch_quant, block_shape
            )
bnellnm's avatar
bnellnm committed
293
294
295

        w = torch.stack(w_l)
        w_s = torch.stack(w_s_l)
296
297
        if e > 0 and w_gs_l[0] is not None:
            w_gs = torch.stack(w_gs_l)
bnellnm's avatar
bnellnm committed
298
299
300
301
302
303
304
305
306
307
308
309
        if w_s.ndim == 2:
            assert w_s.shape[-1] == 1
            w_s = w_s.view(-1, 1, 1)

        if block_shape is not None:
            block_n, block_k = block_shape
            n_tiles = (rows + block_n - 1) // block_n
            k_tiles = (cols + block_k - 1) // block_k
            assert w_s.shape == (e, n_tiles, k_tiles)
    else:
        w = w_16
        w_s = None
310
        w_gs = None
bnellnm's avatar
bnellnm committed
311

312
    return w_16, w, w_s, w_gs
bnellnm's avatar
bnellnm committed
313
314
315
316
317
318
319


def make_test_weights(
    e: int,
    n: int,
    k: int,
    in_dtype: torch.dtype = torch.bfloat16,
320
321
    quant_dtype: torch.dtype | str | None = None,
    block_shape: list[int] | None = None,
322
    per_out_ch_quant: bool = False,
323
    make_gate: bool = True,
324
) -> tuple[
325
326
    tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None],
    tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None],
327
]:
bnellnm's avatar
bnellnm committed
328
    return (
329
        make_test_weight(
330
331
332
333
334
335
336
            e,
            (2 if make_gate else 1) * n,
            k,
            in_dtype,
            quant_dtype,
            block_shape,
            per_out_ch_quant,
337
338
        ),
        make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape, per_out_ch_quant),
bnellnm's avatar
bnellnm committed
339
    )
340
341
342


def per_token_cast_to_fp8(
343
344
    x: torch.Tensor, block_size: int = 128
) -> tuple[torch.Tensor, torch.Tensor]:
345
346
347
    assert x.dim() == 2
    m, n = x.shape
    pad_size = (block_size - (n % block_size)) % block_size
348
    x = torch.nn.functional.pad(x, (0, pad_size), value=0) if pad_size > 0 else x
349
350
351
352
    x_view = x.view(m, -1, block_size)
    x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
    fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn)
    return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
353
354


355
356
357
358
359
def make_test_quant_config(
    e: int,
    n: int,
    k: int,
    in_dtype: torch.dtype,
360
    quant_dtype: torch.dtype | str | None = None,
361
    per_act_token_quant: bool = False,
362
    block_shape: list[int] | None = None,
363
    make_gate: bool = True,
364
365
366
367
368
369
370
371
372
) -> tuple[torch.Tensor, torch.Tensor, FusedMoEQuantConfig]:
    (_, w1, w1_s, w1_gs), (_, w2, w2_s, w2_gs) = make_test_weights(
        e,
        n,
        k,
        in_dtype,
        quant_dtype,
        per_out_ch_quant=per_act_token_quant,
        block_shape=block_shape,
373
        make_gate=make_gate,
374
375
376
    )

    # Hacky/trivial scales for nvfp4.
377
378
    a1_gscale: torch.Tensor | None = None
    a2_gscale: torch.Tensor | None = None
379
    if quant_dtype == "nvfp4":
380
381
        a1_gscale = torch.ones((e,), device="cuda", dtype=torch.float32)
        a2_gscale = torch.ones((e,), device="cuda", dtype=torch.float32)
382
383
384
385
386
387
        a1_scale = a1_gscale
        a2_scale = a2_gscale
    else:
        a1_scale = None
        a2_scale = None

388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
    return (
        w1,
        w2,
        FusedMoEQuantConfig.make(
            quant_dtype,
            per_act_token_quant=per_act_token_quant,
            block_shape=block_shape,
            w1_scale=w1_s,
            w2_scale=w2_s,
            a1_gscale=a1_gscale,
            a2_gscale=a2_gscale,
            a1_scale=a1_scale,
            a2_scale=a2_scale,
            # TODO: make sure this is handled properly
            g1_alphas=(1 / w1_gs) if w1_gs is not None else None,
            g2_alphas=(1 / w2_gs) if w2_gs is not None else None,
        ),
405
406
407
408
409
410
411
412
413
414
    )


def fused_moe(
    hidden_states: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    score: torch.Tensor,
    topk: int,
    renormalize: bool = False,
415
    quant_config: FusedMoEQuantConfig | None = None,
416
    global_num_experts: int = -1,
417
    expert_map: torch.Tensor | None = None,
418
) -> torch.Tensor:
419
420
421
422
423
424
425
426
427
428
429
430
431
    topk_weights, topk_ids, _ = fused_topk(
        hidden_states, score.float(), topk, renormalize
    )
    return fused_experts(
        hidden_states,
        w1,
        w2,
        topk_weights,
        topk_ids,
        global_num_experts=global_num_experts,
        expert_map=expert_map,
        quant_config=quant_config,
    )
432
433


434
435
436
437
438
439
440
441
442
443
444
# CustomOp?
class BaselineMM(torch.nn.Module):
    def __init__(
        self,
        b: torch.Tensor,
        out_dtype: torch.dtype,
    ):
        super().__init__()
        self.b = b.to(dtype=torch.float32)
        self.out_dtype = out_dtype

445
    def forward(self, a: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]:
446
        return torch.mm(a.to(dtype=torch.float32), self.b).to(self.out_dtype), None
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488


class TestMLP(torch.nn.Module):
    def __init__(
        self,
        w1: torch.Tensor,
        w2: torch.Tensor,
        out_dtype: torch.dtype,
    ):
        super().__init__()
        self.gate_up_proj = BaselineMM(w1, out_dtype)
        self.down_proj = BaselineMM(w2, out_dtype)
        self.act_fn = SiluAndMul()

    def forward(self, x):
        x, _ = self.gate_up_proj(x)
        x = self.act_fn(x)
        x, _ = self.down_proj(x)
        return x


def make_naive_shared_experts(
    N: int,
    K: int,
    in_dtype: torch.dtype = torch.bfloat16,
) -> torch.nn.Module:
    w1 = torch.randn((K, N * 2), device="cuda", dtype=in_dtype) / 15
    w2 = torch.randn((N, K), device="cuda", dtype=in_dtype) / 15
    return TestMLP(w1, w2, out_dtype=in_dtype)


class RealMLP(torch.nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        w1: torch.Tensor,
        w2: torch.Tensor,
        hidden_act: str = "silu",
        quant_config=None,
        reduce_results: bool = True,
        prefix: str = "",
489
490
        w1_s: torch.Tensor | None = None,
        w2_s: torch.Tensor | None = None,
491
492
    ) -> None:
        from vllm.model_executor.layers.linear import (
493
494
495
            MergedColumnParallelLinear,
            RowParallelLinear,
        )
496
497
498

        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
499
500
            hidden_size,
            [intermediate_size] * 2,
501
502
            bias=False,
            quant_config=quant_config,
503
504
            prefix=f"{prefix}.gate_up_proj",
        )
505
        self.gate_up_proj.register_parameter(
506
507
            "weight", torch.nn.Parameter(w1, requires_grad=False)
        )
508
        self.gate_up_proj.register_parameter(
509
510
            "weight_scale", torch.nn.Parameter(w1_s, requires_grad=False)
        )
511
        self.gate_up_proj.register_parameter(
512
513
514
515
516
517
518
519
520
521
            "input_scale", None
        )  # torch.nn.Parameter(None, requires_grad=False))
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            reduce_results=reduce_results,
            prefix=f"{prefix}.down_proj",
        )
522
        self.down_proj.register_parameter(
523
524
            "weight", torch.nn.Parameter(w2, requires_grad=False)
        )
525
        self.down_proj.register_parameter(
526
527
            "weight_scale", torch.nn.Parameter(w2_s, requires_grad=False)
        )
528
        self.down_proj.register_parameter(
529
530
            "input_scale", None
        )  # torch.nn.Parameter(None, requires_grad=False))
531
        if hidden_act != "silu":
532
533
534
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
535
536
537
538
539
540
541
542
543
544
545
546
547
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


def make_shared_experts(
    N: int,
    K: int,
    in_dtype: torch.dtype = torch.bfloat16,
548
    quant_dtype: torch.dtype | str | None = None,
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
) -> torch.nn.Module:
    from vllm.model_executor.layers.quantization.fp8 import Fp8Config

    (_, w1, w1_s, _), (_, w2, w2_s, _) = make_test_weights(
        1,
        N,
        K,
        in_dtype=in_dtype,
        quant_dtype=quant_dtype,
    )
    old_dtype = torch.get_default_dtype()
    try:
        torch.set_default_dtype(in_dtype)
        if quant_dtype == torch.float8_e4m3fn:
            w1 = w1[0].transpose(0, 1)
            w2 = w2[0].transpose(0, 1)
            w1_s = w1_s[0].transpose(0, 1) if w1_s is not None else None
            w2_s = w2_s[0].transpose(0, 1) if w2_s is not None else None
            quant_config = Fp8Config(True)
        else:
            w1 = w1[0]
            w2 = w2[0]
            w1_s = None
            w2_s = None
            quant_config = None

575
        return RealMLP(K, N, w1, w2, "silu", quant_config, w1_s=w1_s, w2_s=w2_s)
576
577
    finally:
        torch.set_default_dtype(old_dtype)
578
579
580
581
582
583
584
585
586
587
588
589
590


def modular_triton_fused_moe(
    moe_config: FusedMoEConfig,
    quant_config: FusedMoEQuantConfig,
    shared_experts: torch.nn.Module | None = None,
) -> FusedMoEModularKernel:
    return FusedMoEModularKernel(
        MoEPrepareAndFinalizeNoEP(),
        TritonExperts(moe_config, quant_config),
        shared_experts,
        inplace=False,
    )