modelopt.py 59.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from fnmatch import fnmatch
5
from typing import TYPE_CHECKING, Any, Optional
6
7
8
9
10

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter

11
12
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
13
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
14
from vllm.attention.layer import Attention
15
from vllm.logger import init_logger
16
from vllm.model_executor.layers.fused_moe.config import (
17
    FusedMoEConfig,
18
19
    FusedMoEQuantConfig,
)
20
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
21
from vllm.model_executor.layers.fused_moe.layer import (
22
23
24
25
    FusedMoE,
    FusedMoEMethodBase,
    FusedMoeWeightScaleSupported,
)
26
27
28
29
30
31
32
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
    Fp8MoeBackend,
    convert_to_fp8_moe_kernel_format,
    make_fp8_moe_kernel,
    make_fp8_moe_quant_config,
    select_fp8_moe_backend,
)
33
34
35
36
37
38
39
40
41
from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
    FLASHINFER_NVFP4_MOE_BACKENDS,
    NvFp4MoeBackend,
    convert_to_nvfp4_moe_kernel_format,
    is_global_sf_supported_for_nvfp4_backend,
    make_nvfp4_moe_kernel,
    make_nvfp4_moe_quant_config,
    select_nvfp4_moe_backend,
)
42
43
44
45
46
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
)
47
from vllm.model_executor.layers.quantization import QuantizationMethods
48
from vllm.model_executor.layers.quantization.base_config import (
49
50
51
    QuantizationConfig,
    QuantizeMethodBase,
)
52
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
53
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
54
    build_flashinfer_fp4_cutlass_moe_prepare_finalize,
55
    flashinfer_trtllm_fp4_moe,
56
    flashinfer_trtllm_fp4_routed_moe,
57
58
    select_nvfp4_gemm_impl,
)
59
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
60
    apply_fi_trtllm_fp8_per_tensor_moe,
61
    build_flashinfer_fp8_cutlass_moe_prepare_finalize,
62
63
    select_cutlass_fp8_gemm_impl,
)
64
65
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
    W8A8BlockFp8LinearOp,
66
67
    process_fp8_input_tensor_strategy_moe,
    process_fp8_weight_tensor_strategy_moe,
68
)
69
70
71
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
    get_marlin_input_dtype,
)
72
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
73
74
75
76
    apply_fp4_marlin_linear,
    is_fp4_marlin_supported,
    prepare_fp4_layer_for_marlin,
)
77
from vllm.model_executor.layers.quantization.utils.quant_utils import (
78
79
80
81
82
    GroupShape,
    cutlass_fp4_supported,
    is_layer_skipped,
    swizzle_blockscale,
)
83
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
84
    Fp8LinearOp,
85
    cutlass_block_fp8_supported,
86
87
    requantize_with_max_scale,
)
88
89
90
91
92
93
from vllm.model_executor.parameter import (
    BlockQuantScaleParameter,
    ChannelQuantScaleParameter,
    ModelWeightParameter,
    PerTensorScaleParameter,
)
94
from vllm.model_executor.utils import replace_parameter
95
96
97
98
from vllm.utils.flashinfer import (
    flashinfer_scaled_fp4_mm,
    has_flashinfer,
)
99

100
101
102
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

103
104
logger = init_logger(__name__)

105
106
107
108
109
110
111
112
113
114
QUANT_ALGOS = [
    # FP8 (per-tensor weight + optional static activation scale).
    "FP8",
    # FP8 per-channel weight scale + per-token activation scale.
    "FP8_PER_CHANNEL_PER_TOKEN",
    # FP8 per-block weight-only (ModelOpt may emit this as lowercase).
    "FP8_PB_WO",
    # FP4
    "NVFP4",
]
115
KV_CACHE_QUANT_ALGOS = ["FP8"]
116
117


