mxfp4.py 45.3 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections.abc import Callable
4
from enum import Enum
5
from typing import Optional
6
7
8
9
10

import torch
from torch.nn.parameter import Parameter

from vllm import envs
11
from vllm.config import get_current_vllm_config
12
from vllm.logger import init_logger
13
14
15
16
17
from vllm.model_executor.layers.fused_moe import (
    FusedMoE,
    FusedMoEConfig,
    FusedMoEMethodBase,
)
18
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
19
from vllm.model_executor.layers.fused_moe.config import (
20
    FusedMoEQuantConfig,
21
    mxfp4_mxfp8_moe_quant_config,
22
    mxfp4_w4a16_moe_quant_config,
23
    ocp_mx_moe_quant_config,
24
)
25
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
26
    BatchedMarlinExperts,
27
28
29
    MarlinExperts,
    fused_marlin_moe,
)
30
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
31
32
    OAITritonExperts,
)
33
from vllm.model_executor.layers.fused_moe.trtllm_moe import TrtLlmGenExperts
34
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
35
36
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
37
38
39
    QuantizationConfig,
    QuantizeMethodBase,
)
40
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
41
42
    prepare_moe_fp4_layer_for_marlin,
)
43
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
44
45
    _can_support_mxfp4,
    _swizzle_mxfp4,
46
    get_padding_alignment,
47
48
)
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
49
from vllm.model_executor.utils import set_weight_attrs
50
from vllm.platforms import current_platform
51
from vllm.scalar_type import scalar_types
52
from vllm.utils.flashinfer import has_flashinfer
53
from vllm.utils.import_utils import has_triton_kernels
Cyrus Leung's avatar
Cyrus Leung committed
54
from vllm.utils.math_utils import round_up
55
from vllm.utils.torch_utils import is_torch_equal_or_newer
56

57
58
59
logger = init_logger(__name__)


60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# enum for mxfp4 backend
class Mxfp4Backend(Enum):
    NONE = 0

    # FlashInfer Backend
    SM100_FI_MXFP4_MXFP8_TRTLLM = 1
    SM100_FI_MXFP4_MXFP8_CUTLASS = 2
    SM100_FI_MXFP4_BF16 = 3
    SM90_FI_MXFP4_BF16 = 4

    # Marlin Backend
    MARLIN = 5

    # Triton Backend
    TRITON = 6


77
78
79
80
81
82
83
84
85
86
87
88
89
def get_mxfp4_backend_with_lora() -> Mxfp4Backend:
    """
    Not all MXFP4 backends support LoRA. Select backends that are known to
    have LoRA support.
    """
    if not current_platform.is_cuda():
        return Mxfp4Backend.NONE

    logger.info_once("[get_mxfp4_backend_with_lora] Using Marlin backend")
    return Mxfp4Backend.MARLIN


def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend:
90
    # Backend Selection
91
92
93
94

    if with_lora_support:
        return get_mxfp4_backend_with_lora()

95
    if current_platform.is_cuda():
96
97
98
99
100
        if (
            current_platform.is_device_capability(90)
            and has_flashinfer()
            and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16
        ):
101
102
            logger.info_once("Using FlashInfer MXFP4 BF16 backend for SM90")
            return Mxfp4Backend.SM90_FI_MXFP4_BF16
103
104
105
106
107
108
        elif (
            current_platform.is_device_capability(100)
            and has_flashinfer()
            and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS
        ):
            logger.info_once("Using FlashInfer MXFP4 MXFP8 CUTLASS backend for SM100")
109
            return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
110
111
112
113
114
        elif (
            current_platform.is_device_capability(100)
            and has_flashinfer()
            and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
        ):
115
116
117
118
119
120
            return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
        elif current_platform.is_device_capability(100) and has_flashinfer():
            logger.info_once(
                "Using FlashInfer MXFP4 BF16 backend for SM100, "
                "For faster performance on SM100, consider setting "
                "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1, though this may impact "
121
122
                "accuracy."
            )
123
            return Mxfp4Backend.SM100_FI_MXFP4_BF16
124
125
126
127
        elif (
            current_platform.is_device_capability(100)
            or current_platform.is_device_capability(90)
        ) and not has_flashinfer():
128
129
130
            logger.warning_once(
                "MXFP4 MoE is enabled on Hopper/Blackwell but FlashInfer "
                "is not available. This may result in degraded performance. "
131
132
                "Please `pip install vllm[flashinfer]` for best results."
            )
133

134
        # If FlashInfer is not available, try either Marlin or Triton
135
136
137
138
139
140
141
142
143
        triton_kernels_supported = (
            has_triton_kernels()
            and is_torch_equal_or_newer("2.8.0")
            # NOTE: triton_kernels are only confirmed to work on SM90 and SM100
            # SM110 fails with this error: https://github.com/vllm-project/vllm/issues/29317
            # SM120 needs this fix: https://github.com/triton-lang/triton/pull/8498
            and (9, 0) <= current_platform.get_device_capability() < (11, 0)
        )
        if envs.VLLM_MXFP4_USE_MARLIN or not triton_kernels_supported:
144
145
146
147
148
            logger.info_once("Using Marlin backend")
            return Mxfp4Backend.MARLIN
        else:
            logger.info_once("Using Triton backend")
            return Mxfp4Backend.TRITON
149
150
151
    elif current_platform.is_xpu():
        logger.info_once("Using ipex marlin backend on XPU")
        return Mxfp4Backend.MARLIN
152
153
154
    elif current_platform.is_rocm() and has_triton_kernels():
        logger.info_once("Using Triton backend")
        return Mxfp4Backend.TRITON
155

156
    return Mxfp4Backend.NONE
157
158
159


class Mxfp4Config(QuantizationConfig):
160
    def __init__(self, ignored_layers: list[str] | None = None):
161
162
163
164
165
166
167
168
169
        super().__init__()
        self.ignored_layers = ignored_layers

    @classmethod
    def from_config(cls, config):
        return cls()

    @classmethod
    def get_min_capability(cls) -> int:
170
        return 80
171
172
173
174
175
176
177
178
179
180
181
182
183

    @classmethod
    def get_name(cls) -> QuantizationMethods:
        return "mxfp4"

    @classmethod
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
        return [torch.bfloat16]

    @classmethod
    def get_config_filenames(cls) -> list[str]:
        return []

184
185
186
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
187
188
189
190
        from vllm.attention.layer import Attention  # Avoid circular import

        if isinstance(layer, LinearBase):
            if self.ignored_layers and is_layer_skipped(
191
192
193
194
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
195
                return UnquantizedLinearMethod()
196
197
198
            # TODO: Add support for MXFP4 Linear Method.
            # MXFP4 LinearMethod is available in AMD-Quark, refer to that implementation
            # if you are interested in enabling MXFP4 here.
199
            logger.debug_once(
200
                "MXFP4 linear layer is not implemented - falling back to "
201
202
                "UnquantizedLinearMethod.",
                scope="local",
203
204
            )
            return UnquantizedLinearMethod()
205
        elif isinstance(layer, FusedMoE):
206
207
208
209
            if current_platform.is_xpu():
                return IpexMxfp4MoEMethod(layer.moe_config)
            else:
                return Mxfp4MoEMethod(layer.moe_config)
210
        elif isinstance(layer, Attention):
211
            # TODO: Add support for MXFP4 Attention.
212
            logger.debug_once(
213
                "MXFP4 attention layer is not implemented. "
214
215
                "Skipping quantization for this layer.",
                scope="local",
216
            )
217
218
219
220
221
        return None


class Mxfp4MoEMethod(FusedMoEMethodBase):
    def __init__(self, moe: FusedMoEConfig):
222
        super().__init__(moe)
223
        self.mxfp4_backend = get_mxfp4_backend(moe.is_lora_enabled)
224
        self.use_marlin = self.mxfp4_backend == Mxfp4Backend.MARLIN
225
        self.max_capture_size = (
226
            get_current_vllm_config().compilation_config.max_cudagraph_capture_size
227
        )
228

229
        assert self.mxfp4_backend != Mxfp4Backend.NONE, (
230
231
            f"get_mxfp4_backend(with_lora_support={moe.is_lora_enabled}) found"
            "no compatible MXFP4 MoE backend (FlashInfer/Marlin/Triton)."
232
233
            "Please check your environment and try again."
        )
234
        self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
235

236
237
238
239
240
241
242
243
244
    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
245
246
247
248
249
250
251
252
253
254
255
256
257
258
        self.num_experts = num_experts
        weight_dtype = torch.uint8
        scale_dtype = torch.uint8

        # FIXME (zyongye): ship after torch and safetensors support mxfp4
        # is_torch_mxfp4_available = (
        #     hasattr(torch, "float4_e2m1fn_x2") and
        #     hasattr(torch, "float8_e8m0fnu"))
        # if is_torch_mxfp4_available:
        #     weight_dtype = torch.float4_e2m1fn_x2
        #     scale_dtype = torch.float8_e8m0fnu

        mxfp4_block = 32

259
        intermediate_size_per_partition_after_pad = intermediate_size_per_partition
260
        if self.mxfp4_backend == Mxfp4Backend.MARLIN:
261
262
263
264
265
266
267
268
269
            # The moe marlin kernel requires that for each linear
            # n % 256 == 0 and k % 128 == 0.
            # In gate_up_proj:
            #    n = 2 * intermediate_size_per_partition_after_pad
            #    k = hidden_size
            # In down_proj
            #    n = hidden_size
            #    k = intermediate_size_per_partition_after_pad
            intermediate_size_per_partition_after_pad = round_up(
270
271
                intermediate_size_per_partition, 128
            )
272
273
274
275
            if current_platform.is_xpu():
                hidden_size = round_up(hidden_size, 128)
            else:
                hidden_size = round_up(hidden_size, 256)
276
277
278
279

            layer.params_dtype = params_dtype
            layer.num_experts = num_experts
            layer.hidden_size = hidden_size
280
            layer.intermediate_size_per_partition = (
281
                intermediate_size_per_partition_after_pad
282
283
284
285
286
            )
        elif (
            self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
            or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
        ):
287
288
289
            # pad the intermediate size to be a multiple of 2 * mxfp4_block
            # for to hold non-uniform sharded tensor as well as swizzling
            # other padding to increase performance
290
            intermediate_size_per_partition_after_pad = round_up(
291
292
                intermediate_size_per_partition, 256
            )
293
            hidden_size = round_up(hidden_size, 256)
294
295
296
297
        elif (
            self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
            or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
        ):
298
            intermediate_size_per_partition_after_pad = round_up(
299
300
                intermediate_size_per_partition, 128
            )
301
            hidden_size = round_up(hidden_size, 128)
302
        elif current_platform.is_rocm():
303
            pad_align = get_padding_alignment()
304
            intermediate_size_per_partition_after_pad = round_up(
305
                intermediate_size_per_partition, pad_align
306
            )
307
            hidden_size = round_up(hidden_size, pad_align)
308
309
        else:
            intermediate_size_per_partition_after_pad = round_up(
310
311
                intermediate_size_per_partition, 64
            )
312
313
314
315

        self.intermediate_size = intermediate_size_per_partition_after_pad
        self.hidden_size = hidden_size
        # Fused gate_up_proj (column parallel)
316
317
318
319
320
321
322
323
324
        w13_weight = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                2 * intermediate_size_per_partition_after_pad,
                hidden_size // 2,
                dtype=weight_dtype,
            ),
            requires_grad=False,
        )
325
326
327
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

328
329
330
331
332
333
334
335
336
        w13_weight_scale = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                2 * intermediate_size_per_partition_after_pad,
                hidden_size // mxfp4_block,
                dtype=scale_dtype,
            ),
            requires_grad=False,
        )
337
338
339
        layer.register_parameter("w13_weight_scale", w13_weight_scale)
        set_weight_attrs(w13_weight_scale, extra_weight_attrs)

340
341
342
343
344
345
346
347
        w13_bias = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                2 * intermediate_size_per_partition_after_pad,
                dtype=torch.bfloat16,
            ),
            requires_grad=False,
        )
348
349
350
351
        layer.register_parameter("w13_bias", w13_bias)
        set_weight_attrs(w13_bias, extra_weight_attrs)

        # down_proj (row parallel)
352
353
354
355
356
357
358
359
360
        w2_weight = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                hidden_size,
                intermediate_size_per_partition_after_pad // 2,
                dtype=weight_dtype,
            ),
            requires_grad=False,
        )
361
362
363
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

364
365
366
367
368
369
370
371
372
        w2_weight_scale = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                hidden_size,
                intermediate_size_per_partition_after_pad // mxfp4_block,
                dtype=scale_dtype,
            ),
            requires_grad=False,
        )
373
374
375
        layer.register_parameter("w2_weight_scale", w2_weight_scale)
        set_weight_attrs(w2_weight_scale, extra_weight_attrs)

376
377
378
379
380
381
382
383
        w2_bias = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                hidden_size,
                dtype=torch.bfloat16,
            ),
            requires_grad=False,
        )
384
385
386
387
        layer.register_parameter("w2_bias", w2_bias)
        set_weight_attrs(w2_bias, extra_weight_attrs)

    def process_weights_after_loading(self, layer):
388
        if self.mxfp4_backend == Mxfp4Backend.MARLIN:
389
            prepare_moe_fp4_layer_for_marlin(layer)
390
391
392
393
394
        elif (
            self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
            or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
        ):
            from flashinfer.fp4_quantization import nvfp4_block_scale_interleave
395
            from flashinfer.fused_moe.core import get_w2_permute_indices_with_cache
396
397
398
399
400
401
402
403
404
405
406
407
408

            layer.gemm1_alpha = Parameter(
                torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(),
                requires_grad=False,
            )
            layer.gemm1_beta = Parameter(
                torch.tensor([1.0] * self.num_experts, dtype=torch.float32).cuda(),
                requires_grad=False,
            )
            layer.gemm1_clamp_limit = Parameter(
                torch.tensor([7.0] * self.num_experts, dtype=torch.float32).cuda(),
                requires_grad=False,
            )
409
410
            sf_block_size = 32  # mxfp4 block size

411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
            assert (
                layer.w13_weight.dim() == 3
                and layer.w13_weight.shape[0] == self.num_experts
                and layer.w13_weight.shape[1] == self.intermediate_size * 2
                and layer.w13_weight.shape[2] == self.hidden_size // 2
            )
            assert (
                layer.w13_weight_scale.dim() == 3
                and layer.w13_weight_scale.shape[0] == self.num_experts
                and layer.w13_weight_scale.shape[1] == self.intermediate_size * 2
                and layer.w13_weight_scale.shape[2] == self.hidden_size // sf_block_size
            )
            assert (
                layer.w2_weight.dim() == 3
                and layer.w2_weight.shape[0] == self.num_experts
                and layer.w2_weight.shape[1] == self.hidden_size
                and layer.w2_weight.shape[2] == self.intermediate_size // 2
            )
            assert (
                layer.w2_weight_scale.dim() == 3
                and layer.w2_weight_scale.shape[1] == self.hidden_size
                and layer.w2_weight_scale.shape[2]
                == self.intermediate_size // sf_block_size
            )
            assert (
                layer.w13_bias.dim() == 2
                and layer.w13_bias.shape[0] == self.num_experts
                and layer.w13_bias.shape[1] == self.intermediate_size * 2
            )
            assert (
                layer.w2_bias.dim() == 2
                and layer.w2_bias.shape[0] == self.num_experts
                and layer.w2_bias.shape[1] == self.hidden_size
            )
445
446
447
448
449
450
451
452

            w13_weight_scale = layer.w13_weight_scale.data
            w2_weight_scale = layer.w2_weight_scale.data
            w13_weight = layer.w13_weight.data
            w2_weight = layer.w2_weight.data
            w13_bias = layer.w13_bias.data.to(torch.float32)
            w2_bias = layer.w2_bias.data.to(torch.float32)

co63oc's avatar
co63oc committed
453
            # Swap w1 and w3 as the definition of
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
            # swiglu is different in the trtllm-gen
            def swap_every_two_rows(x, axis=-1):
                shape = x.shape
                if axis < 0:
                    axis = len(shape) + axis

                # Create a new shape with pairs swapped along specified axis
                new_shape = list(shape)
                new_shape[axis] = shape[axis] // 2
                new_shape.insert(axis + 1, 2)

                # Reshape to expose pairs, swap them, and reshape back
                x = x.reshape(*new_shape)
                x = x.flip(axis + 1)
                new_shape = list(shape)
                return x.reshape(*new_shape)

            w13_weight_scale = swap_every_two_rows(w13_weight_scale, -2)
            w13_weight = swap_every_two_rows(w13_weight, -2)
            w13_bias = swap_every_two_rows(w13_bias, -1)

            # Do not interleave as the checkpoint is already interleaved

            # Shuffle weights and scaling factors for transposed mma output
            gemm1_weights_mxfp4_shuffled = []
            gemm1_scales_mxfp4_shuffled = []
            gemm2_weights_mxfp4_shuffled = []
            gemm2_scales_mxfp4_shuffled = []
            gemm1_bias_shuffled = []
            gemm2_bias_shuffled = []
            epilogue_tile_m = 128  # FIXME: this depends on the kernel internals
            for i in range(self.num_experts):
486
                # w13 weight shuffling
487
                permute_indices = get_w2_permute_indices_with_cache(
488
489
490
491
                    self._cache_permute_indices,
                    w13_weight[i].view(torch.uint8),
                    epilogue_tile_m,
                )
492
493
494
495
496
                gemm1_weights_mxfp4_shuffled.append(
                    w13_weight[i]
                    .view(torch.uint8)[permute_indices.to(w13_weight.device)]
                    .contiguous()
                )
497
                # w13 scale shuffling
498
                permute_sf_indices = get_w2_permute_indices_with_cache(
499
500
501
502
503
                    self._cache_permute_indices,
                    w13_weight_scale[i].view(torch.uint8),
                    epilogue_tile_m,
                    num_elts_per_sf=16,
                )
504
                gemm1_scales_mxfp4_shuffled.append(
505
506
507
508
509
510
511
512
                    nvfp4_block_scale_interleave(
                        w13_weight_scale[i]
                        .view(torch.uint8)[
                            permute_sf_indices.to(w13_weight_scale.device)
                        ]
                        .contiguous()
                    )
                )
513
                # w13 bias shuffling
514
                permute_bias_indices = get_w2_permute_indices_with_cache(
515
516
517
518
                    self._cache_permute_indices,
                    w13_bias[i].clone().reshape(-1, 1),
                    epilogue_tile_m,
                )
519
520
521
522
523
524
                gemm1_bias_shuffled.append(
                    w13_bias[i]
                    .clone()
                    .reshape(-1, 1)[permute_bias_indices.to(w13_bias.device)]
                    .contiguous()
                )
525
                # w2 weight shuffling
526
                permute_indices = get_w2_permute_indices_with_cache(
527
528
529
530
                    self._cache_permute_indices,
                    w2_weight[i].view(torch.uint8),
                    epilogue_tile_m,
                )
531
532
533
534
535
                gemm2_weights_mxfp4_shuffled.append(
                    w2_weight[i]
                    .view(torch.uint8)[permute_indices.to(w2_weight.device)]
                    .contiguous()
                )
536
                # w2 scale shuffling
537
                permute_sf_indices = get_w2_permute_indices_with_cache(
538
539
540
541
542
                    self._cache_permute_indices,
                    w2_weight_scale[i].view(torch.uint8),
                    epilogue_tile_m,
                    num_elts_per_sf=16,
                )
543
                gemm2_scales_mxfp4_shuffled.append(
544
545
546
547
548
549
550
551
                    nvfp4_block_scale_interleave(
                        w2_weight_scale[i]
                        .view(torch.uint8)[
                            permute_sf_indices.to(w2_weight_scale.device)
                        ]
                        .contiguous()
                    )
                )
552
                # w2 bias shuffling
553
                permute_indices = get_w2_permute_indices_with_cache(
554
555
556
557
                    self._cache_permute_indices,
                    w2_bias[i].clone().reshape(-1, 1),
                    epilogue_tile_m,
                )
558
559
560
561
562
563
                gemm2_bias_shuffled.append(
                    w2_bias[i]
                    .clone()
                    .reshape(-1, 1)[permute_indices.to(w2_bias.device)]
                    .contiguous()
                )
564
565

            w13_weight = torch.stack(gemm1_weights_mxfp4_shuffled)
566
567
568
569
570
571
572
573
574
            w13_weight_scale = (
                torch.stack(gemm1_scales_mxfp4_shuffled)
                .reshape(
                    self.num_experts,
                    2 * self.intermediate_size,
                    self.hidden_size // sf_block_size,
                )
                .view(torch.float8_e4m3fn)
            )
575
576

            w2_weight = torch.stack(gemm2_weights_mxfp4_shuffled)
577
578
579
580
581
582
583
584
585
            w2_weight_scale = (
                torch.stack(gemm2_scales_mxfp4_shuffled)
                .reshape(
                    self.num_experts,
                    self.hidden_size,
                    self.intermediate_size // sf_block_size,
                )
                .view(torch.float8_e4m3fn)
            )
586
587

            layer.w13_weight = Parameter(w13_weight, requires_grad=False)
588
            layer.w13_weight_scale = Parameter(w13_weight_scale, requires_grad=False)
589
            layer.w2_weight = Parameter(w2_weight, requires_grad=False)
590
            layer.w2_weight_scale = Parameter(w2_weight_scale, requires_grad=False)
591
592
            layer.w13_bias = Parameter(
                torch.stack(gemm1_bias_shuffled).reshape(self.num_experts, -1),
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
                requires_grad=False,
            )
            layer.w2_bias = Parameter(
                torch.stack(gemm2_bias_shuffled).reshape(self.num_experts, -1),
                requires_grad=False,
            )
        elif (
            self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
            or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
        ):
            layer.gemm1_alpha = Parameter(
                torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(),
                requires_grad=False,
            )
            layer.gemm1_beta = Parameter(
                torch.tensor([1.0] * self.num_experts, dtype=torch.float32).cuda(),
                requires_grad=False,
            )
            layer.gemm1_clamp_limit = Parameter(
                torch.tensor([7.0] * self.num_experts, dtype=torch.float32).cuda(),
                requires_grad=False,
            )
615
616
617
618

            sf_block_size = 32  # mxfp4 block size

            # Common shape assertions
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
            assert (
                layer.w13_weight.dim() == 3
                and layer.w13_weight.shape[0] == self.num_experts
                and layer.w13_weight.shape[1] == self.intermediate_size * 2
                and layer.w13_weight.shape[2] == self.hidden_size // 2
            )
            assert (
                layer.w13_weight_scale.dim() == 3
                and layer.w13_weight_scale.shape[0] == self.num_experts
                and layer.w13_weight_scale.shape[1] == self.intermediate_size * 2
                and layer.w13_weight_scale.shape[2] == self.hidden_size // sf_block_size
            )
            assert (
                layer.w2_weight.dim() == 3
                and layer.w2_weight.shape[0] == self.num_experts
                and layer.w2_weight.shape[1] == self.hidden_size
                and layer.w2_weight.shape[2] == self.intermediate_size // 2
            )
            assert (
                layer.w2_weight_scale.dim() == 3
                and layer.w2_weight_scale.shape[1] == self.hidden_size
                and layer.w2_weight_scale.shape[2]
                == self.intermediate_size // sf_block_size
            )
            assert (
                layer.w13_bias.dim() == 2
                and layer.w13_bias.shape[0] == self.num_experts
                and layer.w13_bias.shape[1] == self.intermediate_size * 2
            )
            assert (
                layer.w2_bias.dim() == 2
                and layer.w2_bias.shape[0] == self.num_experts
                and layer.w2_bias.shape[1] == self.hidden_size
            )
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677

            # De-interleave and swap for w13 weight, bias, and scales
            w13_w = layer.w13_weight.data
            gate_w, up_w = w13_w[:, ::2, :], w13_w[:, 1::2, :]
            deinterleaved_w13_w = torch.cat([gate_w, up_w], dim=1)
            w1_w, w3_w = torch.chunk(deinterleaved_w13_w, 2, dim=1)
            w13_weight_swapped = torch.cat([w3_w, w1_w], dim=1)

            w13_b = layer.w13_bias.data.to(torch.float32)
            gate_b, up_b = w13_b[:, ::2], w13_b[:, 1::2]
            deinterleaved_w13_b = torch.cat([gate_b, up_b], dim=1)
            b1, b3 = torch.chunk(deinterleaved_w13_b, 2, dim=-1)
            w13_bias_swapped = torch.cat([b3, b1], dim=-1).to(torch.bfloat16)

            w13_s = layer.w13_weight_scale.data
            gate_s, up_s = w13_s[:, ::2, :], w13_s[:, 1::2, :]
            deinterleaved_w13_s = torch.cat([gate_s, up_s], dim=1)
            s1, s3 = torch.chunk(deinterleaved_w13_s, 2, dim=1)
            w13_scale_swapped = torch.cat([s3, s1], dim=1)

            if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS:
                from flashinfer import block_scale_interleave

                orig_shape = w13_scale_swapped.shape
                w13_scale_interleaved = block_scale_interleave(
678
679
                    w13_scale_swapped.view(torch.uint8)
                ).reshape(orig_shape)
680
681
682
683

                w2_s = layer.w2_weight_scale.data
                orig_shape = w2_s.shape
                w2_scale_interleaved = block_scale_interleave(
684
685
686
687
688
689
690
691
692
693
694
                    w2_s.view(torch.uint8)
                ).reshape(orig_shape)

                layer.w13_weight = Parameter(w13_weight_swapped, requires_grad=False)
                layer.w13_weight_scale = Parameter(
                    w13_scale_interleaved, requires_grad=False
                )
                layer.w13_bias = Parameter(w13_bias_swapped, requires_grad=False)
                layer.w2_weight_scale = Parameter(
                    w2_scale_interleaved, requires_grad=False
                )
695
696
697
698
            elif self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16:

                def _interleave_mxfp4_cutlass_sm90(w):
                    w_shape = w.shape
699
700
701
                    w_interleaved = w.reshape(
                        w_shape[0], w_shape[1], (w_shape[2] // 4), 4
                    )
702
703
                    w_interleaved = w_interleaved.permute(0, 2, 1, 3)
                    w_interleaved = w_interleaved.reshape(
704
705
                        w_shape[0], w_shape[2] // 4, w_shape[1] * 4
                    )
706
707
                    return w_interleaved

708
709
                w31_scales = w13_scale_swapped.to(torch.uint8).view(torch.uint8)
                w31_scales_interleaved = _interleave_mxfp4_cutlass_sm90(w31_scales)
710
711
712

                w2_weight_scale = layer.w2_weight_scale.data
                w2_scales = w2_weight_scale.to(torch.uint8).view(torch.uint8)
713
714
715
716
717
718
719
720
                w2_scales_interleaved = _interleave_mxfp4_cutlass_sm90(w2_scales)

                layer.w13_weight = torch.nn.Parameter(
                    torch.cat([w3_w, w1_w], dim=1), requires_grad=False
                )
                layer.w13_bias = torch.nn.Parameter(
                    w13_bias_swapped, requires_grad=False
                )
721
                layer.w13_weight_scale = torch.nn.Parameter(
722
723
                    w31_scales_interleaved, requires_grad=False
                )
724
                layer.w2_weight_scale = torch.nn.Parameter(
725
726
                    w2_scales_interleaved, requires_grad=False
                )
727
        elif self.mxfp4_backend == Mxfp4Backend.TRITON:
728
729
730
731
732
733
734
735
            from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig

            w13_bias = layer.w13_bias.to(torch.float32)
            w2_bias = layer.w2_bias.to(torch.float32)

            layer.w13_bias = Parameter(w13_bias, requires_grad=False)
            layer.w2_bias = Parameter(w2_bias, requires_grad=False)

736
737
738
739
740
            # Ideally we'd use FusedMoEModularKernel.prepare_finalize object
            # (stored in self.fused_experts) to determine if the MoE has a
            # batched activation format. As self.fused_experts is not
            # initialized at this point, we resort to checking the MoE config
            # directly.
741
            is_batched_moe = self.moe.use_pplx_kernels or self.moe.use_deepep_ll_kernels
742
            if is_batched_moe:
743
744
745
746
747
                num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8
            else:
                num_warps = 8

            w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
748
749
                layer.w13_weight, layer.w13_weight_scale, num_warps
            )
750
            w2_weight, w2_flex, w2_scale = _swizzle_mxfp4(
751
752
                layer.w2_weight, layer.w2_weight_scale, num_warps
            )
753
754

            self.w13_precision_config = PrecisionConfig(
755
756
                weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex)
            )
757
            self.w2_precision_config = PrecisionConfig(
758
759
                weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex)
            )
760

761
762
            self.w13_weight = w13_weight
            self.w2_weight = w2_weight
763
764
765
766
            del layer.w13_weight
            del layer.w2_weight
            layer.w13_weight = w13_weight
            layer.w2_weight = w2_weight
767
768
        else:
            raise ValueError(f"Unsupported backend: {self.mxfp4_backend}")
769

770
    def get_fused_moe_quant_config(
771
        self, layer: torch.nn.Module
772
    ) -> FusedMoEQuantConfig | None:
773
        if self.mxfp4_backend == Mxfp4Backend.MARLIN:
774
775
776
777
778
779
780
            return mxfp4_w4a16_moe_quant_config(
                w1_bias=layer.w13_bias,
                w2_bias=layer.w2_bias,
                w1_scale=layer.w13_weight_scale,
                w2_scale=layer.w2_weight_scale,
            )
        elif self.mxfp4_backend == Mxfp4Backend.TRITON:
781
782
            w1_scale = self.w13_precision_config
            w2_scale = self.w2_precision_config
783
784
785
786
787
788
            return mxfp4_w4a16_moe_quant_config(
                w1_bias=layer.w13_bias,
                w2_bias=layer.w2_bias,
                w1_scale=w1_scale,
                w2_scale=w2_scale,
            )
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
        elif self.mxfp4_backend in [
            Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM,
            Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS,
        ]:
            return mxfp4_mxfp8_moe_quant_config(
                w1_bias=layer.w13_bias,
                w2_bias=layer.w2_bias,
                w1_scale=layer.w13_weight_scale,
                w2_scale=layer.w2_weight_scale,
            )
        elif self.mxfp4_backend in [Mxfp4Backend.SM100_FI_MXFP4_BF16]:
            return mxfp4_w4a16_moe_quant_config(
                w1_bias=layer.w13_bias,
                w2_bias=layer.w2_bias,
                w1_scale=layer.w13_weight_scale,
                w2_scale=layer.w2_weight_scale,
            )
806
807
808
        else:
            w1_scale = layer.w13_weight_scale
            w2_scale = layer.w2_weight_scale
809
810
            return ocp_mx_moe_quant_config(
                quant_dtype="mxfp4",
811
812
813
814
815
                w1_bias=layer.w13_bias,
                w2_bias=layer.w2_bias,
                w1_scale=w1_scale,
                w2_scale=w2_scale,
            )
816

817
818
819
820
821
    def select_gemm_impl(
        self,
        prepare_finalize: mk.FusedMoEPrepareAndFinalize,
        layer: torch.nn.Module,
    ) -> mk.FusedMoEPermuteExpertsUnpermute:
822
823
824
825
        if (
            prepare_finalize.activation_format
            == mk.FusedMoEActivationFormat.BatchedExperts
        ):
826
827
828
829
830
831
832
833
834
835
836
            if self.mxfp4_backend == Mxfp4Backend.MARLIN:
                max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
                assert max_num_tokens_per_rank is not None
                assert self.moe_quant_config is not None
                return BatchedMarlinExperts(
                    max_num_tokens=max_num_tokens_per_rank,
                    num_dispatchers=prepare_finalize.num_dispatchers(),
                    quant_config=self.moe_quant_config,
                )
            else:
                raise NotImplementedError(
837
838
                    f"Incompatible Mxfp4 backend ({self.mxfp4_backend}) for "
                    "EP batched experts format"
839
                )
840
        else:
841
            assert self.moe_quant_config is not None
842
843
844
845
            if (
                self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
                or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
            ):
846
847
848
849
850
                # B200 code-path
                kwargs = {
                    "gemm1_alpha": layer.gemm1_alpha,
                    "gemm1_beta": layer.gemm1_beta,
                    "gemm1_clamp_limit": layer.gemm1_clamp_limit,
851
                    # TODO(bnell): part of quant_config
852
853
                    "max_capture_size": self.max_capture_size,
                }
854
855
                return TrtLlmGenExperts(self.moe, self.moe_quant_config, **kwargs)
            elif self.mxfp4_backend == Mxfp4Backend.MARLIN:
856
                return MarlinExperts(self.moe_quant_config)
857
            elif self.mxfp4_backend == Mxfp4Backend.TRITON:
858
                return OAITritonExperts(self.moe_quant_config)
859
860
861
862
            else:
                raise NotImplementedError(
                    f"Incompatible Mxfp4 backend ({self.mxfp4_backend}) for EP"
                )
863

864
865
866
    @property
    def allow_inplace(self) -> bool:
        return True
867

868
869
    def apply(
        self,
870
        layer: FusedMoE,
871
872
873
874
875
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
        use_grouped_topk: bool = False,
876
877
        topk_group: int | None = None,
        num_expert_group: int | None = None,
878
        global_num_experts: int = -1,
879
880
        expert_map: torch.Tensor | None = None,
        custom_routing_function: Callable | None = None,
881
        scoring_func: str = "softmax",
882
        routed_scaling_factor: float = 1.0,
883
        e_score_correction_bias: torch.Tensor | None = None,
884
885
886
        apply_router_weight_on_input: bool = False,
        activation: str = "silu",
        enable_eplb: bool = False,
887
888
889
890
        expert_load_view: torch.Tensor | None = None,
        logical_to_physical_map: torch.Tensor | None = None,
        logical_replica_count: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
891
892
893
        if enable_eplb:
            raise NotImplementedError("EPLB is not supported for mxfp4")

894
        if self.mxfp4_backend == Mxfp4Backend.MARLIN:
895
            topk_weights, topk_ids, _ = layer.select_experts(
896
897
                hidden_states=x,
                router_logits=router_logits,
898
            )
899

900
            return fused_marlin_moe(
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
                x,
                layer.w13_weight,
                layer.w2_weight,
                layer.w13_bias,
                layer.w2_bias,
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                router_logits,
                topk_weights,
                topk_ids,
                global_scale1=None,
                global_scale2=None,
                quant_type_id=scalar_types.float4_e2m1f.id,
                apply_router_weight_on_input=apply_router_weight_on_input,
                global_num_experts=global_num_experts,
                activation=activation,
917
918
                expert_map=expert_map,
            )
919

920
        assert _can_support_mxfp4(
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
            use_grouped_topk,
            topk_group,
            num_expert_group,
            expert_map,
            custom_routing_function,
            e_score_correction_bias,
            apply_router_weight_on_input,
            scoring_func,
            activation,
            expert_load_view,
            logical_to_physical_map,
            logical_replica_count,
        ), "MXFP4 are not supported with this configuration."

        if (
            self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
            or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
        ):
939
            from flashinfer import trtllm_fp4_block_scale_moe
940

941
            if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16:
942
943
944
                assert x.dtype == torch.bfloat16
                x_quant = x
                x_scale = None
945
946
            elif self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM:
                from flashinfer import mxfp8_quantize
947

948
                x_quant, x_scale = mxfp8_quantize(x, False)  # to mxfp8
949
                x_scale = x_scale.view(torch.float8_e4m3fn).reshape(*x.shape[:-1], -1)
950

951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
            trtllm_gen_output = trtllm_fp4_block_scale_moe(
                router_logits.to(torch.bfloat16),
                None,  # routing_bias
                x_quant,
                x_scale,
                layer.w13_weight,  # uint8 (e2m1 x 2)
                layer.w13_weight_scale,  # uint8 (e4m3 x 2)
                layer.w13_bias,  # fp32 per expert per channel
                layer.gemm1_alpha,  # fp32 per expert
                layer.gemm1_beta,  # fp32 per expert
                layer.gemm1_clamp_limit,  # fp32 per expert
                layer.w2_weight,  # uint8 (e2m1 x 2)
                layer.w2_weight_scale,  # ue8m0
                layer.w2_bias,  # fp32 per expert per channel
                None,  # output1_scale_scalar
                None,  # output1_scale_gate_scalar
                None,  # output2_scale_scalar
968
                global_num_experts,
969
970
971
972
                top_k,
                None,  # n_group
                None,  # topk_group
                self.intermediate_size,  # padded to multiple of 256
973
                layer.ep_rank * layer.local_num_experts,  # local_expert_offset
974
975
                self.num_experts,  # local num experts
                None,
976
                None,
977
978
                1 if renormalize else 0,  # routing_method_type, renormalize
                True,  # do finalize
979
                tune_max_num_tokens=max(self.max_capture_size, 1),
980
981
            )[0]
            return trtllm_gen_output
982
983
984
985
        elif (
            self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
            or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
        ):
986
987
            from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe

988
            topk_weights, topk_ids, _ = layer.select_experts(
989
990
991
992
993
994
995
996
997
998
                hidden_states=x,
                router_logits=router_logits,
            )

            # Backend-specific preparation
            if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS:
                from flashinfer import mxfp8_quantize

                x_quant, x_scale = mxfp8_quantize(x, True, 32)

999
                fake_input_scale = torch.ones(self.num_experts, device=x.device)
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
                quant_scales = [
                    layer.w13_weight_scale.contiguous().view(torch.int32),
                    fake_input_scale,
                    layer.w2_weight_scale.contiguous().view(torch.int32),
                    fake_input_scale,
                ]

                fi_input = x_quant
                extra_kwargs = dict(
                    use_mxfp8_act_scaling=True,
                    input_sf=x_scale,
1011
1012
                    fc1_expert_weights=layer.w13_weight.contiguous().view(torch.long),
                    fc2_expert_weights=layer.w2_weight.contiguous().view(torch.long),
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
                )
            elif self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16:
                assert x.dtype == torch.bfloat16

                quant_scales = [
                    layer.w13_weight_scale,
                    layer.w2_weight_scale,
                ]

                fi_input = x
                extra_kwargs = dict(
                    use_w4_group_scaling=True,
                    fc1_expert_weights=layer.w13_weight,
                    fc2_expert_weights=layer.w2_weight,
                )

            output = torch.empty_like(x, dtype=torch.bfloat16)
            _ = flashinfer_cutlass_fused_moe(
                input=fi_input,
                token_selected_experts=topk_ids.to(torch.int).contiguous(),
                token_final_scales=topk_weights,
                output_dtype=torch.bfloat16,
                output=output,
                quant_scales=quant_scales,
                fc1_expert_biases=layer.w13_bias,
                fc2_expert_biases=layer.w2_bias,
                swiglu_alpha=layer.gemm1_alpha,
                swiglu_beta=layer.gemm1_beta,
                swiglu_limit=layer.gemm1_clamp_limit,
                tp_size=self.moe.tp_size,
                tp_rank=self.moe.tp_rank,
                ep_size=self.moe.ep_size,
                ep_rank=self.moe.ep_rank,
1046
                tune_max_num_tokens=max(self.max_capture_size, 1),
1047
1048
1049
1050
1051
                **extra_kwargs,
            )

            return output
        elif self.mxfp4_backend == Mxfp4Backend.TRITON:
1052
            from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (  # noqa: E501
1053
1054
1055
                triton_kernel_moe_forward,
            )

1056
1057
            return triton_kernel_moe_forward(
                hidden_states=x,
1058
1059
                w1=layer.w13_weight,
                w2=layer.w2_weight,
1060
1061
1062
1063
1064
                gating_output=router_logits,
                topk=top_k,
                renormalize=renormalize,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
1065
                quant_config=self.moe_quant_config,
1066
1067
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
1068
1069
        else:
            raise ValueError(f"Unsupported backend: {self.mxfp4_backend}")
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100


class IpexMxfp4MoEMethod(Mxfp4MoEMethod):
    def __init__(self, moe_config: FusedMoEConfig):
        super().__init__(moe_config)
        self.moe_config = moe_config

    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        super().create_weights(
            layer,
            num_experts,
            hidden_size,
            intermediate_size_per_partition,
            params_dtype,
            **extra_weight_attrs,
        )
        self.original_hidden_size = hidden_size

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        import intel_extension_for_pytorch as ipex

        layer.w13_weight.data = layer.w13_weight.data.view(torch.int32)
        layer.w2_weight.data = layer.w2_weight.data.view(torch.int32)
1101
        ep_rank_start = self.moe_config.ep_rank * self.moe_config.num_local_experts
1102
1103
1104
1105
1106
1107
1108
1109
        layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE(
            layer.w13_weight,
            layer.w2_weight,
            w1_scale_inv=layer.w13_weight_scale,
            w2_scale_inv=layer.w2_weight_scale,
            w13_bias=layer.w13_bias,
            w2_bias=layer.w2_bias,
            is_mxfp4=True,
1110
            experts_start_id=ep_rank_start,
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
        )

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
        use_grouped_topk: bool = False,
        topk_group: int | None = None,
        num_expert_group: int | None = None,
        global_num_experts: int = -1,
        expert_map: torch.Tensor | None = None,
        custom_routing_function: Callable | None = None,
        scoring_func: str = "softmax",
        routed_scaling_factor: float = 1.0,
        e_score_correction_bias: torch.Tensor | None = None,
        apply_router_weight_on_input: bool = False,
        activation: str = "silu",
        enable_eplb: bool = False,
        expert_load_view: torch.Tensor | None = None,
        logical_to_physical_map: torch.Tensor | None = None,
        logical_replica_count: torch.Tensor | None = None,
    ) -> torch.Tensor:
        assert activation == "swigluoai", (
            "Only swiglu_oai activation is supported for IPEX MXFP4 MoE"
1138
        )
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
        hidden_size_pad = round_up(self.original_hidden_size, 128)
        x_pad = torch.nn.functional.pad(x, (0, hidden_size_pad - x.size(-1)))
        hidden_states = layer.ipex_fusion(
            x_pad,
            use_grouped_topk,
            top_k,
            router_logits,
            renormalize,
            topk_group,
            num_expert_group,
            activation="swiglu_oai",
        )
        hidden_states = hidden_states[..., : self.original_hidden_size].contiguous()
        return hidden_states