modelopt.py 59.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from fnmatch import fnmatch
5
from typing import TYPE_CHECKING, Any, Optional
6
7
8
9
10

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter

11
12
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
13
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
14
from vllm.attention.layer import Attention
15
from vllm.logger import init_logger
16
from vllm.model_executor.layers.fused_moe.config import (
17
18
    FusedMoEQuantConfig,
)
19
from vllm.model_executor.layers.fused_moe.layer import (
20
21
22
23
    FusedMoE,
    FusedMoEMethodBase,
    FusedMoeWeightScaleSupported,
)
24
25
26
27
28
29
30
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
    Fp8MoeBackend,
    convert_to_fp8_moe_kernel_format,
    make_fp8_moe_kernel,
    make_fp8_moe_quant_config,
    select_fp8_moe_backend,
)
31
32
33
34
35
36
37
38
39
from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
    FLASHINFER_NVFP4_MOE_BACKENDS,
    NvFp4MoeBackend,
    convert_to_nvfp4_moe_kernel_format,
    is_global_sf_supported_for_nvfp4_backend,
    make_nvfp4_moe_kernel,
    make_nvfp4_moe_quant_config,
    select_nvfp4_moe_backend,
)
40
41
42
43
44
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
)
45
from vllm.model_executor.layers.quantization import QuantizationMethods
46
from vllm.model_executor.layers.quantization.base_config import (
47
48
49
    QuantizationConfig,
    QuantizeMethodBase,
)
50
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
51
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
52
    build_flashinfer_fp4_cutlass_moe_prepare_finalize,
53
    flashinfer_trtllm_fp4_moe,
54
    flashinfer_trtllm_fp4_routed_moe,
55
56
    select_nvfp4_gemm_impl,
)
57
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
58
    apply_fi_trtllm_fp8_per_tensor_moe,
59
    build_flashinfer_fp8_cutlass_moe_prepare_finalize,
60
61
    select_cutlass_fp8_gemm_impl,
)
62
63
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
    W8A8BlockFp8LinearOp,
64
65
    process_fp8_input_tensor_strategy_moe,
    process_fp8_weight_tensor_strategy_moe,
66
)
67
68
69
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
    get_marlin_input_dtype,
)
70
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
71
72
73
74
    apply_fp4_marlin_linear,
    is_fp4_marlin_supported,
    prepare_fp4_layer_for_marlin,
)
75
from vllm.model_executor.layers.quantization.utils.quant_utils import (
76
77
78
79
80
    GroupShape,
    cutlass_fp4_supported,
    is_layer_skipped,
    swizzle_blockscale,
)
81
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
82
    Fp8LinearOp,
83
    cutlass_block_fp8_supported,
84
85
    requantize_with_max_scale,
)
86
87
88
89
90
91
from vllm.model_executor.parameter import (
    BlockQuantScaleParameter,
    ChannelQuantScaleParameter,
    ModelWeightParameter,
    PerTensorScaleParameter,
)
92
from vllm.model_executor.utils import replace_parameter
93
94
95
96
from vllm.utils.flashinfer import (
    flashinfer_scaled_fp4_mm,
    has_flashinfer,
)
97

98
99
100
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

101
102
logger = init_logger(__name__)

103
104
105
106
107
108
109
110
111
112
QUANT_ALGOS = [
    # FP8 (per-tensor weight + optional static activation scale).
    "FP8",
    # FP8 per-channel weight scale + per-token activation scale.
    "FP8_PER_CHANNEL_PER_TOKEN",
    # FP8 per-block weight-only (ModelOpt may emit this as lowercase).
    "FP8_PB_WO",
    # FP4
    "NVFP4",
]
113
KV_CACHE_QUANT_ALGOS = ["FP8"]
114
115


116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
    """

    def __init__(self, quant_config: "ModelOptQuantConfigBase"):
        super().__init__(quant_config)


class ModelOptQuantConfigBase(QuantizationConfig):
    LinearMethodCls: type = LinearMethodBase
    FusedMoEMethodCls: type = FusedMoEMethodBase
    KVCacheMethodCls: type = BaseKVCacheMethod

    def __init__(
        self,
        exclude_modules: list[str],
    ):
        super().__init__()
        self.exclude_modules: list[str] = exclude_modules

    def is_layer_excluded(self, prefix: str) -> bool:
        """
        Check if a layer should be excluded from quantization.

        Handles both exact matching (for fused layers) and ModelOpt wildcard matching.

        The ModelOpt exclude_modules list is a list of wildcards.
        """
        if len(self.exclude_modules) == 0:
            return False

        # First check exact matching with fused layer support
        if is_layer_skipped(prefix, self.exclude_modules, self.packed_modules_mapping):
            return True

        # TODO: This special hard coded logic is not needed for quantized checkpoints
        # generated by ModelOpt >= 0.39.0 where they are handled natually by the
        # exclude_modules config. But need to keep them for loading quantized
        # checkpoints generated by older versions. Then check substring matching
        # for patterns not caught by exact match
        for exclude_module in self.exclude_modules:
            # Skip exact matches already handled above
            if exclude_module != prefix and (
                exclude_module in prefix
                or (
                    prefix.startswith("language_model.")
                    and exclude_module in prefix.removeprefix("language_model.")
                )
            ):
                return True

        # modelopt exclude modules are not simple strings, they are wildcards
        for wildcard_pattern in self.exclude_modules:
            if fnmatch(prefix, wildcard_pattern):
                return True

        return False

    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
        # handle kv-cache first so we can focus only on weight quantization thereafter
        if isinstance(layer, Attention):
            return self.KVCacheMethodCls(self)

        # handle exclusion
        if self.is_layer_excluded(prefix):
            if isinstance(layer, LinearBase):
                return UnquantizedLinearMethod()
            return None

        # TODO: This special hard coded logic is not needed for quantized checkpoints
        # generated by ModelOpt >= 0.39.0 where they are handled natually by the
        # exclude_modules config. But need to keep them for loading quantized
        # checkpoints generated by older versions. Then check substring matching
        # for patterns not caught by exact match
        if "vision_tower" in prefix or "vision_model" in prefix:
            return UnquantizedLinearMethod()

        # now, the layer is quantized, handle it here
        if isinstance(layer, LinearBase):
198
199
200
201
            quant_method = self.LinearMethodCls(self)
            if getattr(quant_method, "backend", "") == "marlin":
                quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
            return quant_method
202
        elif isinstance(layer, FusedMoE):
203
204
205
206
            quant_method = self.FusedMoEMethodCls(quant_config=self, layer=layer)
            if getattr(quant_method, "backend", "") == "marlin":
                quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
            return quant_method
207
208
209
210
211

        return None

    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if len(self.exclude_modules) > 0:
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
            # This is a workaround for the weights remapping issue:
            # https://github.com/vllm-project/vllm/issues/28072
            # Right now, the Nvidia ModelOpt library use just one wildcard pattern:
            #        module_path*
            # It gets applied if the whole tree of modules rooted at module_path
            # is not quantized. Here we replace such pattern by 2 patterns that are
            # collectively equivalent to the original pattern:
            #        module_path
            #        module_path.*
            new_exclude_modules = []
            for exclude in self.exclude_modules:
                if len(exclude) >= 2 and exclude[-1] == "*" and exclude[-2] != ".":
                    new_exclude_modules.append(exclude[:-1])
                    new_exclude_modules.append(exclude[:-1] + ".*")
                else:
                    new_exclude_modules.append(exclude)

            self.exclude_modules = hf_to_vllm_mapper.apply_list(new_exclude_modules)
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278

    @staticmethod
    def get_config_filenames() -> list[str]:
        return ["hf_quant_config.json"]

    @classmethod
    def _from_config(
        cls,
        *,
        quant_method: str,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
        original_config: dict[str, Any],
        group_size: int | None,
    ) -> "ModelOptQuantConfigBase":
        raise NotImplementedError("Please implement this function in sub classes")

    @classmethod
    def from_config(cls, config: dict[str, Any]) -> "ModelOptQuantConfigBase":
        # Handle both ModelOpt format and compressed-tensors style format
        if "quantization" in config:
            # Traditional ModelOpt format:
            # {"quantization": {"quant_algo": "..."}}
            quant_config = cls.get_from_keys(config, ["quantization"])
            if not isinstance(quant_config, dict):
                raise ValueError("Expected 'quantization' to be a dictionary in config")

            quant_method = quant_config.get("quant_algo")

            # Handle kv_cache_quant_algo with proper type validation
            kv_cache_quant_method = quant_config.get("kv_cache_quant_algo")

            # Handle group_size with proper type validation
            group_size_raw = quant_config.get("group_size")

            # "exclude_modules" is the key in the legacy hf_quant_config.json
            exclude_modules = quant_config.get("exclude_modules", [])
        else:
            # Compressed-tensors style format:
            # {"quant_algo": "...", "quant_method": "modelopt"}
            quant_method = config.get("quant_algo")
            kv_cache_quant_method = config.get("kv_cache_quant_algo")
            # "ignore" is the key in config.json
            exclude_modules = config.get("ignore", [])
            group_size_raw = config.get("group_size")

        if not quant_method:
            raise ValueError("Missing 'quant_algo' in quantization config")

279
280
281
        # Normalize quant_algo for robust matching (ModelOpt may emit lowercase).
        quant_method = str(quant_method).upper()

282
283
284
285
286
287
288
289
        if kv_cache_quant_method is None:
            # No KV cache quantization, keep this branch just to have this comment
            pass
        elif not isinstance(kv_cache_quant_method, str):
            raise ValueError(
                f"kv_cache_quant_algo must be a string, got "
                f"{type(kv_cache_quant_method)}"
            )
290
291
        else:
            kv_cache_quant_method = kv_cache_quant_method.upper()
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326

        if not isinstance(exclude_modules, list):
            raise ValueError(
                f"exclude_modules must be a list, got {type(exclude_modules)}"
            )

        if group_size_raw is None:
            group_size = None
        elif isinstance(group_size_raw, int):
            group_size = group_size_raw
        else:
            try:
                group_size = int(group_size_raw)
            except (ValueError, TypeError):
                raise ValueError(
                    f"group_size must be an integer, got {type(group_size_raw)}"
                ) from None

        if quant_method not in QUANT_ALGOS:
            raise ValueError(
                f"ModelOpt currently only supports: {QUANT_ALGOS} "
                "quantizations in vLLM. Please check the "
                "`hf_quant_config.json` file for your model's "
                "quant configuration."
            )
        return cls._from_config(
            quant_method=quant_method,
            kv_cache_quant_method=kv_cache_quant_method,
            exclude_modules=exclude_modules,
            group_size=group_size,
            original_config=config,
        )


class ModelOptFp8Config(ModelOptQuantConfigBase):
327
328
329
330
    """Config class for ModelOpt FP8."""

    def __init__(
        self,
331
        quant_method: str,
332
333
334
        is_checkpoint_fp8_serialized: bool,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
335
    ) -> None:
336
        super().__init__(exclude_modules)
337
        self.quant_method = quant_method
338
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
339
        self.kv_cache_quant_method = kv_cache_quant_method
340
        if is_checkpoint_fp8_serialized:
341
            logger.warning(
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
                "Detected ModelOpt fp8 checkpoint (quant_algo=%s). Please note "
                "that the format is experimental and could change.",
                quant_method,
            )

        # Select LinearMethod implementation based on quant_algo.
        if self.quant_method == "FP8":
            self.LinearMethodCls = ModelOptFp8LinearMethod
        elif self.quant_method == "FP8_PER_CHANNEL_PER_TOKEN":
            self.LinearMethodCls = ModelOptFp8PcPtLinearMethod
        elif self.quant_method == "FP8_PB_WO":
            self.LinearMethodCls = ModelOptFp8PbWoLinearMethod
        else:
            raise ValueError(
                "Unsupported ModelOpt FP8 quant_algo for vLLM: "
                f"{self.quant_method}. Supported: FP8 / "
                "FP8_PER_CHANNEL_PER_TOKEN / FP8_PB_WO."
359
            )
360

361
    def get_name(self) -> QuantizationMethods:
362
363
        return "modelopt"

364
    def get_supported_act_dtypes(self) -> list[torch.dtype]:
365
366
367
368
369
370
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
        return 89

371
372
    @classmethod
    def override_quantization_method(
373
        cls, hf_quant_cfg, user_quant
374
    ) -> QuantizationMethods | None:
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
        """Detect if this ModelOpt config should be used based on
        quantization config."""

        if hf_quant_cfg is None:
            return None

        # Use the community standard 'quant_method'
        quant_method = hf_quant_cfg.get("quant_method", "").lower()

        # Only proceed if the method is explicitly "modelopt"
        if quant_method != "modelopt":
            return None

        # Look for ModelOpt-specific config structure
        if "quantization" in hf_quant_cfg:
            quant_config = hf_quant_cfg["quantization"]
            if isinstance(quant_config, dict):
392
393
                quant_algo = str(quant_config.get("quant_algo", ""))
                if "FP8" in quant_algo.upper():
394
395
396
                    return "modelopt"
        else:
            # Check for compressed-tensors style config with specific quant_algo
397
398
            quant_algo = str(hf_quant_cfg.get("quant_algo", ""))
            if "FP8" in quant_algo.upper():
399
400
401
402
                return "modelopt"

        return None

403
    @classmethod
404
405
406
407
408
409
410
411
412
    def _from_config(
        cls,
        *,
        quant_method: str,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
        original_config: dict[str, Any],
        **kwargs: Any,
    ) -> "ModelOptFp8Config":
413
        is_checkpoint_fp8_serialized = "FP8" in quant_method
414

415
416
417
418
419
420
        return cls(
            quant_method,
            is_checkpoint_fp8_serialized,
            kv_cache_quant_method,
            exclude_modules,
        )
421

422
423
424
425

class ModelOptFp8LinearMethod(LinearMethodBase):
    """Linear method for Model Optimizer static quantization.
    Supports loading FP8 checkpoints with static weight scale and
426
    activation scale. Future support might be added for dynamic
427
428
429
430
    scales.

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
431
    2. Only support float8_e4m3fn datatype
432
433
434
        Args: quant_config: The ModelOpt quantization config.
    """

435
    def __init__(self, quant_config: ModelOptFp8Config) -> None:
436
        self.quant_config = quant_config
437
        self.fp8_linear = Fp8LinearOp(
438
439
            act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR
        )
440
441
442
443
444

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
445
        output_partition_sizes: list[int],
446
447
448
449
450
451
452
453
454
455
456
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
457
458
459
460
461
462
463
464
465
466
467
468
469
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_fp8_serialized
            else params_dtype
        )
        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition, input_size_per_partition, dtype=weight_dtype
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
470
471
472
473
        layer.register_parameter("weight", weight)

        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
474
475
476
477
            weight_scale = PerTensorScaleParameter(
                data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
                weight_loader=weight_loader,
            )
478
479
480
            weight_scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("weight_scale", weight_scale)
            # INPUT SCALE
481
482
483
484
            scale = PerTensorScaleParameter(
                data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
                weight_loader=weight_loader,
            )
485
486
487
488
489

            scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("input_scale", scale)

    def process_weights_after_loading(self, layer: Module) -> None:
490
491
492
493
        weight = layer.weight
        max_w_scale = layer.weight_scale.max()
        if not (layer.weight_scale == layer.weight_scale[0]).all():
            max_w_scale, weight = requantize_with_max_scale(
494
495
                layer.weight, layer.weight_scale, layer.logical_widths
            )
496
497
        layer.weight = Parameter(weight.t(), requires_grad=False)
        layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
498
        layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
499
500
501
502
503

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
504
        bias: torch.Tensor | None = None,
505
    ) -> torch.Tensor:
506
507
508
509
510
511
512
        return self.fp8_linear.apply(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            input_scale=layer.input_scale,
            bias=bias,
        )
513
514


515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
class ModelOptFp8PcPtLinearMethod(LinearMethodBase):
    """Linear method for ModelOpt FP8_PER_CHANNEL_PER_TOKEN checkpoints.

    Expected checkpoint structure (per Linear):
    - weight: fp8-e4m3fn, shape [out, in]
    - weight_scale: fp32, shape [out] (per-output-channel)
    - no input_scale (activations are dynamically quantized per-token)
    """

    def __init__(self, quant_config: ModelOptFp8Config) -> None:
        self.quant_config = quant_config
        self.fp8_linear = Fp8LinearOp(
            act_quant_static=False, act_quant_group_shape=GroupShape.PER_TOKEN
        )

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size

        if not self.quant_config.is_checkpoint_fp8_serialized:
            raise ValueError(
                "FP8_PER_CHANNEL_PER_TOKEN currently only supports "
                "FP8-serialized checkpoints."
            )

        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition

        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition,
                dtype=torch.float8_e4m3fn,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
        layer.register_parameter("weight", weight)

        weight_scale = ChannelQuantScaleParameter(
            data=torch.empty(output_size_per_partition, dtype=torch.float32),
            output_dim=0,
            weight_loader=weight_loader,
        )
        weight_scale[:] = torch.finfo(torch.float32).min
        layer.register_parameter("weight_scale", weight_scale)

    def process_weights_after_loading(self, layer: Module) -> None:
        layer.weight = Parameter(layer.weight.t(), requires_grad=False)
        layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        return self.fp8_linear.apply(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            input_scale=None,
            bias=bias,
        )


class ModelOptFp8PbWoLinearMethod(LinearMethodBase):
    """Linear method for ModelOpt FP8_PB_WO checkpoints.

    ModelOpt exports `weight_scale` as a 4D tensor:
      [out_blk, 1, in_blk, 1]
    where block size is typically 128 for both dims.

    vLLM executes it as FP8 GEMM with *dynamic per-token* activation quant.
    """

    _WEIGHT_BLOCK_SIZE: tuple[int, int] = (128, 128)

    def __init__(self, quant_config: ModelOptFp8Config) -> None:
        self.quant_config = quant_config
        block_n, block_k = self._WEIGHT_BLOCK_SIZE
        self.weight_block_size = list(self._WEIGHT_BLOCK_SIZE)
        self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
            weight_group_shape=GroupShape(block_n, block_k),
            act_quant_group_shape=GroupShape(1, block_k),
            cutlass_block_fp8_supported=cutlass_block_fp8_supported(),
            use_aiter_and_is_supported=False,
        )

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size

        if not self.quant_config.is_checkpoint_fp8_serialized:
            raise ValueError(
                "FP8_PB_WO currently only supports FP8-serialized checkpoints."
            )

        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition

        # Expose block size so the v2 weight loaders can translate offsets from
        # element-space -> block-space for BlockQuantScaleParameter.
        layer.weight_block_size = self.weight_block_size

        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition,
                dtype=torch.float8_e4m3fn,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
        layer.register_parameter("weight", weight)

        block_n, block_k = self._WEIGHT_BLOCK_SIZE
        if output_size_per_partition % block_n != 0:
            raise ValueError(
                "ModelOpt FP8_PB_WO requires out_features divisible by "
                f"{block_n}, got {output_size_per_partition}."
            )
        if input_size_per_partition % block_k != 0:
            raise ValueError(
                "ModelOpt FP8_PB_WO requires in_features divisible by "
                f"{block_k}, got {input_size_per_partition}."
            )

        out_blks = output_size_per_partition // block_n
        in_blks = input_size_per_partition // block_k

        # Match ModelOpt's exported shape so weight loading works without a
        # custom loader: [out_blk, 1, in_blk, 1]
        weight_scale = BlockQuantScaleParameter(
            data=torch.empty((out_blks, 1, in_blks, 1), dtype=torch.float32),
            input_dim=2,
            output_dim=0,
            weight_loader=weight_loader,
        )
        weight_scale[:] = torch.finfo(torch.float32).min
        layer.register_parameter("weight_scale", weight_scale)

    def process_weights_after_loading(self, layer: Module) -> None:
        # Keep weight in [out, in] layout for W8A8BlockFp8LinearOp.
        layer.weight = Parameter(layer.weight.data, requires_grad=False)

        scale = layer.weight_scale
        if scale.dim() == 4:
            # [out_blk, 1, in_blk, 1] -> [out_blk, in_blk]
            scale = scale.squeeze(1).squeeze(-1)
        elif scale.dim() != 2:
            raise ValueError(
                "Unexpected ModelOpt FP8_PB_WO weight_scale shape: "
                f"{tuple(scale.shape)}."
            )

        layer.weight_scale = Parameter(scale.contiguous(), requires_grad=False)

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        return self.w8a8_block_fp8_linear.apply(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            input_scale=None,
            bias=bias,
        )


712
713
714
715
716
717
718
719
class ModelOptFp8MoEMethod(FusedMoEMethodBase):
    """MoE method for ModelOpt FP8.
    Supports loading FP8 checkpoints with static weight scale and
    activation scale.
    Args:
        quant_config: The ModelOpt quantization config.
    """

720
721
722
    def __init__(
        self,
        quant_config: ModelOptFp8Config,
723
        layer: FusedMoE,
724
    ) -> None:
725
        super().__init__(layer.moe_config)
726
        self.quant_config = quant_config
727
728
729
730
731
        assert self.quant_config.is_checkpoint_fp8_serialized
        self.fp8_backend = select_fp8_moe_backend(
            block_quant=False,
            tp_size=layer.moe_parallel_config.tp_size,
            with_lora_support=self.moe.is_lora_enabled,
732
        )
733
        self.kernel: mk.FusedMoEModularKernel | None = None
734
735

    def maybe_make_prepare_finalize(
736
        self,
737
        routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
738
    ) -> mk.FusedMoEPrepareAndFinalize | None:
739
        # TRT LLM not supported with all2all yet.
740
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
741
            return None
742
        elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
743
744
745
746
747
748
749
750
751
752
            # TP case: avoid convert to ModularKernelMethod - to be refactored.
            if self.moe.dp_size == 1:
                return None

            prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
                self.moe,
                use_deepseek_fp8_block_scale=False,
            )
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
753
        return super().maybe_make_prepare_finalize(routing_tables)
754
755
756
757

    def select_gemm_impl(
        self,
        prepare_finalize: mk.FusedMoEPrepareAndFinalize,
758
        layer: torch.nn.Module,
759
    ) -> mk.FusedMoEPermuteExpertsUnpermute:
760
        assert self.moe_quant_config is not None
761
        experts = select_cutlass_fp8_gemm_impl(
762
763
            self.moe,
            self.moe_quant_config,
764
765
766
        )
        logger.debug_once("Using %s", experts.__class__.__name__)
        return experts
767
768
769
770
771
772
773
774
775
776

    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
777
778
779
        layer.orig_dtype = params_dtype
        layer.num_experts = num_experts

780
        # Use FP8 dtype if checkpoint is serialized
781
782
783
784
785
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_fp8_serialized
            else params_dtype
        )
786
787
        weight_loader = extra_weight_attrs.get("weight_loader")

788
789
790
791
792
        if self.moe.is_act_and_mul:
            w13_up_dim = 2 * intermediate_size_per_partition
        else:
            w13_up_dim = intermediate_size_per_partition

793
        w13_weight = ModelWeightParameter(
794
795
            data=torch.empty(
                num_experts,
796
                w13_up_dim,
797
798
799
                hidden_size,
                dtype=weight_dtype,
            ),
800
801
802
803
804
805
806
            input_dim=2,
            output_dim=1,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_weight", w13_weight)

        w2_weight = ModelWeightParameter(
807
808
809
810
811
812
            data=torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=weight_dtype,
            ),
813
814
815
816
817
818
            input_dim=2,
            output_dim=1,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w2_weight", w2_weight)

819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
        # WEIGHT SCALES - Per-tensor scaling for ModelOpts
        # For gated MoE, allocate 2 scales for w1 and w3 respectively.
        # They will be combined to a single scale after weight loading.
        # For non-gated MoE, allocate 1 scale for w13.
        w13_weight_scale = PerTensorScaleParameter(
            data=torch.full(
                (num_experts, 2 if self.moe.is_act_and_mul else 1),
                1.0,
                dtype=torch.float32,
            ),
            weight_loader=weight_loader,
        )
        w2_weight_scale = PerTensorScaleParameter(
            data=torch.full((num_experts,), 1.0, dtype=torch.float32),
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_weight_scale", w13_weight_scale)
        layer.register_parameter("w2_weight_scale", w2_weight_scale)
837

838
839
840
841
        # INPUT SCALES - Per-tensor scaling for ModelOpt
        w13_input_scale = PerTensorScaleParameter(
            data=torch.full((num_experts,), 1.0, dtype=torch.float32),
            weight_loader=weight_loader,
842
        )
843
844
845
846
847
848
        w2_input_scale = PerTensorScaleParameter(
            data=torch.full((num_experts,), 1.0, dtype=torch.float32),
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_input_scale", w13_input_scale)
        layer.register_parameter("w2_input_scale", w2_input_scale)
849

850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
    def _setup_kernel(
        self,
        layer: torch.nn.Module,
        w13: torch.Tensor,
        w2: torch.Tensor,
        w13_scale: torch.Tensor,
        w2_scale: torch.Tensor,
        w13_input_scale: torch.Tensor,
        w2_input_scale: torch.Tensor,
    ):
        w13, w2, w13_scale, w2_scale = convert_to_fp8_moe_kernel_format(
            fp8_backend=self.fp8_backend,
            layer=layer,
            w13=w13,
            w2=w2,
            w13_scale=w13_scale,
            w2_scale=w2_scale,
            w13_input_scale=w13_input_scale,
            w2_input_scale=w2_input_scale,
        )
870

871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
        # Replace parameters with updated versions. Note that this helper
        # function ensures the replacement is compatible with RL weight reloads.
        replace_parameter(layer, "w13_weight", w13)
        replace_parameter(layer, "w2_weight", w2)
        replace_parameter(layer, "w13_weight_scale", w13_scale)
        replace_parameter(layer, "w2_weight_scale", w2_scale)

        # Setup modular kernel for TP case.
        self.moe_quant_config = self.get_fused_moe_quant_config(layer)
        if self.moe_quant_config:
            self.kernel, self.use_inplace = make_fp8_moe_kernel(
                layer=layer,
                moe_quant_config=self.moe_quant_config,
                moe_config=self.moe,
                fp8_backend=self.fp8_backend,
886
            )
887

888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        w13 = layer.w13_weight
        w2 = layer.w2_weight
        w13_scale = layer.w13_weight_scale
        w2_scale = layer.w2_weight_scale
        w13_input_scale = layer.w13_input_scale
        w2_input_scale = layer.w2_input_scale

        # Per tensor kernels require single activation scale. Use the max.
        w13_input_scale, w2_input_scale = process_fp8_input_tensor_strategy_moe(
            w13_input_scale, w2_input_scale
        )
        replace_parameter(layer, "w13_input_scale", w13_input_scale)
        replace_parameter(layer, "w2_input_scale", w2_input_scale)

        # Per tensor kernels require single weight scale for w13 per expert, but
        # on disk there is a scale for w1 and w3. Use the max to requantize.
        shard_size = layer.intermediate_size_per_partition
        w13, w13_scale = process_fp8_weight_tensor_strategy_moe(
            w13,
            w13_scale,
            shard_size,
            num_experts=layer.w13_weight.shape[0],
            is_act_and_mul=self.moe.is_act_and_mul,
912
913
        )

914
915
916
917
        # Shuffle weights to runtime format and setup kernel.
        self._setup_kernel(
            layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale
        )
918

919
    def get_fused_moe_quant_config(
920
        self, layer: torch.nn.Module
921
    ) -> FusedMoEQuantConfig | None:
922
923
924
925
926
927
928
929
930
931
932
933
        w1_scale = layer.w13_weight_scale
        w2_scale = layer.w2_weight_scale
        a1_scale = layer.w13_input_scale
        a2_scale = layer.w2_input_scale

        return make_fp8_moe_quant_config(
            fp8_backend=self.fp8_backend,
            w1_scale=w1_scale,
            w2_scale=w2_scale,
            a1_scale=a1_scale,
            a2_scale=a2_scale,
        )
934

935
936
    def apply(
        self,
937
        layer: FusedMoE,
938
939
        x: torch.Tensor,
        router_logits: torch.Tensor,
940
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
941
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
942
943
            if layer.enable_eplb:
                raise NotImplementedError(
944
                    "EPLB not supported for FlashInfer TRTLLM FP8 MoE Backend."
945
                )
946
947
            # TODO(rob): this validation should happen at kernel selection
            # time in the oracle rather than here.
948
949
            assert layer.activation == "silu", (
                f"Expected 'silu' activation but got {layer.activation}"
950
            )
951
            assert not layer.renormalize
952
            return apply_fi_trtllm_fp8_per_tensor_moe(
953
954
955
                layer=layer,
                hidden_states=x,
                router_logits=router_logits,
956
957
958
959
960
961
                routing_bias=layer.e_score_correction_bias,
                global_num_experts=layer.global_num_experts,
                top_k=layer.top_k,
                num_expert_group=layer.num_expert_group,
                topk_group=layer.topk_group,
                apply_router_weight_on_input=layer.apply_router_weight_on_input,
962
            )
963

964
        topk_weights, topk_ids = layer.select_experts(
965
966
967
            hidden_states=x,
            router_logits=router_logits,
        )
968

969
970
971
        # TODO(rob): this validation should happen at kernel selection
        # time in the oracle rather than here.
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
972
            assert layer.activation in ("silu", "relu2_no_mul"), (
973
                "Expected activation to be in ('silu', 'relu2_no_mul'),"
974
                f"but got {layer.activation}"
975
            )
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991

        assert self.kernel is not None
        result = self.kernel(
            x,
            layer.w13_weight,
            layer.w2_weight,
            topk_weights,
            topk_ids,
            inplace=self.use_inplace,
            activation=layer.activation,
            global_num_experts=layer.global_num_experts,
            expert_map=layer.expert_map,
            apply_router_weight_on_input=layer.apply_router_weight_on_input,
        )

        return result
992
993


994
995
996
997
998
999
ModelOptFp8Config.LinearMethodCls = ModelOptFp8LinearMethod
ModelOptFp8Config.FusedMoEMethodCls = ModelOptFp8MoEMethod
ModelOptFp8Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod


class ModelOptNvFp4Config(ModelOptQuantConfigBase):
1000
1001
1002
1003
1004
    """Config class for ModelOpt FP4."""

    def __init__(
        self,
        is_checkpoint_nvfp4_serialized: bool,
1005
        kv_cache_quant_algo: str | None,
1006
        exclude_modules: list[str],
1007
1008
        group_size: int = 16,
    ) -> None:
1009
        super().__init__(exclude_modules)
1010
1011
1012
1013
        self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
        if is_checkpoint_nvfp4_serialized:
            logger.warning(
                "Detected ModelOpt NVFP4 checkpoint. Please note that"
1014
1015
                " the format is experimental and could change in future."
            )
1016
1017
1018
1019

            self.group_size = group_size
            self.kv_cache_quant_algo = kv_cache_quant_algo

1020
    def get_name(self) -> QuantizationMethods:
1021
        return "modelopt_fp4"
1022

1023
    def get_supported_act_dtypes(self) -> list[torch.dtype]:
1024
1025
1026
1027
        return [torch.bfloat16, torch.half, torch.float8_e4m3fn]

    @classmethod
    def get_min_capability(cls) -> int:
1028
        return 75
1029

1030
1031
    @classmethod
    def override_quantization_method(
1032
        cls, hf_quant_cfg, user_quant
1033
    ) -> QuantizationMethods | None:
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
        """Detect if this ModelOpt FP4 config should be used based on
        quantization config."""
        if hf_quant_cfg is None:
            return None

        # Use the community standard 'quant_method'
        quant_method = hf_quant_cfg.get("quant_method", "").lower()

        # Only proceed if the method is explicitly "modelopt"
        if quant_method != "modelopt":
            return None

        # Look for ModelOpt-specific config structure
        if "quantization" in hf_quant_cfg:
            quant_config = hf_quant_cfg["quantization"]
            if isinstance(quant_config, dict):
                quant_algo = quant_config.get("quant_algo", "")
                if "NVFP4" in quant_algo:
                    return "modelopt_fp4"
        else:
            # Check for compressed-tensors style config with specific
            # quant_algo field
            quant_algo = hf_quant_cfg.get("quant_algo", "")
            if isinstance(quant_algo, str) and "FP4" in quant_algo.upper():
                return "modelopt_fp4"

        return None

1062
    @classmethod
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
    def _from_config(
        cls,
        *,
        quant_method: str,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
        original_config: dict[str, Any],
        group_size: int | None,
        **kwargs: Any,
    ) -> "ModelOptNvFp4Config":
1073
        is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
1074

1075
1076
1077
        if group_size is None:
            group_size = 16  # Default value

1078
        # For FP4, these fields are required
1079
        if is_checkpoint_nvfp4_serialized and "quantization" in original_config:
1080
            # Check if required fields are present in the quantization config
1081
            quant_config = original_config["quantization"]
1082
            required_fields = ["group_size", "kv_cache_quant_algo", "exclude_modules"]
1083
1084
1085
1086
1087
1088
            missing_fields = [
                field for field in required_fields if field not in quant_config
            ]
            if missing_fields:
                raise ValueError(
                    f"NVFP4 quantization requires the following fields in "
1089
1090
1091
1092
1093
                    f"hf_quant_config.json: {missing_fields}"
                )

        return cls(
            is_checkpoint_nvfp4_serialized,
1094
            kv_cache_quant_method,
1095
1096
1097
            exclude_modules,
            group_size,
        )
1098
1099
1100
1101
1102


class ModelOptNvFp4LinearMethod(LinearMethodBase):
    """Linear method for Model Optimizer NVFP4.
    Supports loading NVFP4 checkpoints with the following structure:
1103

1104
1105
1106
1107
1108
1109
1110
    input_scale: torch.float32, scalar ,
    weight: NVFP4(represented as byte) Shape: [1, X, y/2]
    weight_scale: FP8-E4M3, Shape: [X, Y], aka per block scale,
    weight_scale_2: torch.float32, scalar,
    Args: quant_config: The ModelOpt quantization config.
    """

1111
    def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
1112
        self.quant_config = quant_config
1113
        self.marlin_input_dtype = None
1114

1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
        self.backend = "none"
        if envs.VLLM_NVFP4_GEMM_BACKEND is None:
            if has_flashinfer():
                self.backend = "flashinfer-cutlass"
            elif cutlass_fp4_supported():
                self.backend = "cutlass"
            elif is_fp4_marlin_supported():
                self.backend = "marlin"
        elif envs.VLLM_NVFP4_GEMM_BACKEND.startswith("flashinfer-"):
            self.backend = envs.VLLM_NVFP4_GEMM_BACKEND
            assert has_flashinfer(), f"FlashInfer is required for {self.backend}"
1126
1127
1128
        elif envs.VLLM_NVFP4_GEMM_BACKEND == "cutlass":
            self.backend = "cutlass"
            assert cutlass_fp4_supported(), f"Cutlass is required for {self.backend}"
1129
1130

        if self.backend == "none":
1131
            raise ValueError(
1132
1133
                "No valid NVFP4 GEMM backend found. "
                "Please check your platform capability."
1134
            )
1135

1136
1137
        logger.info_once(f"Using {self.backend} for NVFP4 GEMM")

1138
1139
1140
1141
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
1142
        output_partition_sizes: list[int],
1143
1144
1145
1146
1147
1148
1149
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size
        if not self.quant_config.is_checkpoint_nvfp4_serialized:
1150
1151
1152
1153
            raise ValueError(
                "NVFP4 quantization was selected, "
                " dynamic quantization is not supported."
            )
1154
1155
1156
1157
1158
1159
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition

1160
1161
1162
1163
        if input_size_per_partition % 16 != 0:
            raise ValueError(
                "Unsupported model when in features size is not multiple of 16"
            )
1164
        # The nvfp4 weight is still represented as
1165
1166
1167
1168
1169
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_nvfp4_serialized
            else params_dtype
        )
1170
1171
1172
1173
1174
1175
        # Weight
        weight = ModelWeightParameter(
            data=torch.empty(
                # 2 fp4 items are packed in the input dimension
                layer.output_size_per_partition,
                layer.input_size_per_partition // 2,
1176
1177
                dtype=torch.uint8,
            ),
1178
1179
            input_dim=1,
            output_dim=0,
1180
1181
            weight_loader=weight_loader,
        )
1182
1183
1184
        layer.register_parameter("weight", weight)

        # Input Weight Scale
1185
1186
1187
1188
        input_scale = PerTensorScaleParameter(
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )
1189
1190
1191
        layer.register_parameter("input_scale", input_scale)

        # Global Weight Scale
1192
1193
1194
1195
        weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )
1196
1197
1198
        layer.register_parameter("weight_scale_2", weight_scale_2)

        # Per Block Weight Scale
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
        weight_scale = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition // self.quant_config.group_size,
                dtype=weight_dtype,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219

        layer.register_parameter("weight_scale", weight_scale)

    def process_weights_after_loading(self, layer: Module) -> None:
        # global scales:
        input_scale_2 = layer.input_scale.max().to(torch.float32)
        layer.input_scale = Parameter(input_scale_2, requires_grad=False)

        weight_scale_2 = layer.weight_scale_2.max().to(torch.float32)
        layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False)

1220
1221
1222
        layer.alpha = Parameter(
            layer.input_scale * layer.weight_scale_2, requires_grad=False
        )
1223

1224
1225
        # Calculate `1 / input_scale` so that we don't need to do so at runtime
        layer.input_scale_inv = Parameter(
1226
1227
            (1 / layer.input_scale).to(torch.float32), requires_grad=False
        )
1228

1229
1230
1231
        # Swizzle the weight blockscale.
        # contracting dimension is input dimension
        # block_size = 16;
1232
1233
1234
        assert layer.weight_scale.dtype == torch.float8_e4m3fn, (
            "Weight Block scale must be represented as FP8-E4M3"
        )
1235

1236
1237
1238
1239
1240
        if self.backend == "marlin":
            prepare_fp4_layer_for_marlin(layer)
            del layer.alpha
            del layer.input_scale
        elif self.backend == "flashinfer-trtllm":
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
            # FlashInfer TRTLLM FP4 GEMM requires a different weight layout.
            # FlashInfer provides nvfp4_quantize to quantize + shuffle the
            # layout but we use our own quantization so we have to call
            # shuffles ourselves.
            from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a

            weight = layer.weight.data
            weight_scale = layer.weight_scale.data

            epilogue_tile_m = 128
1251
1252
1253
1254
1255
1256
            weight = shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m)
            weight_scale = (
                shuffle_matrix_sf_a(weight_scale.view(torch.uint8), epilogue_tile_m)
                .reshape(weight_scale.shape)
                .view(torch.float8_e4m3fn)
            )
1257

1258
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
1259
1260
1261
            layer.weight = Parameter(weight, requires_grad=False)
        else:
            swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
1262
            layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False)
1263
            layer.weight = Parameter(layer.weight.data, requires_grad=False)
1264
1265
1266
1267
1268

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
1269
        bias: torch.Tensor | None = None,
1270
    ) -> torch.Tensor:
1271
        if self.backend == "marlin":
1272
1273
1274
1275
1276
1277
1278
1279
            return apply_fp4_marlin_linear(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale,
                weight_scale_2=layer.weight_scale_2,
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
1280
                bias=bias,
1281
                input_dtype=self.marlin_input_dtype,
1282
            )
1283

1284
        output_dtype = x.dtype
1285
        output_shape = [x.shape[0], layer.weight.shape[0]]
1286
1287

        # quantize BF16 or FP16 to (FP4 and interleaved block scale)
1288
        x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_scale_inv)
1289
1290
1291

        # validate dtypes of quantized input, input block scale,
        # weight and weight_blockscale
1292
1293
1294
1295
1296
        assert x_fp4.dtype == torch.uint8
        assert layer.weight.dtype == torch.uint8
        assert x_blockscale.dtype == torch.float8_e4m3fn
        assert layer.weight_scale.dtype == torch.float8_e4m3fn
        assert layer.alpha.dtype == torch.float32
1297

1298
1299
1300
1301
        mm_args = (
            x_fp4,
            layer.weight,
            x_blockscale,
1302
            layer.weight_scale,
1303
1304
1305
            layer.alpha,
            output_dtype,
        )
1306
1307
1308
        if self.backend.startswith("flashinfer-"):
            backend_name = self.backend[len("flashinfer-") :]
            out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name)
1309
        else:
1310
            assert self.backend == "cutlass"
1311
1312
            out = cutlass_scaled_fp4_mm(*mm_args)

1313
1314
1315
        if bias is not None:
            out = out + bias
        return out.view(*output_shape)
1316
1317
1318
1319
1320


class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
    """
    MoE Method for FP4 Quantization.
1321
    Args:
1322
1323
1324
        quant_config: NVFP4 Quant Config
    """

1325
1326
1327
    def __init__(
        self,
        quant_config: ModelOptNvFp4Config,
1328
        layer: FusedMoE,
1329
    ) -> None:
1330
        super().__init__(layer.moe_config)
1331
        self.quant_config = quant_config
1332
1333
1334
1335
1336
1337
1338
1339
1340
        self.nvfp4_backend = select_nvfp4_moe_backend()
        # TODO: move this type of check into the oracle.
        if (
            not self.moe.is_act_and_mul
            and not self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTLASS
        ):
            raise NotImplementedError(
                "Non-gated activations are only supported by FlashInfer "
                "CUTLASS NvFP4 MoE backend."
1341
            )
1342
1343
1344
1345
1346

        self.use_global_sf = is_global_sf_supported_for_nvfp4_backend(
            self.nvfp4_backend
        )
        self.kernel: mk.FusedMoEModularKernel | None = None
1347

1348
1349
1350
1351
    def maybe_make_prepare_finalize(
        self,
        routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
    ) -> mk.FusedMoEPrepareAndFinalize | None:
1352
1353
        UNSUPPORTED = [NvFp4MoeBackend.MARLIN, NvFp4MoeBackend.FLASHINFER_TRTLLM]
        if self.nvfp4_backend in UNSUPPORTED:
1354
            return None
1355
        elif self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTLASS:
1356
1357
1358
            # TP case: avoid convert to ModularKernelMethod - to be refactored.
            if self.moe.dp_size == 1:
                return None
1359
            # For now, fp4 moe only works with the flashinfer dispatcher.
1360
1361
1362
            prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(
                self.moe
            )
1363
1364
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
1365
        else:
1366
            return super().maybe_make_prepare_finalize(routing_tables)
1367

1368
1369
1370
    def select_gemm_impl(
        self,
        prepare_finalize: mk.FusedMoEPrepareAndFinalize,
1371
        layer: torch.nn.Module,
1372
    ) -> mk.FusedMoEPermuteExpertsUnpermute:
1373
        assert self.moe_quant_config is not None
1374
        experts = select_nvfp4_gemm_impl(
1375
1376
            self.moe,
            self.moe_quant_config,
1377
            allow_flashinfer=self.nvfp4_backend in FLASHINFER_NVFP4_MOE_BACKENDS,
1378
1379
1380
        )
        logger.debug_once("Using %s", experts.__class__.__name__)
        return experts
1381

1382
1383
1384
1385
1386
1387
    def uses_weight_scale_2_pattern(self) -> bool:
        """
        FP4 variants use 'weight_scale_2' pattern for per-tensor weight scales.
        """
        return True

1388
1389
1390
1391
1392
1393
1394
1395
1396
    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
1397
        assert self.quant_config.is_checkpoint_nvfp4_serialized
1398

1399
1400
        layer.num_experts = num_experts
        layer.params_dtype = params_dtype
1401
1402
1403
1404
        layer.quant_config = self.quant_config
        weight_dtype = torch.uint8
        weight_scale_dtype = torch.float8_e4m3fn
        weight_loader = extra_weight_attrs.get("weight_loader")
1405
        global_num_experts = extra_weight_attrs.get("global_num_experts")
1406
1407
1408
1409
        # GEMM 1
        w13_weight = ModelWeightParameter(
            data=torch.empty(
                num_experts,
1410
                (2 if self.moe.is_act_and_mul else 1) * intermediate_size_per_partition,
1411
1412
                # 2 fp4 items are packed in the input dimension
                hidden_size // 2,
1413
1414
                dtype=weight_dtype,
            ),
1415
1416
            input_dim=1,
            output_dim=2,
1417
1418
            weight_loader=weight_loader,
        )
1419
1420
1421
1422
1423
1424
1425
1426
1427
        layer.register_parameter("w13_weight", w13_weight)

        # GEMM 2
        w2_weight = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                hidden_size,
                # 2 fp4 items are packed in the input dimension
                intermediate_size_per_partition // 2,
1428
1429
                dtype=weight_dtype,
            ),
1430
1431
            input_dim=1,
            output_dim=2,
1432
1433
            weight_loader=weight_loader,
        )
1434
1435
1436
1437
1438
        layer.register_parameter("w2_weight", w2_weight)

        w13_weight_scale = ModelWeightParameter(
            data=torch.empty(
                num_experts,
1439
                (2 if self.moe.is_act_and_mul else 1) * intermediate_size_per_partition,
1440
1441
                # 2 fp4 items are packed in the input dimension
                hidden_size // self.quant_config.group_size,
1442
1443
                dtype=weight_scale_dtype,
            ),
1444
1445
            input_dim=1,
            output_dim=2,
1446
1447
            weight_loader=weight_loader,
        )
1448
1449
1450
1451
1452
1453
1454
        layer.register_parameter("w13_weight_scale", w13_weight_scale)

        w2_weight_scale = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                hidden_size,
                # 2 fp4 items are packed in the input dimension
1455
1456
1457
                intermediate_size_per_partition // self.quant_config.group_size,
                dtype=weight_scale_dtype,
            ),
1458
1459
            input_dim=1,
            output_dim=2,
1460
1461
            weight_loader=weight_loader,
        )
1462
1463
1464
        layer.register_parameter("w2_weight_scale", w2_weight_scale)

        extra_weight_attrs.update(
1465
1466
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
        )
1467
1468

        w13_weight_scale_2 = PerTensorScaleParameter(
1469
1470
1471
            data=torch.empty(
                num_experts, 2 if self.moe.is_act_and_mul else 1, dtype=torch.float32
            ),
1472
1473
            weight_loader=weight_loader,
        )
1474
1475
1476
1477
        layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)

        w2_weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(num_experts, dtype=torch.float32),
1478
1479
            weight_loader=weight_loader,
        )
1480
1481
1482
        layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2)

        extra_weight_attrs.update(
1483
1484
            {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
        )
1485

1486
1487
        global_sf_num_experts = (
            global_num_experts if self.use_global_sf else num_experts
1488
        )
1489
        w13_input_scale = PerTensorScaleParameter(
1490
            data=torch.empty(
1491
                global_sf_num_experts,
1492
1493
1494
                2 if self.moe.is_act_and_mul else 1,
                dtype=torch.float32,
            ),
1495
1496
            weight_loader=weight_loader,
        )
1497
1498
        layer.register_parameter("w13_input_scale", w13_input_scale)

1499
        w2_input_scale = PerTensorScaleParameter(
1500
            data=torch.empty(global_sf_num_experts, dtype=torch.float32),
1501
1502
            weight_loader=weight_loader,
        )
1503
1504
1505
        layer.register_parameter("w2_input_scale", w2_input_scale)

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
1506
1507
1508
        """
        Convert NVFP4 MoE weights into kernel format and setup the kernel.
        """
1509

1510
        # Use a single gscale for w13.
1511
        if self.moe.is_act_and_mul and not torch.allclose(
1512
1513
            layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]
        ):
1514
1515
            logger.warning_once(
                "w1_weight_scale_2 must match w3_weight_scale_2. "
1516
1517
                "Accuracy may be affected."
            )
1518
        w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0].contiguous()
1519

1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
        (
            w13,
            w13_scale,
            w13_scale_2,
            a13_scale,
            w2,
            w2_scale,
            w2_scale_2,
            a2_scale,
        ) = convert_to_nvfp4_moe_kernel_format(
            nvfp4_backend=self.nvfp4_backend,
            layer=layer,
            w13=layer.w13_weight,
            w13_scale=layer.w13_weight_scale,
            w13_scale_2=w13_weight_scale_2,
            a13_scale=layer.w13_input_scale,
            w2=layer.w2_weight,
            w2_scale=layer.w2_weight_scale,
            w2_scale_2=layer.w2_weight_scale_2,
            a2_scale=layer.w2_input_scale,
            is_act_and_mul=self.moe.is_act_and_mul,
1541
        )
1542

1543
1544
1545
1546
1547
1548
1549
1550
        replace_parameter(layer, "w13_weight", w13)
        replace_parameter(layer, "w13_weight_scale", w13_scale)
        replace_parameter(layer, "w13_weight_scale_2", w13_scale_2)
        replace_parameter(layer, "w13_input_scale", a13_scale)
        replace_parameter(layer, "w2_weight", w2)
        replace_parameter(layer, "w2_weight_scale", w2_scale)
        replace_parameter(layer, "w2_weight_scale_2", w2_scale_2)
        replace_parameter(layer, "w2_input_scale", a2_scale)
1551

1552
1553
1554
1555
1556
1557
1558
        self.moe_quant_config = self.get_fused_moe_quant_config(layer)
        use_dp = self.moe.dp_size > 1
        if self.moe_quant_config is not None and not use_dp:
            self.kernel = make_nvfp4_moe_kernel(
                backend=self.nvfp4_backend,
                quant_config=self.moe_quant_config,
                moe_config=self.moe,
1559
            )
1560

1561
1562
1563
1564
1565
1566
1567
1568
1569
    def prepare_dp_allgather_tensor(
        self,
        layer: FusedMoE,
        hidden_states: torch.Tensor,
        router_logits: torch.Tensor,
    ) -> tuple[torch.Tensor, list[torch.Tensor]]:
        """Optionally prepare extra tensors to carry through DP allgather/EP."""
        import flashinfer

1570
1571
        assert self.moe_quant_config is not None
        a1_gscale = self.moe_quant_config.a1_gscale
1572
1573
1574
1575
1576
1577
1578
1579
        hidden_states_fp4, hidden_states_sf = flashinfer.fp4_quantize(
            hidden_states,
            a1_gscale,
            is_sf_swizzled_layout=False,
        )
        extra_tensors: list[torch.Tensor] = [hidden_states_sf]
        return hidden_states_fp4, extra_tensors

1580
    def get_fused_moe_quant_config(
1581
        self, layer: torch.nn.Module
1582
    ) -> FusedMoEQuantConfig | None:
1583
1584
1585
        return make_nvfp4_moe_quant_config(
            backend=self.nvfp4_backend,
            w13_scale=layer.w13_weight_scale,
1586
            w2_scale=layer.w2_weight_scale,
1587
1588
1589
1590
            w13_scale_2=layer.w13_weight_scale_2,
            w2_scale_2=layer.w2_weight_scale_2,
            a13_scale=layer.w13_input_scale,
            a2_scale=layer.w2_input_scale,
1591
1592
        )

1593
1594
1595
1596
    @property
    def supports_eplb(self) -> bool:
        return True

1597
1598
    def apply(
        self,
1599
        layer: FusedMoE,
1600
1601
        x: torch.Tensor,
        router_logits: torch.Tensor,
1602
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
1603
        if (
1604
            self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
1605
            and not layer.enable_eplb
1606
        ):
1607
1608
1609
1610
            return flashinfer_trtllm_fp4_moe(
                layer=layer,
                x=x,
                router_logits=router_logits,
1611
1612
1613
1614
1615
1616
                top_k=layer.top_k,
                global_num_experts=layer.global_num_experts,
                num_expert_group=layer.num_expert_group,
                topk_group=layer.topk_group,
                custom_routing_function=layer.custom_routing_function,
                e_score_correction_bias=layer.e_score_correction_bias,
1617
            )
1618

1619
1620
1621
1622
1623
        # Hidden_states in select_experts is only used to extract metadata
        if isinstance(x, tuple):
            x_routing, _ = x
        else:
            x_routing = x
1624
        topk_weights, topk_ids = layer.select_experts(
1625
            hidden_states=x_routing,
1626
            router_logits=router_logits,
1627
        )
1628

1629
        # EPLB path
1630
1631
        if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
            assert layer.enable_eplb
1632
1633
1634
1635
1636
1637
1638
1639
            return flashinfer_trtllm_fp4_routed_moe(
                layer=layer,
                x=x,
                topk_ids=topk_ids,
                topk_weights=topk_weights,
                top_k=layer.top_k,
                global_num_experts=layer.global_num_experts,
            )
1640
1641
1642
        else:
            assert self.kernel is not None
            return self.kernel(
1643
1644
1645
1646
1647
                x,
                layer.w13_weight,
                layer.w2_weight,
                topk_weights,
                topk_ids,
1648
                inplace=False,
1649
1650
1651
1652
                activation=layer.activation,
                global_num_experts=layer.global_num_experts,
                expert_map=layer.expert_map,
                apply_router_weight_on_input=layer.apply_router_weight_on_input,
1653
            )
1654
1655
1656
1657
1658


ModelOptNvFp4Config.LinearMethodCls = ModelOptNvFp4LinearMethod
ModelOptNvFp4Config.FusedMoEMethodCls = ModelOptNvFp4FusedMoE
ModelOptNvFp4Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod