common.py 23.8 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
4
from typing import Any
5
6
7
8
9

import torch

import vllm._custom_ops as ops
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
10
from tests.kernels.moe.utils import make_test_weights, per_token_cast_to_fp8
11
12
13
14
15
from tests.kernels.quantization.nvfp4_utils import (
    FLOAT4_E2M1_MAX,
    FLOAT8_E4M3_MAX,
    dequantize_nvfp4_to_dtype,
)
16
17
from tests.kernels.utils import torch_experts
from vllm.config import VllmConfig
18
19
20
21
22
from vllm.distributed import (
    get_dp_group,
    get_pcp_group,
    get_tensor_model_parallel_world_size,
)
23
from vllm.forward_context import set_forward_context
24
from vllm.model_executor.layers.fused_moe import fused_topk
25
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
26
27
28
from vllm.model_executor.layers.fused_moe.all2all_utils import (
    maybe_make_prepare_finalize,
)
29
from vllm.model_executor.layers.fused_moe.config import (
30
31
32
    FusedMoEConfig,
    FusedMoEParallelConfig,
    FusedMoEQuantConfig,
33
    RoutingMethodType,
34
)
35
36
37
38
39
40
41
42
from vllm.model_executor.layers.quantization.utils.quant_utils import (
    kFp8Dynamic128Sym,
    kFp8DynamicTensorSym,
    kFp8DynamicTokenSym,
    kFp8Static128BlockSym,
    kFp8StaticChannelSym,
    kFp8StaticTensorSym,
)
43
44
45
46
47
48
from vllm.utils.import_utils import (
    has_aiter,
    has_deep_ep,
    has_deep_gemm,
    has_mori,
)
49
from vllm.utils.math_utils import next_power_of_2
50

51
52
53
54
55
56
from .mk_objects import (
    TestMoEQuantConfig,
    expert_info,
    make_fused_experts,
    prepare_finalize_info,
)
57
58
59
from .parallel_utils import ProcessGroupInfo


60
def _describe_tensor(t: torch.Tensor | None, name: str) -> str:
61
62
63
64
65
66
67
68
    if t is None:
        return f"{name} : None"
    else:
        return f"{name} : {t.shape} {t.dtype} {t.device}"


@dataclass
class Config:
69
    Ms: list[int] | int
70
71
72
    K: int
    N: int
    E: int
73
    topks: list[int] | int
74
    dtype: torch.dtype
75
    quant_config: TestMoEQuantConfig | None
76
77

    prepare_finalize_type: mk.FusedMoEPrepareAndFinalize
78
    fused_experts_type: mk.FusedMoEExperts
79
80
81

    world_size: int

82
    torch_trace_dir_path: str | None = None
83

84
85
    def __post_init__(self):
        if self.quant_config is None:
86
            self.quant_config = TestMoEQuantConfig(None, False, False, None)
87

88
89
    def describe(self) -> str:
        s = ""
90
91
92
93
94
95
96
97
98
99
100
        s += "== Config:\n"
        s += f" world_size={self.world_size}\n"
        s += f" PF={self.prepare_finalize_type.__name__}\n"
        s += f" FE={self.fused_experts_type.__name__}\n"
        s += f" E={self.E}\n"
        s += f" Ms={self.Ms}\n"
        s += f" N={self.N}\n"
        s += f" K={self.K}\n"
        s += f" topk={self.topks}\n"
        s += f" dtype={self.dtype}\n"
        s += " Quant:\n"
101
        if self.quant_config is not None:
102
103
104
105
            s += f"     q_dtype={self.quant_dtype}\n"
            s += f"     q_block_shape={self.quant_block_shape}\n"
            s += f"     q_per_out_ch_quant={self.is_per_out_ch_quant}\n"
            s += f"     q_per_act_token={self.is_per_act_token_quant}\n"
106
        else:
107
            s += "     quant=None\n"
