utils.py 19.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
bnellnm's avatar
bnellnm committed
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5

import torch

bnellnm's avatar
bnellnm committed
6
import vllm._custom_ops as ops
7
from tests.kernels.quant_utils import per_block_cast_to_int8
8
from tests.kernels.quantization.nvfp4_utils import FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX
9
from vllm.model_executor.layers.activation import SiluAndMul
10
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
11
12
13
from vllm.model_executor.layers.fused_moe.all2all_utils import (
    maybe_make_prepare_finalize,
)
14
15
16
17
18
19
from vllm.model_executor.layers.fused_moe.config import (
    FusedMoEConfig,
    FusedMoEParallelConfig,
    FusedMoEQuantConfig,
    RoutingMethodType,
)
bnellnm's avatar
bnellnm committed
20
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
21
22
23
24
    BatchedPrepareAndFinalize,
    BatchedTritonExperts,
    NaiveBatchedExperts,
)
25
26
27
28
from vllm.model_executor.layers.fused_moe.fused_moe import (
    TritonExperts,
    fused_experts,
)
29
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEKernel
30
from vllm.model_executor.layers.fused_moe.router.fused_topk_router import fused_topk
31
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
32
from vllm.utils.deep_gemm import per_block_cast_to_fp8
33
from vllm.utils.math_utils import round_up
bnellnm's avatar
bnellnm committed
34
35


