"vscode:/vscode.git/clone" did not exist on "b024a42e93d1f078816302579b577fb24b5939a4"
modelopt.py 72.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from fnmatch import fnmatch
5
from typing import TYPE_CHECKING, Any, Optional
6
7
8
9
10

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter

11
12
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
13
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
14
from vllm.attention.layer import Attention
15
from vllm.logger import init_logger
16
from vllm.model_executor.layers.fused_moe.config import (
17
18
19
20
    FusedMoEQuantConfig,
    fp8_w8a8_moe_quant_config,
    nvfp4_moe_quant_config,
)
21
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
22
from vllm.model_executor.layers.fused_moe.layer import (
23
24
25
26
27
28
29
30
31
    FusedMoE,
    FusedMoEMethodBase,
    FusedMoeWeightScaleSupported,
)
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
)
32
from vllm.model_executor.layers.quantization import QuantizationMethods
33
from vllm.model_executor.layers.quantization.base_config import (
34
35
36
    QuantizationConfig,
    QuantizeMethodBase,
)
37
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
38
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
39
    build_flashinfer_fp4_cutlass_moe_prepare_finalize,
40
    flashinfer_trtllm_fp4_moe,
41
    flashinfer_trtllm_fp4_routed_moe,
42
    prepare_static_weights_for_trtllm_fp4_moe,
43
44
45
    reorder_w1w3_to_w3w1,
    select_nvfp4_gemm_impl,
)
46
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
47
48
    FlashinferMoeBackend,
    apply_flashinfer_per_tensor_scale_fp8,
49
    build_flashinfer_fp8_cutlass_moe_prepare_finalize,
50
51
    flashinfer_cutlass_moe_fp8,
    get_flashinfer_moe_backend,
52
    is_flashinfer_supporting_global_sf,
53
54
55
56
57
    register_moe_scaling_factors,
    rotate_flashinfer_fp8_moe_weights,
    select_cutlass_fp8_gemm_impl,
    swap_w13_to_w31,
)
58
59
60
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
    W8A8BlockFp8LinearOp,
)
61
62
63
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
    get_marlin_input_dtype,
)
64
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
65
66
67
68
69
    apply_fp4_marlin_linear,
    is_fp4_marlin_supported,
    prepare_fp4_layer_for_marlin,
    prepare_moe_fp4_layer_for_marlin,
)
70
from vllm.model_executor.layers.quantization.utils.quant_utils import (
71
72
73
74
75
    GroupShape,
    cutlass_fp4_supported,
    is_layer_skipped,
    swizzle_blockscale,
)
76
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
77
    Fp8LinearOp,
78
    cutlass_block_fp8_supported,
79
80
    requantize_with_max_scale,
)
81
82
83
84
85
86
from vllm.model_executor.parameter import (
    BlockQuantScaleParameter,
    ChannelQuantScaleParameter,
    ModelWeightParameter,
    PerTensorScaleParameter,
)
87
from vllm.scalar_type import scalar_types
88
89
90
91
92
from vllm.utils.flashinfer import (
    flashinfer_scaled_fp4_mm,
    has_flashinfer,
    has_flashinfer_moe,
)
93
from vllm.utils.math_utils import round_up
94

95
96
97
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

98
99
logger = init_logger(__name__)

100
101
102
103
104
105
106
107
108
109
QUANT_ALGOS = [
    # FP8 (per-tensor weight + optional static activation scale).
    "FP8",
    # FP8 per-channel weight scale + per-token activation scale.
    "FP8_PER_CHANNEL_PER_TOKEN",
    # FP8 per-block weight-only (ModelOpt may emit this as lowercase).
    "FP8_PB_WO",
    # FP4
    "NVFP4",
]
110
KV_CACHE_QUANT_ALGOS = ["FP8"]
111
112


