utils.py 18.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
bnellnm's avatar
bnellnm committed
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5

import torch

bnellnm's avatar
bnellnm committed
6
import vllm._custom_ops as ops
7
from tests.kernels.quant_utils import per_block_cast_to_int8
8
from tests.kernels.quantization.nvfp4_utils import FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX
9
from vllm.model_executor.layers.activation import SiluAndMul
10
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
11
12
13
from vllm.model_executor.layers.fused_moe.all2all_utils import (
    maybe_make_prepare_finalize,
)
14
15
16
17
18
19
from vllm.model_executor.layers.fused_moe.config import (
    FusedMoEConfig,
    FusedMoEParallelConfig,
    FusedMoEQuantConfig,
    RoutingMethodType,
)
bnellnm's avatar
bnellnm committed
20
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
21
22
23
24
    BatchedPrepareAndFinalize,
    BatchedTritonExperts,
    NaiveBatchedExperts,
)
25
26
27
28
from vllm.model_executor.layers.fused_moe.fused_moe import (
    TritonExperts,
    fused_experts,
)
29
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEKernel
30
from vllm.model_executor.layers.fused_moe.router.fused_topk_router import fused_topk
31
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
32
from vllm.utils.deep_gemm import per_block_cast_to_fp8
33
from vllm.utils.math_utils import round_up
bnellnm's avatar
bnellnm committed
34
35


36
37
38
39
40
41
42
43
44
45
def shuffle_weight(w: torch.Tensor) -> torch.Tensor:
    """Fold weights to adjacent locations for Triton MoE / SwiGLU kernel layout."""
    shape = w.shape
    n = shape[-1]
    first = w[..., : n // 2]
    second = w[..., n // 2 :]
    stacked = torch.stack((first, second), dim=-1)
    return stacked.reshape(shape)


46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def make_dummy_moe_config(
    num_experts: int = 1,
    experts_per_token: int = 1,
    hidden_dim: int = 1,
    intermediate_size_per_partition: int = 1,
    in_dtype: torch.dtype = torch.bfloat16,
) -> FusedMoEConfig:
    """
    This is a dummy config for the mk constructor interface
    as most kernels like DeepGEMM, CUTLASSFp4, Triton, MARLIN
    do not actually use this config.

    CUTLASSFp8 needs to set some params for workshapes.
    """
    return FusedMoEConfig(
        num_experts=num_experts,
        experts_per_token=experts_per_token,
        hidden_dim=hidden_dim,
        intermediate_size_per_partition=intermediate_size_per_partition,
        num_local_experts=num_experts,
66
        num_logical_experts=num_experts,
67
        moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
68
        activation=MoEActivation.SILU,
69
70
71
72
73
74
        in_dtype=in_dtype,
        device="cuda",
        routing_method=RoutingMethodType.TopK,
    )


bnellnm's avatar
bnellnm committed
75
76
77
78
79
80
def triton_moe(
    a: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weight: torch.Tensor,
    topk_ids: torch.Tensor,
81
82
83
84
85
    w1_scale: torch.Tensor | None = None,
    w2_scale: torch.Tensor | None = None,
    a1_scale: torch.Tensor | None = None,
    a2_scale: torch.Tensor | None = None,
    quant_dtype: torch.dtype | None = None,
bnellnm's avatar
bnellnm committed
86
    per_act_token_quant=False,
87
    block_shape: list[int] | None = None,
bnellnm's avatar
bnellnm committed
88
) -> torch.Tensor:
89
90
91
92
93
94
95
96
97
98
    quant_config = FusedMoEQuantConfig.make(
        quant_dtype,
        per_act_token_quant=per_act_token_quant,
        block_shape=block_shape,
        w1_scale=w1_scale,
        w2_scale=w2_scale,
        a1_scale=a1_scale,
        a2_scale=a2_scale,
    )

99
    return fused_experts(a, w1, w2, topk_weight, topk_ids, quant_config=quant_config)
bnellnm's avatar
bnellnm committed
100
101
102
103
104
105
106
107


def batched_moe(
    a: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weight: torch.Tensor,
    topk_ids: torch.Tensor,
108
109
110
111
112
    w1_scale: torch.Tensor | None = None,
    w2_scale: torch.Tensor | None = None,
    a1_scale: torch.Tensor | None = None,
    a2_scale: torch.Tensor | None = None,
    quant_dtype: torch.dtype | None = None,
bnellnm's avatar
bnellnm committed
113
    per_act_token_quant: bool = False,
114
    block_shape: list[int] | None = None,
bnellnm's avatar
bnellnm committed
115
116
117
) -> torch.Tensor:
    max_num_tokens = round_up(a.shape[0], 64)

118
119
120
121
122
123
124
125
126
127
    quant_config = FusedMoEQuantConfig.make(
        quant_dtype,
        per_act_token_quant=per_act_token_quant,
        block_shape=block_shape,
        w1_scale=w1_scale,
        w2_scale=w2_scale,
        a1_scale=a1_scale,
        a2_scale=a2_scale,
    )

128
129
130
    moe_config = make_dummy_moe_config()

    fused_experts = FusedMoEKernel(
131
132
133
        BatchedPrepareAndFinalize(
            max_num_tokens, num_dispatchers=1, num_local_experts=w1.shape[0], rank=0
        ),
bnellnm's avatar
bnellnm committed
134
135
        BatchedTritonExperts(
            max_num_tokens=max_num_tokens,
136
            num_dispatchers=1,
137
            quant_config=quant_config,
138
            moe_config=moe_config,
bnellnm's avatar
bnellnm committed
139
        ),
140
        inplace=False,
141
142
    )

143
144
145
146
147
148
149
150
151
152
153
    return fused_experts.apply(
        a,
        w1,
        w2,
        topk_weight,
        topk_ids,
        global_num_experts=w1.shape[0],
        activation=moe_config.activation,
        apply_router_weight_on_input=False,
        expert_map=None,
    )
bnellnm's avatar
bnellnm committed
154
155
156
157
158
159
160
161


def naive_batched_moe(
    a: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weight: torch.Tensor,
    topk_ids: torch.Tensor,
162
163
164
165
166
    w1_scale: torch.Tensor | None = None,
    w2_scale: torch.Tensor | None = None,
    a1_scale: torch.Tensor | None = None,
    a2_scale: torch.Tensor | None = None,
    quant_dtype: torch.dtype | None = None,
bnellnm's avatar
bnellnm committed
167
    per_act_token_quant: bool = False,
168
    block_shape: list[int] | None = None,
bnellnm's avatar
bnellnm committed
169
170
171
) -> torch.Tensor:
    max_num_tokens = round_up(a.shape[0], 64)

172
173
174
175
176
177
178
179
180
    quant_config = FusedMoEQuantConfig.make(
        quant_dtype,
        per_act_token_quant=per_act_token_quant,
        block_shape=block_shape,
        w1_scale=w1_scale,
        w2_scale=w2_scale,
        a1_scale=a1_scale,
        a2_scale=a2_scale,
    )
181
    moe_config = make_dummy_moe_config()
182

183
    fused_experts = FusedMoEKernel(
184
185
186
        BatchedPrepareAndFinalize(
            max_num_tokens, num_dispatchers=1, num_local_experts=w1.shape[0], rank=0
        ),
bnellnm's avatar
bnellnm committed
187
188
        NaiveBatchedExperts(
            max_num_tokens=max_num_tokens,
189
            num_dispatchers=1,
190
            quant_config=quant_config,
191
            moe_config=moe_config,
bnellnm's avatar
bnellnm committed
192
        ),
193
        inplace=False,
bnellnm's avatar
bnellnm committed
194
    )
195

196
197
198
199
200
201
202
203
204
205
206
    return fused_experts.apply(
        a,
        w1,
        w2,
        topk_weight,
        topk_ids,
        global_num_experts=w1.shape[0],
        activation=moe_config.activation,
        apply_router_weight_on_input=False,
        expert_map=None,
    )
bnellnm's avatar
bnellnm committed
207
208


209
def chunk_scales(
210
211
    scales: torch.Tensor | None, start: int, end: int
) -> torch.Tensor | None:
bnellnm's avatar
bnellnm committed
212
213
214
215
216
217
218
219
220
221
222
223
224
    if scales is not None:
        if scales.numel() == 1:
            return scales
        else:
            return scales[start:end]
    return None


def make_quantized_test_activations(
    E: int,
    m: int,
    k: int,
    in_dtype: torch.dtype,
225
226
    quant_dtype: torch.dtype | None = None,
    block_shape: list[int] | None = None,
bnellnm's avatar
bnellnm committed
227
    per_act_token_quant: bool = False,
228
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
bnellnm's avatar
bnellnm committed
229
230
231
232
233
    a = torch.randn((E, m, k), device="cuda", dtype=in_dtype) / 10
    a_q = a
    a_scale = None

    if quant_dtype is not None:
234
235
236
        assert quant_dtype == torch.float8_e4m3fn or quant_dtype == torch.int8, (
            "only fp8/int8 supported"
        )
bnellnm's avatar
bnellnm committed
237
238
239
240
        a_q = torch.zeros_like(a, dtype=quant_dtype)
        a_scale_l = [None] * E
        for e in range(E):
            a_q[e], a_scale_l[e] = moe_kernel_quantize_input(
241
242
                a[e], None, quant_dtype, per_act_token_quant, block_shape
            )
bnellnm's avatar
bnellnm committed
243
244
245
246
247
248
249
250
251
252
        a_scale = torch.stack(a_scale_l)

        if not per_act_token_quant and block_shape is None:
            a_scale = a_scale.view(E, 1, 1)

    return a, a_q, a_scale


def moe_quantize_weights(
    w: torch.Tensor,
253
254
    w_s: torch.Tensor | None,
    quant_dtype: torch.dtype | str | None,
bnellnm's avatar
bnellnm committed
255
    per_token_quant: bool,
256
257
    block_shape: list[int] | None,
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
258
259
260
261
262
    assert (
        quant_dtype == torch.float8_e4m3fn
        or quant_dtype == torch.int8
        or quant_dtype == "nvfp4"
    ), "only fp8/int8/nvfp4 supported"
263
264

    w_gs = None
bnellnm's avatar
bnellnm committed
265
266
267
268
269

    if block_shape is not None:
        assert not per_token_quant
        if quant_dtype == torch.int8:
            w, w_s = per_block_cast_to_int8(w, block_shape)
270
        elif quant_dtype == torch.float8_e4m3fn:
bnellnm's avatar
bnellnm committed
271
            w, w_s = per_block_cast_to_fp8(w, block_shape)
272
273
274
275
        elif quant_dtype == "nvfp4":
            raise RuntimeError("blocked quantization not supported for nvfp4")
        else:
            raise RuntimeError(f"Unsupported quant type {quant_dtype}")
bnellnm's avatar
bnellnm committed
276
277
278
    else:
        if quant_dtype == torch.int8:
            w, w_s = ops.scaled_int8_quant(
279
280
                w, w_s, use_per_token_if_dynamic=per_token_quant
            )
281
        elif quant_dtype == torch.float8_e4m3fn:
bnellnm's avatar
bnellnm committed
282
            w, w_s = ops.scaled_fp8_quant(
283
284
                w, w_s, use_per_token_if_dynamic=per_token_quant
            )
285
286
287
288
289
290
291
        elif quant_dtype == "nvfp4":
            assert not per_token_quant
            w_amax = torch.abs(w).max().to(torch.float32)
            w_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w_amax
            w, w_s = ops.scaled_fp4_quant(w, w_gs)
        else:
            raise RuntimeError(f"Unsupported quant type {quant_dtype}")
bnellnm's avatar
bnellnm committed
292

293
    return w, w_s, w_gs
bnellnm's avatar
bnellnm committed
294
295
296
297
298
299
300


def make_test_weight(
    e: int,
    rows: int,
    cols: int,
    in_dtype: torch.dtype = torch.bfloat16,
301
302
    quant_dtype: torch.dtype | str | None = None,
    block_shape: list[int] | None = None,
303
    per_out_ch_quant: bool = False,
304
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
bnellnm's avatar
bnellnm committed
305
    w_16 = torch.randn((e, rows, cols), device="cuda", dtype=in_dtype) / 15
306
    w_gs = None
bnellnm's avatar
bnellnm committed
307
308
309
310

    if quant_dtype is not None:
        w_l = [None] * e
        w_s_l = [None] * e
311
        w_gs_l = [None] * e
bnellnm's avatar
bnellnm committed
312
        for idx in range(e):
313
            w_l[idx], w_s_l[idx], w_gs_l[idx] = moe_quantize_weights(
314
315
                w_16[idx], None, quant_dtype, per_out_ch_quant, block_shape
            )
bnellnm's avatar
bnellnm committed
316
317
318

        w = torch.stack(w_l)
        w_s = torch.stack(w_s_l)
319
320
        if e > 0 and w_gs_l[0] is not None:
            w_gs = torch.stack(w_gs_l)
bnellnm's avatar
bnellnm committed
321
322
323
324
325
326
327
328
329
330
331
332
        if w_s.ndim == 2:
            assert w_s.shape[-1] == 1
            w_s = w_s.view(-1, 1, 1)

        if block_shape is not None:
            block_n, block_k = block_shape
            n_tiles = (rows + block_n - 1) // block_n
            k_tiles = (cols + block_k - 1) // block_k
            assert w_s.shape == (e, n_tiles, k_tiles)
    else:
        w = w_16
        w_s = None
333
        w_gs = None
bnellnm's avatar
bnellnm committed
334

335
    return w_16, w, w_s, w_gs
bnellnm's avatar
bnellnm committed
336
337
338
339
340
341
342


def make_test_weights(
    e: int,
    n: int,
    k: int,
    in_dtype: torch.dtype = torch.bfloat16,
343
344
    quant_dtype: torch.dtype | str | None = None,
    block_shape: list[int] | None = None,
345
    per_out_ch_quant: bool = False,
346
    make_gate: bool = True,
347
) -> tuple[
348
349
    tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None],
    tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None],
350
]:
bnellnm's avatar
bnellnm committed
351
    return (
352
        make_test_weight(
353
354
355
356
357
358
359
            e,
            (2 if make_gate else 1) * n,
            k,
            in_dtype,
            quant_dtype,
            block_shape,
            per_out_ch_quant,
360
361
        ),
        make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape, per_out_ch_quant),
bnellnm's avatar
bnellnm committed
362
    )
363
364
365


def per_token_cast_to_fp8(
366
367
    x: torch.Tensor, block_size: int = 128
) -> tuple[torch.Tensor, torch.Tensor]:
368
369
370
    assert x.dim() == 2
    m, n = x.shape
    pad_size = (block_size - (n % block_size)) % block_size
371
    x = torch.nn.functional.pad(x, (0, pad_size), value=0) if pad_size > 0 else x
372
373
374
375
    x_view = x.view(m, -1, block_size)
    x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
    fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn)
    return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
376
377


378
379
380
381
382
def make_test_quant_config(
    e: int,
    n: int,
    k: int,
    in_dtype: torch.dtype,
383
    quant_dtype: torch.dtype | str | None = None,
384
    per_act_token_quant: bool = False,
385
    block_shape: list[int] | None = None,
386
    make_gate: bool = True,
387
388
389
390
391
392
393
394
395
) -> tuple[torch.Tensor, torch.Tensor, FusedMoEQuantConfig]:
    (_, w1, w1_s, w1_gs), (_, w2, w2_s, w2_gs) = make_test_weights(
        e,
        n,
        k,
        in_dtype,
        quant_dtype,
        per_out_ch_quant=per_act_token_quant,
        block_shape=block_shape,
396
        make_gate=make_gate,
397
398
399
    )

    # Hacky/trivial scales for nvfp4.
400
401
    a1_gscale: torch.Tensor | None = None
    a2_gscale: torch.Tensor | None = None
402
    if quant_dtype == "nvfp4":
403
404
        a1_gscale = torch.ones((e,), device="cuda", dtype=torch.float32)
        a2_gscale = torch.ones((e,), device="cuda", dtype=torch.float32)
405
406
407
408
409
410
        a1_scale = a1_gscale
        a2_scale = a2_gscale
    else:
        a1_scale = None
        a2_scale = None

411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
    return (
        w1,
        w2,
        FusedMoEQuantConfig.make(
            quant_dtype,
            per_act_token_quant=per_act_token_quant,
            block_shape=block_shape,
            w1_scale=w1_s,
            w2_scale=w2_s,
            a1_gscale=a1_gscale,
            a2_gscale=a2_gscale,
            a1_scale=a1_scale,
            a2_scale=a2_scale,
            # TODO: make sure this is handled properly
            g1_alphas=(1 / w1_gs) if w1_gs is not None else None,
            g2_alphas=(1 / w2_gs) if w2_gs is not None else None,
        ),
428
429
430
431
432
433
434
435
436
437
    )


def fused_moe(
    hidden_states: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    score: torch.Tensor,
    topk: int,
    renormalize: bool = False,
438
    quant_config: FusedMoEQuantConfig | None = None,
439
    global_num_experts: int = -1,
440
    expert_map: torch.Tensor | None = None,
441
) -> torch.Tensor:
442
443
444
445
446
447
448
449
450
451
452
453
454
    topk_weights, topk_ids, _ = fused_topk(
        hidden_states, score.float(), topk, renormalize
    )
    return fused_experts(
        hidden_states,
        w1,
        w2,
        topk_weights,
        topk_ids,
        global_num_experts=global_num_experts,
        expert_map=expert_map,
        quant_config=quant_config,
    )
455
456


457
458
459
460
461
462
463
464
465
466
467
# CustomOp?
class BaselineMM(torch.nn.Module):
    def __init__(
        self,
        b: torch.Tensor,
        out_dtype: torch.dtype,
    ):
        super().__init__()
        self.b = b.to(dtype=torch.float32)
        self.out_dtype = out_dtype

468
    def forward(self, a: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]:
469
        return torch.mm(a.to(dtype=torch.float32), self.b).to(self.out_dtype), None
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511


class TestMLP(torch.nn.Module):
    def __init__(
        self,
        w1: torch.Tensor,
        w2: torch.Tensor,
        out_dtype: torch.dtype,
    ):
        super().__init__()
        self.gate_up_proj = BaselineMM(w1, out_dtype)
        self.down_proj = BaselineMM(w2, out_dtype)
        self.act_fn = SiluAndMul()

    def forward(self, x):
        x, _ = self.gate_up_proj(x)
        x = self.act_fn(x)
        x, _ = self.down_proj(x)
        return x


def make_naive_shared_experts(
    N: int,
    K: int,
    in_dtype: torch.dtype = torch.bfloat16,
) -> torch.nn.Module:
    w1 = torch.randn((K, N * 2), device="cuda", dtype=in_dtype) / 15
    w2 = torch.randn((N, K), device="cuda", dtype=in_dtype) / 15
    return TestMLP(w1, w2, out_dtype=in_dtype)


class RealMLP(torch.nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        w1: torch.Tensor,
        w2: torch.Tensor,
        hidden_act: str = "silu",
        quant_config=None,
        reduce_results: bool = True,
        prefix: str = "",
512
513
        w1_s: torch.Tensor | None = None,
        w2_s: torch.Tensor | None = None,
514
515
    ) -> None:
        from vllm.model_executor.layers.linear import (
516
517
518
            MergedColumnParallelLinear,
            RowParallelLinear,
        )
519
520
521

        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
522
523
            hidden_size,
            [intermediate_size] * 2,
524
525
            bias=False,
            quant_config=quant_config,
526
527
            prefix=f"{prefix}.gate_up_proj",
        )
528
        self.gate_up_proj.register_parameter(
529
530
            "weight", torch.nn.Parameter(w1, requires_grad=False)
        )
531
        self.gate_up_proj.register_parameter(
532
533
            "weight_scale", torch.nn.Parameter(w1_s, requires_grad=False)
        )
534
        self.gate_up_proj.register_parameter(
535
536
537
538
539
540
541
542
543
544
            "input_scale", None
        )  # torch.nn.Parameter(None, requires_grad=False))
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            reduce_results=reduce_results,
            prefix=f"{prefix}.down_proj",
        )
545
        self.down_proj.register_parameter(
546
547
            "weight", torch.nn.Parameter(w2, requires_grad=False)
        )
548
        self.down_proj.register_parameter(
549
550
            "weight_scale", torch.nn.Parameter(w2_s, requires_grad=False)
        )
551
        self.down_proj.register_parameter(
552
553
            "input_scale", None
        )  # torch.nn.Parameter(None, requires_grad=False))
554
        if hidden_act != "silu":
555
556
557
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
558
559
560
561
562
563
564
565
566
567
568
569
570
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


def make_shared_experts(
    N: int,
    K: int,
    in_dtype: torch.dtype = torch.bfloat16,
571
    quant_dtype: torch.dtype | str | None = None,
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
) -> torch.nn.Module:
    from vllm.model_executor.layers.quantization.fp8 import Fp8Config

    (_, w1, w1_s, _), (_, w2, w2_s, _) = make_test_weights(
        1,
        N,
        K,
        in_dtype=in_dtype,
        quant_dtype=quant_dtype,
    )
    old_dtype = torch.get_default_dtype()
    try:
        torch.set_default_dtype(in_dtype)
        if quant_dtype == torch.float8_e4m3fn:
            w1 = w1[0].transpose(0, 1)
            w2 = w2[0].transpose(0, 1)
            w1_s = w1_s[0].transpose(0, 1) if w1_s is not None else None
            w2_s = w2_s[0].transpose(0, 1) if w2_s is not None else None
            quant_config = Fp8Config(True)
        else:
            w1 = w1[0]
            w2 = w2[0]
            w1_s = None
            w2_s = None
            quant_config = None

598
        return RealMLP(K, N, w1, w2, "silu", quant_config, w1_s=w1_s, w2_s=w2_s)
599
600
    finally:
        torch.set_default_dtype(old_dtype)
601
602
603
604
605


def modular_triton_fused_moe(
    moe_config: FusedMoEConfig,
    quant_config: FusedMoEQuantConfig,
606
607
608
609
610
611
612
613
) -> FusedMoEKernel:
    return FusedMoEKernel(
        maybe_make_prepare_finalize(
            moe=moe_config,
            quant_config=quant_config,
            allow_new_interface=True,
            use_monolithic=False,
        ),
614
615
616
        TritonExperts(moe_config, quant_config),
        inplace=False,
    )