"vllm/v1/executor/uniproc_executor.py" did not exist on "4fdd6f5cbf877de7c4de33086fe41bb0ac1d3cf3"
common.py 23.9 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
4
from typing import Any
5
6
7
8
9

import torch

import vllm._custom_ops as ops
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
10
from tests.kernels.moe.utils import make_test_weights, per_token_cast_to_fp8
11
12
13
14
15
from tests.kernels.quantization.nvfp4_utils import (
    FLOAT4_E2M1_MAX,
    FLOAT8_E4M3_MAX,
    dequantize_nvfp4_to_dtype,
)
16
17
from tests.kernels.utils import torch_experts
from vllm.config import VllmConfig
18
19
20
21
22
from vllm.distributed import (
    get_dp_group,
    get_pcp_group,
    get_tensor_model_parallel_world_size,
)
23
from vllm.forward_context import set_forward_context
24
from vllm.model_executor.layers.fused_moe import fused_topk
25
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
26
27
28
from vllm.model_executor.layers.fused_moe.all2all_utils import (
    maybe_make_prepare_finalize,
)
29
from vllm.model_executor.layers.fused_moe.config import (
30
31
32
    FusedMoEConfig,
    FusedMoEParallelConfig,
    FusedMoEQuantConfig,
33
    RoutingMethodType,
34
)
35
36
37
38
39
40
41
42
from vllm.model_executor.layers.quantization.utils.quant_utils import (
    kFp8Dynamic128Sym,
    kFp8DynamicTensorSym,
    kFp8DynamicTokenSym,
    kFp8Static128BlockSym,
    kFp8StaticChannelSym,
    kFp8StaticTensorSym,
)
43
44
45
46
47
48
from vllm.utils.import_utils import (
    has_aiter,
    has_deep_ep,
    has_deep_gemm,
    has_mori,
)
49

50
51
52
53
54
55
from .mk_objects import (
    TestMoEQuantConfig,
    expert_info,
    make_fused_experts,
    prepare_finalize_info,
)
56
57
58
from .parallel_utils import ProcessGroupInfo


59
def _describe_tensor(t: torch.Tensor | None, name: str) -> str:
60
61
62
63
64
65
66
67
    if t is None:
        return f"{name} : None"
    else:
        return f"{name} : {t.shape} {t.dtype} {t.device}"


@dataclass
class Config:
68
    Ms: list[int] | int
69
70
71
    K: int
    N: int
    E: int
72
    topks: list[int] | int
73
    dtype: torch.dtype
74
    quant_config: TestMoEQuantConfig | None
75
76

    prepare_finalize_type: mk.FusedMoEPrepareAndFinalize
77
    fused_experts_type: mk.FusedMoEExperts
78
79
80

    world_size: int

81
    torch_trace_dir_path: str | None = None
82

83
84
    def __post_init__(self):
        if self.quant_config is None:
85
            self.quant_config = TestMoEQuantConfig(None, False, False, None)
86

87
88
    def describe(self) -> str:
        s = ""
89
90
91
92
93
94
95
96
97
98
99
        s += "== Config:\n"
        s += f" world_size={self.world_size}\n"
        s += f" PF={self.prepare_finalize_type.__name__}\n"
        s += f" FE={self.fused_experts_type.__name__}\n"
        s += f" E={self.E}\n"
        s += f" Ms={self.Ms}\n"
        s += f" N={self.N}\n"
        s += f" K={self.K}\n"
        s += f" topk={self.topks}\n"
        s += f" dtype={self.dtype}\n"
        s += " Quant:\n"
100
        if self.quant_config is not None:
101
102
103
104
            s += f"     q_dtype={self.quant_dtype}\n"
            s += f"     q_block_shape={self.quant_block_shape}\n"
            s += f"     q_per_out_ch_quant={self.is_per_out_ch_quant}\n"
            s += f"     q_per_act_token={self.is_per_act_token_quant}\n"
105
        else:
106
            s += "     quant=None\n"