36
37
38
39
40
41
42
43
44
45
def shuffle_weight(w: torch.Tensor) -> torch.Tensor:
    """Fold weights to adjacent locations for Triton MoE / SwiGLU kernel layout."""
    shape = w.shape
    n = shape[-1]
    first = w[..., : n // 2]
    second = w[..., n // 2 :]
    stacked = torch.stack((first, second), dim=-1)
    return stacked.reshape(shape)


46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def make_dummy_moe_config(
    num_experts: int = 1,
    experts_per_token: int = 1,
    hidden_dim: int = 1,
    intermediate_size_per_partition: int = 1,
    in_dtype: torch.dtype = torch.bfloat16,
) -> FusedMoEConfig:
    """
    This is a dummy config for the mk constructor interface
    as most kernels like DeepGEMM, CUTLASSFp4, Triton, MARLIN
    do not actually use this config.

    CUTLASSFp8 needs to set some params for workshapes.
    """
    return FusedMoEConfig(
        num_experts=num_experts,
        experts_per_token=experts_per_token,
        hidden_dim=hidden_dim,
        intermediate_size_per_partition=intermediate_size_per_partition,
        num_local_experts=num_experts,
66
        num_logical_experts=num_experts,
67
        moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
68
        activation=MoEActivation.SILU,
69
70
71
        in_dtype=in_dtype,
        device="cuda",
        routing_method=RoutingMethodType.TopK,
72
        max_num_tokens=512,
73
74
75
    )


bnellnm's avatar
bnellnm committed
76
77
78
79
80
81
def triton_moe(
    a: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weight: torch.Tensor,
    topk_ids: torch.Tensor,
82
83
84
85
86
    w1_scale: torch.Tensor | None = None,
    w2_scale: torch.Tensor | None = None,
    a1_scale: torch.Tensor | None = None,
    a2_scale: torch.Tensor | None = None,
    quant_dtype: torch.dtype | None = None,
bnellnm's avatar
bnellnm committed
87
    per_act_token_quant=False,
88
    block_shape: list[int] | None = None,
bnellnm's avatar
bnellnm committed
89
) -> torch.Tensor:
90
91
92
93
94
95
96
97
98
99
    quant_config = FusedMoEQuantConfig.make(
        quant_dtype,
        per_act_token_quant=per_act_token_quant,
        block_shape=block_shape,
        w1_scale=w1_scale,
        w2_scale=w2_scale,
        a1_scale=a1_scale,
        a2_scale=a2_scale,
    )

100
    return fused_experts(a, w1, w2, topk_weight, topk_ids, quant_config=quant_config)
bnellnm's avatar
bnellnm committed
101
102
103
104
105
106
107
108


def batched_moe(
    a: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weight: torch.Tensor,
    topk_ids: torch.Tensor,
109
110
111
112
113
    w1_scale: torch.Tensor | None = None,
    w2_scale: torch.Tensor | None = None,
    a1_scale: torch.Tensor | None = None,
    a2_scale: torch.Tensor | None = None,
    quant_dtype: torch.dtype | None = None,
bnellnm's avatar
bnellnm committed
114
    per_act_token_quant: bool = False,
115
    block_shape: list[int] | None = None,
bnellnm's avatar
bnellnm committed
116
117
118
) -> torch.Tensor:
    max_num_tokens = round_up(a.shape[0], 64)

119
120
121
122
123
124
125
126
127
128
    quant_config = FusedMoEQuantConfig.make(
        quant_dtype,
        per_act_token_quant=per_act_token_quant,
        block_shape=block_shape,
        w1_scale=w1_scale,
        w2_scale=w2_scale,
        a1_scale=a1_scale,
        a2_scale=a2_scale,
    )

129
130
131
    moe_config = make_dummy_moe_config()

    fused_experts = FusedMoEKernel(
132
133
134
        BatchedPrepareAndFinalize(
            max_num_tokens, num_dispatchers=1, num_local_experts=w1.shape[0], rank=0
        ),
bnellnm's avatar
bnellnm committed
135
136
        BatchedTritonExperts(
            max_num_tokens=max_num_tokens,
137
            num_dispatchers=1,
138
            quant_config=quant_config,
139
            moe_config=moe_config,
bnellnm's avatar
bnellnm committed
140
        ),
141
        inplace=False,
142
143
    )

144
145
146
147
148
149
150
151
152
153
154
    return fused_experts.apply(
        a,
        w1,
        w2,
        topk_weight,
        topk_ids,
        global_num_experts=w1.shape[0],
        activation=moe_config.activation,
        apply_router_weight_on_input=False,
        expert_map=None,
    )
bnellnm's avatar
bnellnm committed
155
156
157
158
159
160
161
162


def naive_batched_moe(
    a: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weight: torch.Tensor,
    topk_ids: torch.Tensor,
163
164
165
166
167
    w1_scale: torch.Tensor | None = None,
    w2_scale: torch.Tensor | None = None,
    a1_scale: torch.Tensor | None = None,
    a2_scale: torch.Tensor | None = None,
    quant_dtype: torch.dtype | None = None,
bnellnm's avatar
bnellnm committed
168
    per_act_token_quant: bool = False,
169
    block_shape: list[int] | None = None,
bnellnm's avatar
bnellnm committed
170
171
172
) -> torch.Tensor:
    max_num_tokens = round_up(a.shape[0], 64)

173
174
175
176
177
178
179
180
181
    quant_config = FusedMoEQuantConfig.make(
        quant_dtype,
        per_act_token_quant=per_act_token_quant,
        block_shape=block_shape,
        w1_scale=w1_scale,
        w2_scale=w2_scale,
        a1_scale=a1_scale,
        a2_scale=a2_scale,
    )
182
    moe_config = make_dummy_moe_config()
183

184
    fused_experts = FusedMoEKernel(
185
186
187
        BatchedPrepareAndFinalize(
            max_num_tokens, num_dispatchers=1, num_local_experts=w1.shape[0], rank=0
        ),
bnellnm's avatar
bnellnm committed
188
189
        NaiveBatchedExperts(
            max_num_tokens=max_num_tokens,
190
            num_dispatchers=1,
191
            quant_config=quant_config,
192
            moe_config=moe_config,
bnellnm's avatar
bnellnm committed
193
        ),
194
        inplace=False,
bnellnm's avatar
bnellnm committed
195
    )
196

197
198
199
200
201
202
203
204
205
206
207
    return fused_experts.apply(
        a,
        w1,
        w2,
        topk_weight,
        topk_ids,
        global_num_experts=w1.shape[0],
        activation=moe_config.activation,
        apply_router_weight_on_input=False,
        expert_map=None,
    )
bnellnm's avatar
bnellnm committed
208
209


210
def chunk_scales(
211
212
    scales: torch.Tensor | None, start: int, end: int
) -> torch.Tensor | None:
bnellnm's avatar
bnellnm committed
213
214
215
216
217
218
219
220
221
222
223
224
225
    if scales is not None:
        if scales.numel() == 1:
            return scales
        else:
            return scales[start:end]
    return None


def make_quantized_test_activations(
    E: int,
    m: int,
    k: int,
    in_dtype: torch.dtype,
226
227
    quant_dtype: torch.dtype | None = None,
    block_shape: list[int] | None = None,
bnellnm's avatar
bnellnm committed
228
    per_act_token_quant: bool = False,
229
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
bnellnm's avatar
bnellnm committed
230
231
232
233
234
    a = torch.randn((E, m, k), device="cuda", dtype=in_dtype) / 10
    a_q = a
    a_scale = None

    if quant_dtype is not None:
235
236
237
        assert quant_dtype == torch.float8_e4m3fn or quant_dtype == torch.int8, (
            "only fp8/int8 supported"
        )
bnellnm's avatar
bnellnm committed
238
239
240
241
        a_q = torch.zeros_like(a, dtype=quant_dtype)
        a_scale_l = [None] * E
        for e in range(E):
            a_q[e], a_scale_l[e] = moe_kernel_quantize_input(
242
243
                a[e], None, quant_dtype, per_act_token_quant, block_shape
            )
bnellnm's avatar
bnellnm committed
244
245
246
247
248
249
250
251
        a_scale = torch.stack(a_scale_l)

        if not per_act_token_quant and block_shape is None:
            a_scale = a_scale.view(E, 1, 1)

    return a, a_q, a_scale


252
def moe_quantize_weights_2d(
bnellnm's avatar
bnellnm committed
253
    w: torch.Tensor,
254
255
    w_s: torch.Tensor | None,
    quant_dtype: torch.dtype | str | None,
bnellnm's avatar
bnellnm committed
256
    per_token_quant: bool,
257
258
    block_shape: list[int] | None,
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
259
260
261
262
263
    assert (
        quant_dtype == torch.float8_e4m3fn
        or quant_dtype == torch.int8
        or quant_dtype == "nvfp4"
    ), "only fp8/int8/nvfp4 supported"
264
265

    w_gs = None
bnellnm's avatar
bnellnm committed
266
267
268
269
270

    if block_shape is not None:
        assert not per_token_quant
        if quant_dtype == torch.int8:
            w, w_s = per_block_cast_to_int8(w, block_shape)
271
        elif quant_dtype == torch.float8_e4m3fn:
bnellnm's avatar
bnellnm committed
272
            w, w_s = per_block_cast_to_fp8(w, block_shape)
273
274
275
276
        elif quant_dtype == "nvfp4":
            raise RuntimeError("blocked quantization not supported for nvfp4")
        else:
            raise RuntimeError(f"Unsupported quant type {quant_dtype}")
bnellnm's avatar
bnellnm committed
277
278
279
    else:
        if quant_dtype == torch.int8:
            w, w_s = ops.scaled_int8_quant(
280
281
                w, w_s, use_per_token_if_dynamic=per_token_quant
            )
282
        elif quant_dtype == torch.float8_e4m3fn:
bnellnm's avatar
bnellnm committed
283
            w, w_s = ops.scaled_fp8_quant(
284
285
                w, w_s, use_per_token_if_dynamic=per_token_quant
            )
286
287
288
289
290
291
292
        elif quant_dtype == "nvfp4":
            assert not per_token_quant
            w_amax = torch.abs(w).max().to(torch.float32)
            w_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w_amax
            w, w_s = ops.scaled_fp4_quant(w, w_gs)
        else:
            raise RuntimeError(f"Unsupported quant type {quant_dtype}")
bnellnm's avatar
bnellnm committed
293

294
    return w, w_s, w_gs
bnellnm's avatar
bnellnm committed
295
296


297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
def moe_quantize_weights(
    w: torch.Tensor,
    w_s: torch.Tensor | None,
    quant_dtype: torch.dtype | str | None,
    per_token_quant: bool,
    block_shape: list[int] | None,
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
    assert w.dim() == 3
    e, rows, cols = w.shape
    w_l = [None] * e
    w_s_l = [None] * e
    w_gs_l = [None] * e
    for idx in range(e):
        w_l[idx], w_s_l[idx], w_gs_l[idx] = moe_quantize_weights_2d(
            w[idx], None, quant_dtype, per_token_quant, block_shape
        )

    w = torch.stack(w_l)
    w_s = torch.stack(w_s_l)
    w_gs = torch.stack(w_gs_l) if e > 0 and w_gs_l[0] is not None else None

    if w_s.ndim == 2:
        assert w_s.shape[-1] == 1
        w_s = w_s.view(-1, 1, 1)

    if block_shape is not None:
        block_n, block_k = block_shape
        n_tiles = (rows + block_n - 1) // block_n
        k_tiles = (cols + block_k - 1) // block_k
        assert w_s.shape == (e, n_tiles, k_tiles)

    return w, w_s, w_gs


bnellnm's avatar
bnellnm committed
331
332
333
334
335
def make_test_weight(
    e: int,
    rows: int,
    cols: int,
    in_dtype: torch.dtype = torch.bfloat16,
336
337
    quant_dtype: torch.dtype | str | None = None,
    block_shape: list[int] | None = None,
338
    per_out_ch_quant: bool = False,
339
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
bnellnm's avatar
bnellnm committed
340
341
342
    w_16 = torch.randn((e, rows, cols), device="cuda", dtype=in_dtype) / 15

    if quant_dtype is not None:
343
344
345
        w, w_s, w_gs = moe_quantize_weights(
            w_16, None, quant_dtype, per_out_ch_quant, block_shape
        )
bnellnm's avatar
bnellnm committed
346
347
348
    else:
        w = w_16
        w_s = None
349
        w_gs = None
bnellnm's avatar
bnellnm committed
350

351
    return w_16, w, w_s, w_gs
bnellnm's avatar
bnellnm committed
352
353
354
355
356
357
358


def make_test_weights(
    e: int,
    n: int,
    k: int,
    in_dtype: torch.dtype = torch.bfloat16,
359
360
    quant_dtype: torch.dtype | str | None = None,
    block_shape: list[int] | None = None,
361
    per_out_ch_quant: bool = False,
362
    make_gate: bool = True,
363
) -> tuple[
364
365
    tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None],
    tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None],
366
]:
bnellnm's avatar
bnellnm committed
367
    return (
368
        make_test_weight(
369
370
371
372
373
374
375
            e,
            (2 if make_gate else 1) * n,
            k,
            in_dtype,
            quant_dtype,
            block_shape,
            per_out_ch_quant,
376
377
        ),
        make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape, per_out_ch_quant),
bnellnm's avatar
bnellnm committed
378
    )
379
380
381


def per_token_cast_to_fp8(
382
383
    x: torch.Tensor, block_size: int = 128
) -> tuple[torch.Tensor, torch.Tensor]:
384
385
386
    assert x.dim() == 2
    m, n = x.shape
    pad_size = (block_size - (n % block_size)) % block_size
387
    x = torch.nn.functional.pad(x, (0, pad_size), value=0) if pad_size > 0 else x
388
389
390
391
    x_view = x.view(m, -1, block_size)
    x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
    fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn)
    return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
392
393


394
395
396
397
398
def make_test_quant_config(
    e: int,
    n: int,
    k: int,
    in_dtype: torch.dtype,
399
    quant_dtype: torch.dtype | str | None = None,
400
    per_act_token_quant: bool = False,
401
    block_shape: list[int] | None = None,
402
    make_gate: bool = True,
403
404
405
406
407
408
409
410
411
) -> tuple[torch.Tensor, torch.Tensor, FusedMoEQuantConfig]:
    (_, w1, w1_s, w1_gs), (_, w2, w2_s, w2_gs) = make_test_weights(
        e,
        n,
        k,
        in_dtype,
        quant_dtype,
        per_out_ch_quant=per_act_token_quant,
        block_shape=block_shape,
412
        make_gate=make_gate,
413
414
415
    )

    # Hacky/trivial scales for nvfp4.
416
417
    a1_gscale: torch.Tensor | None = None
    a2_gscale: torch.Tensor | None = None
418
    if quant_dtype == "nvfp4":
419
420
        a1_gscale = torch.ones((e,), device="cuda", dtype=torch.float32)
        a2_gscale = torch.ones((e,), device="cuda", dtype=torch.float32)
421
422
423
424
425
426
        a1_scale = a1_gscale
        a2_scale = a2_gscale
    else:
        a1_scale = None
        a2_scale = None

427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
    return (
        w1,
        w2,
        FusedMoEQuantConfig.make(
            quant_dtype,
            per_act_token_quant=per_act_token_quant,
            block_shape=block_shape,
            w1_scale=w1_s,
            w2_scale=w2_s,
            a1_gscale=a1_gscale,
            a2_gscale=a2_gscale,
            a1_scale=a1_scale,
            a2_scale=a2_scale,
            # TODO: make sure this is handled properly
            g1_alphas=(1 / w1_gs) if w1_gs is not None else None,
            g2_alphas=(1 / w2_gs) if w2_gs is not None else None,
        ),
444
445
446
447
448
449
450
451
452
453
    )


def fused_moe(
    hidden_states: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    score: torch.Tensor,
    topk: int,
    renormalize: bool = False,
454
    quant_config: FusedMoEQuantConfig | None = None,
455
    global_num_experts: int = -1,
456
    expert_map: torch.Tensor | None = None,
457
) -> torch.Tensor:
458
459
460
461
462
463
464
465
466
467
468
469
470
    topk_weights, topk_ids, _ = fused_topk(
        hidden_states, score.float(), topk, renormalize
    )
    return fused_experts(
        hidden_states,
        w1,
        w2,
        topk_weights,
        topk_ids,
        global_num_experts=global_num_experts,
        expert_map=expert_map,
        quant_config=quant_config,
    )
471
472


473
474
475
476
477
478
479
class BaselineMM(torch.nn.Module):
    def __init__(
        self,
        b: torch.Tensor,
        out_dtype: torch.dtype,
    ):
        super().__init__()
480
        self.b = torch.nn.Parameter(b.to(dtype=torch.float32))
481
482
        self.out_dtype = out_dtype

483
    def forward(self, a: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]:
484
        return torch.mm(a.to(dtype=torch.float32), self.b).to(self.out_dtype), None
485
486


487
488
489
490
491
492
493
494
495
class BaselineSiluAndMul(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        return torch.nn.functional.silu(x[..., :d]) * x[..., d:]


496
497
498
499
500
501
502
503
504
505
class TestMLP(torch.nn.Module):
    def __init__(
        self,
        w1: torch.Tensor,
        w2: torch.Tensor,
        out_dtype: torch.dtype,
    ):
        super().__init__()
        self.gate_up_proj = BaselineMM(w1, out_dtype)
        self.down_proj = BaselineMM(w2, out_dtype)
506
        self.act_fn = BaselineSiluAndMul()
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535

    def forward(self, x):
        x, _ = self.gate_up_proj(x)
        x = self.act_fn(x)
        x, _ = self.down_proj(x)
        return x


def make_naive_shared_experts(
    N: int,
    K: int,
    in_dtype: torch.dtype = torch.bfloat16,
) -> torch.nn.Module:
    w1 = torch.randn((K, N * 2), device="cuda", dtype=in_dtype) / 15
    w2 = torch.randn((N, K), device="cuda", dtype=in_dtype) / 15
    return TestMLP(w1, w2, out_dtype=in_dtype)


class RealMLP(torch.nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        w1: torch.Tensor,
        w2: torch.Tensor,
        hidden_act: str = "silu",
        quant_config=None,
        reduce_results: bool = True,
        prefix: str = "",
536
537
        w1_s: torch.Tensor | None = None,
        w2_s: torch.Tensor | None = None,
538
539
    ) -> None:
        from vllm.model_executor.layers.linear import (
540
541
542
            MergedColumnParallelLinear,
            RowParallelLinear,
        )
543
544
545

        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
546
547
            hidden_size,
            [intermediate_size] * 2,
548
549
            bias=False,
            quant_config=quant_config,
550
551
            prefix=f"{prefix}.gate_up_proj",
        )
552
        self.gate_up_proj.register_parameter(
553
554
            "weight", torch.nn.Parameter(w1, requires_grad=False)
        )
555
        self.gate_up_proj.register_parameter(
556
557
            "weight_scale", torch.nn.Parameter(w1_s, requires_grad=False)
        )
558
        self.gate_up_proj.register_parameter(
559
560
561
562
563
564
565
566
567
568
            "input_scale", None
        )  # torch.nn.Parameter(None, requires_grad=False))
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            reduce_results=reduce_results,
            prefix=f"{prefix}.down_proj",
        )
569
        self.down_proj.register_parameter(
570
571
            "weight", torch.nn.Parameter(w2, requires_grad=False)
        )
572
        self.down_proj.register_parameter(
573
574
            "weight_scale", torch.nn.Parameter(w2_s, requires_grad=False)
        )
575
        self.down_proj.register_parameter(
576
577
            "input_scale", None
        )  # torch.nn.Parameter(None, requires_grad=False))
578
        if hidden_act != "silu":
579
580
581
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
582
583
584
585
586
587
588
589
590
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


591
def make_shared_experts_with_weights(
592
593
    N: int,
    K: int,
594
595
596
597
598
    in_dtype: torch.dtype,
    w1: torch.Tensor,
    w2: torch.Tensor,
    w1_s: torch.Tensor | None = None,
    w2_s: torch.Tensor | None = None,
599
    quant_dtype: torch.dtype | str | None = None,
600
601
602
603
604
) -> torch.nn.Module:
    old_dtype = torch.get_default_dtype()
    try:
        torch.set_default_dtype(in_dtype)
        if quant_dtype == torch.float8_e4m3fn:
605
606
            from vllm.model_executor.layers.quantization.fp8 import Fp8Config

607
608
609
610
            quant_config = Fp8Config(True)
        else:
            quant_config = None

611
        return RealMLP(K, N, w1, w2, "silu", quant_config, w1_s=w1_s, w2_s=w2_s)
612
613
    finally:
        torch.set_default_dtype(old_dtype)
614
615
616
617
618


def modular_triton_fused_moe(
    moe_config: FusedMoEConfig,
    quant_config: FusedMoEQuantConfig,
619
620
621
622
623
624
625
626
) -> FusedMoEKernel:
    return FusedMoEKernel(
        maybe_make_prepare_finalize(
            moe=moe_config,
            quant_config=quant_config,
            allow_new_interface=True,
            use_monolithic=False,
        ),
627
628
629
        TritonExperts(moe_config, quant_config),
        inplace=False,
    )
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648


def make_shared_experts(
    N: int,
    K: int,
    in_dtype: torch.dtype = torch.bfloat16,
    quant_dtype: torch.dtype | str | None = None,
) -> torch.nn.Module:
    (_, w1, w1_s, _), (_, w2, w2_s, _) = make_test_weights(
        1,
        N,
        K,
        in_dtype=in_dtype,
        quant_dtype=quant_dtype,
    )

    return make_shared_experts_with_weights(
        N, K, in_dtype, w1, w2, w1_s=w1_s, w2_s=w2_s, quant_dtype=quant_dtype
    )