118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
    """

    def __init__(self, quant_config: "ModelOptQuantConfigBase"):
        super().__init__(quant_config)


class ModelOptQuantConfigBase(QuantizationConfig):
    LinearMethodCls: type = LinearMethodBase
    FusedMoEMethodCls: type = FusedMoEMethodBase
    KVCacheMethodCls: type = BaseKVCacheMethod

    def __init__(
        self,
        exclude_modules: list[str],
    ):
        super().__init__()
        self.exclude_modules: list[str] = exclude_modules

    def is_layer_excluded(self, prefix: str) -> bool:
        """
        Check if a layer should be excluded from quantization.

        Handles both exact matching (for fused layers) and ModelOpt wildcard matching.

        The ModelOpt exclude_modules list is a list of wildcards.
        """
        if len(self.exclude_modules) == 0:
            return False

        # First check exact matching with fused layer support
        if is_layer_skipped(prefix, self.exclude_modules, self.packed_modules_mapping):
            return True

        # TODO: This special hard coded logic is not needed for quantized checkpoints
        # generated by ModelOpt >= 0.39.0 where they are handled natually by the
        # exclude_modules config. But need to keep them for loading quantized
        # checkpoints generated by older versions. Then check substring matching
        # for patterns not caught by exact match
        for exclude_module in self.exclude_modules:
            # Skip exact matches already handled above
            if exclude_module != prefix and (
                exclude_module in prefix
                or (
                    prefix.startswith("language_model.")
                    and exclude_module in prefix.removeprefix("language_model.")
                )
            ):
                return True

        # modelopt exclude modules are not simple strings, they are wildcards
        for wildcard_pattern in self.exclude_modules:
            if fnmatch(prefix, wildcard_pattern):
                return True

        return False

    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
        # handle kv-cache first so we can focus only on weight quantization thereafter
        if isinstance(layer, Attention):
            return self.KVCacheMethodCls(self)

        # handle exclusion
        if self.is_layer_excluded(prefix):
            if isinstance(layer, LinearBase):
                return UnquantizedLinearMethod()
            return None

        # TODO: This special hard coded logic is not needed for quantized checkpoints
        # generated by ModelOpt >= 0.39.0 where they are handled natually by the
        # exclude_modules config. But need to keep them for loading quantized
        # checkpoints generated by older versions. Then check substring matching
        # for patterns not caught by exact match
        if "vision_tower" in prefix or "vision_model" in prefix:
            return UnquantizedLinearMethod()

        # now, the layer is quantized, handle it here
        if isinstance(layer, LinearBase):
200
201
202
203
            quant_method = self.LinearMethodCls(self)
            if getattr(quant_method, "backend", "") == "marlin":
                quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
            return quant_method
204
        elif isinstance(layer, FusedMoE):
205
206
207
            quant_method = self.FusedMoEMethodCls(
                quant_config=self, moe_config=layer.moe_config
            )
208
209
210
            if getattr(quant_method, "backend", "") == "marlin":
                quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
            return quant_method
211
212
213
214
215

        return None

    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if len(self.exclude_modules) > 0:
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
            # This is a workaround for the weights remapping issue:
            # https://github.com/vllm-project/vllm/issues/28072
            # Right now, the Nvidia ModelOpt library use just one wildcard pattern:
            #        module_path*
            # It gets applied if the whole tree of modules rooted at module_path
            # is not quantized. Here we replace such pattern by 2 patterns that are
            # collectively equivalent to the original pattern:
            #        module_path
            #        module_path.*
            new_exclude_modules = []
            for exclude in self.exclude_modules:
                if len(exclude) >= 2 and exclude[-1] == "*" and exclude[-2] != ".":
                    new_exclude_modules.append(exclude[:-1])
                    new_exclude_modules.append(exclude[:-1] + ".*")
                else:
                    new_exclude_modules.append(exclude)

            self.exclude_modules = hf_to_vllm_mapper.apply_list(new_exclude_modules)
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282

    @staticmethod
    def get_config_filenames() -> list[str]:
        return ["hf_quant_config.json"]

    @classmethod
    def _from_config(
        cls,
        *,
        quant_method: str,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
        original_config: dict[str, Any],
        group_size: int | None,
    ) -> "ModelOptQuantConfigBase":
        raise NotImplementedError("Please implement this function in sub classes")

    @classmethod
    def from_config(cls, config: dict[str, Any]) -> "ModelOptQuantConfigBase":
        # Handle both ModelOpt format and compressed-tensors style format
        if "quantization" in config:
            # Traditional ModelOpt format:
            # {"quantization": {"quant_algo": "..."}}
            quant_config = cls.get_from_keys(config, ["quantization"])
            if not isinstance(quant_config, dict):
                raise ValueError("Expected 'quantization' to be a dictionary in config")

            quant_method = quant_config.get("quant_algo")

            # Handle kv_cache_quant_algo with proper type validation
            kv_cache_quant_method = quant_config.get("kv_cache_quant_algo")

            # Handle group_size with proper type validation
            group_size_raw = quant_config.get("group_size")

            # "exclude_modules" is the key in the legacy hf_quant_config.json
            exclude_modules = quant_config.get("exclude_modules", [])
        else:
            # Compressed-tensors style format:
            # {"quant_algo": "...", "quant_method": "modelopt"}
            quant_method = config.get("quant_algo")
            kv_cache_quant_method = config.get("kv_cache_quant_algo")
            # "ignore" is the key in config.json
            exclude_modules = config.get("ignore", [])
            group_size_raw = config.get("group_size")

        if not quant_method:
            raise ValueError("Missing 'quant_algo' in quantization config")

283
284
285
        # Normalize quant_algo for robust matching (ModelOpt may emit lowercase).
        quant_method = str(quant_method).upper()

286
287
288
289
290
291
292
293
        if kv_cache_quant_method is None:
            # No KV cache quantization, keep this branch just to have this comment
            pass
        elif not isinstance(kv_cache_quant_method, str):
            raise ValueError(
                f"kv_cache_quant_algo must be a string, got "
                f"{type(kv_cache_quant_method)}"
            )
294
295
        else:
            kv_cache_quant_method = kv_cache_quant_method.upper()
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330

        if not isinstance(exclude_modules, list):
            raise ValueError(
                f"exclude_modules must be a list, got {type(exclude_modules)}"
            )

        if group_size_raw is None:
            group_size = None
        elif isinstance(group_size_raw, int):
            group_size = group_size_raw
        else:
            try:
                group_size = int(group_size_raw)
            except (ValueError, TypeError):
                raise ValueError(
                    f"group_size must be an integer, got {type(group_size_raw)}"
                ) from None

        if quant_method not in QUANT_ALGOS:
            raise ValueError(
                f"ModelOpt currently only supports: {QUANT_ALGOS} "
                "quantizations in vLLM. Please check the "
                "`hf_quant_config.json` file for your model's "
                "quant configuration."
            )
        return cls._from_config(
            quant_method=quant_method,
            kv_cache_quant_method=kv_cache_quant_method,
            exclude_modules=exclude_modules,
            group_size=group_size,
            original_config=config,
        )


class ModelOptFp8Config(ModelOptQuantConfigBase):
331
332
333
334
    """Config class for ModelOpt FP8."""

    def __init__(
        self,
335
        quant_method: str,
336
337
338
        is_checkpoint_fp8_serialized: bool,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
339
    ) -> None:
340
        super().__init__(exclude_modules)
341
        self.quant_method = quant_method
342
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
343
        self.kv_cache_quant_method = kv_cache_quant_method
344
        if is_checkpoint_fp8_serialized:
345
            logger.warning(
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
                "Detected ModelOpt fp8 checkpoint (quant_algo=%s). Please note "
                "that the format is experimental and could change.",
                quant_method,
            )

        # Select LinearMethod implementation based on quant_algo.
        if self.quant_method == "FP8":
            self.LinearMethodCls = ModelOptFp8LinearMethod
        elif self.quant_method == "FP8_PER_CHANNEL_PER_TOKEN":
            self.LinearMethodCls = ModelOptFp8PcPtLinearMethod
        elif self.quant_method == "FP8_PB_WO":
            self.LinearMethodCls = ModelOptFp8PbWoLinearMethod
        else:
            raise ValueError(
                "Unsupported ModelOpt FP8 quant_algo for vLLM: "
                f"{self.quant_method}. Supported: FP8 / "
                "FP8_PER_CHANNEL_PER_TOKEN / FP8_PB_WO."
363
            )
364

365
    def get_name(self) -> QuantizationMethods:
366
367
        return "modelopt"

368
    def get_supported_act_dtypes(self) -> list[torch.dtype]:
369
370
371
372
373
374
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
        return 89

375
376
    @classmethod
    def override_quantization_method(
377
        cls, hf_quant_cfg, user_quant
378
    ) -> QuantizationMethods | None:
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
        """Detect if this ModelOpt config should be used based on
        quantization config."""

        if hf_quant_cfg is None:
            return None

        # Use the community standard 'quant_method'
        quant_method = hf_quant_cfg.get("quant_method", "").lower()

        # Only proceed if the method is explicitly "modelopt"
        if quant_method != "modelopt":
            return None

        # Look for ModelOpt-specific config structure
        if "quantization" in hf_quant_cfg:
            quant_config = hf_quant_cfg["quantization"]
            if isinstance(quant_config, dict):
396
397
                quant_algo = str(quant_config.get("quant_algo", ""))
                if "FP8" in quant_algo.upper():
398
399
400
                    return "modelopt"
        else:
            # Check for compressed-tensors style config with specific quant_algo
401
402
            quant_algo = str(hf_quant_cfg.get("quant_algo", ""))
            if "FP8" in quant_algo.upper():
403
404
405
406
                return "modelopt"

        return None

407
    @classmethod
408
409
410
411
412
413
414
415
416
    def _from_config(
        cls,
        *,
        quant_method: str,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
        original_config: dict[str, Any],
        **kwargs: Any,
    ) -> "ModelOptFp8Config":
417
        is_checkpoint_fp8_serialized = "FP8" in quant_method
418

419
420
421
422
423
424
        return cls(
            quant_method,
            is_checkpoint_fp8_serialized,
            kv_cache_quant_method,
            exclude_modules,
        )
425

426
427
428
429

class ModelOptFp8LinearMethod(LinearMethodBase):
    """Linear method for Model Optimizer static quantization.
    Supports loading FP8 checkpoints with static weight scale and
430
    activation scale. Future support might be added for dynamic
431
432
433
434
    scales.

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
435
    2. Only support float8_e4m3fn datatype
436
437
438
        Args: quant_config: The ModelOpt quantization config.
    """

439
    def __init__(self, quant_config: ModelOptFp8Config) -> None:
440
        self.quant_config = quant_config
441
        self.fp8_linear = Fp8LinearOp(
442
443
            act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR
        )
444
445
446
447
448

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
449
        output_partition_sizes: list[int],
450
451
452
453
454
455
456
457
458
459
460
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
461
462
463
464
465
466
467
468
469
470
471
472
473
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_fp8_serialized
            else params_dtype
        )
        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition, input_size_per_partition, dtype=weight_dtype
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
474
475
476
477
        layer.register_parameter("weight", weight)

        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
478
479
480
481
            weight_scale = PerTensorScaleParameter(
                data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
                weight_loader=weight_loader,
            )
482
483
484
            weight_scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("weight_scale", weight_scale)
            # INPUT SCALE
485
486
487
488
            scale = PerTensorScaleParameter(
                data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
                weight_loader=weight_loader,
            )
489
490
491
492
493

            scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("input_scale", scale)

    def process_weights_after_loading(self, layer: Module) -> None:
494
495
496
497
        weight = layer.weight
        max_w_scale = layer.weight_scale.max()
        if not (layer.weight_scale == layer.weight_scale[0]).all():
            max_w_scale, weight = requantize_with_max_scale(
498
499
                layer.weight, layer.weight_scale, layer.logical_widths
            )
500
501
        layer.weight = Parameter(weight.t(), requires_grad=False)
        layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
502
        layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
503
504
505
506
507

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
508
        bias: torch.Tensor | None = None,
509
    ) -> torch.Tensor:
510
511
512
513
514
515
516
        return self.fp8_linear.apply(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            input_scale=layer.input_scale,
            bias=bias,
        )
517
518


519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
class ModelOptFp8PcPtLinearMethod(LinearMethodBase):
    """Linear method for ModelOpt FP8_PER_CHANNEL_PER_TOKEN checkpoints.

    Expected checkpoint structure (per Linear):
    - weight: fp8-e4m3fn, shape [out, in]
    - weight_scale: fp32, shape [out] (per-output-channel)
    - no input_scale (activations are dynamically quantized per-token)
    """

    def __init__(self, quant_config: ModelOptFp8Config) -> None:
        self.quant_config = quant_config
        self.fp8_linear = Fp8LinearOp(
            act_quant_static=False, act_quant_group_shape=GroupShape.PER_TOKEN
        )

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size

        if not self.quant_config.is_checkpoint_fp8_serialized:
            raise ValueError(
                "FP8_PER_CHANNEL_PER_TOKEN currently only supports "
                "FP8-serialized checkpoints."
            )

        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition

        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition,
                dtype=torch.float8_e4m3fn,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
        layer.register_parameter("weight", weight)

        weight_scale = ChannelQuantScaleParameter(
            data=torch.empty(output_size_per_partition, dtype=torch.float32),
            output_dim=0,
            weight_loader=weight_loader,
        )
        weight_scale[:] = torch.finfo(torch.float32).min
        layer.register_parameter("weight_scale", weight_scale)

    def process_weights_after_loading(self, layer: Module) -> None:
        layer.weight = Parameter(layer.weight.t(), requires_grad=False)
        layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        return self.fp8_linear.apply(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            input_scale=None,
            bias=bias,
        )


class ModelOptFp8PbWoLinearMethod(LinearMethodBase):
    """Linear method for ModelOpt FP8_PB_WO checkpoints.

    ModelOpt exports `weight_scale` as a 4D tensor:
      [out_blk, 1, in_blk, 1]
    where block size is typically 128 for both dims.

    vLLM executes it as FP8 GEMM with *dynamic per-token* activation quant.
    """

    _WEIGHT_BLOCK_SIZE: tuple[int, int] = (128, 128)

    def __init__(self, quant_config: ModelOptFp8Config) -> None:
        self.quant_config = quant_config
        block_n, block_k = self._WEIGHT_BLOCK_SIZE
        self.weight_block_size = list(self._WEIGHT_BLOCK_SIZE)
        self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
            weight_group_shape=GroupShape(block_n, block_k),
            act_quant_group_shape=GroupShape(1, block_k),
            cutlass_block_fp8_supported=cutlass_block_fp8_supported(),
            use_aiter_and_is_supported=False,
        )

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size

        if not self.quant_config.is_checkpoint_fp8_serialized:
            raise ValueError(
                "FP8_PB_WO currently only supports FP8-serialized checkpoints."
            )

        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition

        # Expose block size so the v2 weight loaders can translate offsets from
        # element-space -> block-space for BlockQuantScaleParameter.
        layer.weight_block_size = self.weight_block_size

        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition,
                dtype=torch.float8_e4m3fn,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
        layer.register_parameter("weight", weight)

        block_n, block_k = self._WEIGHT_BLOCK_SIZE
        if output_size_per_partition % block_n != 0:
            raise ValueError(
                "ModelOpt FP8_PB_WO requires out_features divisible by "
                f"{block_n}, got {output_size_per_partition}."
            )
        if input_size_per_partition % block_k != 0:
            raise ValueError(
                "ModelOpt FP8_PB_WO requires in_features divisible by "
                f"{block_k}, got {input_size_per_partition}."
            )

        out_blks = output_size_per_partition // block_n
        in_blks = input_size_per_partition // block_k

        # Match ModelOpt's exported shape so weight loading works without a
        # custom loader: [out_blk, 1, in_blk, 1]
        weight_scale = BlockQuantScaleParameter(
            data=torch.empty((out_blks, 1, in_blks, 1), dtype=torch.float32),
            input_dim=2,
            output_dim=0,
            weight_loader=weight_loader,
        )
        weight_scale[:] = torch.finfo(torch.float32).min
        layer.register_parameter("weight_scale", weight_scale)

    def process_weights_after_loading(self, layer: Module) -> None:
        # Keep weight in [out, in] layout for W8A8BlockFp8LinearOp.
        layer.weight = Parameter(layer.weight.data, requires_grad=False)

        scale = layer.weight_scale
        if scale.dim() == 4:
            # [out_blk, 1, in_blk, 1] -> [out_blk, in_blk]
            scale = scale.squeeze(1).squeeze(-1)
        elif scale.dim() != 2:
            raise ValueError(
                "Unexpected ModelOpt FP8_PB_WO weight_scale shape: "
                f"{tuple(scale.shape)}."
            )

        layer.weight_scale = Parameter(scale.contiguous(), requires_grad=False)

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        return self.w8a8_block_fp8_linear.apply(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            input_scale=None,
            bias=bias,
        )


716
717
718
719
720
721
722
723
class ModelOptFp8MoEMethod(FusedMoEMethodBase):
    """MoE method for ModelOpt FP8.
    Supports loading FP8 checkpoints with static weight scale and
    activation scale.
    Args:
        quant_config: The ModelOpt quantization config.
    """

724
725
726
    def __init__(
        self,
        quant_config: ModelOptFp8Config,
727
        moe_config: FusedMoEConfig,
728
    ) -> None:
729
        super().__init__(moe_config)
730
        self.quant_config = quant_config
731
732
733
        assert self.quant_config.is_checkpoint_fp8_serialized
        self.fp8_backend = select_fp8_moe_backend(
            block_quant=False,
734
            tp_size=moe_config.moe_parallel_config.tp_size,
735
            with_lora_support=self.moe.is_lora_enabled,
736
        )
737
        self.kernel: mk.FusedMoEModularKernel | None = None
738
739

    def maybe_make_prepare_finalize(
740
        self,
741
        routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
742
    ) -> mk.FusedMoEPrepareAndFinalize | None:
743
        # TRT LLM not supported with all2all yet.
744
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
745
            return None
746
        elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
747
748
749
750
751
752
753
754
755
756
            # TP case: avoid convert to ModularKernelMethod - to be refactored.
            if self.moe.dp_size == 1:
                return None

            prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
                self.moe,
                use_deepseek_fp8_block_scale=False,
            )
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
757
        return super().maybe_make_prepare_finalize(routing_tables)
758
759
760
761

    def select_gemm_impl(
        self,
        prepare_finalize: mk.FusedMoEPrepareAndFinalize,
762
        layer: torch.nn.Module,
763
    ) -> mk.FusedMoEPermuteExpertsUnpermute:
764
        assert self.moe_quant_config is not None
765
        experts = select_cutlass_fp8_gemm_impl(
766
767
            self.moe,
            self.moe_quant_config,
768
769
770
        )
        logger.debug_once("Using %s", experts.__class__.__name__)
        return experts
771
772
773
774
775
776
777
778
779
780

    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
781
782
783
        layer.orig_dtype = params_dtype
        layer.num_experts = num_experts

784
        # Use FP8 dtype if checkpoint is serialized
785
786
787
788
789
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_fp8_serialized
            else params_dtype
        )
790
791
        weight_loader = extra_weight_attrs.get("weight_loader")

792
793
794
795
796
        if self.moe.is_act_and_mul:
            w13_up_dim = 2 * intermediate_size_per_partition
        else:
            w13_up_dim = intermediate_size_per_partition

797
        w13_weight = ModelWeightParameter(
798
799
            data=torch.empty(
                num_experts,
800
                w13_up_dim,
801
802
803
                hidden_size,
                dtype=weight_dtype,
            ),
804
805
806
807
808
809
810
            input_dim=2,
            output_dim=1,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_weight", w13_weight)

        w2_weight = ModelWeightParameter(
811
812
813
814
815
816
            data=torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=weight_dtype,
            ),
817
818
819
820
821
822
            input_dim=2,
            output_dim=1,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w2_weight", w2_weight)

823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
        # WEIGHT SCALES - Per-tensor scaling for ModelOpts
        # For gated MoE, allocate 2 scales for w1 and w3 respectively.
        # They will be combined to a single scale after weight loading.
        # For non-gated MoE, allocate 1 scale for w13.
        w13_weight_scale = PerTensorScaleParameter(
            data=torch.full(
                (num_experts, 2 if self.moe.is_act_and_mul else 1),
                1.0,
                dtype=torch.float32,
            ),
            weight_loader=weight_loader,
        )
        w2_weight_scale = PerTensorScaleParameter(
            data=torch.full((num_experts,), 1.0, dtype=torch.float32),
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_weight_scale", w13_weight_scale)
        layer.register_parameter("w2_weight_scale", w2_weight_scale)
841

842
843
844
845
        # INPUT SCALES - Per-tensor scaling for ModelOpt
        w13_input_scale = PerTensorScaleParameter(
            data=torch.full((num_experts,), 1.0, dtype=torch.float32),
            weight_loader=weight_loader,
846
        )
847
848
849
850
851
852
        w2_input_scale = PerTensorScaleParameter(
            data=torch.full((num_experts,), 1.0, dtype=torch.float32),
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_input_scale", w13_input_scale)
        layer.register_parameter("w2_input_scale", w2_input_scale)
853

854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
    def _setup_kernel(
        self,
        layer: torch.nn.Module,
        w13: torch.Tensor,
        w2: torch.Tensor,
        w13_scale: torch.Tensor,
        w2_scale: torch.Tensor,
        w13_input_scale: torch.Tensor,
        w2_input_scale: torch.Tensor,
    ):
        w13, w2, w13_scale, w2_scale = convert_to_fp8_moe_kernel_format(
            fp8_backend=self.fp8_backend,
            layer=layer,
            w13=w13,
            w2=w2,
            w13_scale=w13_scale,
            w2_scale=w2_scale,
            w13_input_scale=w13_input_scale,
            w2_input_scale=w2_input_scale,
        )
874

875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
        # Replace parameters with updated versions. Note that this helper
        # function ensures the replacement is compatible with RL weight reloads.
        replace_parameter(layer, "w13_weight", w13)
        replace_parameter(layer, "w2_weight", w2)
        replace_parameter(layer, "w13_weight_scale", w13_scale)
        replace_parameter(layer, "w2_weight_scale", w2_scale)

        # Setup modular kernel for TP case.
        self.moe_quant_config = self.get_fused_moe_quant_config(layer)
        if self.moe_quant_config:
            self.kernel, self.use_inplace = make_fp8_moe_kernel(
                layer=layer,
                moe_quant_config=self.moe_quant_config,
                moe_config=self.moe,
                fp8_backend=self.fp8_backend,
890
            )
891

892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        w13 = layer.w13_weight
        w2 = layer.w2_weight
        w13_scale = layer.w13_weight_scale
        w2_scale = layer.w2_weight_scale
        w13_input_scale = layer.w13_input_scale
        w2_input_scale = layer.w2_input_scale

        # Per tensor kernels require single activation scale. Use the max.
        w13_input_scale, w2_input_scale = process_fp8_input_tensor_strategy_moe(
            w13_input_scale, w2_input_scale
        )
        replace_parameter(layer, "w13_input_scale", w13_input_scale)
        replace_parameter(layer, "w2_input_scale", w2_input_scale)

        # Per tensor kernels require single weight scale for w13 per expert, but
        # on disk there is a scale for w1 and w3. Use the max to requantize.
        shard_size = layer.intermediate_size_per_partition
        w13, w13_scale = process_fp8_weight_tensor_strategy_moe(
            w13,
            w13_scale,
            shard_size,
            num_experts=layer.w13_weight.shape[0],
            is_act_and_mul=self.moe.is_act_and_mul,
916
917
        )

918
919
920
921
        # Shuffle weights to runtime format and setup kernel.
        self._setup_kernel(
            layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale
        )
922

923
    def get_fused_moe_quant_config(
924
        self, layer: torch.nn.Module
925
    ) -> FusedMoEQuantConfig | None:
926
927
928
929
930
931
932
933
934
935
936
937
        w1_scale = layer.w13_weight_scale
        w2_scale = layer.w2_weight_scale
        a1_scale = layer.w13_input_scale
        a2_scale = layer.w2_input_scale

        return make_fp8_moe_quant_config(
            fp8_backend=self.fp8_backend,
            w1_scale=w1_scale,
            w2_scale=w2_scale,
            a1_scale=a1_scale,
            a2_scale=a2_scale,
        )
938

939
940
    def apply(
        self,
941
        layer: FusedMoE,
942
        router: FusedMoERouter,
943
944
        x: torch.Tensor,
        router_logits: torch.Tensor,
945
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
946
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
947
948
            if layer.enable_eplb:
                raise NotImplementedError(
949
                    "EPLB not supported for FlashInfer TRTLLM FP8 MoE Backend."
950
                )
951
952
            # TODO(rob): this validation should happen at kernel selection
            # time in the oracle rather than here.
953
954
            assert layer.activation == "silu", (
                f"Expected 'silu' activation but got {layer.activation}"
955
            )
956
            assert not layer.renormalize
957
            return apply_fi_trtllm_fp8_per_tensor_moe(
958
959
960
                layer=layer,
                hidden_states=x,
                router_logits=router_logits,
961
962
963
964
965
966
                routing_bias=layer.e_score_correction_bias,
                global_num_experts=layer.global_num_experts,
                top_k=layer.top_k,
                num_expert_group=layer.num_expert_group,
                topk_group=layer.topk_group,
                apply_router_weight_on_input=layer.apply_router_weight_on_input,
967
            )
968

969
970
        # Expert selection
        topk_weights, topk_ids = router.select_experts(
971
972
973
            hidden_states=x,
            router_logits=router_logits,
        )
974

975
976
977
        # TODO(rob): this validation should happen at kernel selection
        # time in the oracle rather than here.
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
978
            assert layer.activation in ("silu", "relu2_no_mul"), (
979
                "Expected activation to be in ('silu', 'relu2_no_mul'),"
980
                f"but got {layer.activation}"
981
            )
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997

        assert self.kernel is not None
        result = self.kernel(
            x,
            layer.w13_weight,
            layer.w2_weight,
            topk_weights,
            topk_ids,
            inplace=self.use_inplace,
            activation=layer.activation,
            global_num_experts=layer.global_num_experts,
            expert_map=layer.expert_map,
            apply_router_weight_on_input=layer.apply_router_weight_on_input,
        )

        return result
998
999


1000
1001
1002
1003
1004
1005
ModelOptFp8Config.LinearMethodCls = ModelOptFp8LinearMethod
ModelOptFp8Config.FusedMoEMethodCls = ModelOptFp8MoEMethod
ModelOptFp8Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod


class ModelOptNvFp4Config(ModelOptQuantConfigBase):
1006
1007
1008
1009
1010
    """Config class for ModelOpt FP4."""

    def __init__(
        self,
        is_checkpoint_nvfp4_serialized: bool,
1011
        kv_cache_quant_algo: str | None,
1012
        exclude_modules: list[str],
1013
1014
        group_size: int = 16,
    ) -> None:
1015
        super().__init__(exclude_modules)
1016
1017
1018
1019
        self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
        if is_checkpoint_nvfp4_serialized:
            logger.warning(
                "Detected ModelOpt NVFP4 checkpoint. Please note that"
1020
1021
                " the format is experimental and could change in future."
            )
1022
1023
1024
1025

            self.group_size = group_size
            self.kv_cache_quant_algo = kv_cache_quant_algo

1026
    def get_name(self) -> QuantizationMethods:
1027
        return "modelopt_fp4"
1028

1029
    def get_supported_act_dtypes(self) -> list[torch.dtype]:
1030
1031
1032
1033
        return [torch.bfloat16, torch.half, torch.float8_e4m3fn]

    @classmethod
    def get_min_capability(cls) -> int:
1034
        return 75
1035

1036
1037
    @classmethod
    def override_quantization_method(
1038
        cls, hf_quant_cfg, user_quant
1039
    ) -> QuantizationMethods | None:
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
        """Detect if this ModelOpt FP4 config should be used based on
        quantization config."""
        if hf_quant_cfg is None:
            return None

        # Use the community standard 'quant_method'
        quant_method = hf_quant_cfg.get("quant_method", "").lower()

        # Only proceed if the method is explicitly "modelopt"
        if quant_method != "modelopt":
            return None

        # Look for ModelOpt-specific config structure
        if "quantization" in hf_quant_cfg:
            quant_config = hf_quant_cfg["quantization"]
            if isinstance(quant_config, dict):
                quant_algo = quant_config.get("quant_algo", "")
                if "NVFP4" in quant_algo:
                    return "modelopt_fp4"
        else:
            # Check for compressed-tensors style config with specific
            # quant_algo field
            quant_algo = hf_quant_cfg.get("quant_algo", "")
            if isinstance(quant_algo, str) and "FP4" in quant_algo.upper():
                return "modelopt_fp4"

        return None

1068
    @classmethod
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
    def _from_config(
        cls,
        *,
        quant_method: str,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
        original_config: dict[str, Any],
        group_size: int | None,
        **kwargs: Any,
    ) -> "ModelOptNvFp4Config":
1079
        is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
1080

1081
1082
1083
        if group_size is None:
            group_size = 16  # Default value

1084
        # For FP4, these fields are required
1085
        if is_checkpoint_nvfp4_serialized and "quantization" in original_config:
1086
            # Check if required fields are present in the quantization config
1087
            quant_config = original_config["quantization"]
1088
            required_fields = ["group_size", "kv_cache_quant_algo", "exclude_modules"]
1089
1090
1091
1092
1093
1094
            missing_fields = [
                field for field in required_fields if field not in quant_config
            ]
            if missing_fields:
                raise ValueError(
                    f"NVFP4 quantization requires the following fields in "
1095
1096
1097
1098
1099
                    f"hf_quant_config.json: {missing_fields}"
                )

        return cls(
            is_checkpoint_nvfp4_serialized,
1100
            kv_cache_quant_method,
1101
1102
1103
            exclude_modules,
            group_size,
        )
1104
1105
1106
1107
1108


class ModelOptNvFp4LinearMethod(LinearMethodBase):
    """Linear method for Model Optimizer NVFP4.
    Supports loading NVFP4 checkpoints with the following structure:
