mxfp4.py 47.5 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from enum import Enum
4
5
6
7
8

import torch
from torch.nn.parameter import Parameter

from vllm import envs
9
from vllm.config import get_current_vllm_config
10
from vllm.logger import init_logger
11
from vllm.model_executor.layers.attention import Attention
12
13
14
15
from vllm.model_executor.layers.fused_moe import (
    FusedMoE,
    FusedMoEConfig,
    FusedMoEMethodBase,
16
    MoEActivation,
17
)
18
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
19
20
21
from vllm.model_executor.layers.fused_moe.all2all_utils import (
    maybe_make_prepare_finalize,
)
22
from vllm.model_executor.layers.fused_moe.config import (
23
    FusedMoEQuantConfig,
24
    mxfp4_mxfp8_moe_quant_config,
25
    mxfp4_w4a16_moe_quant_config,
26
    ocp_mx_moe_quant_config,
27
)
28
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
29
    BatchedMarlinExperts,
30
31
    MarlinExperts,
)
32
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
33
    OAITritonExperts,
34
    UnfusedOAITritonExperts,
35
)
36
from vllm.model_executor.layers.fused_moe.trtllm_moe import TrtLlmGenExperts
37
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
38
39
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
40
41
42
    QuantizationConfig,
    QuantizeMethodBase,
)
43
44
45
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
    get_marlin_input_dtype,
)
46
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
47
48
    prepare_moe_fp4_layer_for_marlin,
)
49
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
50
51
    _can_support_mxfp4,
    _swizzle_mxfp4,
52
    get_padding_alignment,
53
54
)
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
55
from vllm.model_executor.utils import set_weight_attrs
56
from vllm.platforms import current_platform
57
from vllm.utils.flashinfer import has_flashinfer
58
from vllm.utils.import_utils import has_triton_kernels
Cyrus Leung's avatar
Cyrus Leung committed
59
from vllm.utils.math_utils import round_up
60

61
62
63
logger = init_logger(__name__)


64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# enum for mxfp4 backend
class Mxfp4Backend(Enum):
    NONE = 0

    # FlashInfer Backend
    SM100_FI_MXFP4_MXFP8_TRTLLM = 1
    SM100_FI_MXFP4_MXFP8_CUTLASS = 2
    SM100_FI_MXFP4_BF16 = 3
    SM90_FI_MXFP4_BF16 = 4

    # Marlin Backend
    MARLIN = 5

    # Triton Backend
    TRITON = 6


81
82
83
84
85
86
87
88
def get_mxfp4_backend_with_lora() -> Mxfp4Backend:
    """
    Not all MXFP4 backends support LoRA. Select backends that are known to
    have LoRA support.
    """
    if not current_platform.is_cuda():
        return Mxfp4Backend.NONE

89
90
91
92
93
94
95
96
    # If FlashInfer is not available, try either Marlin or Triton
    triton_kernels_supported = (
        has_triton_kernels()
        # NOTE: triton_kernels are only confirmed to work on SM90 and SM100
        # SM110 fails with this error: https://github.com/vllm-project/vllm/issues/29317
        # SM120 needs this fix: https://github.com/triton-lang/triton/pull/8498
        and (9, 0) <= current_platform.get_device_capability() < (11, 0)
    )
97
98
99
    if envs.VLLM_MXFP4_USE_MARLIN is False and triton_kernels_supported:
        logger.info_once("[get_mxfp4_backend_with_lora] Using Triton backend")
        return Mxfp4Backend.TRITON
100

101
102
    logger.info_once("[get_mxfp4_backend_with_lora] Using Marlin backend")
    return Mxfp4Backend.MARLIN
103
104
105


def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend:
106
    # Backend Selection
107
108
109
110

    if with_lora_support:
        return get_mxfp4_backend_with_lora()

111
    if current_platform.is_cuda():
112
113
114
115
116
        if (
            current_platform.is_device_capability(90)
            and has_flashinfer()
            and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16
        ):
117
118
            logger.info_once("Using FlashInfer MXFP4 BF16 backend for SM90")
            return Mxfp4Backend.SM90_FI_MXFP4_BF16
119
        elif (
120
            current_platform.is_device_capability_family(100)
121
122
123
124
            and has_flashinfer()
            and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS
        ):
            logger.info_once("Using FlashInfer MXFP4 MXFP8 CUTLASS backend for SM100")
125
            return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
126
        elif (
127
            current_platform.is_device_capability_family(100)
128
129
130
            and has_flashinfer()
            and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
        ):
131
132
133
            logger.info_once(
                "Using FlashInfer MXFP4 MXFP8 TRTLLM backend for SM100", scope="local"
            )
134
            return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
135
        elif current_platform.is_device_capability_family(100) and has_flashinfer():
136
137
138
139
            logger.info_once(
                "Using FlashInfer MXFP4 BF16 backend for SM100, "
                "For faster performance on SM100, consider setting "
                "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1, though this may impact "
140
141
                "accuracy."
            )
142
            return Mxfp4Backend.SM100_FI_MXFP4_BF16
143
        elif (
144
            current_platform.is_device_capability_family(100)
145
146
            or current_platform.is_device_capability(90)
        ) and not has_flashinfer():
147
148
149
            logger.warning_once(
                "MXFP4 MoE is enabled on Hopper/Blackwell but FlashInfer "
                "is not available. This may result in degraded performance. "
150
151
                "Please `pip install vllm[flashinfer]` for best results."
            )
152

153
        # If FlashInfer is not available, try either Marlin or Triton
154
155
156
157
158
159
160
161
        triton_kernels_supported = (
            has_triton_kernels()
            # NOTE: triton_kernels are only confirmed to work on SM90 and SM100
            # SM110 fails with this error: https://github.com/vllm-project/vllm/issues/29317
            # SM120 needs this fix: https://github.com/triton-lang/triton/pull/8498
            and (9, 0) <= current_platform.get_device_capability() < (11, 0)
        )
        if envs.VLLM_MXFP4_USE_MARLIN or not triton_kernels_supported:
162
163
164
165
166
            logger.info_once("Using Marlin backend")
            return Mxfp4Backend.MARLIN
        else:
            logger.info_once("Using Triton backend")
            return Mxfp4Backend.TRITON
167
    elif current_platform.is_xpu():
168
        logger.info_once("Using xpu backend on XPU")
169
        return Mxfp4Backend.MARLIN
170
171
172
    elif current_platform.is_rocm() and has_triton_kernels():
        logger.info_once("Using Triton backend")
        return Mxfp4Backend.TRITON
173

174
    return Mxfp4Backend.NONE
175
176
177


class Mxfp4Config(QuantizationConfig):
178
    def __init__(self, ignored_layers: list[str] | None = None):
179
180
181
182
183
184
185
186
187
        super().__init__()
        self.ignored_layers = ignored_layers

    @classmethod
    def from_config(cls, config):
        return cls()

    @classmethod
    def get_min_capability(cls) -> int:
188
        return 80
189
190
191
192
193
194
195
196
197
198
199
200
201

    @classmethod
    def get_name(cls) -> QuantizationMethods:
        return "mxfp4"

    @classmethod
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
        return [torch.bfloat16]

    @classmethod
    def get_config_filenames(cls) -> list[str]:
        return []

202
203
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
204
    ) -> "QuantizeMethodBase | None":
205
206
        if isinstance(layer, LinearBase):
            if self.ignored_layers and is_layer_skipped(
207
208
209
210
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
211
                return UnquantizedLinearMethod()
212
213
214
            # TODO: Add support for MXFP4 Linear Method.
            # MXFP4 LinearMethod is available in AMD-Quark, refer to that implementation
            # if you are interested in enabling MXFP4 here.
215
            logger.debug_once(
216
                "MXFP4 linear layer is not implemented - falling back to "
217
218
                "UnquantizedLinearMethod.",
                scope="local",
219
220
            )
            return UnquantizedLinearMethod()
221
        elif isinstance(layer, FusedMoE):
222
            if current_platform.is_xpu():
223
                return XpuMxfp4MoEMethod(layer.moe_config)
224
            else:
225
226
                quant_method = Mxfp4MoEMethod(layer.moe_config)
                return quant_method
227
        elif isinstance(layer, Attention):
228
            # TODO: Add support for MXFP4 Attention.
229
            logger.debug_once(
230
                "MXFP4 attention layer is not implemented. "
231
232
                "Skipping quantization for this layer.",
                scope="local",
233
            )
234
235
        return None

236
237
238
239
    def is_mxfp4_quant(self, prefix: str, layer: torch.nn.Module) -> bool:
        """MXFP4 config always uses MXFP4 quantization."""
        return True

240
241

class Mxfp4MoEMethod(FusedMoEMethodBase):
242
243
    """MXFP4 MoE quantization method."""

244
    def __init__(self, moe: FusedMoEConfig):
245
        super().__init__(moe)
246
        self.weight_dtype = "mxfp4"
247
        self.mxfp4_backend = get_mxfp4_backend(moe.is_lora_enabled)
248

249
        self.max_capture_size = (
250
            get_current_vllm_config().compilation_config.max_cudagraph_capture_size
251
        )
252

253
        assert self.mxfp4_backend != Mxfp4Backend.NONE, (
254
255
            f"get_mxfp4_backend(with_lora_support={moe.is_lora_enabled}) found"
            "no compatible MXFP4 MoE backend (FlashInfer/Marlin/Triton)."
256
257
            "Please check your environment and try again."
        )
258
        self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
259
        self.moe_mk: mk.FusedMoEModularKernel | None = None
260

261
262
263
264
265
266
267
268
269
    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
270
271
272
273
274
275
276
277
278
279
280
281
282
283
        self.num_experts = num_experts
        weight_dtype = torch.uint8
        scale_dtype = torch.uint8

        # FIXME (zyongye): ship after torch and safetensors support mxfp4
        # is_torch_mxfp4_available = (
        #     hasattr(torch, "float4_e2m1fn_x2") and
        #     hasattr(torch, "float8_e8m0fnu"))
        # if is_torch_mxfp4_available:
        #     weight_dtype = torch.float4_e2m1fn_x2
        #     scale_dtype = torch.float8_e8m0fnu

        mxfp4_block = 32

284
        intermediate_size_per_partition_after_pad = intermediate_size_per_partition
285
        if self.mxfp4_backend == Mxfp4Backend.MARLIN:
286
287
288
289
290
291
292
293
294
            # The moe marlin kernel requires that for each linear
            # n % 256 == 0 and k % 128 == 0.
            # In gate_up_proj:
            #    n = 2 * intermediate_size_per_partition_after_pad
            #    k = hidden_size
            # In down_proj
            #    n = hidden_size
            #    k = intermediate_size_per_partition_after_pad
            intermediate_size_per_partition_after_pad = round_up(
295
296
                intermediate_size_per_partition, 128
            )
297
298
299
300
            if current_platform.is_xpu():
                hidden_size = round_up(hidden_size, 128)
            else:
                hidden_size = round_up(hidden_size, 256)
301
302
303
304

            layer.params_dtype = params_dtype
            layer.num_experts = num_experts
            layer.hidden_size = hidden_size
305
            layer.intermediate_size_per_partition = (
306
                intermediate_size_per_partition_after_pad
307
308
309
310
311
            )
        elif (
            self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
            or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
        ):
312
313
314
            # pad the intermediate size to be a multiple of 2 * mxfp4_block
            # for to hold non-uniform sharded tensor as well as swizzling
            # other padding to increase performance
315
            intermediate_size_per_partition_after_pad = round_up(
316
317
                intermediate_size_per_partition, 256
            )
318
            hidden_size = round_up(hidden_size, 256)
319
320
321
322
        elif (
            self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
            or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
        ):
323
            intermediate_size_per_partition_after_pad = round_up(
324
325
                intermediate_size_per_partition, 128
            )
326
            hidden_size = round_up(hidden_size, 128)
327
        elif current_platform.is_rocm():
328
            pad_align = get_padding_alignment()
329
            intermediate_size_per_partition_after_pad = round_up(
330
                intermediate_size_per_partition, pad_align
331
            )
332
            hidden_size = round_up(hidden_size, pad_align)
333
334
        else:
            intermediate_size_per_partition_after_pad = round_up(
335
336
                intermediate_size_per_partition, 64
            )
337
338
339
340

        self.intermediate_size = intermediate_size_per_partition_after_pad
        self.hidden_size = hidden_size
        # Fused gate_up_proj (column parallel)
341
342
343
344
345
346
347
348
349
        w13_weight = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                2 * intermediate_size_per_partition_after_pad,
                hidden_size // 2,
                dtype=weight_dtype,
            ),
            requires_grad=False,
        )
350
351
352
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

353
354
355
356
357
358
359
360
361
        w13_weight_scale = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                2 * intermediate_size_per_partition_after_pad,
                hidden_size // mxfp4_block,
                dtype=scale_dtype,
            ),
            requires_grad=False,
        )
362
363
364
        layer.register_parameter("w13_weight_scale", w13_weight_scale)
        set_weight_attrs(w13_weight_scale, extra_weight_attrs)

365
366
367
368
369
370
371
372
        w13_bias = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                2 * intermediate_size_per_partition_after_pad,
                dtype=torch.bfloat16,
            ),
            requires_grad=False,
        )
373
374
375
376
        layer.register_parameter("w13_bias", w13_bias)
        set_weight_attrs(w13_bias, extra_weight_attrs)

        # down_proj (row parallel)
377
378
379
380
381
382
383
384
385
        w2_weight = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                hidden_size,
                intermediate_size_per_partition_after_pad // 2,
                dtype=weight_dtype,
            ),
            requires_grad=False,
        )
386
387
388
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

389
390
391
392
393
394
395
396
397
        w2_weight_scale = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                hidden_size,
                intermediate_size_per_partition_after_pad // mxfp4_block,
                dtype=scale_dtype,
            ),
            requires_grad=False,
        )
398
399
400
        layer.register_parameter("w2_weight_scale", w2_weight_scale)
        set_weight_attrs(w2_weight_scale, extra_weight_attrs)

401
402
403
404
405
406
407
408
        w2_bias = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                hidden_size,
                dtype=torch.bfloat16,
            ),
            requires_grad=False,
        )
409
410
411
412
        layer.register_parameter("w2_bias", w2_bias)
        set_weight_attrs(w2_bias, extra_weight_attrs)

    def process_weights_after_loading(self, layer):
413
        if self.mxfp4_backend == Mxfp4Backend.MARLIN:
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
            prepare_moe_fp4_layer_for_marlin(
                layer, input_dtype=get_marlin_input_dtype()
            )

            self.moe_quant_config = self.get_fused_moe_quant_config(layer)
            assert self.moe_quant_config is not None

            prepare_finalize = maybe_make_prepare_finalize(
                moe=self.moe,
                quant_config=self.moe_quant_config,
                routing_tables=layer._maybe_init_expert_routing_tables(),
                allow_new_interface=True,
            )
            assert prepare_finalize is not None

            self.moe_mk = mk.FusedMoEModularKernel(
                prepare_finalize,
                MarlinExperts(
                    self.moe,
                    self.moe_quant_config,
                ),
                inplace=not self.moe.disable_inplace,
                shared_experts=None,
            )
438
439
440
441
442
        elif (
            self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
            or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
        ):
            from flashinfer.fp4_quantization import nvfp4_block_scale_interleave
443
            from flashinfer.fused_moe.core import get_w2_permute_indices_with_cache
444
445
446
447
448
449
450
451
452
453
454
455
456

            layer.gemm1_alpha = Parameter(
                torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(),
                requires_grad=False,
            )
            layer.gemm1_beta = Parameter(
                torch.tensor([1.0] * self.num_experts, dtype=torch.float32).cuda(),
                requires_grad=False,
            )
            layer.gemm1_clamp_limit = Parameter(
                torch.tensor([7.0] * self.num_experts, dtype=torch.float32).cuda(),
                requires_grad=False,
            )
457
458
            sf_block_size = 32  # mxfp4 block size

459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
            assert (
                layer.w13_weight.dim() == 3
                and layer.w13_weight.shape[0] == self.num_experts
                and layer.w13_weight.shape[1] == self.intermediate_size * 2
                and layer.w13_weight.shape[2] == self.hidden_size // 2
            )
            assert (
                layer.w13_weight_scale.dim() == 3
                and layer.w13_weight_scale.shape[0] == self.num_experts
                and layer.w13_weight_scale.shape[1] == self.intermediate_size * 2
                and layer.w13_weight_scale.shape[2] == self.hidden_size // sf_block_size
            )
            assert (
                layer.w2_weight.dim() == 3
                and layer.w2_weight.shape[0] == self.num_experts
                and layer.w2_weight.shape[1] == self.hidden_size
                and layer.w2_weight.shape[2] == self.intermediate_size // 2
            )
            assert (
                layer.w2_weight_scale.dim() == 3
                and layer.w2_weight_scale.shape[1] == self.hidden_size
                and layer.w2_weight_scale.shape[2]
                == self.intermediate_size // sf_block_size
            )
            assert (
                layer.w13_bias.dim() == 2
                and layer.w13_bias.shape[0] == self.num_experts
                and layer.w13_bias.shape[1] == self.intermediate_size * 2
            )
            assert (
                layer.w2_bias.dim() == 2
                and layer.w2_bias.shape[0] == self.num_experts
                and layer.w2_bias.shape[1] == self.hidden_size
            )
493
494
495
496
497
498
499
500

            w13_weight_scale = layer.w13_weight_scale.data
            w2_weight_scale = layer.w2_weight_scale.data
            w13_weight = layer.w13_weight.data
            w2_weight = layer.w2_weight.data
            w13_bias = layer.w13_bias.data.to(torch.float32)
            w2_bias = layer.w2_bias.data.to(torch.float32)

co63oc's avatar
co63oc committed
501
            # Swap w1 and w3 as the definition of
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
            # swiglu is different in the trtllm-gen
            def swap_every_two_rows(x, axis=-1):
                shape = x.shape
                if axis < 0:
                    axis = len(shape) + axis

                # Create a new shape with pairs swapped along specified axis
                new_shape = list(shape)
                new_shape[axis] = shape[axis] // 2
                new_shape.insert(axis + 1, 2)

                # Reshape to expose pairs, swap them, and reshape back
                x = x.reshape(*new_shape)
                x = x.flip(axis + 1)
                new_shape = list(shape)
                return x.reshape(*new_shape)

            w13_weight_scale = swap_every_two_rows(w13_weight_scale, -2)
            w13_weight = swap_every_two_rows(w13_weight, -2)
            w13_bias = swap_every_two_rows(w13_bias, -1)

            # Do not interleave as the checkpoint is already interleaved

            # Shuffle weights and scaling factors for transposed mma output
            gemm1_weights_mxfp4_shuffled = []
            gemm1_scales_mxfp4_shuffled = []
            gemm2_weights_mxfp4_shuffled = []
            gemm2_scales_mxfp4_shuffled = []
            gemm1_bias_shuffled = []
            gemm2_bias_shuffled = []
            epilogue_tile_m = 128  # FIXME: this depends on the kernel internals
            for i in range(self.num_experts):
534
                # w13 weight shuffling
535
                permute_indices = get_w2_permute_indices_with_cache(
536
537
538
539
                    self._cache_permute_indices,
                    w13_weight[i].view(torch.uint8),
                    epilogue_tile_m,
                )
540
541
542
543
544
                gemm1_weights_mxfp4_shuffled.append(
                    w13_weight[i]
                    .view(torch.uint8)[permute_indices.to(w13_weight.device)]
                    .contiguous()
                )
545
                # w13 scale shuffling
546
                permute_sf_indices = get_w2_permute_indices_with_cache(
547
548
549
550
551
                    self._cache_permute_indices,
                    w13_weight_scale[i].view(torch.uint8),
                    epilogue_tile_m,
                    num_elts_per_sf=16,
                )
552
                gemm1_scales_mxfp4_shuffled.append(
553
554
555
556
557
558
559
560
                    nvfp4_block_scale_interleave(
                        w13_weight_scale[i]
                        .view(torch.uint8)[
                            permute_sf_indices.to(w13_weight_scale.device)
                        ]
                        .contiguous()
                    )
                )
561
                # w13 bias shuffling
562
                permute_bias_indices = get_w2_permute_indices_with_cache(
563
564
565
566
                    self._cache_permute_indices,
                    w13_bias[i].clone().reshape(-1, 1),
                    epilogue_tile_m,
                )
567
568
569
570
571
572
                gemm1_bias_shuffled.append(
                    w13_bias[i]
                    .clone()
                    .reshape(-1, 1)[permute_bias_indices.to(w13_bias.device)]
                    .contiguous()
                )
573
                # w2 weight shuffling
574
                permute_indices = get_w2_permute_indices_with_cache(
575
576
577
578
                    self._cache_permute_indices,
                    w2_weight[i].view(torch.uint8),
                    epilogue_tile_m,
                )
579
580
581
582
583
                gemm2_weights_mxfp4_shuffled.append(
                    w2_weight[i]
                    .view(torch.uint8)[permute_indices.to(w2_weight.device)]
                    .contiguous()
                )
584
                # w2 scale shuffling
585
                permute_sf_indices = get_w2_permute_indices_with_cache(
586
587
588
589
590
                    self._cache_permute_indices,
                    w2_weight_scale[i].view(torch.uint8),
                    epilogue_tile_m,
                    num_elts_per_sf=16,
                )
591
                gemm2_scales_mxfp4_shuffled.append(
592
593
594
595
596
597
598
599
                    nvfp4_block_scale_interleave(
                        w2_weight_scale[i]
                        .view(torch.uint8)[
                            permute_sf_indices.to(w2_weight_scale.device)
                        ]
                        .contiguous()
                    )
                )
600
                # w2 bias shuffling
601
                permute_indices = get_w2_permute_indices_with_cache(
602
603
604
605
                    self._cache_permute_indices,
                    w2_bias[i].clone().reshape(-1, 1),
                    epilogue_tile_m,
                )
606
607
608
609
610
611
                gemm2_bias_shuffled.append(
                    w2_bias[i]
                    .clone()
                    .reshape(-1, 1)[permute_indices.to(w2_bias.device)]
                    .contiguous()
                )
612
613

            w13_weight = torch.stack(gemm1_weights_mxfp4_shuffled)
614
615
616
617
618
619
620
621
622
            w13_weight_scale = (
                torch.stack(gemm1_scales_mxfp4_shuffled)
                .reshape(
                    self.num_experts,
                    2 * self.intermediate_size,
                    self.hidden_size // sf_block_size,
                )
                .view(torch.float8_e4m3fn)
            )
623
624

            w2_weight = torch.stack(gemm2_weights_mxfp4_shuffled)
625
626
627
628
629
630
631
632
633
            w2_weight_scale = (
                torch.stack(gemm2_scales_mxfp4_shuffled)
                .reshape(
                    self.num_experts,
                    self.hidden_size,
                    self.intermediate_size // sf_block_size,
                )
                .view(torch.float8_e4m3fn)
            )
634
635

            layer.w13_weight = Parameter(w13_weight, requires_grad=False)
636
            layer.w13_weight_scale = Parameter(w13_weight_scale, requires_grad=False)
637
            layer.w2_weight = Parameter(w2_weight, requires_grad=False)
638
            layer.w2_weight_scale = Parameter(w2_weight_scale, requires_grad=False)
639
640
            layer.w13_bias = Parameter(
                torch.stack(gemm1_bias_shuffled).reshape(self.num_experts, -1),
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
                requires_grad=False,
            )
            layer.w2_bias = Parameter(
                torch.stack(gemm2_bias_shuffled).reshape(self.num_experts, -1),
                requires_grad=False,
            )
        elif (
            self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
            or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
        ):
            layer.gemm1_alpha = Parameter(
                torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(),
                requires_grad=False,
            )
            layer.gemm1_beta = Parameter(
                torch.tensor([1.0] * self.num_experts, dtype=torch.float32).cuda(),
                requires_grad=False,
            )
            layer.gemm1_clamp_limit = Parameter(
                torch.tensor([7.0] * self.num_experts, dtype=torch.float32).cuda(),
                requires_grad=False,
            )
663
664
665
666

            sf_block_size = 32  # mxfp4 block size

            # Common shape assertions
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
            assert (
                layer.w13_weight.dim() == 3
                and layer.w13_weight.shape[0] == self.num_experts
                and layer.w13_weight.shape[1] == self.intermediate_size * 2
                and layer.w13_weight.shape[2] == self.hidden_size // 2
            )
            assert (
                layer.w13_weight_scale.dim() == 3
                and layer.w13_weight_scale.shape[0] == self.num_experts
                and layer.w13_weight_scale.shape[1] == self.intermediate_size * 2
                and layer.w13_weight_scale.shape[2] == self.hidden_size // sf_block_size
            )
            assert (
                layer.w2_weight.dim() == 3
                and layer.w2_weight.shape[0] == self.num_experts
                and layer.w2_weight.shape[1] == self.hidden_size
                and layer.w2_weight.shape[2] == self.intermediate_size // 2
            )
            assert (
                layer.w2_weight_scale.dim() == 3
                and layer.w2_weight_scale.shape[1] == self.hidden_size
                and layer.w2_weight_scale.shape[2]
                == self.intermediate_size // sf_block_size
            )
            assert (
                layer.w13_bias.dim() == 2
                and layer.w13_bias.shape[0] == self.num_experts
                and layer.w13_bias.shape[1] == self.intermediate_size * 2
            )
            assert (
                layer.w2_bias.dim() == 2
                and layer.w2_bias.shape[0] == self.num_experts
                and layer.w2_bias.shape[1] == self.hidden_size
            )
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725

            # De-interleave and swap for w13 weight, bias, and scales
            w13_w = layer.w13_weight.data
            gate_w, up_w = w13_w[:, ::2, :], w13_w[:, 1::2, :]
            deinterleaved_w13_w = torch.cat([gate_w, up_w], dim=1)
            w1_w, w3_w = torch.chunk(deinterleaved_w13_w, 2, dim=1)
            w13_weight_swapped = torch.cat([w3_w, w1_w], dim=1)

            w13_b = layer.w13_bias.data.to(torch.float32)
            gate_b, up_b = w13_b[:, ::2], w13_b[:, 1::2]
            deinterleaved_w13_b = torch.cat([gate_b, up_b], dim=1)
            b1, b3 = torch.chunk(deinterleaved_w13_b, 2, dim=-1)
            w13_bias_swapped = torch.cat([b3, b1], dim=-1).to(torch.bfloat16)

            w13_s = layer.w13_weight_scale.data
            gate_s, up_s = w13_s[:, ::2, :], w13_s[:, 1::2, :]
            deinterleaved_w13_s = torch.cat([gate_s, up_s], dim=1)
            s1, s3 = torch.chunk(deinterleaved_w13_s, 2, dim=1)
            w13_scale_swapped = torch.cat([s3, s1], dim=1)

            if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS:
                from flashinfer import block_scale_interleave

                orig_shape = w13_scale_swapped.shape
                w13_scale_interleaved = block_scale_interleave(
726
727
                    w13_scale_swapped.view(torch.uint8)
                ).reshape(orig_shape)
728
729
730
731

                w2_s = layer.w2_weight_scale.data
                orig_shape = w2_s.shape
                w2_scale_interleaved = block_scale_interleave(
732
733
734
735
736
737
738
739
740
741
742
                    w2_s.view(torch.uint8)
                ).reshape(orig_shape)

                layer.w13_weight = Parameter(w13_weight_swapped, requires_grad=False)
                layer.w13_weight_scale = Parameter(
                    w13_scale_interleaved, requires_grad=False
                )
                layer.w13_bias = Parameter(w13_bias_swapped, requires_grad=False)
                layer.w2_weight_scale = Parameter(
                    w2_scale_interleaved, requires_grad=False
                )
743
744
745
746
            elif self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16:

                def _interleave_mxfp4_cutlass_sm90(w):
                    w_shape = w.shape
747
748
749
                    w_interleaved = w.reshape(
                        w_shape[0], w_shape[1], (w_shape[2] // 4), 4
                    )
750
751
                    w_interleaved = w_interleaved.permute(0, 2, 1, 3)
                    w_interleaved = w_interleaved.reshape(
752
753
                        w_shape[0], w_shape[2] // 4, w_shape[1] * 4
                    )
754
755
                    return w_interleaved

756
757
                w31_scales = w13_scale_swapped.to(torch.uint8).view(torch.uint8)
                w31_scales_interleaved = _interleave_mxfp4_cutlass_sm90(w31_scales)
758
759
760

                w2_weight_scale = layer.w2_weight_scale.data
                w2_scales = w2_weight_scale.to(torch.uint8).view(torch.uint8)
761
762
763
764
765
766
767
768
                w2_scales_interleaved = _interleave_mxfp4_cutlass_sm90(w2_scales)

                layer.w13_weight = torch.nn.Parameter(
                    torch.cat([w3_w, w1_w], dim=1), requires_grad=False
                )
                layer.w13_bias = torch.nn.Parameter(
                    w13_bias_swapped, requires_grad=False
                )
769
                layer.w13_weight_scale = torch.nn.Parameter(
770
771
                    w31_scales_interleaved, requires_grad=False
                )
772
                layer.w2_weight_scale = torch.nn.Parameter(
773
774
                    w2_scales_interleaved, requires_grad=False
                )
775
        elif self.mxfp4_backend == Mxfp4Backend.TRITON:
776
777
778
779
780
781
782
783
            from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig

            w13_bias = layer.w13_bias.to(torch.float32)
            w2_bias = layer.w2_bias.to(torch.float32)

            layer.w13_bias = Parameter(w13_bias, requires_grad=False)
            layer.w2_bias = Parameter(w2_bias, requires_grad=False)

784
785
786
787
788
            # Ideally we'd use FusedMoEModularKernel.prepare_finalize object
            # (stored in self.fused_experts) to determine if the MoE has a
            # batched activation format. As self.fused_experts is not
            # initialized at this point, we resort to checking the MoE config
            # directly.
789
            is_batched_moe = self.moe.use_pplx_kernels or self.moe.use_deepep_ll_kernels
790
            if is_batched_moe:
791
792
793
794
795
                num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8
            else:
                num_warps = 8

            w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
796
797
                layer.w13_weight, layer.w13_weight_scale, num_warps
            )
798
            w2_weight, w2_flex, w2_scale = _swizzle_mxfp4(
799
800
                layer.w2_weight, layer.w2_weight_scale, num_warps
            )
801
802

            self.w13_precision_config = PrecisionConfig(
803
804
                weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex)
            )
805
            self.w2_precision_config = PrecisionConfig(
806
807
                weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex)
            )
808

809
810
            self.w13_weight = w13_weight
            self.w2_weight = w2_weight
811
812
813
814
            del layer.w13_weight
            del layer.w2_weight
            layer.w13_weight = w13_weight
            layer.w2_weight = w2_weight
815
        else:
816
817
818
819
            raise ValueError(
                f"Unsupported mxfp4_backend: {self.mxfp4_backend}: "
                f"should be one of: {list(Mxfp4Backend)}."
            )
820

821
    def get_fused_moe_quant_config(
822
        self, layer: torch.nn.Module
823
    ) -> FusedMoEQuantConfig | None:
824
        if self.mxfp4_backend == Mxfp4Backend.MARLIN:
825
826
827
828
829
830
831
            return mxfp4_w4a16_moe_quant_config(
                w1_bias=layer.w13_bias,
                w2_bias=layer.w2_bias,
                w1_scale=layer.w13_weight_scale,
                w2_scale=layer.w2_weight_scale,
            )
        elif self.mxfp4_backend == Mxfp4Backend.TRITON:
832
833
            w1_scale = self.w13_precision_config
            w2_scale = self.w2_precision_config
834
835
836
837
838
839
            return mxfp4_w4a16_moe_quant_config(
                w1_bias=layer.w13_bias,
                w2_bias=layer.w2_bias,
                w1_scale=w1_scale,
                w2_scale=w2_scale,
            )
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
        elif self.mxfp4_backend in [
            Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM,
            Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS,
        ]:
            return mxfp4_mxfp8_moe_quant_config(
                w1_bias=layer.w13_bias,
                w2_bias=layer.w2_bias,
                w1_scale=layer.w13_weight_scale,
                w2_scale=layer.w2_weight_scale,
            )
        elif self.mxfp4_backend in [Mxfp4Backend.SM100_FI_MXFP4_BF16]:
            return mxfp4_w4a16_moe_quant_config(
                w1_bias=layer.w13_bias,
                w2_bias=layer.w2_bias,
                w1_scale=layer.w13_weight_scale,
                w2_scale=layer.w2_weight_scale,
            )
857
858
859
        else:
            w1_scale = layer.w13_weight_scale
            w2_scale = layer.w2_weight_scale
860
861
            return ocp_mx_moe_quant_config(
                quant_dtype="mxfp4",
862
863
864
865
866
                w1_bias=layer.w13_bias,
                w2_bias=layer.w2_bias,
                w1_scale=w1_scale,
                w2_scale=w2_scale,
            )
867

868
869
870
871
872
    def select_gemm_impl(
        self,
        prepare_finalize: mk.FusedMoEPrepareAndFinalize,
        layer: torch.nn.Module,
    ) -> mk.FusedMoEPermuteExpertsUnpermute:
873
874
875
876
        if (
            prepare_finalize.activation_format
            == mk.FusedMoEActivationFormat.BatchedExperts
        ):
877
878
879
880
881
882
883
884
            if self.mxfp4_backend == Mxfp4Backend.MARLIN:
                max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
                assert max_num_tokens_per_rank is not None
                assert self.moe_quant_config is not None
                return BatchedMarlinExperts(
                    max_num_tokens=max_num_tokens_per_rank,
                    num_dispatchers=prepare_finalize.num_dispatchers(),
                    quant_config=self.moe_quant_config,
885
                    moe_config=self.moe,
886
887
888
                )
            else:
                raise NotImplementedError(
889
890
                    f"Incompatible Mxfp4 backend ({self.mxfp4_backend}) for "
                    "EP batched experts format"
891
                )
892
        else:
893
            assert self.moe_quant_config is not None
894
895
896
897
            if (
                self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
                or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
            ):
898
899
900
901
902
                # B200 code-path
                kwargs = {
                    "gemm1_alpha": layer.gemm1_alpha,
                    "gemm1_beta": layer.gemm1_beta,
                    "gemm1_clamp_limit": layer.gemm1_clamp_limit,
903
                    # TODO(bnell): part of quant_config
904
905
                    "max_capture_size": self.max_capture_size,
                }
906
907
                return TrtLlmGenExperts(self.moe, self.moe_quant_config, **kwargs)
            elif self.mxfp4_backend == Mxfp4Backend.MARLIN:
908
                return MarlinExperts(self.moe, self.moe_quant_config)
909
            elif self.mxfp4_backend == Mxfp4Backend.TRITON:
910
                if self.moe.is_lora_enabled:
911
912
                    return UnfusedOAITritonExperts(self.moe, self.moe_quant_config)
                return OAITritonExperts(self.moe, self.moe_quant_config)
913
914
915
916
            else:
                raise NotImplementedError(
                    f"Incompatible Mxfp4 backend ({self.mxfp4_backend}) for EP"
                )
917

918
919
920
921
922
923
924
925
    @property
    def is_monolithic(self) -> bool:
        return (
            self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
            or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
            or self.mxfp4_backend == Mxfp4Backend.TRITON
        )

926
927
    def apply(
        self,
928
        layer: FusedMoE,
929
        x: torch.Tensor,
930
931
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
932
        shared_experts_input: torch.Tensor | None,
933
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
934
        assert not self.is_monolithic
935
        if layer.enable_eplb:
936
937
            raise NotImplementedError("EPLB is not supported for mxfp4")

938
        if self.mxfp4_backend == Mxfp4Backend.MARLIN:
939
940
941
942
943
944
945
946
            assert self.moe_mk is not None

            return self.moe_mk(
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
947
                activation=layer.activation,
948
                global_num_experts=layer.global_num_experts,
949
                expert_map=layer.expert_map,
950
                apply_router_weight_on_input=layer.apply_router_weight_on_input,
951
            )
952
        assert _can_support_mxfp4(
953
954
955
956
957
958
959
960
961
            layer.use_grouped_topk,
            layer.topk_group,
            layer.num_expert_group,
            layer.expert_map,
            layer.custom_routing_function,
            layer.e_score_correction_bias,
            layer.apply_router_weight_on_input,
            layer.scoring_func,
            layer.activation,
962
963
964
            layer.eplb_state.expert_load_view,
            layer.eplb_state.logical_to_physical_map,
            layer.eplb_state.logical_replica_count,
965
966
        ), "MXFP4 are not supported with this configuration."

967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
        assert (
            self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
            or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
        )
        from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe

        # Backend-specific preparation
        if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS:
            from flashinfer import mxfp8_quantize

            x_quant, x_scale = mxfp8_quantize(x, True, 32)

            fake_input_scale = torch.ones(self.num_experts, device=x.device)
            quant_scales = [
                layer.w13_weight_scale.contiguous().view(torch.int32),
                fake_input_scale,
                layer.w2_weight_scale.contiguous().view(torch.int32),
                fake_input_scale,
            ]

            fi_input = x_quant
            extra_kwargs = dict(
                use_mxfp8_act_scaling=True,
                input_sf=x_scale,
                fc1_expert_weights=layer.w13_weight.contiguous().view(torch.long),
                fc2_expert_weights=layer.w2_weight.contiguous().view(torch.long),
            )
        elif self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16:
            assert x.dtype == torch.bfloat16

            quant_scales = [
                layer.w13_weight_scale,
                layer.w2_weight_scale,
            ]

            fi_input = x
            extra_kwargs = dict(
                use_w4_group_scaling=True,
                fc1_expert_weights=layer.w13_weight,
                fc2_expert_weights=layer.w2_weight,
            )

        output = torch.empty_like(x, dtype=torch.bfloat16)

        flashinfer_cutlass_fused_moe(
            input=fi_input,
            token_selected_experts=topk_ids.to(torch.int).contiguous(),
            token_final_scales=topk_weights,
            output_dtype=torch.bfloat16,
            output=output,
            quant_scales=quant_scales,
            fc1_expert_biases=layer.w13_bias,
            fc2_expert_biases=layer.w2_bias,
            swiglu_alpha=layer.gemm1_alpha,
            swiglu_beta=layer.gemm1_beta,
            swiglu_limit=layer.gemm1_clamp_limit,
            tp_size=self.moe.tp_size,
            tp_rank=self.moe.tp_rank,
            ep_size=self.moe.ep_size,
            ep_rank=self.moe.ep_rank,
            tune_max_num_tokens=max(self.max_capture_size, 1),
            **extra_kwargs,
        )

        return output

    def apply_monolithic(
        self,
        layer: FusedMoE,
        x: torch.Tensor,
        router_logits: torch.Tensor,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        assert self.is_monolithic

        if layer.enable_eplb:
            raise NotImplementedError("EPLB is not supported for mxfp4")

        assert _can_support_mxfp4(
            layer.use_grouped_topk,
            layer.topk_group,
            layer.num_expert_group,
            layer.expert_map,
            layer.custom_routing_function,
            layer.e_score_correction_bias,
            layer.apply_router_weight_on_input,
            layer.scoring_func,
            layer.activation,
            layer.eplb_state.expert_load_view,
            layer.eplb_state.logical_to_physical_map,
            layer.eplb_state.logical_replica_count,
        ), "MXFP4 are not supported with this configuration."

1059
1060
1061
1062
        if (
            self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
            or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
        ):
1063
            from flashinfer import trtllm_fp4_block_scale_moe
1064

1065
            if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16:
1066
1067
1068
                assert x.dtype == torch.bfloat16
                x_quant = x
                x_scale = None
1069
1070
            elif self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM:
                from flashinfer import mxfp8_quantize
1071

1072
                x_quant, x_scale = mxfp8_quantize(x, False)  # to mxfp8
1073
                x_scale = x_scale.view(torch.float8_e4m3fn).reshape(*x.shape[:-1], -1)
1074

1075
            trtllm_gen_output = trtllm_fp4_block_scale_moe(
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
                routing_logits=router_logits.to(torch.bfloat16),
                routing_bias=None,
                hidden_states=x_quant,
                hidden_states_scale=x_scale,
                gemm1_weights=layer.w13_weight,  # uint8 (e2m1 x 2)
                gemm1_weights_scale=layer.w13_weight_scale,  # uint8 (e4m3 x 2)
                gemm1_bias=layer.w13_bias,  # fp32 per expert per channel
                gemm1_alpha=layer.gemm1_alpha,  # fp32 per expert
                gemm1_beta=layer.gemm1_beta,  # fp32 per expert
                gemm1_clamp_limit=layer.gemm1_clamp_limit,  # fp32 per expert
                gemm2_weights=layer.w2_weight,  # uint8 (e2m1 x 2)
                gemm2_weights_scale=layer.w2_weight_scale,  # ue8m0
                gemm2_bias=layer.w2_bias,  # fp32 per expert per channel
                output1_scale_scalar=None,
                output1_scale_gate_scalar=None,
                output2_scale_scalar=None,
                num_experts=layer.global_num_experts,
                top_k=layer.top_k,
                n_group=None,
                topk_group=None,
                intermediate_size=self.intermediate_size,  # padded to multiple of 256
                local_expert_offset=layer.ep_rank * layer.local_num_experts,
                local_num_experts=self.num_experts,
                routed_scaling_factor=None,
                routing_method_type=1 if layer.renormalize else 0,
                do_finalize=True,
1102
                tune_max_num_tokens=max(self.max_capture_size, 1),
1103
1104
            )[0]
            return trtllm_gen_output
1105
        elif self.mxfp4_backend == Mxfp4Backend.TRITON:
1106
            from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (  # noqa: E501
1107
1108
1109
                triton_kernel_moe_forward,
            )

1110
1111
            return triton_kernel_moe_forward(
                hidden_states=x,
1112
1113
                w1=layer.w13_weight,
                w2=layer.w2_weight,
1114
                gating_output=router_logits,
1115
1116
1117
1118
                topk=layer.top_k,
                renormalize=layer.renormalize,
                global_num_experts=layer.global_num_experts,
                expert_map=layer.expert_map,
1119
                quant_config=self.moe_quant_config,
1120
                apply_router_weight_on_input=layer.apply_router_weight_on_input,
1121
            )
1122
1123
        else:
            raise ValueError(f"Unsupported backend: {self.mxfp4_backend}")
1124
1125


1126
class XpuMxfp4MoEMethod(Mxfp4MoEMethod):
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
    def __init__(self, moe_config: FusedMoEConfig):
        super().__init__(moe_config)
        self.moe_config = moe_config

    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        super().create_weights(
            layer,
            num_experts,
            hidden_size,
            intermediate_size_per_partition,
            params_dtype,
            **extra_weight_attrs,
        )
        self.original_hidden_size = hidden_size

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
1151
        pass
1152

1153
1154
1155
1156
1157
    @property
    def is_monolithic(self) -> bool:
        return True

    def apply_monolithic(
1158
        self,
1159
        layer: FusedMoE,
1160
1161
1162
        x: torch.Tensor,
        router_logits: torch.Tensor,
    ) -> torch.Tensor:
1163
1164
1165
        assert layer.activation == MoEActivation.SWIGLUOAI, (
            "Only swiglu_oai activation is supported for "
            f"XPU MXFP4 MoE, not {layer.activation}."
1166
        )
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
        from vllm_xpu_kernels.fused_moe_interface import xpu_fused_moe

        M, _ = x.size()
        routing_weights = torch.empty(
            M, layer.top_k, dtype=torch.float32, device=x.device
        )
        selected_experts = torch.empty(
            M, layer.top_k, dtype=torch.int32, device=x.device
        )
        token_expert_indices = torch.empty(
            M, layer.top_k, dtype=torch.int32, device=x.device
        )

        if layer.use_grouped_topk:
            routing_weights, selected_experts = torch.ops._moe_C.fused_grouped_topk(
                x,
                router_logits,
                layer.top_k,
                layer.renormalize,
                n_expert_group=layer.num_expert_group,
                n_topk_group=layer.topk_group,
                scoring_func=layer.scoring_func,
                routed_scaling_factor=layer.routed_scaling_factor,
                bias=layer.e_score_correction_bias,
            )
        else:
            torch.ops._moe_C.topk_softmax(
                routing_weights,
                selected_experts,
                token_expert_indices,
                router_logits,
                layer.renormalize,
                layer.e_score_correction_bias,
            )

        return xpu_fused_moe(
            hidden_states=x,
            w13=layer.w13_weight,
            w13_bias=layer.w13_bias if self.moe.has_bias else None,
            w13_scales=layer.w13_weight_scale,
            w2=layer.w2_weight,
            w2_bias=layer.w2_bias if self.moe.has_bias else None,
            w2_scales=layer.w2_weight_scale,
            topk_weights=routing_weights,
            topk_ids=selected_experts,
            n_experts_per_token=layer.top_k,
            activation=layer.activation,
            num_experts=layer.local_num_experts,
            is_mxfp4=True,
1216
        )