common.py 23.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Any, Optional, Union

import torch

import vllm._custom_ops as ops
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from tests.kernels.utils import torch_experts
from vllm.config import VllmConfig
from vllm.distributed import get_dp_group, get_tensor_model_parallel_world_size
# Fused experts and PrepareFinalize imports
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
    BatchedDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import (  # noqa: E501
    BatchedTritonOrDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.config import (
    FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig)
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
    BatchedTritonExperts, NaiveBatchedExperts)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.layer import (FusedMoEMethodBase,
                                                        TritonExperts)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
    MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
    TritonOrDeepGemmExperts)
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx

from .parallel_utils import ProcessGroupInfo
from .utils import (make_block_quant_fp8_weights, make_non_quant_weights,
                    make_quant_fp8_weights, per_token_cast_to_fp8)

if has_pplx():
    from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
        PplxPrepareAndFinalize)
if has_deep_ep():
    from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import (  # noqa: E501
        DeepEPHTPrepareAndFinalize)
    from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import (  # noqa: E501
        DeepEPLLPrepareAndFinalize)


def _describe_tensor(t: Optional[torch.Tensor], name: str) -> str:
    if t is None:
        return f"{name} : None"
    else:
        return f"{name} : {t.shape} {t.dtype} {t.device}"


@dataclass
class Config:
    Ms: Union[list[int], int]
    K: int
    N: int
    E: int
    topks: Union[list[int], int]
    dtype: torch.dtype
    quant_config: Optional[FusedMoEQuantConfig]

    prepare_finalize_type: mk.FusedMoEPrepareAndFinalize
    fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute

    fused_moe_chunk_size: Optional[int]
    world_size: int

    torch_trace_dir_path: Optional[str] = None

    def describe(self) -> str:
        s = ""
        s += "== Config: \n"
        s += f" world_size={self.world_size} \n"
        s += f" PF={self.prepare_finalize_type.__name__} \n"
        s += f" FE={self.fused_experts_type.__name__} \n"
        s += f" topk={self.topks} \n"
        s += f" dtype={self.dtype} \n"
        s += f" fused_moe_chunk_size={self.fused_moe_chunk_size} \n"
        s += " Quant: \n"
        s += f" fused_moe_chunk_size={self.fused_moe_chunk_size} \n "
        if self.quant_config is not None:
            s += f"     q_dtype={self.quant_dtype} \n"
            s += f"     q_block_shape={self.quant_block_shape} \n"
            s += f"     q_per_out_ch_quant={self.is_per_out_ch_quant} \n"
            s += f"     q_per_act_token={self.is_per_act_token_quant} \n"
        else:
            s += "     quant=None \n"
        return s

    @property
    def M(self) -> int:
        assert isinstance(self.Ms, int)
        return self.Ms

    @property
    def quant_dtype(self) -> Optional[torch.dtype]:
        if self.quant_config is None:
            return None
        return self.quant_config.quant_dtype

    @property
    def is_per_act_token_quant(self) -> bool:
        if self.quant_config is None:
            return False
        return self.quant_config.per_act_token_quant

    @property
    def is_per_tensor_act_quant(self) -> bool:
        if self.quant_config is None:
            return False
        return (not self.is_per_act_token_quant
                and self.quant_block_shape is None)

    @property
    def is_per_out_ch_quant(self) -> bool:
        if self.quant_config is None:
            return False
        return self.quant_config.per_out_ch_quant

    @property
    def quant_block_shape(self) -> Optional[list[int]]:
        if self.quant_config is None:
            return None
        return self.quant_config.block_shape

    @property
    def topk(self) -> int:
        assert isinstance(self.topks, int)
        return self.topks

    @property
    def topk_ids_dtype(self) -> Optional[torch.dtype]:
        topk_ids_dtype = None
        if self.prepare_finalize_type == PplxPrepareAndFinalize:
            topk_ids_dtype = torch.uint32
        elif self.prepare_finalize_type in [
                DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize
        ]:
            topk_ids_dtype = torch.int64
        return topk_ids_dtype

    @property
    def num_local_experts(self) -> int:
        return self.E // self.world_size

    def make_env_data(self) -> tuple[VllmConfig, dict[Any, Any]]:
        """
        make env data for vllm launch. 
        """
        vllm_config = VllmConfig()
        vllm_config.parallel_config.data_parallel_size = self.world_size
        vllm_config.parallel_config.enable_expert_parallel = True

        env_dict = {
            "VLLM_ALL2ALL_BACKEND": self.all2all_backend(),
            "VLLM_USE_DEEP_GEMM": str(int(self.needs_deep_gemm())),
        }
        if self.fused_moe_chunk_size is not None:
            env_dict.update(
                {"VLLM_FUSED_MOE_CHUNK_SIZE": str(self.fused_moe_chunk_size)})
        return vllm_config, env_dict

    def is_fp8_block_quantized(self):
        return (self.quant_dtype == torch.float8_e4m3fn
                and self.quant_block_shape is not None)

    def is_batched_prepare_finalize(self):
        return self.prepare_finalize_type in [
            PplxPrepareAndFinalize, DeepEPLLPrepareAndFinalize
        ]

    def is_batched_fused_experts(self):
        return self.fused_experts_type in [
            CutlassExpertsFp8, BatchedDeepGemmExperts, BatchedTritonExperts,
            NaiveBatchedExperts, BatchedTritonOrDeepGemmExperts
        ]

    def is_standard_fused_experts(self):
        return self.fused_experts_type in [
            CutlassExpertsFp8, DeepGemmExperts, TritonOrDeepGemmExperts,
            TritonExperts
        ]

    def is_fe_16bit_supported(self):
        return self.fused_experts_type in [
            BatchedTritonExperts, BatchedTritonOrDeepGemmExperts,
            NaiveBatchedExperts, TritonExperts
        ]

    def is_fe_fp8_supported(self):
        return self.fused_experts_type in [
            BatchedDeepGemmExperts,
            BatchedTritonExperts,
            BatchedTritonOrDeepGemmExperts,
            CutlassExpertsFp8,
            DeepGemmExperts,
            TritonExperts,
            TritonOrDeepGemmExperts,
            NaiveBatchedExperts,
        ]

    def is_fe_block_fp8_supported(self):
        return self.fused_experts_type in [
            BatchedDeepGemmExperts,
            BatchedTritonOrDeepGemmExperts,
            DeepGemmExperts,
            TritonExperts,
            TritonOrDeepGemmExperts,
            BatchedTritonExperts,
            NaiveBatchedExperts,
        ]

    def is_fe_supports_chunking(self):
        return self.fused_experts_type in [
            CutlassExpertsFp8, DeepGemmExperts, TritonOrDeepGemmExperts,
            TritonExperts
        ]

    def needs_deep_gemm(self):
        return self.fused_experts_type in [
            BatchedDeepGemmExperts,
            DeepGemmExperts,
        ]

    def needs_pplx(self):
        return self.prepare_finalize_type in [PplxPrepareAndFinalize]

    def needs_deep_ep(self):
        return self.prepare_finalize_type in [
            DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize
        ]

    def all2all_backend(self):
        if self.needs_pplx():
            return "pplx"
        if self.prepare_finalize_type == DeepEPHTPrepareAndFinalize:
            return "deepep_high_throughput"
        if self.prepare_finalize_type == DeepEPLLPrepareAndFinalize:
            return "deepep_low_latency"
        return "naive"

    def needs_all2all(self):
        return self.prepare_finalize_type in [
            PplxPrepareAndFinalize, DeepEPHTPrepareAndFinalize,
            DeepEPLLPrepareAndFinalize
        ]

    def is_valid(self):
        # Check prepare-finalize and fused-experts compatibility
        if self.is_batched_prepare_finalize():
            if not self.is_batched_fused_experts():
                return False
        else:
            if not self.is_standard_fused_experts():
                return False

        use_chunking = self.fused_moe_chunk_size is not None
        if use_chunking and not self.is_fe_supports_chunking():
            return False

        # Check quantization sanity
        if (int(self.is_per_act_token_quant) +
                int(self.is_per_tensor_act_quant) +
                int(self.quant_block_shape is not None)) > 1:
            # invalid quant config
            return False

        # check bf16 / fp16 support
        is_16bit = (self.dtype.itemsize == 2 and self.quant_dtype is None)
        if is_16bit and not self.is_fe_16bit_supported():
            return False

        # Check fp8 support
        is_fp8 = self.quant_dtype == torch.float8_e4m3fn
        if is_fp8 and not self.is_fe_fp8_supported():
            return False

        # Check fp8 block quanization support
        is_block_quatized = self.quant_block_shape is not None
        if is_block_quatized and not is_fp8:
            return False
        if is_block_quatized and not self.is_fe_block_fp8_supported():
            return False

        # deep_gemm only works with block-quantized
        if self.needs_deep_gemm() and not is_block_quatized:
            return False

        # Check dependencies
        if self.needs_deep_ep() and not has_deep_ep():
            return False
        if self.needs_deep_gemm() and not has_deep_gemm():
            return False
        if self.needs_pplx() and not has_pplx():  # noqa: SIM103
            return False

        return True


@dataclass
class WeightTensors:
    w1: torch.Tensor
    w2: torch.Tensor
    w1_scale: Optional[torch.Tensor]
    w2_scale: Optional[torch.Tensor]

    def describe(self):
        s = ""
        s += "== Weight Tensors: \n"
        s += f' - {_describe_tensor(self.w1, "w1")} \n'
        s += f' - {_describe_tensor(self.w2, "w2")} \n'
        s += f' - {_describe_tensor(self.w1_scale, "w1_scale")} \n'
        s += f' - {_describe_tensor(self.w2_scale, "w2_scale")} \n'
        return s

    def to_current_device(self):
        self.w1 = self.w1.to(device=torch.cuda.current_device())
        self.w2 = self.w2.to(device=torch.cuda.current_device())
        is_quantized = self.w1.dtype == torch.float8_e4m3fn
        if is_quantized:
            assert self.w1_scale is not None
            assert self.w2_scale is not None
            self.w1_scale = self.w1_scale.to(
                device=torch.cuda.current_device())
            self.w2_scale = self.w2_scale.to(
                device=torch.cuda.current_device())

    def slice_weights(self, rank: int,
                      num_local_experts: int) -> "WeightTensors":
        s = rank * num_local_experts
        e = s + num_local_experts
        w1 = self.w1[s:e, :, :]
        w2 = self.w2[s:e, :, :]
        is_quantized = self.w1.dtype == torch.float8_e4m3fn
        w1_scale, w2_scale = (None, None)
        if is_quantized:
            assert self.w1_scale is not None
            assert self.w2_scale is not None
            w1_scale = self.w1_scale[s:e, :, :]
            w2_scale = self.w2_scale[s:e, :, :]
        return WeightTensors(w1, w2, w1_scale, w2_scale)

    @staticmethod
    def make(config: Config) -> "WeightTensors":

        if config.quant_dtype is None:
            # just make normal dtype weights
            w1, w2 = make_non_quant_weights(e=config.E,
                                            n=config.N,
                                            k=config.K,
                                            dtype=config.dtype)
            return WeightTensors(w1=w1, w2=w2, w1_scale=None, w2_scale=None)

        assert config.quant_dtype == torch.float8_e4m3fn
        if not config.is_fp8_block_quantized():
            w1, w2, w1_scale, w2_scale = make_quant_fp8_weights(
                e=config.E,
                n=config.N,
                k=config.K,
                per_out_channel_quant=config.is_per_out_ch_quant,
            )
            return WeightTensors(w1=w1,
                                 w2=w2,
                                 w1_scale=w1_scale,
                                 w2_scale=w2_scale)

        assert config.quant_block_shape is not None
        w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights(
            e=config.E,
            n=config.N,
            k=config.K,
            block_size=config.quant_block_shape,
        )
        return WeightTensors(w1=w1,
                             w2=w2,
                             w1_scale=w1_scale,
                             w2_scale=w2_scale)


@dataclass
class RankTensors:
    hidden_states: torch.Tensor
    hidden_states_scale: Optional[torch.Tensor]

    topk_weights: torch.Tensor
    topk_ids: torch.Tensor
    expert_map: Optional[torch.Tensor]

    quant_config: Optional[FusedMoEQuantConfig]

    def describe(self):
        s = ""
        s += "== Rank Tensors: \n"
        s += f' - {_describe_tensor(self.hidden_states, "HS")} \n'
        s += f' - {_describe_tensor(self.hidden_states_scale, "HS_scale")} \n'
        s += f' - {_describe_tensor(self.topk_weights, "topk_weights")} \n'
        s += f' - {_describe_tensor(self.topk_ids, "topk_ids")} \n'
        s += f' - {_describe_tensor(self.expert_map, "expert_map")} \n'
        return s

    @staticmethod
    def make_hidden_states(
            config: Config) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
        """
        Return hidden_states
        """
        m, k, dtype = (config.M, config.K, config.dtype)
        a = (torch.randn(
            (m, k), device=torch.cuda.current_device(), dtype=dtype) / 15.0)

        if config.quant_dtype is None:
            return a, None

        # We dequant and use that as hidden_states so the tests are stable.
        # quantizing and dequantizing yield slightly different results
        # depending on the hardware. Here we, quantize and dequantize
419
        # first - so further quantize and dequantize will yield the same
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
        # values.
        if config.is_per_tensor_act_quant:
            a_q, a_scales = ops.scaled_fp8_quant(
                a, use_per_token_if_dynamic=False)
            return a_q.float().mul(a_scales).to(dtype), a_scales

        if config.is_per_act_token_quant:
            a_q, a_scales = ops.scaled_fp8_quant(a,
                                                 use_per_token_if_dynamic=True)
            return a_q.float().mul(a_scales).to(dtype), None

        assert config.quant_block_shape is not None
        block_k = config.quant_block_shape[1]
        a_q, a_scales = per_token_cast_to_fp8(a, block_size=block_k)
        return a_q.float().view(
            (-1, block_k)).mul(a_scales.view(-1, 1)).view(m, k).to(dtype), None

    @staticmethod
    def make(config: Config, pgi: ProcessGroupInfo):

        dtype = config.dtype
        topk, m, _ = (config.topk, config.M, config.K)
        hidden_states, hidden_states_scale = RankTensors.make_hidden_states(
            config)

        num_local_experts, global_num_experts = (config.num_local_experts,
                                                 config.E)
        score = torch.randn((m, global_num_experts),
                            device="cuda",
                            dtype=dtype)
        topk_weights, topk_ids, _ = fused_topk(hidden_states, score, topk,
                                               False)
        topk_ids = topk_ids.to(config.topk_ids_dtype)

        # distribute topk_ids evenly
        for mi in range(m):
            topk_ids[mi] = torch.randperm(config.E)[:topk]
        topk_ids = topk_ids.to(device=torch.cuda.current_device())

        expert_map = None
        if config.world_size > 1:
            expert_map = torch.full((global_num_experts, ),
                                    fill_value=-1,
                                    dtype=torch.int32)
            s = pgi.rank * num_local_experts
            e = s + num_local_experts
            expert_map[s:e] = torch.tensor(list(range(num_local_experts)))
            expert_map = expert_map.to(device=torch.cuda.current_device(),
                                       dtype=torch.int32)

        return RankTensors(
            hidden_states=hidden_states,
            hidden_states_scale=hidden_states_scale,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            expert_map=expert_map,
            quant_config=config.quant_config,
        )


def reference_moe_impl(config: Config, weights: WeightTensors,
                       rank_tensors: RankTensors) -> torch.Tensor:

    return torch_experts(a=rank_tensors.hidden_states,
                         w1=weights.w1,
                         w2=weights.w2,
                         topk_weight=rank_tensors.topk_weights,
                         topk_ids=rank_tensors.topk_ids,
                         global_num_experts=config.E,
                         expert_map=None,
                         w1_scale=weights.w1_scale,
                         w2_scale=weights.w2_scale,
                         a1_scale=rank_tensors.hidden_states_scale,
                         quant_dtype=config.quant_dtype,
                         per_act_token_quant=config.is_per_act_token_quant,
                         block_shape=config.quant_block_shape,
                         apply_router_weights_on_input=config.topk == 1)


def make_fused_experts(
        config: Config, moe: FusedMoEConfig,
        num_dispatchers: int) -> mk.FusedMoEPermuteExpertsUnpermute:

    use_fp8 = config.quant_dtype == torch.float8_e4m3fn
    batch_kwargs = {
        "max_num_tokens": moe.max_num_tokens,
        "num_dispatchers": num_dispatchers,
    }
    quant_kwargs = {
        "use_fp8_w8a8": use_fp8,
        "use_int8_w8a8": False,
        "use_int8_w8a16": False,
        "use_int4_w4a16": False,
        "block_shape": config.quant_block_shape,
        "per_act_token_quant": config.is_per_act_token_quant,
    }
    deepgemm_kwargs = {"allow_deep_gemm": has_deep_gemm()}

    if config.fused_experts_type == BatchedDeepGemmExperts:
        kwargs = batch_kwargs | {
            "block_shape": config.quant_block_shape,
            "per_act_token_quant": config.is_per_act_token_quant,
        }
        print(f"Making BatchedDeepGemmExperts {kwargs} ...")
        experts = BatchedDeepGemmExperts(**kwargs)
    elif config.fused_experts_type == BatchedTritonExperts:
        kwargs = batch_kwargs | quant_kwargs
        print(f"Making BatchedTritonExperts {kwargs} ...")
        experts = BatchedTritonExperts(**kwargs)
    elif config.fused_experts_type == BatchedTritonOrDeepGemmExperts:
        kwargs = batch_kwargs | quant_kwargs | deepgemm_kwargs
        print(f"Making BatchedTritonOrDeepGemmExperts {kwargs} ...")
        experts = BatchedTritonOrDeepGemmExperts(**kwargs)
    elif config.fused_experts_type == DeepGemmExperts:
        print("Making DeepGemmExperts () ...")
        experts = DeepGemmExperts()
    elif config.fused_experts_type == TritonExperts:
        kwargs = quant_kwargs
        print(f"Making TritonExperts {kwargs} ...")
        experts = TritonExperts(**kwargs)
    elif config.fused_experts_type == TritonOrDeepGemmExperts:
        kwargs = quant_kwargs | deepgemm_kwargs
        print(f"Making TritonOrDeepGemmExperts {kwargs} ...")
        experts = TritonOrDeepGemmExperts(**kwargs)
    elif config.fused_experts_type == NaiveBatchedExperts:
        kwargs = batch_kwargs | quant_kwargs
        print(f"Making NaiveBatchedExperts {kwargs} ...")
        experts = NaiveBatchedExperts(**kwargs)
    elif config.fused_experts_type == CutlassExpertsFp8:
        use_batched_format = config.is_batched_prepare_finalize()
        num_experts = (moe.num_local_experts
                       if use_batched_format else moe.num_experts)
        kwargs = {
            "max_experts_per_worker": num_experts,
            "out_dtype": moe.in_dtype,
            "per_act_token_quant": config.is_per_act_token_quant,
            "per_out_ch_quant": config.is_per_out_ch_quant,
            "block_shape": config.quant_block_shape,
            "num_dispatchers": num_dispatchers,
            "use_batched_format": use_batched_format
        }
        print(f"Making CutlassExpertsFp8 {kwargs} ...")
        experts = CutlassExpertsFp8(**kwargs)

    return experts


def make_modular_kernel(config: Config,
                        vllm_config: VllmConfig) -> mk.FusedMoEModularKernel:

    def next_power_of_2(x):
        import math
        if x == 0:
            return 1
        return 2**math.ceil(math.log2(x))

    # make moe config
    moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make(
        tp_size_=get_tensor_model_parallel_world_size(),
        dp_size_=get_dp_group().world_size,
        vllm_parallel_config=vllm_config.parallel_config,
    )
    moe = FusedMoEConfig(
        num_experts=config.E,
        experts_per_token=config.topk,
        hidden_dim=config.K,
        num_local_experts=config.num_local_experts,
        moe_parallel_config=moe_parallel_config,
        in_dtype=config.dtype,
        quant_config=config.quant_config,
        max_num_tokens=next_power_of_2(config.M),
    )

    # make modular kernel
    prepare_finalize = None
    if config.needs_all2all():
        prepare_finalize = FusedMoEMethodBase.maybe_make_prepare_finalize(moe)
        assert prepare_finalize is not None
    else:
        prepare_finalize = MoEPrepareAndFinalizeNoEP()

    fused_experts = make_fused_experts(config, moe,
                                       prepare_finalize.num_dispatchers())

    modular_kernel = mk.FusedMoEModularKernel(
        prepare_finalize=prepare_finalize, fused_experts=fused_experts)

    return modular_kernel


def run_modular_kernel(
    pgi: ProcessGroupInfo,
    vllm_config: VllmConfig,
    config: Config,
    weights: WeightTensors,
    rank_tensors: RankTensors,
) -> torch.Tensor:
    assert isinstance(config.Ms, int)
    assert isinstance(config.topks, int)

    # weights for rank
    rank_weights = weights.slice_weights(pgi.rank, config.num_local_experts)

    mk = make_modular_kernel(config, vllm_config)

    mk_kwargs = {
        "hidden_states": rank_tensors.hidden_states.clone(
        ),  # impls might update the tensor in place
        "w1": rank_weights.w1,
        "w2": rank_weights.w2,
        "topk_weights": rank_tensors.topk_weights,
        "topk_ids": rank_tensors.topk_ids,
        "expert_map": rank_tensors.expert_map,
        "w1_scale": rank_weights.w1_scale,
        "w2_scale": rank_weights.w2_scale,
        "a1_scale": rank_tensors.hidden_states_scale,
        "global_num_experts": config.E,
        "apply_router_weight_on_input": config.topk == 1,
    }
    out = mk.forward(**mk_kwargs)

    return out