108
109
110
111
112
113
114
115
        return s

    @property
    def M(self) -> int:
        assert isinstance(self.Ms, int)
        return self.Ms

    @property
116
    def quant_dtype(self) -> torch.dtype | str | None:
117
        assert self.quant_config is not None
118
119
120
121
        return self.quant_config.quant_dtype

    @property
    def is_per_act_token_quant(self) -> bool:
122
        assert self.quant_config is not None
123
124
125
126
        return self.quant_config.per_act_token_quant

    @property
    def is_per_tensor_act_quant(self) -> bool:
127
        return not self.is_per_act_token_quant and self.quant_block_shape is None
128
129
130

    @property
    def is_per_out_ch_quant(self) -> bool:
131
        assert self.quant_config is not None
132
133
134
        return self.quant_config.per_out_ch_quant

    @property
135
    def quant_block_shape(self) -> list[int] | None:
136
        assert self.quant_config is not None
137
138
139
140
141
142
143
144
145
146
147
148
149
        return self.quant_config.block_shape

    @property
    def topk(self) -> int:
        assert isinstance(self.topks, int)
        return self.topks

    @property
    def num_local_experts(self) -> int:
        return self.E // self.world_size

    def make_env_data(self) -> tuple[VllmConfig, dict[Any, Any]]:
        """
150
        make env data for vllm launch.
151
152
153
154
155
156
157
158
        """
        vllm_config = VllmConfig()
        vllm_config.parallel_config.data_parallel_size = self.world_size
        vllm_config.parallel_config.enable_expert_parallel = True

        env_dict = {
            "VLLM_USE_DEEP_GEMM": str(int(self.needs_deep_gemm())),
        }
159

160
        vllm_config.parallel_config.all2all_backend = self.all2all_backend()
161

162
163
        return vllm_config, env_dict

164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
    def fe_supports_quant_scheme(self) -> bool:
        """Check if the fused experts class supports this quant config.
        See https://github.com/ROCm/aiter/issues/2419 for AITER gaps."""
        if self.quant_config is None or self.quant_dtype is None:
            return True
        if self.quant_dtype != torch.float8_e4m3fn:
            return True
        # Derive QuantKeys from test config
        if self.quant_block_shape is not None:
            w_key = kFp8Static128BlockSym
            a_key = kFp8Dynamic128Sym
        elif self.is_per_out_ch_quant:
            w_key = kFp8StaticChannelSym
            a_key = (
                kFp8DynamicTokenSym
                if self.is_per_act_token_quant
                else kFp8StaticTensorSym
            )
        else:
            w_key = kFp8StaticTensorSym
            a_key = (
                kFp8DynamicTensorSym
                if self.is_per_act_token_quant
                else kFp8StaticTensorSym
            )
        fe_cls = self.fused_experts_type
        if hasattr(fe_cls, "_supports_quant_scheme"):
            try:
                return fe_cls._supports_quant_scheme(w_key, a_key)
            except NotImplementedError:
                pass
        return True

197
    def is_fp8_block_quantized(self):
198
199
200
201
        return (
            self.quant_dtype == torch.float8_e4m3fn
            and self.quant_block_shape is not None
        )
202
203

    def is_batched_prepare_finalize(self):
204
        info = prepare_finalize_info(self.prepare_finalize_type)
205
        return mk.FusedMoEActivationFormat.BatchedExperts == info.activation_format
206
207

    def is_batched_fused_experts(self):
208
        info = expert_info(self.fused_experts_type)
209
        return mk.FusedMoEActivationFormat.BatchedExperts == info.activation_format
210
211

    def is_standard_fused_experts(self):
212
213
214
215
216
217
218
219
220
221
222
223
224
225
        info = expert_info(self.fused_experts_type)
        return mk.FusedMoEActivationFormat.Standard == info.activation_format

    def fe_supported_types(self):
        info = expert_info(self.fused_experts_type)
        return info.supported_dtypes

    def pf_supported_types(self):
        info = prepare_finalize_info(self.prepare_finalize_type)
        return info.supported_dtypes

    def is_block_quant_supported(self):
        info = expert_info(self.fused_experts_type)
        return info.blocked_quantization_support
226

227
228
229
230
231
232
233
    def supports_expert_map(self):
        info = expert_info(self.fused_experts_type)
        return info.supports_expert_map

    def supports_apply_weight_on_input(self):
        info = prepare_finalize_info(self.prepare_finalize_type)
        return info.supports_apply_weight_on_input
234
235

    def needs_deep_gemm(self):
236
237
        info = expert_info(self.fused_experts_type)
        return info.needs_deep_gemm
238
239

    def needs_deep_ep(self):
240
        info = prepare_finalize_info(self.prepare_finalize_type)
241
242
243
244
        return (
            info.backend == "deepep_high_throughput"
            or info.backend == "deepep_low_latency"
        )
245

246
247
248
249
250
251
252
253
    def needs_aiter(self):
        info = expert_info(self.fused_experts_type)
        return info.needs_aiter

    def needs_mori(self):
        info = prepare_finalize_info(self.prepare_finalize_type)
        return info.backend == "mori"

254
    def all2all_backend(self):
255
256
        info = prepare_finalize_info(self.prepare_finalize_type)
        return info.backend
257

258
    def is_valid(self) -> tuple[bool, str | None]:
259
260
261
        # Check prepare-finalize and fused-experts compatibility
        if self.is_batched_prepare_finalize():
            if not self.is_batched_fused_experts():
262
                return False, "Mismatched format."
263
264
        else:
            if not self.is_standard_fused_experts():
265
                return False, "Mismatched format."
266
267

        # Check quantization sanity
268
269
270
271
272
        if (
            int(self.is_per_act_token_quant)
            + int(self.is_per_tensor_act_quant)
            + int(self.quant_block_shape is not None)
        ) > 1:
273
            # invalid quant config
274
            return False, f"Bad quant_config {self.quant_config}."
275

276
277
        # check type support
        if self.quant_dtype is None:
278
279
280
281
            if (
                self.dtype not in self.pf_supported_types()
                or self.dtype not in self.fe_supported_types()
            ):
282
283
284
285
286
                return False, (
                    f"Unsupported type {self.dtype} not in "
                    f"{self.pf_supported_types()} and "
                    f"{self.fe_supported_types()}."
                )
287
        else:
288
289
290
291
            if (
                self.quant_dtype not in self.pf_supported_types()
                or self.quant_dtype not in self.fe_supported_types()
            ):
292
293
294
295
296
                return False, (
                    f"Unsupported quant type {self.quant_dtype} "
                    f"not in {self.pf_supported_types()} and "
                    f"{self.fe_supported_types()}."
                )
297

298
299
300
301
302
303
304
305
306
        # Check quant scheme compatibility with fused experts class
        if not self.fe_supports_quant_scheme():
            return False, (
                f"FE {self.fused_experts_type.__name__} does not support "
                f"quant scheme (per_out_ch={self.is_per_out_ch_quant}, "
                f"per_act_token={self.is_per_act_token_quant}, "
                f"block={self.quant_block_shape})"
            )

307
308
309
        # Check block quantization support
        is_block_quantized = self.quant_block_shape is not None
        if is_block_quantized and self.quant_dtype is None:
310
311
            return False, "No block quantization support."

312
        if is_block_quantized and not self.is_block_quant_supported():
313
            return False, "Mismatched block quantization support."
314
315

        # deep_gemm only works with block-quantized
316
        if self.needs_deep_gemm() and not is_block_quantized:
317
            return False, "Needs DeepGEMM but not block quantized."
318

319
        # Check dependencies (turn into asserts?)
320
        if self.needs_deep_ep() and not has_deep_ep():
321
            return False, "Needs DeepEP, but DeepEP not available."
322
        if self.needs_deep_gemm() and not has_deep_gemm():
323
            return False, "Needs DeepGEMM, but DeepGEMM not available."
324
325
326
327
        if self.needs_aiter() and not has_aiter():  # noqa: SIM103
            return False, "Needs Aiter, but Aiter not available."
        if self.needs_mori() and not has_mori():  # noqa: SIM103
            return False, "Needs MoRI, but MoRI not available."
328

329
        return True, None
330
331
332
333
334
335


@dataclass
class WeightTensors:
    w1: torch.Tensor
    w2: torch.Tensor
336
337
338
339
    w1_scale: torch.Tensor | None
    w2_scale: torch.Tensor | None
    w1_gs: torch.Tensor | None = None
    w2_gs: torch.Tensor | None = None
340
341
342
343

    def describe(self):
        s = ""
        s += "== Weight Tensors: \n"
344
345
346
347
348
349
        s += f" - {_describe_tensor(self.w1, 'w1')} \n"
        s += f" - {_describe_tensor(self.w2, 'w2')} \n"
        s += f" - {_describe_tensor(self.w1_scale, 'w1_scale')} \n"
        s += f" - {_describe_tensor(self.w2_scale, 'w2_scale')} \n"
        s += f" - {_describe_tensor(self.w1_gs, 'w1_gs')} \n"
        s += f" - {_describe_tensor(self.w2_gs, 'w2_gs')} \n"
350
351
        return s

352
353
    def is_quantized(self) -> bool:
        # or w1_scale is not None?
354
355
356
357
358
        return (
            self.w1.dtype == torch.float8_e4m3fn
            or self.w1.dtype == torch.uint8
            or self.w1.dtype == torch.int8
        )
359

360
    def to_current_device(self):
361
        device = torch.accelerator.current_device_index()
362
363
        self.w1 = self.w1.to(device=device)
        self.w2 = self.w2.to(device=device)
364

365
366
367
368
        if self.w1_scale is not None:
            self.w1_scale = self.w1_scale.to(device=device)
        if self.w2_scale is not None:
            self.w2_scale = self.w2_scale.to(device=device)
369

370
        if self.w1_gs is not None:
371
372
373
            self.w1_gs = self.w1_gs.to(device=device)
        if self.w2_gs is not None:
            self.w2_gs = self.w2_gs.to(device=device)
374

375
    def slice_weights(self, rank: int, num_local_experts: int) -> "WeightTensors":
376
377
378
379
        s = rank * num_local_experts
        e = s + num_local_experts
        w1 = self.w1[s:e, :, :]
        w2 = self.w2[s:e, :, :]
380
381
        w1_scale = self.w1_scale[s:e, :, :] if self.w1_scale is not None else None
        w2_scale = self.w2_scale[s:e, :, :] if self.w2_scale is not None else None
382
383
        w1_gs = self.w1_gs[s:e] if self.w1_gs is not None else None
        w2_gs = self.w2_gs[s:e] if self.w2_gs is not None else None
384

385
        return WeightTensors(w1, w2, w1_scale, w2_scale, w1_gs, w2_gs)
386

387
388
389
    @staticmethod
    def make(config: Config) -> "WeightTensors":
        (_, w1, w1_scale, w1_gs), (_, w2, w2_scale, w2_gs) = make_test_weights(
390
391
392
            e=config.E,
            n=config.N,
            k=config.K,
393
394
395
            in_dtype=config.dtype,
            quant_dtype=config.quant_dtype,
            block_shape=config.quant_block_shape,
396
397
            # or config.is_per_out_ch_quant
            per_out_ch_quant=config.is_per_act_token_quant,
398
399
400
        )
        return WeightTensors(
            w1=w1, w2=w2, w1_scale=w1_scale, w2_scale=w2_scale, w1_gs=w1_gs, w2_gs=w2_gs
401
402
403
404
405
406
        )


@dataclass
class RankTensors:
    hidden_states: torch.Tensor
407
    hidden_states_scale: torch.Tensor | None
408
409
410

    topk_weights: torch.Tensor
    topk_ids: torch.Tensor
411
    expert_map: torch.Tensor | None
412
413
414
415

    def describe(self):
        s = ""
        s += "== Rank Tensors: \n"
416
417
418
419
420
        s += f" - {_describe_tensor(self.hidden_states, 'HS')} \n"
        s += f" - {_describe_tensor(self.hidden_states_scale, 'HS_scale')} \n"
        s += f" - {_describe_tensor(self.topk_weights, 'topk_weights')} \n"
        s += f" - {_describe_tensor(self.topk_ids, 'topk_ids')} \n"
        s += f" - {_describe_tensor(self.expert_map, 'expert_map')} \n"
421
422
423
424
        return s

    @staticmethod
    def make_hidden_states(
425
        config: Config,
426
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
427
428
429
430
        """
        Return hidden_states
        """
        m, k, dtype = (config.M, config.K, config.dtype)
431
432
        device = torch.accelerator.current_device_index()
        a = torch.randn((m, k), device=device, dtype=dtype) / 15.0
433
434
435
436
437
438
439

        if config.quant_dtype is None:
            return a, None

        # We dequant and use that as hidden_states so the tests are stable.
        # quantizing and dequantizing yield slightly different results
        # depending on the hardware. Here we, quantize and dequantize
440
        # first - so further quantize and dequantize will yield the same
441
442
        # values.
        if config.is_per_tensor_act_quant:
443
            a_q, a_scales = ops.scaled_fp8_quant(a, use_per_token_if_dynamic=False)
444
445
446
            return a_q.float().mul(a_scales).to(dtype), a_scales

        if config.is_per_act_token_quant:
447
            a_q, a_scales = ops.scaled_fp8_quant(a, use_per_token_if_dynamic=True)
448
449
450
451
452
            return a_q.float().mul(a_scales).to(dtype), None

        assert config.quant_block_shape is not None
        block_k = config.quant_block_shape[1]
        a_q, a_scales = per_token_cast_to_fp8(a, block_size=block_k)
453
454
455
        return a_q.float().view((-1, block_k)).mul(a_scales.view(-1, 1)).view(m, k).to(
            dtype
        ), None
456
457
458
459
460

    @staticmethod
    def make(config: Config, pgi: ProcessGroupInfo):
        dtype = config.dtype
        topk, m, _ = (config.topk, config.M, config.K)
461
        hidden_states, hidden_states_scale = RankTensors.make_hidden_states(config)
462

463
464
465
        num_local_experts, global_num_experts = (config.num_local_experts, config.E)
        score = torch.randn((m, global_num_experts), device="cuda", dtype=dtype)
        topk_weights, topk_ids, _ = fused_topk(hidden_states, score, topk, False)
466
467

        # distribute topk_ids evenly
468
        device = torch.accelerator.current_device_index()
469
470
        for mi in range(m):
            topk_ids[mi] = torch.randperm(config.E)[:topk]
471
        topk_ids = topk_ids.to(device=device)
472
473

        expert_map = None
474
        if config.world_size > 1 and config.supports_expert_map():
475
476
477
            expert_map = torch.full(
                (global_num_experts,), fill_value=-1, dtype=torch.int32
            )
478
479
480
            s = pgi.rank * num_local_experts
            e = s + num_local_experts
            expert_map[s:e] = torch.tensor(list(range(num_local_experts)))
481
            expert_map = expert_map.to(device=device, dtype=torch.int32)
482
483
484
485
486
487
488
489
490
491

        return RankTensors(
            hidden_states=hidden_states,
            hidden_states_scale=hidden_states_scale,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            expert_map=expert_map,
        )


492
493
494
def reference_moe_impl(
    config: Config, weights: WeightTensors, rank_tensors: RankTensors
) -> torch.Tensor:
495
496
497
498
499
500
501
502
503
504
505
506
    if config.quant_dtype == "nvfp4":
        quant_blocksize = 16
        dtype = config.dtype

        w1_q = weights.w1
        w1_blockscale = weights.w1_scale
        w1_gs = weights.w1_gs

        w2_q = weights.w2
        w2_blockscale = weights.w2_scale
        w2_gs = weights.w2_gs

507
508
509
510
        a_global_scale = (
            (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX)
            / torch.amax(rank_tensors.hidden_states.flatten(), dim=-1)
        ).to(torch.float32)
511
512
513
514
515
516
517
518
519
520
521
522

        assert w1_gs is not None
        assert w2_gs is not None
        assert w1_blockscale is not None
        assert w2_blockscale is not None

        assert w1_blockscale.shape[1] % 128 == 0
        assert w1_blockscale.shape[2] % 4 == 0
        assert w2_blockscale.shape[1] % 128 == 0
        assert w2_blockscale.shape[2] % 4 == 0

        a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(
523
524
            rank_tensors.hidden_states, a_global_scale
        )
525

526
527
528
529
530
531
532
533
        a = dequantize_nvfp4_to_dtype(
            a_fp4,
            a_scale_interleaved,
            a_global_scale,
            dtype=dtype,
            device=a_fp4.device,
            block_size=quant_blocksize,
        )
534
535
536
537
538
539
540
541
542

        e = w1_q.shape[0]
        n = w1_q.shape[1] // 2
        k = w2_q.shape[1]

        w1 = torch.zeros((e, 2 * n, k), device="cuda", dtype=dtype)
        w2 = torch.zeros((e, k, n), device="cuda", dtype=dtype)

        for idx in range(0, e):
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
            w1[idx] = dequantize_nvfp4_to_dtype(
                w1_q[idx],
                w1_blockscale[idx],
                w1_gs[idx],
                dtype=dtype,
                device=w1_q.device,
                block_size=quant_blocksize,
            )
            w2[idx] = dequantize_nvfp4_to_dtype(
                w2_q[idx],
                w2_blockscale[idx],
                w2_gs[idx],
                dtype=dtype,
                device=w2_q.device,
                block_size=quant_blocksize,
            )
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
        a_scale = None
        w1_scale = None
        w2_scale = None
        quant_dtype = None
        per_act_token_quant = False
        block_shape = None
    else:
        a = rank_tensors.hidden_states
        a_scale = rank_tensors.hidden_states_scale
        w1 = weights.w1
        w1_scale = weights.w1_scale
        w2 = weights.w2
        w2_scale = weights.w2_scale
        quant_dtype = config.quant_dtype
        per_act_token_quant = config.is_per_act_token_quant
        block_shape = config.quant_block_shape

576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
    return torch_experts(
        a=a,
        w1=w1,
        w2=w2,
        topk_weight=rank_tensors.topk_weights,
        topk_ids=rank_tensors.topk_ids,
        global_num_experts=config.E,
        expert_map=None,
        w1_scale=w1_scale,
        w2_scale=w2_scale,
        a1_scale=a_scale,
        quant_dtype=quant_dtype,
        per_act_token_quant=per_act_token_quant,
        block_shape=block_shape,
        apply_router_weights_on_input=config.topk == 1
        and config.supports_apply_weight_on_input(),
    )
593
594


595
def _make_gscale(num_experts: int) -> torch.Tensor:
596
    return torch.ones(
597
598
599
        (num_experts,),
        device=torch.accelerator.current_device_index(),
        dtype=torch.float32,
600
    )
601
602


603
604
605
def make_modular_kernel(
    config: Config,
    vllm_config: VllmConfig,
606
    quant_config: FusedMoEQuantConfig,
607
) -> mk.FusedMoEKernel:
608
609
610
    # make moe config
    moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make(
        tp_size_=get_tensor_model_parallel_world_size(),
611
        pcp_size_=get_pcp_group().world_size,
612
        dp_size_=get_dp_group().world_size,
613
        sp_size_=1,
614
615
        vllm_parallel_config=vllm_config.parallel_config,
    )
616

617
618
619
620
    moe = FusedMoEConfig(
        num_experts=config.E,
        experts_per_token=config.topk,
        hidden_dim=config.K,
621
        intermediate_size_per_partition=config.N,
622
        num_local_experts=config.num_local_experts,
623
        num_logical_experts=config.E,
624
625
626
        moe_parallel_config=moe_parallel_config,
        in_dtype=config.dtype,
        max_num_tokens=next_power_of_2(config.M),
627
        activation=MoEActivation.SILU,
628
629
        device=vllm_config.device_config.device,
        routing_method=RoutingMethodType.DeepSeekV3,
630
631
    )

632
633
634
635
    prepare_finalize = maybe_make_prepare_finalize(
        moe=moe,
        quant_config=quant_config,
        allow_new_interface=True,
636
    )
637
    assert prepare_finalize is not None
638
639
640
641

    fused_experts = make_fused_experts(
        config.fused_experts_type,
        moe,
642
        quant_config,
643
        prepare_finalize.num_dispatchers(),
644
        config.N,
645
    )
646

647
    modular_kernel = mk.FusedMoEKernel(
648
649
        prepare_finalize=prepare_finalize,
        fused_experts=fused_experts,
650
        inplace=False,
651
    )
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668

    return modular_kernel


def run_modular_kernel(
    pgi: ProcessGroupInfo,
    vllm_config: VllmConfig,
    config: Config,
    weights: WeightTensors,
    rank_tensors: RankTensors,
) -> torch.Tensor:
    assert isinstance(config.Ms, int)
    assert isinstance(config.topks, int)

    # weights for rank
    rank_weights = weights.slice_weights(pgi.rank, config.num_local_experts)

669
670
671
672
673
674
675
676
677
678
    if config.quant_dtype == "nvfp4":
        gscale = _make_gscale(config.num_local_experts)
    else:
        gscale = None

    quant_config = FusedMoEQuantConfig.make(
        config.quant_dtype,
        w1_scale=rank_weights.w1_scale,
        w2_scale=rank_weights.w2_scale,
        a1_scale=rank_tensors.hidden_states_scale,
679
680
        g1_alphas=(1 / rank_weights.w1_gs) if rank_weights.w1_gs is not None else None,
        g2_alphas=(1 / rank_weights.w2_gs) if rank_weights.w2_gs is not None else None,
681
682
683
684
685
686
687
688
689
690
691
692
        a1_gscale=gscale,
        a2_gscale=gscale,
        block_shape=config.quant_block_shape,
        per_act_token_quant=config.is_per_act_token_quant,
        per_out_ch_quant=config.is_per_out_ch_quant,
    )

    mk = make_modular_kernel(config, vllm_config, quant_config)

    # impls might update the tensor in place
    hidden_states = rank_tensors.hidden_states.clone()

693
    topk_ids = rank_tensors.topk_ids.to(mk.prepare_finalize.topk_indices_dtype())
694
695

    mk_kwargs = {
696
697
698
699
700
        "hidden_states": hidden_states,
        "w1": rank_weights.w1,
        "w2": rank_weights.w2,
        "topk_weights": rank_tensors.topk_weights,
        "topk_ids": topk_ids,
701
        "activation": MoEActivation.SILU,
702
703
704
705
        "expert_map": rank_tensors.expert_map,
        "global_num_experts": config.E,
        "apply_router_weight_on_input": config.topk == 1
        and config.supports_apply_weight_on_input(),
706
    }
707
708

    num_tokens = rank_tensors.hidden_states.shape[0]
709
710
711
    num_tokens_across_dp = torch.tensor(
        [num_tokens] * config.world_size, device="cuda", dtype=torch.int
    )
712
713

    with set_forward_context(
714
715
716
717
        None,
        vllm_config,
        num_tokens=num_tokens,
        num_tokens_across_dp=num_tokens_across_dp,
718
    ):
719
        out = mk.apply(**mk_kwargs)
720
721

    return out