mxfp4.py 49.6 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from enum import Enum
4
5
6
7
8

import torch
from torch.nn.parameter import Parameter

from vllm import envs
9
from vllm._aiter_ops import rocm_aiter_ops
10
from vllm.config import get_current_vllm_config
11
from vllm.logger import init_logger
12
from vllm.model_executor.layers.attention import Attention
13
14
15
16
from vllm.model_executor.layers.fused_moe import (
    FusedMoE,
    FusedMoEConfig,
    FusedMoEMethodBase,
17
    MoEActivation,
18
)
19
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
20
21
22
from vllm.model_executor.layers.fused_moe.all2all_utils import (
    maybe_make_prepare_finalize,
)
23
from vllm.model_executor.layers.fused_moe.config import (
24
    FusedMoEQuantConfig,
25
    mxfp4_mxfp8_moe_quant_config,
26
    mxfp4_w4a16_moe_quant_config,
27
    ocp_mx_moe_quant_config,
28
)
29
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
30
    BatchedMarlinExperts,
31
32
    MarlinExperts,
)
33
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
34
    OAITritonExperts,
35
    UnfusedOAITritonExperts,
36
)
37
from vllm.model_executor.layers.fused_moe.trtllm_moe import TrtLlmGenExperts
38
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
39
40
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
41
42
43
    QuantizationConfig,
    QuantizeMethodBase,
)
44
45
46
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
    get_marlin_input_dtype,
)
47
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
48
49
    prepare_moe_fp4_layer_for_marlin,
)
50
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
51
52
    _can_support_mxfp4,
    _swizzle_mxfp4,
53
    get_padding_alignment,
54
55
)
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
56
from vllm.model_executor.utils import set_weight_attrs
57
from vllm.platforms import current_platform
58
from vllm.utils.flashinfer import has_flashinfer
59
from vllm.utils.import_utils import has_triton_kernels
Cyrus Leung's avatar
Cyrus Leung committed
60
from vllm.utils.math_utils import round_up
61

62
63
64
logger = init_logger(__name__)


65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# enum for mxfp4 backend
class Mxfp4Backend(Enum):
    NONE = 0

    # FlashInfer Backend
    SM100_FI_MXFP4_MXFP8_TRTLLM = 1
    SM100_FI_MXFP4_MXFP8_CUTLASS = 2
    SM100_FI_MXFP4_BF16 = 3
    SM90_FI_MXFP4_BF16 = 4

    # Marlin Backend
    MARLIN = 5

    # Triton Backend
    TRITON = 6

81
82
    CK = 7

83

84
85
86
87
88
89
90
91
def get_mxfp4_backend_with_lora() -> Mxfp4Backend:
    """
    Not all MXFP4 backends support LoRA. Select backends that are known to
    have LoRA support.
    """
    if not current_platform.is_cuda():
        return Mxfp4Backend.NONE

92
93
94
95
96
97
98
99
    # If FlashInfer is not available, try either Marlin or Triton
    triton_kernels_supported = (
        has_triton_kernels()
        # NOTE: triton_kernels are only confirmed to work on SM90 and SM100
        # SM110 fails with this error: https://github.com/vllm-project/vllm/issues/29317
        # SM120 needs this fix: https://github.com/triton-lang/triton/pull/8498
        and (9, 0) <= current_platform.get_device_capability() < (11, 0)
    )
100
101
102
    if envs.VLLM_MXFP4_USE_MARLIN is False and triton_kernels_supported:
        logger.info_once("[get_mxfp4_backend_with_lora] Using Triton backend")
        return Mxfp4Backend.TRITON
103

104
105
    logger.info_once("[get_mxfp4_backend_with_lora] Using Marlin backend")
    return Mxfp4Backend.MARLIN
106
107
108


def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend:
109
    # Backend Selection
110
111
112
113

    if with_lora_support:
        return get_mxfp4_backend_with_lora()

114
    if current_platform.is_cuda():
115
116
117
118
119
        if (
            current_platform.is_device_capability(90)
            and has_flashinfer()
            and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16
        ):
120
121
            logger.info_once("Using FlashInfer MXFP4 BF16 backend for SM90")
            return Mxfp4Backend.SM90_FI_MXFP4_BF16
122
        elif (
123
            current_platform.is_device_capability_family(100)
124
125
126
127
            and has_flashinfer()
            and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS
        ):
            logger.info_once("Using FlashInfer MXFP4 MXFP8 CUTLASS backend for SM100")
128
            return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
129
        elif (
130
            current_platform.is_device_capability_family(100)
131
132
133
            and has_flashinfer()
            and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
        ):
134
135
136
            logger.info_once(
                "Using FlashInfer MXFP4 MXFP8 TRTLLM backend for SM100", scope="local"
            )
137
            return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
138
        elif current_platform.is_device_capability_family(100) and has_flashinfer():
139
140
141
142
            logger.info_once(
                "Using FlashInfer MXFP4 BF16 backend for SM100, "
                "For faster performance on SM100, consider setting "
                "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1, though this may impact "
143
144
                "accuracy."
            )
145
            return Mxfp4Backend.SM100_FI_MXFP4_BF16
146
        elif (
147
            current_platform.is_device_capability_family(100)
148
149
            or current_platform.is_device_capability(90)
        ) and not has_flashinfer():
150
151
152
            logger.warning_once(
                "MXFP4 MoE is enabled on Hopper/Blackwell but FlashInfer "
                "is not available. This may result in degraded performance. "
153
154
                "Please `pip install vllm[flashinfer]` for best results."
            )
155

156
        # If FlashInfer is not available, try either Marlin or Triton
157
158
159
160
161
162
163
164
        triton_kernels_supported = (
            has_triton_kernels()
            # NOTE: triton_kernels are only confirmed to work on SM90 and SM100
            # SM110 fails with this error: https://github.com/vllm-project/vllm/issues/29317
            # SM120 needs this fix: https://github.com/triton-lang/triton/pull/8498
            and (9, 0) <= current_platform.get_device_capability() < (11, 0)
        )
        if envs.VLLM_MXFP4_USE_MARLIN or not triton_kernels_supported:
165
166
167
168
169
            logger.info_once("Using Marlin backend")
            return Mxfp4Backend.MARLIN
        else:
            logger.info_once("Using Triton backend")
            return Mxfp4Backend.TRITON
170
    elif current_platform.is_xpu():
171
        logger.info_once("Using xpu backend on XPU")
172
        return Mxfp4Backend.MARLIN
173
174
175
176
177
178
179
180
181
    elif current_platform.is_rocm():
        from vllm.platforms.rocm import on_gfx950

        if rocm_aiter_ops.is_enabled() and on_gfx950():
            logger.info_once("Using CK MXFP4 MoE backend (Aiter ROCm)")
            return Mxfp4Backend.CK
        elif has_triton_kernels():
            logger.info_once("Using Triton backend")
            return Mxfp4Backend.TRITON
182

183
    return Mxfp4Backend.NONE
184
185
186


class Mxfp4Config(QuantizationConfig):
187
    def __init__(self, ignored_layers: list[str] | None = None):
188
189
190
191
192
193
194
195
196
        super().__init__()
        self.ignored_layers = ignored_layers

    @classmethod
    def from_config(cls, config):
        return cls()

    @classmethod
    def get_min_capability(cls) -> int:
197
        return 80
198
199
200
201
202
203
204
205
206
207
208
209
210

    @classmethod
    def get_name(cls) -> QuantizationMethods:
        return "mxfp4"

    @classmethod
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
        return [torch.bfloat16]

    @classmethod
    def get_config_filenames(cls) -> list[str]:
        return []

211
212
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
213
    ) -> "QuantizeMethodBase | None":
214
215
        if isinstance(layer, LinearBase):
            if self.ignored_layers and is_layer_skipped(
216
217
218
219
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
220
                return UnquantizedLinearMethod()
221
222
223
            # TODO: Add support for MXFP4 Linear Method.
            # MXFP4 LinearMethod is available in AMD-Quark, refer to that implementation
            # if you are interested in enabling MXFP4 here.
224
            logger.debug_once(
225
                "MXFP4 linear layer is not implemented - falling back to "
226
227
                "UnquantizedLinearMethod.",
                scope="local",
228
229
            )
            return UnquantizedLinearMethod()
230
        elif isinstance(layer, FusedMoE):
231
            if current_platform.is_xpu():
232
                return XpuMxfp4MoEMethod(layer.moe_config)
233
            else:
234
235
                quant_method = Mxfp4MoEMethod(layer.moe_config)
                return quant_method
236
        elif isinstance(layer, Attention):
237
            # TODO: Add support for MXFP4 Attention.
238
            logger.debug_once(
239
                "MXFP4 attention layer is not implemented. "
240
241
                "Skipping quantization for this layer.",
                scope="local",
242
            )
243
244
        return None

245
246
247
248
    def is_mxfp4_quant(self, prefix: str, layer: torch.nn.Module) -> bool:
        """MXFP4 config always uses MXFP4 quantization."""
        return True

249
250

class Mxfp4MoEMethod(FusedMoEMethodBase):
251
252
    """MXFP4 MoE quantization method."""

253
    def __init__(self, moe: FusedMoEConfig):
254
        super().__init__(moe)
255
        self.weight_dtype = "mxfp4"
256
        self.mxfp4_backend = get_mxfp4_backend(moe.is_lora_enabled)
257

258
        self.max_capture_size = (
259
            get_current_vllm_config().compilation_config.max_cudagraph_capture_size
260
        )
261

262
        assert self.mxfp4_backend != Mxfp4Backend.NONE, (
263
264
            f"get_mxfp4_backend(with_lora_support={moe.is_lora_enabled}) found"
            "no compatible MXFP4 MoE backend (FlashInfer/Marlin/Triton)."
265
266
            "Please check your environment and try again."
        )
267
        self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
268
        # Initialized in process_weights_after_loading for CUTLASS/SM90 backends
269
        self.moe_kernel: mk.FusedMoEKernel | None = None
270

271
272
273
274
275
276
277
278
279
    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
280
281
282
283
284
285
286
287
288
289
290
291
292
293
        self.num_experts = num_experts
        weight_dtype = torch.uint8
        scale_dtype = torch.uint8

        # FIXME (zyongye): ship after torch and safetensors support mxfp4
        # is_torch_mxfp4_available = (
        #     hasattr(torch, "float4_e2m1fn_x2") and
        #     hasattr(torch, "float8_e8m0fnu"))
        # if is_torch_mxfp4_available:
        #     weight_dtype = torch.float4_e2m1fn_x2
        #     scale_dtype = torch.float8_e8m0fnu

        mxfp4_block = 32

294
        intermediate_size_per_partition_after_pad = intermediate_size_per_partition
295
        if self.mxfp4_backend == Mxfp4Backend.MARLIN:
296
297
298
299
300
301
302
303
304
            # The moe marlin kernel requires that for each linear
            # n % 256 == 0 and k % 128 == 0.
            # In gate_up_proj:
            #    n = 2 * intermediate_size_per_partition_after_pad
            #    k = hidden_size
            # In down_proj
            #    n = hidden_size
            #    k = intermediate_size_per_partition_after_pad
            intermediate_size_per_partition_after_pad = round_up(
305
306
                intermediate_size_per_partition, 128
            )
307
308
309
310
            if current_platform.is_xpu():
                hidden_size = round_up(hidden_size, 128)
            else:
                hidden_size = round_up(hidden_size, 256)
311
312
313
314

            layer.params_dtype = params_dtype
            layer.num_experts = num_experts
            layer.hidden_size = hidden_size
315
            layer.intermediate_size_per_partition = (
316
                intermediate_size_per_partition_after_pad
317
318
319
320
321
            )
        elif (
            self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
            or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
        ):
322
323
324
            # pad the intermediate size to be a multiple of 2 * mxfp4_block
            # for to hold non-uniform sharded tensor as well as swizzling
            # other padding to increase performance
325
            intermediate_size_per_partition_after_pad = round_up(
326
327
                intermediate_size_per_partition, 256
            )
328
            hidden_size = round_up(hidden_size, 256)
329
330
331
332
        elif (
            self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
            or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
        ):
333
            intermediate_size_per_partition_after_pad = round_up(
334
335
                intermediate_size_per_partition, 128
            )
336
            hidden_size = round_up(hidden_size, 128)
337
        elif current_platform.is_rocm():
338
            pad_align = get_padding_alignment()
339
            intermediate_size_per_partition_after_pad = round_up(
340
                intermediate_size_per_partition, pad_align
341
            )
342
            hidden_size = round_up(hidden_size, pad_align)
343
344
        else:
            intermediate_size_per_partition_after_pad = round_up(
345
346
                intermediate_size_per_partition, 64
            )
347
348
349

        self.intermediate_size = intermediate_size_per_partition_after_pad
        self.hidden_size = hidden_size
350
351
352
353
        self.hidden_pad = extra_weight_attrs.get("hidden_pad", 0)
        self.intermediate_pad = (
            intermediate_size_per_partition_after_pad - intermediate_size_per_partition
        )
354
        # Fused gate_up_proj (column parallel)
355
356
357
358
359
360
361
362
363
        w13_weight = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                2 * intermediate_size_per_partition_after_pad,
                hidden_size // 2,
                dtype=weight_dtype,
            ),
            requires_grad=False,
        )
364
365
366
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

367
368
369
370
371
372
373
374
375
        w13_weight_scale = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                2 * intermediate_size_per_partition_after_pad,
                hidden_size // mxfp4_block,
                dtype=scale_dtype,
            ),
            requires_grad=False,
        )
376
377
378
        layer.register_parameter("w13_weight_scale", w13_weight_scale)
        set_weight_attrs(w13_weight_scale, extra_weight_attrs)

379
380
381
382
383
384
385
386
        w13_bias = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                2 * intermediate_size_per_partition_after_pad,
                dtype=torch.bfloat16,
            ),
            requires_grad=False,
        )
387
388
389
390
        layer.register_parameter("w13_bias", w13_bias)
        set_weight_attrs(w13_bias, extra_weight_attrs)

        # down_proj (row parallel)
391
392
393
394
395
396
397
398
399
        w2_weight = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                hidden_size,
                intermediate_size_per_partition_after_pad // 2,
                dtype=weight_dtype,
            ),
            requires_grad=False,
        )
400
401
402
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

403
404
405
406
407
408
409
410
411
        w2_weight_scale = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                hidden_size,
                intermediate_size_per_partition_after_pad // mxfp4_block,
                dtype=scale_dtype,
            ),
            requires_grad=False,
        )
412
413
414
        layer.register_parameter("w2_weight_scale", w2_weight_scale)
        set_weight_attrs(w2_weight_scale, extra_weight_attrs)

415
416
417
418
419
420
421
422
        w2_bias = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                hidden_size,
                dtype=torch.bfloat16,
            ),
            requires_grad=False,
        )
423
424
425
426
        layer.register_parameter("w2_bias", w2_bias)
        set_weight_attrs(w2_bias, extra_weight_attrs)

    def process_weights_after_loading(self, layer):
427
        if self.mxfp4_backend == Mxfp4Backend.MARLIN:
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
            prepare_moe_fp4_layer_for_marlin(
                layer, input_dtype=get_marlin_input_dtype()
            )

            self.moe_quant_config = self.get_fused_moe_quant_config(layer)
            assert self.moe_quant_config is not None

            prepare_finalize = maybe_make_prepare_finalize(
                moe=self.moe,
                quant_config=self.moe_quant_config,
                routing_tables=layer._maybe_init_expert_routing_tables(),
                allow_new_interface=True,
            )
            assert prepare_finalize is not None

443
            self.moe_kernel = mk.FusedMoEKernel(
444
445
446
447
448
449
450
451
                prepare_finalize,
                MarlinExperts(
                    self.moe,
                    self.moe_quant_config,
                ),
                inplace=not self.moe.disable_inplace,
                shared_experts=None,
            )
452
453
454
455
456
        elif (
            self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
            or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
        ):
            from flashinfer.fp4_quantization import nvfp4_block_scale_interleave
457
            from flashinfer.fused_moe.core import get_w2_permute_indices_with_cache
458
459
460
461
462
463
464
465
466
467
468
469
470

            layer.gemm1_alpha = Parameter(
                torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(),
                requires_grad=False,
            )
            layer.gemm1_beta = Parameter(
                torch.tensor([1.0] * self.num_experts, dtype=torch.float32).cuda(),
                requires_grad=False,
            )
            layer.gemm1_clamp_limit = Parameter(
                torch.tensor([7.0] * self.num_experts, dtype=torch.float32).cuda(),
                requires_grad=False,
            )
471
472
            sf_block_size = 32  # mxfp4 block size

473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
            assert (
                layer.w13_weight.dim() == 3
                and layer.w13_weight.shape[0] == self.num_experts
                and layer.w13_weight.shape[1] == self.intermediate_size * 2
                and layer.w13_weight.shape[2] == self.hidden_size // 2
            )
            assert (
                layer.w13_weight_scale.dim() == 3
                and layer.w13_weight_scale.shape[0] == self.num_experts
                and layer.w13_weight_scale.shape[1] == self.intermediate_size * 2
                and layer.w13_weight_scale.shape[2] == self.hidden_size // sf_block_size
            )
            assert (
                layer.w2_weight.dim() == 3
                and layer.w2_weight.shape[0] == self.num_experts
                and layer.w2_weight.shape[1] == self.hidden_size
                and layer.w2_weight.shape[2] == self.intermediate_size // 2
            )
            assert (
                layer.w2_weight_scale.dim() == 3
                and layer.w2_weight_scale.shape[1] == self.hidden_size
                and layer.w2_weight_scale.shape[2]
                == self.intermediate_size // sf_block_size
            )
            assert (
                layer.w13_bias.dim() == 2
                and layer.w13_bias.shape[0] == self.num_experts
                and layer.w13_bias.shape[1] == self.intermediate_size * 2
            )
            assert (
                layer.w2_bias.dim() == 2
                and layer.w2_bias.shape[0] == self.num_experts
                and layer.w2_bias.shape[1] == self.hidden_size
            )
507
508
509
510
511
512
513
514

            w13_weight_scale = layer.w13_weight_scale.data
            w2_weight_scale = layer.w2_weight_scale.data
            w13_weight = layer.w13_weight.data
            w2_weight = layer.w2_weight.data
            w13_bias = layer.w13_bias.data.to(torch.float32)
            w2_bias = layer.w2_bias.data.to(torch.float32)

co63oc's avatar
co63oc committed
515
            # Swap w1 and w3 as the definition of
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
            # swiglu is different in the trtllm-gen
            def swap_every_two_rows(x, axis=-1):
                shape = x.shape
                if axis < 0:
                    axis = len(shape) + axis

                # Create a new shape with pairs swapped along specified axis
                new_shape = list(shape)
                new_shape[axis] = shape[axis] // 2
                new_shape.insert(axis + 1, 2)

                # Reshape to expose pairs, swap them, and reshape back
                x = x.reshape(*new_shape)
                x = x.flip(axis + 1)
                new_shape = list(shape)
                return x.reshape(*new_shape)

            w13_weight_scale = swap_every_two_rows(w13_weight_scale, -2)
            w13_weight = swap_every_two_rows(w13_weight, -2)
            w13_bias = swap_every_two_rows(w13_bias, -1)

            # Do not interleave as the checkpoint is already interleaved

            # Shuffle weights and scaling factors for transposed mma output
            gemm1_weights_mxfp4_shuffled = []
            gemm1_scales_mxfp4_shuffled = []
            gemm2_weights_mxfp4_shuffled = []
            gemm2_scales_mxfp4_shuffled = []
            gemm1_bias_shuffled = []
            gemm2_bias_shuffled = []
            epilogue_tile_m = 128  # FIXME: this depends on the kernel internals
            for i in range(self.num_experts):
548
                # w13 weight shuffling
549
                permute_indices = get_w2_permute_indices_with_cache(
550
551
552
553
                    self._cache_permute_indices,
                    w13_weight[i].view(torch.uint8),
                    epilogue_tile_m,
                )
554
555
556
557
558
                gemm1_weights_mxfp4_shuffled.append(
                    w13_weight[i]
                    .view(torch.uint8)[permute_indices.to(w13_weight.device)]
                    .contiguous()
                )
559
                # w13 scale shuffling
560
                permute_sf_indices = get_w2_permute_indices_with_cache(
561
562
563
564
565
                    self._cache_permute_indices,
                    w13_weight_scale[i].view(torch.uint8),
                    epilogue_tile_m,
                    num_elts_per_sf=16,
                )
566
                gemm1_scales_mxfp4_shuffled.append(
567
568
569
570
571
572
573
574
                    nvfp4_block_scale_interleave(
                        w13_weight_scale[i]
                        .view(torch.uint8)[
                            permute_sf_indices.to(w13_weight_scale.device)
                        ]
                        .contiguous()
                    )
                )
575
                # w13 bias shuffling
576
                permute_bias_indices = get_w2_permute_indices_with_cache(
577
578
579
580
                    self._cache_permute_indices,
                    w13_bias[i].clone().reshape(-1, 1),
                    epilogue_tile_m,
                )
581
582
583
584
585
586
                gemm1_bias_shuffled.append(
                    w13_bias[i]
                    .clone()
                    .reshape(-1, 1)[permute_bias_indices.to(w13_bias.device)]
                    .contiguous()
                )
587
                # w2 weight shuffling
588
                permute_indices = get_w2_permute_indices_with_cache(
589
590
591
592
                    self._cache_permute_indices,
                    w2_weight[i].view(torch.uint8),
                    epilogue_tile_m,
                )
593
594
595
596
597
                gemm2_weights_mxfp4_shuffled.append(
                    w2_weight[i]
                    .view(torch.uint8)[permute_indices.to(w2_weight.device)]
                    .contiguous()
                )
598
                # w2 scale shuffling
599
                permute_sf_indices = get_w2_permute_indices_with_cache(
600
601
602
603
604
                    self._cache_permute_indices,
                    w2_weight_scale[i].view(torch.uint8),
                    epilogue_tile_m,
                    num_elts_per_sf=16,
                )
605
                gemm2_scales_mxfp4_shuffled.append(
606
607
608
609
610
611
612
613
                    nvfp4_block_scale_interleave(
                        w2_weight_scale[i]
                        .view(torch.uint8)[
                            permute_sf_indices.to(w2_weight_scale.device)
                        ]
                        .contiguous()
                    )
                )
614
                # w2 bias shuffling
615
                permute_indices = get_w2_permute_indices_with_cache(
616
617
618
619
                    self._cache_permute_indices,
                    w2_bias[i].clone().reshape(-1, 1),
                    epilogue_tile_m,
                )
620
621
622
623
624
625
                gemm2_bias_shuffled.append(
                    w2_bias[i]
                    .clone()
                    .reshape(-1, 1)[permute_indices.to(w2_bias.device)]
                    .contiguous()
                )
626
627

            w13_weight = torch.stack(gemm1_weights_mxfp4_shuffled)
628
629
630
631
632
633
634
635
636
            w13_weight_scale = (
                torch.stack(gemm1_scales_mxfp4_shuffled)
                .reshape(
                    self.num_experts,
                    2 * self.intermediate_size,
                    self.hidden_size // sf_block_size,
                )
                .view(torch.float8_e4m3fn)
            )
637
638

            w2_weight = torch.stack(gemm2_weights_mxfp4_shuffled)
639
640
641
642
643
644
645
646
647
            w2_weight_scale = (
                torch.stack(gemm2_scales_mxfp4_shuffled)
                .reshape(
                    self.num_experts,
                    self.hidden_size,
                    self.intermediate_size // sf_block_size,
                )
                .view(torch.float8_e4m3fn)
            )
648
649

            layer.w13_weight = Parameter(w13_weight, requires_grad=False)
650
            layer.w13_weight_scale = Parameter(w13_weight_scale, requires_grad=False)
651
            layer.w2_weight = Parameter(w2_weight, requires_grad=False)
652
            layer.w2_weight_scale = Parameter(w2_weight_scale, requires_grad=False)
653
654
            layer.w13_bias = Parameter(
                torch.stack(gemm1_bias_shuffled).reshape(self.num_experts, -1),
655
656
657
658
659
660
661
662
663
664
                requires_grad=False,
            )
            layer.w2_bias = Parameter(
                torch.stack(gemm2_bias_shuffled).reshape(self.num_experts, -1),
                requires_grad=False,
            )
        elif (
            self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
            or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
        ):
665
666
667
            sf_block_size = 32  # mxfp4 block size

            # Common shape assertions
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
            assert (
                layer.w13_weight.dim() == 3
                and layer.w13_weight.shape[0] == self.num_experts
                and layer.w13_weight.shape[1] == self.intermediate_size * 2
                and layer.w13_weight.shape[2] == self.hidden_size // 2
            )
            assert (
                layer.w13_weight_scale.dim() == 3
                and layer.w13_weight_scale.shape[0] == self.num_experts
                and layer.w13_weight_scale.shape[1] == self.intermediate_size * 2
                and layer.w13_weight_scale.shape[2] == self.hidden_size // sf_block_size
            )
            assert (
                layer.w2_weight.dim() == 3
                and layer.w2_weight.shape[0] == self.num_experts
                and layer.w2_weight.shape[1] == self.hidden_size
                and layer.w2_weight.shape[2] == self.intermediate_size // 2
            )
            assert (
                layer.w2_weight_scale.dim() == 3
                and layer.w2_weight_scale.shape[1] == self.hidden_size
                and layer.w2_weight_scale.shape[2]
                == self.intermediate_size // sf_block_size
            )
            assert (
                layer.w13_bias.dim() == 2
                and layer.w13_bias.shape[0] == self.num_experts
                and layer.w13_bias.shape[1] == self.intermediate_size * 2
            )
            assert (
                layer.w2_bias.dim() == 2
                and layer.w2_bias.shape[0] == self.num_experts
                and layer.w2_bias.shape[1] == self.hidden_size
            )
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726

            # De-interleave and swap for w13 weight, bias, and scales
            w13_w = layer.w13_weight.data
            gate_w, up_w = w13_w[:, ::2, :], w13_w[:, 1::2, :]
            deinterleaved_w13_w = torch.cat([gate_w, up_w], dim=1)
            w1_w, w3_w = torch.chunk(deinterleaved_w13_w, 2, dim=1)
            w13_weight_swapped = torch.cat([w3_w, w1_w], dim=1)

            w13_b = layer.w13_bias.data.to(torch.float32)
            gate_b, up_b = w13_b[:, ::2], w13_b[:, 1::2]
            deinterleaved_w13_b = torch.cat([gate_b, up_b], dim=1)
            b1, b3 = torch.chunk(deinterleaved_w13_b, 2, dim=-1)
            w13_bias_swapped = torch.cat([b3, b1], dim=-1).to(torch.bfloat16)

            w13_s = layer.w13_weight_scale.data
            gate_s, up_s = w13_s[:, ::2, :], w13_s[:, 1::2, :]
            deinterleaved_w13_s = torch.cat([gate_s, up_s], dim=1)
            s1, s3 = torch.chunk(deinterleaved_w13_s, 2, dim=1)
            w13_scale_swapped = torch.cat([s3, s1], dim=1)

            if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS:
                from flashinfer import block_scale_interleave

                orig_shape = w13_scale_swapped.shape
                w13_scale_interleaved = block_scale_interleave(
727
728
                    w13_scale_swapped.view(torch.uint8)
                ).reshape(orig_shape)
729
730
731
732

                w2_s = layer.w2_weight_scale.data
                orig_shape = w2_s.shape
                w2_scale_interleaved = block_scale_interleave(
733
734
735
736
737
738
739
740
741
742
743
                    w2_s.view(torch.uint8)
                ).reshape(orig_shape)

                layer.w13_weight = Parameter(w13_weight_swapped, requires_grad=False)
                layer.w13_weight_scale = Parameter(
                    w13_scale_interleaved, requires_grad=False
                )
                layer.w13_bias = Parameter(w13_bias_swapped, requires_grad=False)
                layer.w2_weight_scale = Parameter(
                    w2_scale_interleaved, requires_grad=False
                )
744
745
746
747
            elif self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16:

                def _interleave_mxfp4_cutlass_sm90(w):
                    w_shape = w.shape
748
749
750
                    w_interleaved = w.reshape(
                        w_shape[0], w_shape[1], (w_shape[2] // 4), 4
                    )
751
752
                    w_interleaved = w_interleaved.permute(0, 2, 1, 3)
                    w_interleaved = w_interleaved.reshape(
753
754
                        w_shape[0], w_shape[2] // 4, w_shape[1] * 4
                    )
755
756
                    return w_interleaved

757
758
                w31_scales = w13_scale_swapped.to(torch.uint8).view(torch.uint8)
                w31_scales_interleaved = _interleave_mxfp4_cutlass_sm90(w31_scales)
759
760
761

                w2_weight_scale = layer.w2_weight_scale.data
                w2_scales = w2_weight_scale.to(torch.uint8).view(torch.uint8)
762
763
764
765
766
767
768
769
                w2_scales_interleaved = _interleave_mxfp4_cutlass_sm90(w2_scales)

                layer.w13_weight = torch.nn.Parameter(
                    torch.cat([w3_w, w1_w], dim=1), requires_grad=False
                )
                layer.w13_bias = torch.nn.Parameter(
                    w13_bias_swapped, requires_grad=False
                )
770
                layer.w13_weight_scale = torch.nn.Parameter(
771
772
                    w31_scales_interleaved, requires_grad=False
                )
773
                layer.w2_weight_scale = torch.nn.Parameter(
774
775
                    w2_scales_interleaved, requires_grad=False
                )
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791

            # theses two kernels go through the `flashinfer_cutlass_fused_moe` path
            from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
                FlashInferExperts,
            )

            self.moe_quant_config = self.get_fused_moe_quant_config(layer)
            assert self.moe_quant_config is not None
            prepare_finalize = maybe_make_prepare_finalize(
                moe=self.moe,
                quant_config=self.moe_quant_config,
                routing_tables=layer._maybe_init_expert_routing_tables(),
                allow_new_interface=True,
            )
            assert prepare_finalize is not None

792
            self.moe_kernel = mk.FusedMoEKernel(
793
794
795
796
797
798
799
                prepare_finalize,
                FlashInferExperts(
                    moe_config=self.moe,
                    quant_config=self.moe_quant_config,
                ),
                shared_experts=None,
            )
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
        elif self.mxfp4_backend == Mxfp4Backend.CK:
            if layer.w13_bias is not None:
                layer.w13_bias.data = layer.w13_bias.data.to(torch.float32)
            if layer.w2_bias.data is not None:
                layer.w2_bias.data = layer.w2_bias.data.to(torch.float32)

            e, n, k = layer.w13_weight.shape
            layer.w13_weight.view(torch.uint8).copy_(
                layer.w13_weight.data.view(torch.uint8)
                .view(e, n // 2, 2, k)
                .permute(0, 2, 1, 3)
                .contiguous()
                .view(e, n, k)
            )
            layer.w13_weight_scale.data = (
                layer.w13_weight_scale.data.view(e, n // 2, 2, -1)
                .permute(0, 2, 1, 3)
                .contiguous()
                .view(e, n, -1)
            )
            layer.w13_weight.data = layer.w13_weight.data.view(torch.float4_e2m1fn_x2)
            layer.w2_weight.data = layer.w2_weight.data.view(torch.float4_e2m1fn_x2)

            layer.w13_weight.data = rocm_aiter_ops.shuffle_weight_a16w4(
                layer.w13_weight, 16, True
            )
            shuffled_w13_scale = rocm_aiter_ops.shuffle_scale_a16w4(
                layer.w13_weight_scale.view(-1, layer.w13_weight_scale.shape[-1]),
                self.num_experts,
                True,
            )

            layer.w2_weight.data = rocm_aiter_ops.shuffle_weight_a16w4(
                layer.w2_weight, 16, False
            )
            shuffled_w2_scale = rocm_aiter_ops.shuffle_scale_a16w4(
                layer.w2_weight_scale.view(-1, layer.w2_weight_scale.shape[-1]),
                self.num_experts,
                False,
            )

            layer.w13_bias.data = (
                layer.w13_bias.data.view(-1, n // 2, 2)
                .permute(0, 2, 1)
                .contiguous()
                .view(-1, n)
            )

            layer.w13_weight_scale = torch.nn.Parameter(
                shuffled_w13_scale, requires_grad=False
            )
            layer.w2_weight_scale = torch.nn.Parameter(
                shuffled_w2_scale, requires_grad=False
            )
            # replace_parameter(layer, "w13_bias", w13_bias)
            # replace_parameter(layer, "w13_weight_scale", w13_weight_scale)
            # replace_parameter(layer, "w2_weight_scale", w2_weight_scale)
            # replace_parameter(layer, "w13_weight", w13_weight)
            # replace_parameter(layer, "w2_weight", w2_weight)

860
        elif self.mxfp4_backend == Mxfp4Backend.TRITON:
861
862
863
864
865
866
867
            from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig

            w13_bias = layer.w13_bias.to(torch.float32)
            w2_bias = layer.w2_bias.to(torch.float32)

            layer.w13_bias = Parameter(w13_bias, requires_grad=False)
            layer.w2_bias = Parameter(w2_bias, requires_grad=False)
868
869
870
871
872
            # Ideally we'd use FusedMoEModularKernel.prepare_finalize object
            # (stored in self.fused_experts) to determine if the MoE has a
            # batched activation format. As self.fused_experts is not
            # initialized at this point, we resort to checking the MoE config
            # directly.
873
            is_batched_moe = self.moe.use_deepep_ll_kernels
874
            if is_batched_moe:
875
876
877
878
                num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8
            else:
                num_warps = 8
            w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
879
880
                layer.w13_weight, layer.w13_weight_scale, num_warps
            )
881
            w2_weight, w2_flex, w2_scale = _swizzle_mxfp4(
882
883
                layer.w2_weight, layer.w2_weight_scale, num_warps
            )
884
885

            self.w13_precision_config = PrecisionConfig(
886
887
                weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex)
            )
888
            self.w2_precision_config = PrecisionConfig(
889
890
                weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex)
            )
891
892
            self.w13_weight = w13_weight
            self.w2_weight = w2_weight
893
894
895
896
            del layer.w13_weight
            del layer.w2_weight
            layer.w13_weight = w13_weight
            layer.w2_weight = w2_weight
897

898
        else:
899
900
901
902
            raise ValueError(
                f"Unsupported mxfp4_backend: {self.mxfp4_backend}: "
                f"should be one of: {list(Mxfp4Backend)}."
            )
903

904
    def get_fused_moe_quant_config(
905
        self, layer: torch.nn.Module
906
    ) -> FusedMoEQuantConfig | None:
907
        if self.mxfp4_backend == Mxfp4Backend.MARLIN:
908
909
910
911
912
913
914
            return mxfp4_w4a16_moe_quant_config(
                w1_bias=layer.w13_bias,
                w2_bias=layer.w2_bias,
                w1_scale=layer.w13_weight_scale,
                w2_scale=layer.w2_weight_scale,
            )
        elif self.mxfp4_backend == Mxfp4Backend.TRITON:
915
916
            w1_scale = self.w13_precision_config
            w2_scale = self.w2_precision_config
917
918
919
920
921
922
            return mxfp4_w4a16_moe_quant_config(
                w1_bias=layer.w13_bias,
                w2_bias=layer.w2_bias,
                w1_scale=w1_scale,
                w2_scale=w2_scale,
            )
923
924
925
926
927
928
929
930
931
932
        elif self.mxfp4_backend in [
            Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM,
            Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS,
        ]:
            return mxfp4_mxfp8_moe_quant_config(
                w1_bias=layer.w13_bias,
                w2_bias=layer.w2_bias,
                w1_scale=layer.w13_weight_scale,
                w2_scale=layer.w2_weight_scale,
            )
933
934
935
        elif self.mxfp4_backend in [
            Mxfp4Backend.SM100_FI_MXFP4_BF16,
            Mxfp4Backend.SM90_FI_MXFP4_BF16,
936
            Mxfp4Backend.CK,
937
        ]:
938
939
940
941
942
943
            return mxfp4_w4a16_moe_quant_config(
                w1_bias=layer.w13_bias,
                w2_bias=layer.w2_bias,
                w1_scale=layer.w13_weight_scale,
                w2_scale=layer.w2_weight_scale,
            )
944
945
946
        else:
            w1_scale = layer.w13_weight_scale
            w2_scale = layer.w2_weight_scale
947
948
            return ocp_mx_moe_quant_config(
                quant_dtype="mxfp4",
949
950
951
952
953
                w1_bias=layer.w13_bias,
                w2_bias=layer.w2_bias,
                w1_scale=w1_scale,
                w2_scale=w2_scale,
            )
954

955
956
    def select_gemm_impl(
        self,
957
        prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular,
958
        layer: torch.nn.Module,
959
    ) -> mk.FusedMoEExpertsModular:
960
961
962
963
        if (
            prepare_finalize.activation_format
            == mk.FusedMoEActivationFormat.BatchedExperts
        ):
964
965
966
967
968
969
970
971
            if self.mxfp4_backend == Mxfp4Backend.MARLIN:
                max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
                assert max_num_tokens_per_rank is not None
                assert self.moe_quant_config is not None
                return BatchedMarlinExperts(
                    max_num_tokens=max_num_tokens_per_rank,
                    num_dispatchers=prepare_finalize.num_dispatchers(),
                    quant_config=self.moe_quant_config,
972
                    moe_config=self.moe,
973
974
975
                )
            else:
                raise NotImplementedError(
976
977
                    f"Incompatible Mxfp4 backend ({self.mxfp4_backend}) for "
                    "EP batched experts format"
978
                )
979
        else:
980
            assert self.moe_quant_config is not None
981
982
983
984
            if (
                self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
                or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
            ):
985
986
                # B200 code-path
                kwargs = {
987
                    # TODO(bnell): part of quant_config
988
989
                    "max_capture_size": self.max_capture_size,
                }
990
991
                return TrtLlmGenExperts(self.moe, self.moe_quant_config, **kwargs)
            elif self.mxfp4_backend == Mxfp4Backend.MARLIN:
992
                return MarlinExperts(self.moe, self.moe_quant_config)
993
            elif self.mxfp4_backend == Mxfp4Backend.TRITON:
994
                if self.moe.is_lora_enabled:
995
996
                    return UnfusedOAITritonExperts(self.moe, self.moe_quant_config)
                return OAITritonExperts(self.moe, self.moe_quant_config)
997
998
999
1000
            else:
                raise NotImplementedError(
                    f"Incompatible Mxfp4 backend ({self.mxfp4_backend}) for EP"
                )
1001

1002
1003
    @property
    def is_monolithic(self) -> bool:
1004
1005
        if self.moe.is_lora_enabled:
            return False
1006
1007
1008
1009
        return (
            self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
            or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
            or self.mxfp4_backend == Mxfp4Backend.TRITON
1010
            or self.mxfp4_backend == Mxfp4Backend.CK
1011
1012
        )

1013
1014
    def apply(
        self,
1015
        layer: FusedMoE,
1016
        x: torch.Tensor,
1017
1018
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
1019
        shared_experts_input: torch.Tensor | None,
1020
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
1021
        assert not self.is_monolithic
1022
        if layer.enable_eplb:
1023
1024
1025
            raise NotImplementedError("EPLB is not supported for mxfp4")

        assert _can_support_mxfp4(
1026
1027
1028
1029
1030
1031
1032
1033
1034
            layer.use_grouped_topk,
            layer.topk_group,
            layer.num_expert_group,
            layer.expert_map,
            layer.custom_routing_function,
            layer.e_score_correction_bias,
            layer.apply_router_weight_on_input,
            layer.scoring_func,
            layer.activation,
1035
1036
1037
            layer.eplb_state.expert_load_view,
            layer.eplb_state.logical_to_physical_map,
            layer.eplb_state.logical_replica_count,
1038
1039
        ), "MXFP4 are not supported with this configuration."

1040
1041
1042
        assert (
            self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
            or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
1043
            or self.mxfp4_backend == Mxfp4Backend.MARLIN
1044
1045
        )

1046
1047
        assert self.moe_kernel is not None
        return self.moe_kernel.apply(
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
            hidden_states=x,
            w1=layer.w13_weight,
            w2=layer.w2_weight,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            activation=layer.activation,
            global_num_experts=layer.global_num_experts,
            apply_router_weight_on_input=layer.apply_router_weight_on_input,
            expert_map=layer.expert_map,
            shared_experts_input=shared_experts_input,
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
        )

    def apply_monolithic(
        self,
        layer: FusedMoE,
        x: torch.Tensor,
        router_logits: torch.Tensor,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        assert self.is_monolithic

        if layer.enable_eplb:
            raise NotImplementedError("EPLB is not supported for mxfp4")

        assert _can_support_mxfp4(
            layer.use_grouped_topk,
            layer.topk_group,
            layer.num_expert_group,
            layer.expert_map,
            layer.custom_routing_function,
            layer.e_score_correction_bias,
            layer.apply_router_weight_on_input,
            layer.scoring_func,
            layer.activation,
            layer.eplb_state.expert_load_view,
            layer.eplb_state.logical_to_physical_map,
            layer.eplb_state.logical_replica_count,
        ), "MXFP4 are not supported with this configuration."

1086
1087
1088
1089
        if (
            self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
            or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
        ):
1090
            from flashinfer import trtllm_fp4_block_scale_moe
1091

1092
            if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16:
1093
1094
1095
                assert x.dtype == torch.bfloat16
                x_quant = x
                x_scale = None
1096
1097
            elif self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM:
                from flashinfer import mxfp8_quantize
1098

1099
                x_quant, x_scale = mxfp8_quantize(x, False)  # to mxfp8
1100
                x_scale = x_scale.view(torch.float8_e4m3fn).reshape(*x.shape[:-1], -1)
1101

1102
            trtllm_gen_output = trtllm_fp4_block_scale_moe(
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
                routing_logits=router_logits.to(torch.bfloat16),
                routing_bias=None,
                hidden_states=x_quant,
                hidden_states_scale=x_scale,
                gemm1_weights=layer.w13_weight,  # uint8 (e2m1 x 2)
                gemm1_weights_scale=layer.w13_weight_scale,  # uint8 (e4m3 x 2)
                gemm1_bias=layer.w13_bias,  # fp32 per expert per channel
                gemm1_alpha=layer.gemm1_alpha,  # fp32 per expert
                gemm1_beta=layer.gemm1_beta,  # fp32 per expert
                gemm1_clamp_limit=layer.gemm1_clamp_limit,  # fp32 per expert
                gemm2_weights=layer.w2_weight,  # uint8 (e2m1 x 2)
                gemm2_weights_scale=layer.w2_weight_scale,  # ue8m0
                gemm2_bias=layer.w2_bias,  # fp32 per expert per channel
                output1_scale_scalar=None,
                output1_scale_gate_scalar=None,
                output2_scale_scalar=None,
                num_experts=layer.global_num_experts,
                top_k=layer.top_k,
                n_group=None,
                topk_group=None,
                intermediate_size=self.intermediate_size,  # padded to multiple of 256
                local_expert_offset=layer.ep_rank * layer.local_num_experts,
                local_num_experts=self.num_experts,
                routed_scaling_factor=None,
                routing_method_type=1 if layer.renormalize else 0,
                do_finalize=True,
1129
                tune_max_num_tokens=max(self.max_capture_size, 1),
1130
1131
            )[0]
            return trtllm_gen_output
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
        elif self.mxfp4_backend == Mxfp4Backend.CK:
            topk_weights, topk_ids = rocm_aiter_ops.fused_topk(
                x, router_logits, layer.top_k, True
            )
            output = rocm_aiter_ops.fused_moe(
                x,
                layer.w13_weight,
                layer.w2_weight,
                topk_weights,
                topk_ids,
                activation_method=rocm_aiter_ops.get_aiter_activation_type("swiglu"),
                quant_method=rocm_aiter_ops.get_aiter_quant_type("per_1x32"),
                w1_scale=layer.w13_weight_scale,
                w2_scale=layer.w2_weight_scale,
                doweight_stage1=False,
                hidden_pad=self.hidden_pad // 128 * 128,
                intermediate_pad=self.intermediate_pad // 64 * 64 * 2,
                bias1=layer.w13_bias,
                bias2=layer.w2_bias,
            )
            return output
1153
        elif self.mxfp4_backend == Mxfp4Backend.TRITON:
1154
            from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (  # noqa: E501
1155
1156
1157
                triton_kernel_moe_forward,
            )

1158
1159
            return triton_kernel_moe_forward(
                hidden_states=x,
1160
1161
                w1=layer.w13_weight,
                w2=layer.w2_weight,
1162
                gating_output=router_logits,
1163
1164
1165
1166
                topk=layer.top_k,
                renormalize=layer.renormalize,
                global_num_experts=layer.global_num_experts,
                expert_map=layer.expert_map,
1167
                quant_config=self.moe_quant_config,
1168
                apply_router_weight_on_input=layer.apply_router_weight_on_input,
1169
            )
1170
1171
        else:
            raise ValueError(f"Unsupported backend: {self.mxfp4_backend}")
1172
1173


1174
class XpuMxfp4MoEMethod(Mxfp4MoEMethod):
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
    def __init__(self, moe_config: FusedMoEConfig):
        super().__init__(moe_config)
        self.moe_config = moe_config

    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        super().create_weights(
            layer,
            num_experts,
            hidden_size,
            intermediate_size_per_partition,
            params_dtype,
            **extra_weight_attrs,
        )
        self.original_hidden_size = hidden_size

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
1199
        pass
1200

1201
1202
1203
1204
1205
    @property
    def is_monolithic(self) -> bool:
        return True

    def apply_monolithic(
1206
        self,
1207
        layer: FusedMoE,
1208
1209
1210
        x: torch.Tensor,
        router_logits: torch.Tensor,
    ) -> torch.Tensor:
1211
1212
1213
        assert layer.activation == MoEActivation.SWIGLUOAI, (
            "Only swiglu_oai activation is supported for "
            f"XPU MXFP4 MoE, not {layer.activation}."
1214
        )
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
        from vllm_xpu_kernels.fused_moe_interface import xpu_fused_moe

        M, _ = x.size()
        routing_weights = torch.empty(
            M, layer.top_k, dtype=torch.float32, device=x.device
        )
        selected_experts = torch.empty(
            M, layer.top_k, dtype=torch.int32, device=x.device
        )
        token_expert_indices = torch.empty(
            M, layer.top_k, dtype=torch.int32, device=x.device
        )

        if layer.use_grouped_topk:
            routing_weights, selected_experts = torch.ops._moe_C.fused_grouped_topk(
                x,
                router_logits,
                layer.top_k,
                layer.renormalize,
                n_expert_group=layer.num_expert_group,
                n_topk_group=layer.topk_group,
                scoring_func=layer.scoring_func,
                routed_scaling_factor=layer.routed_scaling_factor,
                bias=layer.e_score_correction_bias,
            )
        else:
            torch.ops._moe_C.topk_softmax(
                routing_weights,
                selected_experts,
                token_expert_indices,
                router_logits,
                layer.renormalize,
                layer.e_score_correction_bias,
            )

        return xpu_fused_moe(
            hidden_states=x,
            w13=layer.w13_weight,
            w13_bias=layer.w13_bias if self.moe.has_bias else None,
            w13_scales=layer.w13_weight_scale,
            w2=layer.w2_weight,
            w2_bias=layer.w2_bias if self.moe.has_bias else None,
            w2_scales=layer.w2_weight_scale,
            topk_weights=routing_weights,
            topk_ids=selected_experts,
            n_experts_per_token=layer.top_k,
1261
            activation=layer.activation.value,
1262
1263
            num_experts=layer.local_num_experts,
            is_mxfp4=True,
1264
        )