107
108
109
110
111
112
113
114
        return s

    @property
    def M(self) -> int:
        assert isinstance(self.Ms, int)
        return self.Ms

    @property
115
    def quant_dtype(self) -> torch.dtype | str | None:
116
        assert self.quant_config is not None
117
118
119
120
        return self.quant_config.quant_dtype

    @property
    def is_per_act_token_quant(self) -> bool:
121
        assert self.quant_config is not None
122
123
124
125
        return self.quant_config.per_act_token_quant

    @property
    def is_per_tensor_act_quant(self) -> bool:
126
        return not self.is_per_act_token_quant and self.quant_block_shape is None
127
128
129

    @property
    def is_per_out_ch_quant(self) -> bool:
130
        assert self.quant_config is not None
131
132
133
        return self.quant_config.per_out_ch_quant

    @property
134
    def quant_block_shape(self) -> list[int] | None:
135
        assert self.quant_config is not None
136
137
138
139
140
141
142
143
144
145
146
147
148
        return self.quant_config.block_shape

    @property
    def topk(self) -> int:
        assert isinstance(self.topks, int)
        return self.topks

    @property
    def num_local_experts(self) -> int:
        return self.E // self.world_size

    def make_env_data(self) -> tuple[VllmConfig, dict[Any, Any]]:
        """
149
        make env data for vllm launch.
150
151
152
153
154
155
156
157
        """
        vllm_config = VllmConfig()
        vllm_config.parallel_config.data_parallel_size = self.world_size
        vllm_config.parallel_config.enable_expert_parallel = True

        env_dict = {
            "VLLM_USE_DEEP_GEMM": str(int(self.needs_deep_gemm())),
        }
158

159
        vllm_config.parallel_config.all2all_backend = self.all2all_backend()
160

161
162
        return vllm_config, env_dict

163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
    def fe_supports_quant_scheme(self) -> bool:
        """Check if the fused experts class supports this quant config.
        See https://github.com/ROCm/aiter/issues/2419 for AITER gaps."""
        if self.quant_config is None or self.quant_dtype is None:
            return True
        if self.quant_dtype != torch.float8_e4m3fn:
            return True
        # Derive QuantKeys from test config
        if self.quant_block_shape is not None:
            w_key = kFp8Static128BlockSym
            a_key = kFp8Dynamic128Sym
        elif self.is_per_out_ch_quant:
            w_key = kFp8StaticChannelSym
            a_key = (
                kFp8DynamicTokenSym
                if self.is_per_act_token_quant
                else kFp8StaticTensorSym
            )
        else:
            w_key = kFp8StaticTensorSym
            a_key = (
                kFp8DynamicTensorSym
                if self.is_per_act_token_quant
                else kFp8StaticTensorSym
            )
        fe_cls = self.fused_experts_type
        if hasattr(fe_cls, "_supports_quant_scheme"):
            try:
                return fe_cls._supports_quant_scheme(w_key, a_key)
            except NotImplementedError:
                pass
        return True

196
    def is_fp8_block_quantized(self):
197
198
199
200
        return (
            self.quant_dtype == torch.float8_e4m3fn
            and self.quant_block_shape is not None
        )
201
202

    def is_batched_prepare_finalize(self):
203
        info = prepare_finalize_info(self.prepare_finalize_type)
204
        return mk.FusedMoEActivationFormat.BatchedExperts == info.activation_format
205
206

    def is_batched_fused_experts(self):
207
        info = expert_info(self.fused_experts_type)
208
        return mk.FusedMoEActivationFormat.BatchedExperts == info.activation_format
209
210

    def is_standard_fused_experts(self):
211
212
213
214
215
216
217
218
219
220
221
222
223
224
        info = expert_info(self.fused_experts_type)
        return mk.FusedMoEActivationFormat.Standard == info.activation_format

    def fe_supported_types(self):
        info = expert_info(self.fused_experts_type)
        return info.supported_dtypes

    def pf_supported_types(self):
        info = prepare_finalize_info(self.prepare_finalize_type)
        return info.supported_dtypes

    def is_block_quant_supported(self):
        info = expert_info(self.fused_experts_type)
        return info.blocked_quantization_support
225

226
227
228
229
230
231
232
    def supports_expert_map(self):
        info = expert_info(self.fused_experts_type)
        return info.supports_expert_map

    def supports_apply_weight_on_input(self):
        info = prepare_finalize_info(self.prepare_finalize_type)
        return info.supports_apply_weight_on_input
233
234

    def needs_deep_gemm(self):
235
236
        info = expert_info(self.fused_experts_type)
        return info.needs_deep_gemm
237
238

    def needs_deep_ep(self):
239
        info = prepare_finalize_info(self.prepare_finalize_type)
240
241
242
243
        return (
            info.backend == "deepep_high_throughput"
            or info.backend == "deepep_low_latency"
        )
244

245
246
247
248
249
250
251
252
    def needs_aiter(self):
        info = expert_info(self.fused_experts_type)
        return info.needs_aiter

    def needs_mori(self):
        info = prepare_finalize_info(self.prepare_finalize_type)
        return info.backend == "mori"

253
    def all2all_backend(self):
254
255
        info = prepare_finalize_info(self.prepare_finalize_type)
        return info.backend
256

257
    def is_valid(self) -> tuple[bool, str | None]:
258
259
260
        # Check prepare-finalize and fused-experts compatibility
        if self.is_batched_prepare_finalize():
            if not self.is_batched_fused_experts():
261
                return False, "Mismatched format."
262
263
        else:
            if not self.is_standard_fused_experts():
264
                return False, "Mismatched format."
265
266

        # Check quantization sanity
267
268
269
270
271
        if (
            int(self.is_per_act_token_quant)
            + int(self.is_per_tensor_act_quant)
            + int(self.quant_block_shape is not None)
        ) > 1:
272
            # invalid quant config
273
            return False, f"Bad quant_config {self.quant_config}."
274

275
276
        # check type support
        if self.quant_dtype is None:
277
278
279
280
            if (
                self.dtype not in self.pf_supported_types()
                or self.dtype not in self.fe_supported_types()
            ):
281
282
283
284
285
                return False, (
                    f"Unsupported type {self.dtype} not in "
                    f"{self.pf_supported_types()} and "
                    f"{self.fe_supported_types()}."
                )
286
        else:
287
288
289
290
            if (
                self.quant_dtype not in self.pf_supported_types()
                or self.quant_dtype not in self.fe_supported_types()
            ):
291
292
293
294
295
                return False, (
                    f"Unsupported quant type {self.quant_dtype} "
                    f"not in {self.pf_supported_types()} and "
                    f"{self.fe_supported_types()}."
                )
296

297
298
299
300
301
302
303
304
305
        # Check quant scheme compatibility with fused experts class
        if not self.fe_supports_quant_scheme():
            return False, (
                f"FE {self.fused_experts_type.__name__} does not support "
                f"quant scheme (per_out_ch={self.is_per_out_ch_quant}, "
                f"per_act_token={self.is_per_act_token_quant}, "
                f"block={self.quant_block_shape})"
            )

306
307
308
        # Check block quantization support
        is_block_quantized = self.quant_block_shape is not None
        if is_block_quantized and self.quant_dtype is None:
309
310
            return False, "No block quantization support."

311
        if is_block_quantized and not self.is_block_quant_supported():
312
            return False, "Mismatched block quantization support."
313
314

        # deep_gemm only works with block-quantized
315
        if self.needs_deep_gemm() and not is_block_quantized:
316
            return False, "Needs DeepGEMM but not block quantized."
317

318
        # Check dependencies (turn into asserts?)
319
        if self.needs_deep_ep() and not has_deep_ep():
320
            return False, "Needs DeepEP, but DeepEP not available."
321
        if self.needs_deep_gemm() and not has_deep_gemm():
322
            return False, "Needs DeepGEMM, but DeepGEMM not available."