113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
    """

    def __init__(self, quant_config: "ModelOptQuantConfigBase"):
        super().__init__(quant_config)


class ModelOptQuantConfigBase(QuantizationConfig):
    LinearMethodCls: type = LinearMethodBase
    FusedMoEMethodCls: type = FusedMoEMethodBase
    KVCacheMethodCls: type = BaseKVCacheMethod

    def __init__(
        self,
        exclude_modules: list[str],
    ):
        super().__init__()
        self.exclude_modules: list[str] = exclude_modules

    def is_layer_excluded(self, prefix: str) -> bool:
        """
        Check if a layer should be excluded from quantization.

        Handles both exact matching (for fused layers) and ModelOpt wildcard matching.

        The ModelOpt exclude_modules list is a list of wildcards.
        """
        if len(self.exclude_modules) == 0:
            return False

        # First check exact matching with fused layer support
        if is_layer_skipped(prefix, self.exclude_modules, self.packed_modules_mapping):
            return True

        # TODO: This special hard coded logic is not needed for quantized checkpoints
        # generated by ModelOpt >= 0.39.0 where they are handled natually by the
        # exclude_modules config. But need to keep them for loading quantized
        # checkpoints generated by older versions. Then check substring matching
        # for patterns not caught by exact match
        for exclude_module in self.exclude_modules:
            # Skip exact matches already handled above
            if exclude_module != prefix and (
                exclude_module in prefix
                or (
                    prefix.startswith("language_model.")
                    and exclude_module in prefix.removeprefix("language_model.")
                )
            ):
                return True

        # modelopt exclude modules are not simple strings, they are wildcards
        for wildcard_pattern in self.exclude_modules:
            if fnmatch(prefix, wildcard_pattern):
                return True

        return False

    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
        # handle kv-cache first so we can focus only on weight quantization thereafter
        if isinstance(layer, Attention):
            return self.KVCacheMethodCls(self)

        # handle exclusion
        if self.is_layer_excluded(prefix):
            if isinstance(layer, LinearBase):
                return UnquantizedLinearMethod()
            return None

        # TODO: This special hard coded logic is not needed for quantized checkpoints
        # generated by ModelOpt >= 0.39.0 where they are handled natually by the
        # exclude_modules config. But need to keep them for loading quantized
        # checkpoints generated by older versions. Then check substring matching
        # for patterns not caught by exact match
        if "vision_tower" in prefix or "vision_model" in prefix:
            return UnquantizedLinearMethod()

        # now, the layer is quantized, handle it here
        if isinstance(layer, LinearBase):
195
196
197
198
            quant_method = self.LinearMethodCls(self)
            if getattr(quant_method, "backend", "") == "marlin":
                quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
            return quant_method
199
        elif isinstance(layer, FusedMoE):
200
201
202
203
            quant_method = self.FusedMoEMethodCls(quant_config=self, layer=layer)
            if getattr(quant_method, "backend", "") == "marlin":
                quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
            return quant_method
204
205
206
207
208

        return None

    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if len(self.exclude_modules) > 0:
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
            # This is a workaround for the weights remapping issue:
            # https://github.com/vllm-project/vllm/issues/28072
            # Right now, the Nvidia ModelOpt library use just one wildcard pattern:
            #        module_path*
            # It gets applied if the whole tree of modules rooted at module_path
            # is not quantized. Here we replace such pattern by 2 patterns that are
            # collectively equivalent to the original pattern:
            #        module_path
            #        module_path.*
            new_exclude_modules = []
            for exclude in self.exclude_modules:
                if len(exclude) >= 2 and exclude[-1] == "*" and exclude[-2] != ".":
                    new_exclude_modules.append(exclude[:-1])
                    new_exclude_modules.append(exclude[:-1] + ".*")
                else:
                    new_exclude_modules.append(exclude)

            self.exclude_modules = hf_to_vllm_mapper.apply_list(new_exclude_modules)
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275

    @staticmethod
    def get_config_filenames() -> list[str]:
        return ["hf_quant_config.json"]

    @classmethod
    def _from_config(
        cls,
        *,
        quant_method: str,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
        original_config: dict[str, Any],
        group_size: int | None,
    ) -> "ModelOptQuantConfigBase":
        raise NotImplementedError("Please implement this function in sub classes")

    @classmethod
    def from_config(cls, config: dict[str, Any]) -> "ModelOptQuantConfigBase":
        # Handle both ModelOpt format and compressed-tensors style format
        if "quantization" in config:
            # Traditional ModelOpt format:
            # {"quantization": {"quant_algo": "..."}}
            quant_config = cls.get_from_keys(config, ["quantization"])
            if not isinstance(quant_config, dict):
                raise ValueError("Expected 'quantization' to be a dictionary in config")

            quant_method = quant_config.get("quant_algo")

            # Handle kv_cache_quant_algo with proper type validation
            kv_cache_quant_method = quant_config.get("kv_cache_quant_algo")

            # Handle group_size with proper type validation
            group_size_raw = quant_config.get("group_size")

            # "exclude_modules" is the key in the legacy hf_quant_config.json
            exclude_modules = quant_config.get("exclude_modules", [])
        else:
            # Compressed-tensors style format:
            # {"quant_algo": "...", "quant_method": "modelopt"}
            quant_method = config.get("quant_algo")
            kv_cache_quant_method = config.get("kv_cache_quant_algo")
            # "ignore" is the key in config.json
            exclude_modules = config.get("ignore", [])
            group_size_raw = config.get("group_size")

        if not quant_method:
            raise ValueError("Missing 'quant_algo' in quantization config")

276
277
278
        # Normalize quant_algo for robust matching (ModelOpt may emit lowercase).
        quant_method = str(quant_method).upper()

279
280
281
282
283
284
285
286
        if kv_cache_quant_method is None:
            # No KV cache quantization, keep this branch just to have this comment
            pass
        elif not isinstance(kv_cache_quant_method, str):
            raise ValueError(
                f"kv_cache_quant_algo must be a string, got "
                f"{type(kv_cache_quant_method)}"
            )
287
288
        else:
            kv_cache_quant_method = kv_cache_quant_method.upper()
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323

        if not isinstance(exclude_modules, list):
            raise ValueError(
                f"exclude_modules must be a list, got {type(exclude_modules)}"
            )

        if group_size_raw is None:
            group_size = None
        elif isinstance(group_size_raw, int):
            group_size = group_size_raw
        else:
            try:
                group_size = int(group_size_raw)
            except (ValueError, TypeError):
                raise ValueError(
                    f"group_size must be an integer, got {type(group_size_raw)}"
                ) from None

        if quant_method not in QUANT_ALGOS:
            raise ValueError(
                f"ModelOpt currently only supports: {QUANT_ALGOS} "
                "quantizations in vLLM. Please check the "
                "`hf_quant_config.json` file for your model's "
                "quant configuration."
            )
        return cls._from_config(
            quant_method=quant_method,
            kv_cache_quant_method=kv_cache_quant_method,
            exclude_modules=exclude_modules,
            group_size=group_size,
            original_config=config,
        )


class ModelOptFp8Config(ModelOptQuantConfigBase):
324
325
326
327
    """Config class for ModelOpt FP8."""

    def __init__(
        self,
328
        quant_method: str,
329
330
331
        is_checkpoint_fp8_serialized: bool,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
332
    ) -> None:
333
        super().__init__(exclude_modules)
334
        self.quant_method = quant_method
335
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
336
        self.kv_cache_quant_method = kv_cache_quant_method
337
        if is_checkpoint_fp8_serialized:
338
            logger.warning(
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
                "Detected ModelOpt fp8 checkpoint (quant_algo=%s). Please note "
                "that the format is experimental and could change.",
                quant_method,
            )

        # Select LinearMethod implementation based on quant_algo.
        if self.quant_method == "FP8":
            self.LinearMethodCls = ModelOptFp8LinearMethod
        elif self.quant_method == "FP8_PER_CHANNEL_PER_TOKEN":
            self.LinearMethodCls = ModelOptFp8PcPtLinearMethod
        elif self.quant_method == "FP8_PB_WO":
            self.LinearMethodCls = ModelOptFp8PbWoLinearMethod
        else:
            raise ValueError(
                "Unsupported ModelOpt FP8 quant_algo for vLLM: "
                f"{self.quant_method}. Supported: FP8 / "
                "FP8_PER_CHANNEL_PER_TOKEN / FP8_PB_WO."
356
            )
357

358
    def get_name(self) -> QuantizationMethods:
359
360
        return "modelopt"

361
    def get_supported_act_dtypes(self) -> list[torch.dtype]:
362
363
364
365
366
367
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
        return 89

368
369
    @classmethod
    def override_quantization_method(
370
        cls, hf_quant_cfg, user_quant
371
    ) -> QuantizationMethods | None:
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
        """Detect if this ModelOpt config should be used based on
        quantization config."""

        if hf_quant_cfg is None:
            return None

        # Use the community standard 'quant_method'
        quant_method = hf_quant_cfg.get("quant_method", "").lower()

        # Only proceed if the method is explicitly "modelopt"
        if quant_method != "modelopt":
            return None

        # Look for ModelOpt-specific config structure
        if "quantization" in hf_quant_cfg:
            quant_config = hf_quant_cfg["quantization"]
            if isinstance(quant_config, dict):
389
390
                quant_algo = str(quant_config.get("quant_algo", ""))
                if "FP8" in quant_algo.upper():
391
392
393
                    return "modelopt"
        else:
            # Check for compressed-tensors style config with specific quant_algo
394
395
            quant_algo = str(hf_quant_cfg.get("quant_algo", ""))
            if "FP8" in quant_algo.upper():
396
397
398
399
                return "modelopt"

        return None

400
    @classmethod
401
402
403
404
405
406
407
408
409
    def _from_config(
        cls,
        *,
        quant_method: str,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
        original_config: dict[str, Any],
        **kwargs: Any,
    ) -> "ModelOptFp8Config":
410
        is_checkpoint_fp8_serialized = "FP8" in quant_method
411

412
413
414
415
416
417
        return cls(
            quant_method,
            is_checkpoint_fp8_serialized,
            kv_cache_quant_method,
            exclude_modules,
        )
418

419
420
421
422

class ModelOptFp8LinearMethod(LinearMethodBase):
    """Linear method for Model Optimizer static quantization.
    Supports loading FP8 checkpoints with static weight scale and
423
    activation scale. Future support might be added for dynamic
424
425
426
427
    scales.

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
428
    2. Only support float8_e4m3fn datatype
429
430
431
        Args: quant_config: The ModelOpt quantization config.
    """

432
    def __init__(self, quant_config: ModelOptFp8Config) -> None:
433
        self.quant_config = quant_config
434
        self.fp8_linear = Fp8LinearOp(
435
436
            act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR
        )
437
438
439
440
441

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
442
        output_partition_sizes: list[int],
443
444
445
446
447
448
449
450
451
452
453
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
454
455
456
457
458
459
460
461
462
463
464
465
466
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_fp8_serialized
            else params_dtype
        )
        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition, input_size_per_partition, dtype=weight_dtype
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
467
468
469
470
        layer.register_parameter("weight", weight)

        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
471
472
473
474
            weight_scale = PerTensorScaleParameter(
                data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
                weight_loader=weight_loader,
            )
475
476
477
            weight_scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("weight_scale", weight_scale)
            # INPUT SCALE
478
479
480
481
            scale = PerTensorScaleParameter(
                data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
                weight_loader=weight_loader,
            )
482
483
484
485
486

            scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("input_scale", scale)

    def process_weights_after_loading(self, layer: Module) -> None:
487
488
489
490
        weight = layer.weight
        max_w_scale = layer.weight_scale.max()
        if not (layer.weight_scale == layer.weight_scale[0]).all():
            max_w_scale, weight = requantize_with_max_scale(
491
492
                layer.weight, layer.weight_scale, layer.logical_widths
            )
493
494
        layer.weight = Parameter(weight.t(), requires_grad=False)
        layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
495
        layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
496
497
498
499
500

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
501
        bias: torch.Tensor | None = None,
502
    ) -> torch.Tensor:
503
504
505
506
507
508
509
        return self.fp8_linear.apply(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            input_scale=layer.input_scale,
            bias=bias,
        )
510
511


512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
class ModelOptFp8PcPtLinearMethod(LinearMethodBase):
    """Linear method for ModelOpt FP8_PER_CHANNEL_PER_TOKEN checkpoints.

    Expected checkpoint structure (per Linear):
    - weight: fp8-e4m3fn, shape [out, in]
    - weight_scale: fp32, shape [out] (per-output-channel)
    - no input_scale (activations are dynamically quantized per-token)
    """

    def __init__(self, quant_config: ModelOptFp8Config) -> None:
        self.quant_config = quant_config
        self.fp8_linear = Fp8LinearOp(
            act_quant_static=False, act_quant_group_shape=GroupShape.PER_TOKEN
        )

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size

        if not self.quant_config.is_checkpoint_fp8_serialized:
            raise ValueError(
                "FP8_PER_CHANNEL_PER_TOKEN currently only supports "
                "FP8-serialized checkpoints."
            )

        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition

        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition,
                dtype=torch.float8_e4m3fn,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
        layer.register_parameter("weight", weight)

        weight_scale = ChannelQuantScaleParameter(
            data=torch.empty(output_size_per_partition, dtype=torch.float32),
            output_dim=0,
            weight_loader=weight_loader,
        )
        weight_scale[:] = torch.finfo(torch.float32).min
        layer.register_parameter("weight_scale", weight_scale)

    def process_weights_after_loading(self, layer: Module) -> None:
        layer.weight = Parameter(layer.weight.t(), requires_grad=False)
        layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        return self.fp8_linear.apply(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            input_scale=None,
            bias=bias,
        )


class ModelOptFp8PbWoLinearMethod(LinearMethodBase):
    """Linear method for ModelOpt FP8_PB_WO checkpoints.

    ModelOpt exports `weight_scale` as a 4D tensor:
      [out_blk, 1, in_blk, 1]
    where block size is typically 128 for both dims.

    vLLM executes it as FP8 GEMM with *dynamic per-token* activation quant.
    """

    _WEIGHT_BLOCK_SIZE: tuple[int, int] = (128, 128)

    def __init__(self, quant_config: ModelOptFp8Config) -> None:
        self.quant_config = quant_config
        block_n, block_k = self._WEIGHT_BLOCK_SIZE
        self.weight_block_size = list(self._WEIGHT_BLOCK_SIZE)
        self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
            weight_group_shape=GroupShape(block_n, block_k),
            act_quant_group_shape=GroupShape(1, block_k),
            cutlass_block_fp8_supported=cutlass_block_fp8_supported(),
            use_aiter_and_is_supported=False,
        )

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size

        if not self.quant_config.is_checkpoint_fp8_serialized:
            raise ValueError(
                "FP8_PB_WO currently only supports FP8-serialized checkpoints."
            )

        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition

        # Expose block size so the v2 weight loaders can translate offsets from
        # element-space -> block-space for BlockQuantScaleParameter.
        layer.weight_block_size = self.weight_block_size

        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition,
                dtype=torch.float8_e4m3fn,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
        layer.register_parameter("weight", weight)

        block_n, block_k = self._WEIGHT_BLOCK_SIZE
        if output_size_per_partition % block_n != 0:
            raise ValueError(
                "ModelOpt FP8_PB_WO requires out_features divisible by "
                f"{block_n}, got {output_size_per_partition}."
            )
        if input_size_per_partition % block_k != 0:
            raise ValueError(
                "ModelOpt FP8_PB_WO requires in_features divisible by "
                f"{block_k}, got {input_size_per_partition}."
            )

        out_blks = output_size_per_partition // block_n
        in_blks = input_size_per_partition // block_k

        # Match ModelOpt's exported shape so weight loading works without a
        # custom loader: [out_blk, 1, in_blk, 1]
        weight_scale = BlockQuantScaleParameter(
            data=torch.empty((out_blks, 1, in_blks, 1), dtype=torch.float32),
            input_dim=2,
            output_dim=0,
            weight_loader=weight_loader,
        )
        weight_scale[:] = torch.finfo(torch.float32).min
        layer.register_parameter("weight_scale", weight_scale)

    def process_weights_after_loading(self, layer: Module) -> None:
        # Keep weight in [out, in] layout for W8A8BlockFp8LinearOp.
        layer.weight = Parameter(layer.weight.data, requires_grad=False)

        scale = layer.weight_scale
        if scale.dim() == 4:
            # [out_blk, 1, in_blk, 1] -> [out_blk, in_blk]
            scale = scale.squeeze(1).squeeze(-1)
        elif scale.dim() != 2:
            raise ValueError(
                "Unexpected ModelOpt FP8_PB_WO weight_scale shape: "
                f"{tuple(scale.shape)}."
            )

        layer.weight_scale = Parameter(scale.contiguous(), requires_grad=False)

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        return self.w8a8_block_fp8_linear.apply(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            input_scale=None,
            bias=bias,
        )


709
710
711
712
713
714
715
716
class ModelOptFp8MoEMethod(FusedMoEMethodBase):
    """MoE method for ModelOpt FP8.
    Supports loading FP8 checkpoints with static weight scale and
    activation scale.
    Args:
        quant_config: The ModelOpt quantization config.
    """

717
718
719
    def __init__(
        self,
        quant_config: ModelOptFp8Config,
720
        layer: FusedMoE,
721
    ) -> None:
722
723
        super().__init__(layer.moe_config)
        self.layer = layer
724
725
        self.quant_config = quant_config
        from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
726
727
728
            cutlass_fp8_supported,
        )

729
        self.cutlass_fp8_supported = cutlass_fp8_supported()
730
        self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
731
        if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
732
            self.flashinfer_moe_backend = get_flashinfer_moe_backend()
733
734
735
736
737
738
739
740
741
742
            if (
                self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
                and not self.moe.is_act_and_mul
            ):
                logger.info_once(
                    "Non-gated MoE is not supported for min-latency mode,"
                    "falling back to high-throughput mode"
                )
                self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS

743
            logger.info_once(
744
745
746
747
                f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
            )

    def maybe_make_prepare_finalize(
748
        self,
749
        routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
750
    ) -> mk.FusedMoEPrepareAndFinalize | None:
751
752
753
        # TRT LLM not supported with all2all yet.
        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
            return None
754
755
756
757
758
759
760
761
762
763
764
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
            # TP case: avoid convert to ModularKernelMethod - to be refactored.
            if self.moe.dp_size == 1:
                return None

            prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
                self.moe,
                use_deepseek_fp8_block_scale=False,
            )
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
765
        return super().maybe_make_prepare_finalize(routing_tables)
766
767
768
769

    def select_gemm_impl(
        self,
        prepare_finalize: mk.FusedMoEPrepareAndFinalize,
770
        layer: torch.nn.Module,
771
    ) -> mk.FusedMoEPermuteExpertsUnpermute:
772
        assert self.moe_quant_config is not None
773
        experts = select_cutlass_fp8_gemm_impl(
774
775
            self.moe,
            self.moe_quant_config,
776
777
778
        )
        logger.debug_once("Using %s", experts.__class__.__name__)
        return experts
779
780
781
782
783
784
785
786
787
788
789

    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        # Use FP8 dtype if checkpoint is serialized
790
791
792
793
794
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_fp8_serialized
            else params_dtype
        )
795
796
        weight_loader = extra_weight_attrs.get("weight_loader")

797
798
799
800
801
        if self.moe.is_act_and_mul:
            w13_up_dim = 2 * intermediate_size_per_partition
        else:
            w13_up_dim = intermediate_size_per_partition

802
        w13_weight = ModelWeightParameter(
803
804
            data=torch.empty(
                num_experts,
805
                w13_up_dim,
806
807
808
                hidden_size,
                dtype=weight_dtype,
            ),
809
810
811
812
813
814
815
            input_dim=2,
            output_dim=1,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_weight", w13_weight)

        w2_weight = ModelWeightParameter(
816
817
818
819
820
821
            data=torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=weight_dtype,
            ),
822
823
824
825
826
827
828
829
            input_dim=2,
            output_dim=1,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w2_weight", w2_weight)

        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALES - Per-tensor scaling for ModelOpts
830
            # For gated MoE, allocate 2 scales for w1 and w3 respectively.
831
            # They will be combined to a single scale after weight loading.
832
833
834
835
836
            # For non-gated MoE, allocate 1 scale for w13.
            if self.moe.is_act_and_mul:
                w13_weight_scale_shape = (num_experts, 2)
            else:
                w13_weight_scale_shape = (num_experts, 1)
837
838
            w13_weight_scale = PerTensorScaleParameter(
                data=torch.full(
839
                    w13_weight_scale_shape,
840
841
842
843
844
845
                    1.0,
                    dtype=torch.float32,
                ),
                weight_loader=weight_loader,
            )
            w2_weight_scale = PerTensorScaleParameter(
846
                data=torch.full((num_experts,), 1.0, dtype=torch.float32),
847
848
849
850
851
852
853
                weight_loader=weight_loader,
            )
            layer.register_parameter("w13_weight_scale", w13_weight_scale)
            layer.register_parameter("w2_weight_scale", w2_weight_scale)

            # Set weight loader attributes for scales
            extra_weight_attrs.update(
854
855
                {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
            )
856
857
858

            # INPUT SCALES - Per-tensor scaling for ModelOpt
            w13_input_scale = PerTensorScaleParameter(
859
                data=torch.full((num_experts,), 1.0, dtype=torch.float32),
860
861
862
                weight_loader=weight_loader,
            )
            w2_input_scale = PerTensorScaleParameter(
863
                data=torch.full((num_experts,), 1.0, dtype=torch.float32),
864
865
866
867
868
869
870
871
872
873
                weight_loader=weight_loader,
            )
            layer.register_parameter("w13_input_scale", w13_input_scale)
            layer.register_parameter("w2_input_scale", w2_input_scale)

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        """Process FP8 MoE weights after loading from serialized checkpoint.
        Only supports pre-quantized checkpoints with FP8 weights and scales.
        """

874
875
876
        if self.flashinfer_moe_backend is not None:
            self._maybe_pad_intermediate_for_flashinfer(layer)

877
        layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
878
879
880
881
        layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)

        from vllm._custom_ops import scaled_fp8_quant
        from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
882
883
            per_tensor_dequantize,
        )
884
885

        # Handle scale parameters
886
        if hasattr(layer, "w13_weight_scale") and layer.w13_weight_scale is not None:
887
888
889
            # Fp8 moe kernel needs single weight scale for w13 per expert.
            # We take the max of the w1 and w3 scales
            # then dequant and requant each expert.
890
891
892
893
894
895
896
897
            if (
                layer.w13_weight_scale.dim() == 2
                and layer.w13_weight_scale.shape[1] == 2
            ):
                assert self.moe.is_act_and_mul, (
                    "w13_weight_scale should have 2 elements per expert "
                    "only for gated MoE"
                )
898
899
900
901
902
903
904
905
906
907
908
909
                # Get the maximum scale across w1 and w3 for each expert
                max_w13_scales = layer.w13_weight_scale.max(dim=1).values

                # Requantize each expert's weights using the combined scale
                # w13_weight (num_experts, 2 * intermediate_size, hidden_size)
                # where the first intermediate_size rows are w1, the next are w3
                intermediate_size = layer.w13_weight.shape[1] // 2
                for expert_id in range(layer.w13_weight.shape[0]):
                    start = 0
                    for shard_id in range(2):  # w1 and w3
                        # Dequantize using the original scale for this shard
                        dq_weight = per_tensor_dequantize(
910
911
912
                            layer.w13_weight[expert_id][
                                start : start + intermediate_size, :
                            ],
913
914
915
916
917
                            layer.w13_weight_scale[expert_id][shard_id],
                        )
                        # Requantize using the combined max scale

                        (
918
919
920
                            layer.w13_weight[expert_id][
                                start : start + intermediate_size, :
                            ],
921
                            _,
922
                        ) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
923
924
925
926

                        start += intermediate_size

                # Update the scale parameter to be per-expert
927
                layer.w13_weight_scale = Parameter(max_w13_scales, requires_grad=False)
928
            else:
929
930
931
                layer.w13_weight_scale = Parameter(
                    layer.w13_weight_scale.data, requires_grad=False
                )
932

933
934
935
936
        if hasattr(layer, "w2_weight_scale") and layer.w2_weight_scale is not None:
            layer.w2_weight_scale = Parameter(
                layer.w2_weight_scale.data, requires_grad=False
            )
937
        # Input scales must be equal for each expert in fp8 MoE layers.
938
939
940
941
942
943
944
945
        if hasattr(layer, "w13_input_scale") and layer.w13_input_scale is not None:
            layer.w13_input_scale = Parameter(
                layer.w13_input_scale.max(), requires_grad=False
            )
        if hasattr(layer, "w2_input_scale") and layer.w2_input_scale is not None:
            layer.w2_input_scale = Parameter(
                layer.w2_input_scale.max(), requires_grad=False
            )
946

947
        if self.flashinfer_moe_backend is not None:
948
949
            if self.moe.is_act_and_mul:
                layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
950
            if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
951
                rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight)
952
        register_moe_scaling_factors(layer)
953

954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
    def _maybe_pad_intermediate_for_flashinfer(self, layer: torch.nn.Module) -> None:
        """Pad intermediate size so FlashInfer kernels' alignment constraints hold.

        Some FlashInfer FP8 MoE kernels require the (gated) intermediate size
        used for GEMM to be divisible by a small alignment value. When this is
        not satisfied (e.g. with certain tensor-parallel sizes), we pad the
        gate/up and down projection weights along the intermediate dim.
        """
        if not hasattr(layer, "w13_weight") or not hasattr(layer, "w2_weight"):
            return

        # Current local intermediate size (per partition) is the K dimension of
        # the down projection.
        num_experts, hidden_size, intermediate = layer.w2_weight.shape

        min_alignment = 16
        padded_intermediate = round_up(intermediate, min_alignment)

        if padded_intermediate == intermediate:
            return

        logger.info(
            "Padding intermediate size from %d to %d for up/down projection weights.",
            intermediate,
            padded_intermediate,
        )

        up_mult = 2 if self.moe.is_act_and_mul else 1
        padded_gate_up_dim = up_mult * padded_intermediate

        # Pad w13 and w12 along its intermediate dimension.
        w13 = layer.w13_weight.data
        padded_w13 = w13.new_zeros((num_experts, padded_gate_up_dim, hidden_size))
        padded_w13[:, : w13.shape[1], :] = w13
        layer.w13_weight.data = padded_w13

        w2 = layer.w2_weight.data
        padded_w2 = w2.new_zeros((num_experts, hidden_size, padded_intermediate))
        padded_w2[:, :, :intermediate] = w2
        layer.w2_weight.data = padded_w2

        if hasattr(layer, "intermediate_size_per_partition"):
            layer.intermediate_size_per_partition = padded_intermediate

998
    def get_fused_moe_quant_config(
999
        self, layer: torch.nn.Module
1000
    ) -> FusedMoEQuantConfig | None:
1001
1002
1003
1004
1005
        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
            return None

        return fp8_w8a8_moe_quant_config(
            w1_scale=layer.w13_weight_scale,
1006
            g1_alphas=layer.output1_scales_gate_scalar.squeeze(),
1007
            w2_scale=layer.w2_weight_scale,
1008
            g2_alphas=layer.output2_scales_scalar.squeeze(),
1009
            a1_scale=layer.w13_input_scale,
1010
            a1_gscale=layer.w13_input_scale,
1011
            a2_scale=layer.w2_input_scale,
1012
            a2_gscale=layer.w2_input_scale_inv,
1013
1014
1015
            per_act_token_quant=False,
        )

1016
1017
    def apply(
        self,
1018
        layer: FusedMoE,
1019
1020
        x: torch.Tensor,
        router_logits: torch.Tensor,
1021
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
1022
        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
1023
1024
1025
1026
            if layer.enable_eplb:
                raise NotImplementedError(
                    "EPLB not supported for `ModelOptFp8MoEMethod` yet."
                )
1027
1028
            assert layer.activation == "silu", (
                f"Expected 'silu' activation but got {layer.activation}"
1029
            )
1030
1031

            assert not layer.renormalize
1032
1033
1034
1035
            return apply_flashinfer_per_tensor_scale_fp8(
                layer=layer,
                hidden_states=x,
                router_logits=router_logits,
1036
1037
1038
1039
1040
1041
                routing_bias=layer.e_score_correction_bias,
                global_num_experts=layer.global_num_experts,
                top_k=layer.top_k,
                num_expert_group=layer.num_expert_group,
                topk_group=layer.topk_group,
                apply_router_weight_on_input=layer.apply_router_weight_on_input,
1042
            )
1043

1044
        # Expert selection
1045
        topk_weights, topk_ids = layer.select_experts(
1046
1047
1048
            hidden_states=x,
            router_logits=router_logits,
        )
1049

1050
        if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
1051
            assert layer.activation in ("silu", "relu2_no_mul"), (
1052
                "Expected activation to be in ('silu', 'relu2_no_mul'),"
1053
                f"but got {layer.activation}"
1054
            )
1055
1056
1057
1058
1059
1060
            return flashinfer_cutlass_moe_fp8(
                x,
                layer,
                topk_weights,
                topk_ids,
                inplace=False,
1061
1062
1063
1064
                activation=layer.activation,
                global_num_experts=layer.global_num_experts,
                expert_map=layer.expert_map,
                apply_router_weight_on_input=layer.apply_router_weight_on_input,
1065
1066
            )
        else:
1067
1068
            from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts

1069
1070
1071
1072
1073
1074
1075
1076
1077
            assert self.moe_quant_config is not None

            return fused_experts(
                x,
                layer.w13_weight,
                layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                inplace=True,
1078
                activation=layer.activation,
1079
                quant_config=self.moe_quant_config,
1080
1081
1082
                global_num_experts=layer.global_num_experts,
                expert_map=layer.expert_map,
                apply_router_weight_on_input=layer.apply_router_weight_on_input,
1083
            )
1084
1085


1086
1087
1088
1089
1090
1091
ModelOptFp8Config.LinearMethodCls = ModelOptFp8LinearMethod
ModelOptFp8Config.FusedMoEMethodCls = ModelOptFp8MoEMethod
ModelOptFp8Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod


class ModelOptNvFp4Config(ModelOptQuantConfigBase):
1092
1093
1094
1095
1096
    """Config class for ModelOpt FP4."""

    def __init__(
        self,
        is_checkpoint_nvfp4_serialized: bool,
1097
        kv_cache_quant_algo: str | None,
1098
        exclude_modules: list[str],
1099
1100
        group_size: int = 16,
    ) -> None:
1101
        super().__init__(exclude_modules)
1102
1103
1104
1105
        self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
        if is_checkpoint_nvfp4_serialized:
            logger.warning(
                "Detected ModelOpt NVFP4 checkpoint. Please note that"
1106
1107
                " the format is experimental and could change in future."
            )
1108
1109
1110
1111

            self.group_size = group_size
            self.kv_cache_quant_algo = kv_cache_quant_algo

1112
    def get_name(self) -> QuantizationMethods:
1113
        return "modelopt_fp4"
1114

1115
    def get_supported_act_dtypes(self) -> list[torch.dtype]:
1116
1117
1118
1119
        return [torch.bfloat16, torch.half, torch.float8_e4m3fn]

    @classmethod
    def get_min_capability(cls) -> int:
1120
        return 75
1121

1122
1123
    @classmethod
    def override_quantization_method(
1124
        cls, hf_quant_cfg, user_quant
1125
    ) -> QuantizationMethods | None:
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
        """Detect if this ModelOpt FP4 config should be used based on
        quantization config."""
        if hf_quant_cfg is None:
            return None

        # Use the community standard 'quant_method'
        quant_method = hf_quant_cfg.get("quant_method", "").lower()

        # Only proceed if the method is explicitly "modelopt"
        if quant_method != "modelopt":
            return None

        # Look for ModelOpt-specific config structure
        if "quantization" in hf_quant_cfg:
            quant_config = hf_quant_cfg["quantization"]
            if isinstance(quant_config, dict):
                quant_algo = quant_config.get("quant_algo", "")
                if "NVFP4" in quant_algo:
                    return "modelopt_fp4"
        else:
            # Check for compressed-tensors style config with specific
            # quant_algo field
            quant_algo = hf_quant_cfg.get("quant_algo", "")
            if isinstance(quant_algo, str) and "FP4" in quant_algo.upper():
                return "modelopt_fp4"

        return None

1154
    @classmethod
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
    def _from_config(
        cls,
        *,
        quant_method: str,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
        original_config: dict[str, Any],
        group_size: int | None,
        **kwargs: Any,
    ) -> "ModelOptNvFp4Config":
1165
        is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
1166

1167
1168
1169
        if group_size is None:
            group_size = 16  # Default value

1170
        # For FP4, these fields are required
1171
        if is_checkpoint_nvfp4_serialized and "quantization" in original_config:
1172
            # Check if required fields are present in the quantization config
1173
            quant_config = original_config["quantization"]
1174
            required_fields = ["group_size", "kv_cache_quant_algo", "exclude_modules"]
1175
1176
1177
1178
1179
1180
            missing_fields = [
                field for field in required_fields if field not in quant_config
            ]
            if missing_fields:
                raise ValueError(
                    f"NVFP4 quantization requires the following fields in "
1181
1182
1183
1184
1185
                    f"hf_quant_config.json: {missing_fields}"
                )

        return cls(
            is_checkpoint_nvfp4_serialized,
1186
            kv_cache_quant_method,
1187
1188
1189
            exclude_modules,
            group_size,
        )
1190
1191
1192
1193
1194


class ModelOptNvFp4LinearMethod(LinearMethodBase):
    """Linear method for Model Optimizer NVFP4.
    Supports loading NVFP4 checkpoints with the following structure:
1195

1196
1197
1198
1199
1200
1201
1202
    input_scale: torch.float32, scalar ,
    weight: NVFP4(represented as byte) Shape: [1, X, y/2]
    weight_scale: FP8-E4M3, Shape: [X, Y], aka per block scale,
    weight_scale_2: torch.float32, scalar,
    Args: quant_config: The ModelOpt quantization config.
    """

1203
    def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
1204
        self.quant_config = quant_config
1205
        self.marlin_input_dtype = None
1206

1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
        self.backend = "none"
        if envs.VLLM_NVFP4_GEMM_BACKEND is None:
            if has_flashinfer():
                self.backend = "flashinfer-cutlass"
            elif cutlass_fp4_supported():
                self.backend = "cutlass"
            elif is_fp4_marlin_supported():
                self.backend = "marlin"
        elif envs.VLLM_NVFP4_GEMM_BACKEND.startswith("flashinfer-"):
            self.backend = envs.VLLM_NVFP4_GEMM_BACKEND
            assert has_flashinfer(), f"FlashInfer is required for {self.backend}"
1218
1219
1220
        elif envs.VLLM_NVFP4_GEMM_BACKEND == "cutlass":
            self.backend = "cutlass"
            assert cutlass_fp4_supported(), f"Cutlass is required for {self.backend}"
1221
1222

        if self.backend == "none":
1223
            raise ValueError(
1224
1225
                "No valid NVFP4 GEMM backend found. "
                "Please check your platform capability."
1226
            )
1227

1228
1229
        logger.info_once(f"Using {self.backend} for NVFP4 GEMM")

1230
1231
1232
1233
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
1234
        output_partition_sizes: list[int],
1235
1236
1237
1238
1239
1240
1241
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size
        if not self.quant_config.is_checkpoint_nvfp4_serialized:
1242
1243
1244
1245
            raise ValueError(
                "NVFP4 quantization was selected, "
                " dynamic quantization is not supported."
            )
1246
1247
1248
1249
1250
1251
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition

1252
1253
1254
1255
        if input_size_per_partition % 16 != 0:
            raise ValueError(
                "Unsupported model when in features size is not multiple of 16"
            )
1256
        # The nvfp4 weight is still represented as
1257
1258
1259
1260
1261
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_nvfp4_serialized
            else params_dtype
        )
1262
1263
1264
1265
1266
1267
        # Weight
        weight = ModelWeightParameter(
            data=torch.empty(
                # 2 fp4 items are packed in the input dimension
                layer.output_size_per_partition,
                layer.input_size_per_partition // 2,
1268
1269
                dtype=torch.uint8,
            ),
1270
1271
            input_dim=1,
            output_dim=0,
1272
1273
            weight_loader=weight_loader,
        )
1274
1275
1276
        layer.register_parameter("weight", weight)

        # Input Weight Scale
1277
1278
1279
1280
        input_scale = PerTensorScaleParameter(
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )
1281
1282
1283
        layer.register_parameter("input_scale", input_scale)

        # Global Weight Scale
1284
1285
1286
1287
        weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )
1288
1289
1290
        layer.register_parameter("weight_scale_2", weight_scale_2)

        # Per Block Weight Scale
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
        weight_scale = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition // self.quant_config.group_size,
                dtype=weight_dtype,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311

        layer.register_parameter("weight_scale", weight_scale)

    def process_weights_after_loading(self, layer: Module) -> None:
        # global scales:
        input_scale_2 = layer.input_scale.max().to(torch.float32)
        layer.input_scale = Parameter(input_scale_2, requires_grad=False)

        weight_scale_2 = layer.weight_scale_2.max().to(torch.float32)
        layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False)

1312
1313
1314
        layer.alpha = Parameter(
            layer.input_scale * layer.weight_scale_2, requires_grad=False
        )
1315

1316
1317
        # Calculate `1 / input_scale` so that we don't need to do so at runtime
        layer.input_scale_inv = Parameter(
1318
1319
            (1 / layer.input_scale).to(torch.float32), requires_grad=False
        )
1320

1321
1322
1323
        # Swizzle the weight blockscale.
        # contracting dimension is input dimension
        # block_size = 16;
1324
1325
1326
        assert layer.weight_scale.dtype == torch.float8_e4m3fn, (
            "Weight Block scale must be represented as FP8-E4M3"
        )
1327

1328
1329
1330
1331
1332
        if self.backend == "marlin":
            prepare_fp4_layer_for_marlin(layer)
            del layer.alpha
            del layer.input_scale
        elif self.backend == "flashinfer-trtllm":
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
            # FlashInfer TRTLLM FP4 GEMM requires a different weight layout.
            # FlashInfer provides nvfp4_quantize to quantize + shuffle the
            # layout but we use our own quantization so we have to call
            # shuffles ourselves.
            from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a

            weight = layer.weight.data
            weight_scale = layer.weight_scale.data

            epilogue_tile_m = 128
1343
1344
1345
1346
1347
1348
            weight = shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m)
            weight_scale = (
                shuffle_matrix_sf_a(weight_scale.view(torch.uint8), epilogue_tile_m)
                .reshape(weight_scale.shape)
                .view(torch.float8_e4m3fn)
            )
1349

1350
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
1351
1352
1353
            layer.weight = Parameter(weight, requires_grad=False)
        else:
            swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
1354
            layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False)
1355
            layer.weight = Parameter(layer.weight.data, requires_grad=False)
1356
1357
1358
1359
1360

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
1361
        bias: torch.Tensor | None = None,
1362
    ) -> torch.Tensor:
1363
        if self.backend == "marlin":
1364
1365
1366
1367
1368
1369
1370
1371
            return apply_fp4_marlin_linear(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale,
                weight_scale_2=layer.weight_scale_2,
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
1372
                bias=bias,
1373
                input_dtype=self.marlin_input_dtype,
1374
            )
1375

1376
        output_dtype = x.dtype
1377
        output_shape = [x.shape[0], layer.weight.shape[0]]
1378
1379

        # quantize BF16 or FP16 to (FP4 and interleaved block scale)
1380
        x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_scale_inv)
1381
1382
1383

        # validate dtypes of quantized input, input block scale,
        # weight and weight_blockscale
1384
1385
1386
1387
1388
        assert x_fp4.dtype == torch.uint8
        assert layer.weight.dtype == torch.uint8
        assert x_blockscale.dtype == torch.float8_e4m3fn
        assert layer.weight_scale.dtype == torch.float8_e4m3fn
        assert layer.alpha.dtype == torch.float32
1389

1390
1391
1392
1393
        mm_args = (
            x_fp4,
            layer.weight,
            x_blockscale,
1394
            layer.weight_scale,
1395
1396
1397
            layer.alpha,
            output_dtype,
        )
1398
1399
1400
        if self.backend.startswith("flashinfer-"):
            backend_name = self.backend[len("flashinfer-") :]
            out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name)
1401
        else:
1402
            assert self.backend == "cutlass"
1403
1404
            out = cutlass_scaled_fp4_mm(*mm_args)

1405
1406
1407
        if bias is not None:
            out = out + bias
        return out.view(*output_shape)
1408
1409
1410
1411
1412


class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
    """
    MoE Method for FP4 Quantization.
1413
    Args:
1414
1415
1416
        quant_config: NVFP4 Quant Config
    """

1417
1418
1419
    def __init__(
        self,
        quant_config: ModelOptNvFp4Config,
1420
        layer: FusedMoE,
1421
    ) -> None:
1422
1423
        from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import (
            detect_nvfp4_moe_support,  # noqa: E501
1424
1425
        )

1426
        super().__init__(layer.moe_config)
1427
1428
        self.quant_config = quant_config
        self.layer = layer
1429
1430
        _nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
        self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
1431
        self.allow_flashinfer = _nvfp4.allow_flashinfer
1432
        self.use_marlin = _nvfp4.use_marlin
1433
        self.marlin_input_dtype = None
1434
1435
        self.flashinfer_moe_backend = None
        if self.allow_flashinfer:
1436
1437
1438
            self.flashinfer_moe_backend = get_flashinfer_moe_backend()
            logger.info_once(
                f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
1439
1440
                " for ModelOptNvFp4FusedMoE."
            )
1441
1442
1443
1444
        elif self.use_marlin:
            logger.info_once("Using Marlin for ModelOptNvFp4FusedMoE.")
        else:
            logger.info_once("Using Cutlass for ModelOptNvFp4FusedMoE.")
1445

1446
1447
1448
1449
    def maybe_make_prepare_finalize(
        self,
        routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
    ) -> mk.FusedMoEPrepareAndFinalize | None:
1450
1451
1452
1453
        if self.use_marlin or (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1454
            return None
1455
1456
1457
1458
        elif (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
        ):
1459
1460
1461
            # TP case: avoid convert to ModularKernelMethod - to be refactored.
            if self.moe.dp_size == 1:
                return None
1462
            # For now, fp4 moe only works with the flashinfer dispatcher.
1463
1464
1465
            prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(
                self.moe
            )
1466
1467
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
1468
        else:
1469
            return super().maybe_make_prepare_finalize(routing_tables)
1470

1471
1472
1473
    def select_gemm_impl(
        self,
        prepare_finalize: mk.FusedMoEPrepareAndFinalize,
1474
        layer: torch.nn.Module,
1475
    ) -> mk.FusedMoEPermuteExpertsUnpermute:
1476
        assert self.moe_quant_config is not None
1477
        experts = select_nvfp4_gemm_impl(
1478
1479
            self.moe,
            self.moe_quant_config,
1480
1481
1482
1483
            allow_flashinfer=self.allow_flashinfer,
        )
        logger.debug_once("Using %s", experts.__class__.__name__)
        return experts
1484

1485
1486
1487
1488
1489
1490
    def uses_weight_scale_2_pattern(self) -> bool:
        """
        FP4 variants use 'weight_scale_2' pattern for per-tensor weight scales.
        """
        return True

1491
1492
1493
1494
1495
1496
1497
1498
1499
    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
1500
        if not self.quant_config.is_checkpoint_nvfp4_serialized:
1501
1502
1503
1504
            raise ValueError(
                "NVFP4 quantization was selected, "
                " dynamic quantization is not supported."
            )
1505

1506
1507
        layer.num_experts = num_experts
        layer.params_dtype = params_dtype
1508
1509
1510
1511
        layer.quant_config = self.quant_config
        weight_dtype = torch.uint8
        weight_scale_dtype = torch.float8_e4m3fn
        weight_loader = extra_weight_attrs.get("weight_loader")
1512
        global_num_experts = extra_weight_attrs.get("global_num_experts")
1513
1514
1515
1516
        # GEMM 1
        w13_weight = ModelWeightParameter(
            data=torch.empty(
                num_experts,
1517
                (2 if self.moe.is_act_and_mul else 1) * intermediate_size_per_partition,
1518
1519
                # 2 fp4 items are packed in the input dimension
                hidden_size // 2,
1520
1521
                dtype=weight_dtype,
            ),
1522
1523
            input_dim=1,
            output_dim=2,
1524
1525
            weight_loader=weight_loader,
        )
1526
1527
1528
1529
1530
1531
1532
1533
1534
        layer.register_parameter("w13_weight", w13_weight)

        # GEMM 2
        w2_weight = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                hidden_size,
                # 2 fp4 items are packed in the input dimension
                intermediate_size_per_partition // 2,
1535
1536
                dtype=weight_dtype,
            ),
1537
1538
            input_dim=1,
            output_dim=2,
1539
1540
            weight_loader=weight_loader,
        )
1541
1542
1543
1544
1545
        layer.register_parameter("w2_weight", w2_weight)

        w13_weight_scale = ModelWeightParameter(
            data=torch.empty(
                num_experts,
1546
                (2 if self.moe.is_act_and_mul else 1) * intermediate_size_per_partition,
1547
1548
                # 2 fp4 items are packed in the input dimension
                hidden_size // self.quant_config.group_size,
1549
1550
                dtype=weight_scale_dtype,
            ),
1551
1552
            input_dim=1,
            output_dim=2,
1553
1554
            weight_loader=weight_loader,
        )
1555
1556
1557
1558
1559
1560
1561
        layer.register_parameter("w13_weight_scale", w13_weight_scale)

        w2_weight_scale = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                hidden_size,
                # 2 fp4 items are packed in the input dimension
1562
1563
1564
                intermediate_size_per_partition // self.quant_config.group_size,
                dtype=weight_scale_dtype,
            ),
1565
1566
            input_dim=1,
            output_dim=2,
1567
1568
            weight_loader=weight_loader,
        )
1569
1570
1571
        layer.register_parameter("w2_weight_scale", w2_weight_scale)

        extra_weight_attrs.update(
1572
1573
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
        )
1574
1575

        w13_weight_scale_2 = PerTensorScaleParameter(
1576
1577
1578
            data=torch.empty(
                num_experts, 2 if self.moe.is_act_and_mul else 1, dtype=torch.float32
            ),
1579
1580
            weight_loader=weight_loader,
        )
1581
1582
1583
1584
        layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)

        w2_weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(num_experts, dtype=torch.float32),
1585
1586
            weight_loader=weight_loader,
        )
1587
1588
1589
        layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2)

        extra_weight_attrs.update(
1590
1591
            {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
        )
1592

1593
1594
1595
1596
1597
        use_global_sf = self.allow_flashinfer and is_flashinfer_supporting_global_sf(
            self.flashinfer_moe_backend
        )
        global_scale_num_experts = global_num_experts if use_global_sf else num_experts

1598
        w13_input_scale = PerTensorScaleParameter(
1599
1600
1601
1602
1603
            data=torch.empty(
                global_scale_num_experts,
                2 if self.moe.is_act_and_mul else 1,
                dtype=torch.float32,
            ),
1604
1605
            weight_loader=weight_loader,
        )
1606
1607
        layer.register_parameter("w13_input_scale", w13_input_scale)

1608
        w2_input_scale = PerTensorScaleParameter(
1609
            data=torch.empty(global_scale_num_experts, dtype=torch.float32),
1610
1611
            weight_loader=weight_loader,
        )
1612
1613
1614
        layer.register_parameter("w2_input_scale", w2_input_scale)

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
1615
        # GEMM 1 processing
1616
1617
1618
        gemm1_weight = layer.w13_weight.data
        gemm1_weight_scale = layer.w13_weight_scale.data

1619
1620
1621
1622
1623
1624
1625
        if (
            self.allow_flashinfer
            and (
                self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
                or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
            )
            and self.moe.is_act_and_mul
1626
        ):
1627
            gemm1_weight, gemm1_weight_scale = reorder_w1w3_to_w3w1(
1628
1629
                gemm1_weight, gemm1_weight_scale, dim=-2
            )
1630
1631

        layer.w13_weight = Parameter(gemm1_weight, requires_grad=False)
1632
        layer.w13_weight_scale = Parameter(gemm1_weight_scale, requires_grad=False)
1633

1634
        # Common processing for w13_weight_scale_2
1635
        if self.moe.is_act_and_mul and not torch.allclose(
1636
1637
            layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]
        ):
1638
1639
            logger.warning_once(
                "w1_weight_scale_2 must match w3_weight_scale_2. "
1640
1641
                "Accuracy may be affected."
            )
1642

1643
        w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0].contiguous()
1644
        layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False)
1645

1646
        # Common processing for input scales and alphas
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
        use_global_sf = self.allow_flashinfer and is_flashinfer_supporting_global_sf(
            self.flashinfer_moe_backend
        )
        if use_global_sf:
            # For backends provide by Flashinfer, the input global scales are
            # shared across all experts.
            w13_input_scale = (
                layer.w13_input_scale.max().to(torch.float32).expand(layer.num_experts)
            )
        else:
            w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
1658
1659
        layer.g1_alphas = Parameter(
            (w13_input_scale * w13_weight_scale_2).to(torch.float32),
1660
1661
            requires_grad=False,
        )
1662
1663
1664

        # This is for quantization, so we need to invert it.
        layer.w13_input_scale_quant = Parameter(
1665
1666
            (1 / w13_input_scale).to(torch.float32), requires_grad=False
        )
1667

1668
        # GEMM 2 processing
1669
1670
1671
1672
1673
1674
1675
1676
        if use_global_sf:
            # For backends provide by Flashinfer, the input global scales are
            # shared across all experts.
            w2_input_scale = (
                layer.w2_input_scale.max().to(torch.float32).expand(layer.num_experts)
            )
        else:
            w2_input_scale = layer.w2_input_scale
1677
        layer.g2_alphas = Parameter(
1678
            (w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
1679
1680
            requires_grad=False,
        )
1681
1682
1683

        # This is for quantization, so we need to invert it.
        layer.w2_input_scale_quant = Parameter(
1684
            (1 / w2_input_scale).to(torch.float32), requires_grad=False
1685
        )
1686

1687
        # TensorRT-LLM specific processing
1688
1689
1690
1691
        if (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1692
            # Prepare static weights for TRT-LLM kernel
1693
            # alternate: prepare_static_weight_layouts_for_trtllm_moe
1694
1695
1696
1697
1698
            (
                gemm1_weights_fp4_shuffled,
                gemm1_scales_fp4_shuffled,
                gemm2_weights_fp4_shuffled,
                gemm2_scales_fp4_shuffled,
1699
            ) = prepare_static_weights_for_trtllm_fp4_moe(
1700
1701
1702
1703
1704
1705
1706
1707
                layer.w13_weight,
                layer.w2_weight,
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                layer.w2_weight.size(-2),  # hidden_size
                layer.w13_weight.size(-2) // 2,  # intermediate_size
                layer.w13_weight.size(0),  # num_experts
            )
1708
            logger.debug_once("Finished shuffling weights for TRT-LLM MOE")
1709

1710
            layer.w13_weight = Parameter(
1711
1712
                gemm1_weights_fp4_shuffled, requires_grad=False
            )
1713
1714
            layer.w2_weight = Parameter(gemm2_weights_fp4_shuffled, requires_grad=False)
            layer.w13_weight_scale = Parameter(
1715
1716
                gemm1_scales_fp4_shuffled, requires_grad=False
            )
1717
            layer.w2_weight_scale = Parameter(
1718
1719
                gemm2_scales_fp4_shuffled, requires_grad=False
            )
1720
1721
1722

            # Additional parameter needed for TRT-LLM
            layer.g1_scale_c = Parameter(
1723
                (layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32),
1724
1725
                requires_grad=False,
            )
1726
1727
1728
1729
1730
1731
1732
        elif self.use_marlin:
            # Marlin processing
            prepare_moe_fp4_layer_for_marlin(layer)
            del layer.g1_alphas
            del layer.g2_alphas
            del layer.w13_input_scale_quant
            del layer.w2_input_scale_quant
1733
1734
        else:
            # Non-TRT-LLM processing (Cutlass or non-flashinfer)
1735
1736
1737
1738
1739
            w13_blockscale_swizzled = swizzle_blockscale(layer.w13_weight_scale)
            layer.w13_weight_scale = Parameter(
                w13_blockscale_swizzled, requires_grad=False
            )

1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
            w13_weight = layer.w13_weight
            intermediate_size_pad = w13_blockscale_swizzled.size(1) - w13_weight.size(1)
            if intermediate_size_pad:
                # padding gated activations will require to split w1 and w3
                # and pad them individually
                assert not self.moe.is_act_and_mul, (
                    "The intermediate size required padding, "
                    "but padding is not implemented for gated activations"
                )

                layer.w13_weight = Parameter(
                    torch.nn.functional.pad(
                        w13_weight, (0, 0, 0, intermediate_size_pad)
                    ),
                    requires_grad=False,
                )
                layer.w2_weight = Parameter(
                    torch.nn.functional.pad(
                        layer.w2_weight, (0, intermediate_size_pad // 2, 0, 0)
                    ),
                    requires_grad=False,
                )
                layer.w2_weight_scale = Parameter(
                    torch.nn.functional.pad(
                        layer.w2_weight_scale, (0, intermediate_size_pad // 16)
                    ),
                    requires_grad=False,
                )

1769
            w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale)
1770
1771
1772
            layer.w2_weight_scale = Parameter(
                w2_blockscale_swizzled, requires_grad=False
            )
1773

1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
    def prepare_dp_allgather_tensor(
        self,
        layer: FusedMoE,
        hidden_states: torch.Tensor,
        router_logits: torch.Tensor,
    ) -> tuple[torch.Tensor, list[torch.Tensor]]:
        """Optionally prepare extra tensors to carry through DP allgather/EP."""
        import flashinfer

        a1_gscale = layer.w13_input_scale_quant
        hidden_states_fp4, hidden_states_sf = flashinfer.fp4_quantize(
            hidden_states,
            a1_gscale,
            is_sf_swizzled_layout=False,
        )
        extra_tensors: list[torch.Tensor] = [hidden_states_sf]
        return hidden_states_fp4, extra_tensors

1792
    def get_fused_moe_quant_config(
1793
        self, layer: torch.nn.Module
1794
    ) -> FusedMoEQuantConfig | None:
1795
1796
1797
1798
        if (
            self.use_marlin
            or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
            return None

        return nvfp4_moe_quant_config(
            w1_scale=layer.w13_weight_scale,
            w2_scale=layer.w2_weight_scale,
            g1_alphas=layer.g1_alphas,
            g2_alphas=layer.g2_alphas,
            a1_gscale=layer.w13_input_scale_quant,
            a2_gscale=layer.w2_input_scale_quant,
        )

1810
1811
1812
1813
    @property
    def supports_eplb(self) -> bool:
        return True

1814
1815
    def apply(
        self,
1816
        layer: FusedMoE,
1817
1818
        x: torch.Tensor,
        router_logits: torch.Tensor,
1819
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
1820
1821
1822
1823
1824
1825
1826
1827
        if not self.moe.is_act_and_mul:
            assert (
                self.allow_flashinfer
                and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
            ), (
                "Non-gated activations are only supported by the"
                " flashinfer CUTLASS backend for modelopt checkpoints"
            )
1828

1829
1830
1831
        if (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
1832
            and not layer.enable_eplb
1833
        ):
1834
1835
1836
1837
            return flashinfer_trtllm_fp4_moe(
                layer=layer,
                x=x,
                router_logits=router_logits,
1838
1839
1840
1841
1842
1843
                top_k=layer.top_k,
                global_num_experts=layer.global_num_experts,
                num_expert_group=layer.num_expert_group,
                topk_group=layer.topk_group,
                custom_routing_function=layer.custom_routing_function,
                e_score_correction_bias=layer.e_score_correction_bias,
1844
            )
1845

1846
1847
1848
1849
1850
        # Hidden_states in select_experts is only used to extract metadata
        if isinstance(x, tuple):
            x_routing, _ = x
        else:
            x_routing = x
1851
        topk_weights, topk_ids = layer.select_experts(
1852
            hidden_states=x_routing,
1853
            router_logits=router_logits,
1854
        )
1855

1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
        # EPLB path
        if (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
            return flashinfer_trtllm_fp4_routed_moe(
                layer=layer,
                x=x,
                topk_ids=topk_ids,
                topk_weights=topk_weights,
                top_k=layer.top_k,
                global_num_experts=layer.global_num_experts,
            )

1870
        if self.use_marlin:
1871
            return fused_marlin_moe(
1872
1873
1874
                x,
                layer.w13_weight,
                layer.w2_weight,
1875
1876
                None,
                None,
1877
1878
1879
1880
1881
1882
1883
1884
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                router_logits,
                topk_weights,
                topk_ids,
                global_scale1=layer.w13_weight_scale_2,
                global_scale2=layer.w2_weight_scale_2,
                quant_type_id=scalar_types.float4_e2m1f.id,
1885
1886
1887
                apply_router_weight_on_input=layer.apply_router_weight_on_input,
                global_num_experts=layer.global_num_experts,
                expert_map=layer.expert_map,
1888
                input_dtype=self.marlin_input_dtype,
1889
            )
1890

1891
1892
1893
1894
        elif self.allow_flashinfer:
            assert self.flashinfer_moe_backend in (
                FlashinferMoeBackend.CUTLASS,
                FlashinferMoeBackend.CUTEDSL,
1895
            )
1896
1897
1898
1899
            if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
                from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (  # noqa: E501
                    flashinfer_cutlass_moe_fp4,
                )
1900

1901
1902
1903
1904
1905
1906
1907
                flashinfer_fn_moe_fp4 = flashinfer_cutlass_moe_fp4
            else:
                from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import (  # noqa: E501
                    flashinfer_cutedsl_moe_fp4,
                )

                flashinfer_fn_moe_fp4 = flashinfer_cutedsl_moe_fp4
1908

1909
1910
            assert self.moe_quant_config is not None
            return flashinfer_fn_moe_fp4(
1911
1912
1913
1914
1915
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
1916
1917
                quant_config=self.moe_quant_config,
                inplace=False,
1918
1919
1920
1921
                activation=layer.activation,
                global_num_experts=layer.global_num_experts,
                expert_map=layer.expert_map,
                apply_router_weight_on_input=layer.apply_router_weight_on_input,
1922
1923
            )
        else:
1924
1925
            # If no modular kernel is provided, use cutlass_moe_fp4 for TP case
            # only (no EP).
1926
1927
            from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4

1928
1929
            assert self.moe_quant_config is not None
            return cutlass_moe_fp4(
1930
1931
1932
1933
1934
                a=x,
                w1_fp4=layer.w13_weight,
                w2_fp4=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
1935
                quant_config=self.moe_quant_config,
1936
1937
                expert_map=layer.expert_map,
                apply_router_weight_on_input=layer.apply_router_weight_on_input,
1938
                # TODO: derive from arguments
1939
1940
1941
1942
                m=x.shape[0],
                n=layer.w2_weight.shape[2] * 2,
                k=x.shape[1],
                e=layer.w13_weight.shape[0],
1943
            )
1944
1945
1946
1947
1948


ModelOptNvFp4Config.LinearMethodCls = ModelOptNvFp4LinearMethod
ModelOptNvFp4Config.FusedMoEMethodCls = ModelOptNvFp4FusedMoE
ModelOptNvFp4Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod