mxfp4.py 46.4 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections.abc import Callable
4
from enum import Enum
5
from typing import Optional
6
7
8
9
10

import torch
from torch.nn.parameter import Parameter

from vllm import envs
11
from vllm.attention.layer import Attention
12
from vllm.config import get_current_vllm_config
13
from vllm.logger import init_logger
14
15
16
17
18
from vllm.model_executor.layers.fused_moe import (
    FusedMoE,
    FusedMoEConfig,
    FusedMoEMethodBase,
)
19
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
20
from vllm.model_executor.layers.fused_moe.config import (
21
    FusedMoEQuantConfig,
22
    mxfp4_mxfp8_moe_quant_config,
23
    mxfp4_w4a16_moe_quant_config,
24
    ocp_mx_moe_quant_config,
25
)
26
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
27
    BatchedMarlinExperts,
28
29
30
    MarlinExperts,
    fused_marlin_moe,
)
31
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
32
    OAITritonExperts,
33
    UnfusedOAITritonExperts,
34
)
35
from vllm.model_executor.layers.fused_moe.trtllm_moe import TrtLlmGenExperts
36
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
37
38
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
39
40
41
    QuantizationConfig,
    QuantizeMethodBase,
)
42
43
44
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
    get_marlin_input_dtype,
)
45
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
46
47
    prepare_moe_fp4_layer_for_marlin,
)
48
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
49
50
    _can_support_mxfp4,
    _swizzle_mxfp4,
51
    get_padding_alignment,
52
53
)
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
54
from vllm.model_executor.utils import set_weight_attrs
55
from vllm.platforms import current_platform
56
from vllm.scalar_type import scalar_types
57
from vllm.utils.flashinfer import has_flashinfer
58
from vllm.utils.import_utils import has_triton_kernels
Cyrus Leung's avatar
Cyrus Leung committed
59
from vllm.utils.math_utils import round_up
60
from vllm.utils.torch_utils import is_torch_equal_or_newer
61

62
63
64
logger = init_logger(__name__)


65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# enum for mxfp4 backend
class Mxfp4Backend(Enum):
    NONE = 0

    # FlashInfer Backend
    SM100_FI_MXFP4_MXFP8_TRTLLM = 1
    SM100_FI_MXFP4_MXFP8_CUTLASS = 2
    SM100_FI_MXFP4_BF16 = 3
    SM90_FI_MXFP4_BF16 = 4

    # Marlin Backend
    MARLIN = 5

    # Triton Backend
    TRITON = 6


82
83
84
85
86
87
88
89
def get_mxfp4_backend_with_lora() -> Mxfp4Backend:
    """
    Not all MXFP4 backends support LoRA. Select backends that are known to
    have LoRA support.
    """
    if not current_platform.is_cuda():
        return Mxfp4Backend.NONE

90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
    # If FlashInfer is not available, try either Marlin or Triton
    triton_kernels_supported = (
        has_triton_kernels()
        and is_torch_equal_or_newer("2.8.0")
        # NOTE: triton_kernels are only confirmed to work on SM90 and SM100
        # SM110 fails with this error: https://github.com/vllm-project/vllm/issues/29317
        # SM120 needs this fix: https://github.com/triton-lang/triton/pull/8498
        and (9, 0) <= current_platform.get_device_capability() < (11, 0)
    )
    if envs.VLLM_MXFP4_USE_MARLIN or not triton_kernels_supported:
        logger.info_once("[get_mxfp4_backend_with_lora] Using Marlin backend")
        return Mxfp4Backend.MARLIN

    logger.info_once("[get_mxfp4_backend_with_lora] Using Triton backend")
    return Mxfp4Backend.TRITON
105
106
107


def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend:
108
    # Backend Selection
109
110
111
112

    if with_lora_support:
        return get_mxfp4_backend_with_lora()

113
    if current_platform.is_cuda():
114
115
116
117
118
        if (
            current_platform.is_device_capability(90)
            and has_flashinfer()
            and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16
        ):
119
120
            logger.info_once("Using FlashInfer MXFP4 BF16 backend for SM90")
            return Mxfp4Backend.SM90_FI_MXFP4_BF16
121
122
123
124
125
126
        elif (
            current_platform.is_device_capability(100)
            and has_flashinfer()
            and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS
        ):
            logger.info_once("Using FlashInfer MXFP4 MXFP8 CUTLASS backend for SM100")
127
            return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
128
129
130
131
132
        elif (
            current_platform.is_device_capability(100)
            and has_flashinfer()
            and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
        ):
133
134
135
136
137
138
            return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
        elif current_platform.is_device_capability(100) and has_flashinfer():
            logger.info_once(
                "Using FlashInfer MXFP4 BF16 backend for SM100, "
                "For faster performance on SM100, consider setting "
                "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1, though this may impact "
139
140
                "accuracy."
            )
141
            return Mxfp4Backend.SM100_FI_MXFP4_BF16
142
143
144
145
        elif (
            current_platform.is_device_capability(100)
            or current_platform.is_device_capability(90)
        ) and not has_flashinfer():
146
147
148
            logger.warning_once(
                "MXFP4 MoE is enabled on Hopper/Blackwell but FlashInfer "
                "is not available. This may result in degraded performance. "
149
150
                "Please `pip install vllm[flashinfer]` for best results."
            )
151

152
        # If FlashInfer is not available, try either Marlin or Triton
153
154
155
156
157
158
159
160
161
        triton_kernels_supported = (
            has_triton_kernels()
            and is_torch_equal_or_newer("2.8.0")
            # NOTE: triton_kernels are only confirmed to work on SM90 and SM100
            # SM110 fails with this error: https://github.com/vllm-project/vllm/issues/29317
            # SM120 needs this fix: https://github.com/triton-lang/triton/pull/8498
            and (9, 0) <= current_platform.get_device_capability() < (11, 0)
        )
        if envs.VLLM_MXFP4_USE_MARLIN or not triton_kernels_supported:
162
163
164
165
166
            logger.info_once("Using Marlin backend")
            return Mxfp4Backend.MARLIN
        else:
            logger.info_once("Using Triton backend")
            return Mxfp4Backend.TRITON
167
168
169
    elif current_platform.is_xpu():
        logger.info_once("Using ipex marlin backend on XPU")
        return Mxfp4Backend.MARLIN
170
171
172
    elif current_platform.is_rocm() and has_triton_kernels():
        logger.info_once("Using Triton backend")
        return Mxfp4Backend.TRITON
173

174
    return Mxfp4Backend.NONE
175
176
177


class Mxfp4Config(QuantizationConfig):
178
    def __init__(self, ignored_layers: list[str] | None = None):
179
180
181
182
183
184
185
186
187
        super().__init__()
        self.ignored_layers = ignored_layers

    @classmethod
    def from_config(cls, config):
        return cls()

    @classmethod
    def get_min_capability(cls) -> int:
188
        return 80
189
190
191
192
193
194
195
196
197
198
199
200
201

    @classmethod
    def get_name(cls) -> QuantizationMethods:
        return "mxfp4"

    @classmethod
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
        return [torch.bfloat16]

    @classmethod
    def get_config_filenames(cls) -> list[str]:
        return []

202
203
204
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
205
206
        if isinstance(layer, LinearBase):
            if self.ignored_layers and is_layer_skipped(
207
208
209
210
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
211
                return UnquantizedLinearMethod()
212
213
214
            # TODO: Add support for MXFP4 Linear Method.
            # MXFP4 LinearMethod is available in AMD-Quark, refer to that implementation
            # if you are interested in enabling MXFP4 here.
215
            logger.debug_once(
216
                "MXFP4 linear layer is not implemented - falling back to "
217
218
                "UnquantizedLinearMethod.",
                scope="local",
219
220
            )
            return UnquantizedLinearMethod()
221
        elif isinstance(layer, FusedMoE):
222
223
224
            if current_platform.is_xpu():
                return IpexMxfp4MoEMethod(layer.moe_config)
            else:
225
226
227
                quant_method = Mxfp4MoEMethod(layer.moe_config)
                quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
                return quant_method
228
        elif isinstance(layer, Attention):
229
            # TODO: Add support for MXFP4 Attention.
230
            logger.debug_once(
231
                "MXFP4 attention layer is not implemented. "
232
233
                "Skipping quantization for this layer.",
                scope="local",
234
            )
235
236
237
238
239
        return None


class Mxfp4MoEMethod(FusedMoEMethodBase):
    def __init__(self, moe: FusedMoEConfig):
240
        super().__init__(moe)
241
        self.mxfp4_backend = get_mxfp4_backend(moe.is_lora_enabled)
242
243

        self.marlin_input_dtype = None
244
        self.use_marlin = self.mxfp4_backend == Mxfp4Backend.MARLIN
245
        self.max_capture_size = (
246
            get_current_vllm_config().compilation_config.max_cudagraph_capture_size
247
        )
248

249
        assert self.mxfp4_backend != Mxfp4Backend.NONE, (
250
251
            f"get_mxfp4_backend(with_lora_support={moe.is_lora_enabled}) found"
            "no compatible MXFP4 MoE backend (FlashInfer/Marlin/Triton)."
252
253
            "Please check your environment and try again."
        )
254
        self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
255

256
257
258
259
260
261
262
263
264
    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
265
266
267
268
269
270
271
272
273
274
275
276
277
278
        self.num_experts = num_experts
        weight_dtype = torch.uint8
        scale_dtype = torch.uint8

        # FIXME (zyongye): ship after torch and safetensors support mxfp4
        # is_torch_mxfp4_available = (
        #     hasattr(torch, "float4_e2m1fn_x2") and
        #     hasattr(torch, "float8_e8m0fnu"))
        # if is_torch_mxfp4_available:
        #     weight_dtype = torch.float4_e2m1fn_x2
        #     scale_dtype = torch.float8_e8m0fnu

        mxfp4_block = 32

279
        intermediate_size_per_partition_after_pad = intermediate_size_per_partition
280
        if self.mxfp4_backend == Mxfp4Backend.MARLIN:
281
282
283
284
285
286
287
288
289
            # The moe marlin kernel requires that for each linear
            # n % 256 == 0 and k % 128 == 0.
            # In gate_up_proj:
            #    n = 2 * intermediate_size_per_partition_after_pad
            #    k = hidden_size
            # In down_proj
            #    n = hidden_size
            #    k = intermediate_size_per_partition_after_pad
            intermediate_size_per_partition_after_pad = round_up(
290
291
                intermediate_size_per_partition, 128
            )
292
293
294
295
            if current_platform.is_xpu():
                hidden_size = round_up(hidden_size, 128)
            else:
                hidden_size = round_up(hidden_size, 256)
296
297
298
299

            layer.params_dtype = params_dtype
            layer.num_experts = num_experts
            layer.hidden_size = hidden_size
300
            layer.intermediate_size_per_partition = (
301
                intermediate_size_per_partition_after_pad
302
303
304
305
306
            )
        elif (
            self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
            or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
        ):
307
308
309
            # pad the intermediate size to be a multiple of 2 * mxfp4_block
            # for to hold non-uniform sharded tensor as well as swizzling
            # other padding to increase performance
310
            intermediate_size_per_partition_after_pad = round_up(
311
312
                intermediate_size_per_partition, 256
            )
313
            hidden_size = round_up(hidden_size, 256)
314
315
316
317
        elif (
            self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
            or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
        ):
318
            intermediate_size_per_partition_after_pad = round_up(
319
320
                intermediate_size_per_partition, 128
            )
321
            hidden_size = round_up(hidden_size, 128)
322
        elif current_platform.is_rocm():
323
            pad_align = get_padding_alignment()
324
            intermediate_size_per_partition_after_pad = round_up(
325
                intermediate_size_per_partition, pad_align
326
            )
327
            hidden_size = round_up(hidden_size, pad_align)
328
329
        else:
            intermediate_size_per_partition_after_pad = round_up(
330
331
                intermediate_size_per_partition, 64
            )
332
333
334
335

        self.intermediate_size = intermediate_size_per_partition_after_pad
        self.hidden_size = hidden_size
        # Fused gate_up_proj (column parallel)
336
337
338
339
340
341
342
343
344
        w13_weight = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                2 * intermediate_size_per_partition_after_pad,
                hidden_size // 2,
                dtype=weight_dtype,
            ),
            requires_grad=False,
        )
345
346
347
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

348
349
350
351
352
353
354
355
356
        w13_weight_scale = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                2 * intermediate_size_per_partition_after_pad,
                hidden_size // mxfp4_block,
                dtype=scale_dtype,
            ),
            requires_grad=False,
        )
357
358
359
        layer.register_parameter("w13_weight_scale", w13_weight_scale)
        set_weight_attrs(w13_weight_scale, extra_weight_attrs)

360
361
362
363
364
365
366
367
        w13_bias = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                2 * intermediate_size_per_partition_after_pad,
                dtype=torch.bfloat16,
            ),
            requires_grad=False,
        )
368
369
370
371
        layer.register_parameter("w13_bias", w13_bias)
        set_weight_attrs(w13_bias, extra_weight_attrs)

        # down_proj (row parallel)
372
373
374
375
376
377
378
379
380
        w2_weight = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                hidden_size,
                intermediate_size_per_partition_after_pad // 2,
                dtype=weight_dtype,
            ),
            requires_grad=False,
        )
381
382
383
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

384
385
386
387
388
389
390
391
392
        w2_weight_scale = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                hidden_size,
                intermediate_size_per_partition_after_pad // mxfp4_block,
                dtype=scale_dtype,
            ),
            requires_grad=False,
        )
393
394
395
        layer.register_parameter("w2_weight_scale", w2_weight_scale)
        set_weight_attrs(w2_weight_scale, extra_weight_attrs)

396
397
398
399
400
401
402
403
        w2_bias = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                hidden_size,
                dtype=torch.bfloat16,
            ),
            requires_grad=False,
        )
404
405
406
407
        layer.register_parameter("w2_bias", w2_bias)
        set_weight_attrs(w2_bias, extra_weight_attrs)

    def process_weights_after_loading(self, layer):
408
        if self.mxfp4_backend == Mxfp4Backend.MARLIN:
409
            prepare_moe_fp4_layer_for_marlin(layer, input_dtype=self.marlin_input_dtype)
410
411
412
413
414
        elif (
            self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
            or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
        ):
            from flashinfer.fp4_quantization import nvfp4_block_scale_interleave
415
            from flashinfer.fused_moe.core import get_w2_permute_indices_with_cache
416
417
418
419
420
421
422
423
424
425
426
427
428

            layer.gemm1_alpha = Parameter(
                torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(),
                requires_grad=False,
            )
            layer.gemm1_beta = Parameter(
                torch.tensor([1.0] * self.num_experts, dtype=torch.float32).cuda(),
                requires_grad=False,
            )
            layer.gemm1_clamp_limit = Parameter(
                torch.tensor([7.0] * self.num_experts, dtype=torch.float32).cuda(),
                requires_grad=False,
            )
429
430
            sf_block_size = 32  # mxfp4 block size

431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
            assert (
                layer.w13_weight.dim() == 3
                and layer.w13_weight.shape[0] == self.num_experts
                and layer.w13_weight.shape[1] == self.intermediate_size * 2
                and layer.w13_weight.shape[2] == self.hidden_size // 2
            )
            assert (
                layer.w13_weight_scale.dim() == 3
                and layer.w13_weight_scale.shape[0] == self.num_experts
                and layer.w13_weight_scale.shape[1] == self.intermediate_size * 2
                and layer.w13_weight_scale.shape[2] == self.hidden_size // sf_block_size
            )
            assert (
                layer.w2_weight.dim() == 3
                and layer.w2_weight.shape[0] == self.num_experts
                and layer.w2_weight.shape[1] == self.hidden_size
                and layer.w2_weight.shape[2] == self.intermediate_size // 2
            )
            assert (
                layer.w2_weight_scale.dim() == 3
                and layer.w2_weight_scale.shape[1] == self.hidden_size
                and layer.w2_weight_scale.shape[2]
                == self.intermediate_size // sf_block_size
            )
            assert (
                layer.w13_bias.dim() == 2
                and layer.w13_bias.shape[0] == self.num_experts
                and layer.w13_bias.shape[1] == self.intermediate_size * 2
            )
            assert (
                layer.w2_bias.dim() == 2
                and layer.w2_bias.shape[0] == self.num_experts
                and layer.w2_bias.shape[1] == self.hidden_size
            )
465
466
467
468
469
470
471
472

            w13_weight_scale = layer.w13_weight_scale.data
            w2_weight_scale = layer.w2_weight_scale.data
            w13_weight = layer.w13_weight.data
            w2_weight = layer.w2_weight.data
            w13_bias = layer.w13_bias.data.to(torch.float32)
            w2_bias = layer.w2_bias.data.to(torch.float32)

co63oc's avatar
co63oc committed
473
            # Swap w1 and w3 as the definition of
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
            # swiglu is different in the trtllm-gen
            def swap_every_two_rows(x, axis=-1):
                shape = x.shape
                if axis < 0:
                    axis = len(shape) + axis

                # Create a new shape with pairs swapped along specified axis
                new_shape = list(shape)
                new_shape[axis] = shape[axis] // 2
                new_shape.insert(axis + 1, 2)

                # Reshape to expose pairs, swap them, and reshape back
                x = x.reshape(*new_shape)
                x = x.flip(axis + 1)
                new_shape = list(shape)
                return x.reshape(*new_shape)

            w13_weight_scale = swap_every_two_rows(w13_weight_scale, -2)
            w13_weight = swap_every_two_rows(w13_weight, -2)
            w13_bias = swap_every_two_rows(w13_bias, -1)

            # Do not interleave as the checkpoint is already interleaved

            # Shuffle weights and scaling factors for transposed mma output
            gemm1_weights_mxfp4_shuffled = []
            gemm1_scales_mxfp4_shuffled = []
            gemm2_weights_mxfp4_shuffled = []
            gemm2_scales_mxfp4_shuffled = []
            gemm1_bias_shuffled = []
            gemm2_bias_shuffled = []
            epilogue_tile_m = 128  # FIXME: this depends on the kernel internals
            for i in range(self.num_experts):
506
                # w13 weight shuffling
507
                permute_indices = get_w2_permute_indices_with_cache(
508
509
510
511
                    self._cache_permute_indices,
                    w13_weight[i].view(torch.uint8),
                    epilogue_tile_m,
                )
512
513
514
515
516
                gemm1_weights_mxfp4_shuffled.append(
                    w13_weight[i]
                    .view(torch.uint8)[permute_indices.to(w13_weight.device)]
                    .contiguous()
                )
517
                # w13 scale shuffling
518
                permute_sf_indices = get_w2_permute_indices_with_cache(
519
520
521
522
523
                    self._cache_permute_indices,
                    w13_weight_scale[i].view(torch.uint8),
                    epilogue_tile_m,
                    num_elts_per_sf=16,
                )
524
                gemm1_scales_mxfp4_shuffled.append(
525
526
527
528
529
530
531
532
                    nvfp4_block_scale_interleave(
                        w13_weight_scale[i]
                        .view(torch.uint8)[
                            permute_sf_indices.to(w13_weight_scale.device)
                        ]
                        .contiguous()
                    )
                )
533
                # w13 bias shuffling
534
                permute_bias_indices = get_w2_permute_indices_with_cache(
535
536
537
538
                    self._cache_permute_indices,
                    w13_bias[i].clone().reshape(-1, 1),
                    epilogue_tile_m,
                )
539
540
541
542
543
544
                gemm1_bias_shuffled.append(
                    w13_bias[i]
                    .clone()
                    .reshape(-1, 1)[permute_bias_indices.to(w13_bias.device)]
                    .contiguous()
                )
545
                # w2 weight shuffling
546
                permute_indices = get_w2_permute_indices_with_cache(
547
548
549
550
                    self._cache_permute_indices,
                    w2_weight[i].view(torch.uint8),
                    epilogue_tile_m,
                )
551
552
553
554
555
                gemm2_weights_mxfp4_shuffled.append(
                    w2_weight[i]
                    .view(torch.uint8)[permute_indices.to(w2_weight.device)]
                    .contiguous()
                )
556
                # w2 scale shuffling
557
                permute_sf_indices = get_w2_permute_indices_with_cache(
558
559
560
561
562
                    self._cache_permute_indices,
                    w2_weight_scale[i].view(torch.uint8),
                    epilogue_tile_m,
                    num_elts_per_sf=16,
                )
563
                gemm2_scales_mxfp4_shuffled.append(
564
565
566
567
568
569
570
571
                    nvfp4_block_scale_interleave(
                        w2_weight_scale[i]
                        .view(torch.uint8)[
                            permute_sf_indices.to(w2_weight_scale.device)
                        ]
                        .contiguous()
                    )
                )
572
                # w2 bias shuffling
573
                permute_indices = get_w2_permute_indices_with_cache(
574
575
576
577
                    self._cache_permute_indices,
                    w2_bias[i].clone().reshape(-1, 1),
                    epilogue_tile_m,
                )
578
579
580
581
582
583
                gemm2_bias_shuffled.append(
                    w2_bias[i]
                    .clone()
                    .reshape(-1, 1)[permute_indices.to(w2_bias.device)]
                    .contiguous()
                )
584
585

            w13_weight = torch.stack(gemm1_weights_mxfp4_shuffled)
586
587
588
589
590
591
592
593
594
            w13_weight_scale = (
                torch.stack(gemm1_scales_mxfp4_shuffled)
                .reshape(
                    self.num_experts,
                    2 * self.intermediate_size,
                    self.hidden_size // sf_block_size,
                )
                .view(torch.float8_e4m3fn)
            )
595
596

            w2_weight = torch.stack(gemm2_weights_mxfp4_shuffled)
597
598
599
600
601
602
603
604
605
            w2_weight_scale = (
                torch.stack(gemm2_scales_mxfp4_shuffled)
                .reshape(
                    self.num_experts,
                    self.hidden_size,
                    self.intermediate_size // sf_block_size,
                )
                .view(torch.float8_e4m3fn)
            )
606
607

            layer.w13_weight = Parameter(w13_weight, requires_grad=False)
608
            layer.w13_weight_scale = Parameter(w13_weight_scale, requires_grad=False)
609
            layer.w2_weight = Parameter(w2_weight, requires_grad=False)
610
            layer.w2_weight_scale = Parameter(w2_weight_scale, requires_grad=False)
611
612
            layer.w13_bias = Parameter(
                torch.stack(gemm1_bias_shuffled).reshape(self.num_experts, -1),
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
                requires_grad=False,
            )
            layer.w2_bias = Parameter(
                torch.stack(gemm2_bias_shuffled).reshape(self.num_experts, -1),
                requires_grad=False,
            )
        elif (
            self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
            or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
        ):
            layer.gemm1_alpha = Parameter(
                torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(),
                requires_grad=False,
            )
            layer.gemm1_beta = Parameter(
                torch.tensor([1.0] * self.num_experts, dtype=torch.float32).cuda(),
                requires_grad=False,
            )
            layer.gemm1_clamp_limit = Parameter(
                torch.tensor([7.0] * self.num_experts, dtype=torch.float32).cuda(),
                requires_grad=False,
            )
635
636
637
638

            sf_block_size = 32  # mxfp4 block size

            # Common shape assertions
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
            assert (
                layer.w13_weight.dim() == 3
                and layer.w13_weight.shape[0] == self.num_experts
                and layer.w13_weight.shape[1] == self.intermediate_size * 2
                and layer.w13_weight.shape[2] == self.hidden_size // 2
            )
            assert (
                layer.w13_weight_scale.dim() == 3
                and layer.w13_weight_scale.shape[0] == self.num_experts
                and layer.w13_weight_scale.shape[1] == self.intermediate_size * 2
                and layer.w13_weight_scale.shape[2] == self.hidden_size // sf_block_size
            )
            assert (
                layer.w2_weight.dim() == 3
                and layer.w2_weight.shape[0] == self.num_experts
                and layer.w2_weight.shape[1] == self.hidden_size
                and layer.w2_weight.shape[2] == self.intermediate_size // 2
            )
            assert (
                layer.w2_weight_scale.dim() == 3
                and layer.w2_weight_scale.shape[1] == self.hidden_size
                and layer.w2_weight_scale.shape[2]
                == self.intermediate_size // sf_block_size
            )
            assert (
                layer.w13_bias.dim() == 2
                and layer.w13_bias.shape[0] == self.num_experts
                and layer.w13_bias.shape[1] == self.intermediate_size * 2
            )
            assert (
                layer.w2_bias.dim() == 2
                and layer.w2_bias.shape[0] == self.num_experts
                and layer.w2_bias.shape[1] == self.hidden_size
            )
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697

            # De-interleave and swap for w13 weight, bias, and scales
            w13_w = layer.w13_weight.data
            gate_w, up_w = w13_w[:, ::2, :], w13_w[:, 1::2, :]
            deinterleaved_w13_w = torch.cat([gate_w, up_w], dim=1)
            w1_w, w3_w = torch.chunk(deinterleaved_w13_w, 2, dim=1)
            w13_weight_swapped = torch.cat([w3_w, w1_w], dim=1)

            w13_b = layer.w13_bias.data.to(torch.float32)
            gate_b, up_b = w13_b[:, ::2], w13_b[:, 1::2]
            deinterleaved_w13_b = torch.cat([gate_b, up_b], dim=1)
            b1, b3 = torch.chunk(deinterleaved_w13_b, 2, dim=-1)
            w13_bias_swapped = torch.cat([b3, b1], dim=-1).to(torch.bfloat16)

            w13_s = layer.w13_weight_scale.data
            gate_s, up_s = w13_s[:, ::2, :], w13_s[:, 1::2, :]
            deinterleaved_w13_s = torch.cat([gate_s, up_s], dim=1)
            s1, s3 = torch.chunk(deinterleaved_w13_s, 2, dim=1)
            w13_scale_swapped = torch.cat([s3, s1], dim=1)

            if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS:
                from flashinfer import block_scale_interleave

                orig_shape = w13_scale_swapped.shape
                w13_scale_interleaved = block_scale_interleave(
698
699
                    w13_scale_swapped.view(torch.uint8)
                ).reshape(orig_shape)
700
701
702
703

                w2_s = layer.w2_weight_scale.data
                orig_shape = w2_s.shape
                w2_scale_interleaved = block_scale_interleave(
704
705
706
707
708
709
710
711
712
713
714
                    w2_s.view(torch.uint8)
                ).reshape(orig_shape)

                layer.w13_weight = Parameter(w13_weight_swapped, requires_grad=False)
                layer.w13_weight_scale = Parameter(
                    w13_scale_interleaved, requires_grad=False
                )
                layer.w13_bias = Parameter(w13_bias_swapped, requires_grad=False)
                layer.w2_weight_scale = Parameter(
                    w2_scale_interleaved, requires_grad=False
                )
715
716
717
718
            elif self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16:

                def _interleave_mxfp4_cutlass_sm90(w):
                    w_shape = w.shape
719
720
721
                    w_interleaved = w.reshape(
                        w_shape[0], w_shape[1], (w_shape[2] // 4), 4
                    )
722
723
                    w_interleaved = w_interleaved.permute(0, 2, 1, 3)
                    w_interleaved = w_interleaved.reshape(
724
725
                        w_shape[0], w_shape[2] // 4, w_shape[1] * 4
                    )
726
727
                    return w_interleaved

728
729
                w31_scales = w13_scale_swapped.to(torch.uint8).view(torch.uint8)
                w31_scales_interleaved = _interleave_mxfp4_cutlass_sm90(w31_scales)
730
731
732

                w2_weight_scale = layer.w2_weight_scale.data
                w2_scales = w2_weight_scale.to(torch.uint8).view(torch.uint8)
733
734
735
736
737
738
739
740
                w2_scales_interleaved = _interleave_mxfp4_cutlass_sm90(w2_scales)

                layer.w13_weight = torch.nn.Parameter(
                    torch.cat([w3_w, w1_w], dim=1), requires_grad=False
                )
                layer.w13_bias = torch.nn.Parameter(
                    w13_bias_swapped, requires_grad=False
                )
741
                layer.w13_weight_scale = torch.nn.Parameter(
742
743
                    w31_scales_interleaved, requires_grad=False
                )
744
                layer.w2_weight_scale = torch.nn.Parameter(
745
746
                    w2_scales_interleaved, requires_grad=False
                )
747
        elif self.mxfp4_backend == Mxfp4Backend.TRITON:
748
749
750
751
752
753
754
755
            from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig

            w13_bias = layer.w13_bias.to(torch.float32)
            w2_bias = layer.w2_bias.to(torch.float32)

            layer.w13_bias = Parameter(w13_bias, requires_grad=False)
            layer.w2_bias = Parameter(w2_bias, requires_grad=False)

756
757
758
759
760
            # Ideally we'd use FusedMoEModularKernel.prepare_finalize object
            # (stored in self.fused_experts) to determine if the MoE has a
            # batched activation format. As self.fused_experts is not
            # initialized at this point, we resort to checking the MoE config
            # directly.
761
            is_batched_moe = self.moe.use_pplx_kernels or self.moe.use_deepep_ll_kernels
762
            if is_batched_moe:
763
764
765
766
767
                num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8
            else:
                num_warps = 8

            w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
768
769
                layer.w13_weight, layer.w13_weight_scale, num_warps
            )
770
            w2_weight, w2_flex, w2_scale = _swizzle_mxfp4(
771
772
                layer.w2_weight, layer.w2_weight_scale, num_warps
            )
773
774

            self.w13_precision_config = PrecisionConfig(
775
776
                weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex)
            )
777
            self.w2_precision_config = PrecisionConfig(
778
779
                weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex)
            )
780

781
782
            self.w13_weight = w13_weight
            self.w2_weight = w2_weight
783
784
785
786
            del layer.w13_weight
            del layer.w2_weight
            layer.w13_weight = w13_weight
            layer.w2_weight = w2_weight
787
788
        else:
            raise ValueError(f"Unsupported backend: {self.mxfp4_backend}")
789

790
    def get_fused_moe_quant_config(
791
        self, layer: torch.nn.Module
792
    ) -> FusedMoEQuantConfig | None:
793
        if self.mxfp4_backend == Mxfp4Backend.MARLIN:
794
795
796
797
798
799
800
            return mxfp4_w4a16_moe_quant_config(
                w1_bias=layer.w13_bias,
                w2_bias=layer.w2_bias,
                w1_scale=layer.w13_weight_scale,
                w2_scale=layer.w2_weight_scale,
            )
        elif self.mxfp4_backend == Mxfp4Backend.TRITON:
801
802
            w1_scale = self.w13_precision_config
            w2_scale = self.w2_precision_config
803
804
805
806
807
808
            return mxfp4_w4a16_moe_quant_config(
                w1_bias=layer.w13_bias,
                w2_bias=layer.w2_bias,
                w1_scale=w1_scale,
                w2_scale=w2_scale,
            )
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
        elif self.mxfp4_backend in [
            Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM,
            Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS,
        ]:
            return mxfp4_mxfp8_moe_quant_config(
                w1_bias=layer.w13_bias,
                w2_bias=layer.w2_bias,
                w1_scale=layer.w13_weight_scale,
                w2_scale=layer.w2_weight_scale,
            )
        elif self.mxfp4_backend in [Mxfp4Backend.SM100_FI_MXFP4_BF16]:
            return mxfp4_w4a16_moe_quant_config(
                w1_bias=layer.w13_bias,
                w2_bias=layer.w2_bias,
                w1_scale=layer.w13_weight_scale,
                w2_scale=layer.w2_weight_scale,
            )
826
827
828
        else:
            w1_scale = layer.w13_weight_scale
            w2_scale = layer.w2_weight_scale
829
830
            return ocp_mx_moe_quant_config(
                quant_dtype="mxfp4",
831
832
833
834
835
                w1_bias=layer.w13_bias,
                w2_bias=layer.w2_bias,
                w1_scale=w1_scale,
                w2_scale=w2_scale,
            )
836

837
838
839
840
841
    def select_gemm_impl(
        self,
        prepare_finalize: mk.FusedMoEPrepareAndFinalize,
        layer: torch.nn.Module,
    ) -> mk.FusedMoEPermuteExpertsUnpermute:
842
843
844
845
        if (
            prepare_finalize.activation_format
            == mk.FusedMoEActivationFormat.BatchedExperts
        ):
846
847
848
849
850
851
852
853
854
855
856
            if self.mxfp4_backend == Mxfp4Backend.MARLIN:
                max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
                assert max_num_tokens_per_rank is not None
                assert self.moe_quant_config is not None
                return BatchedMarlinExperts(
                    max_num_tokens=max_num_tokens_per_rank,
                    num_dispatchers=prepare_finalize.num_dispatchers(),
                    quant_config=self.moe_quant_config,
                )
            else:
                raise NotImplementedError(
857
858
                    f"Incompatible Mxfp4 backend ({self.mxfp4_backend}) for "
                    "EP batched experts format"
859
                )
860
        else:
861
            assert self.moe_quant_config is not None
862
863
864
865
            if (
                self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
                or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
            ):
866
867
868
869
870
                # B200 code-path
                kwargs = {
                    "gemm1_alpha": layer.gemm1_alpha,
                    "gemm1_beta": layer.gemm1_beta,
                    "gemm1_clamp_limit": layer.gemm1_clamp_limit,
871
                    # TODO(bnell): part of quant_config
872
873
                    "max_capture_size": self.max_capture_size,
                }
874
875
                return TrtLlmGenExperts(self.moe, self.moe_quant_config, **kwargs)
            elif self.mxfp4_backend == Mxfp4Backend.MARLIN:
876
                return MarlinExperts(self.moe_quant_config)
877
            elif self.mxfp4_backend == Mxfp4Backend.TRITON:
878
879
                if self.moe.is_lora_enabled:
                    return UnfusedOAITritonExperts(self.moe_quant_config)
880
                return OAITritonExperts(self.moe_quant_config)
881
882
883
884
            else:
                raise NotImplementedError(
                    f"Incompatible Mxfp4 backend ({self.mxfp4_backend}) for EP"
                )
885

886
887
888
    @property
    def allow_inplace(self) -> bool:
        return True
889

890
891
    def apply(
        self,
892
        layer: FusedMoE,
893
894
895
896
897
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
        use_grouped_topk: bool = False,
898
899
        topk_group: int | None = None,
        num_expert_group: int | None = None,
900
        global_num_experts: int = -1,
901
902
        expert_map: torch.Tensor | None = None,
        custom_routing_function: Callable | None = None,
903
        scoring_func: str = "softmax",
904
        routed_scaling_factor: float = 1.0,
905
        e_score_correction_bias: torch.Tensor | None = None,
906
907
908
        apply_router_weight_on_input: bool = False,
        activation: str = "silu",
        enable_eplb: bool = False,
909
910
911
912
        expert_load_view: torch.Tensor | None = None,
        logical_to_physical_map: torch.Tensor | None = None,
        logical_replica_count: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
913
914
915
        if enable_eplb:
            raise NotImplementedError("EPLB is not supported for mxfp4")

916
        if self.mxfp4_backend == Mxfp4Backend.MARLIN:
917
            topk_weights, topk_ids, _ = layer.select_experts(
918
919
                hidden_states=x,
                router_logits=router_logits,
920
            )
921

922
            return fused_marlin_moe(
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
                x,
                layer.w13_weight,
                layer.w2_weight,
                layer.w13_bias,
                layer.w2_bias,
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                router_logits,
                topk_weights,
                topk_ids,
                global_scale1=None,
                global_scale2=None,
                quant_type_id=scalar_types.float4_e2m1f.id,
                apply_router_weight_on_input=apply_router_weight_on_input,
                global_num_experts=global_num_experts,
                activation=activation,
939
                expert_map=expert_map,
940
                input_dtype=self.marlin_input_dtype,
941
            )
942

943
        assert _can_support_mxfp4(
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
            use_grouped_topk,
            topk_group,
            num_expert_group,
            expert_map,
            custom_routing_function,
            e_score_correction_bias,
            apply_router_weight_on_input,
            scoring_func,
            activation,
            expert_load_view,
            logical_to_physical_map,
            logical_replica_count,
        ), "MXFP4 are not supported with this configuration."

        if (
            self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
            or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
        ):
962
            from flashinfer import trtllm_fp4_block_scale_moe
963

964
            if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16:
965
966
967
                assert x.dtype == torch.bfloat16
                x_quant = x
                x_scale = None
968
969
            elif self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM:
                from flashinfer import mxfp8_quantize
970

971
                x_quant, x_scale = mxfp8_quantize(x, False)  # to mxfp8
972
                x_scale = x_scale.view(torch.float8_e4m3fn).reshape(*x.shape[:-1], -1)
973

974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
            trtllm_gen_output = trtllm_fp4_block_scale_moe(
                router_logits.to(torch.bfloat16),
                None,  # routing_bias
                x_quant,
                x_scale,
                layer.w13_weight,  # uint8 (e2m1 x 2)
                layer.w13_weight_scale,  # uint8 (e4m3 x 2)
                layer.w13_bias,  # fp32 per expert per channel
                layer.gemm1_alpha,  # fp32 per expert
                layer.gemm1_beta,  # fp32 per expert
                layer.gemm1_clamp_limit,  # fp32 per expert
                layer.w2_weight,  # uint8 (e2m1 x 2)
                layer.w2_weight_scale,  # ue8m0
                layer.w2_bias,  # fp32 per expert per channel
                None,  # output1_scale_scalar
                None,  # output1_scale_gate_scalar
                None,  # output2_scale_scalar
991
                global_num_experts,
992
993
994
995
                top_k,
                None,  # n_group
                None,  # topk_group
                self.intermediate_size,  # padded to multiple of 256
996
                layer.ep_rank * layer.local_num_experts,  # local_expert_offset
997
998
                self.num_experts,  # local num experts
                None,
999
                None,
1000
1001
                1 if renormalize else 0,  # routing_method_type, renormalize
                True,  # do finalize
1002
                tune_max_num_tokens=max(self.max_capture_size, 1),
1003
1004
            )[0]
            return trtllm_gen_output
1005
1006
1007
1008
        elif (
            self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
            or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
        ):
1009
1010
            from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe

1011
            topk_weights, topk_ids, _ = layer.select_experts(
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
                hidden_states=x,
                router_logits=router_logits,
            )

            # Backend-specific preparation
            if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS:
                from flashinfer import mxfp8_quantize

                x_quant, x_scale = mxfp8_quantize(x, True, 32)

1022
                fake_input_scale = torch.ones(self.num_experts, device=x.device)
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
                quant_scales = [
                    layer.w13_weight_scale.contiguous().view(torch.int32),
                    fake_input_scale,
                    layer.w2_weight_scale.contiguous().view(torch.int32),
                    fake_input_scale,
                ]

                fi_input = x_quant
                extra_kwargs = dict(
                    use_mxfp8_act_scaling=True,
                    input_sf=x_scale,
1034
1035
                    fc1_expert_weights=layer.w13_weight.contiguous().view(torch.long),
                    fc2_expert_weights=layer.w2_weight.contiguous().view(torch.long),
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
                )
            elif self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16:
                assert x.dtype == torch.bfloat16

                quant_scales = [
                    layer.w13_weight_scale,
                    layer.w2_weight_scale,
                ]

                fi_input = x
                extra_kwargs = dict(
                    use_w4_group_scaling=True,
                    fc1_expert_weights=layer.w13_weight,
                    fc2_expert_weights=layer.w2_weight,
                )

            output = torch.empty_like(x, dtype=torch.bfloat16)
            _ = flashinfer_cutlass_fused_moe(
                input=fi_input,
                token_selected_experts=topk_ids.to(torch.int).contiguous(),
                token_final_scales=topk_weights,
                output_dtype=torch.bfloat16,
                output=output,
                quant_scales=quant_scales,
                fc1_expert_biases=layer.w13_bias,
                fc2_expert_biases=layer.w2_bias,
                swiglu_alpha=layer.gemm1_alpha,
                swiglu_beta=layer.gemm1_beta,
                swiglu_limit=layer.gemm1_clamp_limit,
                tp_size=self.moe.tp_size,
                tp_rank=self.moe.tp_rank,
                ep_size=self.moe.ep_size,
                ep_rank=self.moe.ep_rank,
1069
                tune_max_num_tokens=max(self.max_capture_size, 1),
1070
1071
1072
1073
1074
                **extra_kwargs,
            )

            return output
        elif self.mxfp4_backend == Mxfp4Backend.TRITON:
1075
            from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (  # noqa: E501
1076
1077
1078
                triton_kernel_moe_forward,
            )

1079
1080
            return triton_kernel_moe_forward(
                hidden_states=x,
1081
1082
                w1=layer.w13_weight,
                w2=layer.w2_weight,
1083
1084
1085
1086
1087
                gating_output=router_logits,
                topk=top_k,
                renormalize=renormalize,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
1088
                quant_config=self.moe_quant_config,
1089
1090
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
1091
1092
        else:
            raise ValueError(f"Unsupported backend: {self.mxfp4_backend}")
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123


class IpexMxfp4MoEMethod(Mxfp4MoEMethod):
    def __init__(self, moe_config: FusedMoEConfig):
        super().__init__(moe_config)
        self.moe_config = moe_config

    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        super().create_weights(
            layer,
            num_experts,
            hidden_size,
            intermediate_size_per_partition,
            params_dtype,
            **extra_weight_attrs,
        )
        self.original_hidden_size = hidden_size

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        import intel_extension_for_pytorch as ipex

        layer.w13_weight.data = layer.w13_weight.data.view(torch.int32)
        layer.w2_weight.data = layer.w2_weight.data.view(torch.int32)
1124
        ep_rank_start = self.moe_config.ep_rank * self.moe_config.num_local_experts
1125
1126
1127
1128
1129
1130
1131
1132
        layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE(
            layer.w13_weight,
            layer.w2_weight,
            w1_scale_inv=layer.w13_weight_scale,
            w2_scale_inv=layer.w2_weight_scale,
            w13_bias=layer.w13_bias,
            w2_bias=layer.w2_bias,
            is_mxfp4=True,
1133
            experts_start_id=ep_rank_start,
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
        )

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
        use_grouped_topk: bool = False,
        topk_group: int | None = None,
        num_expert_group: int | None = None,
        global_num_experts: int = -1,
        expert_map: torch.Tensor | None = None,
        custom_routing_function: Callable | None = None,
        scoring_func: str = "softmax",
        routed_scaling_factor: float = 1.0,
        e_score_correction_bias: torch.Tensor | None = None,
        apply_router_weight_on_input: bool = False,
        activation: str = "silu",
        enable_eplb: bool = False,
        expert_load_view: torch.Tensor | None = None,
        logical_to_physical_map: torch.Tensor | None = None,
        logical_replica_count: torch.Tensor | None = None,
    ) -> torch.Tensor:
        assert activation == "swigluoai", (
            "Only swiglu_oai activation is supported for IPEX MXFP4 MoE"
1161
        )
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
        hidden_size_pad = round_up(self.original_hidden_size, 128)
        x_pad = torch.nn.functional.pad(x, (0, hidden_size_pad - x.size(-1)))
        hidden_states = layer.ipex_fusion(
            x_pad,
            use_grouped_topk,
            top_k,
            router_logits,
            renormalize,
            topk_group,
            num_expert_group,
            activation="swiglu_oai",
        )
        hidden_states = hidden_states[..., : self.original_hidden_size].contiguous()
        return hidden_states