323
324
325
326
        if self.needs_aiter() and not has_aiter():  # noqa: SIM103
            return False, "Needs Aiter, but Aiter not available."
        if self.needs_mori() and not has_mori():  # noqa: SIM103
            return False, "Needs MoRI, but MoRI not available."
327

328
        return True, None
329
330
331
332
333
334


@dataclass
class WeightTensors:
    w1: torch.Tensor
    w2: torch.Tensor
335
336
337
338
    w1_scale: torch.Tensor | None
    w2_scale: torch.Tensor | None
    w1_gs: torch.Tensor | None = None
    w2_gs: torch.Tensor | None = None
339
340
341
342

    def describe(self):
        s = ""
        s += "== Weight Tensors: \n"
343
344
345
346
347
348
        s += f" - {_describe_tensor(self.w1, 'w1')} \n"
        s += f" - {_describe_tensor(self.w2, 'w2')} \n"
        s += f" - {_describe_tensor(self.w1_scale, 'w1_scale')} \n"
        s += f" - {_describe_tensor(self.w2_scale, 'w2_scale')} \n"
        s += f" - {_describe_tensor(self.w1_gs, 'w1_gs')} \n"
        s += f" - {_describe_tensor(self.w2_gs, 'w2_gs')} \n"
349
350
        return s

351
352
    def is_quantized(self) -> bool:
        # or w1_scale is not None?
353
354
355
356
357
        return (
            self.w1.dtype == torch.float8_e4m3fn
            or self.w1.dtype == torch.uint8
            or self.w1.dtype == torch.int8
        )
358

359
    def to_current_device(self):
360
        device = torch.accelerator.current_device_index()
361
362
        self.w1 = self.w1.to(device=device)
        self.w2 = self.w2.to(device=device)
363

364
365
366
367
        if self.w1_scale is not None:
            self.w1_scale = self.w1_scale.to(device=device)
        if self.w2_scale is not None:
            self.w2_scale = self.w2_scale.to(device=device)
368

369
        if self.w1_gs is not None:
370
371
372
            self.w1_gs = self.w1_gs.to(device=device)
        if self.w2_gs is not None:
            self.w2_gs = self.w2_gs.to(device=device)
373

374
    def slice_weights(self, rank: int, num_local_experts: int) -> "WeightTensors":
375
376
377
378
        s = rank * num_local_experts
        e = s + num_local_experts
        w1 = self.w1[s:e, :, :]
        w2 = self.w2[s:e, :, :]
379
380
        w1_scale = self.w1_scale[s:e, :, :] if self.w1_scale is not None else None
        w2_scale = self.w2_scale[s:e, :, :] if self.w2_scale is not None else None
381
382
        w1_gs = self.w1_gs[s:e] if self.w1_gs is not None else None
        w2_gs = self.w2_gs[s:e] if self.w2_gs is not None else None
383

384
        return WeightTensors(w1, w2, w1_scale, w2_scale, w1_gs, w2_gs)
385

386
387
388
    @staticmethod
    def make(config: Config) -> "WeightTensors":
        (_, w1, w1_scale, w1_gs), (_, w2, w2_scale, w2_gs) = make_test_weights(
389
390
391
            e=config.E,
            n=config.N,
            k=config.K,
392
393
394
            in_dtype=config.dtype,
            quant_dtype=config.quant_dtype,
            block_shape=config.quant_block_shape,
395
396
            # or config.is_per_out_ch_quant
            per_out_ch_quant=config.is_per_act_token_quant,
397
398
399
        )
        return WeightTensors(
            w1=w1, w2=w2, w1_scale=w1_scale, w2_scale=w2_scale, w1_gs=w1_gs, w2_gs=w2_gs
400
401
402
403
404
405
        )


@dataclass
class RankTensors:
    hidden_states: torch.Tensor
406
    hidden_states_scale: torch.Tensor | None
407
408
409

    topk_weights: torch.Tensor
    topk_ids: torch.Tensor
410
    expert_map: torch.Tensor | None
411
412
413
414

    def describe(self):
        s = ""
        s += "== Rank Tensors: \n"
415
416
417
418
419
        s += f" - {_describe_tensor(self.hidden_states, 'HS')} \n"
        s += f" - {_describe_tensor(self.hidden_states_scale, 'HS_scale')} \n"
        s += f" - {_describe_tensor(self.topk_weights, 'topk_weights')} \n"
        s += f" - {_describe_tensor(self.topk_ids, 'topk_ids')} \n"
        s += f" - {_describe_tensor(self.expert_map, 'expert_map')} \n"
420
421
422
423
        return s

    @staticmethod
    def make_hidden_states(
424
        config: Config,
425
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
426
427
428
429
        """
        Return hidden_states
        """
        m, k, dtype = (config.M, config.K, config.dtype)
430
431
        device = torch.accelerator.current_device_index()
        a = torch.randn((m, k), device=device, dtype=dtype) / 15.0
432
433
434
435
436
437
438

        if config.quant_dtype is None:
            return a, None

        # We dequant and use that as hidden_states so the tests are stable.
        # quantizing and dequantizing yield slightly different results
        # depending on the hardware. Here we, quantize and dequantize
439
        # first - so further quantize and dequantize will yield the same
440
441
        # values.
        if config.is_per_tensor_act_quant:
442
            a_q, a_scales = ops.scaled_fp8_quant(a, use_per_token_if_dynamic=False)
443
444
445
            return a_q.float().mul(a_scales).to(dtype), a_scales

        if config.is_per_act_token_quant:
446
            a_q, a_scales = ops.scaled_fp8_quant(a, use_per_token_if_dynamic=True)
447
448
449
450
451
            return a_q.float().mul(a_scales).to(dtype), None

        assert config.quant_block_shape is not None
        block_k = config.quant_block_shape[1]
        a_q, a_scales = per_token_cast_to_fp8(a, block_size=block_k)
452
453
454
        return a_q.float().view((-1, block_k)).mul(a_scales.view(-1, 1)).view(m, k).to(
            dtype
        ), None
455
456
457
458
459

    @staticmethod
    def make(config: Config, pgi: ProcessGroupInfo):
        dtype = config.dtype
        topk, m, _ = (config.topk, config.M, config.K)
460
        hidden_states, hidden_states_scale = RankTensors.make_hidden_states(config)
461

462
463
464
        num_local_experts, global_num_experts = (config.num_local_experts, config.E)
        score = torch.randn((m, global_num_experts), device="cuda", dtype=dtype)
        topk_weights, topk_ids, _ = fused_topk(hidden_states, score, topk, False)
465
466

        # distribute topk_ids evenly
467
        device = torch.accelerator.current_device_index()
468
469
        for mi in range(m):
            topk_ids[mi] = torch.randperm(config.E)[:topk]
470
        topk_ids = topk_ids.to(device=device)
471
472

        expert_map = None
473
        if config.world_size > 1 and config.supports_expert_map():
474
475
476
            expert_map = torch.full(
                (global_num_experts,), fill_value=-1, dtype=torch.int32
            )
477
478
479
            s = pgi.rank * num_local_experts
            e = s + num_local_experts
            expert_map[s:e] = torch.tensor(list(range(num_local_experts)))
480
            expert_map = expert_map.to(device=device, dtype=torch.int32)
481
482
483
484
485
486
487
488
489
490

        return RankTensors(
            hidden_states=hidden_states,
            hidden_states_scale=hidden_states_scale,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            expert_map=expert_map,
        )


491
492
493
def reference_moe_impl(
    config: Config, weights: WeightTensors, rank_tensors: RankTensors
) -> torch.Tensor:
494
495
496
497
498
499
500
501
502
503
504
505
    if config.quant_dtype == "nvfp4":
        quant_blocksize = 16
        dtype = config.dtype

        w1_q = weights.w1
        w1_blockscale = weights.w1_scale
        w1_gs = weights.w1_gs

        w2_q = weights.w2
        w2_blockscale = weights.w2_scale
        w2_gs = weights.w2_gs

506
507
508
509
        a_global_scale = (
            (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX)
            / torch.amax(rank_tensors.hidden_states.flatten(), dim=-1)
        ).to(torch.float32)
510
511
512
513
514
515
516
517
518
519
520
521

        assert w1_gs is not None
        assert w2_gs is not None
        assert w1_blockscale is not None
        assert w2_blockscale is not None

        assert w1_blockscale.shape[1] % 128 == 0
        assert w1_blockscale.shape[2] % 4 == 0
        assert w2_blockscale.shape[1] % 128 == 0
        assert w2_blockscale.shape[2] % 4 == 0

        a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(
522
523
            rank_tensors.hidden_states, a_global_scale
        )
524

525
526
527
528
529
530
531
532
        a = dequantize_nvfp4_to_dtype(
            a_fp4,
            a_scale_interleaved,
            a_global_scale,
            dtype=dtype,
            device=a_fp4.device,
            block_size=quant_blocksize,
        )
533
534
535
536
537
538
539
540
541

        e = w1_q.shape[0]
        n = w1_q.shape[1] // 2
        k = w2_q.shape[1]

        w1 = torch.zeros((e, 2 * n, k), device="cuda", dtype=dtype)
        w2 = torch.zeros((e, k, n), device="cuda", dtype=dtype)

        for idx in range(0, e):
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
            w1[idx] = dequantize_nvfp4_to_dtype(
                w1_q[idx],
                w1_blockscale[idx],
                w1_gs[idx],
                dtype=dtype,
                device=w1_q.device,
                block_size=quant_blocksize,
            )
            w2[idx] = dequantize_nvfp4_to_dtype(
                w2_q[idx],
                w2_blockscale[idx],
                w2_gs[idx],
                dtype=dtype,
                device=w2_q.device,
                block_size=quant_blocksize,
            )
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
        a_scale = None
        w1_scale = None
        w2_scale = None
        quant_dtype = None
        per_act_token_quant = False
        block_shape = None
    else:
        a = rank_tensors.hidden_states
        a_scale = rank_tensors.hidden_states_scale
        w1 = weights.w1
        w1_scale = weights.w1_scale
        w2 = weights.w2
        w2_scale = weights.w2_scale
        quant_dtype = config.quant_dtype
        per_act_token_quant = config.is_per_act_token_quant
        block_shape = config.quant_block_shape

575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
    return torch_experts(
        a=a,
        w1=w1,
        w2=w2,
        topk_weight=rank_tensors.topk_weights,
        topk_ids=rank_tensors.topk_ids,
        global_num_experts=config.E,
        expert_map=None,
        w1_scale=w1_scale,
        w2_scale=w2_scale,
        a1_scale=a_scale,
        quant_dtype=quant_dtype,
        per_act_token_quant=per_act_token_quant,
        block_shape=block_shape,
        apply_router_weights_on_input=config.topk == 1
        and config.supports_apply_weight_on_input(),
    )
592
593


594
def _make_gscale(num_experts: int) -> torch.Tensor:
595
    return torch.ones(
596
597
598
        (num_experts,),
        device=torch.accelerator.current_device_index(),
        dtype=torch.float32,
599
    )
600
601


602
603
604
def make_modular_kernel(
    config: Config,
    vllm_config: VllmConfig,
605
    quant_config: FusedMoEQuantConfig,
606
) -> mk.FusedMoEKernel:
607
608
    def next_power_of_2(x):
        import math
609

610
611
        if x == 0:
            return 1
612
        return 2 ** math.ceil(math.log2(x))
613
614
615
616

    # make moe config
    moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make(
        tp_size_=get_tensor_model_parallel_world_size(),
617
        pcp_size_=get_pcp_group().world_size,
618
        dp_size_=get_dp_group().world_size,
619
        sp_size_=1,
620
621
        vllm_parallel_config=vllm_config.parallel_config,
    )
622

623
624
625
626
    moe = FusedMoEConfig(
        num_experts=config.E,
        experts_per_token=config.topk,
        hidden_dim=config.K,
627
        intermediate_size_per_partition=config.N,
628
        num_local_experts=config.num_local_experts,
629
        num_logical_experts=config.E,
630
631
632
        moe_parallel_config=moe_parallel_config,
        in_dtype=config.dtype,
        max_num_tokens=next_power_of_2(config.M),
633
        activation=MoEActivation.SILU,
634
635
        device=vllm_config.device_config.device,
        routing_method=RoutingMethodType.DeepSeekV3,
636
637
    )

638
639
640
641
    prepare_finalize = maybe_make_prepare_finalize(
        moe=moe,
        quant_config=quant_config,
        allow_new_interface=True,
642
    )
643
    assert prepare_finalize is not None
644
645
646
647

    fused_experts = make_fused_experts(
        config.fused_experts_type,
        moe,
648
        quant_config,
649
        prepare_finalize.num_dispatchers(),
650
        config.N,
651
    )
652

653
    modular_kernel = mk.FusedMoEKernel(
654
655
        prepare_finalize=prepare_finalize,
        fused_experts=fused_experts,
656
        inplace=False,
657
    )
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674

    return modular_kernel


def run_modular_kernel(
    pgi: ProcessGroupInfo,
    vllm_config: VllmConfig,
    config: Config,
    weights: WeightTensors,
    rank_tensors: RankTensors,
) -> torch.Tensor:
    assert isinstance(config.Ms, int)
    assert isinstance(config.topks, int)

    # weights for rank
    rank_weights = weights.slice_weights(pgi.rank, config.num_local_experts)

675
676
677
678
679
680
681
682
683
684
    if config.quant_dtype == "nvfp4":
        gscale = _make_gscale(config.num_local_experts)
    else:
        gscale = None

    quant_config = FusedMoEQuantConfig.make(
        config.quant_dtype,
        w1_scale=rank_weights.w1_scale,
        w2_scale=rank_weights.w2_scale,
        a1_scale=rank_tensors.hidden_states_scale,
685
686
        g1_alphas=(1 / rank_weights.w1_gs) if rank_weights.w1_gs is not None else None,
        g2_alphas=(1 / rank_weights.w2_gs) if rank_weights.w2_gs is not None else None,
687
688
689
690
691
692
693
694
695
696
697
698
        a1_gscale=gscale,
        a2_gscale=gscale,
        block_shape=config.quant_block_shape,
        per_act_token_quant=config.is_per_act_token_quant,
        per_out_ch_quant=config.is_per_out_ch_quant,
    )

    mk = make_modular_kernel(config, vllm_config, quant_config)

    # impls might update the tensor in place
    hidden_states = rank_tensors.hidden_states.clone()

699
    topk_ids = rank_tensors.topk_ids.to(mk.prepare_finalize.topk_indices_dtype())
700
701

    mk_kwargs = {
702
703
704
705
706
        "hidden_states": hidden_states,
        "w1": rank_weights.w1,
        "w2": rank_weights.w2,
        "topk_weights": rank_tensors.topk_weights,
        "topk_ids": topk_ids,
707
        "activation": MoEActivation.SILU,
708
709
710
711
        "expert_map": rank_tensors.expert_map,
        "global_num_experts": config.E,
        "apply_router_weight_on_input": config.topk == 1
        and config.supports_apply_weight_on_input(),
712
    }
713
714

    num_tokens = rank_tensors.hidden_states.shape[0]
715
716
717
    num_tokens_across_dp = torch.tensor(
        [num_tokens] * config.world_size, device="cuda", dtype=torch.int
    )
718
719

    with set_forward_context(
720
721
722
723
        None,
        vllm_config,
        num_tokens=num_tokens,
        num_tokens_across_dp=num_tokens_across_dp,
724
    ):
725
        out = mk.apply(**mk_kwargs)
726
727

    return out