1109

1110
1111
1112
1113
1114
1115
1116
    input_scale: torch.float32, scalar ,
    weight: NVFP4(represented as byte) Shape: [1, X, y/2]
    weight_scale: FP8-E4M3, Shape: [X, Y], aka per block scale,
    weight_scale_2: torch.float32, scalar,
    Args: quant_config: The ModelOpt quantization config.
    """

1117
    def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
1118
        self.quant_config = quant_config
1119
        self.marlin_input_dtype = None
1120

1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
        self.backend = "none"
        if envs.VLLM_NVFP4_GEMM_BACKEND is None:
            if has_flashinfer():
                self.backend = "flashinfer-cutlass"
            elif cutlass_fp4_supported():
                self.backend = "cutlass"
            elif is_fp4_marlin_supported():
                self.backend = "marlin"
        elif envs.VLLM_NVFP4_GEMM_BACKEND.startswith("flashinfer-"):
            self.backend = envs.VLLM_NVFP4_GEMM_BACKEND
            assert has_flashinfer(), f"FlashInfer is required for {self.backend}"
1132
1133
1134
        elif envs.VLLM_NVFP4_GEMM_BACKEND == "cutlass":
            self.backend = "cutlass"
            assert cutlass_fp4_supported(), f"Cutlass is required for {self.backend}"
1135
1136

        if self.backend == "none":
1137
            raise ValueError(
1138
1139
                "No valid NVFP4 GEMM backend found. "
                "Please check your platform capability."
1140
            )
1141

1142
1143
        logger.info_once(f"Using {self.backend} for NVFP4 GEMM")

1144
1145
1146
1147
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
1148
        output_partition_sizes: list[int],
1149
1150
1151
1152
1153
1154
1155
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size
        if not self.quant_config.is_checkpoint_nvfp4_serialized:
1156
1157
1158
1159
            raise ValueError(
                "NVFP4 quantization was selected, "
                " dynamic quantization is not supported."
            )
1160
1161
1162
1163
1164
1165
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition

1166
1167
1168
1169
        if input_size_per_partition % 16 != 0:
            raise ValueError(
                "Unsupported model when in features size is not multiple of 16"
            )
1170
        # The nvfp4 weight is still represented as
1171
1172
1173
1174
1175
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_nvfp4_serialized
            else params_dtype
        )
1176
1177
1178
1179
1180
1181
        # Weight
        weight = ModelWeightParameter(
            data=torch.empty(
                # 2 fp4 items are packed in the input dimension
                layer.output_size_per_partition,
                layer.input_size_per_partition // 2,
1182
1183
                dtype=torch.uint8,
            ),
1184
1185
            input_dim=1,
            output_dim=0,
1186
1187
            weight_loader=weight_loader,
        )
1188
1189
1190
        layer.register_parameter("weight", weight)

        # Input Weight Scale
1191
1192
1193
1194
        input_scale = PerTensorScaleParameter(
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )
1195
1196
1197
        layer.register_parameter("input_scale", input_scale)

        # Global Weight Scale
1198
1199
1200
1201
        weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )
1202
1203
1204
        layer.register_parameter("weight_scale_2", weight_scale_2)

        # Per Block Weight Scale
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
        weight_scale = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition // self.quant_config.group_size,
                dtype=weight_dtype,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225

        layer.register_parameter("weight_scale", weight_scale)

    def process_weights_after_loading(self, layer: Module) -> None:
        # global scales:
        input_scale_2 = layer.input_scale.max().to(torch.float32)
        layer.input_scale = Parameter(input_scale_2, requires_grad=False)

        weight_scale_2 = layer.weight_scale_2.max().to(torch.float32)
        layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False)

1226
1227
1228
        layer.alpha = Parameter(
            layer.input_scale * layer.weight_scale_2, requires_grad=False
        )
1229

1230
1231
        # Calculate `1 / input_scale` so that we don't need to do so at runtime
        layer.input_scale_inv = Parameter(
1232
1233
            (1 / layer.input_scale).to(torch.float32), requires_grad=False
        )
1234

1235
1236
1237
        # Swizzle the weight blockscale.
        # contracting dimension is input dimension
        # block_size = 16;
1238
1239
1240
        assert layer.weight_scale.dtype == torch.float8_e4m3fn, (
            "Weight Block scale must be represented as FP8-E4M3"
        )
1241

1242
1243
1244
1245
1246
        if self.backend == "marlin":
            prepare_fp4_layer_for_marlin(layer)
            del layer.alpha
            del layer.input_scale
        elif self.backend == "flashinfer-trtllm":
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
            # FlashInfer TRTLLM FP4 GEMM requires a different weight layout.
            # FlashInfer provides nvfp4_quantize to quantize + shuffle the
            # layout but we use our own quantization so we have to call
            # shuffles ourselves.
            from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a

            weight = layer.weight.data
            weight_scale = layer.weight_scale.data

            epilogue_tile_m = 128
1257
1258
1259
1260
1261
1262
            weight = shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m)
            weight_scale = (
                shuffle_matrix_sf_a(weight_scale.view(torch.uint8), epilogue_tile_m)
                .reshape(weight_scale.shape)
                .view(torch.float8_e4m3fn)
            )
1263

1264
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
1265
1266
1267
            layer.weight = Parameter(weight, requires_grad=False)
        else:
            swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
1268
            layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False)
1269
            layer.weight = Parameter(layer.weight.data, requires_grad=False)
1270
1271
1272
1273
1274

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
1275
        bias: torch.Tensor | None = None,
1276
    ) -> torch.Tensor:
1277
        if self.backend == "marlin":
1278
1279
1280
1281
1282
1283
1284
1285
            return apply_fp4_marlin_linear(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale,
                weight_scale_2=layer.weight_scale_2,
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
1286
                bias=bias,
1287
                input_dtype=self.marlin_input_dtype,
1288
            )
1289

1290
        output_dtype = x.dtype
1291
        output_shape = [x.shape[0], layer.weight.shape[0]]
1292
1293

        # quantize BF16 or FP16 to (FP4 and interleaved block scale)
1294
        x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_scale_inv)
1295
1296
1297

        # validate dtypes of quantized input, input block scale,
        # weight and weight_blockscale
1298
1299
1300
1301
1302
        assert x_fp4.dtype == torch.uint8
        assert layer.weight.dtype == torch.uint8
        assert x_blockscale.dtype == torch.float8_e4m3fn
        assert layer.weight_scale.dtype == torch.float8_e4m3fn
        assert layer.alpha.dtype == torch.float32
1303

1304
1305
1306
1307
        mm_args = (
            x_fp4,
            layer.weight,
            x_blockscale,
1308
            layer.weight_scale,
1309
1310
1311
            layer.alpha,
            output_dtype,
        )
1312
1313
1314
        if self.backend.startswith("flashinfer-"):
            backend_name = self.backend[len("flashinfer-") :]
            out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name)
1315
        else:
1316
            assert self.backend == "cutlass"
1317
1318
            out = cutlass_scaled_fp4_mm(*mm_args)

1319
1320
1321
        if bias is not None:
            out = out + bias
        return out.view(*output_shape)
1322
1323
1324
1325
1326


class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
    """
    MoE Method for FP4 Quantization.
1327
    Args:
1328
1329
1330
        quant_config: NVFP4 Quant Config
    """

1331
1332
1333
    def __init__(
        self,
        quant_config: ModelOptNvFp4Config,
1334
        moe_config: FusedMoEConfig,
1335
    ) -> None:
1336
        super().__init__(moe_config)
1337
        self.quant_config = quant_config
1338
1339
1340
1341
1342
1343
1344
1345
1346
        self.nvfp4_backend = select_nvfp4_moe_backend()
        # TODO: move this type of check into the oracle.
        if (
            not self.moe.is_act_and_mul
            and not self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTLASS
        ):
            raise NotImplementedError(
                "Non-gated activations are only supported by FlashInfer "
                "CUTLASS NvFP4 MoE backend."
1347
            )
1348
1349
1350
1351
1352

        self.use_global_sf = is_global_sf_supported_for_nvfp4_backend(
            self.nvfp4_backend
        )
        self.kernel: mk.FusedMoEModularKernel | None = None
1353

1354
1355
1356
1357
    def maybe_make_prepare_finalize(
        self,
        routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
    ) -> mk.FusedMoEPrepareAndFinalize | None:
1358
1359
        UNSUPPORTED = [NvFp4MoeBackend.MARLIN, NvFp4MoeBackend.FLASHINFER_TRTLLM]
        if self.nvfp4_backend in UNSUPPORTED:
1360
            return None
1361
        elif self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTLASS:
1362
1363
1364
            # TP case: avoid convert to ModularKernelMethod - to be refactored.
            if self.moe.dp_size == 1:
                return None
1365
            # For now, fp4 moe only works with the flashinfer dispatcher.
1366
1367
1368
            prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(
                self.moe
            )
1369
1370
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
1371
        else:
1372
            return super().maybe_make_prepare_finalize(routing_tables)
1373

1374
1375
1376
    def select_gemm_impl(
        self,
        prepare_finalize: mk.FusedMoEPrepareAndFinalize,
1377
        layer: torch.nn.Module,
1378
    ) -> mk.FusedMoEPermuteExpertsUnpermute:
1379
        assert self.moe_quant_config is not None
1380
        experts = select_nvfp4_gemm_impl(
1381
1382
            self.moe,
            self.moe_quant_config,
1383
            allow_flashinfer=self.nvfp4_backend in FLASHINFER_NVFP4_MOE_BACKENDS,
1384
1385
1386
        )
        logger.debug_once("Using %s", experts.__class__.__name__)
        return experts
1387

1388
1389
1390
1391
1392
1393
    def uses_weight_scale_2_pattern(self) -> bool:
        """
        FP4 variants use 'weight_scale_2' pattern for per-tensor weight scales.
        """
        return True

1394
1395
1396
1397
1398
1399
1400
1401
1402
    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
1403
        assert self.quant_config.is_checkpoint_nvfp4_serialized
1404

1405
1406
        layer.num_experts = num_experts
        layer.params_dtype = params_dtype
1407
1408
1409
1410
        layer.quant_config = self.quant_config
        weight_dtype = torch.uint8
        weight_scale_dtype = torch.float8_e4m3fn
        weight_loader = extra_weight_attrs.get("weight_loader")
1411
        global_num_experts = extra_weight_attrs.get("global_num_experts")
1412
1413
1414
1415
        # GEMM 1
        w13_weight = ModelWeightParameter(
            data=torch.empty(
                num_experts,
1416
                (2 if self.moe.is_act_and_mul else 1) * intermediate_size_per_partition,
1417
1418
                # 2 fp4 items are packed in the input dimension
                hidden_size // 2,
1419
1420
                dtype=weight_dtype,
            ),
1421
1422
            input_dim=1,
            output_dim=2,
1423
1424
            weight_loader=weight_loader,
        )
1425
1426
1427
1428
1429
1430
1431
1432
1433
        layer.register_parameter("w13_weight", w13_weight)

        # GEMM 2
        w2_weight = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                hidden_size,
                # 2 fp4 items are packed in the input dimension
                intermediate_size_per_partition // 2,
1434
1435
                dtype=weight_dtype,
            ),
1436
1437
            input_dim=1,
            output_dim=2,
1438
1439
            weight_loader=weight_loader,
        )
1440
1441
1442
1443
1444
        layer.register_parameter("w2_weight", w2_weight)

        w13_weight_scale = ModelWeightParameter(
            data=torch.empty(
                num_experts,
1445
                (2 if self.moe.is_act_and_mul else 1) * intermediate_size_per_partition,
1446
1447
                # 2 fp4 items are packed in the input dimension
                hidden_size // self.quant_config.group_size,
1448
1449
                dtype=weight_scale_dtype,
            ),
1450
1451
            input_dim=1,
            output_dim=2,
1452
1453
            weight_loader=weight_loader,
        )
1454
1455
1456
1457
1458
1459
1460
        layer.register_parameter("w13_weight_scale", w13_weight_scale)

        w2_weight_scale = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                hidden_size,
                # 2 fp4 items are packed in the input dimension
1461
1462
1463
                intermediate_size_per_partition // self.quant_config.group_size,
                dtype=weight_scale_dtype,
            ),
1464
1465
            input_dim=1,
            output_dim=2,
1466
1467
            weight_loader=weight_loader,
        )
1468
1469
1470
        layer.register_parameter("w2_weight_scale", w2_weight_scale)

        extra_weight_attrs.update(
1471
1472
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
        )
1473
1474

        w13_weight_scale_2 = PerTensorScaleParameter(
1475
1476
1477
            data=torch.empty(
                num_experts, 2 if self.moe.is_act_and_mul else 1, dtype=torch.float32
            ),
1478
1479
            weight_loader=weight_loader,
        )
1480
1481
1482
1483
        layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)

        w2_weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(num_experts, dtype=torch.float32),
1484
1485
            weight_loader=weight_loader,
        )
1486
1487
1488
        layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2)

        extra_weight_attrs.update(
1489
1490
            {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
        )
1491

1492
1493
        global_sf_num_experts = (
            global_num_experts if self.use_global_sf else num_experts
1494
        )
1495
        w13_input_scale = PerTensorScaleParameter(
1496
            data=torch.empty(
1497
                global_sf_num_experts,
1498
1499
1500
                2 if self.moe.is_act_and_mul else 1,
                dtype=torch.float32,
            ),
1501
1502
            weight_loader=weight_loader,
        )
1503
1504
        layer.register_parameter("w13_input_scale", w13_input_scale)

1505
        w2_input_scale = PerTensorScaleParameter(
1506
            data=torch.empty(global_sf_num_experts, dtype=torch.float32),
1507
1508
            weight_loader=weight_loader,
        )
1509
1510
1511
        layer.register_parameter("w2_input_scale", w2_input_scale)

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
1512
1513
1514
        """
        Convert NVFP4 MoE weights into kernel format and setup the kernel.
        """
1515

1516
        # Use a single gscale for w13.
1517
        if self.moe.is_act_and_mul and not torch.allclose(
1518
1519
            layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]
        ):
1520
1521
            logger.warning_once(
                "w1_weight_scale_2 must match w3_weight_scale_2. "
1522
1523
                "Accuracy may be affected."
            )
1524
        w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0].contiguous()
1525

1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
        (
            w13,
            w13_scale,
            w13_scale_2,
            a13_scale,
            w2,
            w2_scale,
            w2_scale_2,
            a2_scale,
        ) = convert_to_nvfp4_moe_kernel_format(
            nvfp4_backend=self.nvfp4_backend,
            layer=layer,
            w13=layer.w13_weight,
            w13_scale=layer.w13_weight_scale,
            w13_scale_2=w13_weight_scale_2,
            a13_scale=layer.w13_input_scale,
            w2=layer.w2_weight,
            w2_scale=layer.w2_weight_scale,
            w2_scale_2=layer.w2_weight_scale_2,
            a2_scale=layer.w2_input_scale,
            is_act_and_mul=self.moe.is_act_and_mul,
1547
        )
1548

1549
1550
1551
1552
1553
1554
1555
1556
        replace_parameter(layer, "w13_weight", w13)
        replace_parameter(layer, "w13_weight_scale", w13_scale)
        replace_parameter(layer, "w13_weight_scale_2", w13_scale_2)
        replace_parameter(layer, "w13_input_scale", a13_scale)
        replace_parameter(layer, "w2_weight", w2)
        replace_parameter(layer, "w2_weight_scale", w2_scale)
        replace_parameter(layer, "w2_weight_scale_2", w2_scale_2)
        replace_parameter(layer, "w2_input_scale", a2_scale)
1557

1558
1559
1560
1561
1562
1563
1564
        self.moe_quant_config = self.get_fused_moe_quant_config(layer)
        use_dp = self.moe.dp_size > 1
        if self.moe_quant_config is not None and not use_dp:
            self.kernel = make_nvfp4_moe_kernel(
                backend=self.nvfp4_backend,
                quant_config=self.moe_quant_config,
                moe_config=self.moe,
1565
            )
1566

1567
1568
1569
1570
1571
1572
1573
1574
1575
    def prepare_dp_allgather_tensor(
        self,
        layer: FusedMoE,
        hidden_states: torch.Tensor,
        router_logits: torch.Tensor,
    ) -> tuple[torch.Tensor, list[torch.Tensor]]:
        """Optionally prepare extra tensors to carry through DP allgather/EP."""
        import flashinfer

1576
1577
        assert self.moe_quant_config is not None
        a1_gscale = self.moe_quant_config.a1_gscale
1578
1579
1580
1581
1582
1583
1584
1585
        hidden_states_fp4, hidden_states_sf = flashinfer.fp4_quantize(
            hidden_states,
            a1_gscale,
            is_sf_swizzled_layout=False,
        )
        extra_tensors: list[torch.Tensor] = [hidden_states_sf]
        return hidden_states_fp4, extra_tensors

1586
    def get_fused_moe_quant_config(
1587
        self, layer: torch.nn.Module
1588
    ) -> FusedMoEQuantConfig | None:
1589
1590
1591
        return make_nvfp4_moe_quant_config(
            backend=self.nvfp4_backend,
            w13_scale=layer.w13_weight_scale,
1592
            w2_scale=layer.w2_weight_scale,
1593
1594
1595
1596
            w13_scale_2=layer.w13_weight_scale_2,
            w2_scale_2=layer.w2_weight_scale_2,
            a13_scale=layer.w13_input_scale,
            a2_scale=layer.w2_input_scale,
1597
1598
        )

1599
1600
1601
1602
    @property
    def supports_eplb(self) -> bool:
        return True

1603
1604
    def apply(
        self,
1605
        layer: FusedMoE,
1606
        router: FusedMoERouter,
1607
1608
        x: torch.Tensor,
        router_logits: torch.Tensor,
1609
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
1610
        if (
1611
            self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
1612
            and not layer.enable_eplb
1613
        ):
1614
1615
1616
1617
            return flashinfer_trtllm_fp4_moe(
                layer=layer,
                x=x,
                router_logits=router_logits,
1618
1619
1620
1621
1622
1623
                top_k=layer.top_k,
                global_num_experts=layer.global_num_experts,
                num_expert_group=layer.num_expert_group,
                topk_group=layer.topk_group,
                custom_routing_function=layer.custom_routing_function,
                e_score_correction_bias=layer.e_score_correction_bias,
1624
            )
1625

1626
1627
1628
1629
1630
        # Hidden_states in select_experts is only used to extract metadata
        if isinstance(x, tuple):
            x_routing, _ = x
        else:
            x_routing = x
1631
        topk_weights, topk_ids = router.select_experts(
1632
            hidden_states=x_routing,
1633
            router_logits=router_logits,
1634
        )
1635

1636
        # EPLB path
1637
1638
        if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
            assert layer.enable_eplb
1639
1640
1641
1642
1643
1644
1645
1646
            return flashinfer_trtllm_fp4_routed_moe(
                layer=layer,
                x=x,
                topk_ids=topk_ids,
                topk_weights=topk_weights,
                top_k=layer.top_k,
                global_num_experts=layer.global_num_experts,
            )
1647
1648
1649
        else:
            assert self.kernel is not None
            return self.kernel(
1650
1651
1652
1653
1654
                x,
                layer.w13_weight,
                layer.w2_weight,
                topk_weights,
                topk_ids,
1655
                inplace=False,
1656
1657
1658
1659
                activation=layer.activation,
                global_num_experts=layer.global_num_experts,
                expert_map=layer.expert_map,
                apply_router_weight_on_input=layer.apply_router_weight_on_input,
1660
            )
1661
1662
1663
1664
1665


ModelOptNvFp4Config.LinearMethodCls = ModelOptNvFp4LinearMethod
ModelOptNvFp4Config.FusedMoEMethodCls = ModelOptNvFp4FusedMoE
ModelOptNvFp4Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod