"vllm/vscode:/vscode.git/clone" did not exist on "3272d7a0b715a79059ebfcf8959d1ac0488ad18c"
utils.py 19.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
bnellnm's avatar
bnellnm committed
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5

import torch

bnellnm's avatar
bnellnm committed
6
import vllm._custom_ops as ops
7
from tests.kernels.quant_utils import per_block_cast_to_int8
8
from tests.kernels.quantization.nvfp4_utils import FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX
9
from vllm.model_executor.layers.activation import SiluAndMul
10
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
11
12
13
from vllm.model_executor.layers.fused_moe.all2all_utils import (
    maybe_make_prepare_finalize,
)
14
15
16
17
18
19
from vllm.model_executor.layers.fused_moe.config import (
    FusedMoEConfig,
    FusedMoEParallelConfig,
    FusedMoEQuantConfig,
    RoutingMethodType,
)
bnellnm's avatar
bnellnm committed
20
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
21
22
23
24
    BatchedPrepareAndFinalize,
    BatchedTritonExperts,
    NaiveBatchedExperts,
)
25
26
27
28
from vllm.model_executor.layers.fused_moe.fused_moe import (
    TritonExperts,
    fused_experts,
)
29
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEKernel
30
from vllm.model_executor.layers.fused_moe.router.fused_topk_router import fused_topk
31
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
32
from vllm.utils.deep_gemm import per_block_cast_to_fp8
33
from vllm.utils.math_utils import round_up
bnellnm's avatar
bnellnm committed
34
35


36
37
38
39
40
41
42
43
44
45
def shuffle_weight(w: torch.Tensor) -> torch.Tensor:
    """Fold weights to adjacent locations for Triton MoE / SwiGLU kernel layout."""
    shape = w.shape
    n = shape[-1]
    first = w[..., : n // 2]
    second = w[..., n // 2 :]
    stacked = torch.stack((first, second), dim=-1)
    return stacked.reshape(shape)


46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def make_dummy_moe_config(
    num_experts: int = 1,
    experts_per_token: int = 1,
    hidden_dim: int = 1,
    intermediate_size_per_partition: int = 1,
    in_dtype: torch.dtype = torch.bfloat16,
) -> FusedMoEConfig:
    """
    This is a dummy config for the mk constructor interface
    as most kernels like DeepGEMM, CUTLASSFp4, Triton, MARLIN
    do not actually use this config.

    CUTLASSFp8 needs to set some params for workshapes.
    """
    return FusedMoEConfig(
        num_experts=num_experts,
        experts_per_token=experts_per_token,
        hidden_dim=hidden_dim,
        intermediate_size_per_partition=intermediate_size_per_partition,
        num_local_experts=num_experts,
66
        num_logical_experts=num_experts,
67
        moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
68
        activation=MoEActivation.SILU,
69
70
71
72
73
74
        in_dtype=in_dtype,
        device="cuda",
        routing_method=RoutingMethodType.TopK,
    )


bnellnm's avatar
bnellnm committed
75
76
77
78
79
80
def triton_moe(
    a: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weight: torch.Tensor,
    topk_ids: torch.Tensor,
81
82
83
84
85
    w1_scale: torch.Tensor | None = None,
    w2_scale: torch.Tensor | None = None,
    a1_scale: torch.Tensor | None = None,
    a2_scale: torch.Tensor | None = None,
    quant_dtype: torch.dtype | None = None,
bnellnm's avatar
bnellnm committed
86
    per_act_token_quant=False,
87
    block_shape: list[int] | None = None,
bnellnm's avatar
bnellnm committed
88
) -> torch.Tensor:
89
90
91
92
93
94
95
96
97
98
    quant_config = FusedMoEQuantConfig.make(
        quant_dtype,
        per_act_token_quant=per_act_token_quant,
        block_shape=block_shape,
        w1_scale=w1_scale,
        w2_scale=w2_scale,
        a1_scale=a1_scale,
        a2_scale=a2_scale,
    )

99
    return fused_experts(a, w1, w2, topk_weight, topk_ids, quant_config=quant_config)
bnellnm's avatar
bnellnm committed
100
101
102
103
104
105
106
107


def batched_moe(
    a: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weight: torch.Tensor,
    topk_ids: torch.Tensor,
108
109
110
111
112
    w1_scale: torch.Tensor | None = None,
    w2_scale: torch.Tensor | None = None,
    a1_scale: torch.Tensor | None = None,
    a2_scale: torch.Tensor | None = None,
    quant_dtype: torch.dtype | None = None,
bnellnm's avatar
bnellnm committed
113
    per_act_token_quant: bool = False,
114
    block_shape: list[int] | None = None,
bnellnm's avatar
bnellnm committed
115
116
117
) -> torch.Tensor:
    max_num_tokens = round_up(a.shape[0], 64)

118
119
120
121
122
123
124
125
126
127
    quant_config = FusedMoEQuantConfig.make(
        quant_dtype,
        per_act_token_quant=per_act_token_quant,
        block_shape=block_shape,
        w1_scale=w1_scale,
        w2_scale=w2_scale,
        a1_scale=a1_scale,
        a2_scale=a2_scale,
    )

128
129
130
    moe_config = make_dummy_moe_config()

    fused_experts = FusedMoEKernel(
131
132
133
        BatchedPrepareAndFinalize(
            max_num_tokens, num_dispatchers=1, num_local_experts=w1.shape[0], rank=0
        ),
bnellnm's avatar
bnellnm committed
134
135
        BatchedTritonExperts(
            max_num_tokens=max_num_tokens,
136
            num_dispatchers=1,
137
            quant_config=quant_config,
138
            moe_config=moe_config,
bnellnm's avatar
bnellnm committed
139
        ),
140
        inplace=False,
141
142
    )

143
144
145
146
147
148
149
150
151
152
153
    return fused_experts.apply(
        a,
        w1,
        w2,
        topk_weight,
        topk_ids,
        global_num_experts=w1.shape[0],
        activation=moe_config.activation,
        apply_router_weight_on_input=False,
        expert_map=None,
    )
bnellnm's avatar
bnellnm committed
154
155
156
157
158
159
160
161


def naive_batched_moe(
    a: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weight: torch.Tensor,
    topk_ids: torch.Tensor,
162
163
164
165
166
    w1_scale: torch.Tensor | None = None,
    w2_scale: torch.Tensor | None = None,
    a1_scale: torch.Tensor | None = None,
    a2_scale: torch.Tensor | None = None,
    quant_dtype: torch.dtype | None = None,
bnellnm's avatar
bnellnm committed
167
    per_act_token_quant: bool = False,
168
    block_shape: list[int] | None = None,
bnellnm's avatar
bnellnm committed
169
170
171
) -> torch.Tensor:
    max_num_tokens = round_up(a.shape[0], 64)

172
173
174
175
176
177
178
179
180
    quant_config = FusedMoEQuantConfig.make(
        quant_dtype,
        per_act_token_quant=per_act_token_quant,
        block_shape=block_shape,
        w1_scale=w1_scale,
        w2_scale=w2_scale,
        a1_scale=a1_scale,
        a2_scale=a2_scale,
    )
181
    moe_config = make_dummy_moe_config()
182

183
    fused_experts = FusedMoEKernel(
184
185
186
        BatchedPrepareAndFinalize(
            max_num_tokens, num_dispatchers=1, num_local_experts=w1.shape[0], rank=0
        ),
bnellnm's avatar
bnellnm committed
187
188
        NaiveBatchedExperts(
            max_num_tokens=max_num_tokens,
189
            num_dispatchers=1,
190
            quant_config=quant_config,
191
            moe_config=moe_config,
bnellnm's avatar
bnellnm committed
192
        ),
193
        inplace=False,
bnellnm's avatar
bnellnm committed
194
    )
195

196
197
198
199
200
201
202
203
204
205
206
    return fused_experts.apply(
        a,
        w1,
        w2,
        topk_weight,
        topk_ids,
        global_num_experts=w1.shape[0],
        activation=moe_config.activation,
        apply_router_weight_on_input=False,
        expert_map=None,
    )
bnellnm's avatar
bnellnm committed
207
208


209
def chunk_scales(
210
211
    scales: torch.Tensor | None, start: int, end: int
) -> torch.Tensor | None:
bnellnm's avatar
bnellnm committed
212
213
214
215
216
217
218
219
220
221
222
223
224
    if scales is not None:
        if scales.numel() == 1:
            return scales
        else:
            return scales[start:end]
    return None


def make_quantized_test_activations(
    E: int,
    m: int,
    k: int,
    in_dtype: torch.dtype,
225
226
    quant_dtype: torch.dtype | None = None,
    block_shape: list[int] | None = None,
bnellnm's avatar
bnellnm committed
227
    per_act_token_quant: bool = False,
228
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
bnellnm's avatar
bnellnm committed
229
230
231
232
233
    a = torch.randn((E, m, k), device="cuda", dtype=in_dtype) / 10
    a_q = a
    a_scale = None

    if quant_dtype is not None:
234
235
236
        assert quant_dtype == torch.float8_e4m3fn or quant_dtype == torch.int8, (
            "only fp8/int8 supported"
        )
bnellnm's avatar
bnellnm committed
237
238
239
240
        a_q = torch.zeros_like(a, dtype=quant_dtype)
        a_scale_l = [None] * E
        for e in range(E):
            a_q[e], a_scale_l[e] = moe_kernel_quantize_input(
241
242
                a[e], None, quant_dtype, per_act_token_quant, block_shape
            )
bnellnm's avatar
bnellnm committed
243
244
245
246
247
248
249
250
        a_scale = torch.stack(a_scale_l)

        if not per_act_token_quant and block_shape is None:
            a_scale = a_scale.view(E, 1, 1)

    return a, a_q, a_scale


251
def moe_quantize_weights_2d(
bnellnm's avatar
bnellnm committed
252
    w: torch.Tensor,
253
254
    w_s: torch.Tensor | None,
    quant_dtype: torch.dtype | str | None,
bnellnm's avatar
bnellnm committed
255
    per_token_quant: bool,
256
257
    block_shape: list[int] | None,
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
258
259
260
261
262
    assert (
        quant_dtype == torch.float8_e4m3fn
        or quant_dtype == torch.int8
        or quant_dtype == "nvfp4"
    ), "only fp8/int8/nvfp4 supported"
263
264

    w_gs = None
bnellnm's avatar
bnellnm committed
265
266
267
268
269

    if block_shape is not None:
        assert not per_token_quant
        if quant_dtype == torch.int8:
            w, w_s = per_block_cast_to_int8(w, block_shape)
270
        elif quant_dtype == torch.float8_e4m3fn:
bnellnm's avatar
bnellnm committed
271
            w, w_s = per_block_cast_to_fp8(w, block_shape)
272
273
274
275
        elif quant_dtype == "nvfp4":
            raise RuntimeError("blocked quantization not supported for nvfp4")
        else:
            raise RuntimeError(f"Unsupported quant type {quant_dtype}")
bnellnm's avatar
bnellnm committed
276
277
278
    else:
        if quant_dtype == torch.int8:
            w, w_s = ops.scaled_int8_quant(
279
280
                w, w_s, use_per_token_if_dynamic=per_token_quant
            )
281
        elif quant_dtype == torch.float8_e4m3fn:
bnellnm's avatar
bnellnm committed
282
            w, w_s = ops.scaled_fp8_quant(
283
284
                w, w_s, use_per_token_if_dynamic=per_token_quant
            )
285
286
287
288
289
290
291
        elif quant_dtype == "nvfp4":
            assert not per_token_quant
            w_amax = torch.abs(w).max().to(torch.float32)
            w_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w_amax
            w, w_s = ops.scaled_fp4_quant(w, w_gs)
        else:
            raise RuntimeError(f"Unsupported quant type {quant_dtype}")
bnellnm's avatar
bnellnm committed
292

293
    return w, w_s, w_gs
bnellnm's avatar
bnellnm committed
294
295


296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
def moe_quantize_weights(
    w: torch.Tensor,
    w_s: torch.Tensor | None,
    quant_dtype: torch.dtype | str | None,
    per_token_quant: bool,
    block_shape: list[int] | None,
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
    assert w.dim() == 3
    e, rows, cols = w.shape
    w_l = [None] * e
    w_s_l = [None] * e
    w_gs_l = [None] * e
    for idx in range(e):
        w_l[idx], w_s_l[idx], w_gs_l[idx] = moe_quantize_weights_2d(
            w[idx], None, quant_dtype, per_token_quant, block_shape
        )

    w = torch.stack(w_l)
    w_s = torch.stack(w_s_l)
    w_gs = torch.stack(w_gs_l) if e > 0 and w_gs_l[0] is not None else None

    if w_s.ndim == 2:
        assert w_s.shape[-1] == 1
        w_s = w_s.view(-1, 1, 1)

    if block_shape is not None:
        block_n, block_k = block_shape
        n_tiles = (rows + block_n - 1) // block_n
        k_tiles = (cols + block_k - 1) // block_k
        assert w_s.shape == (e, n_tiles, k_tiles)

    return w, w_s, w_gs


bnellnm's avatar
bnellnm committed
330
331
332
333
334
def make_test_weight(
    e: int,
    rows: int,
    cols: int,
    in_dtype: torch.dtype = torch.bfloat16,
335
336
    quant_dtype: torch.dtype | str | None = None,
    block_shape: list[int] | None = None,
337
    per_out_ch_quant: bool = False,
338
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
bnellnm's avatar
bnellnm committed
339
340
341
    w_16 = torch.randn((e, rows, cols), device="cuda", dtype=in_dtype) / 15

    if quant_dtype is not None:
342
343
344
        w, w_s, w_gs = moe_quantize_weights(
            w_16, None, quant_dtype, per_out_ch_quant, block_shape
        )
bnellnm's avatar
bnellnm committed
345
346
347
    else:
        w = w_16
        w_s = None
348
        w_gs = None
bnellnm's avatar
bnellnm committed
349

350
    return w_16, w, w_s, w_gs
bnellnm's avatar
bnellnm committed
351
352
353
354
355
356
357


def make_test_weights(
    e: int,
    n: int,
    k: int,
    in_dtype: torch.dtype = torch.bfloat16,
358
359
    quant_dtype: torch.dtype | str | None = None,
    block_shape: list[int] | None = None,
360
    per_out_ch_quant: bool = False,
361
    make_gate: bool = True,
362
) -> tuple[
363
364
    tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None],
    tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None],
365
]:
bnellnm's avatar
bnellnm committed
366
    return (
367
        make_test_weight(
368
369
370
371
372
373
374
            e,
            (2 if make_gate else 1) * n,
            k,
            in_dtype,
            quant_dtype,
            block_shape,
            per_out_ch_quant,
375
376
        ),
        make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape, per_out_ch_quant),
bnellnm's avatar
bnellnm committed
377
    )
378
379
380


def per_token_cast_to_fp8(
381
382
    x: torch.Tensor, block_size: int = 128
) -> tuple[torch.Tensor, torch.Tensor]:
383
384
385
    assert x.dim() == 2
    m, n = x.shape
    pad_size = (block_size - (n % block_size)) % block_size
386
    x = torch.nn.functional.pad(x, (0, pad_size), value=0) if pad_size > 0 else x
387
388
389
390
    x_view = x.view(m, -1, block_size)
    x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
    fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn)
    return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
391
392


393
394
395
396
397
def make_test_quant_config(
    e: int,
    n: int,
    k: int,
    in_dtype: torch.dtype,
398
    quant_dtype: torch.dtype | str | None = None,
399
    per_act_token_quant: bool = False,
400
    block_shape: list[int] | None = None,
401
    make_gate: bool = True,
402
403
404
405
406
407
408
409
410
) -> tuple[torch.Tensor, torch.Tensor, FusedMoEQuantConfig]:
    (_, w1, w1_s, w1_gs), (_, w2, w2_s, w2_gs) = make_test_weights(
        e,
        n,
        k,
        in_dtype,
        quant_dtype,
        per_out_ch_quant=per_act_token_quant,
        block_shape=block_shape,
411
        make_gate=make_gate,
412
413
414
    )

    # Hacky/trivial scales for nvfp4.
415
416
    a1_gscale: torch.Tensor | None = None
    a2_gscale: torch.Tensor | None = None
417
    if quant_dtype == "nvfp4":
418
419
        a1_gscale = torch.ones((e,), device="cuda", dtype=torch.float32)
        a2_gscale = torch.ones((e,), device="cuda", dtype=torch.float32)
420
421
422
423
424
425
        a1_scale = a1_gscale
        a2_scale = a2_gscale
    else:
        a1_scale = None
        a2_scale = None

426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
    return (
        w1,
        w2,
        FusedMoEQuantConfig.make(
            quant_dtype,
            per_act_token_quant=per_act_token_quant,
            block_shape=block_shape,
            w1_scale=w1_s,
            w2_scale=w2_s,
            a1_gscale=a1_gscale,
            a2_gscale=a2_gscale,
            a1_scale=a1_scale,
            a2_scale=a2_scale,
            # TODO: make sure this is handled properly
            g1_alphas=(1 / w1_gs) if w1_gs is not None else None,
            g2_alphas=(1 / w2_gs) if w2_gs is not None else None,
        ),
443
444
445
446
447
448
449
450
451
452
    )


def fused_moe(
    hidden_states: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    score: torch.Tensor,
    topk: int,
    renormalize: bool = False,
453
    quant_config: FusedMoEQuantConfig | None = None,
454
    global_num_experts: int = -1,
455
    expert_map: torch.Tensor | None = None,
456
) -> torch.Tensor:
457
458
459
460
461
462
463
464
465
466
467
468
469
    topk_weights, topk_ids, _ = fused_topk(
        hidden_states, score.float(), topk, renormalize
    )
    return fused_experts(
        hidden_states,
        w1,
        w2,
        topk_weights,
        topk_ids,
        global_num_experts=global_num_experts,
        expert_map=expert_map,
        quant_config=quant_config,
    )
470
471


472
473
474
475
476
477
478
class BaselineMM(torch.nn.Module):
    def __init__(
        self,
        b: torch.Tensor,
        out_dtype: torch.dtype,
    ):
        super().__init__()
479
        self.b = torch.nn.Parameter(b.to(dtype=torch.float32))
480
481
        self.out_dtype = out_dtype

482
    def forward(self, a: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]:
483
        return torch.mm(a.to(dtype=torch.float32), self.b).to(self.out_dtype), None
484
485


486
487
488
489
490
491
492
493
494
class BaselineSiluAndMul(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        return torch.nn.functional.silu(x[..., :d]) * x[..., d:]


495
496
497
498
499
500
501
502
503
504
class TestMLP(torch.nn.Module):
    def __init__(
        self,
        w1: torch.Tensor,
        w2: torch.Tensor,
        out_dtype: torch.dtype,
    ):
        super().__init__()
        self.gate_up_proj = BaselineMM(w1, out_dtype)
        self.down_proj = BaselineMM(w2, out_dtype)
505
        self.act_fn = BaselineSiluAndMul()
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534

    def forward(self, x):
        x, _ = self.gate_up_proj(x)
        x = self.act_fn(x)
        x, _ = self.down_proj(x)
        return x


def make_naive_shared_experts(
    N: int,
    K: int,
    in_dtype: torch.dtype = torch.bfloat16,
) -> torch.nn.Module:
    w1 = torch.randn((K, N * 2), device="cuda", dtype=in_dtype) / 15
    w2 = torch.randn((N, K), device="cuda", dtype=in_dtype) / 15
    return TestMLP(w1, w2, out_dtype=in_dtype)


class RealMLP(torch.nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        w1: torch.Tensor,
        w2: torch.Tensor,
        hidden_act: str = "silu",
        quant_config=None,
        reduce_results: bool = True,
        prefix: str = "",
535
536
        w1_s: torch.Tensor | None = None,
        w2_s: torch.Tensor | None = None,
537
538
    ) -> None:
        from vllm.model_executor.layers.linear import (
539
540
541
            MergedColumnParallelLinear,
            RowParallelLinear,
        )
542
543
544

        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
545
546
            hidden_size,
            [intermediate_size] * 2,
547
548
            bias=False,
            quant_config=quant_config,
549
550
            prefix=f"{prefix}.gate_up_proj",
        )
551
        self.gate_up_proj.register_parameter(
552
553
            "weight", torch.nn.Parameter(w1, requires_grad=False)
        )
554
        self.gate_up_proj.register_parameter(
555
556
            "weight_scale", torch.nn.Parameter(w1_s, requires_grad=False)
        )
557
        self.gate_up_proj.register_parameter(
558
559
560
561
562
563
564
565
566
567
            "input_scale", None
        )  # torch.nn.Parameter(None, requires_grad=False))
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            reduce_results=reduce_results,
            prefix=f"{prefix}.down_proj",
        )
568
        self.down_proj.register_parameter(
569
570
            "weight", torch.nn.Parameter(w2, requires_grad=False)
        )
571
        self.down_proj.register_parameter(
572
573
            "weight_scale", torch.nn.Parameter(w2_s, requires_grad=False)
        )
574
        self.down_proj.register_parameter(
575
576
            "input_scale", None
        )  # torch.nn.Parameter(None, requires_grad=False))
577
        if hidden_act != "silu":
578
579
580
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
581
582
583
584
585
586
587
588
589
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


590
def make_shared_experts_with_weights(
591
592
    N: int,
    K: int,
593
594
595
596
597
    in_dtype: torch.dtype,
    w1: torch.Tensor,
    w2: torch.Tensor,
    w1_s: torch.Tensor | None = None,
    w2_s: torch.Tensor | None = None,
598
    quant_dtype: torch.dtype | str | None = None,
599
600
601
602
603
) -> torch.nn.Module:
    old_dtype = torch.get_default_dtype()
    try:
        torch.set_default_dtype(in_dtype)
        if quant_dtype == torch.float8_e4m3fn:
604
605
            from vllm.model_executor.layers.quantization.fp8 import Fp8Config

606
607
608
609
            quant_config = Fp8Config(True)
        else:
            quant_config = None

610
        return RealMLP(K, N, w1, w2, "silu", quant_config, w1_s=w1_s, w2_s=w2_s)
611
612
    finally:
        torch.set_default_dtype(old_dtype)
613
614
615
616
617


def modular_triton_fused_moe(
    moe_config: FusedMoEConfig,
    quant_config: FusedMoEQuantConfig,
618
619
620
621
622
623
624
625
) -> FusedMoEKernel:
    return FusedMoEKernel(
        maybe_make_prepare_finalize(
            moe=moe_config,
            quant_config=quant_config,
            allow_new_interface=True,
            use_monolithic=False,
        ),
626
627
628
        TritonExperts(moe_config, quant_config),
        inplace=False,
    )
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647


def make_shared_experts(
    N: int,
    K: int,
    in_dtype: torch.dtype = torch.bfloat16,
    quant_dtype: torch.dtype | str | None = None,
) -> torch.nn.Module:
    (_, w1, w1_s, _), (_, w2, w2_s, _) = make_test_weights(
        1,
        N,
        K,
        in_dtype=in_dtype,
        quant_dtype=quant_dtype,
    )

    return make_shared_experts_with_weights(
        N, K, in_dtype, w1, w2, w1_s=w1_s, w2_s=w2_s, quant_dtype=quant_dtype
    )