mxfp4.py 50.8 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from enum import Enum
4
5
6
7
8

import torch
from torch.nn.parameter import Parameter

from vllm import envs
9
from vllm._aiter_ops import rocm_aiter_ops
10
from vllm.config import get_current_vllm_config
11
from vllm.logger import init_logger
12
from vllm.model_executor.layers.attention import Attention
13
14
15
16
from vllm.model_executor.layers.fused_moe import (
    FusedMoE,
    FusedMoEConfig,
    FusedMoEMethodBase,
17
    MoEActivation,
18
)
19
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
20
21
22
from vllm.model_executor.layers.fused_moe.all2all_utils import (
    maybe_make_prepare_finalize,
)
23
from vllm.model_executor.layers.fused_moe.config import (
24
    FusedMoEQuantConfig,
25
    mxfp4_mxfp8_moe_quant_config,
26
    mxfp4_w4a16_moe_quant_config,
27
    ocp_mx_moe_quant_config,
28
)
29
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
30
    BatchedMarlinExperts,
31
32
    MarlinExperts,
)
33
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
34
    OAITritonExperts,
35
    UnfusedOAITritonExperts,
36
)
37
from vllm.model_executor.layers.fused_moe.trtllm_moe import TrtLlmGenExperts
38
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
39
40
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
41
42
43
    QuantizationConfig,
    QuantizeMethodBase,
)
44
45
46
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
    get_marlin_input_dtype,
)
47
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
48
49
    prepare_moe_fp4_layer_for_marlin,
)
50
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
51
    CK_MXFP4_MOE_DIM_ALIGNMENT,
52
53
    _can_support_mxfp4,
    _swizzle_mxfp4,
54
    get_padding_alignment,
55
56
)
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
57
from vllm.model_executor.utils import set_weight_attrs
58
from vllm.platforms import current_platform
59
from vllm.utils.flashinfer import has_flashinfer
60
from vllm.utils.import_utils import has_triton_kernels
Cyrus Leung's avatar
Cyrus Leung committed
61
from vllm.utils.math_utils import round_up
62

63
64
65
logger = init_logger(__name__)


66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# enum for mxfp4 backend
class Mxfp4Backend(Enum):
    NONE = 0

    # FlashInfer Backend
    SM100_FI_MXFP4_MXFP8_TRTLLM = 1
    SM100_FI_MXFP4_MXFP8_CUTLASS = 2
    SM100_FI_MXFP4_BF16 = 3
    SM90_FI_MXFP4_BF16 = 4

    # Marlin Backend
    MARLIN = 5

    # Triton Backend
    TRITON = 6

82
83
    CK = 7

84

85
86
87
88
89
90
91
92
def get_mxfp4_backend_with_lora() -> Mxfp4Backend:
    """
    Not all MXFP4 backends support LoRA. Select backends that are known to
    have LoRA support.
    """
    if not current_platform.is_cuda():
        return Mxfp4Backend.NONE

93
94
95
96
97
98
99
100
    # If FlashInfer is not available, try either Marlin or Triton
    triton_kernels_supported = (
        has_triton_kernels()
        # NOTE: triton_kernels are only confirmed to work on SM90 and SM100
        # SM110 fails with this error: https://github.com/vllm-project/vllm/issues/29317
        # SM120 needs this fix: https://github.com/triton-lang/triton/pull/8498
        and (9, 0) <= current_platform.get_device_capability() < (11, 0)
    )
101
102
103
    if envs.VLLM_MXFP4_USE_MARLIN is False and triton_kernels_supported:
        logger.info_once("[get_mxfp4_backend_with_lora] Using Triton backend")
        return Mxfp4Backend.TRITON
104

105
106
    logger.info_once("[get_mxfp4_backend_with_lora] Using Marlin backend")
    return Mxfp4Backend.MARLIN
107
108
109


def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend:
110
    # Backend Selection
111
112
113
114

    if with_lora_support:
        return get_mxfp4_backend_with_lora()

115
    if current_platform.is_cuda():
116
117
118
119
120
        if (
            current_platform.is_device_capability(90)
            and has_flashinfer()
            and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16
        ):
121
122
            logger.info_once("Using FlashInfer MXFP4 BF16 backend for SM90")
            return Mxfp4Backend.SM90_FI_MXFP4_BF16
123
        elif (
124
            current_platform.is_device_capability_family(100)
125
126
127
128
            and has_flashinfer()
            and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS
        ):
            logger.info_once("Using FlashInfer MXFP4 MXFP8 CUTLASS backend for SM100")
129
            return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
130
        elif (
131
            current_platform.is_device_capability_family(100)
132
133
134
            and has_flashinfer()
            and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
        ):
135
136
137
            logger.info_once(
                "Using FlashInfer MXFP4 MXFP8 TRTLLM backend for SM100", scope="local"
            )
138
            return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
139
        elif current_platform.is_device_capability_family(100) and has_flashinfer():
140
141
142
143
            logger.info_once(
                "Using FlashInfer MXFP4 BF16 backend for SM100, "
                "For faster performance on SM100, consider setting "
                "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1, though this may impact "
144
145
                "accuracy."
            )
146
            return Mxfp4Backend.SM100_FI_MXFP4_BF16
147
        elif (
148
            current_platform.is_device_capability_family(100)
149
150
            or current_platform.is_device_capability(90)
        ) and not has_flashinfer():
151
152
153
            logger.warning_once(
                "MXFP4 MoE is enabled on Hopper/Blackwell but FlashInfer "
                "is not available. This may result in degraded performance. "
154
155
                "Please `pip install vllm[flashinfer]` for best results."
            )
156

157
        # If FlashInfer is not available, try either Marlin or Triton
158
159
160
161
162
163
164
165
        triton_kernels_supported = (
            has_triton_kernels()
            # NOTE: triton_kernels are only confirmed to work on SM90 and SM100
            # SM110 fails with this error: https://github.com/vllm-project/vllm/issues/29317
            # SM120 needs this fix: https://github.com/triton-lang/triton/pull/8498
            and (9, 0) <= current_platform.get_device_capability() < (11, 0)
        )
        if envs.VLLM_MXFP4_USE_MARLIN or not triton_kernels_supported:
166
167
168
169
170
            logger.info_once("Using Marlin backend")
            return Mxfp4Backend.MARLIN
        else:
            logger.info_once("Using Triton backend")
            return Mxfp4Backend.TRITON
171
    elif current_platform.is_xpu():
172
        logger.info_once("Using xpu backend on XPU")
173
        return Mxfp4Backend.MARLIN
174
175
176
177
178
179
180
181
182
    elif current_platform.is_rocm():
        from vllm.platforms.rocm import on_gfx950

        if rocm_aiter_ops.is_enabled() and on_gfx950():
            logger.info_once("Using CK MXFP4 MoE backend (Aiter ROCm)")
            return Mxfp4Backend.CK
        elif has_triton_kernels():
            logger.info_once("Using Triton backend")
            return Mxfp4Backend.TRITON
183

184
    return Mxfp4Backend.NONE
185
186
187


class Mxfp4Config(QuantizationConfig):
188
    def __init__(self, ignored_layers: list[str] | None = None):
189
190
191
192
193
194
195
196
197
        super().__init__()
        self.ignored_layers = ignored_layers

    @classmethod
    def from_config(cls, config):
        return cls()

    @classmethod
    def get_min_capability(cls) -> int:
198
        return 80
199
200
201
202
203
204
205
206
207
208
209
210
211

    @classmethod
    def get_name(cls) -> QuantizationMethods:
        return "mxfp4"

    @classmethod
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
        return [torch.bfloat16]

    @classmethod
    def get_config_filenames(cls) -> list[str]:
        return []

212
213
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
214
    ) -> "QuantizeMethodBase | None":
215
216
        if isinstance(layer, LinearBase):
            if self.ignored_layers and is_layer_skipped(
217
218
219
220
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
221
                return UnquantizedLinearMethod()
222
223
224
            # TODO: Add support for MXFP4 Linear Method.
            # MXFP4 LinearMethod is available in AMD-Quark, refer to that implementation
            # if you are interested in enabling MXFP4 here.
225
            logger.debug_once(
226
                "MXFP4 linear layer is not implemented - falling back to "
227
228
                "UnquantizedLinearMethod.",
                scope="local",
229
230
            )
            return UnquantizedLinearMethod()
231
        elif isinstance(layer, FusedMoE):
232
            if current_platform.is_xpu():
233
                return XpuMxfp4MoEMethod(layer.moe_config)
234
            else:
235
236
                quant_method = Mxfp4MoEMethod(layer.moe_config)
                return quant_method
237
        elif isinstance(layer, Attention):
238
            # TODO: Add support for MXFP4 Attention.
239
            logger.debug_once(
240
                "MXFP4 attention layer is not implemented. "
241
242
                "Skipping quantization for this layer.",
                scope="local",
243
            )
244
245
        return None

246
247
248
249
    def is_mxfp4_quant(self, prefix: str, layer: torch.nn.Module) -> bool:
        """MXFP4 config always uses MXFP4 quantization."""
        return True

250
251

class Mxfp4MoEMethod(FusedMoEMethodBase):
252
253
    """MXFP4 MoE quantization method."""

254
    def __init__(self, moe: FusedMoEConfig):
255
        super().__init__(moe)
256
        self.weight_dtype = "mxfp4"
257
        self.mxfp4_backend = get_mxfp4_backend(moe.is_lora_enabled)
258

259
        self.max_capture_size = (
260
            get_current_vllm_config().compilation_config.max_cudagraph_capture_size
261
        )
262

263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
        # CK's pre-compiled MXFP4 MoE GEMM kernel instances have dimension
        # alignment requirements. Fall back to Triton when not met.
        if (
            self.mxfp4_backend == Mxfp4Backend.CK
            and moe.intermediate_size_per_partition % CK_MXFP4_MOE_DIM_ALIGNMENT != 0
        ):
            if has_triton_kernels():
                logger.warning_once(
                    "CK MXFP4 MoE GEMM does not support "
                    "intermediate_size_per_partition=%d (not a multiple of "
                    "%d). Falling back to Triton backend.",
                    moe.intermediate_size_per_partition,
                    CK_MXFP4_MOE_DIM_ALIGNMENT,
                )
                self.mxfp4_backend = Mxfp4Backend.TRITON
            else:
                raise ValueError(
                    f"CK MXFP4 MoE GEMM does not support "
                    f"intermediate_size_per_partition="
                    f"{moe.intermediate_size_per_partition} (not a multiple "
                    f"of {CK_MXFP4_MOE_DIM_ALIGNMENT}) and no Triton "
                    f"fallback is available. Use a compatible "
                    f"tensor_parallel_size."
                )

288
        assert self.mxfp4_backend != Mxfp4Backend.NONE, (
289
290
            f"get_mxfp4_backend(with_lora_support={moe.is_lora_enabled}) found"
            "no compatible MXFP4 MoE backend (FlashInfer/Marlin/Triton)."
291
292
            "Please check your environment and try again."
        )
293
        self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
294
        # Initialized in process_weights_after_loading for CUTLASS/SM90 backends
295
        self.moe_kernel: mk.FusedMoEKernel | None = None
296

297
298
299
300
301
302
303
304
305
    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
306
307
308
309
310
311
312
313
314
315
316
317
318
319
        self.num_experts = num_experts
        weight_dtype = torch.uint8
        scale_dtype = torch.uint8

        # FIXME (zyongye): ship after torch and safetensors support mxfp4
        # is_torch_mxfp4_available = (
        #     hasattr(torch, "float4_e2m1fn_x2") and
        #     hasattr(torch, "float8_e8m0fnu"))
        # if is_torch_mxfp4_available:
        #     weight_dtype = torch.float4_e2m1fn_x2
        #     scale_dtype = torch.float8_e8m0fnu

        mxfp4_block = 32

320
        intermediate_size_per_partition_after_pad = intermediate_size_per_partition
321
        if self.mxfp4_backend == Mxfp4Backend.MARLIN:
322
323
324
325
326
327
328
329
330
            # The moe marlin kernel requires that for each linear
            # n % 256 == 0 and k % 128 == 0.
            # In gate_up_proj:
            #    n = 2 * intermediate_size_per_partition_after_pad
            #    k = hidden_size
            # In down_proj
            #    n = hidden_size
            #    k = intermediate_size_per_partition_after_pad
            intermediate_size_per_partition_after_pad = round_up(
331
332
                intermediate_size_per_partition, 128
            )
333
334
335
336
            if current_platform.is_xpu():
                hidden_size = round_up(hidden_size, 128)
            else:
                hidden_size = round_up(hidden_size, 256)
337
338
339
340

            layer.params_dtype = params_dtype
            layer.num_experts = num_experts
            layer.hidden_size = hidden_size
341
            layer.intermediate_size_per_partition = (
342
                intermediate_size_per_partition_after_pad
343
344
345
346
347
            )
        elif (
            self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
            or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
        ):
348
349
350
            # pad the intermediate size to be a multiple of 2 * mxfp4_block
            # for to hold non-uniform sharded tensor as well as swizzling
            # other padding to increase performance
351
            intermediate_size_per_partition_after_pad = round_up(
352
353
                intermediate_size_per_partition, 256
            )
354
            hidden_size = round_up(hidden_size, 256)
355
356
357
358
        elif (
            self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
            or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
        ):
359
            intermediate_size_per_partition_after_pad = round_up(
360
361
                intermediate_size_per_partition, 128
            )
362
            hidden_size = round_up(hidden_size, 128)
363
        elif current_platform.is_rocm():
364
            pad_align = get_padding_alignment()
365
            intermediate_size_per_partition_after_pad = round_up(
366
                intermediate_size_per_partition, pad_align
367
            )
368
            hidden_size = round_up(hidden_size, pad_align)
369
370
        else:
            intermediate_size_per_partition_after_pad = round_up(
371
372
                intermediate_size_per_partition, 64
            )
373
374
375

        self.intermediate_size = intermediate_size_per_partition_after_pad
        self.hidden_size = hidden_size
376
377
378
379
        self.hidden_pad = extra_weight_attrs.get("hidden_pad", 0)
        self.intermediate_pad = (
            intermediate_size_per_partition_after_pad - intermediate_size_per_partition
        )
380
        # Fused gate_up_proj (column parallel)
381
382
383
384
385
386
387
388
389
        w13_weight = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                2 * intermediate_size_per_partition_after_pad,
                hidden_size // 2,
                dtype=weight_dtype,
            ),
            requires_grad=False,
        )
390
391
392
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

393
394
395
396
397
398
399
400
401
        w13_weight_scale = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                2 * intermediate_size_per_partition_after_pad,
                hidden_size // mxfp4_block,
                dtype=scale_dtype,
            ),
            requires_grad=False,
        )
402
403
404
        layer.register_parameter("w13_weight_scale", w13_weight_scale)
        set_weight_attrs(w13_weight_scale, extra_weight_attrs)

405
406
407
408
409
410
411
412
        w13_bias = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                2 * intermediate_size_per_partition_after_pad,
                dtype=torch.bfloat16,
            ),
            requires_grad=False,
        )
413
414
415
416
        layer.register_parameter("w13_bias", w13_bias)
        set_weight_attrs(w13_bias, extra_weight_attrs)

        # down_proj (row parallel)
417
418
419
420
421
422
423
424
425
        w2_weight = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                hidden_size,
                intermediate_size_per_partition_after_pad // 2,
                dtype=weight_dtype,
            ),
            requires_grad=False,
        )
426
427
428
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

429
430
431
432
433
434
435
436
437
        w2_weight_scale = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                hidden_size,
                intermediate_size_per_partition_after_pad // mxfp4_block,
                dtype=scale_dtype,
            ),
            requires_grad=False,
        )
438
439
440
        layer.register_parameter("w2_weight_scale", w2_weight_scale)
        set_weight_attrs(w2_weight_scale, extra_weight_attrs)

441
442
443
444
445
446
447
448
        w2_bias = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                hidden_size,
                dtype=torch.bfloat16,
            ),
            requires_grad=False,
        )
449
450
451
452
        layer.register_parameter("w2_bias", w2_bias)
        set_weight_attrs(w2_bias, extra_weight_attrs)

    def process_weights_after_loading(self, layer):
453
        if self.mxfp4_backend == Mxfp4Backend.MARLIN:
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
            prepare_moe_fp4_layer_for_marlin(
                layer, input_dtype=get_marlin_input_dtype()
            )

            self.moe_quant_config = self.get_fused_moe_quant_config(layer)
            assert self.moe_quant_config is not None

            prepare_finalize = maybe_make_prepare_finalize(
                moe=self.moe,
                quant_config=self.moe_quant_config,
                routing_tables=layer._maybe_init_expert_routing_tables(),
                allow_new_interface=True,
            )
            assert prepare_finalize is not None

469
            self.moe_kernel = mk.FusedMoEKernel(
470
471
472
473
474
475
476
477
                prepare_finalize,
                MarlinExperts(
                    self.moe,
                    self.moe_quant_config,
                ),
                inplace=not self.moe.disable_inplace,
                shared_experts=None,
            )
478
479
480
481
482
        elif (
            self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
            or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
        ):
            from flashinfer.fp4_quantization import nvfp4_block_scale_interleave
483
            from flashinfer.fused_moe.core import get_w2_permute_indices_with_cache
484
485
486
487
488
489
490
491
492
493
494
495
496

            layer.gemm1_alpha = Parameter(
                torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(),
                requires_grad=False,
            )
            layer.gemm1_beta = Parameter(
                torch.tensor([1.0] * self.num_experts, dtype=torch.float32).cuda(),
                requires_grad=False,
            )
            layer.gemm1_clamp_limit = Parameter(
                torch.tensor([7.0] * self.num_experts, dtype=torch.float32).cuda(),
                requires_grad=False,
            )
497
498
            sf_block_size = 32  # mxfp4 block size

499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
            assert (
                layer.w13_weight.dim() == 3
                and layer.w13_weight.shape[0] == self.num_experts
                and layer.w13_weight.shape[1] == self.intermediate_size * 2
                and layer.w13_weight.shape[2] == self.hidden_size // 2
            )
            assert (
                layer.w13_weight_scale.dim() == 3
                and layer.w13_weight_scale.shape[0] == self.num_experts
                and layer.w13_weight_scale.shape[1] == self.intermediate_size * 2
                and layer.w13_weight_scale.shape[2] == self.hidden_size // sf_block_size
            )
            assert (
                layer.w2_weight.dim() == 3
                and layer.w2_weight.shape[0] == self.num_experts
                and layer.w2_weight.shape[1] == self.hidden_size
                and layer.w2_weight.shape[2] == self.intermediate_size // 2
            )
            assert (
                layer.w2_weight_scale.dim() == 3
                and layer.w2_weight_scale.shape[1] == self.hidden_size
                and layer.w2_weight_scale.shape[2]
                == self.intermediate_size // sf_block_size
            )
            assert (
                layer.w13_bias.dim() == 2
                and layer.w13_bias.shape[0] == self.num_experts
                and layer.w13_bias.shape[1] == self.intermediate_size * 2
            )
            assert (
                layer.w2_bias.dim() == 2
                and layer.w2_bias.shape[0] == self.num_experts
                and layer.w2_bias.shape[1] == self.hidden_size
            )
533
534
535
536
537
538
539
540

            w13_weight_scale = layer.w13_weight_scale.data
            w2_weight_scale = layer.w2_weight_scale.data
            w13_weight = layer.w13_weight.data
            w2_weight = layer.w2_weight.data
            w13_bias = layer.w13_bias.data.to(torch.float32)
            w2_bias = layer.w2_bias.data.to(torch.float32)

co63oc's avatar
co63oc committed
541
            # Swap w1 and w3 as the definition of
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
            # swiglu is different in the trtllm-gen
            def swap_every_two_rows(x, axis=-1):
                shape = x.shape
                if axis < 0:
                    axis = len(shape) + axis

                # Create a new shape with pairs swapped along specified axis
                new_shape = list(shape)
                new_shape[axis] = shape[axis] // 2
                new_shape.insert(axis + 1, 2)

                # Reshape to expose pairs, swap them, and reshape back
                x = x.reshape(*new_shape)
                x = x.flip(axis + 1)
                new_shape = list(shape)
                return x.reshape(*new_shape)

            w13_weight_scale = swap_every_two_rows(w13_weight_scale, -2)
            w13_weight = swap_every_two_rows(w13_weight, -2)
            w13_bias = swap_every_two_rows(w13_bias, -1)

            # Do not interleave as the checkpoint is already interleaved

            # Shuffle weights and scaling factors for transposed mma output
            gemm1_weights_mxfp4_shuffled = []
            gemm1_scales_mxfp4_shuffled = []
            gemm2_weights_mxfp4_shuffled = []
            gemm2_scales_mxfp4_shuffled = []
            gemm1_bias_shuffled = []
            gemm2_bias_shuffled = []
            epilogue_tile_m = 128  # FIXME: this depends on the kernel internals
            for i in range(self.num_experts):
574
                # w13 weight shuffling
575
                permute_indices = get_w2_permute_indices_with_cache(
576
577
578
579
                    self._cache_permute_indices,
                    w13_weight[i].view(torch.uint8),
                    epilogue_tile_m,
                )
580
581
582
583
584
                gemm1_weights_mxfp4_shuffled.append(
                    w13_weight[i]
                    .view(torch.uint8)[permute_indices.to(w13_weight.device)]
                    .contiguous()
                )
585
                # w13 scale shuffling
586
                permute_sf_indices = get_w2_permute_indices_with_cache(
587
588
589
590
591
                    self._cache_permute_indices,
                    w13_weight_scale[i].view(torch.uint8),
                    epilogue_tile_m,
                    num_elts_per_sf=16,
                )
592
                gemm1_scales_mxfp4_shuffled.append(
593
594
595
596
597
598
599
600
                    nvfp4_block_scale_interleave(
                        w13_weight_scale[i]
                        .view(torch.uint8)[
                            permute_sf_indices.to(w13_weight_scale.device)
                        ]
                        .contiguous()
                    )
                )
601
                # w13 bias shuffling
602
                permute_bias_indices = get_w2_permute_indices_with_cache(
603
604
605
606
                    self._cache_permute_indices,
                    w13_bias[i].clone().reshape(-1, 1),
                    epilogue_tile_m,
                )
607
608
609
610
611
612
                gemm1_bias_shuffled.append(
                    w13_bias[i]
                    .clone()
                    .reshape(-1, 1)[permute_bias_indices.to(w13_bias.device)]
                    .contiguous()
                )
613
                # w2 weight shuffling
614
                permute_indices = get_w2_permute_indices_with_cache(
615
616
617
618
                    self._cache_permute_indices,
                    w2_weight[i].view(torch.uint8),
                    epilogue_tile_m,
                )
619
620
621
622
623
                gemm2_weights_mxfp4_shuffled.append(
                    w2_weight[i]
                    .view(torch.uint8)[permute_indices.to(w2_weight.device)]
                    .contiguous()
                )
624
                # w2 scale shuffling
625
                permute_sf_indices = get_w2_permute_indices_with_cache(
626
627
628
629
630
                    self._cache_permute_indices,
                    w2_weight_scale[i].view(torch.uint8),
                    epilogue_tile_m,
                    num_elts_per_sf=16,
                )
631
                gemm2_scales_mxfp4_shuffled.append(
632
633
634
635
636
637
638
639
                    nvfp4_block_scale_interleave(
                        w2_weight_scale[i]
                        .view(torch.uint8)[
                            permute_sf_indices.to(w2_weight_scale.device)
                        ]
                        .contiguous()
                    )
                )
640
                # w2 bias shuffling
641
                permute_indices = get_w2_permute_indices_with_cache(
642
643
644
645
                    self._cache_permute_indices,
                    w2_bias[i].clone().reshape(-1, 1),
                    epilogue_tile_m,
                )
646
647
648
649
650
651
                gemm2_bias_shuffled.append(
                    w2_bias[i]
                    .clone()
                    .reshape(-1, 1)[permute_indices.to(w2_bias.device)]
                    .contiguous()
                )
652
653

            w13_weight = torch.stack(gemm1_weights_mxfp4_shuffled)
654
655
656
657
658
659
660
661
662
            w13_weight_scale = (
                torch.stack(gemm1_scales_mxfp4_shuffled)
                .reshape(
                    self.num_experts,
                    2 * self.intermediate_size,
                    self.hidden_size // sf_block_size,
                )
                .view(torch.float8_e4m3fn)
            )
663
664

            w2_weight = torch.stack(gemm2_weights_mxfp4_shuffled)
665
666
667
668
669
670
671
672
673
            w2_weight_scale = (
                torch.stack(gemm2_scales_mxfp4_shuffled)
                .reshape(
                    self.num_experts,
                    self.hidden_size,
                    self.intermediate_size // sf_block_size,
                )
                .view(torch.float8_e4m3fn)
            )
674
675

            layer.w13_weight = Parameter(w13_weight, requires_grad=False)
676
            layer.w13_weight_scale = Parameter(w13_weight_scale, requires_grad=False)
677
            layer.w2_weight = Parameter(w2_weight, requires_grad=False)
678
            layer.w2_weight_scale = Parameter(w2_weight_scale, requires_grad=False)
679
680
            layer.w13_bias = Parameter(
                torch.stack(gemm1_bias_shuffled).reshape(self.num_experts, -1),
681
682
683
684
685
686
687
688
689
690
                requires_grad=False,
            )
            layer.w2_bias = Parameter(
                torch.stack(gemm2_bias_shuffled).reshape(self.num_experts, -1),
                requires_grad=False,
            )
        elif (
            self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
            or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
        ):
691
692
693
            sf_block_size = 32  # mxfp4 block size

            # Common shape assertions
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
            assert (
                layer.w13_weight.dim() == 3
                and layer.w13_weight.shape[0] == self.num_experts
                and layer.w13_weight.shape[1] == self.intermediate_size * 2
                and layer.w13_weight.shape[2] == self.hidden_size // 2
            )
            assert (
                layer.w13_weight_scale.dim() == 3
                and layer.w13_weight_scale.shape[0] == self.num_experts
                and layer.w13_weight_scale.shape[1] == self.intermediate_size * 2
                and layer.w13_weight_scale.shape[2] == self.hidden_size // sf_block_size
            )
            assert (
                layer.w2_weight.dim() == 3
                and layer.w2_weight.shape[0] == self.num_experts
                and layer.w2_weight.shape[1] == self.hidden_size
                and layer.w2_weight.shape[2] == self.intermediate_size // 2
            )
            assert (
                layer.w2_weight_scale.dim() == 3
                and layer.w2_weight_scale.shape[1] == self.hidden_size
                and layer.w2_weight_scale.shape[2]
                == self.intermediate_size // sf_block_size
            )
            assert (
                layer.w13_bias.dim() == 2
                and layer.w13_bias.shape[0] == self.num_experts
                and layer.w13_bias.shape[1] == self.intermediate_size * 2
            )
            assert (
                layer.w2_bias.dim() == 2
                and layer.w2_bias.shape[0] == self.num_experts
                and layer.w2_bias.shape[1] == self.hidden_size
            )
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752

            # De-interleave and swap for w13 weight, bias, and scales
            w13_w = layer.w13_weight.data
            gate_w, up_w = w13_w[:, ::2, :], w13_w[:, 1::2, :]
            deinterleaved_w13_w = torch.cat([gate_w, up_w], dim=1)
            w1_w, w3_w = torch.chunk(deinterleaved_w13_w, 2, dim=1)
            w13_weight_swapped = torch.cat([w3_w, w1_w], dim=1)

            w13_b = layer.w13_bias.data.to(torch.float32)
            gate_b, up_b = w13_b[:, ::2], w13_b[:, 1::2]
            deinterleaved_w13_b = torch.cat([gate_b, up_b], dim=1)
            b1, b3 = torch.chunk(deinterleaved_w13_b, 2, dim=-1)
            w13_bias_swapped = torch.cat([b3, b1], dim=-1).to(torch.bfloat16)

            w13_s = layer.w13_weight_scale.data
            gate_s, up_s = w13_s[:, ::2, :], w13_s[:, 1::2, :]
            deinterleaved_w13_s = torch.cat([gate_s, up_s], dim=1)
            s1, s3 = torch.chunk(deinterleaved_w13_s, 2, dim=1)
            w13_scale_swapped = torch.cat([s3, s1], dim=1)

            if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS:
                from flashinfer import block_scale_interleave

                orig_shape = w13_scale_swapped.shape
                w13_scale_interleaved = block_scale_interleave(
753
754
                    w13_scale_swapped.view(torch.uint8)
                ).reshape(orig_shape)
755
756
757
758

                w2_s = layer.w2_weight_scale.data
                orig_shape = w2_s.shape
                w2_scale_interleaved = block_scale_interleave(
759
760
761
762
763
764
765
766
767
768
769
                    w2_s.view(torch.uint8)
                ).reshape(orig_shape)

                layer.w13_weight = Parameter(w13_weight_swapped, requires_grad=False)
                layer.w13_weight_scale = Parameter(
                    w13_scale_interleaved, requires_grad=False
                )
                layer.w13_bias = Parameter(w13_bias_swapped, requires_grad=False)
                layer.w2_weight_scale = Parameter(
                    w2_scale_interleaved, requires_grad=False
                )
770
771
772
773
            elif self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16:

                def _interleave_mxfp4_cutlass_sm90(w):
                    w_shape = w.shape
774
775
776
                    w_interleaved = w.reshape(
                        w_shape[0], w_shape[1], (w_shape[2] // 4), 4
                    )
777
778
                    w_interleaved = w_interleaved.permute(0, 2, 1, 3)
                    w_interleaved = w_interleaved.reshape(
779
780
                        w_shape[0], w_shape[2] // 4, w_shape[1] * 4
                    )
781
782
                    return w_interleaved

783
784
                w31_scales = w13_scale_swapped.to(torch.uint8).view(torch.uint8)
                w31_scales_interleaved = _interleave_mxfp4_cutlass_sm90(w31_scales)
785
786
787

                w2_weight_scale = layer.w2_weight_scale.data
                w2_scales = w2_weight_scale.to(torch.uint8).view(torch.uint8)
788
789
790
791
792
793
794
795
                w2_scales_interleaved = _interleave_mxfp4_cutlass_sm90(w2_scales)

                layer.w13_weight = torch.nn.Parameter(
                    torch.cat([w3_w, w1_w], dim=1), requires_grad=False
                )
                layer.w13_bias = torch.nn.Parameter(
                    w13_bias_swapped, requires_grad=False
                )
796
                layer.w13_weight_scale = torch.nn.Parameter(
797
798
                    w31_scales_interleaved, requires_grad=False
                )
799
                layer.w2_weight_scale = torch.nn.Parameter(
800
801
                    w2_scales_interleaved, requires_grad=False
                )
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817

            # theses two kernels go through the `flashinfer_cutlass_fused_moe` path
            from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
                FlashInferExperts,
            )

            self.moe_quant_config = self.get_fused_moe_quant_config(layer)
            assert self.moe_quant_config is not None
            prepare_finalize = maybe_make_prepare_finalize(
                moe=self.moe,
                quant_config=self.moe_quant_config,
                routing_tables=layer._maybe_init_expert_routing_tables(),
                allow_new_interface=True,
            )
            assert prepare_finalize is not None

818
            self.moe_kernel = mk.FusedMoEKernel(
819
820
821
822
823
824
825
                prepare_finalize,
                FlashInferExperts(
                    moe_config=self.moe,
                    quant_config=self.moe_quant_config,
                ),
                shared_experts=None,
            )
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
        elif self.mxfp4_backend == Mxfp4Backend.CK:
            if layer.w13_bias is not None:
                layer.w13_bias.data = layer.w13_bias.data.to(torch.float32)
            if layer.w2_bias.data is not None:
                layer.w2_bias.data = layer.w2_bias.data.to(torch.float32)

            e, n, k = layer.w13_weight.shape
            layer.w13_weight.view(torch.uint8).copy_(
                layer.w13_weight.data.view(torch.uint8)
                .view(e, n // 2, 2, k)
                .permute(0, 2, 1, 3)
                .contiguous()
                .view(e, n, k)
            )
            layer.w13_weight_scale.data = (
                layer.w13_weight_scale.data.view(e, n // 2, 2, -1)
                .permute(0, 2, 1, 3)
                .contiguous()
                .view(e, n, -1)
            )
            layer.w13_weight.data = layer.w13_weight.data.view(torch.float4_e2m1fn_x2)
            layer.w2_weight.data = layer.w2_weight.data.view(torch.float4_e2m1fn_x2)

            layer.w13_weight.data = rocm_aiter_ops.shuffle_weight_a16w4(
                layer.w13_weight, 16, True
            )
            shuffled_w13_scale = rocm_aiter_ops.shuffle_scale_a16w4(
                layer.w13_weight_scale.view(-1, layer.w13_weight_scale.shape[-1]),
                self.num_experts,
                True,
            )

            layer.w2_weight.data = rocm_aiter_ops.shuffle_weight_a16w4(
                layer.w2_weight, 16, False
            )
            shuffled_w2_scale = rocm_aiter_ops.shuffle_scale_a16w4(
                layer.w2_weight_scale.view(-1, layer.w2_weight_scale.shape[-1]),
                self.num_experts,
                False,
            )

            layer.w13_bias.data = (
                layer.w13_bias.data.view(-1, n // 2, 2)
                .permute(0, 2, 1)
                .contiguous()
                .view(-1, n)
            )

            layer.w13_weight_scale = torch.nn.Parameter(
                shuffled_w13_scale, requires_grad=False
            )
            layer.w2_weight_scale = torch.nn.Parameter(
                shuffled_w2_scale, requires_grad=False
            )
            # replace_parameter(layer, "w13_bias", w13_bias)
            # replace_parameter(layer, "w13_weight_scale", w13_weight_scale)
            # replace_parameter(layer, "w2_weight_scale", w2_weight_scale)
            # replace_parameter(layer, "w13_weight", w13_weight)
            # replace_parameter(layer, "w2_weight", w2_weight)

886
        elif self.mxfp4_backend == Mxfp4Backend.TRITON:
887
888
889
890
891
892
893
            from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig

            w13_bias = layer.w13_bias.to(torch.float32)
            w2_bias = layer.w2_bias.to(torch.float32)

            layer.w13_bias = Parameter(w13_bias, requires_grad=False)
            layer.w2_bias = Parameter(w2_bias, requires_grad=False)
894
895
896
897
898
            # Ideally we'd use FusedMoEModularKernel.prepare_finalize object
            # (stored in self.fused_experts) to determine if the MoE has a
            # batched activation format. As self.fused_experts is not
            # initialized at this point, we resort to checking the MoE config
            # directly.
899
            is_batched_moe = self.moe.use_deepep_ll_kernels
900
            if is_batched_moe:
901
902
903
904
                num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8
            else:
                num_warps = 8
            w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
905
906
                layer.w13_weight, layer.w13_weight_scale, num_warps
            )
907
            w2_weight, w2_flex, w2_scale = _swizzle_mxfp4(
908
909
                layer.w2_weight, layer.w2_weight_scale, num_warps
            )
910
911

            self.w13_precision_config = PrecisionConfig(
912
913
                weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex)
            )
914
            self.w2_precision_config = PrecisionConfig(
915
916
                weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex)
            )
917
918
            self.w13_weight = w13_weight
            self.w2_weight = w2_weight
919
920
921
922
            del layer.w13_weight
            del layer.w2_weight
            layer.w13_weight = w13_weight
            layer.w2_weight = w2_weight
923

924
        else:
925
926
927
928
            raise ValueError(
                f"Unsupported mxfp4_backend: {self.mxfp4_backend}: "
                f"should be one of: {list(Mxfp4Backend)}."
            )
929

930
    def get_fused_moe_quant_config(
931
        self, layer: torch.nn.Module
932
    ) -> FusedMoEQuantConfig | None:
933
        if self.mxfp4_backend == Mxfp4Backend.MARLIN:
934
935
936
937
938
939
940
            return mxfp4_w4a16_moe_quant_config(
                w1_bias=layer.w13_bias,
                w2_bias=layer.w2_bias,
                w1_scale=layer.w13_weight_scale,
                w2_scale=layer.w2_weight_scale,
            )
        elif self.mxfp4_backend == Mxfp4Backend.TRITON:
941
942
            w1_scale = self.w13_precision_config
            w2_scale = self.w2_precision_config
943
944
945
946
947
948
            return mxfp4_w4a16_moe_quant_config(
                w1_bias=layer.w13_bias,
                w2_bias=layer.w2_bias,
                w1_scale=w1_scale,
                w2_scale=w2_scale,
            )
949
950
951
952
953
954
955
956
957
958
        elif self.mxfp4_backend in [
            Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM,
            Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS,
        ]:
            return mxfp4_mxfp8_moe_quant_config(
                w1_bias=layer.w13_bias,
                w2_bias=layer.w2_bias,
                w1_scale=layer.w13_weight_scale,
                w2_scale=layer.w2_weight_scale,
            )
959
960
961
        elif self.mxfp4_backend in [
            Mxfp4Backend.SM100_FI_MXFP4_BF16,
            Mxfp4Backend.SM90_FI_MXFP4_BF16,
962
            Mxfp4Backend.CK,
963
        ]:
964
965
966
967
968
969
            return mxfp4_w4a16_moe_quant_config(
                w1_bias=layer.w13_bias,
                w2_bias=layer.w2_bias,
                w1_scale=layer.w13_weight_scale,
                w2_scale=layer.w2_weight_scale,
            )
970
971
972
        else:
            w1_scale = layer.w13_weight_scale
            w2_scale = layer.w2_weight_scale
973
974
            return ocp_mx_moe_quant_config(
                quant_dtype="mxfp4",
975
976
977
978
979
                w1_bias=layer.w13_bias,
                w2_bias=layer.w2_bias,
                w1_scale=w1_scale,
                w2_scale=w2_scale,
            )
980

981
982
    def select_gemm_impl(
        self,
983
        prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular,
984
        layer: torch.nn.Module,
985
    ) -> mk.FusedMoEExpertsModular:
986
987
988
989
        if (
            prepare_finalize.activation_format
            == mk.FusedMoEActivationFormat.BatchedExperts
        ):
990
991
992
993
994
995
996
997
            if self.mxfp4_backend == Mxfp4Backend.MARLIN:
                max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
                assert max_num_tokens_per_rank is not None
                assert self.moe_quant_config is not None
                return BatchedMarlinExperts(
                    max_num_tokens=max_num_tokens_per_rank,
                    num_dispatchers=prepare_finalize.num_dispatchers(),
                    quant_config=self.moe_quant_config,
998
                    moe_config=self.moe,
999
1000
1001
                )
            else:
                raise NotImplementedError(
1002
1003
                    f"Incompatible Mxfp4 backend ({self.mxfp4_backend}) for "
                    "EP batched experts format"
1004
                )
1005
        else:
1006
            assert self.moe_quant_config is not None
1007
1008
1009
1010
            if (
                self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
                or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
            ):
1011
1012
                # B200 code-path
                kwargs = {
1013
                    # TODO(bnell): part of quant_config
1014
1015
                    "max_capture_size": self.max_capture_size,
                }
1016
1017
                return TrtLlmGenExperts(self.moe, self.moe_quant_config, **kwargs)
            elif self.mxfp4_backend == Mxfp4Backend.MARLIN:
1018
                return MarlinExperts(self.moe, self.moe_quant_config)
1019
            elif self.mxfp4_backend == Mxfp4Backend.TRITON:
1020
                if self.moe.is_lora_enabled:
1021
1022
                    return UnfusedOAITritonExperts(self.moe, self.moe_quant_config)
                return OAITritonExperts(self.moe, self.moe_quant_config)
1023
1024
1025
1026
            else:
                raise NotImplementedError(
                    f"Incompatible Mxfp4 backend ({self.mxfp4_backend}) for EP"
                )
1027

1028
1029
    @property
    def is_monolithic(self) -> bool:
1030
1031
        if self.moe.is_lora_enabled:
            return False
1032
1033
1034
1035
        return (
            self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
            or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
            or self.mxfp4_backend == Mxfp4Backend.TRITON
1036
            or self.mxfp4_backend == Mxfp4Backend.CK
1037
1038
        )

1039
1040
    def apply(
        self,
1041
        layer: FusedMoE,
1042
        x: torch.Tensor,
1043
1044
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
1045
        shared_experts_input: torch.Tensor | None,
1046
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
1047
        assert not self.is_monolithic
1048
        if layer.enable_eplb:
1049
1050
1051
            raise NotImplementedError("EPLB is not supported for mxfp4")

        assert _can_support_mxfp4(
1052
1053
1054
1055
1056
1057
1058
1059
1060
            layer.use_grouped_topk,
            layer.topk_group,
            layer.num_expert_group,
            layer.expert_map,
            layer.custom_routing_function,
            layer.e_score_correction_bias,
            layer.apply_router_weight_on_input,
            layer.scoring_func,
            layer.activation,
1061
1062
1063
            layer.eplb_state.expert_load_view,
            layer.eplb_state.logical_to_physical_map,
            layer.eplb_state.logical_replica_count,
1064
1065
        ), "MXFP4 are not supported with this configuration."

1066
1067
1068
        assert (
            self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
            or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
1069
            or self.mxfp4_backend == Mxfp4Backend.MARLIN
1070
1071
        )

1072
1073
        assert self.moe_kernel is not None
        return self.moe_kernel.apply(
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
            hidden_states=x,
            w1=layer.w13_weight,
            w2=layer.w2_weight,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            activation=layer.activation,
            global_num_experts=layer.global_num_experts,
            apply_router_weight_on_input=layer.apply_router_weight_on_input,
            expert_map=layer.expert_map,
            shared_experts_input=shared_experts_input,
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
        )

    def apply_monolithic(
        self,
        layer: FusedMoE,
        x: torch.Tensor,
        router_logits: torch.Tensor,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        assert self.is_monolithic

        if layer.enable_eplb:
            raise NotImplementedError("EPLB is not supported for mxfp4")

        assert _can_support_mxfp4(
            layer.use_grouped_topk,
            layer.topk_group,
            layer.num_expert_group,
            layer.expert_map,
            layer.custom_routing_function,
            layer.e_score_correction_bias,
            layer.apply_router_weight_on_input,
            layer.scoring_func,
            layer.activation,
            layer.eplb_state.expert_load_view,
            layer.eplb_state.logical_to_physical_map,
            layer.eplb_state.logical_replica_count,
        ), "MXFP4 are not supported with this configuration."

1112
1113
1114
1115
        if (
            self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
            or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
        ):
1116
            from flashinfer import trtllm_fp4_block_scale_moe
1117

1118
            if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16:
1119
1120
1121
                assert x.dtype == torch.bfloat16
                x_quant = x
                x_scale = None
1122
1123
            elif self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM:
                from flashinfer import mxfp8_quantize
1124

1125
                x_quant, x_scale = mxfp8_quantize(x, False)  # to mxfp8
1126
                x_scale = x_scale.view(torch.float8_e4m3fn).reshape(*x.shape[:-1], -1)
1127

1128
            trtllm_gen_output = trtllm_fp4_block_scale_moe(
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
                routing_logits=router_logits.to(torch.bfloat16),
                routing_bias=None,
                hidden_states=x_quant,
                hidden_states_scale=x_scale,
                gemm1_weights=layer.w13_weight,  # uint8 (e2m1 x 2)
                gemm1_weights_scale=layer.w13_weight_scale,  # uint8 (e4m3 x 2)
                gemm1_bias=layer.w13_bias,  # fp32 per expert per channel
                gemm1_alpha=layer.gemm1_alpha,  # fp32 per expert
                gemm1_beta=layer.gemm1_beta,  # fp32 per expert
                gemm1_clamp_limit=layer.gemm1_clamp_limit,  # fp32 per expert
                gemm2_weights=layer.w2_weight,  # uint8 (e2m1 x 2)
                gemm2_weights_scale=layer.w2_weight_scale,  # ue8m0
                gemm2_bias=layer.w2_bias,  # fp32 per expert per channel
                output1_scale_scalar=None,
                output1_scale_gate_scalar=None,
                output2_scale_scalar=None,
                num_experts=layer.global_num_experts,
                top_k=layer.top_k,
                n_group=None,
                topk_group=None,
                intermediate_size=self.intermediate_size,  # padded to multiple of 256
                local_expert_offset=layer.ep_rank * layer.local_num_experts,
                local_num_experts=self.num_experts,
                routed_scaling_factor=None,
                routing_method_type=1 if layer.renormalize else 0,
                do_finalize=True,
1155
                tune_max_num_tokens=max(self.max_capture_size, 1),
1156
1157
            )[0]
            return trtllm_gen_output
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
        elif self.mxfp4_backend == Mxfp4Backend.CK:
            topk_weights, topk_ids = rocm_aiter_ops.fused_topk(
                x, router_logits, layer.top_k, True
            )
            output = rocm_aiter_ops.fused_moe(
                x,
                layer.w13_weight,
                layer.w2_weight,
                topk_weights,
                topk_ids,
                activation_method=rocm_aiter_ops.get_aiter_activation_type("swiglu"),
                quant_method=rocm_aiter_ops.get_aiter_quant_type("per_1x32"),
                w1_scale=layer.w13_weight_scale,
                w2_scale=layer.w2_weight_scale,
                doweight_stage1=False,
                hidden_pad=self.hidden_pad // 128 * 128,
                intermediate_pad=self.intermediate_pad // 64 * 64 * 2,
                bias1=layer.w13_bias,
                bias2=layer.w2_bias,
            )
            return output
1179
        elif self.mxfp4_backend == Mxfp4Backend.TRITON:
1180
            from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (  # noqa: E501
1181
1182
1183
                triton_kernel_moe_forward,
            )

1184
1185
            return triton_kernel_moe_forward(
                hidden_states=x,
1186
1187
                w1=layer.w13_weight,
                w2=layer.w2_weight,
1188
                gating_output=router_logits,
1189
1190
1191
1192
                topk=layer.top_k,
                renormalize=layer.renormalize,
                global_num_experts=layer.global_num_experts,
                expert_map=layer.expert_map,
1193
                quant_config=self.moe_quant_config,
1194
                apply_router_weight_on_input=layer.apply_router_weight_on_input,
1195
            )
1196
1197
        else:
            raise ValueError(f"Unsupported backend: {self.mxfp4_backend}")
1198
1199


1200
class XpuMxfp4MoEMethod(Mxfp4MoEMethod):
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
    def __init__(self, moe_config: FusedMoEConfig):
        super().__init__(moe_config)
        self.moe_config = moe_config

    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        super().create_weights(
            layer,
            num_experts,
            hidden_size,
            intermediate_size_per_partition,
            params_dtype,
            **extra_weight_attrs,
        )
        self.original_hidden_size = hidden_size

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
1225
        pass
1226

1227
1228
1229
1230
1231
    @property
    def is_monolithic(self) -> bool:
        return True

    def apply_monolithic(
1232
        self,
1233
        layer: FusedMoE,
1234
1235
1236
        x: torch.Tensor,
        router_logits: torch.Tensor,
    ) -> torch.Tensor:
1237
1238
1239
        assert layer.activation == MoEActivation.SWIGLUOAI, (
            "Only swiglu_oai activation is supported for "
            f"XPU MXFP4 MoE, not {layer.activation}."
1240
        )
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
        from vllm_xpu_kernels.fused_moe_interface import xpu_fused_moe

        M, _ = x.size()
        routing_weights = torch.empty(
            M, layer.top_k, dtype=torch.float32, device=x.device
        )
        selected_experts = torch.empty(
            M, layer.top_k, dtype=torch.int32, device=x.device
        )
        token_expert_indices = torch.empty(
            M, layer.top_k, dtype=torch.int32, device=x.device
        )

        if layer.use_grouped_topk:
            routing_weights, selected_experts = torch.ops._moe_C.fused_grouped_topk(
                x,
                router_logits,
                layer.top_k,
                layer.renormalize,
                n_expert_group=layer.num_expert_group,
                n_topk_group=layer.topk_group,
                scoring_func=layer.scoring_func,
                routed_scaling_factor=layer.routed_scaling_factor,
                bias=layer.e_score_correction_bias,
            )
        else:
            torch.ops._moe_C.topk_softmax(
                routing_weights,
                selected_experts,
                token_expert_indices,
                router_logits,
                layer.renormalize,
                layer.e_score_correction_bias,
            )

        return xpu_fused_moe(
            hidden_states=x,
            w13=layer.w13_weight,
            w13_bias=layer.w13_bias if self.moe.has_bias else None,
            w13_scales=layer.w13_weight_scale,
            w2=layer.w2_weight,
            w2_bias=layer.w2_bias if self.moe.has_bias else None,
            w2_scales=layer.w2_weight_scale,
            topk_weights=routing_weights,
            topk_ids=selected_experts,
            n_experts_per_token=layer.top_k,
1287
            activation=layer.activation.value,
1288
1289
            num_experts=layer.local_num_experts,
            is_mxfp4=True,
1290
        )