common.py 22.7 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
4
from typing import Any
5
6
7
8
9

import torch

import vllm._custom_ops as ops
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
10
from tests.kernels.moe.utils import make_test_weights, per_token_cast_to_fp8
11
12
13
14
15
from tests.kernels.quantization.nvfp4_utils import (
    FLOAT4_E2M1_MAX,
    FLOAT8_E4M3_MAX,
    dequantize_nvfp4_to_dtype,
)
16
17
from tests.kernels.utils import torch_experts
from vllm.config import VllmConfig
18
19
20
21
22
from vllm.distributed import (
    get_dp_group,
    get_pcp_group,
    get_tensor_model_parallel_world_size,
)
23
from vllm.forward_context import set_forward_context
24
from vllm.model_executor.layers.fused_moe import fused_topk
25
26
27
from vllm.model_executor.layers.fused_moe.all2all_utils import (
    maybe_make_prepare_finalize,
)
28
from vllm.model_executor.layers.fused_moe.config import (
29
30
31
    FusedMoEConfig,
    FusedMoEParallelConfig,
    FusedMoEQuantConfig,
32
    RoutingMethodType,
33
)
34
35
36
37
38
39
40
from vllm.utils.import_utils import (
    has_aiter,
    has_deep_ep,
    has_deep_gemm,
    has_mori,
    has_pplx,
)
41

42
43
44
45
46
47
from .mk_objects import (
    TestMoEQuantConfig,
    expert_info,
    make_fused_experts,
    prepare_finalize_info,
)
48
49
50
from .parallel_utils import ProcessGroupInfo


51
def _describe_tensor(t: torch.Tensor | None, name: str) -> str:
52
53
54
55
56
57
58
59
    if t is None:
        return f"{name} : None"
    else:
        return f"{name} : {t.shape} {t.dtype} {t.device}"


@dataclass
class Config:
60
    Ms: list[int] | int
61
62
63
    K: int
    N: int
    E: int
64
    topks: list[int] | int
65
    dtype: torch.dtype
66
    quant_config: TestMoEQuantConfig | None
67
68
69
70

    prepare_finalize_type: mk.FusedMoEPrepareAndFinalize
    fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute

71
    fused_moe_chunk_size: int | None
72
73
    world_size: int

74
    torch_trace_dir_path: str | None = None
75

76
77
    def __post_init__(self):
        if self.quant_config is None:
78
            self.quant_config = TestMoEQuantConfig(None, False, False, None)
79

80
81
    def describe(self) -> str:
        s = ""
82
83
84
85
86
87
88
89
90
91
92
93
        s += "== Config:\n"
        s += f" world_size={self.world_size}\n"
        s += f" PF={self.prepare_finalize_type.__name__}\n"
        s += f" FE={self.fused_experts_type.__name__}\n"
        s += f" E={self.E}\n"
        s += f" Ms={self.Ms}\n"
        s += f" N={self.N}\n"
        s += f" K={self.K}\n"
        s += f" topk={self.topks}\n"
        s += f" dtype={self.dtype}\n"
        s += f" fused_moe_chunk_size={self.fused_moe_chunk_size}\n"
        s += " Quant:\n"
94
        if self.quant_config is not None:
95
96
97
98
            s += f"     q_dtype={self.quant_dtype}\n"
            s += f"     q_block_shape={self.quant_block_shape}\n"
            s += f"     q_per_out_ch_quant={self.is_per_out_ch_quant}\n"
            s += f"     q_per_act_token={self.is_per_act_token_quant}\n"
99
        else:
100
            s += "     quant=None\n"
101
102
103
104
105
106
107
108
        return s

    @property
    def M(self) -> int:
        assert isinstance(self.Ms, int)
        return self.Ms

    @property
109
    def quant_dtype(self) -> torch.dtype | str | None:
110
        assert self.quant_config is not None
111
112
113
114
        return self.quant_config.quant_dtype

    @property
    def is_per_act_token_quant(self) -> bool:
115
        assert self.quant_config is not None
116
117
118
119
        return self.quant_config.per_act_token_quant

    @property
    def is_per_tensor_act_quant(self) -> bool:
120
        return not self.is_per_act_token_quant and self.quant_block_shape is None
121
122
123

    @property
    def is_per_out_ch_quant(self) -> bool:
124
        assert self.quant_config is not None
125
126
127
        return self.quant_config.per_out_ch_quant

    @property
128
    def quant_block_shape(self) -> list[int] | None:
129
        assert self.quant_config is not None
130
131
132
133
134
135
136
137
138
139
140
141
142
        return self.quant_config.block_shape

    @property
    def topk(self) -> int:
        assert isinstance(self.topks, int)
        return self.topks

    @property
    def num_local_experts(self) -> int:
        return self.E // self.world_size

    def make_env_data(self) -> tuple[VllmConfig, dict[Any, Any]]:
        """
143
        make env data for vllm launch.
144
145
146
147
148
149
150
151
        """
        vllm_config = VllmConfig()
        vllm_config.parallel_config.data_parallel_size = self.world_size
        vllm_config.parallel_config.enable_expert_parallel = True

        env_dict = {
            "VLLM_USE_DEEP_GEMM": str(int(self.needs_deep_gemm())),
        }
152

153
        vllm_config.parallel_config.all2all_backend = self.all2all_backend()
154

155
156
        if self.fused_moe_chunk_size is not None:
            env_dict.update(
157
158
                {"VLLM_FUSED_MOE_CHUNK_SIZE": str(self.fused_moe_chunk_size)}
            )
159

160
161
162
        return vllm_config, env_dict

    def is_fp8_block_quantized(self):
163
164
165
166
        return (
            self.quant_dtype == torch.float8_e4m3fn
            and self.quant_block_shape is not None
        )
167
168

    def is_batched_prepare_finalize(self):
169
        info = prepare_finalize_info(self.prepare_finalize_type)
170
        return mk.FusedMoEActivationFormat.BatchedExperts == info.activation_format
171
172

    def is_batched_fused_experts(self):
173
        info = expert_info(self.fused_experts_type)
174
        return mk.FusedMoEActivationFormat.BatchedExperts == info.activation_format
175
176

    def is_standard_fused_experts(self):
177
178
179
180
181
182
183
184
185
186
187
188
189
190
        info = expert_info(self.fused_experts_type)
        return mk.FusedMoEActivationFormat.Standard == info.activation_format

    def fe_supported_types(self):
        info = expert_info(self.fused_experts_type)
        return info.supported_dtypes

    def pf_supported_types(self):
        info = prepare_finalize_info(self.prepare_finalize_type)
        return info.supported_dtypes

    def is_block_quant_supported(self):
        info = expert_info(self.fused_experts_type)
        return info.blocked_quantization_support
191
192

    def is_fe_supports_chunking(self):
193
194
195
196
197
198
199
200
201
202
        info = expert_info(self.fused_experts_type)
        return info.supports_chunking

    def supports_expert_map(self):
        info = expert_info(self.fused_experts_type)
        return info.supports_expert_map

    def supports_apply_weight_on_input(self):
        info = prepare_finalize_info(self.prepare_finalize_type)
        return info.supports_apply_weight_on_input
203
204

    def needs_deep_gemm(self):
205
206
        info = expert_info(self.fused_experts_type)
        return info.needs_deep_gemm
207
208

    def needs_pplx(self):
209
210
        info = prepare_finalize_info(self.prepare_finalize_type)
        return info.backend == "pplx"
211
212

    def needs_deep_ep(self):
213
        info = prepare_finalize_info(self.prepare_finalize_type)
214
215
216
217
        return (
            info.backend == "deepep_high_throughput"
            or info.backend == "deepep_low_latency"
        )
218

219
220
221
222
223
224
225
226
    def needs_aiter(self):
        info = expert_info(self.fused_experts_type)
        return info.needs_aiter

    def needs_mori(self):
        info = prepare_finalize_info(self.prepare_finalize_type)
        return info.backend == "mori"

227
    def all2all_backend(self):
228
229
        info = prepare_finalize_info(self.prepare_finalize_type)
        return info.backend
230

231
    def is_valid(self) -> tuple[bool, str | None]:
232
233
234
        # Check prepare-finalize and fused-experts compatibility
        if self.is_batched_prepare_finalize():
            if not self.is_batched_fused_experts():
235
                return False, "Mismatched format."
236
237
        else:
            if not self.is_standard_fused_experts():
238
                return False, "Mismatched format."
239
240
241

        use_chunking = self.fused_moe_chunk_size is not None
        if use_chunking and not self.is_fe_supports_chunking():
242
            return False, "Chunking not supported."
243
244

        # Check quantization sanity
245
246
247
248
249
        if (
            int(self.is_per_act_token_quant)
            + int(self.is_per_tensor_act_quant)
            + int(self.quant_block_shape is not None)
        ) > 1:
250
            # invalid quant config
251
            return False, f"Bad quant_config {self.quant_config}."
252

253
254
        # check type support
        if self.quant_dtype is None:
255
256
257
258
            if (
                self.dtype not in self.pf_supported_types()
                or self.dtype not in self.fe_supported_types()
            ):
259
260
261
262
263
                return False, (
                    f"Unsupported type {self.dtype} not in "
                    f"{self.pf_supported_types()} and "
                    f"{self.fe_supported_types()}."
                )
264
        else:
265
266
267
268
            if (
                self.quant_dtype not in self.pf_supported_types()
                or self.quant_dtype not in self.fe_supported_types()
            ):
269
270
271
272
273
                return False, (
                    f"Unsupported quant type {self.quant_dtype} "
                    f"not in {self.pf_supported_types()} and "
                    f"{self.fe_supported_types()}."
                )
274

275
276
277
        # Check block quantization support
        is_block_quantized = self.quant_block_shape is not None
        if is_block_quantized and self.quant_dtype is None:
278
279
            return False, "No block quantization support."

280
        if is_block_quantized and not self.is_block_quant_supported():
281
            return False, "Mismatched block quantization support."
282
283

        # deep_gemm only works with block-quantized
284
        if self.needs_deep_gemm() and not is_block_quantized:
285
            return False, "Needs DeepGEMM but not block quantized."
286

287
        # Check dependencies (turn into asserts?)
288
        if self.needs_deep_ep() and not has_deep_ep():
289
            return False, "Needs DeepEP, but DeepEP not available."
290
        if self.needs_deep_gemm() and not has_deep_gemm():
291
            return False, "Needs DeepGEMM, but DeepGEMM not available."
292
        if self.needs_pplx() and not has_pplx():  # noqa: SIM103
293
            return False, "Needs PPLX, but PPLX not available."
294
295
296
297
        if self.needs_aiter() and not has_aiter():  # noqa: SIM103
            return False, "Needs Aiter, but Aiter not available."
        if self.needs_mori() and not has_mori():  # noqa: SIM103
            return False, "Needs MoRI, but MoRI not available."
298

299
        return True, None
300
301
302
303
304
305


@dataclass
class WeightTensors:
    w1: torch.Tensor
    w2: torch.Tensor
306
307
308
309
    w1_scale: torch.Tensor | None
    w2_scale: torch.Tensor | None
    w1_gs: torch.Tensor | None = None
    w2_gs: torch.Tensor | None = None
310
311
312
313

    def describe(self):
        s = ""
        s += "== Weight Tensors: \n"
314
315
316
317
318
319
        s += f" - {_describe_tensor(self.w1, 'w1')} \n"
        s += f" - {_describe_tensor(self.w2, 'w2')} \n"
        s += f" - {_describe_tensor(self.w1_scale, 'w1_scale')} \n"
        s += f" - {_describe_tensor(self.w2_scale, 'w2_scale')} \n"
        s += f" - {_describe_tensor(self.w1_gs, 'w1_gs')} \n"
        s += f" - {_describe_tensor(self.w2_gs, 'w2_gs')} \n"
320
321
        return s

322
323
    def is_quantized(self) -> bool:
        # or w1_scale is not None?
324
325
326
327
328
        return (
            self.w1.dtype == torch.float8_e4m3fn
            or self.w1.dtype == torch.uint8
            or self.w1.dtype == torch.int8
        )
329

330
    def to_current_device(self):
331
332
333
        device = torch.cuda.current_device()
        self.w1 = self.w1.to(device=device)
        self.w2 = self.w2.to(device=device)
334

335
336
337
338
        if self.w1_scale is not None:
            self.w1_scale = self.w1_scale.to(device=device)
        if self.w2_scale is not None:
            self.w2_scale = self.w2_scale.to(device=device)
339

340
        if self.w1_gs is not None:
341
342
343
            self.w1_gs = self.w1_gs.to(device=device)
        if self.w2_gs is not None:
            self.w2_gs = self.w2_gs.to(device=device)
344

345
    def slice_weights(self, rank: int, num_local_experts: int) -> "WeightTensors":
346
347
348
349
        s = rank * num_local_experts
        e = s + num_local_experts
        w1 = self.w1[s:e, :, :]
        w2 = self.w2[s:e, :, :]
350
351
        w1_scale = self.w1_scale[s:e, :, :] if self.w1_scale is not None else None
        w2_scale = self.w2_scale[s:e, :, :] if self.w2_scale is not None else None
352
353
        w1_gs = self.w1_gs[s:e] if self.w1_gs is not None else None
        w2_gs = self.w2_gs[s:e] if self.w2_gs is not None else None
354

355
        return WeightTensors(w1, w2, w1_scale, w2_scale, w1_gs, w2_gs)
356

357
358
359
    @staticmethod
    def make(config: Config) -> "WeightTensors":
        (_, w1, w1_scale, w1_gs), (_, w2, w2_scale, w2_gs) = make_test_weights(
360
361
362
            e=config.E,
            n=config.N,
            k=config.K,
363
364
365
            in_dtype=config.dtype,
            quant_dtype=config.quant_dtype,
            block_shape=config.quant_block_shape,
366
367
            # or config.is_per_out_ch_quant
            per_out_ch_quant=config.is_per_act_token_quant,
368
369
370
        )
        return WeightTensors(
            w1=w1, w2=w2, w1_scale=w1_scale, w2_scale=w2_scale, w1_gs=w1_gs, w2_gs=w2_gs
371
372
373
374
375
376
        )


@dataclass
class RankTensors:
    hidden_states: torch.Tensor
377
    hidden_states_scale: torch.Tensor | None
378
379
380

    topk_weights: torch.Tensor
    topk_ids: torch.Tensor
381
    expert_map: torch.Tensor | None
382
383
384
385

    def describe(self):
        s = ""
        s += "== Rank Tensors: \n"
386
387
388
389
390
        s += f" - {_describe_tensor(self.hidden_states, 'HS')} \n"
        s += f" - {_describe_tensor(self.hidden_states_scale, 'HS_scale')} \n"
        s += f" - {_describe_tensor(self.topk_weights, 'topk_weights')} \n"
        s += f" - {_describe_tensor(self.topk_ids, 'topk_ids')} \n"
        s += f" - {_describe_tensor(self.expert_map, 'expert_map')} \n"
391
392
393
394
        return s

    @staticmethod
    def make_hidden_states(
395
        config: Config,
396
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
397
398
399
400
        """
        Return hidden_states
        """
        m, k, dtype = (config.M, config.K, config.dtype)
401
        a = torch.randn((m, k), device=torch.cuda.current_device(), dtype=dtype) / 15.0
402
403
404
405
406
407
408

        if config.quant_dtype is None:
            return a, None

        # We dequant and use that as hidden_states so the tests are stable.
        # quantizing and dequantizing yield slightly different results
        # depending on the hardware. Here we, quantize and dequantize
409
        # first - so further quantize and dequantize will yield the same
410
411
        # values.
        if config.is_per_tensor_act_quant:
412
            a_q, a_scales = ops.scaled_fp8_quant(a, use_per_token_if_dynamic=False)
413
414
415
            return a_q.float().mul(a_scales).to(dtype), a_scales

        if config.is_per_act_token_quant:
416
            a_q, a_scales = ops.scaled_fp8_quant(a, use_per_token_if_dynamic=True)
417
418
419
420
421
            return a_q.float().mul(a_scales).to(dtype), None

        assert config.quant_block_shape is not None
        block_k = config.quant_block_shape[1]
        a_q, a_scales = per_token_cast_to_fp8(a, block_size=block_k)
422
423
424
        return a_q.float().view((-1, block_k)).mul(a_scales.view(-1, 1)).view(m, k).to(
            dtype
        ), None
425
426
427
428
429

    @staticmethod
    def make(config: Config, pgi: ProcessGroupInfo):
        dtype = config.dtype
        topk, m, _ = (config.topk, config.M, config.K)
430
        hidden_states, hidden_states_scale = RankTensors.make_hidden_states(config)
431

432
433
434
        num_local_experts, global_num_experts = (config.num_local_experts, config.E)
        score = torch.randn((m, global_num_experts), device="cuda", dtype=dtype)
        topk_weights, topk_ids, _ = fused_topk(hidden_states, score, topk, False)
435
436
437
438
439
440
441

        # distribute topk_ids evenly
        for mi in range(m):
            topk_ids[mi] = torch.randperm(config.E)[:topk]
        topk_ids = topk_ids.to(device=torch.cuda.current_device())

        expert_map = None
442
        if config.world_size > 1 and config.supports_expert_map():
443
444
445
            expert_map = torch.full(
                (global_num_experts,), fill_value=-1, dtype=torch.int32
            )
446
447
448
            s = pgi.rank * num_local_experts
            e = s + num_local_experts
            expert_map[s:e] = torch.tensor(list(range(num_local_experts)))
449
450
451
            expert_map = expert_map.to(
                device=torch.cuda.current_device(), dtype=torch.int32
            )
452
453
454
455
456
457
458
459
460
461

        return RankTensors(
            hidden_states=hidden_states,
            hidden_states_scale=hidden_states_scale,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            expert_map=expert_map,
        )


462
463
464
def reference_moe_impl(
    config: Config, weights: WeightTensors, rank_tensors: RankTensors
) -> torch.Tensor:
465
466
467
468
469
470
471
472
473
474
475
476
    if config.quant_dtype == "nvfp4":
        quant_blocksize = 16
        dtype = config.dtype

        w1_q = weights.w1
        w1_blockscale = weights.w1_scale
        w1_gs = weights.w1_gs

        w2_q = weights.w2
        w2_blockscale = weights.w2_scale
        w2_gs = weights.w2_gs

477
478
479
480
        a_global_scale = (
            (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX)
            / torch.amax(rank_tensors.hidden_states.flatten(), dim=-1)
        ).to(torch.float32)
481
482
483
484
485
486
487
488
489
490
491
492

        assert w1_gs is not None
        assert w2_gs is not None
        assert w1_blockscale is not None
        assert w2_blockscale is not None

        assert w1_blockscale.shape[1] % 128 == 0
        assert w1_blockscale.shape[2] % 4 == 0
        assert w2_blockscale.shape[1] % 128 == 0
        assert w2_blockscale.shape[2] % 4 == 0

        a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(
493
494
            rank_tensors.hidden_states, a_global_scale
        )
495

496
497
498
499
500
501
502
503
        a = dequantize_nvfp4_to_dtype(
            a_fp4,
            a_scale_interleaved,
            a_global_scale,
            dtype=dtype,
            device=a_fp4.device,
            block_size=quant_blocksize,
        )
504
505
506
507
508
509
510
511
512

        e = w1_q.shape[0]
        n = w1_q.shape[1] // 2
        k = w2_q.shape[1]

        w1 = torch.zeros((e, 2 * n, k), device="cuda", dtype=dtype)
        w2 = torch.zeros((e, k, n), device="cuda", dtype=dtype)

        for idx in range(0, e):
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
            w1[idx] = dequantize_nvfp4_to_dtype(
                w1_q[idx],
                w1_blockscale[idx],
                w1_gs[idx],
                dtype=dtype,
                device=w1_q.device,
                block_size=quant_blocksize,
            )
            w2[idx] = dequantize_nvfp4_to_dtype(
                w2_q[idx],
                w2_blockscale[idx],
                w2_gs[idx],
                dtype=dtype,
                device=w2_q.device,
                block_size=quant_blocksize,
            )
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
        a_scale = None
        w1_scale = None
        w2_scale = None
        quant_dtype = None
        per_act_token_quant = False
        block_shape = None
    else:
        a = rank_tensors.hidden_states
        a_scale = rank_tensors.hidden_states_scale
        w1 = weights.w1
        w1_scale = weights.w1_scale
        w2 = weights.w2
        w2_scale = weights.w2_scale
        quant_dtype = config.quant_dtype
        per_act_token_quant = config.is_per_act_token_quant
        block_shape = config.quant_block_shape

546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
    return torch_experts(
        a=a,
        w1=w1,
        w2=w2,
        topk_weight=rank_tensors.topk_weights,
        topk_ids=rank_tensors.topk_ids,
        global_num_experts=config.E,
        expert_map=None,
        w1_scale=w1_scale,
        w2_scale=w2_scale,
        a1_scale=a_scale,
        quant_dtype=quant_dtype,
        per_act_token_quant=per_act_token_quant,
        block_shape=block_shape,
        apply_router_weights_on_input=config.topk == 1
        and config.supports_apply_weight_on_input(),
    )
563
564


565
def _make_gscale(num_experts: int) -> torch.Tensor:
566
567
568
    return torch.ones(
        (num_experts,), device=torch.cuda.current_device(), dtype=torch.float32
    )
569
570


571
572
573
def make_modular_kernel(
    config: Config,
    vllm_config: VllmConfig,
574
    quant_config: FusedMoEQuantConfig,
575
) -> mk.FusedMoEModularKernel:
576
577
    def next_power_of_2(x):
        import math
578

579
580
        if x == 0:
            return 1
581
        return 2 ** math.ceil(math.log2(x))
582
583
584
585

    # make moe config
    moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make(
        tp_size_=get_tensor_model_parallel_world_size(),
586
        pcp_size_=get_pcp_group().world_size,
587
        dp_size_=get_dp_group().world_size,
588
        sp_size_=1,
589
590
        vllm_parallel_config=vllm_config.parallel_config,
    )
591

592
593
594
595
    moe = FusedMoEConfig(
        num_experts=config.E,
        experts_per_token=config.topk,
        hidden_dim=config.K,
596
        intermediate_size_per_partition=config.N,
597
        num_local_experts=config.num_local_experts,
598
        num_logical_experts=config.E,
599
600
601
        moe_parallel_config=moe_parallel_config,
        in_dtype=config.dtype,
        max_num_tokens=next_power_of_2(config.M),
602
603
604
        activation="silu",
        device=vllm_config.device_config.device,
        routing_method=RoutingMethodType.DeepSeekV3,
605
606
    )

607
608
609
610
    prepare_finalize = maybe_make_prepare_finalize(
        moe=moe,
        quant_config=quant_config,
        allow_new_interface=True,
611
    )
612
    assert prepare_finalize is not None
613
614
615
616

    fused_experts = make_fused_experts(
        config.fused_experts_type,
        moe,
617
        quant_config,
618
        prepare_finalize.num_dispatchers(),
619
        config.N,
620
    )
621
622

    modular_kernel = mk.FusedMoEModularKernel(
623
624
        prepare_finalize=prepare_finalize,
        fused_experts=fused_experts,
625
        inplace=False,
626
    )
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643

    return modular_kernel


def run_modular_kernel(
    pgi: ProcessGroupInfo,
    vllm_config: VllmConfig,
    config: Config,
    weights: WeightTensors,
    rank_tensors: RankTensors,
) -> torch.Tensor:
    assert isinstance(config.Ms, int)
    assert isinstance(config.topks, int)

    # weights for rank
    rank_weights = weights.slice_weights(pgi.rank, config.num_local_experts)

644
645
646
647
648
649
650
651
652
653
    if config.quant_dtype == "nvfp4":
        gscale = _make_gscale(config.num_local_experts)
    else:
        gscale = None

    quant_config = FusedMoEQuantConfig.make(
        config.quant_dtype,
        w1_scale=rank_weights.w1_scale,
        w2_scale=rank_weights.w2_scale,
        a1_scale=rank_tensors.hidden_states_scale,
654
655
        g1_alphas=(1 / rank_weights.w1_gs) if rank_weights.w1_gs is not None else None,
        g2_alphas=(1 / rank_weights.w2_gs) if rank_weights.w2_gs is not None else None,
656
657
658
659
660
661
662
663
664
665
666
667
        a1_gscale=gscale,
        a2_gscale=gscale,
        block_shape=config.quant_block_shape,
        per_act_token_quant=config.is_per_act_token_quant,
        per_out_ch_quant=config.is_per_out_ch_quant,
    )

    mk = make_modular_kernel(config, vllm_config, quant_config)

    # impls might update the tensor in place
    hidden_states = rank_tensors.hidden_states.clone()

668
    topk_ids = rank_tensors.topk_ids.to(mk.prepare_finalize.topk_indices_dtype())
669
670

    mk_kwargs = {
671
672
673
674
675
676
677
678
679
        "hidden_states": hidden_states,
        "w1": rank_weights.w1,
        "w2": rank_weights.w2,
        "topk_weights": rank_tensors.topk_weights,
        "topk_ids": topk_ids,
        "expert_map": rank_tensors.expert_map,
        "global_num_experts": config.E,
        "apply_router_weight_on_input": config.topk == 1
        and config.supports_apply_weight_on_input(),
680
    }
681
682

    num_tokens = rank_tensors.hidden_states.shape[0]
683
684
685
    num_tokens_across_dp = torch.tensor(
        [num_tokens] * config.world_size, device="cuda", dtype=torch.int
    )
686
687

    with set_forward_context(
688
689
690
691
        None,
        vllm_config,
        num_tokens=num_tokens,
        num_tokens_across_dp=num_tokens_across_dp,
692
693
    ):
        out = mk.forward(**mk_kwargs)
694
695

    return out