mxfp4.py 28.8 KB
Newer Older
1
2
3
4
5
6
7
8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, Optional

import torch
from torch.nn.parameter import Parameter

from vllm import envs
9

10
from vllm.config import get_current_vllm_config
11
from vllm.logger import init_logger
12
13
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
                                                  FusedMoEMethodBase)
14
15
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
from vllm.model_executor.layers.fused_moe.trtllm_moe import TrtLlmGenExperts
16
17
18
19
20
from vllm.model_executor.layers.linear import (LinearBase,
                                               UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig, QuantizeMethodBase)
21
22
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
    prepare_moe_fp4_layer_for_marlin)
23
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
24
    _can_support_mxfp4, _swizzle_mxfp4)
25
26
27
from vllm.model_executor.layers.quantization.utils.quant_utils import (
    is_layer_skipped)
from vllm.model_executor.utils import set_weight_attrs
28
from vllm.platforms import current_platform
29
30
31
from vllm.scalar_type import scalar_types
from vllm.utils import (has_triton_kernels, is_torch_equal_or_newer,
                        next_power_of_2, round_up)
32
from vllm.utils.flashinfer import has_flashinfer
33

34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
logger = init_logger(__name__)


def _should_use_flashinfer_mxfp4_bf16():
    """Determine if FlashInfer MXFP4 BF16 should be used."""
    # If explicitly set, respect the setting
    if envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16"):
        return envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16

    # Enable by default on SM100 if MXFP8 is not explicitly enabled
    if (current_platform.is_device_capability(100) and has_flashinfer()
            and not envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8")):
        logger.info_once(
            "Enabling FlashInfer MXFP4 BF16 backend by default for Blackwell. "
            "For faster performance, consider setting "
            "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1, "
            "though this may impact accuracy.")
        return True

    return False


def _should_use_flashinfer_mxfp4_mxfp8():
    """Determine if FlashInfer MXFP4 MXFP8 should be used."""
    return envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8


def should_use_flashinfer_mxfp4():
    return (_should_use_flashinfer_mxfp4_mxfp8()
            or _should_use_flashinfer_mxfp4_bf16())
64
65
66
67
68
69
70
71
72
73
74
75
76
77


class Mxfp4Config(QuantizationConfig):

    def __init__(self, ignored_layers: Optional[list[str]] = None):
        super().__init__()
        self.ignored_layers = ignored_layers

    @classmethod
    def from_config(cls, config):
        return cls()

    @classmethod
    def get_min_capability(cls) -> int:
78
        return 80
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113

    @classmethod
    def get_name(cls) -> QuantizationMethods:
        return "mxfp4"

    @classmethod
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
        return [torch.bfloat16]

    @classmethod
    def get_config_filenames(cls) -> list[str]:
        return []

    def get_quant_method(self, layer: torch.nn.Module,
                         prefix: str) -> Optional["QuantizeMethodBase"]:
        from vllm.attention.layer import Attention  # Avoid circular import

        if isinstance(layer, LinearBase):
            if self.ignored_layers and is_layer_skipped(
                    prefix=prefix,
                    ignored_layers=self.ignored_layers,
                    fused_mapping=self.packed_modules_mapping):
                return UnquantizedLinearMethod()
            raise NotImplementedError("Mxfp4 linear layer is not implemented")
        elif isinstance(layer, FusedMoE):
            return Mxfp4MoEMethod(layer.moe_config)
        elif isinstance(layer, Attention):
            raise NotImplementedError(
                "Mxfp4 attention layer is not implemented")
        return None


class Mxfp4MoEMethod(FusedMoEMethodBase):

    def __init__(self, moe: FusedMoEConfig):
114
        super().__init__(moe)
115
116
        self.topk_indices_dtype = None
        self.moe = moe
117
        self.use_marlin = self._should_use_marlin()
118
119
        self.max_capture_size = get_current_vllm_config(
        ).compilation_config.max_capture_size
120

121
122
123
        if current_platform.is_device_capability(100) and not has_flashinfer():
            logger.warning_once(
                "MXFP4 MoE is enabled on Blackwell but FlashInfer "
124
125
                "is not available. This may result in degraded performance. "
                "Please `pip install vllm[flashinfer]` for best results.")
126

127
128
129
130
131
132
        if current_platform.is_device_capability(100) and not has_flashinfer():
            logger.warning_once(
                "MXFP4 MoE is enabled on Blackwell but FlashInfer "
                "is not available. This may result in degraded performance. "
                "Please `pip install vllm[flashinfer]` for best results.")

133
134
135
136
    def _should_use_marlin(self):
        if envs.VLLM_MXFP4_USE_MARLIN is not None:
            return envs.VLLM_MXFP4_USE_MARLIN
        if current_platform.is_cuda() and \
137
138
                not current_platform.is_device_capability(100):
            if not current_platform.has_device_capability(90):
139
140
141
142
143
144
145
                # marlin kernel has better performance on ampere
                return True
            if not has_triton_kernels():
                return True
            if not is_torch_equal_or_newer("2.8.0"):
                return True
        return False
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165

    def create_weights(self, layer: torch.nn.Module, num_experts: int,
                       hidden_size: int, intermediate_size_per_partition: int,
                       params_dtype: torch.dtype, **extra_weight_attrs):
        self.num_experts = num_experts
        weight_dtype = torch.uint8
        scale_dtype = torch.uint8

        # FIXME (zyongye): ship after torch and safetensors support mxfp4
        # is_torch_mxfp4_available = (
        #     hasattr(torch, "float4_e2m1fn_x2") and
        #     hasattr(torch, "float8_e8m0fnu"))
        # if is_torch_mxfp4_available:
        #     weight_dtype = torch.float4_e2m1fn_x2
        #     scale_dtype = torch.float8_e8m0fnu

        mxfp4_block = 32

        intermediate_size_per_partition_after_pad = \
            intermediate_size_per_partition
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
        if self.use_marlin:
            # The moe marlin kernel requires that for each linear
            # n % 256 == 0 and k % 128 == 0.
            # In gate_up_proj:
            #    n = 2 * intermediate_size_per_partition_after_pad
            #    k = hidden_size
            # In down_proj
            #    n = hidden_size
            #    k = intermediate_size_per_partition_after_pad
            intermediate_size_per_partition_after_pad = round_up(
                intermediate_size_per_partition, 128)
            hidden_size = round_up(hidden_size, 256)

            layer.params_dtype = params_dtype
            layer.num_experts = num_experts
            layer.hidden_size = hidden_size
            layer.intermediate_size_per_partition = \
                intermediate_size_per_partition_after_pad
184
        elif should_use_flashinfer_mxfp4():
185
186
187
            # pad the intermediate size to be a multiple of 2 * mxfp4_block
            # for to hold non-uniform sharded tensor as well as swizzling
            # other padding to increase performance
188
189
190
            intermediate_size_per_partition_after_pad = round_up(
                intermediate_size_per_partition, 256)
            hidden_size = round_up(hidden_size, 256)
191
192
193
194
195
196
        elif current_platform.is_rocm():
            intermediate_size_per_partition_after_pad = round_up(
                intermediate_size_per_partition, 128)
        else:
            intermediate_size_per_partition_after_pad = round_up(
                intermediate_size_per_partition, 64)
197
198
199
200

        self.intermediate_size = intermediate_size_per_partition_after_pad
        self.hidden_size = hidden_size
        # Fused gate_up_proj (column parallel)
201
202
203
204
205
206
207
208
209
        w13_weight = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                2 * intermediate_size_per_partition_after_pad,
                hidden_size // 2,
                dtype=weight_dtype,
            ),
            requires_grad=False,
        )
210
211
212
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

213
214
215
216
217
218
219
220
221
        w13_weight_scale = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                2 * intermediate_size_per_partition_after_pad,
                hidden_size // mxfp4_block,
                dtype=scale_dtype,
            ),
            requires_grad=False,
        )
222
223
224
        layer.register_parameter("w13_weight_scale", w13_weight_scale)
        set_weight_attrs(w13_weight_scale, extra_weight_attrs)

225
226
227
228
229
230
231
232
        w13_bias = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                2 * intermediate_size_per_partition_after_pad,
                dtype=torch.bfloat16,
            ),
            requires_grad=False,
        )
233
234
235
236
        layer.register_parameter("w13_bias", w13_bias)
        set_weight_attrs(w13_bias, extra_weight_attrs)

        # down_proj (row parallel)
237
238
239
240
241
242
243
244
245
        w2_weight = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                hidden_size,
                intermediate_size_per_partition_after_pad // 2,
                dtype=weight_dtype,
            ),
            requires_grad=False,
        )
246
247
248
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

249
250
251
252
253
254
255
256
257
        w2_weight_scale = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                hidden_size,
                intermediate_size_per_partition_after_pad // mxfp4_block,
                dtype=scale_dtype,
            ),
            requires_grad=False,
        )
258
259
260
        layer.register_parameter("w2_weight_scale", w2_weight_scale)
        set_weight_attrs(w2_weight_scale, extra_weight_attrs)

261
262
263
264
265
266
267
268
        w2_bias = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                hidden_size,
                dtype=torch.bfloat16,
            ),
            requires_grad=False,
        )
269
270
271
272
        layer.register_parameter("w2_bias", w2_bias)
        set_weight_attrs(w2_bias, extra_weight_attrs)

    def process_weights_after_loading(self, layer):
273
274
        if self.use_marlin:
            prepare_moe_fp4_layer_for_marlin(layer)
275
276
        elif should_use_flashinfer_mxfp4():
            from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
            layer.gemm1_alpha = Parameter(torch.tensor(
                [1.702] * self.num_experts, dtype=torch.float32).cuda(),
                                          requires_grad=False)
            layer.gemm1_beta = Parameter(torch.tensor(
                [1.0] * self.num_experts, dtype=torch.float32).cuda(),
                                         requires_grad=False)
            layer.gemm1_clamp_limit = Parameter(torch.tensor(
                [7.0] * self.num_experts, dtype=torch.float32).cuda(),
                                                requires_grad=False)
            sf_block_size = 32  # mxfp4 block size

            assert (layer.w13_weight.dim() == 3
                    and layer.w13_weight.shape[0] == self.num_experts
                    and layer.w13_weight.shape[1] == self.intermediate_size * 2
                    and layer.w13_weight.shape[2] == self.hidden_size // 2)
            assert (layer.w13_weight_scale.dim() == 3
                    and layer.w13_weight_scale.shape[0] == self.num_experts
                    and layer.w13_weight_scale.shape[1]
                    == self.intermediate_size * 2
                    and layer.w13_weight_scale.shape[2]
                    == self.hidden_size // sf_block_size)
            assert (layer.w2_weight.dim() == 3
                    and layer.w2_weight.shape[0] == self.num_experts
                    and layer.w2_weight.shape[1] == self.hidden_size and
                    layer.w2_weight.shape[2] == self.intermediate_size // 2)
            assert (layer.w2_weight_scale.dim() == 3
                    and layer.w2_weight_scale.shape[1] == self.hidden_size
                    and layer.w2_weight_scale.shape[2]
                    == self.intermediate_size // sf_block_size)
            assert (layer.w13_bias.dim() == 2
                    and layer.w13_bias.shape[0] == self.num_experts
                    and layer.w13_bias.shape[1] == self.intermediate_size * 2)
            assert (layer.w2_bias.dim() == 2
                    and layer.w2_bias.shape[0] == self.num_experts
                    and layer.w2_bias.shape[1] == self.hidden_size)

            w13_weight_scale = layer.w13_weight_scale.data
            w2_weight_scale = layer.w2_weight_scale.data
            w13_weight = layer.w13_weight.data
            w2_weight = layer.w2_weight.data
            w13_bias = layer.w13_bias.data.to(torch.float32)
            w2_bias = layer.w2_bias.data.to(torch.float32)

            # Swap w1 and w3 as the defenition of
            # swiglu is different in the trtllm-gen
            def swap_every_two_rows(x, axis=-1):
                shape = x.shape
                if axis < 0:
                    axis = len(shape) + axis

                # Create a new shape with pairs swapped along specified axis
                new_shape = list(shape)
                new_shape[axis] = shape[axis] // 2
                new_shape.insert(axis + 1, 2)

                # Reshape to expose pairs, swap them, and reshape back
                x = x.reshape(*new_shape)
                x = x.flip(axis + 1)
                new_shape = list(shape)
                return x.reshape(*new_shape)

            w13_weight_scale = swap_every_two_rows(w13_weight_scale, -2)
            w13_weight = swap_every_two_rows(w13_weight, -2)
            w13_bias = swap_every_two_rows(w13_bias, -1)

            # Do not interleave as the checkpoint is already interleaved

            # Shuffle weights and scaling factors for transposed mma output
            gemm1_weights_mxfp4_shuffled = []
            gemm1_scales_mxfp4_shuffled = []
            gemm2_weights_mxfp4_shuffled = []
            gemm2_scales_mxfp4_shuffled = []
            gemm1_bias_shuffled = []
            gemm2_bias_shuffled = []
            epilogue_tile_m = 128  # FIXME: this depends on the kernel internals
            for i in range(self.num_experts):
                gemm1_weights_mxfp4_shuffled.append(
                    shuffle_matrix_a(w13_weight[i].view(torch.uint8),
                                     epilogue_tile_m))
                gemm1_scales_mxfp4_shuffled.append(
                    shuffle_matrix_sf_a(w13_weight_scale[i].view(torch.uint8),
                                        epilogue_tile_m))
                gemm1_bias_shuffled.append(
                    shuffle_matrix_a(w13_bias[i].clone().reshape(-1, 1),
                                     epilogue_tile_m))

                gemm2_weights_mxfp4_shuffled.append(
                    shuffle_matrix_a(w2_weight[i].view(torch.uint8),
                                     epilogue_tile_m))
                gemm2_scales_mxfp4_shuffled.append(
                    shuffle_matrix_sf_a(w2_weight_scale[i].view(torch.uint8),
                                        epilogue_tile_m))
                gemm2_bias_shuffled.append(
                    shuffle_matrix_a(w2_bias[i].clone().reshape(-1, 1),
                                     epilogue_tile_m))

            w13_weight = torch.stack(gemm1_weights_mxfp4_shuffled)
            w13_weight_scale = torch.stack(
                gemm1_scales_mxfp4_shuffled).reshape(
                    self.num_experts, 2 * self.intermediate_size,
                    self.hidden_size // sf_block_size).view(
                        torch.float8_e4m3fn)

            w2_weight = torch.stack(gemm2_weights_mxfp4_shuffled)
            w2_weight_scale = torch.stack(gemm2_scales_mxfp4_shuffled).reshape(
                self.num_experts, self.hidden_size, self.intermediate_size //
                sf_block_size).view(torch.float8_e4m3fn)

            layer.w13_weight = Parameter(w13_weight, requires_grad=False)
            layer.w13_weight_scale = Parameter(w13_weight_scale,
                                               requires_grad=False)
            layer.w2_weight = Parameter(w2_weight, requires_grad=False)
            layer.w2_weight_scale = Parameter(w2_weight_scale,
                                              requires_grad=False)
            layer.w13_bias = Parameter(
                torch.stack(gemm1_bias_shuffled).reshape(self.num_experts, -1),
                requires_grad=False)
            layer.w2_bias = Parameter(torch.stack(gemm2_bias_shuffled).reshape(
                self.num_experts, -1),
                                      requires_grad=False)
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
        else:
            from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig

            w13_bias = layer.w13_bias.to(torch.float32)
            w2_bias = layer.w2_bias.to(torch.float32)

            layer.w13_bias = Parameter(w13_bias, requires_grad=False)
            layer.w2_bias = Parameter(w2_bias, requires_grad=False)

            # FIXME warp need to be adjusted based on batch size
            # only apply to  batched mode
            if self.moe.use_ep:
                num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8
            else:
                num_warps = 8

            w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
                layer.w13_weight, layer.w13_weight_scale, num_warps)
            w2_weight, w2_flex, w2_scale = _swizzle_mxfp4(
                layer.w2_weight, layer.w2_weight_scale, num_warps)

            self.w13_precision_config = PrecisionConfig(
                weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex))
            self.w2_precision_config = PrecisionConfig(
                weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex))

            self.w13_weight_triton_tensor = w13_weight
            self.w2_weight_triton_tensor = w2_weight

            # need to delete the original weights to save memory on single GPU
            del layer.w13_weight
            del layer.w2_weight
            layer.w13_weight = None
            layer.w2_weight = None
            torch.cuda.empty_cache()
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456

    def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int):
        # Number of tokens in the input tensor.
        num_tokens = x.shape[0]
        # Factor to account for the imbalance of the experts.
        # factor equals to the
        # max_real_num_tokens_per_expert / perfect_num_tokens_per_expert
        # - 1.0 means perfect expert distribution.
        # - > 1.0 means some experts have more
        #     tokens than the perfect distribution.
        # - < 1.0 does not make sense.
        imbalance_factor = 1.3
        # Calculate the number of tokens per expert
        # assuming perfect distribution.
        num_tokens_per_expert = (num_tokens * top_k) // self.num_experts
        # Apply the imbalance factor.
        num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor)
        # And pad the number to the next power of 2.
        tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
        # Cap to 8-64 tokens per CTA tile
        # as it's the range supported by the kernel.
        tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)

        return tile_tokens_dim

457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
    def select_gemm_impl(
        self,
        prepare_finalize: mk.FusedMoEPrepareAndFinalize,
        moe: FusedMoEConfig,
        layer: torch.nn.Module,
    ) -> mk.FusedMoEPermuteExpertsUnpermute:
        if (prepare_finalize.activation_format ==
                mk.FusedMoEActivationFormat.BatchedExperts):
            raise NotImplementedError(
                "Mxfp4 does not support batched experts format for EP")
        else:
            if should_use_flashinfer_mxfp4():
                # B200 code-path
                kwargs = {
                    "gemm1_alpha": layer.gemm1_alpha,
                    "gemm1_beta": layer.gemm1_beta,
                    "gemm1_clamp_limit": layer.gemm1_clamp_limit,
                    "w13_bias": layer.w13_bias,
                    "w2_bias": layer.w2_bias,
                    "max_capture_size": self.max_capture_size,
                }
                return TrtLlmGenExperts(moe, **kwargs)
            else:
                # Use matmul_ogs from triton_kernels here!
                raise NotImplementedError(
                    "Mxfp4 does not support non-batched experts format for EP")

    def _route_and_experts(
            self,
            layer: torch.nn.Module,
            x: torch.Tensor,
            router_logits: torch.Tensor,
            top_k: int,
            renormalize: bool,
            use_grouped_topk: bool = False,
            topk_group: Optional[int] = None,
            num_expert_group: Optional[int] = None,
            global_num_experts: int = -1,
            expert_map: Optional[torch.Tensor] = None,
            custom_routing_function: Optional[Callable] = None,
            scoring_func: str = "softmax",
            e_score_correction_bias: Optional[torch.Tensor] = None,
            apply_router_weight_on_input: bool = False,
            activation: str = "silu",
            enable_eplb: bool = False,
            expert_load_view: Optional[torch.Tensor] = None,
            logical_to_physical_map: Optional[torch.Tensor] = None,
            logical_replica_count: Optional[torch.Tensor] = None
    ) -> torch.Tensor:

        assert isinstance(self.fused_experts, mk.FusedMoEModularKernel)

        topk_weights, topk_ids = FusedMoE.select_experts(
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
            num_expert_group=num_expert_group,
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
            e_score_correction_bias=e_score_correction_bias,
            indices_type=self.topk_indices_dtype,
            enable_eplb=enable_eplb,
            expert_map=expert_map,
            expert_load_view=expert_load_view,
            logical_to_physical_map=logical_to_physical_map,
            logical_replica_count=logical_replica_count)

        return self.fused_experts(
            hidden_states=x,
            w1=layer.w13_weight,
            w2=layer.w2_weight,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            inplace=True,
            activation=activation,
            global_num_experts=global_num_experts,
            expert_map=expert_map,
            w1_scale=layer.w13_weight_scale,
            w2_scale=layer.w2_weight_scale,
            apply_router_weight_on_input=apply_router_weight_on_input,
        )

542
543
544
545
546
547
548
549
550
551
552
553
554
555
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
        use_grouped_topk: bool = False,
        topk_group: Optional[int] = None,
        num_expert_group: Optional[int] = None,
        global_num_experts: int = -1,
        expert_map: Optional[torch.Tensor] = None,
        custom_routing_function: Optional[Callable] = None,
        scoring_func: str = "softmax",
556
        routed_scaling_factor: float = 1.0,
557
558
559
560
561
562
563
564
565
566
567
568
        e_score_correction_bias: Optional[torch.Tensor] = None,
        apply_router_weight_on_input: bool = False,
        activation: str = "silu",
        enable_eplb: bool = False,
        expert_load_view: Optional[torch.Tensor] = None,
        logical_to_physical_map: Optional[torch.Tensor] = None,
        logical_replica_count: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:

        if enable_eplb:
            raise NotImplementedError("EPLB is not supported for mxfp4")

569
570
571
572
573
574
575
576
577
578
579
        if self.use_marlin:
            topk_weights, topk_ids = FusedMoE.select_experts(
                hidden_states=x,
                router_logits=router_logits,
                use_grouped_topk=use_grouped_topk,
                top_k=top_k,
                renormalize=renormalize,
                topk_group=topk_group,
                num_expert_group=num_expert_group,
                custom_routing_function=custom_routing_function,
                scoring_func=scoring_func,
580
                routed_scaling_factor=routed_scaling_factor,
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
                e_score_correction_bias=e_score_correction_bias)

            return torch.ops.vllm.fused_marlin_moe(
                x,
                layer.w13_weight,
                layer.w2_weight,
                layer.w13_bias,
                layer.w2_bias,
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                router_logits,
                topk_weights,
                topk_ids,
                global_scale1=None,
                global_scale2=None,
                quant_type_id=scalar_types.float4_e2m1f.id,
                apply_router_weight_on_input=apply_router_weight_on_input,
                global_num_experts=global_num_experts,
                activation=activation,
                expert_map=expert_map)

602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
        if self.fused_experts is not None:
            return self._route_and_experts(
                layer,
                x,
                router_logits,
                top_k,
                renormalize,
                use_grouped_topk,
                topk_group,
                num_expert_group,
                global_num_experts,
                expert_map,
                custom_routing_function,
                scoring_func,
                e_score_correction_bias,
                apply_router_weight_on_input,
                activation,
                enable_eplb,
                expert_load_view,
                logical_to_physical_map,
                logical_replica_count,
            )

625
626
627
628
629
        assert _can_support_mxfp4(
            use_grouped_topk, topk_group, num_expert_group, expert_map,
            custom_routing_function, e_score_correction_bias,
            apply_router_weight_on_input, scoring_func, activation,
            expert_load_view, logical_to_physical_map,
630
631
            logical_replica_count), (
                "MXFP4 are not supported with this configuration.")
632

633
634
635
        if should_use_flashinfer_mxfp4():
            from flashinfer import mxfp8_quantize, trtllm_fp4_block_scale_moe
            if _should_use_flashinfer_mxfp4_bf16():
636
637
638
639
640
                assert x.dtype == torch.bfloat16
                x_quant = x
                x_scale = None
            else:
                x_quant, x_scale = mxfp8_quantize(x, False)  # to mxfp8
641
642
                x_scale = x_scale.view(torch.float8_e4m3fn).reshape(
                    *x.shape[:-1], -1)
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
            trtllm_gen_output = trtllm_fp4_block_scale_moe(
                router_logits.to(torch.bfloat16),
                None,  # routing_bias
                x_quant,
                x_scale,
                layer.w13_weight,  # uint8 (e2m1 x 2)
                layer.w13_weight_scale,  # uint8 (e4m3 x 2)
                layer.w13_bias,  # fp32 per expert per channel
                layer.gemm1_alpha,  # fp32 per expert
                layer.gemm1_beta,  # fp32 per expert
                layer.gemm1_clamp_limit,  # fp32 per expert
                layer.w2_weight,  # uint8 (e2m1 x 2)
                layer.w2_weight_scale,  # ue8m0
                layer.w2_bias,  # fp32 per expert per channel
                None,  # output1_scale_scalar
                None,  # output1_scale_gate_scalar
                None,  # output2_scale_scalar
660
                global_num_experts,
661
662
663
664
                top_k,
                None,  # n_group
                None,  # topk_group
                self.intermediate_size,  # padded to multiple of 256
665
                layer.ep_rank * layer.local_num_experts,  # local_expert_offset
666
667
668
669
670
                self.num_experts,  # local num experts
                None,
                self._get_tile_tokens_dim(x, top_k),
                1 if renormalize else 0,  # routing_method_type, renormalize
                True,  # do finalize
671
                tune_max_num_tokens=self.max_capture_size,
672
673
            )[0]
            return trtllm_gen_output
674
        else:
675
676
            from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (  # noqa: E501
                triton_kernel_moe_forward)
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
            return triton_kernel_moe_forward(
                hidden_states=x,
                w1=self.w13_weight_triton_tensor,
                w2=self.w2_weight_triton_tensor,
                gating_output=router_logits,
                topk=top_k,
                renormalize=renormalize,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                w1_bias=layer.w13_bias,
                w2_bias=layer.w2_bias,
                w1_precision=self.w13_precision_config,
                w2_precision=self.w2_precision_config,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )