modelopt.py 71.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from fnmatch import fnmatch
5
from typing import TYPE_CHECKING, Any
6
7
8
9

import torch
from torch.nn.parameter import Parameter

10
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
11
from vllm.logger import init_logger
12
13
14
from vllm.model_executor.kernels.linear import (
    init_fp8_linear_kernel,
)
15
from vllm.model_executor.layers.attention import Attention, MLAAttention
16
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
17
from vllm.model_executor.layers.fused_moe.config import (
18
    FusedMoEConfig,
19
20
    FusedMoEQuantConfig,
)
21
from vllm.model_executor.layers.fused_moe.layer import (
22
23
24
25
    FusedMoE,
    FusedMoEMethodBase,
    FusedMoeWeightScaleSupported,
)
26
27
28
29
30
31
32
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
    Fp8MoeBackend,
    convert_to_fp8_moe_kernel_format,
    make_fp8_moe_kernel,
    make_fp8_moe_quant_config,
    select_fp8_moe_backend,
)
33
34
35
36
37
38
39
40
from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
    NvFp4MoeBackend,
    convert_to_nvfp4_moe_kernel_format,
    is_global_sf_supported_for_nvfp4_backend,
    make_nvfp4_moe_kernel,
    make_nvfp4_moe_quant_config,
    select_nvfp4_moe_backend,
)
41
42
43
44
45
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
)
46
from vllm.model_executor.layers.quantization import QuantizationMethods
47
from vllm.model_executor.layers.quantization.base_config import (
48
49
50
    QuantizationConfig,
    QuantizeMethodBase,
)
51
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
52
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
53
    flashinfer_trtllm_fp4_moe,
54
    flashinfer_trtllm_fp4_routed_moe,
55
)
56
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
57
    apply_fi_trtllm_fp8_per_tensor_moe,
58
)
59
60
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
    W8A8BlockFp8LinearOp,
61
62
    process_fp8_input_tensor_strategy_moe,
    process_fp8_weight_tensor_strategy_moe,
63
)
64
65
66
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
    get_marlin_input_dtype,
)
67
68
69
70
71
72
from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
    MXFP8_BLOCK_SIZE,
    MXFP8_SCALE_DTYPE,
    MXFP8_VALUE_DTYPE,
    Mxfp8LinearBackend,
    Mxfp8LinearOp,
73
    swizzle_mxfp8_scale,
74
)
75
76
77
78
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
    apply_nvfp4_linear,
    convert_to_nvfp4_linear_kernel_format,
    select_nvfp4_linear_backend,
79
)
80
from vllm.model_executor.layers.quantization.utils.quant_utils import (
81
82
    GroupShape,
    is_layer_skipped,
83
84
85
    kFp8DynamicTokenSym,
    kFp8StaticTensorSym,
    kFp8StaticTokenSym,
86
87
    kNvfp4Dynamic,
    kNvfp4Static,
88
)
89
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
90
    cutlass_block_fp8_supported,
91
92
    requantize_with_max_scale,
)
93
94
95
96
97
98
from vllm.model_executor.parameter import (
    BlockQuantScaleParameter,
    ChannelQuantScaleParameter,
    ModelWeightParameter,
    PerTensorScaleParameter,
)
99
from vllm.model_executor.utils import replace_parameter
100

101
102
103
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

104
105
logger = init_logger(__name__)

106
107
108
109
110
111
112
113
114
QUANT_ALGOS = [
    # FP8 (per-tensor weight + optional static activation scale).
    "FP8",
    # FP8 per-channel weight scale + per-token activation scale.
    "FP8_PER_CHANNEL_PER_TOKEN",
    # FP8 per-block weight-only (ModelOpt may emit this as lowercase).
    "FP8_PB_WO",
    # FP4
    "NVFP4",
115
116
    # MXFP8
    "MXFP8",
117
118
    # MIXED_PRECISION,
    "MIXED_PRECISION",
119
]
120
KV_CACHE_QUANT_ALGOS = ["FP8"]
121
122


123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
    """

    def __init__(self, quant_config: "ModelOptQuantConfigBase"):
        super().__init__(quant_config)


class ModelOptQuantConfigBase(QuantizationConfig):
    LinearMethodCls: type = LinearMethodBase
    FusedMoEMethodCls: type = FusedMoEMethodBase
    KVCacheMethodCls: type = BaseKVCacheMethod

    def __init__(
        self,
        exclude_modules: list[str],
    ):
        super().__init__()
        self.exclude_modules: list[str] = exclude_modules

    def is_layer_excluded(self, prefix: str) -> bool:
        """
        Check if a layer should be excluded from quantization.

        Handles both exact matching (for fused layers) and ModelOpt wildcard matching.

        The ModelOpt exclude_modules list is a list of wildcards.
        """
        if len(self.exclude_modules) == 0:
            return False

        # First check exact matching with fused layer support
        if is_layer_skipped(prefix, self.exclude_modules, self.packed_modules_mapping):
            return True

        # TODO: This special hard coded logic is not needed for quantized checkpoints
        # generated by ModelOpt >= 0.39.0 where they are handled natually by the
        # exclude_modules config. But need to keep them for loading quantized
        # checkpoints generated by older versions. Then check substring matching
        # for patterns not caught by exact match
        for exclude_module in self.exclude_modules:
            # Skip exact matches already handled above
            if exclude_module != prefix and (
                exclude_module in prefix
                or (
                    prefix.startswith("language_model.")
                    and exclude_module in prefix.removeprefix("language_model.")
                )
            ):
                return True

        # modelopt exclude modules are not simple strings, they are wildcards
        for wildcard_pattern in self.exclude_modules:
            if fnmatch(prefix, wildcard_pattern):
                return True

        return False

    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
184
    ) -> "QuantizeMethodBase | None":
185
        # handle kv-cache first so we can focus only on weight quantization thereafter
186
        if isinstance(layer, (Attention, MLAAttention)):
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
            return self.KVCacheMethodCls(self)

        # handle exclusion
        if self.is_layer_excluded(prefix):
            if isinstance(layer, LinearBase):
                return UnquantizedLinearMethod()
            return None

        # TODO: This special hard coded logic is not needed for quantized checkpoints
        # generated by ModelOpt >= 0.39.0 where they are handled natually by the
        # exclude_modules config. But need to keep them for loading quantized
        # checkpoints generated by older versions. Then check substring matching
        # for patterns not caught by exact match
        if "vision_tower" in prefix or "vision_model" in prefix:
            return UnquantizedLinearMethod()

        # now, the layer is quantized, handle it here
        if isinstance(layer, LinearBase):
205
206
207
208
            quant_method = self.LinearMethodCls(self)
            if getattr(quant_method, "backend", "") == "marlin":
                quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
            return quant_method
209
        elif isinstance(layer, FusedMoE):
210
211
212
            quant_method = self.FusedMoEMethodCls(
                quant_config=self, moe_config=layer.moe_config
            )
213
214
215
            if getattr(quant_method, "backend", "") == "marlin":
                quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
            return quant_method
216
217
218
219
220

        return None

    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if len(self.exclude_modules) > 0:
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
            # This is a workaround for the weights remapping issue:
            # https://github.com/vllm-project/vllm/issues/28072
            # Right now, the Nvidia ModelOpt library use just one wildcard pattern:
            #        module_path*
            # It gets applied if the whole tree of modules rooted at module_path
            # is not quantized. Here we replace such pattern by 2 patterns that are
            # collectively equivalent to the original pattern:
            #        module_path
            #        module_path.*
            new_exclude_modules = []
            for exclude in self.exclude_modules:
                if len(exclude) >= 2 and exclude[-1] == "*" and exclude[-2] != ".":
                    new_exclude_modules.append(exclude[:-1])
                    new_exclude_modules.append(exclude[:-1] + ".*")
                else:
                    new_exclude_modules.append(exclude)

            self.exclude_modules = hf_to_vllm_mapper.apply_list(new_exclude_modules)
239

240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
    @staticmethod
    def _extract_modelopt_quant_algo(
        hf_quant_cfg: dict[str, Any] | None,
    ) -> str | None:
        """Extract upper-cased quant_algo from a modelopt config.

        Returns the quant_algo string (upper-cased), or None if the config
        is not a modelopt config.
        """
        if hf_quant_cfg is None:
            return None
        if hf_quant_cfg.get("quant_method", "").lower() != "modelopt":
            return None
        if "quantization" in hf_quant_cfg:
            quant_config = hf_quant_cfg["quantization"]
            if isinstance(quant_config, dict):
                return str(quant_config.get("quant_algo", "")).upper()
            return None
        return str(hf_quant_cfg.get("quant_algo", "")).upper()

260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
    @staticmethod
    def get_config_filenames() -> list[str]:
        return ["hf_quant_config.json"]

    @classmethod
    def _from_config(
        cls,
        *,
        quant_method: str,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
        original_config: dict[str, Any],
        group_size: int | None,
    ) -> "ModelOptQuantConfigBase":
        raise NotImplementedError("Please implement this function in sub classes")

    @classmethod
    def from_config(cls, config: dict[str, Any]) -> "ModelOptQuantConfigBase":
        # Handle both ModelOpt format and compressed-tensors style format
        if "quantization" in config:
            # Traditional ModelOpt format:
            # {"quantization": {"quant_algo": "..."}}
            quant_config = cls.get_from_keys(config, ["quantization"])
            if not isinstance(quant_config, dict):
                raise ValueError("Expected 'quantization' to be a dictionary in config")

            quant_method = quant_config.get("quant_algo")

            # Handle kv_cache_quant_algo with proper type validation
            kv_cache_quant_method = quant_config.get("kv_cache_quant_algo")

            # Handle group_size with proper type validation
            group_size_raw = quant_config.get("group_size")

            # "exclude_modules" is the key in the legacy hf_quant_config.json
            exclude_modules = quant_config.get("exclude_modules", [])
        else:
297
            # Compressed-tensors style format (config.json quantization_config):
298
299
            # {"quant_algo": "...", "quant_method": "modelopt"}
            quant_method = config.get("quant_algo")
300
301
302
303
304
305
306
307
308
309
310

            # "kv_cache_scheme" (a dict) instead of "kv_cache_quant_algo" (a string).
            kv_cache_scheme = config.get("kv_cache_scheme")
            if isinstance(kv_cache_scheme, dict) and (
                kv_cache_scheme.get("type") == "float"
                and kv_cache_scheme.get("num_bits") == 8
            ):
                kv_cache_quant_method = "FP8"
            else:
                kv_cache_quant_method = None

311
312
313
314
315
316
317
            # "ignore" is the key in config.json
            exclude_modules = config.get("ignore", [])
            group_size_raw = config.get("group_size")

        if not quant_method:
            raise ValueError("Missing 'quant_algo' in quantization config")

318
319
320
        # Normalize quant_algo for robust matching (ModelOpt may emit lowercase).
        quant_method = str(quant_method).upper()

321
322
323
324
325
326
327
328
        if kv_cache_quant_method is None:
            # No KV cache quantization, keep this branch just to have this comment
            pass
        elif not isinstance(kv_cache_quant_method, str):
            raise ValueError(
                f"kv_cache_quant_algo must be a string, got "
                f"{type(kv_cache_quant_method)}"
            )
329
330
        else:
            kv_cache_quant_method = kv_cache_quant_method.upper()
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365

        if not isinstance(exclude_modules, list):
            raise ValueError(
                f"exclude_modules must be a list, got {type(exclude_modules)}"
            )

        if group_size_raw is None:
            group_size = None
        elif isinstance(group_size_raw, int):
            group_size = group_size_raw
        else:
            try:
                group_size = int(group_size_raw)
            except (ValueError, TypeError):
                raise ValueError(
                    f"group_size must be an integer, got {type(group_size_raw)}"
                ) from None

        if quant_method not in QUANT_ALGOS:
            raise ValueError(
                f"ModelOpt currently only supports: {QUANT_ALGOS} "
                "quantizations in vLLM. Please check the "
                "`hf_quant_config.json` file for your model's "
                "quant configuration."
            )
        return cls._from_config(
            quant_method=quant_method,
            kv_cache_quant_method=kv_cache_quant_method,
            exclude_modules=exclude_modules,
            group_size=group_size,
            original_config=config,
        )


class ModelOptFp8Config(ModelOptQuantConfigBase):
366
367
368
369
    """Config class for ModelOpt FP8."""

    def __init__(
        self,
370
        quant_method: str,
371
372
373
        is_checkpoint_fp8_serialized: bool,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
374
    ) -> None:
375
        super().__init__(exclude_modules)
376
        self.quant_method = quant_method
377
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
378
        self.kv_cache_quant_method = kv_cache_quant_method
379
        if is_checkpoint_fp8_serialized:
380
            logger.warning(
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
                "Detected ModelOpt fp8 checkpoint (quant_algo=%s). Please note "
                "that the format is experimental and could change.",
                quant_method,
            )

        # Select LinearMethod implementation based on quant_algo.
        if self.quant_method == "FP8":
            self.LinearMethodCls = ModelOptFp8LinearMethod
        elif self.quant_method == "FP8_PER_CHANNEL_PER_TOKEN":
            self.LinearMethodCls = ModelOptFp8PcPtLinearMethod
        elif self.quant_method == "FP8_PB_WO":
            self.LinearMethodCls = ModelOptFp8PbWoLinearMethod
        else:
            raise ValueError(
                "Unsupported ModelOpt FP8 quant_algo for vLLM: "
                f"{self.quant_method}. Supported: FP8 / "
                "FP8_PER_CHANNEL_PER_TOKEN / FP8_PB_WO."
398
            )
399

400
    def get_name(self) -> QuantizationMethods:
401
402
        return "modelopt"

403
    def get_supported_act_dtypes(self) -> list[torch.dtype]:
404
405
406
407
408
409
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
        return 89

410
411
    @classmethod
    def override_quantization_method(
412
        cls, hf_quant_cfg, user_quant
413
    ) -> QuantizationMethods | None:
414
415
416
        algo = cls._extract_modelopt_quant_algo(hf_quant_cfg)
        if algo is not None and algo == "FP8":
            return "modelopt"
417
418
        return None

419
    @classmethod
420
421
422
423
424
425
426
427
428
    def _from_config(
        cls,
        *,
        quant_method: str,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
        original_config: dict[str, Any],
        **kwargs: Any,
    ) -> "ModelOptFp8Config":
429
        is_checkpoint_fp8_serialized = "FP8" in quant_method
430

431
432
433
434
435
436
        return cls(
            quant_method,
            is_checkpoint_fp8_serialized,
            kv_cache_quant_method,
            exclude_modules,
        )
437

438
439
440
441

class ModelOptFp8LinearMethod(LinearMethodBase):
    """Linear method for Model Optimizer static quantization.
    Supports loading FP8 checkpoints with static weight scale and
442
    activation scale. Future support might be added for dynamic
443
444
445
446
    scales.

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
447
    2. Only support float8_e4m3fn datatype
448
449
450
        Args: quant_config: The ModelOpt quantization config.
    """

451
    def __init__(self, quant_config: ModelOptFp8Config) -> None:
452
        self.quant_config = quant_config
453
454
455
456
457
        self.fp8_linear = init_fp8_linear_kernel(
            activation_quant_key=kFp8StaticTensorSym,
            weight_quant_key=kFp8StaticTensorSym,
            out_dtype=torch.get_default_dtype(),
            module_name=self.__class__.__name__,
458
        )
459
460
461
462
463

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
464
        output_partition_sizes: list[int],
465
466
467
468
469
470
471
472
473
474
475
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
476
477
478
479
480
481
482
483
484
485
486
487
488
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_fp8_serialized
            else params_dtype
        )
        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition, input_size_per_partition, dtype=weight_dtype
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
489
490
491
492
        layer.register_parameter("weight", weight)

        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
493
494
495
496
            weight_scale = PerTensorScaleParameter(
                data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
                weight_loader=weight_loader,
            )
497
498
499
            weight_scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("weight_scale", weight_scale)
            # INPUT SCALE
500
501
502
503
            scale = PerTensorScaleParameter(
                data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
                weight_loader=weight_loader,
            )
504
505
506
507

            scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("input_scale", scale)

508
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
509
510
511
512
        weight = layer.weight
        max_w_scale = layer.weight_scale.max()
        if not (layer.weight_scale == layer.weight_scale[0]).all():
            max_w_scale, weight = requantize_with_max_scale(
513
514
                layer.weight, layer.weight_scale, layer.logical_widths
            )
515
516
        layer.weight = Parameter(weight.t(), requires_grad=False)
        layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
517
        layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
518
519
520
521
522

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
523
        bias: torch.Tensor | None = None,
524
    ) -> torch.Tensor:
525
        return self.fp8_linear.apply_weights(layer, x, bias)
526
527


528
529
530
531
532
533
534
535
536
537
538
class ModelOptFp8PcPtLinearMethod(LinearMethodBase):
    """Linear method for ModelOpt FP8_PER_CHANNEL_PER_TOKEN checkpoints.

    Expected checkpoint structure (per Linear):
    - weight: fp8-e4m3fn, shape [out, in]
    - weight_scale: fp32, shape [out] (per-output-channel)
    - no input_scale (activations are dynamically quantized per-token)
    """

    def __init__(self, quant_config: ModelOptFp8Config) -> None:
        self.quant_config = quant_config
539
540
541
542
543
        self.fp8_linear = init_fp8_linear_kernel(
            activation_quant_key=kFp8DynamicTokenSym,
            weight_quant_key=kFp8StaticTokenSym,
            out_dtype=torch.get_default_dtype(),
            module_name=self.__class__.__name__,
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
        )

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size

        if not self.quant_config.is_checkpoint_fp8_serialized:
            raise ValueError(
                "FP8_PER_CHANNEL_PER_TOKEN currently only supports "
                "FP8-serialized checkpoints."
            )

        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition

        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition,
                dtype=torch.float8_e4m3fn,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
        layer.register_parameter("weight", weight)

        weight_scale = ChannelQuantScaleParameter(
            data=torch.empty(output_size_per_partition, dtype=torch.float32),
            output_dim=0,
            weight_loader=weight_loader,
        )
        weight_scale[:] = torch.finfo(torch.float32).min
        layer.register_parameter("weight_scale", weight_scale)

590
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
591
592
593
594
595
596
597
598
599
        layer.weight = Parameter(layer.weight.t(), requires_grad=False)
        layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
600
        return self.fp8_linear.apply_weights(layer, x, bias)
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690


class ModelOptFp8PbWoLinearMethod(LinearMethodBase):
    """Linear method for ModelOpt FP8_PB_WO checkpoints.

    ModelOpt exports `weight_scale` as a 4D tensor:
      [out_blk, 1, in_blk, 1]
    where block size is typically 128 for both dims.

    vLLM executes it as FP8 GEMM with *dynamic per-token* activation quant.
    """

    _WEIGHT_BLOCK_SIZE: tuple[int, int] = (128, 128)

    def __init__(self, quant_config: ModelOptFp8Config) -> None:
        self.quant_config = quant_config
        block_n, block_k = self._WEIGHT_BLOCK_SIZE
        self.weight_block_size = list(self._WEIGHT_BLOCK_SIZE)
        self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
            weight_group_shape=GroupShape(block_n, block_k),
            act_quant_group_shape=GroupShape(1, block_k),
            cutlass_block_fp8_supported=cutlass_block_fp8_supported(),
            use_aiter_and_is_supported=False,
        )

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size

        if not self.quant_config.is_checkpoint_fp8_serialized:
            raise ValueError(
                "FP8_PB_WO currently only supports FP8-serialized checkpoints."
            )

        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition

        # Expose block size so the v2 weight loaders can translate offsets from
        # element-space -> block-space for BlockQuantScaleParameter.
        layer.weight_block_size = self.weight_block_size

        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition,
                dtype=torch.float8_e4m3fn,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
        layer.register_parameter("weight", weight)

        block_n, block_k = self._WEIGHT_BLOCK_SIZE
        if output_size_per_partition % block_n != 0:
            raise ValueError(
                "ModelOpt FP8_PB_WO requires out_features divisible by "
                f"{block_n}, got {output_size_per_partition}."
            )
        if input_size_per_partition % block_k != 0:
            raise ValueError(
                "ModelOpt FP8_PB_WO requires in_features divisible by "
                f"{block_k}, got {input_size_per_partition}."
            )

        out_blks = output_size_per_partition // block_n
        in_blks = input_size_per_partition // block_k

        # Match ModelOpt's exported shape so weight loading works without a
        # custom loader: [out_blk, 1, in_blk, 1]
        weight_scale = BlockQuantScaleParameter(
            data=torch.empty((out_blks, 1, in_blks, 1), dtype=torch.float32),
            input_dim=2,
            output_dim=0,
            weight_loader=weight_loader,
        )
        weight_scale[:] = torch.finfo(torch.float32).min
        layer.register_parameter("weight_scale", weight_scale)

691
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
        # Keep weight in [out, in] layout for W8A8BlockFp8LinearOp.
        layer.weight = Parameter(layer.weight.data, requires_grad=False)

        scale = layer.weight_scale
        if scale.dim() == 4:
            # [out_blk, 1, in_blk, 1] -> [out_blk, in_blk]
            scale = scale.squeeze(1).squeeze(-1)
        elif scale.dim() != 2:
            raise ValueError(
                "Unexpected ModelOpt FP8_PB_WO weight_scale shape: "
                f"{tuple(scale.shape)}."
            )

        layer.weight_scale = Parameter(scale.contiguous(), requires_grad=False)

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        return self.w8a8_block_fp8_linear.apply(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            input_scale=None,
            bias=bias,
        )


722
723
724
725
726
727
728
729
class ModelOptFp8MoEMethod(FusedMoEMethodBase):
    """MoE method for ModelOpt FP8.
    Supports loading FP8 checkpoints with static weight scale and
    activation scale.
    Args:
        quant_config: The ModelOpt quantization config.
    """

730
731
732
    def __init__(
        self,
        quant_config: ModelOptFp8Config,
733
        moe_config: FusedMoEConfig,
734
    ) -> None:
735
        super().__init__(moe_config)
736
        self.quant_config = quant_config
737
        assert self.quant_config.is_checkpoint_fp8_serialized
738
739
740
741
742
743

        # Select Fp8 MoE backend
        self.fp8_backend, self.experts_cls = select_fp8_moe_backend(
            config=self.moe,
            weight_key=kFp8StaticTensorSym,
            activation_key=kFp8StaticTensorSym,
744
        )
745

746
    def maybe_make_prepare_finalize(
747
        self,
748
        routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
749
    ) -> mk.FusedMoEPrepareAndFinalize | None:
750
751
752
753
        raise ValueError(
            f"{self.__class__.__name__} uses the new modular kernel initialization "
            "logic. This function should not be called."
        )
754
755
756
757

    def select_gemm_impl(
        self,
        prepare_finalize: mk.FusedMoEPrepareAndFinalize,
758
        layer: torch.nn.Module,
759
    ) -> mk.FusedMoEPermuteExpertsUnpermute:
760
761
762
        raise ValueError(
            f"{self.__class__.__name__} uses the new modular kernel initialization "
            "logic. This function should not be called."
763
        )
764
765
766
767
768
769
770
771
772
773

    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
774
775
776
        layer.orig_dtype = params_dtype
        layer.num_experts = num_experts

777
        # Use FP8 dtype if checkpoint is serialized
778
779
780
781
782
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_fp8_serialized
            else params_dtype
        )
783
784
        weight_loader = extra_weight_attrs.get("weight_loader")

785
        w13_num_shards = 2 if self.moe.is_act_and_mul else 1
786

787
        w13_weight = ModelWeightParameter(
788
789
            data=torch.empty(
                num_experts,
790
                w13_num_shards * intermediate_size_per_partition,
791
792
793
                hidden_size,
                dtype=weight_dtype,
            ),
794
795
796
797
798
799
800
            input_dim=2,
            output_dim=1,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_weight", w13_weight)

        w2_weight = ModelWeightParameter(
801
802
803
804
805
806
            data=torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=weight_dtype,
            ),
807
808
809
810
811
812
            input_dim=2,
            output_dim=1,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w2_weight", w2_weight)

813
814
815
816
817
818
        # WEIGHT SCALES - Per-tensor scaling for ModelOpts
        # For gated MoE, allocate 2 scales for w1 and w3 respectively.
        # They will be combined to a single scale after weight loading.
        # For non-gated MoE, allocate 1 scale for w13.
        w13_weight_scale = PerTensorScaleParameter(
            data=torch.full(
819
                (num_experts, w13_num_shards),
820
821
822
823
824
825
826
827
828
829
830
                1.0,
                dtype=torch.float32,
            ),
            weight_loader=weight_loader,
        )
        w2_weight_scale = PerTensorScaleParameter(
            data=torch.full((num_experts,), 1.0, dtype=torch.float32),
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_weight_scale", w13_weight_scale)
        layer.register_parameter("w2_weight_scale", w2_weight_scale)
831

832
833
834
835
        # INPUT SCALES - Per-tensor scaling for ModelOpt
        w13_input_scale = PerTensorScaleParameter(
            data=torch.full((num_experts,), 1.0, dtype=torch.float32),
            weight_loader=weight_loader,
836
        )
837
838
839
840
841
842
        w2_input_scale = PerTensorScaleParameter(
            data=torch.full((num_experts,), 1.0, dtype=torch.float32),
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_input_scale", w13_input_scale)
        layer.register_parameter("w2_input_scale", w2_input_scale)
843

844
845
    def _setup_kernel(
        self,
846
        layer: FusedMoE,
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
        w13: torch.Tensor,
        w2: torch.Tensor,
        w13_scale: torch.Tensor,
        w2_scale: torch.Tensor,
        w13_input_scale: torch.Tensor,
        w2_input_scale: torch.Tensor,
    ):
        w13, w2, w13_scale, w2_scale = convert_to_fp8_moe_kernel_format(
            fp8_backend=self.fp8_backend,
            layer=layer,
            w13=w13,
            w2=w2,
            w13_scale=w13_scale,
            w2_scale=w2_scale,
            w13_input_scale=w13_input_scale,
            w2_input_scale=w2_input_scale,
        )
864

865
866
867
868
869
870
871
        # Replace parameters with updated versions. Note that this helper
        # function ensures the replacement is compatible with RL weight reloads.
        replace_parameter(layer, "w13_weight", w13)
        replace_parameter(layer, "w2_weight", w2)
        replace_parameter(layer, "w13_weight_scale", w13_scale)
        replace_parameter(layer, "w2_weight_scale", w2_scale)

872
        # Setup modular kernel.
873
874
        self.moe_quant_config = self.get_fused_moe_quant_config(layer)
        if self.moe_quant_config:
875
            assert self.experts_cls is not None
876
            self.moe_mk = make_fp8_moe_kernel(
877
878
879
                moe_quant_config=self.moe_quant_config,
                moe_config=self.moe,
                fp8_backend=self.fp8_backend,
880
                experts_cls=self.experts_cls,
881
882
                routing_tables=layer._maybe_init_expert_routing_tables(),
                shared_experts=layer.shared_experts,
883
            )
884

885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        w13 = layer.w13_weight
        w2 = layer.w2_weight
        w13_scale = layer.w13_weight_scale
        w2_scale = layer.w2_weight_scale
        w13_input_scale = layer.w13_input_scale
        w2_input_scale = layer.w2_input_scale

        # Per tensor kernels require single activation scale. Use the max.
        w13_input_scale, w2_input_scale = process_fp8_input_tensor_strategy_moe(
            w13_input_scale, w2_input_scale
        )
        replace_parameter(layer, "w13_input_scale", w13_input_scale)
        replace_parameter(layer, "w2_input_scale", w2_input_scale)

        # Per tensor kernels require single weight scale for w13 per expert, but
        # on disk there is a scale for w1 and w3. Use the max to requantize.
        shard_size = layer.intermediate_size_per_partition
        w13, w13_scale = process_fp8_weight_tensor_strategy_moe(
            w13,
            w13_scale,
            shard_size,
            num_experts=layer.w13_weight.shape[0],
            is_act_and_mul=self.moe.is_act_and_mul,
909
910
        )

911
912
913
914
        # Shuffle weights to runtime format and setup kernel.
        self._setup_kernel(
            layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale
        )
915

916
    def get_fused_moe_quant_config(
917
        self, layer: torch.nn.Module
918
    ) -> FusedMoEQuantConfig | None:
919
920
921
922
923
924
925
926
927
928
929
930
        w1_scale = layer.w13_weight_scale
        w2_scale = layer.w2_weight_scale
        a1_scale = layer.w13_input_scale
        a2_scale = layer.w2_input_scale

        return make_fp8_moe_quant_config(
            fp8_backend=self.fp8_backend,
            w1_scale=w1_scale,
            w2_scale=w2_scale,
            a1_scale=a1_scale,
            a2_scale=a2_scale,
        )
931

932
933
934
935
936
    @property
    def is_monolithic(self) -> bool:
        return self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM

    def apply_monolithic(
937
        self,
938
        layer: FusedMoE,
939
940
        x: torch.Tensor,
        router_logits: torch.Tensor,
941
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
942
943
944
945
946
        assert self.is_monolithic
        assert self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
        if layer.enable_eplb:
            raise NotImplementedError(
                "EPLB not supported for FlashInfer TRTLLM FP8 MoE Backend."
947
            )
948
949
        # TODO(rob): this validation should happen at kernel selection
        # time in the oracle rather than here.
950
951
952
953
        SUPPORTED_ACTIVATIONS = [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
        assert layer.activation in SUPPORTED_ACTIVATIONS, (
            f"Only {SUPPORTED_ACTIVATIONS} activations are supported for FlashInfer "
            f"TRTLLM FP4 MoE, {layer.activation} found instead."
954
955
956
        )
        return apply_fi_trtllm_fp8_per_tensor_moe(
            layer=layer,
957
958
            hidden_states=x,
            router_logits=router_logits,
959
960
961
962
963
964
            routing_bias=layer.e_score_correction_bias,
            global_num_experts=layer.global_num_experts,
            top_k=layer.top_k,
            num_expert_group=layer.num_expert_group,
            topk_group=layer.topk_group,
            apply_router_weight_on_input=layer.apply_router_weight_on_input,
965
        )
966

967
968
969
970
971
972
    def apply(
        self,
        layer: FusedMoE,
        x: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
973
        shared_experts_input: torch.Tensor | None,
974
975
976
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        assert not self.is_monolithic

977
978
979
        # TODO(rob): this validation should happen at kernel selection
        # time in the oracle rather than here.
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
980
981
982
983
            assert layer.activation in (
                MoEActivation.SILU,
                MoEActivation.RELU2_NO_MUL,
            ), (
984
                "Expected activation to be in ('silu', 'relu2_no_mul'),"
985
                f"but got {layer.activation}"
986
            )
987

988
989
        assert self.moe_mk is not None
        return self.moe_mk(
990
991
992
993
994
            hidden_states=x,
            w1=layer.w13_weight,
            w2=layer.w2_weight,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
995
996
997
998
            activation=layer.activation,
            global_num_experts=layer.global_num_experts,
            expert_map=layer.expert_map,
            apply_router_weight_on_input=layer.apply_router_weight_on_input,
999
            shared_experts_input=shared_experts_input,
1000
1001
        )

1002

1003
1004
1005
1006
1007
1008
ModelOptFp8Config.LinearMethodCls = ModelOptFp8LinearMethod
ModelOptFp8Config.FusedMoEMethodCls = ModelOptFp8MoEMethod
ModelOptFp8Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod


class ModelOptNvFp4Config(ModelOptQuantConfigBase):
1009
1010
1011
1012
1013
    """Config class for ModelOpt FP4."""

    def __init__(
        self,
        is_checkpoint_nvfp4_serialized: bool,
1014
        kv_cache_quant_algo: str | None,
1015
        exclude_modules: list[str],
1016
1017
        group_size: int = 16,
    ) -> None:
1018
        super().__init__(exclude_modules)
1019
1020
1021
1022
        self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
        if is_checkpoint_nvfp4_serialized:
            logger.warning(
                "Detected ModelOpt NVFP4 checkpoint. Please note that"
1023
1024
                " the format is experimental and could change in future."
            )
1025
1026
1027
1028

            self.group_size = group_size
            self.kv_cache_quant_algo = kv_cache_quant_algo

1029
    def get_name(self) -> QuantizationMethods:
1030
        return "modelopt_fp4"
1031

1032
    def get_supported_act_dtypes(self) -> list[torch.dtype]:
1033
1034
1035
1036
        return [torch.bfloat16, torch.half, torch.float8_e4m3fn]

    @classmethod
    def get_min_capability(cls) -> int:
1037
        return 75
1038

1039
1040
    @classmethod
    def override_quantization_method(
1041
        cls, hf_quant_cfg, user_quant
1042
    ) -> QuantizationMethods | None:
1043
1044
1045
        algo = cls._extract_modelopt_quant_algo(hf_quant_cfg)
        if algo is not None and ("NVFP4" in algo or "FP4" in algo):
            return "modelopt_fp4"
1046
1047
        return None

1048
    @classmethod
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
    def _from_config(
        cls,
        *,
        quant_method: str,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
        original_config: dict[str, Any],
        group_size: int | None,
        **kwargs: Any,
    ) -> "ModelOptNvFp4Config":
1059
        is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
1060

1061
1062
1063
        if group_size is None:
            group_size = 16  # Default value

1064
        # For FP4, these fields are required
1065
        if is_checkpoint_nvfp4_serialized and "quantization" in original_config:
1066
            # Check if required fields are present in the quantization config
1067
            quant_config = original_config["quantization"]
1068
            required_fields = ["group_size", "kv_cache_quant_algo", "exclude_modules"]
1069
1070
1071
1072
1073
1074
            missing_fields = [
                field for field in required_fields if field not in quant_config
            ]
            if missing_fields:
                raise ValueError(
                    f"NVFP4 quantization requires the following fields in "
1075
1076
1077
1078
1079
                    f"hf_quant_config.json: {missing_fields}"
                )

        return cls(
            is_checkpoint_nvfp4_serialized,
1080
            kv_cache_quant_method,
1081
1082
1083
            exclude_modules,
            group_size,
        )
1084
1085
1086
1087
1088


class ModelOptNvFp4LinearMethod(LinearMethodBase):
    """Linear method for Model Optimizer NVFP4.
    Supports loading NVFP4 checkpoints with the following structure:
1089

1090
1091
1092
1093
1094
1095
1096
    input_scale: torch.float32, scalar ,
    weight: NVFP4(represented as byte) Shape: [1, X, y/2]
    weight_scale: FP8-E4M3, Shape: [X, Y], aka per block scale,
    weight_scale_2: torch.float32, scalar,
    Args: quant_config: The ModelOpt quantization config.
    """

1097
    def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
1098
        self.quant_config = quant_config
1099
        self.marlin_input_dtype = None
1100
        self.backend = select_nvfp4_linear_backend()
1101

1102
1103
1104
1105
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
1106
        output_partition_sizes: list[int],
1107
1108
1109
1110
1111
1112
1113
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size
        if not self.quant_config.is_checkpoint_nvfp4_serialized:
1114
1115
1116
1117
            raise ValueError(
                "NVFP4 quantization was selected, "
                " dynamic quantization is not supported."
            )
1118
1119
1120
1121
1122
1123
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition

1124
1125
1126
1127
        if input_size_per_partition % 16 != 0:
            raise ValueError(
                "Unsupported model when in features size is not multiple of 16"
            )
1128
        # The nvfp4 weight is still represented as
1129
1130
1131
1132
1133
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_nvfp4_serialized
            else params_dtype
        )
1134
1135
1136
1137
1138
1139
        # Weight
        weight = ModelWeightParameter(
            data=torch.empty(
                # 2 fp4 items are packed in the input dimension
                layer.output_size_per_partition,
                layer.input_size_per_partition // 2,
1140
1141
                dtype=torch.uint8,
            ),
1142
1143
            input_dim=1,
            output_dim=0,
1144
1145
            weight_loader=weight_loader,
        )
1146
1147
        layer.register_parameter("weight", weight)

1148
1149
        # Input Global Scale
        input_global_scale = PerTensorScaleParameter(
1150
1151
1152
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )
1153
        layer.register_parameter("input_scale", input_global_scale)
1154

1155
1156
        # Weight Global Scale
        weight_global_scale = PerTensorScaleParameter(
1157
1158
1159
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )
1160
        layer.register_parameter("weight_scale_2", weight_global_scale)
1161
1162

        # Per Block Weight Scale
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
        weight_scale = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition // self.quant_config.group_size,
                dtype=weight_dtype,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
1173
1174
1175

        layer.register_parameter("weight_scale", weight_scale)

1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        # Rename ModelOpt checkpoint names to standardized names
        input_global_scale = layer.input_scale.max().to(torch.float32)
        layer.input_global_scale = Parameter(input_global_scale, requires_grad=False)
        del layer.input_scale
        weight_global_scale = layer.weight_scale_2.max().to(torch.float32)
        layer.weight_global_scale = Parameter(weight_global_scale, requires_grad=False)
        del layer.weight_scale_2

        # Pre-compute alpha and inverse for runtime quantization
1186
        layer.alpha = Parameter(
1187
            layer.input_global_scale * layer.weight_global_scale, requires_grad=False
1188
        )
1189
1190
        layer.input_global_scale_inv = Parameter(
            (1.0 / layer.input_global_scale).to(torch.float32), requires_grad=False
1191
        )
1192

1193
1194
        # Convert layer to NVFP4 linear kernel format
        convert_to_nvfp4_linear_kernel_format(self.backend, layer)
1195
1196
1197
1198
1199

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
1200
        bias: torch.Tensor | None = None,
1201
    ) -> torch.Tensor:
1202
1203
1204
1205
1206
        return apply_nvfp4_linear(
            backend=self.backend,
            layer=layer,
            x=x,
            bias=bias,
1207
        )
1208

1209
1210
1211
1212

class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
    """
    MoE Method for FP4 Quantization.
1213
    Args:
1214
1215
1216
        quant_config: NVFP4 Quant Config
    """

1217
1218
1219
    def __init__(
        self,
        quant_config: ModelOptNvFp4Config,
1220
        moe_config: FusedMoEConfig,
1221
    ) -> None:
1222
        super().__init__(moe_config)
1223
        self.quant_config = quant_config
1224
1225
1226
1227
1228
1229
1230
        # Select experts implementation.
        self.nvfp4_backend, self.experts_cls = select_nvfp4_moe_backend(
            config=self.moe,
            weight_key=kNvfp4Static,
            activation_key=kNvfp4Dynamic,
        )

1231
1232
1233
        self.use_global_sf = is_global_sf_supported_for_nvfp4_backend(
            self.nvfp4_backend
        )
1234

1235
1236
1237
1238
    def maybe_make_prepare_finalize(
        self,
        routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
    ) -> mk.FusedMoEPrepareAndFinalize | None:
1239
1240
1241
1242
        raise ValueError(
            f"{self.__class__.__name__} uses the new modular kernel initialization "
            "logic. This function should not be called."
        )
1243

1244
1245
1246
    def select_gemm_impl(
        self,
        prepare_finalize: mk.FusedMoEPrepareAndFinalize,
1247
        layer: torch.nn.Module,
1248
    ) -> mk.FusedMoEPermuteExpertsUnpermute:
1249
1250
1251
        raise ValueError(
            f"{self.__class__.__name__} uses the new modular kernel initialization "
            "logic. This function should not be called."
1252
        )
1253

1254
1255
1256
1257
1258
1259
    def uses_weight_scale_2_pattern(self) -> bool:
        """
        FP4 variants use 'weight_scale_2' pattern for per-tensor weight scales.
        """
        return True

1260
1261
1262
1263
1264
1265
1266
1267
1268
    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
1269
        assert self.quant_config.is_checkpoint_nvfp4_serialized
1270

1271
1272
        layer.num_experts = num_experts
        layer.params_dtype = params_dtype
1273
1274
1275
1276
        layer.quant_config = self.quant_config
        weight_dtype = torch.uint8
        weight_scale_dtype = torch.float8_e4m3fn
        weight_loader = extra_weight_attrs.get("weight_loader")
1277
        global_num_experts = extra_weight_attrs.get("global_num_experts")
1278
        w13_num_shards = 2 if self.moe.is_act_and_mul else 1
1279
1280
1281
1282
        # GEMM 1
        w13_weight = ModelWeightParameter(
            data=torch.empty(
                num_experts,
1283
                w13_num_shards * intermediate_size_per_partition,
1284
1285
                # 2 fp4 items are packed in the input dimension
                hidden_size // 2,
1286
1287
                dtype=weight_dtype,
            ),
1288
1289
            input_dim=1,
            output_dim=2,
1290
1291
            weight_loader=weight_loader,
        )
1292
1293
1294
1295
1296
1297
1298
1299
1300
        layer.register_parameter("w13_weight", w13_weight)

        # GEMM 2
        w2_weight = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                hidden_size,
                # 2 fp4 items are packed in the input dimension
                intermediate_size_per_partition // 2,
1301
1302
                dtype=weight_dtype,
            ),
1303
1304
            input_dim=1,
            output_dim=2,
1305
1306
            weight_loader=weight_loader,
        )
1307
1308
1309
1310
1311
        layer.register_parameter("w2_weight", w2_weight)

        w13_weight_scale = ModelWeightParameter(
            data=torch.empty(
                num_experts,
1312
                w13_num_shards * intermediate_size_per_partition,
1313
1314
                # 2 fp4 items are packed in the input dimension
                hidden_size // self.quant_config.group_size,
1315
1316
                dtype=weight_scale_dtype,
            ),
1317
1318
            input_dim=1,
            output_dim=2,
1319
1320
            weight_loader=weight_loader,
        )
1321
1322
1323
1324
1325
1326
1327
        layer.register_parameter("w13_weight_scale", w13_weight_scale)

        w2_weight_scale = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                hidden_size,
                # 2 fp4 items are packed in the input dimension
1328
1329
1330
                intermediate_size_per_partition // self.quant_config.group_size,
                dtype=weight_scale_dtype,
            ),
1331
1332
            input_dim=1,
            output_dim=2,
1333
1334
            weight_loader=weight_loader,
        )
1335
1336
1337
        layer.register_parameter("w2_weight_scale", w2_weight_scale)

        extra_weight_attrs.update(
1338
1339
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
        )
1340
1341

        w13_weight_scale_2 = PerTensorScaleParameter(
1342
            data=torch.empty(num_experts, w13_num_shards, dtype=torch.float32),
1343
1344
            weight_loader=weight_loader,
        )
1345
1346
1347
1348
        layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)

        w2_weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(num_experts, dtype=torch.float32),
1349
1350
            weight_loader=weight_loader,
        )
1351
1352
1353
        layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2)

        extra_weight_attrs.update(
1354
1355
            {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
        )
1356

1357
1358
        global_sf_num_experts = (
            global_num_experts if self.use_global_sf else num_experts
1359
        )
1360
        w13_input_scale = PerTensorScaleParameter(
1361
            data=torch.empty(
1362
                global_sf_num_experts,
1363
                w13_num_shards,
1364
1365
                dtype=torch.float32,
            ),
1366
1367
            weight_loader=weight_loader,
        )
1368
1369
        layer.register_parameter("w13_input_scale", w13_input_scale)

1370
        w2_input_scale = PerTensorScaleParameter(
1371
            data=torch.empty(global_sf_num_experts, dtype=torch.float32),
1372
1373
            weight_loader=weight_loader,
        )
1374
1375
        layer.register_parameter("w2_input_scale", w2_input_scale)

1376
    def process_weights_after_loading(self, layer: FusedMoE) -> None:
1377
1378
1379
        """
        Convert NVFP4 MoE weights into kernel format and setup the kernel.
        """
1380

1381
        # Use a single gscale for w13.
1382
        if self.moe.is_act_and_mul and not torch.allclose(
1383
1384
            layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]
        ):
1385
1386
            logger.warning_once(
                "w1_weight_scale_2 must match w3_weight_scale_2. "
1387
1388
                "Accuracy may be affected."
            )
1389
        w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0].contiguous()
1390

1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
        (
            w13,
            w13_scale,
            w13_scale_2,
            a13_scale,
            w2,
            w2_scale,
            w2_scale_2,
            a2_scale,
        ) = convert_to_nvfp4_moe_kernel_format(
            nvfp4_backend=self.nvfp4_backend,
            layer=layer,
            w13=layer.w13_weight,
            w13_scale=layer.w13_weight_scale,
            w13_scale_2=w13_weight_scale_2,
            a13_scale=layer.w13_input_scale,
            w2=layer.w2_weight,
            w2_scale=layer.w2_weight_scale,
            w2_scale_2=layer.w2_weight_scale_2,
            a2_scale=layer.w2_input_scale,
            is_act_and_mul=self.moe.is_act_and_mul,
1412
        )
1413

1414
1415
1416
1417
1418
1419
1420
1421
        replace_parameter(layer, "w13_weight", w13)
        replace_parameter(layer, "w13_weight_scale", w13_scale)
        replace_parameter(layer, "w13_weight_scale_2", w13_scale_2)
        replace_parameter(layer, "w13_input_scale", a13_scale)
        replace_parameter(layer, "w2_weight", w2)
        replace_parameter(layer, "w2_weight_scale", w2_scale)
        replace_parameter(layer, "w2_weight_scale_2", w2_scale_2)
        replace_parameter(layer, "w2_input_scale", a2_scale)
1422

1423
1424
1425
1426
        # Setup modular kernel for TP case and naive DP/EP case.
        # In non-naive DP/EP case, we will create a ModularKernelMethod.
        # TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
        # in both cases.
1427
        self.moe_quant_config = self.get_fused_moe_quant_config(layer)
1428
        if self.moe_quant_config:
1429
            assert self.experts_cls is not None
1430
            self.moe_mk = make_nvfp4_moe_kernel(
1431
                moe_quant_config=self.moe_quant_config,
1432
                moe_config=self.moe,
1433
                experts_cls=self.experts_cls,
1434
1435
                shared_experts=layer.shared_experts,
                routing_tables=layer._maybe_init_expert_routing_tables(),
1436
            )
1437

1438
1439
1440
1441
    @property
    def do_post_quant_allgather(self):
        return self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM

1442
1443
1444
1445
1446
1447
1448
    def prepare_dp_allgather_tensor(
        self,
        layer: FusedMoE,
        hidden_states: torch.Tensor,
        router_logits: torch.Tensor,
    ) -> tuple[torch.Tensor, list[torch.Tensor]]:
        """Optionally prepare extra tensors to carry through DP allgather/EP."""
1449
1450
1451
1452
1453
1454
        if self.nvfp4_backend != NvFp4MoeBackend.FLASHINFER_TRTLLM:
            raise RuntimeError(
                "prepare_dp_allgather_tensor is only supported for "
                "FlashInfer TRTLLM NVFP4 MoE backend."
            )

1455
1456
1457
1458
        import flashinfer

        hidden_states_fp4, hidden_states_sf = flashinfer.fp4_quantize(
            hidden_states,
1459
            layer.a1_gscale,
1460
1461
1462
1463
1464
            is_sf_swizzled_layout=False,
        )
        extra_tensors: list[torch.Tensor] = [hidden_states_sf]
        return hidden_states_fp4, extra_tensors

1465
    def get_fused_moe_quant_config(
1466
        self, layer: torch.nn.Module
1467
    ) -> FusedMoEQuantConfig | None:
1468
1469
1470
        return make_nvfp4_moe_quant_config(
            backend=self.nvfp4_backend,
            w13_scale=layer.w13_weight_scale,
1471
            w2_scale=layer.w2_weight_scale,
1472
1473
1474
1475
            w13_scale_2=layer.w13_weight_scale_2,
            w2_scale_2=layer.w2_weight_scale_2,
            a13_scale=layer.w13_input_scale,
            a2_scale=layer.w2_input_scale,
1476
1477
        )

1478
1479
1480
1481
    @property
    def supports_eplb(self) -> bool:
        return True

1482
1483
1484
1485
1486
1487
1488
1489
    @property
    def is_monolithic(self) -> bool:
        return (
            self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
            and not self.moe.moe_parallel_config.enable_eplb
        )

    def apply_monolithic(
1490
        self,
1491
        layer: FusedMoE,
1492
1493
        x: torch.Tensor,
        router_logits: torch.Tensor,
1494
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
1495
1496
        assert self.is_monolithic
        assert (
1497
            self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
1498
            and not layer.enable_eplb
1499
        )
1500

1501
1502
1503
        return flashinfer_trtllm_fp4_moe(
            layer=layer,
            x=x,
1504
            router_logits=router_logits,
1505
1506
1507
1508
1509
1510
1511
            top_k=layer.top_k,
            activation=layer.activation,
            global_num_experts=layer.global_num_experts,
            num_expert_group=layer.num_expert_group,
            topk_group=layer.topk_group,
            custom_routing_function=layer.custom_routing_function,
            e_score_correction_bias=layer.e_score_correction_bias,
1512
        )
1513

1514
1515
1516
1517
1518
1519
    def apply(
        self,
        layer: FusedMoE,
        x: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
1520
        shared_experts_input: torch.Tensor | None,
1521
1522
1523
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        assert not self.is_monolithic

1524
        # EPLB path
1525
1526
        if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
            assert layer.enable_eplb
1527
1528
1529
1530
1531
1532
            return flashinfer_trtllm_fp4_routed_moe(
                layer=layer,
                x=x,
                topk_ids=topk_ids,
                topk_weights=topk_weights,
                top_k=layer.top_k,
1533
                activation=layer.activation,
1534
1535
                global_num_experts=layer.global_num_experts,
            )
1536
        else:
1537
1538
            assert self.moe_mk is not None
            return self.moe_mk(
1539
1540
1541
1542
1543
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
1544
1545
1546
1547
                activation=layer.activation,
                global_num_experts=layer.global_num_experts,
                expert_map=layer.expert_map,
                apply_router_weight_on_input=layer.apply_router_weight_on_input,
1548
                shared_experts_input=shared_experts_input,
1549
            )
1550
1551
1552
1553
1554


ModelOptNvFp4Config.LinearMethodCls = ModelOptNvFp4LinearMethod
ModelOptNvFp4Config.FusedMoEMethodCls = ModelOptNvFp4FusedMoE
ModelOptNvFp4Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607


class ModelOptMxFp8Config(ModelOptQuantConfigBase):
    """Config class for ModelOpt MXFP8."""

    def __init__(
        self,
        is_checkpoint_mxfp8_serialized: bool,
        kv_cache_quant_algo: str | None,
        exclude_modules: list[str],
    ) -> None:
        super().__init__(exclude_modules)
        self.is_checkpoint_mxfp8_serialized = is_checkpoint_mxfp8_serialized

        if not is_checkpoint_mxfp8_serialized:
            raise ValueError(
                "MXFP8 quantization requires a serialized checkpoint. "
                "Dynamic quantization is not supported."
            )

        logger.warning(
            "Detected ModelOpt MXFP8 checkpoint. Please note that "
            "the format is experimental and could change in future."
        )

        self.kv_cache_quant_algo = kv_cache_quant_algo

    def get_name(self) -> QuantizationMethods:
        return "modelopt_mxfp8"

    def get_supported_act_dtypes(self) -> list[torch.dtype]:
        return [torch.bfloat16]

    @classmethod
    def get_min_capability(cls) -> int:
        # MXFP8 hardware acceleration requires Blackwell (SM100) or newer
        return 100

    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> "QuantizeMethodBase | None":
        # MXFP8 does not yet support MoE models
        if isinstance(layer, FusedMoE):
            raise NotImplementedError(
                "MXFP8 quantization does not yet support MoE models. "
                "Please use FP8 or NVFP4 quantization for MoE models."
            )
        return super().get_quant_method(layer, prefix)

    @classmethod
    def override_quantization_method(
        cls, hf_quant_cfg, user_quant
    ) -> QuantizationMethods | None:
1608
1609
1610
        algo = cls._extract_modelopt_quant_algo(hf_quant_cfg)
        if algo is not None and "MXFP8" in algo:
            return "modelopt_mxfp8"
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
        return None

    @classmethod
    def _from_config(
        cls,
        *,
        quant_method: str,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
        original_config: dict[str, Any],
        **kwargs: Any,
    ) -> "ModelOptMxFp8Config":
        is_checkpoint_mxfp8_serialized = "MXFP8" in quant_method.upper()

        # For MXFP8, validate required fields in the config
        if is_checkpoint_mxfp8_serialized and "quantization" in original_config:
            quant_config = original_config["quantization"]
            required_fields = ["kv_cache_quant_algo", "exclude_modules"]
            missing_fields = [
                field for field in required_fields if field not in quant_config
            ]
            if missing_fields:
                raise ValueError(
                    f"MXFP8 quantization requires the following fields in "
                    f"hf_quant_config.json: {missing_fields}"
                )

        return cls(
            is_checkpoint_mxfp8_serialized,
            kv_cache_quant_method,
            exclude_modules,
        )


class ModelOptMxFp8LinearMethod(LinearMethodBase):
    """Linear method for ModelOpt MXFP8 quantization."""

    def __init__(self, quant_config: ModelOptMxFp8Config) -> None:
        self.quant_config = quant_config

        if not self.quant_config.is_checkpoint_mxfp8_serialized:
            raise ValueError(
                "MXFP8 currently only supports serialized checkpoints. "
                "Dynamic quantization is not supported."
            )

1657
1658
1659
        self.backend: Mxfp8LinearBackend = Mxfp8LinearBackend.FLASHINFER_CUTLASS
        self.mxfp8_linear_op = Mxfp8LinearOp(backend=self.backend)
        logger.info_once("Using %s backend for MXFP8 GEMM", self.backend.value)
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size

        if not self.quant_config.is_checkpoint_mxfp8_serialized:
            raise ValueError(
                "MXFP8 quantization was selected, but checkpoint is not "
                "MXFP8 serialized. Dynamic quantization is not supported."
            )

        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition

        if input_size_per_partition % MXFP8_BLOCK_SIZE != 0:
            raise ValueError(
                f"MXFP8 requires input dimension to be divisible by "
                f"{MXFP8_BLOCK_SIZE}, got {input_size_per_partition}"
            )

        # Weight tensor: FP8 E4M3 format
        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition,
                dtype=MXFP8_VALUE_DTYPE,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
        layer.register_parameter("weight", weight)

        # Weight scale tensor (E8M0 encoded as uint8), one scale per block of 32 along K
        weight_scale = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition // MXFP8_BLOCK_SIZE,
                dtype=MXFP8_SCALE_DTYPE,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
        layer.register_parameter("weight_scale", weight_scale)

1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
    def _process_weights_after_loading_scale_2d(self, layer: torch.nn.Module) -> None:
        """Not swizzled - MXFP8 GEMM emulation"""
        weight = layer.weight.data  # [N, K]
        N, K = weight.shape
        scale_k = K // MXFP8_BLOCK_SIZE

        # Slice weight_scale to match weight dimensions (handles padding)
        weight_scale = layer.weight_scale.data[:N, :scale_k].contiguous()

        layer.weight = Parameter(weight.contiguous(), requires_grad=False)
        layer.weight_scale = Parameter(weight_scale, requires_grad=False)

    def _process_weights_after_loading_scale_1d(self, layer: torch.nn.Module) -> None:
        """Swizzled - MXFP8 GEMM Flashinfer CUTLASS"""
        weight = layer.weight.data  # [N, K]
        N, K = weight.shape

        # 2D weight scale
        weight_scale = layer.weight_scale.data

        # Swizzle the weight scales
        scale_k = K // MXFP8_BLOCK_SIZE
        weight_scale_2d = weight_scale[:N, :scale_k].contiguous()
        weight_scale_swizzled = swizzle_mxfp8_scale(weight_scale_2d, M=N, K=K)

        layer.weight = Parameter(weight.contiguous(), requires_grad=False)
        layer.weight_scale = Parameter(
            weight_scale_swizzled.contiguous(), requires_grad=False
        )

1747
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
1748
        # Validate weight tensor
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
        if layer.weight.ndim != 2:
            raise ValueError(
                f"MXFP8 weight must be 2D tensor [N, K], got {layer.weight.ndim}D "
                f"with shape {tuple(layer.weight.shape)}"
            )

        if layer.weight.dtype != MXFP8_VALUE_DTYPE:
            raise ValueError(
                f"MXFP8 weight must be {MXFP8_VALUE_DTYPE} (FP8 E4M3), "
                f"got {layer.weight.dtype}. The checkpoint may not be properly "
                f"quantized with MXFP8."
            )

1762
1763
1764
1765
1766
1767
1768
1769
        # Validate weight scale tensor (should be 2D, not swizzled)
        assert layer.weight_scale.ndim == 2, (
            f"MXFP8 weight scale must be 2D, got {layer.weight_scale.ndim}D"
        )
        assert layer.weight_scale.dtype == MXFP8_SCALE_DTYPE, (
            f"MXFP8 weight scale must be {MXFP8_SCALE_DTYPE},"
            f" got {layer.weight_scale.dtype}"
        )
1770

1771
1772
1773
1774
        if self.backend == Mxfp8LinearBackend.EMULATION:
            # Swizzled layout is not used
            self._process_weights_after_loading_scale_2d(layer)
            return
1775

1776
1777
1778
        assert self.backend == Mxfp8LinearBackend.FLASHINFER_CUTLASS
        # Swizzled layout is required for Flashinfer CUTLASS
        self._process_weights_after_loading_scale_1d(layer)
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        if layer.weight.dtype != MXFP8_VALUE_DTYPE:
            raise ValueError(
                f"Weight dtype {layer.weight.dtype} != expected {MXFP8_VALUE_DTYPE}"
            )
        if layer.weight_scale.dtype != MXFP8_SCALE_DTYPE:
            raise ValueError(
                f"Weight scale dtype {layer.weight_scale.dtype} != "
                f"expected {MXFP8_SCALE_DTYPE}"
            )

        return self.mxfp8_linear_op.apply(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            out_dtype=x.dtype,
            bias=bias,
        )


# Register the method classes for ModelOptMxFp8Config
ModelOptMxFp8Config.LinearMethodCls = ModelOptMxFp8LinearMethod
ModelOptMxFp8Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992


class ModelOptMixedPrecisionConfig(ModelOptQuantConfigBase):
    """Config class for ModelOpt MIXED_PRECISION.

    Supports checkpoints where different layers use different quantization
    algorithms (e.g., FP8 for dense layers and NVFP4 for MoE experts).
    The per-layer algorithm is specified in the ``quantized_layers`` dict
    inside ``config.json``'s ``quantization_config`` (preferred) or the
    legacy ``hf_quant_config.json``.
    """

    def __init__(
        self,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
        quantized_layers: dict[str, dict[str, Any]],
        fp8_config: ModelOptFp8Config,
        nvfp4_config: ModelOptNvFp4Config,
    ) -> None:
        super().__init__(exclude_modules)
        self.kv_cache_quant_method = kv_cache_quant_method
        self.quantized_layers = quantized_layers
        self.fp8_config = fp8_config
        self.nvfp4_config = nvfp4_config

    def get_name(self) -> QuantizationMethods:
        return "modelopt_mixed"

    def get_supported_act_dtypes(self) -> list[torch.dtype]:
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
        return 89

    @classmethod
    def override_quantization_method(
        cls, hf_quant_cfg, user_quant
    ) -> QuantizationMethods | None:
        algo = cls._extract_modelopt_quant_algo(hf_quant_cfg)
        if algo is not None and algo == "MIXED_PRECISION":
            return "modelopt_mixed"
        return None

    @classmethod
    def _from_config(
        cls,
        *,
        quant_method: str,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
        original_config: dict[str, Any],
        group_size: int | None,
        **kwargs: Any,
    ) -> "ModelOptMixedPrecisionConfig":
        if "quantization" in original_config:
            quantized_layers = original_config["quantization"].get(
                "quantized_layers", {}
            )
        else:
            quantized_layers = original_config.get("quantized_layers", {})

        if not quantized_layers:
            raise ValueError(
                "MIXED_PRECISION quant_algo requires a non-empty "
                "'quantized_layers' mapping in the quantization config."
            )

        # Determine group_size from the first NVFP4 entry if not provided.
        if group_size is None:
            for layer_info in quantized_layers.values():
                if layer_info.get("quant_algo", "").upper() == "NVFP4":
                    group_size = layer_info.get("group_size", 16)
                    break
        if group_size is None:
            group_size = 16

        fp8_config = ModelOptFp8Config(
            quant_method="FP8",
            is_checkpoint_fp8_serialized=True,
            kv_cache_quant_method=kv_cache_quant_method,
            exclude_modules=[],
        )
        nvfp4_config = ModelOptNvFp4Config(
            is_checkpoint_nvfp4_serialized=True,
            kv_cache_quant_algo=kv_cache_quant_method,
            exclude_modules=[],
            group_size=group_size,
        )

        return cls(
            kv_cache_quant_method=kv_cache_quant_method,
            exclude_modules=exclude_modules,
            quantized_layers=quantized_layers,
            fp8_config=fp8_config,
            nvfp4_config=nvfp4_config,
        )

    def _resolve_quant_algo(self, prefix: str) -> str | None:
        """Look up the quant_algo for a vLLM-side layer prefix.

        Tries three strategies in order:
        1. Direct lookup in ``quantized_layers``.
        2. Packed/fused-layer lookup (unfuse via ``packed_modules_mapping``).
        3. Prefix-based lookup for FusedMoE (any child key starts with
           ``prefix + "."``).

        Returns the upper-cased quant_algo string, or *None* if the prefix
        is not found.
        """
        # 1. Direct lookup
        if prefix in self.quantized_layers:
            return self.quantized_layers[prefix]["quant_algo"].upper()

        # 2. Packed / fused layer lookup
        proj_name = prefix.rsplit(".", 1)[-1]
        if self.packed_modules_mapping and proj_name in self.packed_modules_mapping:
            algos: set[str] = set()
            base = prefix.rsplit(".", 1)[0]
            for shard_name in self.packed_modules_mapping[proj_name]:
                shard_prefix = f"{base}.{shard_name}"
                if shard_prefix in self.quantized_layers:
                    algos.add(self.quantized_layers[shard_prefix]["quant_algo"].upper())
            if len(algos) == 1:
                return algos.pop()
            if len(algos) > 1:
                raise ValueError(
                    f"Mixed quant_algo within fused layer {prefix}: "
                    f"{algos}. All shards must use the same quantization."
                )

        # 3. Prefix-based lookup (for FusedMoE / parent modules)
        prefix_dot = prefix + "."
        for key, info in self.quantized_layers.items():
            if key.startswith(prefix_dot):
                return info["quant_algo"].upper()

        return None

    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> "QuantizeMethodBase | None":
        """Return quantize-method based on layer."""
        # KV-cache quantization
        if isinstance(layer, Attention):
            if self.kv_cache_quant_method:
                return ModelOptFp8KVCacheMethod(self)
            return None

        # Excluded layers
        if self.is_layer_excluded(prefix):
            if isinstance(layer, LinearBase):
                return UnquantizedLinearMethod()
            return None

        quant_algo = self._resolve_quant_algo(prefix)

        if isinstance(layer, LinearBase):
            if quant_algo == "FP8":
                return ModelOptFp8LinearMethod(self.fp8_config)
            if quant_algo == "NVFP4":
                return ModelOptNvFp4LinearMethod(self.nvfp4_config)
            # Layer not in quantized_layers — leave unquantized
            return UnquantizedLinearMethod()

        if isinstance(layer, FusedMoE):
            if quant_algo == "FP8":
                return ModelOptFp8MoEMethod(
                    quant_config=self.fp8_config,
                    moe_config=layer.moe_config,
                )
            if quant_algo == "NVFP4":
                return ModelOptNvFp4FusedMoE(
                    quant_config=self.nvfp4_config,
                    moe_config=layer.moe_config,
                )
            return None

        return None

    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        super().apply_vllm_mapper(hf_to_vllm_mapper)
        if self.quantized_layers:
            self.quantized_layers = hf_to_vllm_mapper.apply_dict(self.quantized_layers)