mxfp4.py 46 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from enum import Enum
4
5
6
7
8

import torch
from torch.nn.parameter import Parameter

from vllm import envs
9
from vllm.config import get_current_vllm_config
10
from vllm.logger import init_logger
11
from vllm.model_executor.layers.attention import Attention
12
13
14
15
16
from vllm.model_executor.layers.fused_moe import (
    FusedMoE,
    FusedMoEConfig,
    FusedMoEMethodBase,
)
17
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
18
from vllm.model_executor.layers.fused_moe.config import (
19
    FusedMoEQuantConfig,
20
    mxfp4_mxfp8_moe_quant_config,
21
    mxfp4_w4a16_moe_quant_config,
22
    ocp_mx_moe_quant_config,
23
)
24
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
25
    BatchedMarlinExperts,
26
27
28
    MarlinExperts,
    fused_marlin_moe,
)
29
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
30
    OAITritonExperts,
31
    UnfusedOAITritonExperts,
32
)
33
from vllm.model_executor.layers.fused_moe.trtllm_moe import TrtLlmGenExperts
34
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
35
36
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
37
38
39
    QuantizationConfig,
    QuantizeMethodBase,
)
40
41
42
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
    get_marlin_input_dtype,
)
43
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
44
45
    prepare_moe_fp4_layer_for_marlin,
)
46
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
47
48
    _can_support_mxfp4,
    _swizzle_mxfp4,
49
    get_padding_alignment,
50
51
)
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
52
from vllm.model_executor.utils import set_weight_attrs
53
from vllm.platforms import current_platform
54
from vllm.scalar_type import scalar_types
55
from vllm.utils.flashinfer import has_flashinfer
56
from vllm.utils.import_utils import has_triton_kernels
Cyrus Leung's avatar
Cyrus Leung committed
57
from vllm.utils.math_utils import round_up
58

59
60
61
logger = init_logger(__name__)


62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# enum for mxfp4 backend
class Mxfp4Backend(Enum):
    NONE = 0

    # FlashInfer Backend
    SM100_FI_MXFP4_MXFP8_TRTLLM = 1
    SM100_FI_MXFP4_MXFP8_CUTLASS = 2
    SM100_FI_MXFP4_BF16 = 3
    SM90_FI_MXFP4_BF16 = 4

    # Marlin Backend
    MARLIN = 5

    # Triton Backend
    TRITON = 6


79
80
81
82
83
84
85
86
def get_mxfp4_backend_with_lora() -> Mxfp4Backend:
    """
    Not all MXFP4 backends support LoRA. Select backends that are known to
    have LoRA support.
    """
    if not current_platform.is_cuda():
        return Mxfp4Backend.NONE

87
88
89
90
91
92
93
94
    # If FlashInfer is not available, try either Marlin or Triton
    triton_kernels_supported = (
        has_triton_kernels()
        # NOTE: triton_kernels are only confirmed to work on SM90 and SM100
        # SM110 fails with this error: https://github.com/vllm-project/vllm/issues/29317
        # SM120 needs this fix: https://github.com/triton-lang/triton/pull/8498
        and (9, 0) <= current_platform.get_device_capability() < (11, 0)
    )
95
96
97
    if envs.VLLM_MXFP4_USE_MARLIN is False and triton_kernels_supported:
        logger.info_once("[get_mxfp4_backend_with_lora] Using Triton backend")
        return Mxfp4Backend.TRITON
98

99
100
    logger.info_once("[get_mxfp4_backend_with_lora] Using Marlin backend")
    return Mxfp4Backend.MARLIN
101
102
103


def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend:
104
    # Backend Selection
105
106
107
108

    if with_lora_support:
        return get_mxfp4_backend_with_lora()

109
    if current_platform.is_cuda():
110
111
112
113
114
        if (
            current_platform.is_device_capability(90)
            and has_flashinfer()
            and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16
        ):
115
116
            logger.info_once("Using FlashInfer MXFP4 BF16 backend for SM90")
            return Mxfp4Backend.SM90_FI_MXFP4_BF16
117
        elif (
118
            current_platform.is_device_capability_family(100)
119
120
121
122
            and has_flashinfer()
            and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS
        ):
            logger.info_once("Using FlashInfer MXFP4 MXFP8 CUTLASS backend for SM100")
123
            return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
124
        elif (
125
            current_platform.is_device_capability_family(100)
126
127
128
            and has_flashinfer()
            and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
        ):
129
            return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
130
        elif current_platform.is_device_capability_family(100) and has_flashinfer():
131
132
133
134
            logger.info_once(
                "Using FlashInfer MXFP4 BF16 backend for SM100, "
                "For faster performance on SM100, consider setting "
                "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1, though this may impact "
135
136
                "accuracy."
            )
137
            return Mxfp4Backend.SM100_FI_MXFP4_BF16
138
        elif (
139
            current_platform.is_device_capability_family(100)
140
141
            or current_platform.is_device_capability(90)
        ) and not has_flashinfer():
142
143
144
            logger.warning_once(
                "MXFP4 MoE is enabled on Hopper/Blackwell but FlashInfer "
                "is not available. This may result in degraded performance. "
145
146
                "Please `pip install vllm[flashinfer]` for best results."
            )
147

148
        # If FlashInfer is not available, try either Marlin or Triton
149
150
151
152
153
154
155
156
        triton_kernels_supported = (
            has_triton_kernels()
            # NOTE: triton_kernels are only confirmed to work on SM90 and SM100
            # SM110 fails with this error: https://github.com/vllm-project/vllm/issues/29317
            # SM120 needs this fix: https://github.com/triton-lang/triton/pull/8498
            and (9, 0) <= current_platform.get_device_capability() < (11, 0)
        )
        if envs.VLLM_MXFP4_USE_MARLIN or not triton_kernels_supported:
157
158
159
160
161
            logger.info_once("Using Marlin backend")
            return Mxfp4Backend.MARLIN
        else:
            logger.info_once("Using Triton backend")
            return Mxfp4Backend.TRITON
162
163
164
    elif current_platform.is_xpu():
        logger.info_once("Using ipex marlin backend on XPU")
        return Mxfp4Backend.MARLIN
165
166
167
    elif current_platform.is_rocm() and has_triton_kernels():
        logger.info_once("Using Triton backend")
        return Mxfp4Backend.TRITON
168

169
    return Mxfp4Backend.NONE
170
171
172


class Mxfp4Config(QuantizationConfig):
173
    def __init__(self, ignored_layers: list[str] | None = None):
174
175
176
177
178
179
180
181
182
        super().__init__()
        self.ignored_layers = ignored_layers

    @classmethod
    def from_config(cls, config):
        return cls()

    @classmethod
    def get_min_capability(cls) -> int:
183
        return 80
184
185
186
187
188
189
190
191
192
193
194
195
196

    @classmethod
    def get_name(cls) -> QuantizationMethods:
        return "mxfp4"

    @classmethod
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
        return [torch.bfloat16]

    @classmethod
    def get_config_filenames(cls) -> list[str]:
        return []

197
198
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
199
    ) -> "QuantizeMethodBase | None":
200
201
        if isinstance(layer, LinearBase):
            if self.ignored_layers and is_layer_skipped(
202
203
204
205
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
206
                return UnquantizedLinearMethod()
207
208
209
            # TODO: Add support for MXFP4 Linear Method.
            # MXFP4 LinearMethod is available in AMD-Quark, refer to that implementation
            # if you are interested in enabling MXFP4 here.
210
            logger.debug_once(
211
                "MXFP4 linear layer is not implemented - falling back to "
212
213
                "UnquantizedLinearMethod.",
                scope="local",
214
215
            )
            return UnquantizedLinearMethod()
216
        elif isinstance(layer, FusedMoE):
217
218
219
            if current_platform.is_xpu():
                return IpexMxfp4MoEMethod(layer.moe_config)
            else:
220
221
222
                quant_method = Mxfp4MoEMethod(layer.moe_config)
                quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
                return quant_method
223
        elif isinstance(layer, Attention):
224
            # TODO: Add support for MXFP4 Attention.
225
            logger.debug_once(
226
                "MXFP4 attention layer is not implemented. "
227
228
                "Skipping quantization for this layer.",
                scope="local",
229
            )
230
231
232
233
234
        return None


class Mxfp4MoEMethod(FusedMoEMethodBase):
    def __init__(self, moe: FusedMoEConfig):
235
        super().__init__(moe)
236
        self.mxfp4_backend = get_mxfp4_backend(moe.is_lora_enabled)
237
238

        self.marlin_input_dtype = None
239
        self.max_capture_size = (
240
            get_current_vllm_config().compilation_config.max_cudagraph_capture_size
241
        )
242

243
        assert self.mxfp4_backend != Mxfp4Backend.NONE, (
244
245
            f"get_mxfp4_backend(with_lora_support={moe.is_lora_enabled}) found"
            "no compatible MXFP4 MoE backend (FlashInfer/Marlin/Triton)."
246
247
            "Please check your environment and try again."
        )
248
        self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
249

250
251
252
253
254
255
256
257
258
    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
259
260
261
262
263
264
265
266
267
268
269
270
271
272
        self.num_experts = num_experts
        weight_dtype = torch.uint8
        scale_dtype = torch.uint8

        # FIXME (zyongye): ship after torch and safetensors support mxfp4
        # is_torch_mxfp4_available = (
        #     hasattr(torch, "float4_e2m1fn_x2") and
        #     hasattr(torch, "float8_e8m0fnu"))
        # if is_torch_mxfp4_available:
        #     weight_dtype = torch.float4_e2m1fn_x2
        #     scale_dtype = torch.float8_e8m0fnu

        mxfp4_block = 32

273
        intermediate_size_per_partition_after_pad = intermediate_size_per_partition
274
        if self.mxfp4_backend == Mxfp4Backend.MARLIN:
275
276
277
278
279
280
281
282
283
            # The moe marlin kernel requires that for each linear
            # n % 256 == 0 and k % 128 == 0.
            # In gate_up_proj:
            #    n = 2 * intermediate_size_per_partition_after_pad
            #    k = hidden_size
            # In down_proj
            #    n = hidden_size
            #    k = intermediate_size_per_partition_after_pad
            intermediate_size_per_partition_after_pad = round_up(
284
285
                intermediate_size_per_partition, 128
            )
286
287
288
289
            if current_platform.is_xpu():
                hidden_size = round_up(hidden_size, 128)
            else:
                hidden_size = round_up(hidden_size, 256)
290
291
292
293

            layer.params_dtype = params_dtype
            layer.num_experts = num_experts
            layer.hidden_size = hidden_size
294
            layer.intermediate_size_per_partition = (
295
                intermediate_size_per_partition_after_pad
296
297
298
299
300
            )
        elif (
            self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
            or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
        ):
301
302
303
            # pad the intermediate size to be a multiple of 2 * mxfp4_block
            # for to hold non-uniform sharded tensor as well as swizzling
            # other padding to increase performance
304
            intermediate_size_per_partition_after_pad = round_up(
305
306
                intermediate_size_per_partition, 256
            )
307
            hidden_size = round_up(hidden_size, 256)
308
309
310
311
        elif (
            self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
            or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
        ):
312
            intermediate_size_per_partition_after_pad = round_up(
313
314
                intermediate_size_per_partition, 128
            )
315
            hidden_size = round_up(hidden_size, 128)
316
        elif current_platform.is_rocm():
317
            pad_align = get_padding_alignment()
318
            intermediate_size_per_partition_after_pad = round_up(
319
                intermediate_size_per_partition, pad_align
320
            )
321
            hidden_size = round_up(hidden_size, pad_align)
322
323
        else:
            intermediate_size_per_partition_after_pad = round_up(
324
325
                intermediate_size_per_partition, 64
            )
326
327
328
329

        self.intermediate_size = intermediate_size_per_partition_after_pad
        self.hidden_size = hidden_size
        # Fused gate_up_proj (column parallel)
330
331
332
333
334
335
336
337
338
        w13_weight = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                2 * intermediate_size_per_partition_after_pad,
                hidden_size // 2,
                dtype=weight_dtype,
            ),
            requires_grad=False,
        )
339
340
341
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

342
343
344
345
346
347
348
349
350
        w13_weight_scale = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                2 * intermediate_size_per_partition_after_pad,
                hidden_size // mxfp4_block,
                dtype=scale_dtype,
            ),
            requires_grad=False,
        )
351
352
353
        layer.register_parameter("w13_weight_scale", w13_weight_scale)
        set_weight_attrs(w13_weight_scale, extra_weight_attrs)

354
355
356
357
358
359
360
361
        w13_bias = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                2 * intermediate_size_per_partition_after_pad,
                dtype=torch.bfloat16,
            ),
            requires_grad=False,
        )
362
363
364
365
        layer.register_parameter("w13_bias", w13_bias)
        set_weight_attrs(w13_bias, extra_weight_attrs)

        # down_proj (row parallel)
366
367
368
369
370
371
372
373
374
        w2_weight = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                hidden_size,
                intermediate_size_per_partition_after_pad // 2,
                dtype=weight_dtype,
            ),
            requires_grad=False,
        )
375
376
377
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

378
379
380
381
382
383
384
385
386
        w2_weight_scale = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                hidden_size,
                intermediate_size_per_partition_after_pad // mxfp4_block,
                dtype=scale_dtype,
            ),
            requires_grad=False,
        )
387
388
389
        layer.register_parameter("w2_weight_scale", w2_weight_scale)
        set_weight_attrs(w2_weight_scale, extra_weight_attrs)

390
391
392
393
394
395
396
397
        w2_bias = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                hidden_size,
                dtype=torch.bfloat16,
            ),
            requires_grad=False,
        )
398
399
400
401
        layer.register_parameter("w2_bias", w2_bias)
        set_weight_attrs(w2_bias, extra_weight_attrs)

    def process_weights_after_loading(self, layer):
402
        if self.mxfp4_backend == Mxfp4Backend.MARLIN:
403
            prepare_moe_fp4_layer_for_marlin(layer, input_dtype=self.marlin_input_dtype)
404
405
406
407
408
        elif (
            self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
            or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
        ):
            from flashinfer.fp4_quantization import nvfp4_block_scale_interleave
409
            from flashinfer.fused_moe.core import get_w2_permute_indices_with_cache
410
411
412
413
414
415
416
417
418
419
420
421
422

            layer.gemm1_alpha = Parameter(
                torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(),
                requires_grad=False,
            )
            layer.gemm1_beta = Parameter(
                torch.tensor([1.0] * self.num_experts, dtype=torch.float32).cuda(),
                requires_grad=False,
            )
            layer.gemm1_clamp_limit = Parameter(
                torch.tensor([7.0] * self.num_experts, dtype=torch.float32).cuda(),
                requires_grad=False,
            )
423
424
            sf_block_size = 32  # mxfp4 block size

425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
            assert (
                layer.w13_weight.dim() == 3
                and layer.w13_weight.shape[0] == self.num_experts
                and layer.w13_weight.shape[1] == self.intermediate_size * 2
                and layer.w13_weight.shape[2] == self.hidden_size // 2
            )
            assert (
                layer.w13_weight_scale.dim() == 3
                and layer.w13_weight_scale.shape[0] == self.num_experts
                and layer.w13_weight_scale.shape[1] == self.intermediate_size * 2
                and layer.w13_weight_scale.shape[2] == self.hidden_size // sf_block_size
            )
            assert (
                layer.w2_weight.dim() == 3
                and layer.w2_weight.shape[0] == self.num_experts
                and layer.w2_weight.shape[1] == self.hidden_size
                and layer.w2_weight.shape[2] == self.intermediate_size // 2
            )
            assert (
                layer.w2_weight_scale.dim() == 3
                and layer.w2_weight_scale.shape[1] == self.hidden_size
                and layer.w2_weight_scale.shape[2]
                == self.intermediate_size // sf_block_size
            )
            assert (
                layer.w13_bias.dim() == 2
                and layer.w13_bias.shape[0] == self.num_experts
                and layer.w13_bias.shape[1] == self.intermediate_size * 2
            )
            assert (
                layer.w2_bias.dim() == 2
                and layer.w2_bias.shape[0] == self.num_experts
                and layer.w2_bias.shape[1] == self.hidden_size
            )
459
460
461
462
463
464
465
466

            w13_weight_scale = layer.w13_weight_scale.data
            w2_weight_scale = layer.w2_weight_scale.data
            w13_weight = layer.w13_weight.data
            w2_weight = layer.w2_weight.data
            w13_bias = layer.w13_bias.data.to(torch.float32)
            w2_bias = layer.w2_bias.data.to(torch.float32)

co63oc's avatar
co63oc committed
467
            # Swap w1 and w3 as the definition of
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
            # swiglu is different in the trtllm-gen
            def swap_every_two_rows(x, axis=-1):
                shape = x.shape
                if axis < 0:
                    axis = len(shape) + axis

                # Create a new shape with pairs swapped along specified axis
                new_shape = list(shape)
                new_shape[axis] = shape[axis] // 2
                new_shape.insert(axis + 1, 2)

                # Reshape to expose pairs, swap them, and reshape back
                x = x.reshape(*new_shape)
                x = x.flip(axis + 1)
                new_shape = list(shape)
                return x.reshape(*new_shape)

            w13_weight_scale = swap_every_two_rows(w13_weight_scale, -2)
            w13_weight = swap_every_two_rows(w13_weight, -2)
            w13_bias = swap_every_two_rows(w13_bias, -1)

            # Do not interleave as the checkpoint is already interleaved

            # Shuffle weights and scaling factors for transposed mma output
            gemm1_weights_mxfp4_shuffled = []
            gemm1_scales_mxfp4_shuffled = []
            gemm2_weights_mxfp4_shuffled = []
            gemm2_scales_mxfp4_shuffled = []
            gemm1_bias_shuffled = []
            gemm2_bias_shuffled = []
            epilogue_tile_m = 128  # FIXME: this depends on the kernel internals
            for i in range(self.num_experts):
500
                # w13 weight shuffling
501
                permute_indices = get_w2_permute_indices_with_cache(
502
503
504
505
                    self._cache_permute_indices,
                    w13_weight[i].view(torch.uint8),
                    epilogue_tile_m,
                )
506
507
508
509
510
                gemm1_weights_mxfp4_shuffled.append(
                    w13_weight[i]
                    .view(torch.uint8)[permute_indices.to(w13_weight.device)]
                    .contiguous()
                )
511
                # w13 scale shuffling
512
                permute_sf_indices = get_w2_permute_indices_with_cache(
513
514
515
516
517
                    self._cache_permute_indices,
                    w13_weight_scale[i].view(torch.uint8),
                    epilogue_tile_m,
                    num_elts_per_sf=16,
                )
518
                gemm1_scales_mxfp4_shuffled.append(
519
520
521
522
523
524
525
526
                    nvfp4_block_scale_interleave(
                        w13_weight_scale[i]
                        .view(torch.uint8)[
                            permute_sf_indices.to(w13_weight_scale.device)
                        ]
                        .contiguous()
                    )
                )
527
                # w13 bias shuffling
528
                permute_bias_indices = get_w2_permute_indices_with_cache(
529
530
531
532
                    self._cache_permute_indices,
                    w13_bias[i].clone().reshape(-1, 1),
                    epilogue_tile_m,
                )
533
534
535
536
537
538
                gemm1_bias_shuffled.append(
                    w13_bias[i]
                    .clone()
                    .reshape(-1, 1)[permute_bias_indices.to(w13_bias.device)]
                    .contiguous()
                )
539
                # w2 weight shuffling
540
                permute_indices = get_w2_permute_indices_with_cache(
541
542
543
544
                    self._cache_permute_indices,
                    w2_weight[i].view(torch.uint8),
                    epilogue_tile_m,
                )
545
546
547
548
549
                gemm2_weights_mxfp4_shuffled.append(
                    w2_weight[i]
                    .view(torch.uint8)[permute_indices.to(w2_weight.device)]
                    .contiguous()
                )
550
                # w2 scale shuffling
551
                permute_sf_indices = get_w2_permute_indices_with_cache(
552
553
554
555
556
                    self._cache_permute_indices,
                    w2_weight_scale[i].view(torch.uint8),
                    epilogue_tile_m,
                    num_elts_per_sf=16,
                )
557
                gemm2_scales_mxfp4_shuffled.append(
558
559
560
561
562
563
564
565
                    nvfp4_block_scale_interleave(
                        w2_weight_scale[i]
                        .view(torch.uint8)[
                            permute_sf_indices.to(w2_weight_scale.device)
                        ]
                        .contiguous()
                    )
                )
566
                # w2 bias shuffling
567
                permute_indices = get_w2_permute_indices_with_cache(
568
569
570
571
                    self._cache_permute_indices,
                    w2_bias[i].clone().reshape(-1, 1),
                    epilogue_tile_m,
                )
572
573
574
575
576
577
                gemm2_bias_shuffled.append(
                    w2_bias[i]
                    .clone()
                    .reshape(-1, 1)[permute_indices.to(w2_bias.device)]
                    .contiguous()
                )
578
579

            w13_weight = torch.stack(gemm1_weights_mxfp4_shuffled)
580
581
582
583
584
585
586
587
588
            w13_weight_scale = (
                torch.stack(gemm1_scales_mxfp4_shuffled)
                .reshape(
                    self.num_experts,
                    2 * self.intermediate_size,
                    self.hidden_size // sf_block_size,
                )
                .view(torch.float8_e4m3fn)
            )
589
590

            w2_weight = torch.stack(gemm2_weights_mxfp4_shuffled)
591
592
593
594
595
596
597
598
599
            w2_weight_scale = (
                torch.stack(gemm2_scales_mxfp4_shuffled)
                .reshape(
                    self.num_experts,
                    self.hidden_size,
                    self.intermediate_size // sf_block_size,
                )
                .view(torch.float8_e4m3fn)
            )
600
601

            layer.w13_weight = Parameter(w13_weight, requires_grad=False)
602
            layer.w13_weight_scale = Parameter(w13_weight_scale, requires_grad=False)
603
            layer.w2_weight = Parameter(w2_weight, requires_grad=False)
604
            layer.w2_weight_scale = Parameter(w2_weight_scale, requires_grad=False)
605
606
            layer.w13_bias = Parameter(
                torch.stack(gemm1_bias_shuffled).reshape(self.num_experts, -1),
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
                requires_grad=False,
            )
            layer.w2_bias = Parameter(
                torch.stack(gemm2_bias_shuffled).reshape(self.num_experts, -1),
                requires_grad=False,
            )
        elif (
            self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
            or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
        ):
            layer.gemm1_alpha = Parameter(
                torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(),
                requires_grad=False,
            )
            layer.gemm1_beta = Parameter(
                torch.tensor([1.0] * self.num_experts, dtype=torch.float32).cuda(),
                requires_grad=False,
            )
            layer.gemm1_clamp_limit = Parameter(
                torch.tensor([7.0] * self.num_experts, dtype=torch.float32).cuda(),
                requires_grad=False,
            )
629
630
631
632

            sf_block_size = 32  # mxfp4 block size

            # Common shape assertions
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
            assert (
                layer.w13_weight.dim() == 3
                and layer.w13_weight.shape[0] == self.num_experts
                and layer.w13_weight.shape[1] == self.intermediate_size * 2
                and layer.w13_weight.shape[2] == self.hidden_size // 2
            )
            assert (
                layer.w13_weight_scale.dim() == 3
                and layer.w13_weight_scale.shape[0] == self.num_experts
                and layer.w13_weight_scale.shape[1] == self.intermediate_size * 2
                and layer.w13_weight_scale.shape[2] == self.hidden_size // sf_block_size
            )
            assert (
                layer.w2_weight.dim() == 3
                and layer.w2_weight.shape[0] == self.num_experts
                and layer.w2_weight.shape[1] == self.hidden_size
                and layer.w2_weight.shape[2] == self.intermediate_size // 2
            )
            assert (
                layer.w2_weight_scale.dim() == 3
                and layer.w2_weight_scale.shape[1] == self.hidden_size
                and layer.w2_weight_scale.shape[2]
                == self.intermediate_size // sf_block_size
            )
            assert (
                layer.w13_bias.dim() == 2
                and layer.w13_bias.shape[0] == self.num_experts
                and layer.w13_bias.shape[1] == self.intermediate_size * 2
            )
            assert (
                layer.w2_bias.dim() == 2
                and layer.w2_bias.shape[0] == self.num_experts
                and layer.w2_bias.shape[1] == self.hidden_size
            )
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691

            # De-interleave and swap for w13 weight, bias, and scales
            w13_w = layer.w13_weight.data
            gate_w, up_w = w13_w[:, ::2, :], w13_w[:, 1::2, :]
            deinterleaved_w13_w = torch.cat([gate_w, up_w], dim=1)
            w1_w, w3_w = torch.chunk(deinterleaved_w13_w, 2, dim=1)
            w13_weight_swapped = torch.cat([w3_w, w1_w], dim=1)

            w13_b = layer.w13_bias.data.to(torch.float32)
            gate_b, up_b = w13_b[:, ::2], w13_b[:, 1::2]
            deinterleaved_w13_b = torch.cat([gate_b, up_b], dim=1)
            b1, b3 = torch.chunk(deinterleaved_w13_b, 2, dim=-1)
            w13_bias_swapped = torch.cat([b3, b1], dim=-1).to(torch.bfloat16)

            w13_s = layer.w13_weight_scale.data
            gate_s, up_s = w13_s[:, ::2, :], w13_s[:, 1::2, :]
            deinterleaved_w13_s = torch.cat([gate_s, up_s], dim=1)
            s1, s3 = torch.chunk(deinterleaved_w13_s, 2, dim=1)
            w13_scale_swapped = torch.cat([s3, s1], dim=1)

            if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS:
                from flashinfer import block_scale_interleave

                orig_shape = w13_scale_swapped.shape
                w13_scale_interleaved = block_scale_interleave(
692
693
                    w13_scale_swapped.view(torch.uint8)
                ).reshape(orig_shape)
694
695
696
697

                w2_s = layer.w2_weight_scale.data
                orig_shape = w2_s.shape
                w2_scale_interleaved = block_scale_interleave(
698
699
700
701
702
703
704
705
706
707
708
                    w2_s.view(torch.uint8)
                ).reshape(orig_shape)

                layer.w13_weight = Parameter(w13_weight_swapped, requires_grad=False)
                layer.w13_weight_scale = Parameter(
                    w13_scale_interleaved, requires_grad=False
                )
                layer.w13_bias = Parameter(w13_bias_swapped, requires_grad=False)
                layer.w2_weight_scale = Parameter(
                    w2_scale_interleaved, requires_grad=False
                )
709
710
711
712
            elif self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16:

                def _interleave_mxfp4_cutlass_sm90(w):
                    w_shape = w.shape
713
714
715
                    w_interleaved = w.reshape(
                        w_shape[0], w_shape[1], (w_shape[2] // 4), 4
                    )
716
717
                    w_interleaved = w_interleaved.permute(0, 2, 1, 3)
                    w_interleaved = w_interleaved.reshape(
718
719
                        w_shape[0], w_shape[2] // 4, w_shape[1] * 4
                    )
720
721
                    return w_interleaved

722
723
                w31_scales = w13_scale_swapped.to(torch.uint8).view(torch.uint8)
                w31_scales_interleaved = _interleave_mxfp4_cutlass_sm90(w31_scales)
724
725
726

                w2_weight_scale = layer.w2_weight_scale.data
                w2_scales = w2_weight_scale.to(torch.uint8).view(torch.uint8)
727
728
729
730
731
732
733
734
                w2_scales_interleaved = _interleave_mxfp4_cutlass_sm90(w2_scales)

                layer.w13_weight = torch.nn.Parameter(
                    torch.cat([w3_w, w1_w], dim=1), requires_grad=False
                )
                layer.w13_bias = torch.nn.Parameter(
                    w13_bias_swapped, requires_grad=False
                )
735
                layer.w13_weight_scale = torch.nn.Parameter(
736
737
                    w31_scales_interleaved, requires_grad=False
                )
738
                layer.w2_weight_scale = torch.nn.Parameter(
739
740
                    w2_scales_interleaved, requires_grad=False
                )
741
        elif self.mxfp4_backend == Mxfp4Backend.TRITON:
742
743
744
745
746
747
748
749
            from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig

            w13_bias = layer.w13_bias.to(torch.float32)
            w2_bias = layer.w2_bias.to(torch.float32)

            layer.w13_bias = Parameter(w13_bias, requires_grad=False)
            layer.w2_bias = Parameter(w2_bias, requires_grad=False)

750
751
752
753
754
            # Ideally we'd use FusedMoEModularKernel.prepare_finalize object
            # (stored in self.fused_experts) to determine if the MoE has a
            # batched activation format. As self.fused_experts is not
            # initialized at this point, we resort to checking the MoE config
            # directly.
755
            is_batched_moe = self.moe.use_pplx_kernels or self.moe.use_deepep_ll_kernels
756
            if is_batched_moe:
757
758
759
760
761
                num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8
            else:
                num_warps = 8

            w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
762
763
                layer.w13_weight, layer.w13_weight_scale, num_warps
            )
764
            w2_weight, w2_flex, w2_scale = _swizzle_mxfp4(
765
766
                layer.w2_weight, layer.w2_weight_scale, num_warps
            )
767
768

            self.w13_precision_config = PrecisionConfig(
769
770
                weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex)
            )
771
            self.w2_precision_config = PrecisionConfig(
772
773
                weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex)
            )
774

775
776
            self.w13_weight = w13_weight
            self.w2_weight = w2_weight
777
778
779
780
            del layer.w13_weight
            del layer.w2_weight
            layer.w13_weight = w13_weight
            layer.w2_weight = w2_weight
781
        else:
782
783
784
785
            raise ValueError(
                f"Unsupported mxfp4_backend: {self.mxfp4_backend}: "
                f"should be one of: {list(Mxfp4Backend)}."
            )
786

787
    def get_fused_moe_quant_config(
788
        self, layer: torch.nn.Module
789
    ) -> FusedMoEQuantConfig | None:
790
        if self.mxfp4_backend == Mxfp4Backend.MARLIN:
791
792
793
794
795
796
797
            return mxfp4_w4a16_moe_quant_config(
                w1_bias=layer.w13_bias,
                w2_bias=layer.w2_bias,
                w1_scale=layer.w13_weight_scale,
                w2_scale=layer.w2_weight_scale,
            )
        elif self.mxfp4_backend == Mxfp4Backend.TRITON:
798
799
            w1_scale = self.w13_precision_config
            w2_scale = self.w2_precision_config
800
801
802
803
804
805
            return mxfp4_w4a16_moe_quant_config(
                w1_bias=layer.w13_bias,
                w2_bias=layer.w2_bias,
                w1_scale=w1_scale,
                w2_scale=w2_scale,
            )
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
        elif self.mxfp4_backend in [
            Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM,
            Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS,
        ]:
            return mxfp4_mxfp8_moe_quant_config(
                w1_bias=layer.w13_bias,
                w2_bias=layer.w2_bias,
                w1_scale=layer.w13_weight_scale,
                w2_scale=layer.w2_weight_scale,
            )
        elif self.mxfp4_backend in [Mxfp4Backend.SM100_FI_MXFP4_BF16]:
            return mxfp4_w4a16_moe_quant_config(
                w1_bias=layer.w13_bias,
                w2_bias=layer.w2_bias,
                w1_scale=layer.w13_weight_scale,
                w2_scale=layer.w2_weight_scale,
            )
823
824
825
        else:
            w1_scale = layer.w13_weight_scale
            w2_scale = layer.w2_weight_scale
826
827
            return ocp_mx_moe_quant_config(
                quant_dtype="mxfp4",
828
829
830
831
832
                w1_bias=layer.w13_bias,
                w2_bias=layer.w2_bias,
                w1_scale=w1_scale,
                w2_scale=w2_scale,
            )
833

834
835
836
837
838
    def select_gemm_impl(
        self,
        prepare_finalize: mk.FusedMoEPrepareAndFinalize,
        layer: torch.nn.Module,
    ) -> mk.FusedMoEPermuteExpertsUnpermute:
839
840
841
842
        if (
            prepare_finalize.activation_format
            == mk.FusedMoEActivationFormat.BatchedExperts
        ):
843
844
845
846
847
848
849
850
            if self.mxfp4_backend == Mxfp4Backend.MARLIN:
                max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
                assert max_num_tokens_per_rank is not None
                assert self.moe_quant_config is not None
                return BatchedMarlinExperts(
                    max_num_tokens=max_num_tokens_per_rank,
                    num_dispatchers=prepare_finalize.num_dispatchers(),
                    quant_config=self.moe_quant_config,
851
                    moe_config=self.moe,
852
853
854
                )
            else:
                raise NotImplementedError(
855
856
                    f"Incompatible Mxfp4 backend ({self.mxfp4_backend}) for "
                    "EP batched experts format"
857
                )
858
        else:
859
            assert self.moe_quant_config is not None
860
861
862
863
            if (
                self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
                or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
            ):
864
865
866
867
868
                # B200 code-path
                kwargs = {
                    "gemm1_alpha": layer.gemm1_alpha,
                    "gemm1_beta": layer.gemm1_beta,
                    "gemm1_clamp_limit": layer.gemm1_clamp_limit,
869
                    # TODO(bnell): part of quant_config
870
871
                    "max_capture_size": self.max_capture_size,
                }
872
873
                return TrtLlmGenExperts(self.moe, self.moe_quant_config, **kwargs)
            elif self.mxfp4_backend == Mxfp4Backend.MARLIN:
874
                return MarlinExperts(self.moe, self.moe_quant_config)
875
            elif self.mxfp4_backend == Mxfp4Backend.TRITON:
876
                if self.moe.is_lora_enabled:
877
878
                    return UnfusedOAITritonExperts(self.moe, self.moe_quant_config)
                return OAITritonExperts(self.moe, self.moe_quant_config)
879
880
881
882
            else:
                raise NotImplementedError(
                    f"Incompatible Mxfp4 backend ({self.mxfp4_backend}) for EP"
                )
883

884
885
886
    @property
    def allow_inplace(self) -> bool:
        return True
887

888
889
890
891
892
893
894
895
    @property
    def is_monolithic(self) -> bool:
        return (
            self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
            or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
            or self.mxfp4_backend == Mxfp4Backend.TRITON
        )

896
897
    def apply(
        self,
898
        layer: FusedMoE,
899
        x: torch.Tensor,
900
901
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
902
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
903
        assert not self.is_monolithic
904
        if layer.enable_eplb:
905
906
            raise NotImplementedError("EPLB is not supported for mxfp4")

907
        if self.mxfp4_backend == Mxfp4Backend.MARLIN:
908
            return fused_marlin_moe(
909
910
911
912
913
914
915
916
917
918
919
920
                x,
                layer.w13_weight,
                layer.w2_weight,
                layer.w13_bias,
                layer.w2_bias,
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                topk_weights,
                topk_ids,
                global_scale1=None,
                global_scale2=None,
                quant_type_id=scalar_types.float4_e2m1f.id,
921
922
923
924
                apply_router_weight_on_input=layer.apply_router_weight_on_input,
                global_num_experts=layer.global_num_experts,
                activation=layer.activation,
                expert_map=layer.expert_map,
925
                input_dtype=self.marlin_input_dtype,
926
            )
927

928
        assert _can_support_mxfp4(
929
930
931
932
933
934
935
936
937
            layer.use_grouped_topk,
            layer.topk_group,
            layer.num_expert_group,
            layer.expert_map,
            layer.custom_routing_function,
            layer.e_score_correction_bias,
            layer.apply_router_weight_on_input,
            layer.scoring_func,
            layer.activation,
938
939
940
            layer.eplb_state.expert_load_view,
            layer.eplb_state.logical_to_physical_map,
            layer.eplb_state.logical_replica_count,
941
942
        ), "MXFP4 are not supported with this configuration."

943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
        assert (
            self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
            or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
        )
        from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe

        # Backend-specific preparation
        if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS:
            from flashinfer import mxfp8_quantize

            x_quant, x_scale = mxfp8_quantize(x, True, 32)

            fake_input_scale = torch.ones(self.num_experts, device=x.device)
            quant_scales = [
                layer.w13_weight_scale.contiguous().view(torch.int32),
                fake_input_scale,
                layer.w2_weight_scale.contiguous().view(torch.int32),
                fake_input_scale,
            ]

            fi_input = x_quant
            extra_kwargs = dict(
                use_mxfp8_act_scaling=True,
                input_sf=x_scale,
                fc1_expert_weights=layer.w13_weight.contiguous().view(torch.long),
                fc2_expert_weights=layer.w2_weight.contiguous().view(torch.long),
            )
        elif self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16:
            assert x.dtype == torch.bfloat16

            quant_scales = [
                layer.w13_weight_scale,
                layer.w2_weight_scale,
            ]

            fi_input = x
            extra_kwargs = dict(
                use_w4_group_scaling=True,
                fc1_expert_weights=layer.w13_weight,
                fc2_expert_weights=layer.w2_weight,
            )

        output = torch.empty_like(x, dtype=torch.bfloat16)

        flashinfer_cutlass_fused_moe(
            input=fi_input,
            token_selected_experts=topk_ids.to(torch.int).contiguous(),
            token_final_scales=topk_weights,
            output_dtype=torch.bfloat16,
            output=output,
            quant_scales=quant_scales,
            fc1_expert_biases=layer.w13_bias,
            fc2_expert_biases=layer.w2_bias,
            swiglu_alpha=layer.gemm1_alpha,
            swiglu_beta=layer.gemm1_beta,
            swiglu_limit=layer.gemm1_clamp_limit,
            tp_size=self.moe.tp_size,
            tp_rank=self.moe.tp_rank,
            ep_size=self.moe.ep_size,
            ep_rank=self.moe.ep_rank,
            tune_max_num_tokens=max(self.max_capture_size, 1),
            **extra_kwargs,
        )

        return output

    def apply_monolithic(
        self,
        layer: FusedMoE,
        x: torch.Tensor,
        router_logits: torch.Tensor,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        assert self.is_monolithic

        if layer.enable_eplb:
            raise NotImplementedError("EPLB is not supported for mxfp4")

        assert _can_support_mxfp4(
            layer.use_grouped_topk,
            layer.topk_group,
            layer.num_expert_group,
            layer.expert_map,
            layer.custom_routing_function,
            layer.e_score_correction_bias,
            layer.apply_router_weight_on_input,
            layer.scoring_func,
            layer.activation,
            layer.eplb_state.expert_load_view,
            layer.eplb_state.logical_to_physical_map,
            layer.eplb_state.logical_replica_count,
        ), "MXFP4 are not supported with this configuration."

1035
1036
1037
1038
        if (
            self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
            or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
        ):
1039
            from flashinfer import trtllm_fp4_block_scale_moe
1040

1041
            if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16:
1042
1043
1044
                assert x.dtype == torch.bfloat16
                x_quant = x
                x_scale = None
1045
1046
            elif self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM:
                from flashinfer import mxfp8_quantize
1047

1048
                x_quant, x_scale = mxfp8_quantize(x, False)  # to mxfp8
1049
                x_scale = x_scale.view(torch.float8_e4m3fn).reshape(*x.shape[:-1], -1)
1050

1051
            trtllm_gen_output = trtllm_fp4_block_scale_moe(
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
                routing_logits=router_logits.to(torch.bfloat16),
                routing_bias=None,
                hidden_states=x_quant,
                hidden_states_scale=x_scale,
                gemm1_weights=layer.w13_weight,  # uint8 (e2m1 x 2)
                gemm1_weights_scale=layer.w13_weight_scale,  # uint8 (e4m3 x 2)
                gemm1_bias=layer.w13_bias,  # fp32 per expert per channel
                gemm1_alpha=layer.gemm1_alpha,  # fp32 per expert
                gemm1_beta=layer.gemm1_beta,  # fp32 per expert
                gemm1_clamp_limit=layer.gemm1_clamp_limit,  # fp32 per expert
                gemm2_weights=layer.w2_weight,  # uint8 (e2m1 x 2)
                gemm2_weights_scale=layer.w2_weight_scale,  # ue8m0
                gemm2_bias=layer.w2_bias,  # fp32 per expert per channel
                output1_scale_scalar=None,
                output1_scale_gate_scalar=None,
                output2_scale_scalar=None,
                num_experts=layer.global_num_experts,
                top_k=layer.top_k,
                n_group=None,
                topk_group=None,
                intermediate_size=self.intermediate_size,  # padded to multiple of 256
                local_expert_offset=layer.ep_rank * layer.local_num_experts,
                local_num_experts=self.num_experts,
                routed_scaling_factor=None,
                routing_method_type=1 if layer.renormalize else 0,
                do_finalize=True,
1078
                tune_max_num_tokens=max(self.max_capture_size, 1),
1079
1080
            )[0]
            return trtllm_gen_output
1081
        elif self.mxfp4_backend == Mxfp4Backend.TRITON:
1082
            from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (  # noqa: E501
1083
1084
1085
                triton_kernel_moe_forward,
            )

1086
1087
            return triton_kernel_moe_forward(
                hidden_states=x,
1088
1089
                w1=layer.w13_weight,
                w2=layer.w2_weight,
1090
                gating_output=router_logits,
1091
1092
1093
1094
                topk=layer.top_k,
                renormalize=layer.renormalize,
                global_num_experts=layer.global_num_experts,
                expert_map=layer.expert_map,
1095
                quant_config=self.moe_quant_config,
1096
                apply_router_weight_on_input=layer.apply_router_weight_on_input,
1097
            )
1098
1099
        else:
            raise ValueError(f"Unsupported backend: {self.mxfp4_backend}")
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130


class IpexMxfp4MoEMethod(Mxfp4MoEMethod):
    def __init__(self, moe_config: FusedMoEConfig):
        super().__init__(moe_config)
        self.moe_config = moe_config

    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        super().create_weights(
            layer,
            num_experts,
            hidden_size,
            intermediate_size_per_partition,
            params_dtype,
            **extra_weight_attrs,
        )
        self.original_hidden_size = hidden_size

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        import intel_extension_for_pytorch as ipex

        layer.w13_weight.data = layer.w13_weight.data.view(torch.int32)
        layer.w2_weight.data = layer.w2_weight.data.view(torch.int32)
1131
        ep_rank_start = self.moe_config.ep_rank * self.moe_config.num_local_experts
1132
1133
1134
1135
1136
1137
1138
1139
        layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE(
            layer.w13_weight,
            layer.w2_weight,
            w1_scale_inv=layer.w13_weight_scale,
            w2_scale_inv=layer.w2_weight_scale,
            w13_bias=layer.w13_bias,
            w2_bias=layer.w2_bias,
            is_mxfp4=True,
1140
            experts_start_id=ep_rank_start,
1141
1142
        )

1143
1144
1145
1146
1147
    @property
    def is_monolithic(self) -> bool:
        return True

    def apply_monolithic(
1148
        self,
1149
        layer: FusedMoE,
1150
1151
1152
        x: torch.Tensor,
        router_logits: torch.Tensor,
    ) -> torch.Tensor:
1153
        assert layer.activation == "swigluoai", (
1154
            "Only swiglu_oai activation is supported for IPEX MXFP4 MoE"
1155
        )
1156
1157
1158
1159
        hidden_size_pad = round_up(self.original_hidden_size, 128)
        x_pad = torch.nn.functional.pad(x, (0, hidden_size_pad - x.size(-1)))
        hidden_states = layer.ipex_fusion(
            x_pad,
1160
1161
            layer.use_grouped_topk,
            layer.top_k,
1162
            router_logits,
1163
1164
1165
            layer.renormalize,
            layer.topk_group,
            layer.num_expert_group,
1166
1167
1168
1169
            activation="swiglu_oai",
        )
        hidden_states = hidden_states[..., : self.original_hidden_size].contiguous()
        return hidden_states