mxfp4.py 45.6 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections.abc import Callable
4
from enum import Enum
5
from typing import Optional
6
7
8
9
10

import torch
from torch.nn.parameter import Parameter

from vllm import envs
11
from vllm.attention.layer import Attention
12
from vllm.config import get_current_vllm_config
13
from vllm.logger import init_logger
14
15
16
17
18
from vllm.model_executor.layers.fused_moe import (
    FusedMoE,
    FusedMoEConfig,
    FusedMoEMethodBase,
)
19
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
20
from vllm.model_executor.layers.fused_moe.config import (
21
    FusedMoEQuantConfig,
22
    mxfp4_mxfp8_moe_quant_config,
23
    mxfp4_w4a16_moe_quant_config,
24
    ocp_mx_moe_quant_config,
25
)
26
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
27
    BatchedMarlinExperts,
28
29
30
    MarlinExperts,
    fused_marlin_moe,
)
31
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
32
33
    OAITritonExperts,
)
34
from vllm.model_executor.layers.fused_moe.trtllm_moe import TrtLlmGenExperts
35
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
36
37
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
38
39
40
    QuantizationConfig,
    QuantizeMethodBase,
)
41
42
43
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
    get_marlin_input_dtype,
)
44
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
45
46
    prepare_moe_fp4_layer_for_marlin,
)
47
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
48
49
    _can_support_mxfp4,
    _swizzle_mxfp4,
50
    get_padding_alignment,
51
52
)
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
53
from vllm.model_executor.utils import set_weight_attrs
54
from vllm.platforms import current_platform
55
from vllm.scalar_type import scalar_types
56
from vllm.utils.flashinfer import has_flashinfer
57
from vllm.utils.import_utils import has_triton_kernels
Cyrus Leung's avatar
Cyrus Leung committed
58
from vllm.utils.math_utils import round_up
59
from vllm.utils.torch_utils import is_torch_equal_or_newer
60

61
62
63
logger = init_logger(__name__)


64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# enum for mxfp4 backend
class Mxfp4Backend(Enum):
    NONE = 0

    # FlashInfer Backend
    SM100_FI_MXFP4_MXFP8_TRTLLM = 1
    SM100_FI_MXFP4_MXFP8_CUTLASS = 2
    SM100_FI_MXFP4_BF16 = 3
    SM90_FI_MXFP4_BF16 = 4

    # Marlin Backend
    MARLIN = 5

    # Triton Backend
    TRITON = 6


81
82
83
84
85
86
87
88
def get_mxfp4_backend_with_lora() -> Mxfp4Backend:
    """
    Not all MXFP4 backends support LoRA. Select backends that are known to
    have LoRA support.
    """
    if not current_platform.is_cuda():
        return Mxfp4Backend.NONE

89
90
    logger.info_once("[get_mxfp4_backend_with_lora] Using Marlin backend")
    return Mxfp4Backend.MARLIN
91
92
93


def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend:
94
    # Backend Selection
95
96
97
98

    if with_lora_support:
        return get_mxfp4_backend_with_lora()

99
    if current_platform.is_cuda():
100
101
102
103
104
        if (
            current_platform.is_device_capability(90)
            and has_flashinfer()
            and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16
        ):
105
106
            logger.info_once("Using FlashInfer MXFP4 BF16 backend for SM90")
            return Mxfp4Backend.SM90_FI_MXFP4_BF16
107
108
109
110
111
112
        elif (
            current_platform.is_device_capability(100)
            and has_flashinfer()
            and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS
        ):
            logger.info_once("Using FlashInfer MXFP4 MXFP8 CUTLASS backend for SM100")
113
            return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
114
115
116
117
118
        elif (
            current_platform.is_device_capability(100)
            and has_flashinfer()
            and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
        ):
119
120
121
122
123
124
            return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
        elif current_platform.is_device_capability(100) and has_flashinfer():
            logger.info_once(
                "Using FlashInfer MXFP4 BF16 backend for SM100, "
                "For faster performance on SM100, consider setting "
                "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1, though this may impact "
125
126
                "accuracy."
            )
127
            return Mxfp4Backend.SM100_FI_MXFP4_BF16
128
129
130
131
        elif (
            current_platform.is_device_capability(100)
            or current_platform.is_device_capability(90)
        ) and not has_flashinfer():
132
133
134
            logger.warning_once(
                "MXFP4 MoE is enabled on Hopper/Blackwell but FlashInfer "
                "is not available. This may result in degraded performance. "
135
136
                "Please `pip install vllm[flashinfer]` for best results."
            )
137

138
        # If FlashInfer is not available, try either Marlin or Triton
139
140
141
142
143
144
145
146
147
        triton_kernels_supported = (
            has_triton_kernels()
            and is_torch_equal_or_newer("2.8.0")
            # NOTE: triton_kernels are only confirmed to work on SM90 and SM100
            # SM110 fails with this error: https://github.com/vllm-project/vllm/issues/29317
            # SM120 needs this fix: https://github.com/triton-lang/triton/pull/8498
            and (9, 0) <= current_platform.get_device_capability() < (11, 0)
        )
        if envs.VLLM_MXFP4_USE_MARLIN or not triton_kernels_supported:
148
149
150
151
152
            logger.info_once("Using Marlin backend")
            return Mxfp4Backend.MARLIN
        else:
            logger.info_once("Using Triton backend")
            return Mxfp4Backend.TRITON
153
154
155
    elif current_platform.is_xpu():
        logger.info_once("Using ipex marlin backend on XPU")
        return Mxfp4Backend.MARLIN
156
157
158
    elif current_platform.is_rocm() and has_triton_kernels():
        logger.info_once("Using Triton backend")
        return Mxfp4Backend.TRITON
159

160
    return Mxfp4Backend.NONE
161
162
163


class Mxfp4Config(QuantizationConfig):
164
    def __init__(self, ignored_layers: list[str] | None = None):
165
166
167
168
169
170
171
172
173
        super().__init__()
        self.ignored_layers = ignored_layers

    @classmethod
    def from_config(cls, config):
        return cls()

    @classmethod
    def get_min_capability(cls) -> int:
174
        return 80
175
176
177
178
179
180
181
182
183
184
185
186
187

    @classmethod
    def get_name(cls) -> QuantizationMethods:
        return "mxfp4"

    @classmethod
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
        return [torch.bfloat16]

    @classmethod
    def get_config_filenames(cls) -> list[str]:
        return []

188
189
190
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
191
192
        if isinstance(layer, LinearBase):
            if self.ignored_layers and is_layer_skipped(
193
194
195
196
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
197
                return UnquantizedLinearMethod()
198
199
200
            # TODO: Add support for MXFP4 Linear Method.
            # MXFP4 LinearMethod is available in AMD-Quark, refer to that implementation
            # if you are interested in enabling MXFP4 here.
201
            logger.debug_once(
202
                "MXFP4 linear layer is not implemented - falling back to "
203
204
                "UnquantizedLinearMethod.",
                scope="local",
205
206
            )
            return UnquantizedLinearMethod()
207
        elif isinstance(layer, FusedMoE):
208
209
210
            if current_platform.is_xpu():
                return IpexMxfp4MoEMethod(layer.moe_config)
            else:
211
212
213
                quant_method = Mxfp4MoEMethod(layer.moe_config)
                quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
                return quant_method
214
        elif isinstance(layer, Attention):
215
            # TODO: Add support for MXFP4 Attention.
216
            logger.debug_once(
217
                "MXFP4 attention layer is not implemented. "
218
219
                "Skipping quantization for this layer.",
                scope="local",
220
            )
221
222
223
224
225
        return None


class Mxfp4MoEMethod(FusedMoEMethodBase):
    def __init__(self, moe: FusedMoEConfig):
226
        super().__init__(moe)
227
        self.mxfp4_backend = get_mxfp4_backend(moe.is_lora_enabled)
228
229

        self.marlin_input_dtype = None
230
        self.use_marlin = self.mxfp4_backend == Mxfp4Backend.MARLIN
231
        self.max_capture_size = (
232
            get_current_vllm_config().compilation_config.max_cudagraph_capture_size
233
        )
234

235
        assert self.mxfp4_backend != Mxfp4Backend.NONE, (
236
237
            f"get_mxfp4_backend(with_lora_support={moe.is_lora_enabled}) found"
            "no compatible MXFP4 MoE backend (FlashInfer/Marlin/Triton)."
238
239
            "Please check your environment and try again."
        )
240
        self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
241

242
243
244
245
246
247
248
249
250
    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
251
252
253
254
255
256
257
258
259
260
261
262
263
264
        self.num_experts = num_experts
        weight_dtype = torch.uint8
        scale_dtype = torch.uint8

        # FIXME (zyongye): ship after torch and safetensors support mxfp4
        # is_torch_mxfp4_available = (
        #     hasattr(torch, "float4_e2m1fn_x2") and
        #     hasattr(torch, "float8_e8m0fnu"))
        # if is_torch_mxfp4_available:
        #     weight_dtype = torch.float4_e2m1fn_x2
        #     scale_dtype = torch.float8_e8m0fnu

        mxfp4_block = 32

265
        intermediate_size_per_partition_after_pad = intermediate_size_per_partition
266
        if self.mxfp4_backend == Mxfp4Backend.MARLIN:
267
268
269
270
271
272
273
274
275
            # The moe marlin kernel requires that for each linear
            # n % 256 == 0 and k % 128 == 0.
            # In gate_up_proj:
            #    n = 2 * intermediate_size_per_partition_after_pad
            #    k = hidden_size
            # In down_proj
            #    n = hidden_size
            #    k = intermediate_size_per_partition_after_pad
            intermediate_size_per_partition_after_pad = round_up(
276
277
                intermediate_size_per_partition, 128
            )
278
279
280
281
            if current_platform.is_xpu():
                hidden_size = round_up(hidden_size, 128)
            else:
                hidden_size = round_up(hidden_size, 256)
282
283
284
285

            layer.params_dtype = params_dtype
            layer.num_experts = num_experts
            layer.hidden_size = hidden_size
286
            layer.intermediate_size_per_partition = (
287
                intermediate_size_per_partition_after_pad
288
289
290
291
292
            )
        elif (
            self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
            or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
        ):
293
294
295
            # pad the intermediate size to be a multiple of 2 * mxfp4_block
            # for to hold non-uniform sharded tensor as well as swizzling
            # other padding to increase performance
296
            intermediate_size_per_partition_after_pad = round_up(
297
298
                intermediate_size_per_partition, 256
            )
299
            hidden_size = round_up(hidden_size, 256)
300
301
302
303
        elif (
            self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
            or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
        ):
304
            intermediate_size_per_partition_after_pad = round_up(
305
306
                intermediate_size_per_partition, 128
            )
307
            hidden_size = round_up(hidden_size, 128)
308
        elif current_platform.is_rocm():
309
            pad_align = get_padding_alignment()
310
            intermediate_size_per_partition_after_pad = round_up(
311
                intermediate_size_per_partition, pad_align
312
            )
313
            hidden_size = round_up(hidden_size, pad_align)
314
315
        else:
            intermediate_size_per_partition_after_pad = round_up(
316
317
                intermediate_size_per_partition, 64
            )
318
319
320
321

        self.intermediate_size = intermediate_size_per_partition_after_pad
        self.hidden_size = hidden_size
        # Fused gate_up_proj (column parallel)
322
323
324
325
326
327
328
329
330
        w13_weight = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                2 * intermediate_size_per_partition_after_pad,
                hidden_size // 2,
                dtype=weight_dtype,
            ),
            requires_grad=False,
        )
331
332
333
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

334
335
336
337
338
339
340
341
342
        w13_weight_scale = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                2 * intermediate_size_per_partition_after_pad,
                hidden_size // mxfp4_block,
                dtype=scale_dtype,
            ),
            requires_grad=False,
        )
343
344
345
        layer.register_parameter("w13_weight_scale", w13_weight_scale)
        set_weight_attrs(w13_weight_scale, extra_weight_attrs)

346
347
348
349
350
351
352
353
        w13_bias = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                2 * intermediate_size_per_partition_after_pad,
                dtype=torch.bfloat16,
            ),
            requires_grad=False,
        )
354
355
356
357
        layer.register_parameter("w13_bias", w13_bias)
        set_weight_attrs(w13_bias, extra_weight_attrs)

        # down_proj (row parallel)
358
359
360
361
362
363
364
365
366
        w2_weight = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                hidden_size,
                intermediate_size_per_partition_after_pad // 2,
                dtype=weight_dtype,
            ),
            requires_grad=False,
        )
367
368
369
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

370
371
372
373
374
375
376
377
378
        w2_weight_scale = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                hidden_size,
                intermediate_size_per_partition_after_pad // mxfp4_block,
                dtype=scale_dtype,
            ),
            requires_grad=False,
        )
379
380
381
        layer.register_parameter("w2_weight_scale", w2_weight_scale)
        set_weight_attrs(w2_weight_scale, extra_weight_attrs)

382
383
384
385
386
387
388
389
        w2_bias = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                hidden_size,
                dtype=torch.bfloat16,
            ),
            requires_grad=False,
        )
390
391
392
393
        layer.register_parameter("w2_bias", w2_bias)
        set_weight_attrs(w2_bias, extra_weight_attrs)

    def process_weights_after_loading(self, layer):
394
        if self.mxfp4_backend == Mxfp4Backend.MARLIN:
395
            prepare_moe_fp4_layer_for_marlin(layer, input_dtype=self.marlin_input_dtype)
396
397
398
399
400
        elif (
            self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
            or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
        ):
            from flashinfer.fp4_quantization import nvfp4_block_scale_interleave
401
            from flashinfer.fused_moe.core import get_w2_permute_indices_with_cache
402
403
404
405
406
407
408
409
410
411
412
413
414

            layer.gemm1_alpha = Parameter(
                torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(),
                requires_grad=False,
            )
            layer.gemm1_beta = Parameter(
                torch.tensor([1.0] * self.num_experts, dtype=torch.float32).cuda(),
                requires_grad=False,
            )
            layer.gemm1_clamp_limit = Parameter(
                torch.tensor([7.0] * self.num_experts, dtype=torch.float32).cuda(),
                requires_grad=False,
            )
415
416
            sf_block_size = 32  # mxfp4 block size

417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
            assert (
                layer.w13_weight.dim() == 3
                and layer.w13_weight.shape[0] == self.num_experts
                and layer.w13_weight.shape[1] == self.intermediate_size * 2
                and layer.w13_weight.shape[2] == self.hidden_size // 2
            )
            assert (
                layer.w13_weight_scale.dim() == 3
                and layer.w13_weight_scale.shape[0] == self.num_experts
                and layer.w13_weight_scale.shape[1] == self.intermediate_size * 2
                and layer.w13_weight_scale.shape[2] == self.hidden_size // sf_block_size
            )
            assert (
                layer.w2_weight.dim() == 3
                and layer.w2_weight.shape[0] == self.num_experts
                and layer.w2_weight.shape[1] == self.hidden_size
                and layer.w2_weight.shape[2] == self.intermediate_size // 2
            )
            assert (
                layer.w2_weight_scale.dim() == 3
                and layer.w2_weight_scale.shape[1] == self.hidden_size
                and layer.w2_weight_scale.shape[2]
                == self.intermediate_size // sf_block_size
            )
            assert (
                layer.w13_bias.dim() == 2
                and layer.w13_bias.shape[0] == self.num_experts
                and layer.w13_bias.shape[1] == self.intermediate_size * 2
            )
            assert (
                layer.w2_bias.dim() == 2
                and layer.w2_bias.shape[0] == self.num_experts
                and layer.w2_bias.shape[1] == self.hidden_size
            )
451
452
453
454
455
456
457
458

            w13_weight_scale = layer.w13_weight_scale.data
            w2_weight_scale = layer.w2_weight_scale.data
            w13_weight = layer.w13_weight.data
            w2_weight = layer.w2_weight.data
            w13_bias = layer.w13_bias.data.to(torch.float32)
            w2_bias = layer.w2_bias.data.to(torch.float32)

co63oc's avatar
co63oc committed
459
            # Swap w1 and w3 as the definition of
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
            # swiglu is different in the trtllm-gen
            def swap_every_two_rows(x, axis=-1):
                shape = x.shape
                if axis < 0:
                    axis = len(shape) + axis

                # Create a new shape with pairs swapped along specified axis
                new_shape = list(shape)
                new_shape[axis] = shape[axis] // 2
                new_shape.insert(axis + 1, 2)

                # Reshape to expose pairs, swap them, and reshape back
                x = x.reshape(*new_shape)
                x = x.flip(axis + 1)
                new_shape = list(shape)
                return x.reshape(*new_shape)

            w13_weight_scale = swap_every_two_rows(w13_weight_scale, -2)
            w13_weight = swap_every_two_rows(w13_weight, -2)
            w13_bias = swap_every_two_rows(w13_bias, -1)

            # Do not interleave as the checkpoint is already interleaved

            # Shuffle weights and scaling factors for transposed mma output
            gemm1_weights_mxfp4_shuffled = []
            gemm1_scales_mxfp4_shuffled = []
            gemm2_weights_mxfp4_shuffled = []
            gemm2_scales_mxfp4_shuffled = []
            gemm1_bias_shuffled = []
            gemm2_bias_shuffled = []
            epilogue_tile_m = 128  # FIXME: this depends on the kernel internals
            for i in range(self.num_experts):
492
                # w13 weight shuffling
493
                permute_indices = get_w2_permute_indices_with_cache(
494
495
496
497
                    self._cache_permute_indices,
                    w13_weight[i].view(torch.uint8),
                    epilogue_tile_m,
                )
498
499
500
501
502
                gemm1_weights_mxfp4_shuffled.append(
                    w13_weight[i]
                    .view(torch.uint8)[permute_indices.to(w13_weight.device)]
                    .contiguous()
                )
503
                # w13 scale shuffling
504
                permute_sf_indices = get_w2_permute_indices_with_cache(
505
506
507
508
509
                    self._cache_permute_indices,
                    w13_weight_scale[i].view(torch.uint8),
                    epilogue_tile_m,
                    num_elts_per_sf=16,
                )
510
                gemm1_scales_mxfp4_shuffled.append(
511
512
513
514
515
516
517
518
                    nvfp4_block_scale_interleave(
                        w13_weight_scale[i]
                        .view(torch.uint8)[
                            permute_sf_indices.to(w13_weight_scale.device)
                        ]
                        .contiguous()
                    )
                )
519
                # w13 bias shuffling
520
                permute_bias_indices = get_w2_permute_indices_with_cache(
521
522
523
524
                    self._cache_permute_indices,
                    w13_bias[i].clone().reshape(-1, 1),
                    epilogue_tile_m,
                )
525
526
527
528
529
530
                gemm1_bias_shuffled.append(
                    w13_bias[i]
                    .clone()
                    .reshape(-1, 1)[permute_bias_indices.to(w13_bias.device)]
                    .contiguous()
                )
531
                # w2 weight shuffling
532
                permute_indices = get_w2_permute_indices_with_cache(
533
534
535
536
                    self._cache_permute_indices,
                    w2_weight[i].view(torch.uint8),
                    epilogue_tile_m,
                )
537
538
539
540
541
                gemm2_weights_mxfp4_shuffled.append(
                    w2_weight[i]
                    .view(torch.uint8)[permute_indices.to(w2_weight.device)]
                    .contiguous()
                )
542
                # w2 scale shuffling
543
                permute_sf_indices = get_w2_permute_indices_with_cache(
544
545
546
547
548
                    self._cache_permute_indices,
                    w2_weight_scale[i].view(torch.uint8),
                    epilogue_tile_m,
                    num_elts_per_sf=16,
                )
549
                gemm2_scales_mxfp4_shuffled.append(
550
551
552
553
554
555
556
557
                    nvfp4_block_scale_interleave(
                        w2_weight_scale[i]
                        .view(torch.uint8)[
                            permute_sf_indices.to(w2_weight_scale.device)
                        ]
                        .contiguous()
                    )
                )
558
                # w2 bias shuffling
559
                permute_indices = get_w2_permute_indices_with_cache(
560
561
562
563
                    self._cache_permute_indices,
                    w2_bias[i].clone().reshape(-1, 1),
                    epilogue_tile_m,
                )
564
565
566
567
568
569
                gemm2_bias_shuffled.append(
                    w2_bias[i]
                    .clone()
                    .reshape(-1, 1)[permute_indices.to(w2_bias.device)]
                    .contiguous()
                )
570
571

            w13_weight = torch.stack(gemm1_weights_mxfp4_shuffled)
572
573
574
575
576
577
578
579
580
            w13_weight_scale = (
                torch.stack(gemm1_scales_mxfp4_shuffled)
                .reshape(
                    self.num_experts,
                    2 * self.intermediate_size,
                    self.hidden_size // sf_block_size,
                )
                .view(torch.float8_e4m3fn)
            )
581
582

            w2_weight = torch.stack(gemm2_weights_mxfp4_shuffled)
583
584
585
586
587
588
589
590
591
            w2_weight_scale = (
                torch.stack(gemm2_scales_mxfp4_shuffled)
                .reshape(
                    self.num_experts,
                    self.hidden_size,
                    self.intermediate_size // sf_block_size,
                )
                .view(torch.float8_e4m3fn)
            )
592
593

            layer.w13_weight = Parameter(w13_weight, requires_grad=False)
594
            layer.w13_weight_scale = Parameter(w13_weight_scale, requires_grad=False)
595
            layer.w2_weight = Parameter(w2_weight, requires_grad=False)
596
            layer.w2_weight_scale = Parameter(w2_weight_scale, requires_grad=False)
597
598
            layer.w13_bias = Parameter(
                torch.stack(gemm1_bias_shuffled).reshape(self.num_experts, -1),
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
                requires_grad=False,
            )
            layer.w2_bias = Parameter(
                torch.stack(gemm2_bias_shuffled).reshape(self.num_experts, -1),
                requires_grad=False,
            )
        elif (
            self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
            or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
        ):
            layer.gemm1_alpha = Parameter(
                torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(),
                requires_grad=False,
            )
            layer.gemm1_beta = Parameter(
                torch.tensor([1.0] * self.num_experts, dtype=torch.float32).cuda(),
                requires_grad=False,
            )
            layer.gemm1_clamp_limit = Parameter(
                torch.tensor([7.0] * self.num_experts, dtype=torch.float32).cuda(),
                requires_grad=False,
            )
621
622
623
624

            sf_block_size = 32  # mxfp4 block size

            # Common shape assertions
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
            assert (
                layer.w13_weight.dim() == 3
                and layer.w13_weight.shape[0] == self.num_experts
                and layer.w13_weight.shape[1] == self.intermediate_size * 2
                and layer.w13_weight.shape[2] == self.hidden_size // 2
            )
            assert (
                layer.w13_weight_scale.dim() == 3
                and layer.w13_weight_scale.shape[0] == self.num_experts
                and layer.w13_weight_scale.shape[1] == self.intermediate_size * 2
                and layer.w13_weight_scale.shape[2] == self.hidden_size // sf_block_size
            )
            assert (
                layer.w2_weight.dim() == 3
                and layer.w2_weight.shape[0] == self.num_experts
                and layer.w2_weight.shape[1] == self.hidden_size
                and layer.w2_weight.shape[2] == self.intermediate_size // 2
            )
            assert (
                layer.w2_weight_scale.dim() == 3
                and layer.w2_weight_scale.shape[1] == self.hidden_size
                and layer.w2_weight_scale.shape[2]
                == self.intermediate_size // sf_block_size
            )
            assert (
                layer.w13_bias.dim() == 2
                and layer.w13_bias.shape[0] == self.num_experts
                and layer.w13_bias.shape[1] == self.intermediate_size * 2
            )
            assert (
                layer.w2_bias.dim() == 2
                and layer.w2_bias.shape[0] == self.num_experts
                and layer.w2_bias.shape[1] == self.hidden_size
            )
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683

            # De-interleave and swap for w13 weight, bias, and scales
            w13_w = layer.w13_weight.data
            gate_w, up_w = w13_w[:, ::2, :], w13_w[:, 1::2, :]
            deinterleaved_w13_w = torch.cat([gate_w, up_w], dim=1)
            w1_w, w3_w = torch.chunk(deinterleaved_w13_w, 2, dim=1)
            w13_weight_swapped = torch.cat([w3_w, w1_w], dim=1)

            w13_b = layer.w13_bias.data.to(torch.float32)
            gate_b, up_b = w13_b[:, ::2], w13_b[:, 1::2]
            deinterleaved_w13_b = torch.cat([gate_b, up_b], dim=1)
            b1, b3 = torch.chunk(deinterleaved_w13_b, 2, dim=-1)
            w13_bias_swapped = torch.cat([b3, b1], dim=-1).to(torch.bfloat16)

            w13_s = layer.w13_weight_scale.data
            gate_s, up_s = w13_s[:, ::2, :], w13_s[:, 1::2, :]
            deinterleaved_w13_s = torch.cat([gate_s, up_s], dim=1)
            s1, s3 = torch.chunk(deinterleaved_w13_s, 2, dim=1)
            w13_scale_swapped = torch.cat([s3, s1], dim=1)

            if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS:
                from flashinfer import block_scale_interleave

                orig_shape = w13_scale_swapped.shape
                w13_scale_interleaved = block_scale_interleave(
684
685
                    w13_scale_swapped.view(torch.uint8)
                ).reshape(orig_shape)
686
687
688
689

                w2_s = layer.w2_weight_scale.data
                orig_shape = w2_s.shape
                w2_scale_interleaved = block_scale_interleave(
690
691
692
693
694
695
696
697
698
699
700
                    w2_s.view(torch.uint8)
                ).reshape(orig_shape)

                layer.w13_weight = Parameter(w13_weight_swapped, requires_grad=False)
                layer.w13_weight_scale = Parameter(
                    w13_scale_interleaved, requires_grad=False
                )
                layer.w13_bias = Parameter(w13_bias_swapped, requires_grad=False)
                layer.w2_weight_scale = Parameter(
                    w2_scale_interleaved, requires_grad=False
                )
701
702
703
704
            elif self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16:

                def _interleave_mxfp4_cutlass_sm90(w):
                    w_shape = w.shape
705
706
707
                    w_interleaved = w.reshape(
                        w_shape[0], w_shape[1], (w_shape[2] // 4), 4
                    )
708
709
                    w_interleaved = w_interleaved.permute(0, 2, 1, 3)
                    w_interleaved = w_interleaved.reshape(
710
711
                        w_shape[0], w_shape[2] // 4, w_shape[1] * 4
                    )
712
713
                    return w_interleaved

714
715
                w31_scales = w13_scale_swapped.to(torch.uint8).view(torch.uint8)
                w31_scales_interleaved = _interleave_mxfp4_cutlass_sm90(w31_scales)
716
717
718

                w2_weight_scale = layer.w2_weight_scale.data
                w2_scales = w2_weight_scale.to(torch.uint8).view(torch.uint8)
719
720
721
722
723
724
725
726
                w2_scales_interleaved = _interleave_mxfp4_cutlass_sm90(w2_scales)

                layer.w13_weight = torch.nn.Parameter(
                    torch.cat([w3_w, w1_w], dim=1), requires_grad=False
                )
                layer.w13_bias = torch.nn.Parameter(
                    w13_bias_swapped, requires_grad=False
                )
727
                layer.w13_weight_scale = torch.nn.Parameter(
728
729
                    w31_scales_interleaved, requires_grad=False
                )
730
                layer.w2_weight_scale = torch.nn.Parameter(
731
732
                    w2_scales_interleaved, requires_grad=False
                )
733
        elif self.mxfp4_backend == Mxfp4Backend.TRITON:
734
735
736
737
738
739
740
741
            from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig

            w13_bias = layer.w13_bias.to(torch.float32)
            w2_bias = layer.w2_bias.to(torch.float32)

            layer.w13_bias = Parameter(w13_bias, requires_grad=False)
            layer.w2_bias = Parameter(w2_bias, requires_grad=False)

742
743
744
745
746
            # Ideally we'd use FusedMoEModularKernel.prepare_finalize object
            # (stored in self.fused_experts) to determine if the MoE has a
            # batched activation format. As self.fused_experts is not
            # initialized at this point, we resort to checking the MoE config
            # directly.
747
            is_batched_moe = self.moe.use_pplx_kernels or self.moe.use_deepep_ll_kernels
748
            if is_batched_moe:
749
750
751
752
753
                num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8
            else:
                num_warps = 8

            w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
754
755
                layer.w13_weight, layer.w13_weight_scale, num_warps
            )
756
            w2_weight, w2_flex, w2_scale = _swizzle_mxfp4(
757
758
                layer.w2_weight, layer.w2_weight_scale, num_warps
            )
759
760

            self.w13_precision_config = PrecisionConfig(
761
762
                weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex)
            )
763
            self.w2_precision_config = PrecisionConfig(
764
765
                weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex)
            )
766

767
768
            self.w13_weight = w13_weight
            self.w2_weight = w2_weight
769
770
771
772
            del layer.w13_weight
            del layer.w2_weight
            layer.w13_weight = w13_weight
            layer.w2_weight = w2_weight
773
774
        else:
            raise ValueError(f"Unsupported backend: {self.mxfp4_backend}")
775

776
    def get_fused_moe_quant_config(
777
        self, layer: torch.nn.Module
778
    ) -> FusedMoEQuantConfig | None:
779
        if self.mxfp4_backend == Mxfp4Backend.MARLIN:
780
781
782
783
784
785
786
            return mxfp4_w4a16_moe_quant_config(
                w1_bias=layer.w13_bias,
                w2_bias=layer.w2_bias,
                w1_scale=layer.w13_weight_scale,
                w2_scale=layer.w2_weight_scale,
            )
        elif self.mxfp4_backend == Mxfp4Backend.TRITON:
787
788
            w1_scale = self.w13_precision_config
            w2_scale = self.w2_precision_config
789
790
791
792
793
794
            return mxfp4_w4a16_moe_quant_config(
                w1_bias=layer.w13_bias,
                w2_bias=layer.w2_bias,
                w1_scale=w1_scale,
                w2_scale=w2_scale,
            )
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
        elif self.mxfp4_backend in [
            Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM,
            Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS,
        ]:
            return mxfp4_mxfp8_moe_quant_config(
                w1_bias=layer.w13_bias,
                w2_bias=layer.w2_bias,
                w1_scale=layer.w13_weight_scale,
                w2_scale=layer.w2_weight_scale,
            )
        elif self.mxfp4_backend in [Mxfp4Backend.SM100_FI_MXFP4_BF16]:
            return mxfp4_w4a16_moe_quant_config(
                w1_bias=layer.w13_bias,
                w2_bias=layer.w2_bias,
                w1_scale=layer.w13_weight_scale,
                w2_scale=layer.w2_weight_scale,
            )
812
813
814
        else:
            w1_scale = layer.w13_weight_scale
            w2_scale = layer.w2_weight_scale
815
816
            return ocp_mx_moe_quant_config(
                quant_dtype="mxfp4",
817
818
819
820
821
                w1_bias=layer.w13_bias,
                w2_bias=layer.w2_bias,
                w1_scale=w1_scale,
                w2_scale=w2_scale,
            )
822

823
824
825
826
827
    def select_gemm_impl(
        self,
        prepare_finalize: mk.FusedMoEPrepareAndFinalize,
        layer: torch.nn.Module,
    ) -> mk.FusedMoEPermuteExpertsUnpermute:
828
829
830
831
        if (
            prepare_finalize.activation_format
            == mk.FusedMoEActivationFormat.BatchedExperts
        ):
832
833
834
835
836
837
838
839
840
841
842
            if self.mxfp4_backend == Mxfp4Backend.MARLIN:
                max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
                assert max_num_tokens_per_rank is not None
                assert self.moe_quant_config is not None
                return BatchedMarlinExperts(
                    max_num_tokens=max_num_tokens_per_rank,
                    num_dispatchers=prepare_finalize.num_dispatchers(),
                    quant_config=self.moe_quant_config,
                )
            else:
                raise NotImplementedError(
843
844
                    f"Incompatible Mxfp4 backend ({self.mxfp4_backend}) for "
                    "EP batched experts format"
845
                )
846
        else:
847
            assert self.moe_quant_config is not None
848
849
850
851
            if (
                self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
                or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
            ):
852
853
854
855
856
                # B200 code-path
                kwargs = {
                    "gemm1_alpha": layer.gemm1_alpha,
                    "gemm1_beta": layer.gemm1_beta,
                    "gemm1_clamp_limit": layer.gemm1_clamp_limit,
857
                    # TODO(bnell): part of quant_config
858
859
                    "max_capture_size": self.max_capture_size,
                }
860
861
                return TrtLlmGenExperts(self.moe, self.moe_quant_config, **kwargs)
            elif self.mxfp4_backend == Mxfp4Backend.MARLIN:
862
                return MarlinExperts(self.moe_quant_config)
863
            elif self.mxfp4_backend == Mxfp4Backend.TRITON:
864
                return OAITritonExperts(self.moe_quant_config)
865
866
867
868
            else:
                raise NotImplementedError(
                    f"Incompatible Mxfp4 backend ({self.mxfp4_backend}) for EP"
                )
869

870
871
872
    @property
    def allow_inplace(self) -> bool:
        return True
873

874
875
    def apply(
        self,
876
        layer: FusedMoE,
877
878
879
880
881
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
        use_grouped_topk: bool = False,
882
883
        topk_group: int | None = None,
        num_expert_group: int | None = None,
884
        global_num_experts: int = -1,
885
886
        expert_map: torch.Tensor | None = None,
        custom_routing_function: Callable | None = None,
887
        scoring_func: str = "softmax",
888
        routed_scaling_factor: float = 1.0,
889
        e_score_correction_bias: torch.Tensor | None = None,
890
891
892
        apply_router_weight_on_input: bool = False,
        activation: str = "silu",
        enable_eplb: bool = False,
893
894
895
896
        expert_load_view: torch.Tensor | None = None,
        logical_to_physical_map: torch.Tensor | None = None,
        logical_replica_count: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
897
898
899
        if enable_eplb:
            raise NotImplementedError("EPLB is not supported for mxfp4")

900
        if self.mxfp4_backend == Mxfp4Backend.MARLIN:
901
            topk_weights, topk_ids, _ = layer.select_experts(
902
903
                hidden_states=x,
                router_logits=router_logits,
904
            )
905

906
            return fused_marlin_moe(
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
                x,
                layer.w13_weight,
                layer.w2_weight,
                layer.w13_bias,
                layer.w2_bias,
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                router_logits,
                topk_weights,
                topk_ids,
                global_scale1=None,
                global_scale2=None,
                quant_type_id=scalar_types.float4_e2m1f.id,
                apply_router_weight_on_input=apply_router_weight_on_input,
                global_num_experts=global_num_experts,
                activation=activation,
923
                expert_map=expert_map,
924
                input_dtype=self.marlin_input_dtype,
925
            )
926

927
        assert _can_support_mxfp4(
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
            use_grouped_topk,
            topk_group,
            num_expert_group,
            expert_map,
            custom_routing_function,
            e_score_correction_bias,
            apply_router_weight_on_input,
            scoring_func,
            activation,
            expert_load_view,
            logical_to_physical_map,
            logical_replica_count,
        ), "MXFP4 are not supported with this configuration."

        if (
            self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
            or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
        ):
946
            from flashinfer import trtllm_fp4_block_scale_moe
947

948
            if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16:
949
950
951
                assert x.dtype == torch.bfloat16
                x_quant = x
                x_scale = None
952
953
            elif self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM:
                from flashinfer import mxfp8_quantize
954

955
                x_quant, x_scale = mxfp8_quantize(x, False)  # to mxfp8
956
                x_scale = x_scale.view(torch.float8_e4m3fn).reshape(*x.shape[:-1], -1)
957

958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
            trtllm_gen_output = trtllm_fp4_block_scale_moe(
                router_logits.to(torch.bfloat16),
                None,  # routing_bias
                x_quant,
                x_scale,
                layer.w13_weight,  # uint8 (e2m1 x 2)
                layer.w13_weight_scale,  # uint8 (e4m3 x 2)
                layer.w13_bias,  # fp32 per expert per channel
                layer.gemm1_alpha,  # fp32 per expert
                layer.gemm1_beta,  # fp32 per expert
                layer.gemm1_clamp_limit,  # fp32 per expert
                layer.w2_weight,  # uint8 (e2m1 x 2)
                layer.w2_weight_scale,  # ue8m0
                layer.w2_bias,  # fp32 per expert per channel
                None,  # output1_scale_scalar
                None,  # output1_scale_gate_scalar
                None,  # output2_scale_scalar
975
                global_num_experts,
976
977
978
979
                top_k,
                None,  # n_group
                None,  # topk_group
                self.intermediate_size,  # padded to multiple of 256
980
                layer.ep_rank * layer.local_num_experts,  # local_expert_offset
981
982
                self.num_experts,  # local num experts
                None,
983
                None,
984
985
                1 if renormalize else 0,  # routing_method_type, renormalize
                True,  # do finalize
986
                tune_max_num_tokens=max(self.max_capture_size, 1),
987
988
            )[0]
            return trtllm_gen_output
989
990
991
992
        elif (
            self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
            or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
        ):
993
994
            from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe

995
            topk_weights, topk_ids, _ = layer.select_experts(
996
997
998
999
1000
1001
1002
1003
1004
1005
                hidden_states=x,
                router_logits=router_logits,
            )

            # Backend-specific preparation
            if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS:
                from flashinfer import mxfp8_quantize

                x_quant, x_scale = mxfp8_quantize(x, True, 32)

1006
                fake_input_scale = torch.ones(self.num_experts, device=x.device)
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
                quant_scales = [
                    layer.w13_weight_scale.contiguous().view(torch.int32),
                    fake_input_scale,
                    layer.w2_weight_scale.contiguous().view(torch.int32),
                    fake_input_scale,
                ]

                fi_input = x_quant
                extra_kwargs = dict(
                    use_mxfp8_act_scaling=True,
                    input_sf=x_scale,
1018
1019
                    fc1_expert_weights=layer.w13_weight.contiguous().view(torch.long),
                    fc2_expert_weights=layer.w2_weight.contiguous().view(torch.long),
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
                )
            elif self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16:
                assert x.dtype == torch.bfloat16

                quant_scales = [
                    layer.w13_weight_scale,
                    layer.w2_weight_scale,
                ]

                fi_input = x
                extra_kwargs = dict(
                    use_w4_group_scaling=True,
                    fc1_expert_weights=layer.w13_weight,
                    fc2_expert_weights=layer.w2_weight,
                )

            output = torch.empty_like(x, dtype=torch.bfloat16)
            _ = flashinfer_cutlass_fused_moe(
                input=fi_input,
                token_selected_experts=topk_ids.to(torch.int).contiguous(),
                token_final_scales=topk_weights,
                output_dtype=torch.bfloat16,
                output=output,
                quant_scales=quant_scales,
                fc1_expert_biases=layer.w13_bias,
                fc2_expert_biases=layer.w2_bias,
                swiglu_alpha=layer.gemm1_alpha,
                swiglu_beta=layer.gemm1_beta,
                swiglu_limit=layer.gemm1_clamp_limit,
                tp_size=self.moe.tp_size,
                tp_rank=self.moe.tp_rank,
                ep_size=self.moe.ep_size,
                ep_rank=self.moe.ep_rank,
1053
                tune_max_num_tokens=max(self.max_capture_size, 1),
1054
1055
1056
1057
1058
                **extra_kwargs,
            )

            return output
        elif self.mxfp4_backend == Mxfp4Backend.TRITON:
1059
            from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (  # noqa: E501
1060
1061
1062
                triton_kernel_moe_forward,
            )

1063
1064
            return triton_kernel_moe_forward(
                hidden_states=x,
1065
1066
                w1=layer.w13_weight,
                w2=layer.w2_weight,
1067
1068
1069
1070
1071
                gating_output=router_logits,
                topk=top_k,
                renormalize=renormalize,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
1072
                quant_config=self.moe_quant_config,
1073
1074
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
1075
1076
        else:
            raise ValueError(f"Unsupported backend: {self.mxfp4_backend}")
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107


class IpexMxfp4MoEMethod(Mxfp4MoEMethod):
    def __init__(self, moe_config: FusedMoEConfig):
        super().__init__(moe_config)
        self.moe_config = moe_config

    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        super().create_weights(
            layer,
            num_experts,
            hidden_size,
            intermediate_size_per_partition,
            params_dtype,
            **extra_weight_attrs,
        )
        self.original_hidden_size = hidden_size

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        import intel_extension_for_pytorch as ipex

        layer.w13_weight.data = layer.w13_weight.data.view(torch.int32)
        layer.w2_weight.data = layer.w2_weight.data.view(torch.int32)
1108
        ep_rank_start = self.moe_config.ep_rank * self.moe_config.num_local_experts
1109
1110
1111
1112
1113
1114
1115
1116
        layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE(
            layer.w13_weight,
            layer.w2_weight,
            w1_scale_inv=layer.w13_weight_scale,
            w2_scale_inv=layer.w2_weight_scale,
            w13_bias=layer.w13_bias,
            w2_bias=layer.w2_bias,
            is_mxfp4=True,
1117
            experts_start_id=ep_rank_start,
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
        )

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
        use_grouped_topk: bool = False,
        topk_group: int | None = None,
        num_expert_group: int | None = None,
        global_num_experts: int = -1,
        expert_map: torch.Tensor | None = None,
        custom_routing_function: Callable | None = None,
        scoring_func: str = "softmax",
        routed_scaling_factor: float = 1.0,
        e_score_correction_bias: torch.Tensor | None = None,
        apply_router_weight_on_input: bool = False,
        activation: str = "silu",
        enable_eplb: bool = False,
        expert_load_view: torch.Tensor | None = None,
        logical_to_physical_map: torch.Tensor | None = None,
        logical_replica_count: torch.Tensor | None = None,
    ) -> torch.Tensor:
        assert activation == "swigluoai", (
            "Only swiglu_oai activation is supported for IPEX MXFP4 MoE"
1145
        )
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
        hidden_size_pad = round_up(self.original_hidden_size, 128)
        x_pad = torch.nn.functional.pad(x, (0, hidden_size_pad - x.size(-1)))
        hidden_states = layer.ipex_fusion(
            x_pad,
            use_grouped_topk,
            top_k,
            router_logits,
            renormalize,
            topk_group,
            num_expert_group,
            activation="swiglu_oai",
        )
        hidden_states = hidden_states[..., : self.original_hidden_size].contiguous()
        return hidden_states