utils.py 19.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
bnellnm's avatar
bnellnm committed
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5

import torch

bnellnm's avatar
bnellnm committed
6
import vllm._custom_ops as ops
7
from tests.kernels.quant_utils import per_block_cast_to_int8
8
from tests.kernels.quantization.nvfp4_utils import FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX
9
from vllm.model_executor.layers.activation import SiluAndMul
10
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
11
12
13
from vllm.model_executor.layers.fused_moe.all2all_utils import (
    maybe_make_prepare_finalize,
)
14
15
16
17
18
19
from vllm.model_executor.layers.fused_moe.config import (
    FusedMoEConfig,
    FusedMoEParallelConfig,
    FusedMoEQuantConfig,
    RoutingMethodType,
)
bnellnm's avatar
bnellnm committed
20
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
21
22
23
    BatchedTritonExperts,
    NaiveBatchedExperts,
)
24
25
26
27
from vllm.model_executor.layers.fused_moe.fused_moe import (
    TritonExperts,
    fused_experts,
)
28
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEKernel
29
30
31
from vllm.model_executor.layers.fused_moe.prepare_finalize.batched import (
    BatchedPrepareAndFinalize,
)
32
from vllm.model_executor.layers.fused_moe.router.fused_topk_router import fused_topk
33
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
34
from vllm.utils.deep_gemm import per_block_cast_to_fp8
35
from vllm.utils.math_utils import round_up
bnellnm's avatar
bnellnm committed
36
37


38
39
40
41
42
43
44
45
46
47
def shuffle_weight(w: torch.Tensor) -> torch.Tensor:
    """Fold weights to adjacent locations for Triton MoE / SwiGLU kernel layout."""
    shape = w.shape
    n = shape[-1]
    first = w[..., : n // 2]
    second = w[..., n // 2 :]
    stacked = torch.stack((first, second), dim=-1)
    return stacked.reshape(shape)


48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def make_dummy_moe_config(
    num_experts: int = 1,
    experts_per_token: int = 1,
    hidden_dim: int = 1,
    intermediate_size_per_partition: int = 1,
    in_dtype: torch.dtype = torch.bfloat16,
) -> FusedMoEConfig:
    """
    This is a dummy config for the mk constructor interface
    as most kernels like DeepGEMM, CUTLASSFp4, Triton, MARLIN
    do not actually use this config.

    CUTLASSFp8 needs to set some params for workshapes.
    """
    return FusedMoEConfig(
        num_experts=num_experts,
        experts_per_token=experts_per_token,
        hidden_dim=hidden_dim,
        intermediate_size_per_partition=intermediate_size_per_partition,
        num_local_experts=num_experts,
68
        num_logical_experts=num_experts,
69
        moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
70
        activation=MoEActivation.SILU,
71
72
73
        in_dtype=in_dtype,
        device="cuda",
        routing_method=RoutingMethodType.TopK,
74
        max_num_tokens=512,
75
76
77
    )


bnellnm's avatar
bnellnm committed
78
79
80
81
82
83
def triton_moe(
    a: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weight: torch.Tensor,
    topk_ids: torch.Tensor,
84
85
86
87
88
    w1_scale: torch.Tensor | None = None,
    w2_scale: torch.Tensor | None = None,
    a1_scale: torch.Tensor | None = None,
    a2_scale: torch.Tensor | None = None,
    quant_dtype: torch.dtype | None = None,
bnellnm's avatar
bnellnm committed
89
    per_act_token_quant=False,
90
    block_shape: list[int] | None = None,
bnellnm's avatar
bnellnm committed
91
) -> torch.Tensor:
92
93
94
95
96
97
98
99
100
101
    quant_config = FusedMoEQuantConfig.make(
        quant_dtype,
        per_act_token_quant=per_act_token_quant,
        block_shape=block_shape,
        w1_scale=w1_scale,
        w2_scale=w2_scale,
        a1_scale=a1_scale,
        a2_scale=a2_scale,
    )

102
    return fused_experts(a, w1, w2, topk_weight, topk_ids, quant_config=quant_config)
bnellnm's avatar
bnellnm committed
103
104
105
106
107
108
109
110


def batched_moe(
    a: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weight: torch.Tensor,
    topk_ids: torch.Tensor,
111
112
113
114
115
    w1_scale: torch.Tensor | None = None,
    w2_scale: torch.Tensor | None = None,
    a1_scale: torch.Tensor | None = None,
    a2_scale: torch.Tensor | None = None,
    quant_dtype: torch.dtype | None = None,
bnellnm's avatar
bnellnm committed
116
    per_act_token_quant: bool = False,
117
    block_shape: list[int] | None = None,
bnellnm's avatar
bnellnm committed
118
119
120
) -> torch.Tensor:
    max_num_tokens = round_up(a.shape[0], 64)

121
122
123
124
125
126
127
128
129
130
    quant_config = FusedMoEQuantConfig.make(
        quant_dtype,
        per_act_token_quant=per_act_token_quant,
        block_shape=block_shape,
        w1_scale=w1_scale,
        w2_scale=w2_scale,
        a1_scale=a1_scale,
        a2_scale=a2_scale,
    )

131
132
133
    moe_config = make_dummy_moe_config()

    fused_experts = FusedMoEKernel(
134
135
136
        BatchedPrepareAndFinalize(
            max_num_tokens, num_dispatchers=1, num_local_experts=w1.shape[0], rank=0
        ),
bnellnm's avatar
bnellnm committed
137
138
        BatchedTritonExperts(
            max_num_tokens=max_num_tokens,
139
            num_dispatchers=1,
140
            quant_config=quant_config,
141
            moe_config=moe_config,
bnellnm's avatar
bnellnm committed
142
        ),
143
        inplace=False,
144
145
    )

146
147
148
149
150
151
152
153
154
155
156
    return fused_experts.apply(
        a,
        w1,
        w2,
        topk_weight,
        topk_ids,
        global_num_experts=w1.shape[0],
        activation=moe_config.activation,
        apply_router_weight_on_input=False,
        expert_map=None,
    )
bnellnm's avatar
bnellnm committed
157
158
159
160
161
162
163
164


def naive_batched_moe(
    a: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weight: torch.Tensor,
    topk_ids: torch.Tensor,
165
166
167
168
169
    w1_scale: torch.Tensor | None = None,
    w2_scale: torch.Tensor | None = None,
    a1_scale: torch.Tensor | None = None,
    a2_scale: torch.Tensor | None = None,
    quant_dtype: torch.dtype | None = None,
bnellnm's avatar
bnellnm committed
170
    per_act_token_quant: bool = False,
171
    block_shape: list[int] | None = None,
bnellnm's avatar
bnellnm committed
172
173
174
) -> torch.Tensor:
    max_num_tokens = round_up(a.shape[0], 64)

175
176
177
178
179
180
181
182
183
    quant_config = FusedMoEQuantConfig.make(
        quant_dtype,
        per_act_token_quant=per_act_token_quant,
        block_shape=block_shape,
        w1_scale=w1_scale,
        w2_scale=w2_scale,
        a1_scale=a1_scale,
        a2_scale=a2_scale,
    )
184
    moe_config = make_dummy_moe_config()
185

186
    fused_experts = FusedMoEKernel(
187
188
189
        BatchedPrepareAndFinalize(
            max_num_tokens, num_dispatchers=1, num_local_experts=w1.shape[0], rank=0
        ),
bnellnm's avatar
bnellnm committed
190
191
        NaiveBatchedExperts(
            max_num_tokens=max_num_tokens,
192
            num_dispatchers=1,
193
            quant_config=quant_config,
194
            moe_config=moe_config,
bnellnm's avatar
bnellnm committed
195
        ),
196
        inplace=False,
bnellnm's avatar
bnellnm committed
197
    )
198

199
200
201
202
203
204
205
206
207
208
209
    return fused_experts.apply(
        a,
        w1,
        w2,
        topk_weight,
        topk_ids,
        global_num_experts=w1.shape[0],
        activation=moe_config.activation,
        apply_router_weight_on_input=False,
        expert_map=None,
    )
bnellnm's avatar
bnellnm committed
210
211


212
def chunk_scales(
213
214
    scales: torch.Tensor | None, start: int, end: int
) -> torch.Tensor | None:
bnellnm's avatar
bnellnm committed
215
216
217
218
219
220
221
222
223
224
225
226
227
    if scales is not None:
        if scales.numel() == 1:
            return scales
        else:
            return scales[start:end]
    return None


def make_quantized_test_activations(
    E: int,
    m: int,
    k: int,
    in_dtype: torch.dtype,
228
229
    quant_dtype: torch.dtype | None = None,
    block_shape: list[int] | None = None,
bnellnm's avatar
bnellnm committed
230
    per_act_token_quant: bool = False,
231
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
bnellnm's avatar
bnellnm committed
232
233
234
235
236
    a = torch.randn((E, m, k), device="cuda", dtype=in_dtype) / 10
    a_q = a
    a_scale = None

    if quant_dtype is not None:
237
238
239
        assert quant_dtype == torch.float8_e4m3fn or quant_dtype == torch.int8, (
            "only fp8/int8 supported"
        )
bnellnm's avatar
bnellnm committed
240
241
242
243
        a_q = torch.zeros_like(a, dtype=quant_dtype)
        a_scale_l = [None] * E
        for e in range(E):
            a_q[e], a_scale_l[e] = moe_kernel_quantize_input(
244
245
                a[e], None, quant_dtype, per_act_token_quant, block_shape
            )
bnellnm's avatar
bnellnm committed
246
247
248
249
250
251
252
253
        a_scale = torch.stack(a_scale_l)

        if not per_act_token_quant and block_shape is None:
            a_scale = a_scale.view(E, 1, 1)

    return a, a_q, a_scale


254
def moe_quantize_weights_2d(
bnellnm's avatar
bnellnm committed
255
    w: torch.Tensor,
256
257
    w_s: torch.Tensor | None,
    quant_dtype: torch.dtype | str | None,
bnellnm's avatar
bnellnm committed
258
    per_token_quant: bool,
259
260
    block_shape: list[int] | None,
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
261
262
263
264
265
    assert (
        quant_dtype == torch.float8_e4m3fn
        or quant_dtype == torch.int8
        or quant_dtype == "nvfp4"
    ), "only fp8/int8/nvfp4 supported"
266
267

    w_gs = None
bnellnm's avatar
bnellnm committed
268
269
270
271
272

    if block_shape is not None:
        assert not per_token_quant
        if quant_dtype == torch.int8:
            w, w_s = per_block_cast_to_int8(w, block_shape)
273
        elif quant_dtype == torch.float8_e4m3fn:
bnellnm's avatar
bnellnm committed
274
            w, w_s = per_block_cast_to_fp8(w, block_shape)
275
276
277
278
        elif quant_dtype == "nvfp4":
            raise RuntimeError("blocked quantization not supported for nvfp4")
        else:
            raise RuntimeError(f"Unsupported quant type {quant_dtype}")
bnellnm's avatar
bnellnm committed
279
280
281
    else:
        if quant_dtype == torch.int8:
            w, w_s = ops.scaled_int8_quant(
282
283
                w, w_s, use_per_token_if_dynamic=per_token_quant
            )
284
        elif quant_dtype == torch.float8_e4m3fn:
bnellnm's avatar
bnellnm committed
285
            w, w_s = ops.scaled_fp8_quant(
286
287
                w, w_s, use_per_token_if_dynamic=per_token_quant
            )
288
289
290
291
292
293
294
        elif quant_dtype == "nvfp4":
            assert not per_token_quant
            w_amax = torch.abs(w).max().to(torch.float32)
            w_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w_amax
            w, w_s = ops.scaled_fp4_quant(w, w_gs)
        else:
            raise RuntimeError(f"Unsupported quant type {quant_dtype}")
bnellnm's avatar
bnellnm committed
295

296
    return w, w_s, w_gs
bnellnm's avatar
bnellnm committed
297
298


299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
def moe_quantize_weights(
    w: torch.Tensor,
    w_s: torch.Tensor | None,
    quant_dtype: torch.dtype | str | None,
    per_token_quant: bool,
    block_shape: list[int] | None,
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
    assert w.dim() == 3
    e, rows, cols = w.shape
    w_l = [None] * e
    w_s_l = [None] * e
    w_gs_l = [None] * e
    for idx in range(e):
        w_l[idx], w_s_l[idx], w_gs_l[idx] = moe_quantize_weights_2d(
            w[idx], None, quant_dtype, per_token_quant, block_shape
        )

    w = torch.stack(w_l)
    w_s = torch.stack(w_s_l)
    w_gs = torch.stack(w_gs_l) if e > 0 and w_gs_l[0] is not None else None

    if w_s.ndim == 2:
        assert w_s.shape[-1] == 1
        w_s = w_s.view(-1, 1, 1)

    if block_shape is not None:
        block_n, block_k = block_shape
        n_tiles = (rows + block_n - 1) // block_n
        k_tiles = (cols + block_k - 1) // block_k
        assert w_s.shape == (e, n_tiles, k_tiles)

    return w, w_s, w_gs


bnellnm's avatar
bnellnm committed
333
334
335
336
337
def make_test_weight(
    e: int,
    rows: int,
    cols: int,
    in_dtype: torch.dtype = torch.bfloat16,
338
339
    quant_dtype: torch.dtype | str | None = None,
    block_shape: list[int] | None = None,
340
    per_out_ch_quant: bool = False,
341
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
bnellnm's avatar
bnellnm committed
342
343
344
    w_16 = torch.randn((e, rows, cols), device="cuda", dtype=in_dtype) / 15

    if quant_dtype is not None:
345
346
347
        w, w_s, w_gs = moe_quantize_weights(
            w_16, None, quant_dtype, per_out_ch_quant, block_shape
        )
bnellnm's avatar
bnellnm committed
348
349
350
    else:
        w = w_16
        w_s = None
351
        w_gs = None
bnellnm's avatar
bnellnm committed
352

353
    return w_16, w, w_s, w_gs
bnellnm's avatar
bnellnm committed
354
355
356
357
358
359
360


def make_test_weights(
    e: int,
    n: int,
    k: int,
    in_dtype: torch.dtype = torch.bfloat16,
361
362
    quant_dtype: torch.dtype | str | None = None,
    block_shape: list[int] | None = None,
363
    per_out_ch_quant: bool = False,
364
    make_gate: bool = True,
365
) -> tuple[
366
367
    tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None],
    tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None],
368
]:
bnellnm's avatar
bnellnm committed
369
    return (
370
        make_test_weight(
371
372
373
374
375
376
377
            e,
            (2 if make_gate else 1) * n,
            k,
            in_dtype,
            quant_dtype,
            block_shape,
            per_out_ch_quant,
378
379
        ),
        make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape, per_out_ch_quant),
bnellnm's avatar
bnellnm committed
380
    )
381
382
383


def per_token_cast_to_fp8(
384
385
    x: torch.Tensor, block_size: int = 128
) -> tuple[torch.Tensor, torch.Tensor]:
386
387
388
    assert x.dim() == 2
    m, n = x.shape
    pad_size = (block_size - (n % block_size)) % block_size
389
    x = torch.nn.functional.pad(x, (0, pad_size), value=0) if pad_size > 0 else x
390
391
392
393
    x_view = x.view(m, -1, block_size)
    x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
    fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn)
    return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
394
395


396
397
398
399
400
def make_test_quant_config(
    e: int,
    n: int,
    k: int,
    in_dtype: torch.dtype,
401
    quant_dtype: torch.dtype | str | None = None,
402
    per_act_token_quant: bool = False,
403
    block_shape: list[int] | None = None,
404
    make_gate: bool = True,
405
406
407
408
409
410
411
412
413
) -> tuple[torch.Tensor, torch.Tensor, FusedMoEQuantConfig]:
    (_, w1, w1_s, w1_gs), (_, w2, w2_s, w2_gs) = make_test_weights(
        e,
        n,
        k,
        in_dtype,
        quant_dtype,
        per_out_ch_quant=per_act_token_quant,
        block_shape=block_shape,
414
        make_gate=make_gate,
415
416
417
    )

    # Hacky/trivial scales for nvfp4.
418
419
    a1_gscale: torch.Tensor | None = None
    a2_gscale: torch.Tensor | None = None
420
    if quant_dtype == "nvfp4":
421
422
        a1_gscale = torch.ones((e,), device="cuda", dtype=torch.float32)
        a2_gscale = torch.ones((e,), device="cuda", dtype=torch.float32)
423
424
425
426
427
428
        a1_scale = a1_gscale
        a2_scale = a2_gscale
    else:
        a1_scale = None
        a2_scale = None

429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
    return (
        w1,
        w2,
        FusedMoEQuantConfig.make(
            quant_dtype,
            per_act_token_quant=per_act_token_quant,
            block_shape=block_shape,
            w1_scale=w1_s,
            w2_scale=w2_s,
            a1_gscale=a1_gscale,
            a2_gscale=a2_gscale,
            a1_scale=a1_scale,
            a2_scale=a2_scale,
            # TODO: make sure this is handled properly
            g1_alphas=(1 / w1_gs) if w1_gs is not None else None,
            g2_alphas=(1 / w2_gs) if w2_gs is not None else None,
        ),
446
447
448
449
450
451
452
453
454
455
    )


def fused_moe(
    hidden_states: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    score: torch.Tensor,
    topk: int,
    renormalize: bool = False,
456
    quant_config: FusedMoEQuantConfig | None = None,
457
    global_num_experts: int = -1,
458
    expert_map: torch.Tensor | None = None,
459
) -> torch.Tensor:
460
461
462
463
464
465
466
467
468
469
470
471
472
    topk_weights, topk_ids, _ = fused_topk(
        hidden_states, score.float(), topk, renormalize
    )
    return fused_experts(
        hidden_states,
        w1,
        w2,
        topk_weights,
        topk_ids,
        global_num_experts=global_num_experts,
        expert_map=expert_map,
        quant_config=quant_config,
    )
473
474


475
476
477
478
479
480
481
class BaselineMM(torch.nn.Module):
    def __init__(
        self,
        b: torch.Tensor,
        out_dtype: torch.dtype,
    ):
        super().__init__()
482
        self.b = torch.nn.Parameter(b.to(dtype=torch.float32))
483
484
        self.out_dtype = out_dtype

485
    def forward(self, a: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]:
486
        return torch.mm(a.to(dtype=torch.float32), self.b).to(self.out_dtype), None
487
488


489
490
491
492
493
494
495
496
497
class BaselineSiluAndMul(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        return torch.nn.functional.silu(x[..., :d]) * x[..., d:]


498
499
500
501
502
503
504
505
506
507
class TestMLP(torch.nn.Module):
    def __init__(
        self,
        w1: torch.Tensor,
        w2: torch.Tensor,
        out_dtype: torch.dtype,
    ):
        super().__init__()
        self.gate_up_proj = BaselineMM(w1, out_dtype)
        self.down_proj = BaselineMM(w2, out_dtype)
508
        self.act_fn = BaselineSiluAndMul()
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537

    def forward(self, x):
        x, _ = self.gate_up_proj(x)
        x = self.act_fn(x)
        x, _ = self.down_proj(x)
        return x


def make_naive_shared_experts(
    N: int,
    K: int,
    in_dtype: torch.dtype = torch.bfloat16,
) -> torch.nn.Module:
    w1 = torch.randn((K, N * 2), device="cuda", dtype=in_dtype) / 15
    w2 = torch.randn((N, K), device="cuda", dtype=in_dtype) / 15
    return TestMLP(w1, w2, out_dtype=in_dtype)


class RealMLP(torch.nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        w1: torch.Tensor,
        w2: torch.Tensor,
        hidden_act: str = "silu",
        quant_config=None,
        reduce_results: bool = True,
        prefix: str = "",
538
539
        w1_s: torch.Tensor | None = None,
        w2_s: torch.Tensor | None = None,
540
541
    ) -> None:
        from vllm.model_executor.layers.linear import (
542
543
544
            MergedColumnParallelLinear,
            RowParallelLinear,
        )
545
546
547

        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
548
549
            hidden_size,
            [intermediate_size] * 2,
550
551
            bias=False,
            quant_config=quant_config,
552
553
            prefix=f"{prefix}.gate_up_proj",
        )
554
        self.gate_up_proj.register_parameter(
555
556
            "weight", torch.nn.Parameter(w1, requires_grad=False)
        )
557
        self.gate_up_proj.register_parameter(
558
559
            "weight_scale", torch.nn.Parameter(w1_s, requires_grad=False)
        )
560
        self.gate_up_proj.register_parameter(
561
562
563
564
565
566
567
568
569
570
            "input_scale", None
        )  # torch.nn.Parameter(None, requires_grad=False))
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            reduce_results=reduce_results,
            prefix=f"{prefix}.down_proj",
        )
571
        self.down_proj.register_parameter(
572
573
            "weight", torch.nn.Parameter(w2, requires_grad=False)
        )
574
        self.down_proj.register_parameter(
575
576
            "weight_scale", torch.nn.Parameter(w2_s, requires_grad=False)
        )
577
        self.down_proj.register_parameter(
578
579
            "input_scale", None
        )  # torch.nn.Parameter(None, requires_grad=False))
580
        if hidden_act != "silu":
581
582
583
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
584
585
586
587
588
589
590
591
592
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


593
def make_shared_experts_with_weights(
594
595
    N: int,
    K: int,
596
597
598
599
600
    in_dtype: torch.dtype,
    w1: torch.Tensor,
    w2: torch.Tensor,
    w1_s: torch.Tensor | None = None,
    w2_s: torch.Tensor | None = None,
601
    quant_dtype: torch.dtype | str | None = None,
602
603
604
605
606
) -> torch.nn.Module:
    old_dtype = torch.get_default_dtype()
    try:
        torch.set_default_dtype(in_dtype)
        if quant_dtype == torch.float8_e4m3fn:
607
608
            from vllm.model_executor.layers.quantization.fp8 import Fp8Config

609
610
611
612
            quant_config = Fp8Config(True)
        else:
            quant_config = None

613
        return RealMLP(K, N, w1, w2, "silu", quant_config, w1_s=w1_s, w2_s=w2_s)
614
615
    finally:
        torch.set_default_dtype(old_dtype)
616
617
618
619
620


def modular_triton_fused_moe(
    moe_config: FusedMoEConfig,
    quant_config: FusedMoEQuantConfig,
621
622
623
624
625
626
627
628
) -> FusedMoEKernel:
    return FusedMoEKernel(
        maybe_make_prepare_finalize(
            moe=moe_config,
            quant_config=quant_config,
            allow_new_interface=True,
            use_monolithic=False,
        ),
629
630
631
        TritonExperts(moe_config, quant_config),
        inplace=False,
    )
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650


def make_shared_experts(
    N: int,
    K: int,
    in_dtype: torch.dtype = torch.bfloat16,
    quant_dtype: torch.dtype | str | None = None,
) -> torch.nn.Module:
    (_, w1, w1_s, _), (_, w2, w2_s, _) = make_test_weights(
        1,
        N,
        K,
        in_dtype=in_dtype,
        quant_dtype=quant_dtype,
    )

    return make_shared_experts_with_weights(
        N, K, in_dtype, w1, w2, w1_s=w1_s, w2_s=w2_s, quant_dtype=quant_dtype
    )