"vscode:/vscode.git/clone" did not exist on "2fa2a50bf950797cb59d48908d205c655ec02654"
modelopt.py 80 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from fnmatch import fnmatch
5
from typing import TYPE_CHECKING, Any
6
7
8
9

import torch
from torch.nn.parameter import Parameter

10
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
11
from vllm.logger import init_logger
12
from vllm.model_executor.kernels.linear import init_fp8_linear_kernel
13
from vllm.model_executor.layers.attention import Attention, MLAAttention
14
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
15
from vllm.model_executor.layers.fused_moe.config import (
16
    FusedMoEConfig,
17
    FusedMoEQuantConfig,
18
19
20
21
    RoutingMethodType,
)
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
    FusedMoEMethodBase,
22
)
23
from vllm.model_executor.layers.fused_moe.layer import (
24
25
26
    FusedMoE,
    FusedMoeWeightScaleSupported,
)
27
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
28
    Fp8MoeBackend,
29
30
31
32
33
    convert_to_fp8_moe_kernel_format,
    make_fp8_moe_kernel,
    make_fp8_moe_quant_config,
    select_fp8_moe_backend,
)
34
35
36
from vllm.model_executor.layers.fused_moe.oracle.mxfp8 import (
    select_mxfp8_moe_backend,
)
37
38
39
40
41
42
43
from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
    convert_to_nvfp4_moe_kernel_format,
    is_global_sf_supported_for_nvfp4_backend,
    make_nvfp4_moe_kernel,
    make_nvfp4_moe_quant_config,
    select_nvfp4_moe_backend,
)
44
45
46
47
48
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
)
49
from vllm.model_executor.layers.quantization import QuantizationMethods
50
from vllm.model_executor.layers.quantization.base_config import (
51
52
53
    QuantizationConfig,
    QuantizeMethodBase,
)
54
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
55
56
57
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
    swap_w13_to_w31,
)
58
59
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
    W8A8BlockFp8LinearOp,
60
61
    process_fp8_input_tensor_strategy_moe,
    process_fp8_weight_tensor_strategy_moe,
62
)
63
64
65
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
    get_marlin_input_dtype,
)
66
67
68
69
70
71
from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
    MXFP8_BLOCK_SIZE,
    MXFP8_SCALE_DTYPE,
    MXFP8_VALUE_DTYPE,
    Mxfp8LinearBackend,
    Mxfp8LinearOp,
72
    mxfp8_e4m3_quantize,
73
    swizzle_mxfp8_scale,
74
)
75
76
77
78
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
    apply_nvfp4_linear,
    convert_to_nvfp4_linear_kernel_format,
    select_nvfp4_linear_backend,
79
)
80
from vllm.model_executor.layers.quantization.utils.quant_utils import (
81
82
    GroupShape,
    is_layer_skipped,
83
84
85
    kFp8DynamicTokenSym,
    kFp8StaticTensorSym,
    kFp8StaticTokenSym,
86
87
    kNvfp4Dynamic,
    kNvfp4Static,
88
)
89
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
90
    cutlass_block_fp8_supported,
91
92
    requantize_with_max_scale,
)
93
94
95
96
97
98
from vllm.model_executor.parameter import (
    BlockQuantScaleParameter,
    ChannelQuantScaleParameter,
    ModelWeightParameter,
    PerTensorScaleParameter,
)
99
100
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe
101

102
103
104
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

105
106
logger = init_logger(__name__)

107
108
109
110
111
112
113
114
115
QUANT_ALGOS = [
    # FP8 (per-tensor weight + optional static activation scale).
    "FP8",
    # FP8 per-channel weight scale + per-token activation scale.
    "FP8_PER_CHANNEL_PER_TOKEN",
    # FP8 per-block weight-only (ModelOpt may emit this as lowercase).
    "FP8_PB_WO",
    # FP4
    "NVFP4",
116
117
    # MXFP8
    "MXFP8",
118
119
    # MIXED_PRECISION,
    "MIXED_PRECISION",
120
]
121
KV_CACHE_QUANT_ALGOS = ["FP8"]
122
123


124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
    """

    def __init__(self, quant_config: "ModelOptQuantConfigBase"):
        super().__init__(quant_config)


class ModelOptQuantConfigBase(QuantizationConfig):
    LinearMethodCls: type = LinearMethodBase
    FusedMoEMethodCls: type = FusedMoEMethodBase
    KVCacheMethodCls: type = BaseKVCacheMethod

    def __init__(
        self,
        exclude_modules: list[str],
    ):
        super().__init__()
        self.exclude_modules: list[str] = exclude_modules

    def is_layer_excluded(self, prefix: str) -> bool:
        """
        Check if a layer should be excluded from quantization.

        Handles both exact matching (for fused layers) and ModelOpt wildcard matching.

        The ModelOpt exclude_modules list is a list of wildcards.
        """
        if len(self.exclude_modules) == 0:
            return False

        # First check exact matching with fused layer support
        if is_layer_skipped(prefix, self.exclude_modules, self.packed_modules_mapping):
            return True

        # TODO: This special hard coded logic is not needed for quantized checkpoints
        # generated by ModelOpt >= 0.39.0 where they are handled natually by the
        # exclude_modules config. But need to keep them for loading quantized
        # checkpoints generated by older versions. Then check substring matching
        # for patterns not caught by exact match
        for exclude_module in self.exclude_modules:
            # Skip exact matches already handled above
            if exclude_module != prefix and (
                exclude_module in prefix
                or (
                    prefix.startswith("language_model.")
                    and exclude_module in prefix.removeprefix("language_model.")
                )
            ):
                return True

        # modelopt exclude modules are not simple strings, they are wildcards
        for wildcard_pattern in self.exclude_modules:
            if fnmatch(prefix, wildcard_pattern):
                return True

        return False

    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
185
    ) -> "QuantizeMethodBase | None":
186
        # handle kv-cache first so we can focus only on weight quantization thereafter
187
        if isinstance(layer, (Attention, MLAAttention)):
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
            return self.KVCacheMethodCls(self)

        # handle exclusion
        if self.is_layer_excluded(prefix):
            if isinstance(layer, LinearBase):
                return UnquantizedLinearMethod()
            return None

        # TODO: This special hard coded logic is not needed for quantized checkpoints
        # generated by ModelOpt >= 0.39.0 where they are handled natually by the
        # exclude_modules config. But need to keep them for loading quantized
        # checkpoints generated by older versions. Then check substring matching
        # for patterns not caught by exact match
        if "vision_tower" in prefix or "vision_model" in prefix:
            return UnquantizedLinearMethod()

        # now, the layer is quantized, handle it here
        if isinstance(layer, LinearBase):
206
207
208
209
            quant_method = self.LinearMethodCls(self)
            if getattr(quant_method, "backend", "") == "marlin":
                quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
            return quant_method
210
        elif isinstance(layer, FusedMoE):
211
212
213
            quant_method = self.FusedMoEMethodCls(
                quant_config=self, moe_config=layer.moe_config
            )
214
215
216
            if getattr(quant_method, "backend", "") == "marlin":
                quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
            return quant_method
217
218
219
220
221

        return None

    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if len(self.exclude_modules) > 0:
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
            # This is a workaround for the weights remapping issue:
            # https://github.com/vllm-project/vllm/issues/28072
            # Right now, the Nvidia ModelOpt library use just one wildcard pattern:
            #        module_path*
            # It gets applied if the whole tree of modules rooted at module_path
            # is not quantized. Here we replace such pattern by 2 patterns that are
            # collectively equivalent to the original pattern:
            #        module_path
            #        module_path.*
            new_exclude_modules = []
            for exclude in self.exclude_modules:
                if len(exclude) >= 2 and exclude[-1] == "*" and exclude[-2] != ".":
                    new_exclude_modules.append(exclude[:-1])
                    new_exclude_modules.append(exclude[:-1] + ".*")
                else:
                    new_exclude_modules.append(exclude)

            self.exclude_modules = hf_to_vllm_mapper.apply_list(new_exclude_modules)
240

241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
    @staticmethod
    def _extract_modelopt_quant_algo(
        hf_quant_cfg: dict[str, Any] | None,
    ) -> str | None:
        """Extract upper-cased quant_algo from a modelopt config.

        Returns the quant_algo string (upper-cased), or None if the config
        is not a modelopt config.
        """
        if hf_quant_cfg is None:
            return None
        if hf_quant_cfg.get("quant_method", "").lower() != "modelopt":
            return None
        if "quantization" in hf_quant_cfg:
            quant_config = hf_quant_cfg["quantization"]
            if isinstance(quant_config, dict):
                return str(quant_config.get("quant_algo", "")).upper()
            return None
        return str(hf_quant_cfg.get("quant_algo", "")).upper()

261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
    @staticmethod
    def get_config_filenames() -> list[str]:
        return ["hf_quant_config.json"]

    @classmethod
    def _from_config(
        cls,
        *,
        quant_method: str,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
        original_config: dict[str, Any],
        group_size: int | None,
    ) -> "ModelOptQuantConfigBase":
        raise NotImplementedError("Please implement this function in sub classes")

    @classmethod
    def from_config(cls, config: dict[str, Any]) -> "ModelOptQuantConfigBase":
        # Handle both ModelOpt format and compressed-tensors style format
        if "quantization" in config:
            # Traditional ModelOpt format:
            # {"quantization": {"quant_algo": "..."}}
            quant_config = cls.get_from_keys(config, ["quantization"])
            if not isinstance(quant_config, dict):
                raise ValueError("Expected 'quantization' to be a dictionary in config")

            quant_method = quant_config.get("quant_algo")

            # Handle kv_cache_quant_algo with proper type validation
            kv_cache_quant_method = quant_config.get("kv_cache_quant_algo")

            # Handle group_size with proper type validation
            group_size_raw = quant_config.get("group_size")

            # "exclude_modules" is the key in the legacy hf_quant_config.json
            exclude_modules = quant_config.get("exclude_modules", [])
        else:
298
            # Compressed-tensors style format (config.json quantization_config):
299
300
            # {"quant_algo": "...", "quant_method": "modelopt"}
            quant_method = config.get("quant_algo")
301
302
303
304
305
306
307
308
309
310
311

            # "kv_cache_scheme" (a dict) instead of "kv_cache_quant_algo" (a string).
            kv_cache_scheme = config.get("kv_cache_scheme")
            if isinstance(kv_cache_scheme, dict) and (
                kv_cache_scheme.get("type") == "float"
                and kv_cache_scheme.get("num_bits") == 8
            ):
                kv_cache_quant_method = "FP8"
            else:
                kv_cache_quant_method = None

312
313
314
315
316
317
318
            # "ignore" is the key in config.json
            exclude_modules = config.get("ignore", [])
            group_size_raw = config.get("group_size")

        if not quant_method:
            raise ValueError("Missing 'quant_algo' in quantization config")

319
320
321
        # Normalize quant_algo for robust matching (ModelOpt may emit lowercase).
        quant_method = str(quant_method).upper()

322
323
324
325
326
327
328
329
        if kv_cache_quant_method is None:
            # No KV cache quantization, keep this branch just to have this comment
            pass
        elif not isinstance(kv_cache_quant_method, str):
            raise ValueError(
                f"kv_cache_quant_algo must be a string, got "
                f"{type(kv_cache_quant_method)}"
            )
330
331
        else:
            kv_cache_quant_method = kv_cache_quant_method.upper()
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366

        if not isinstance(exclude_modules, list):
            raise ValueError(
                f"exclude_modules must be a list, got {type(exclude_modules)}"
            )

        if group_size_raw is None:
            group_size = None
        elif isinstance(group_size_raw, int):
            group_size = group_size_raw
        else:
            try:
                group_size = int(group_size_raw)
            except (ValueError, TypeError):
                raise ValueError(
                    f"group_size must be an integer, got {type(group_size_raw)}"
                ) from None

        if quant_method not in QUANT_ALGOS:
            raise ValueError(
                f"ModelOpt currently only supports: {QUANT_ALGOS} "
                "quantizations in vLLM. Please check the "
                "`hf_quant_config.json` file for your model's "
                "quant configuration."
            )
        return cls._from_config(
            quant_method=quant_method,
            kv_cache_quant_method=kv_cache_quant_method,
            exclude_modules=exclude_modules,
            group_size=group_size,
            original_config=config,
        )


class ModelOptFp8Config(ModelOptQuantConfigBase):
367
368
369
370
    """Config class for ModelOpt FP8."""

    def __init__(
        self,
371
        quant_method: str,
372
373
374
        is_checkpoint_fp8_serialized: bool,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
375
    ) -> None:
376
        super().__init__(exclude_modules)
377
        self.quant_method = quant_method
378
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
379
        self.kv_cache_quant_method = kv_cache_quant_method
380
        if is_checkpoint_fp8_serialized:
381
            logger.warning(
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
                "Detected ModelOpt fp8 checkpoint (quant_algo=%s). Please note "
                "that the format is experimental and could change.",
                quant_method,
            )

        # Select LinearMethod implementation based on quant_algo.
        if self.quant_method == "FP8":
            self.LinearMethodCls = ModelOptFp8LinearMethod
        elif self.quant_method == "FP8_PER_CHANNEL_PER_TOKEN":
            self.LinearMethodCls = ModelOptFp8PcPtLinearMethod
        elif self.quant_method == "FP8_PB_WO":
            self.LinearMethodCls = ModelOptFp8PbWoLinearMethod
        else:
            raise ValueError(
                "Unsupported ModelOpt FP8 quant_algo for vLLM: "
                f"{self.quant_method}. Supported: FP8 / "
                "FP8_PER_CHANNEL_PER_TOKEN / FP8_PB_WO."
399
            )
400

401
    def get_name(self) -> QuantizationMethods:
402
403
        return "modelopt"

404
    def get_supported_act_dtypes(self) -> list[torch.dtype]:
405
406
407
408
409
410
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
        return 89

411
412
    @classmethod
    def override_quantization_method(
413
        cls, hf_quant_cfg, user_quant
414
    ) -> QuantizationMethods | None:
415
416
417
        algo = cls._extract_modelopt_quant_algo(hf_quant_cfg)
        if algo is not None and algo == "FP8":
            return "modelopt"
418
419
        return None

420
    @classmethod
421
422
423
424
425
426
427
428
429
    def _from_config(
        cls,
        *,
        quant_method: str,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
        original_config: dict[str, Any],
        **kwargs: Any,
    ) -> "ModelOptFp8Config":
430
        is_checkpoint_fp8_serialized = "FP8" in quant_method
431

432
433
434
435
436
437
        return cls(
            quant_method,
            is_checkpoint_fp8_serialized,
            kv_cache_quant_method,
            exclude_modules,
        )
438

439
440
441
442

class ModelOptFp8LinearMethod(LinearMethodBase):
    """Linear method for Model Optimizer static quantization.
    Supports loading FP8 checkpoints with static weight scale and
443
    activation scale. Future support might be added for dynamic
444
445
446
447
    scales.

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
448
    2. Only support float8_e4m3fn datatype
449
450
451
        Args: quant_config: The ModelOpt quantization config.
    """

452
    def __init__(self, quant_config: ModelOptFp8Config) -> None:
453
        self.quant_config = quant_config
454
455
456
457
458
        self.fp8_linear = init_fp8_linear_kernel(
            activation_quant_key=kFp8StaticTensorSym,
            weight_quant_key=kFp8StaticTensorSym,
            out_dtype=torch.get_default_dtype(),
            module_name=self.__class__.__name__,
459
        )
460
461
462
463
464

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
465
        output_partition_sizes: list[int],
466
467
468
469
470
471
472
473
474
475
476
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
477
478
479
480
481
482
483
484
485
486
487
488
489
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_fp8_serialized
            else params_dtype
        )
        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition, input_size_per_partition, dtype=weight_dtype
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
490
491
492
493
        layer.register_parameter("weight", weight)

        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
494
495
496
497
            weight_scale = PerTensorScaleParameter(
                data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
                weight_loader=weight_loader,
            )
498
499
500
            weight_scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("weight_scale", weight_scale)
            # INPUT SCALE
501
502
503
504
            scale = PerTensorScaleParameter(
                data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
                weight_loader=weight_loader,
            )
505
506
507
508

            scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("input_scale", scale)

509
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
510
511
512
513
        weight = layer.weight
        max_w_scale = layer.weight_scale.max()
        if not (layer.weight_scale == layer.weight_scale[0]).all():
            max_w_scale, weight = requantize_with_max_scale(
514
515
                layer.weight, layer.weight_scale, layer.logical_widths
            )
516
517
        layer.weight = Parameter(weight.t(), requires_grad=False)
        layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
518
        layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
519
520
521
522
523

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
524
        bias: torch.Tensor | None = None,
525
    ) -> torch.Tensor:
526
        return self.fp8_linear.apply_weights(layer, x, bias)
527
528


529
530
531
532
533
534
535
536
537
538
539
class ModelOptFp8PcPtLinearMethod(LinearMethodBase):
    """Linear method for ModelOpt FP8_PER_CHANNEL_PER_TOKEN checkpoints.

    Expected checkpoint structure (per Linear):
    - weight: fp8-e4m3fn, shape [out, in]
    - weight_scale: fp32, shape [out] (per-output-channel)
    - no input_scale (activations are dynamically quantized per-token)
    """

    def __init__(self, quant_config: ModelOptFp8Config) -> None:
        self.quant_config = quant_config
540
541
542
543
544
        self.fp8_linear = init_fp8_linear_kernel(
            activation_quant_key=kFp8DynamicTokenSym,
            weight_quant_key=kFp8StaticTokenSym,
            out_dtype=torch.get_default_dtype(),
            module_name=self.__class__.__name__,
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
        )

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size

        if not self.quant_config.is_checkpoint_fp8_serialized:
            raise ValueError(
                "FP8_PER_CHANNEL_PER_TOKEN currently only supports "
                "FP8-serialized checkpoints."
            )

        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition

        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition,
                dtype=torch.float8_e4m3fn,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
        layer.register_parameter("weight", weight)

        weight_scale = ChannelQuantScaleParameter(
            data=torch.empty(output_size_per_partition, dtype=torch.float32),
            output_dim=0,
            weight_loader=weight_loader,
        )
        weight_scale[:] = torch.finfo(torch.float32).min
        layer.register_parameter("weight_scale", weight_scale)

591
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
592
593
594
595
596
597
598
599
600
        layer.weight = Parameter(layer.weight.t(), requires_grad=False)
        layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
601
        return self.fp8_linear.apply_weights(layer, x, bias)
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691


class ModelOptFp8PbWoLinearMethod(LinearMethodBase):
    """Linear method for ModelOpt FP8_PB_WO checkpoints.

    ModelOpt exports `weight_scale` as a 4D tensor:
      [out_blk, 1, in_blk, 1]
    where block size is typically 128 for both dims.

    vLLM executes it as FP8 GEMM with *dynamic per-token* activation quant.
    """

    _WEIGHT_BLOCK_SIZE: tuple[int, int] = (128, 128)

    def __init__(self, quant_config: ModelOptFp8Config) -> None:
        self.quant_config = quant_config
        block_n, block_k = self._WEIGHT_BLOCK_SIZE
        self.weight_block_size = list(self._WEIGHT_BLOCK_SIZE)
        self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
            weight_group_shape=GroupShape(block_n, block_k),
            act_quant_group_shape=GroupShape(1, block_k),
            cutlass_block_fp8_supported=cutlass_block_fp8_supported(),
            use_aiter_and_is_supported=False,
        )

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size

        if not self.quant_config.is_checkpoint_fp8_serialized:
            raise ValueError(
                "FP8_PB_WO currently only supports FP8-serialized checkpoints."
            )

        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition

        # Expose block size so the v2 weight loaders can translate offsets from
        # element-space -> block-space for BlockQuantScaleParameter.
        layer.weight_block_size = self.weight_block_size

        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition,
                dtype=torch.float8_e4m3fn,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
        layer.register_parameter("weight", weight)

        block_n, block_k = self._WEIGHT_BLOCK_SIZE
        if output_size_per_partition % block_n != 0:
            raise ValueError(
                "ModelOpt FP8_PB_WO requires out_features divisible by "
                f"{block_n}, got {output_size_per_partition}."
            )
        if input_size_per_partition % block_k != 0:
            raise ValueError(
                "ModelOpt FP8_PB_WO requires in_features divisible by "
                f"{block_k}, got {input_size_per_partition}."
            )

        out_blks = output_size_per_partition // block_n
        in_blks = input_size_per_partition // block_k

        # Match ModelOpt's exported shape so weight loading works without a
        # custom loader: [out_blk, 1, in_blk, 1]
        weight_scale = BlockQuantScaleParameter(
            data=torch.empty((out_blks, 1, in_blks, 1), dtype=torch.float32),
            input_dim=2,
            output_dim=0,
            weight_loader=weight_loader,
        )
        weight_scale[:] = torch.finfo(torch.float32).min
        layer.register_parameter("weight_scale", weight_scale)

692
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
        # Keep weight in [out, in] layout for W8A8BlockFp8LinearOp.
        layer.weight = Parameter(layer.weight.data, requires_grad=False)

        scale = layer.weight_scale
        if scale.dim() == 4:
            # [out_blk, 1, in_blk, 1] -> [out_blk, in_blk]
            scale = scale.squeeze(1).squeeze(-1)
        elif scale.dim() != 2:
            raise ValueError(
                "Unexpected ModelOpt FP8_PB_WO weight_scale shape: "
                f"{tuple(scale.shape)}."
            )

        layer.weight_scale = Parameter(scale.contiguous(), requires_grad=False)

708
709
710
        if hasattr(self, "fp8_linear"):
            self.fp8_linear.process_weights_after_loading(layer)

711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        return self.w8a8_block_fp8_linear.apply(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            input_scale=None,
            bias=bias,
        )


726
727
728
729
730
731
732
733
class ModelOptFp8MoEMethod(FusedMoEMethodBase):
    """MoE method for ModelOpt FP8.
    Supports loading FP8 checkpoints with static weight scale and
    activation scale.
    Args:
        quant_config: The ModelOpt quantization config.
    """

734
735
736
    def __init__(
        self,
        quant_config: ModelOptFp8Config,
737
        moe_config: FusedMoEConfig,
738
    ) -> None:
739
        super().__init__(moe_config)
740
        self.quant_config = quant_config
741
        assert self.quant_config.is_checkpoint_fp8_serialized
742
743
744
745
746
747

        # Select Fp8 MoE backend
        self.fp8_backend, self.experts_cls = select_fp8_moe_backend(
            config=self.moe,
            weight_key=kFp8StaticTensorSym,
            activation_key=kFp8StaticTensorSym,
748
        )
749

750
    def maybe_make_prepare_finalize(
751
        self,
752
        routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
753
    ) -> mk.FusedMoEPrepareAndFinalizeModular | None:
754
755
756
757
        raise ValueError(
            f"{self.__class__.__name__} uses the new modular kernel initialization "
            "logic. This function should not be called."
        )
758
759
760

    def select_gemm_impl(
        self,
761
        prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular,
762
        layer: torch.nn.Module,
763
    ) -> mk.FusedMoEExpertsModular:
764
765
766
        raise ValueError(
            f"{self.__class__.__name__} uses the new modular kernel initialization "
            "logic. This function should not be called."
767
        )
768
769
770
771
772
773
774
775
776
777

    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
778
779
780
        layer.orig_dtype = params_dtype
        layer.num_experts = num_experts

781
        # Use FP8 dtype if checkpoint is serialized
782
783
784
785
786
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_fp8_serialized
            else params_dtype
        )
787
788
        weight_loader = extra_weight_attrs.get("weight_loader")

789
        w13_num_shards = 2 if self.moe.is_act_and_mul else 1
790

791
        w13_weight = ModelWeightParameter(
792
793
            data=torch.empty(
                num_experts,
794
                w13_num_shards * intermediate_size_per_partition,
795
796
797
                hidden_size,
                dtype=weight_dtype,
            ),
798
799
800
801
802
803
804
            input_dim=2,
            output_dim=1,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_weight", w13_weight)

        w2_weight = ModelWeightParameter(
805
806
807
808
809
810
            data=torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=weight_dtype,
            ),
811
812
813
814
815
816
            input_dim=2,
            output_dim=1,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w2_weight", w2_weight)

817
818
819
820
821
822
        # WEIGHT SCALES - Per-tensor scaling for ModelOpts
        # For gated MoE, allocate 2 scales for w1 and w3 respectively.
        # They will be combined to a single scale after weight loading.
        # For non-gated MoE, allocate 1 scale for w13.
        w13_weight_scale = PerTensorScaleParameter(
            data=torch.full(
823
                (num_experts, w13_num_shards),
824
825
826
827
828
829
830
831
832
833
834
                1.0,
                dtype=torch.float32,
            ),
            weight_loader=weight_loader,
        )
        w2_weight_scale = PerTensorScaleParameter(
            data=torch.full((num_experts,), 1.0, dtype=torch.float32),
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_weight_scale", w13_weight_scale)
        layer.register_parameter("w2_weight_scale", w2_weight_scale)
835

836
837
838
839
        # INPUT SCALES - Per-tensor scaling for ModelOpt
        w13_input_scale = PerTensorScaleParameter(
            data=torch.full((num_experts,), 1.0, dtype=torch.float32),
            weight_loader=weight_loader,
840
        )
841
842
843
844
845
846
        w2_input_scale = PerTensorScaleParameter(
            data=torch.full((num_experts,), 1.0, dtype=torch.float32),
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_input_scale", w13_input_scale)
        layer.register_parameter("w2_input_scale", w2_input_scale)
847

848
849
    def _setup_kernel(
        self,
850
        layer: FusedMoE,
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
        w13: torch.Tensor,
        w2: torch.Tensor,
        w13_scale: torch.Tensor,
        w2_scale: torch.Tensor,
        w13_input_scale: torch.Tensor,
        w2_input_scale: torch.Tensor,
    ):
        w13, w2, w13_scale, w2_scale = convert_to_fp8_moe_kernel_format(
            fp8_backend=self.fp8_backend,
            layer=layer,
            w13=w13,
            w2=w2,
            w13_scale=w13_scale,
            w2_scale=w2_scale,
            w13_input_scale=w13_input_scale,
            w2_input_scale=w2_input_scale,
        )
868

869
870
871
872
873
874
875
        # Replace parameters with updated versions. Note that this helper
        # function ensures the replacement is compatible with RL weight reloads.
        replace_parameter(layer, "w13_weight", w13)
        replace_parameter(layer, "w2_weight", w2)
        replace_parameter(layer, "w13_weight_scale", w13_scale)
        replace_parameter(layer, "w2_weight_scale", w2_scale)

876
        # Setup modular kernel.
877
        self.moe_quant_config = self.get_fused_moe_quant_config(layer)
878
879
880
881
882
883
884
885
886
        assert self.experts_cls is not None
        self.moe_kernel = make_fp8_moe_kernel(
            moe_quant_config=self.moe_quant_config,
            moe_config=self.moe,
            fp8_backend=self.fp8_backend,
            experts_cls=self.experts_cls,
            routing_tables=layer._maybe_init_expert_routing_tables(),
            shared_experts=layer.shared_experts,
        )
887

888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        w13 = layer.w13_weight
        w2 = layer.w2_weight
        w13_scale = layer.w13_weight_scale
        w2_scale = layer.w2_weight_scale
        w13_input_scale = layer.w13_input_scale
        w2_input_scale = layer.w2_input_scale

        # Per tensor kernels require single activation scale. Use the max.
        w13_input_scale, w2_input_scale = process_fp8_input_tensor_strategy_moe(
            w13_input_scale, w2_input_scale
        )
        replace_parameter(layer, "w13_input_scale", w13_input_scale)
        replace_parameter(layer, "w2_input_scale", w2_input_scale)

        # Per tensor kernels require single weight scale for w13 per expert, but
        # on disk there is a scale for w1 and w3. Use the max to requantize.
        shard_size = layer.intermediate_size_per_partition
        w13, w13_scale = process_fp8_weight_tensor_strategy_moe(
            w13,
            w13_scale,
            shard_size,
            num_experts=layer.w13_weight.shape[0],
            is_act_and_mul=self.moe.is_act_and_mul,
912
913
        )

914
915
916
917
        # Shuffle weights to runtime format and setup kernel.
        self._setup_kernel(
            layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale
        )
918

919
    def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
920
921
922
923
924
925
926
927
928
929
930
931
        w1_scale = layer.w13_weight_scale
        w2_scale = layer.w2_weight_scale
        a1_scale = layer.w13_input_scale
        a2_scale = layer.w2_input_scale

        return make_fp8_moe_quant_config(
            fp8_backend=self.fp8_backend,
            w1_scale=w1_scale,
            w2_scale=w2_scale,
            a1_scale=a1_scale,
            a2_scale=a2_scale,
        )
932

933
    def apply_monolithic(
934
        self,
935
        layer: FusedMoE,
936
937
        x: torch.Tensor,
        router_logits: torch.Tensor,
938
    ) -> torch.Tensor:
939
        assert self.is_monolithic
940
941
942
943
944
945
946
        assert self.moe_kernel is not None
        return self.moe_kernel.apply_monolithic(
            x,
            layer.w13_weight,
            layer.w2_weight,
            router_logits,
            activation=layer.activation,
947
            global_num_experts=layer.global_num_experts,
948
949
            expert_map=layer.expert_map,
            apply_router_weight_on_input=layer.apply_router_weight_on_input,
950
951
            num_expert_group=layer.num_expert_group,
            topk_group=layer.topk_group,
952
953
            e_score_correction_bias=layer.e_score_correction_bias,
            routed_scaling_factor=layer.routed_scaling_factor,
954
        )
955

956
957
958
959
960
961
    def apply(
        self,
        layer: FusedMoE,
        x: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
962
        shared_experts_input: torch.Tensor | None,
963
    ) -> torch.Tensor:
964
        assert not self.is_monolithic
965
966
967
968
969
970
971
        assert self.moe_kernel is not None
        return self.moe_kernel.apply(
            x,
            layer.w13_weight,
            layer.w2_weight,
            topk_weights,
            topk_ids,
972
973
974
975
            activation=layer.activation,
            global_num_experts=layer.global_num_experts,
            expert_map=layer.expert_map,
            apply_router_weight_on_input=layer.apply_router_weight_on_input,
976
            shared_experts_input=shared_experts_input,
977
978
        )

979

980
981
982
983
984
985
ModelOptFp8Config.LinearMethodCls = ModelOptFp8LinearMethod
ModelOptFp8Config.FusedMoEMethodCls = ModelOptFp8MoEMethod
ModelOptFp8Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod


class ModelOptNvFp4Config(ModelOptQuantConfigBase):
986
987
988
989
990
    """Config class for ModelOpt FP4."""

    def __init__(
        self,
        is_checkpoint_nvfp4_serialized: bool,
991
        kv_cache_quant_algo: str | None,
992
        exclude_modules: list[str],
993
994
        group_size: int = 16,
    ) -> None:
995
        super().__init__(exclude_modules)
996
997
998
999
        self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
        if is_checkpoint_nvfp4_serialized:
            logger.warning(
                "Detected ModelOpt NVFP4 checkpoint. Please note that"
1000
1001
                " the format is experimental and could change in future."
            )
1002
1003
1004
1005

            self.group_size = group_size
            self.kv_cache_quant_algo = kv_cache_quant_algo

1006
    def get_name(self) -> QuantizationMethods:
1007
        return "modelopt_fp4"
1008

1009
    def get_supported_act_dtypes(self) -> list[torch.dtype]:
1010
1011
1012
1013
        return [torch.bfloat16, torch.half, torch.float8_e4m3fn]

    @classmethod
    def get_min_capability(cls) -> int:
1014
        return 75
1015

1016
1017
    @classmethod
    def override_quantization_method(
1018
        cls, hf_quant_cfg, user_quant
1019
    ) -> QuantizationMethods | None:
1020
1021
1022
        algo = cls._extract_modelopt_quant_algo(hf_quant_cfg)
        if algo is not None and ("NVFP4" in algo or "FP4" in algo):
            return "modelopt_fp4"
1023
1024
        return None

1025
    @classmethod
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
    def _from_config(
        cls,
        *,
        quant_method: str,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
        original_config: dict[str, Any],
        group_size: int | None,
        **kwargs: Any,
    ) -> "ModelOptNvFp4Config":
1036
        is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
1037

1038
1039
1040
        if group_size is None:
            group_size = 16  # Default value

1041
        # For FP4, these fields are required
1042
        if is_checkpoint_nvfp4_serialized and "quantization" in original_config:
1043
            # Check if required fields are present in the quantization config
1044
            quant_config = original_config["quantization"]
1045
            required_fields = ["group_size", "kv_cache_quant_algo", "exclude_modules"]
1046
1047
1048
1049
1050
1051
            missing_fields = [
                field for field in required_fields if field not in quant_config
            ]
            if missing_fields:
                raise ValueError(
                    f"NVFP4 quantization requires the following fields in "
1052
1053
1054
1055
1056
                    f"hf_quant_config.json: {missing_fields}"
                )

        return cls(
            is_checkpoint_nvfp4_serialized,
1057
            kv_cache_quant_method,
1058
1059
1060
            exclude_modules,
            group_size,
        )
1061
1062
1063
1064
1065


class ModelOptNvFp4LinearMethod(LinearMethodBase):
    """Linear method for Model Optimizer NVFP4.
    Supports loading NVFP4 checkpoints with the following structure:
1066

1067
1068
1069
1070
1071
1072
1073
    input_scale: torch.float32, scalar ,
    weight: NVFP4(represented as byte) Shape: [1, X, y/2]
    weight_scale: FP8-E4M3, Shape: [X, Y], aka per block scale,
    weight_scale_2: torch.float32, scalar,
    Args: quant_config: The ModelOpt quantization config.
    """

1074
    def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
1075
        self.quant_config = quant_config
1076
        self.marlin_input_dtype = None
1077
        self.backend = select_nvfp4_linear_backend()
1078

1079
1080
1081
1082
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
1083
        output_partition_sizes: list[int],
1084
1085
1086
1087
1088
1089
1090
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size
        if not self.quant_config.is_checkpoint_nvfp4_serialized:
1091
1092
1093
1094
            raise ValueError(
                "NVFP4 quantization was selected, "
                " dynamic quantization is not supported."
            )
1095
1096
1097
1098
1099
1100
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition

1101
1102
1103
1104
        if input_size_per_partition % 16 != 0:
            raise ValueError(
                "Unsupported model when in features size is not multiple of 16"
            )
1105
        # The nvfp4 weight is still represented as
1106
1107
1108
1109
1110
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_nvfp4_serialized
            else params_dtype
        )
1111
1112
1113
1114
1115
1116
        # Weight
        weight = ModelWeightParameter(
            data=torch.empty(
                # 2 fp4 items are packed in the input dimension
                layer.output_size_per_partition,
                layer.input_size_per_partition // 2,
1117
1118
                dtype=torch.uint8,
            ),
1119
1120
            input_dim=1,
            output_dim=0,
1121
1122
            weight_loader=weight_loader,
        )
1123
1124
        layer.register_parameter("weight", weight)

1125
1126
        # Input Global Scale
        input_global_scale = PerTensorScaleParameter(
1127
1128
1129
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )
1130
        layer.register_parameter("input_scale", input_global_scale)
1131

1132
1133
        # Weight Global Scale
        weight_global_scale = PerTensorScaleParameter(
1134
1135
1136
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )
1137
        layer.register_parameter("weight_scale_2", weight_global_scale)
1138
1139

        # Per Block Weight Scale
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
        weight_scale = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition // self.quant_config.group_size,
                dtype=weight_dtype,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
1150
1151
1152

        layer.register_parameter("weight_scale", weight_scale)

1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        # Rename ModelOpt checkpoint names to standardized names
        input_global_scale = layer.input_scale.max().to(torch.float32)
        layer.input_global_scale = Parameter(input_global_scale, requires_grad=False)
        del layer.input_scale
        weight_global_scale = layer.weight_scale_2.max().to(torch.float32)
        layer.weight_global_scale = Parameter(weight_global_scale, requires_grad=False)
        del layer.weight_scale_2

        # Pre-compute alpha and inverse for runtime quantization
1163
        layer.alpha = Parameter(
1164
            layer.input_global_scale * layer.weight_global_scale, requires_grad=False
1165
        )
1166
1167
        layer.input_global_scale_inv = Parameter(
            (1.0 / layer.input_global_scale).to(torch.float32), requires_grad=False
1168
        )
1169

1170
1171
        # Convert layer to NVFP4 linear kernel format
        convert_to_nvfp4_linear_kernel_format(self.backend, layer)
1172
1173
1174
1175
1176

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
1177
        bias: torch.Tensor | None = None,
1178
    ) -> torch.Tensor:
1179
1180
1181
1182
1183
        return apply_nvfp4_linear(
            backend=self.backend,
            layer=layer,
            x=x,
            bias=bias,
1184
        )
1185

1186
1187
1188
1189

class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
    """
    MoE Method for FP4 Quantization.
1190
    Args:
1191
1192
1193
        quant_config: NVFP4 Quant Config
    """

1194
1195
1196
    def __init__(
        self,
        quant_config: ModelOptNvFp4Config,
1197
        moe_config: FusedMoEConfig,
1198
    ) -> None:
1199
        super().__init__(moe_config)
1200
        self.quant_config = quant_config
1201
1202
1203
1204
1205
1206
1207
        # Select experts implementation.
        self.nvfp4_backend, self.experts_cls = select_nvfp4_moe_backend(
            config=self.moe,
            weight_key=kNvfp4Static,
            activation_key=kNvfp4Dynamic,
        )

1208
1209
1210
        self.use_global_sf = is_global_sf_supported_for_nvfp4_backend(
            self.nvfp4_backend
        )
1211

1212
1213
1214
    def maybe_make_prepare_finalize(
        self,
        routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
1215
    ) -> mk.FusedMoEPrepareAndFinalizeModular | None:
1216
1217
1218
        raise ValueError(
            f"{self.__class__.__name__} uses the new modular kernel initialization "
            "logic. This function should not be called."
1219
        )
1220

1221
1222
1223
1224
1225
1226
    def uses_weight_scale_2_pattern(self) -> bool:
        """
        FP4 variants use 'weight_scale_2' pattern for per-tensor weight scales.
        """
        return True

1227
1228
1229
1230
1231
1232
1233
1234
1235
    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
1236
        assert self.quant_config.is_checkpoint_nvfp4_serialized
1237

1238
1239
        layer.num_experts = num_experts
        layer.params_dtype = params_dtype
1240
1241
1242
1243
        layer.quant_config = self.quant_config
        weight_dtype = torch.uint8
        weight_scale_dtype = torch.float8_e4m3fn
        weight_loader = extra_weight_attrs.get("weight_loader")
1244
        global_num_experts = extra_weight_attrs.get("global_num_experts")
1245
        w13_num_shards = 2 if self.moe.is_act_and_mul else 1
1246
1247
1248
1249
        # GEMM 1
        w13_weight = ModelWeightParameter(
            data=torch.empty(
                num_experts,
1250
                w13_num_shards * intermediate_size_per_partition,
1251
1252
                # 2 fp4 items are packed in the input dimension
                hidden_size // 2,
1253
1254
                dtype=weight_dtype,
            ),
1255
1256
            input_dim=1,
            output_dim=2,
1257
1258
            weight_loader=weight_loader,
        )
1259
1260
1261
1262
1263
1264
1265
1266
1267
        layer.register_parameter("w13_weight", w13_weight)

        # GEMM 2
        w2_weight = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                hidden_size,
                # 2 fp4 items are packed in the input dimension
                intermediate_size_per_partition // 2,
1268
1269
                dtype=weight_dtype,
            ),
1270
1271
            input_dim=1,
            output_dim=2,
1272
1273
            weight_loader=weight_loader,
        )
1274
1275
1276
1277
1278
        layer.register_parameter("w2_weight", w2_weight)

        w13_weight_scale = ModelWeightParameter(
            data=torch.empty(
                num_experts,
1279
                w13_num_shards * intermediate_size_per_partition,
1280
1281
                # 2 fp4 items are packed in the input dimension
                hidden_size // self.quant_config.group_size,
1282
1283
                dtype=weight_scale_dtype,
            ),
1284
1285
            input_dim=1,
            output_dim=2,
1286
1287
            weight_loader=weight_loader,
        )
1288
1289
1290
1291
1292
1293
1294
        layer.register_parameter("w13_weight_scale", w13_weight_scale)

        w2_weight_scale = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                hidden_size,
                # 2 fp4 items are packed in the input dimension
1295
1296
1297
                intermediate_size_per_partition // self.quant_config.group_size,
                dtype=weight_scale_dtype,
            ),
1298
1299
            input_dim=1,
            output_dim=2,
1300
1301
            weight_loader=weight_loader,
        )
1302
1303
1304
        layer.register_parameter("w2_weight_scale", w2_weight_scale)

        extra_weight_attrs.update(
1305
1306
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
        )
1307
1308

        w13_weight_scale_2 = PerTensorScaleParameter(
1309
            data=torch.empty(num_experts, w13_num_shards, dtype=torch.float32),
1310
1311
            weight_loader=weight_loader,
        )
1312
1313
1314
1315
        layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)

        w2_weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(num_experts, dtype=torch.float32),
1316
1317
            weight_loader=weight_loader,
        )
1318
1319
1320
        layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2)

        extra_weight_attrs.update(
1321
1322
            {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
        )
1323

1324
1325
        global_sf_num_experts = (
            global_num_experts if self.use_global_sf else num_experts
1326
        )
1327
        w13_input_scale = PerTensorScaleParameter(
1328
            data=torch.empty(
1329
                global_sf_num_experts,
1330
                w13_num_shards,
1331
1332
                dtype=torch.float32,
            ),
1333
1334
            weight_loader=weight_loader,
        )
1335
1336
        layer.register_parameter("w13_input_scale", w13_input_scale)

1337
        w2_input_scale = PerTensorScaleParameter(
1338
            data=torch.empty(global_sf_num_experts, dtype=torch.float32),
1339
1340
            weight_loader=weight_loader,
        )
1341
1342
        layer.register_parameter("w2_input_scale", w2_input_scale)

1343
    def process_weights_after_loading(self, layer: FusedMoE) -> None:
1344
1345
1346
        """
        Convert NVFP4 MoE weights into kernel format and setup the kernel.
        """
1347

1348
        # Use a single gscale for w13.
1349
        if self.moe.is_act_and_mul and not torch.allclose(
1350
1351
            layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]
        ):
1352
1353
            logger.warning_once(
                "w1_weight_scale_2 must match w3_weight_scale_2. "
1354
1355
                "Accuracy may be affected."
            )
1356
        w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0].contiguous()
1357

1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
        (
            w13,
            w13_scale,
            w13_scale_2,
            a13_scale,
            w2,
            w2_scale,
            w2_scale_2,
            a2_scale,
        ) = convert_to_nvfp4_moe_kernel_format(
            nvfp4_backend=self.nvfp4_backend,
            layer=layer,
            w13=layer.w13_weight,
            w13_scale=layer.w13_weight_scale,
            w13_scale_2=w13_weight_scale_2,
            a13_scale=layer.w13_input_scale,
            w2=layer.w2_weight,
            w2_scale=layer.w2_weight_scale,
            w2_scale_2=layer.w2_weight_scale_2,
            a2_scale=layer.w2_input_scale,
            is_act_and_mul=self.moe.is_act_and_mul,
1379
        )
1380

1381
1382
1383
1384
1385
1386
1387
1388
        replace_parameter(layer, "w13_weight", w13)
        replace_parameter(layer, "w13_weight_scale", w13_scale)
        replace_parameter(layer, "w13_weight_scale_2", w13_scale_2)
        replace_parameter(layer, "w13_input_scale", a13_scale)
        replace_parameter(layer, "w2_weight", w2)
        replace_parameter(layer, "w2_weight_scale", w2_scale)
        replace_parameter(layer, "w2_weight_scale_2", w2_scale_2)
        replace_parameter(layer, "w2_input_scale", a2_scale)
1389

1390
        # Setup modular kernel.
1391
        self.moe_quant_config = self.get_fused_moe_quant_config(layer)
1392
1393
1394
1395
1396
1397
1398
        assert self.experts_cls is not None
        self.moe_kernel = make_nvfp4_moe_kernel(
            moe_quant_config=self.moe_quant_config,
            moe_config=self.moe,
            experts_cls=self.experts_cls,
            shared_experts=layer.shared_experts,
            routing_tables=layer._maybe_init_expert_routing_tables(),
1399
        )
1400
        self.moe_kernel.fused_experts.process_weights_after_loading(layer)
1401

1402
    def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
1403
1404
1405
        return make_nvfp4_moe_quant_config(
            backend=self.nvfp4_backend,
            w13_scale=layer.w13_weight_scale,
1406
            w2_scale=layer.w2_weight_scale,
1407
1408
1409
1410
            w13_scale_2=layer.w13_weight_scale_2,
            w2_scale_2=layer.w2_weight_scale_2,
            a13_scale=layer.w13_input_scale,
            a2_scale=layer.w2_input_scale,
1411
1412
        )

1413
1414
1415
1416
    @property
    def supports_eplb(self) -> bool:
        return True

1417
    def apply_monolithic(
1418
        self,
1419
        layer: FusedMoE,
1420
1421
        x: torch.Tensor,
        router_logits: torch.Tensor,
1422
    ) -> torch.Tensor:
1423
        assert self.is_monolithic
1424
1425
1426
1427
1428
1429
        assert self.moe_kernel is not None
        return self.moe_kernel.apply_monolithic(
            x,
            layer.w13_weight,
            layer.w2_weight,
            router_logits,
1430
1431
            activation=layer.activation,
            global_num_experts=layer.global_num_experts,
1432
1433
            expert_map=layer.expert_map,
            apply_router_weight_on_input=layer.apply_router_weight_on_input,
1434
1435
1436
            num_expert_group=layer.num_expert_group,
            topk_group=layer.topk_group,
            e_score_correction_bias=layer.e_score_correction_bias,
1437
            routed_scaling_factor=layer.routed_scaling_factor,
1438
        )
1439

1440
1441
1442
1443
1444
1445
    def apply(
        self,
        layer: FusedMoE,
        x: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
1446
        shared_experts_input: torch.Tensor | None,
1447
    ) -> torch.Tensor:
1448
        assert not self.is_monolithic
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
        assert self.moe_kernel is not None
        return self.moe_kernel.apply(
            x,
            layer.w13_weight,
            layer.w2_weight,
            topk_weights,
            topk_ids,
            activation=layer.activation,
            global_num_experts=layer.global_num_experts,
            expert_map=layer.expert_map,
            apply_router_weight_on_input=layer.apply_router_weight_on_input,
            shared_experts_input=shared_experts_input,
        )
1462
1463
1464
1465
1466


ModelOptNvFp4Config.LinearMethodCls = ModelOptNvFp4LinearMethod
ModelOptNvFp4Config.FusedMoEMethodCls = ModelOptNvFp4FusedMoE
ModelOptNvFp4Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508


class ModelOptMxFp8Config(ModelOptQuantConfigBase):
    """Config class for ModelOpt MXFP8."""

    def __init__(
        self,
        is_checkpoint_mxfp8_serialized: bool,
        kv_cache_quant_algo: str | None,
        exclude_modules: list[str],
    ) -> None:
        super().__init__(exclude_modules)
        self.is_checkpoint_mxfp8_serialized = is_checkpoint_mxfp8_serialized

        if not is_checkpoint_mxfp8_serialized:
            raise ValueError(
                "MXFP8 quantization requires a serialized checkpoint. "
                "Dynamic quantization is not supported."
            )

        logger.warning(
            "Detected ModelOpt MXFP8 checkpoint. Please note that "
            "the format is experimental and could change in future."
        )

        self.kv_cache_quant_algo = kv_cache_quant_algo

    def get_name(self) -> QuantizationMethods:
        return "modelopt_mxfp8"

    def get_supported_act_dtypes(self) -> list[torch.dtype]:
        return [torch.bfloat16]

    @classmethod
    def get_min_capability(cls) -> int:
        # MXFP8 hardware acceleration requires Blackwell (SM100) or newer
        return 100

    @classmethod
    def override_quantization_method(
        cls, hf_quant_cfg, user_quant
    ) -> QuantizationMethods | None:
1509
1510
1511
        algo = cls._extract_modelopt_quant_algo(hf_quant_cfg)
        if algo is not None and "MXFP8" in algo:
            return "modelopt_mxfp8"
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
        return None

    @classmethod
    def _from_config(
        cls,
        *,
        quant_method: str,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
        original_config: dict[str, Any],
        **kwargs: Any,
    ) -> "ModelOptMxFp8Config":
        is_checkpoint_mxfp8_serialized = "MXFP8" in quant_method.upper()

        # For MXFP8, validate required fields in the config
        if is_checkpoint_mxfp8_serialized and "quantization" in original_config:
            quant_config = original_config["quantization"]
            required_fields = ["kv_cache_quant_algo", "exclude_modules"]
            missing_fields = [
                field for field in required_fields if field not in quant_config
            ]
            if missing_fields:
                raise ValueError(
                    f"MXFP8 quantization requires the following fields in "
                    f"hf_quant_config.json: {missing_fields}"
                )

        return cls(
            is_checkpoint_mxfp8_serialized,
            kv_cache_quant_method,
            exclude_modules,
        )


class ModelOptMxFp8LinearMethod(LinearMethodBase):
    """Linear method for ModelOpt MXFP8 quantization."""

    def __init__(self, quant_config: ModelOptMxFp8Config) -> None:
        self.quant_config = quant_config

        if not self.quant_config.is_checkpoint_mxfp8_serialized:
            raise ValueError(
                "MXFP8 currently only supports serialized checkpoints. "
                "Dynamic quantization is not supported."
            )

1558
1559
1560
        self.backend: Mxfp8LinearBackend = Mxfp8LinearBackend.FLASHINFER_CUTLASS
        self.mxfp8_linear_op = Mxfp8LinearOp(backend=self.backend)
        logger.info_once("Using %s backend for MXFP8 GEMM", self.backend.value)
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size

        if not self.quant_config.is_checkpoint_mxfp8_serialized:
            raise ValueError(
                "MXFP8 quantization was selected, but checkpoint is not "
                "MXFP8 serialized. Dynamic quantization is not supported."
            )

        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition

        if input_size_per_partition % MXFP8_BLOCK_SIZE != 0:
            raise ValueError(
                f"MXFP8 requires input dimension to be divisible by "
                f"{MXFP8_BLOCK_SIZE}, got {input_size_per_partition}"
            )

        # Weight tensor: FP8 E4M3 format
        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition,
                dtype=MXFP8_VALUE_DTYPE,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
        layer.register_parameter("weight", weight)

        # Weight scale tensor (E8M0 encoded as uint8), one scale per block of 32 along K
        weight_scale = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition // MXFP8_BLOCK_SIZE,
                dtype=MXFP8_SCALE_DTYPE,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
        layer.register_parameter("weight_scale", weight_scale)

1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
    def _process_weights_after_loading_scale_2d(self, layer: torch.nn.Module) -> None:
        """Not swizzled - MXFP8 GEMM emulation"""
        weight = layer.weight.data  # [N, K]
        N, K = weight.shape
        scale_k = K // MXFP8_BLOCK_SIZE

        # Slice weight_scale to match weight dimensions (handles padding)
        weight_scale = layer.weight_scale.data[:N, :scale_k].contiguous()

        layer.weight = Parameter(weight.contiguous(), requires_grad=False)
        layer.weight_scale = Parameter(weight_scale, requires_grad=False)

    def _process_weights_after_loading_scale_1d(self, layer: torch.nn.Module) -> None:
        """Swizzled - MXFP8 GEMM Flashinfer CUTLASS"""
        weight = layer.weight.data  # [N, K]
        N, K = weight.shape

        # 2D weight scale
        weight_scale = layer.weight_scale.data

        # Swizzle the weight scales
        scale_k = K // MXFP8_BLOCK_SIZE
        weight_scale_2d = weight_scale[:N, :scale_k].contiguous()
        weight_scale_swizzled = swizzle_mxfp8_scale(weight_scale_2d, M=N, K=K)

        layer.weight = Parameter(weight.contiguous(), requires_grad=False)
        layer.weight_scale = Parameter(
            weight_scale_swizzled.contiguous(), requires_grad=False
        )

1648
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
1649
        # Validate weight tensor
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
        if layer.weight.ndim != 2:
            raise ValueError(
                f"MXFP8 weight must be 2D tensor [N, K], got {layer.weight.ndim}D "
                f"with shape {tuple(layer.weight.shape)}"
            )

        if layer.weight.dtype != MXFP8_VALUE_DTYPE:
            raise ValueError(
                f"MXFP8 weight must be {MXFP8_VALUE_DTYPE} (FP8 E4M3), "
                f"got {layer.weight.dtype}. The checkpoint may not be properly "
                f"quantized with MXFP8."
            )

1663
1664
1665
1666
1667
1668
1669
1670
        # Validate weight scale tensor (should be 2D, not swizzled)
        assert layer.weight_scale.ndim == 2, (
            f"MXFP8 weight scale must be 2D, got {layer.weight_scale.ndim}D"
        )
        assert layer.weight_scale.dtype == MXFP8_SCALE_DTYPE, (
            f"MXFP8 weight scale must be {MXFP8_SCALE_DTYPE},"
            f" got {layer.weight_scale.dtype}"
        )
1671

1672
1673
1674
1675
        if self.backend == Mxfp8LinearBackend.EMULATION:
            # Swizzled layout is not used
            self._process_weights_after_loading_scale_2d(layer)
            return
1676

1677
1678
1679
        assert self.backend == Mxfp8LinearBackend.FLASHINFER_CUTLASS
        # Swizzled layout is required for Flashinfer CUTLASS
        self._process_weights_after_loading_scale_1d(layer)
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        if layer.weight.dtype != MXFP8_VALUE_DTYPE:
            raise ValueError(
                f"Weight dtype {layer.weight.dtype} != expected {MXFP8_VALUE_DTYPE}"
            )
        if layer.weight_scale.dtype != MXFP8_SCALE_DTYPE:
            raise ValueError(
                f"Weight scale dtype {layer.weight_scale.dtype} != "
                f"expected {MXFP8_SCALE_DTYPE}"
            )

        return self.mxfp8_linear_op.apply(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            out_dtype=x.dtype,
            bias=bias,
        )


1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
class ModelOptMxFp8FusedMoE(FusedMoEMethodBase):
    """FlashInfer TRTLLM MXFP8 block-scale MoE for ModelOpt checkpoints."""

    def __init__(
        self,
        quant_config: ModelOptMxFp8Config,
        moe_config: FusedMoEConfig,
    ) -> None:
        super().__init__(moe_config)
        self.quant_config = quant_config
        assert self.quant_config.is_checkpoint_mxfp8_serialized

1718
        self.mxfp8_backend, _ = select_mxfp8_moe_backend(self.moe)
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947

    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.orig_dtype = params_dtype

        if hidden_size % MXFP8_BLOCK_SIZE != 0:
            raise ValueError(
                f"MXFP8 MoE requires hidden_size divisible by {MXFP8_BLOCK_SIZE}, "
                f"got {hidden_size}."
            )
        if intermediate_size_per_partition % MXFP8_BLOCK_SIZE != 0:
            raise ValueError(
                "MXFP8 MoE requires intermediate_size_per_partition divisible by "
                f"{MXFP8_BLOCK_SIZE}, got {intermediate_size_per_partition}."
            )

        layer.num_experts = num_experts
        weight_loader = extra_weight_attrs.get("weight_loader")
        w13_num_shards = 2 if self.moe.is_act_and_mul else 1

        # GEMM 1 weights: [E, (2I or I), H]
        w13_weight = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                w13_num_shards * intermediate_size_per_partition,
                hidden_size,
                dtype=MXFP8_VALUE_DTYPE,
            ),
            input_dim=2,
            output_dim=1,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_weight", w13_weight)

        # GEMM 2 weights: [E, H, I]
        w2_weight = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=MXFP8_VALUE_DTYPE,
            ),
            input_dim=2,
            output_dim=1,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w2_weight", w2_weight)

        # Per-block (K=32) E8M0 scales.
        w13_weight_scale = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                w13_num_shards * intermediate_size_per_partition,
                hidden_size // MXFP8_BLOCK_SIZE,
                dtype=MXFP8_SCALE_DTYPE,
            ),
            input_dim=2,
            output_dim=1,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_weight_scale", w13_weight_scale)

        w2_weight_scale = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition // MXFP8_BLOCK_SIZE,
                dtype=MXFP8_SCALE_DTYPE,
            ),
            input_dim=2,
            output_dim=1,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w2_weight_scale", w2_weight_scale)

        # Ensure the generic MoE weight-loader treats these as block scales.
        set_weight_attrs(
            layer.w13_weight_scale,
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value},
        )
        set_weight_attrs(
            layer.w2_weight_scale,
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value},
        )

    @staticmethod
    def _check_weight_dtypes(layer: torch.nn.Module) -> None:
        """Validate weight and scale dtypes before processing."""
        expected = {
            "w13_weight": MXFP8_VALUE_DTYPE,
            "w2_weight": MXFP8_VALUE_DTYPE,
            "w13_weight_scale": MXFP8_SCALE_DTYPE,
            "w2_weight_scale": MXFP8_SCALE_DTYPE,
        }
        for name, expected_dtype in expected.items():
            actual = getattr(layer, name).dtype
            if actual != expected_dtype:
                raise ValueError(
                    f"Expected {name} dtype {expected_dtype}, got {actual}."
                )

    def _shuffle_weights_for_trtllm(self, layer: torch.nn.Module) -> None:
        """Shuffle weights and scales into FlashInfer TRTLLM MXFP8 layout."""
        from flashinfer import (
            reorder_rows_for_gated_act_gemm,
            shuffle_matrix_a,
            shuffle_matrix_sf_a,
        )

        epilogue_tile_m = 128
        num_experts = layer.w13_weight.shape[0]
        is_gated = self.moe.is_act_and_mul
        intermediate_size_factor = 2 if is_gated else 1

        w13_weight = layer.w13_weight.data
        w13_scale = layer.w13_weight_scale.data
        if is_gated:
            # FI TRTLLM gated kernels use W31 ordering. Model checkpoints store
            # gated projection as W13, so convert once before shuffling.
            w13_weight = swap_w13_to_w31(w13_weight)
            w13_scale = swap_w13_to_w31(w13_scale)

        w13_weight_shuffled = []
        w2_weight_shuffled = []
        w13_scale_shuffled = []
        w2_scale_shuffled = []
        for i in range(num_experts):
            w13_i = w13_weight[i].reshape(
                intermediate_size_factor * layer.intermediate_size_per_partition, -1
            )
            w13_sf_i = w13_scale[i].reshape(
                intermediate_size_factor * layer.intermediate_size_per_partition, -1
            )
            if is_gated:
                # Reorder rows for gated activation layout expected by TRTLLM.
                w13_i = reorder_rows_for_gated_act_gemm(w13_i.clone())
                w13_sf_i = reorder_rows_for_gated_act_gemm(w13_sf_i.clone())

            w13_shuffled_i = shuffle_matrix_a(w13_i.view(torch.uint8), epilogue_tile_m)
            w2_shuffled_i = shuffle_matrix_a(
                layer.w2_weight.data[i].view(torch.uint8), epilogue_tile_m
            )
            w13_weight_shuffled.append(
                w13_shuffled_i.contiguous().view(MXFP8_VALUE_DTYPE)
            )
            w2_weight_shuffled.append(
                w2_shuffled_i.contiguous().view(MXFP8_VALUE_DTYPE)
            )
            w13_sf_shuffled_i = shuffle_matrix_sf_a(
                w13_sf_i.view(torch.uint8).reshape(
                    intermediate_size_factor * layer.intermediate_size_per_partition,
                    -1,
                ),
                epilogue_tile_m,
            )
            w2_sf_shuffled_i = shuffle_matrix_sf_a(
                layer.w2_weight_scale.data[i]
                .view(torch.uint8)
                .reshape(layer.hidden_size, -1),
                epilogue_tile_m,
            )
            w13_scale_shuffled.append(
                w13_sf_shuffled_i.contiguous().view(MXFP8_SCALE_DTYPE)
            )
            w2_scale_shuffled.append(
                w2_sf_shuffled_i.contiguous().view(MXFP8_SCALE_DTYPE)
            )

        replace_parameter(
            layer, "w13_weight", torch.stack(w13_weight_shuffled).contiguous()
        )
        replace_parameter(
            layer, "w2_weight", torch.stack(w2_weight_shuffled).contiguous()
        )
        replace_parameter(
            layer,
            "w13_weight_scale",
            torch.stack(w13_scale_shuffled).contiguous(),
        )
        replace_parameter(
            layer,
            "w2_weight_scale",
            torch.stack(w2_scale_shuffled).contiguous(),
        )

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

        self._check_weight_dtypes(layer)
        self._shuffle_weights_for_trtllm(layer)
        layer._already_called_process_weights_after_loading = True

    def maybe_make_prepare_finalize(
        self,
        routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
    ) -> mk.FusedMoEPrepareAndFinalizeModular | None:
        raise ValueError(
            f"{self.__class__.__name__} uses the new modular kernel initialization "
            "logic. This function should not be called."
        )

    def select_gemm_impl(
        self,
        prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular,
        layer: torch.nn.Module,
    ) -> mk.FusedMoEExpertsModular:
        raise ValueError(
            f"{self.__class__.__name__} uses the new modular kernel initialization "
            "logic. This function should not be called."
        )

    def get_fused_moe_quant_config(
        self, layer: torch.nn.Module
    ) -> FusedMoEQuantConfig | None:
        # TRTLLM MXFP8 path is monolithic and does not use modular kernel config.
        return None

    @property
    def is_monolithic(self) -> bool:
1948
        return self.mxfp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960

    def apply_monolithic(
        self,
        layer: FusedMoE,
        x: torch.Tensor,
        router_logits: torch.Tensor,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        from flashinfer.fused_moe.core import (
            ActivationType,
            Fp8QuantizationType,
        )

1961
        assert self.mxfp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046

        if layer.enable_eplb:
            raise NotImplementedError(
                "EPLB is not supported for FlashInfer TRTLLM MXFP8 MoE backend."
            )

        supported_activations = [MoEActivation.SILU]
        if layer.activation not in supported_activations:
            raise NotImplementedError(
                "FlashInfer TRTLLM MXFP8 MoE supports only "
                f"{supported_activations}, got {layer.activation}."
            )

        # Map vLLM MoEActivation to FlashInfer ActivationType.
        activation_map = {
            MoEActivation.SILU: ActivationType.Swiglu,
            MoEActivation.RELU2_NO_MUL: ActivationType.Relu2,
        }
        fi_activation_type: ActivationType = activation_map[layer.activation]

        # DeepSeekV3 routing requires float32 logits; others expect bfloat16.
        if layer.routing_method_type == RoutingMethodType.DeepSeekV3:
            assert router_logits.dtype == torch.float32, (
                "DeepSeekV3 routing requires float32 router_logits, "
                f"got {router_logits.dtype}."
            )
        else:
            router_logits = router_logits.to(torch.bfloat16)

        # Treat 0 as "unset" for compatibility with ungrouped routing configs.
        n_group = layer.num_expert_group or None
        topk_group = layer.topk_group or None

        hidden_states_mxfp8, hidden_states_scale = mxfp8_e4m3_quantize(
            x,
            is_sf_swizzled_layout=False,
        )

        kwargs: dict = dict(
            routing_logits=router_logits,
            routing_bias=layer.e_score_correction_bias,
            hidden_states=hidden_states_mxfp8,
            hidden_states_scale=hidden_states_scale,
            gemm1_weights=layer.w13_weight,
            gemm1_weights_scale=layer.w13_weight_scale,
            gemm2_weights=layer.w2_weight,
            gemm2_weights_scale=layer.w2_weight_scale,
            num_experts=layer.global_num_experts,
            top_k=layer.top_k,
            # Keep Optional semantics: FlashInfer expects None for non-grouped
            # routing (e.g. Qwen3 Renormalize), not 0.
            n_group=n_group,
            topk_group=topk_group,
            intermediate_size=layer.intermediate_size_per_partition,
            local_expert_offset=layer.ep_rank * layer.local_num_experts,
            local_num_experts=layer.local_num_experts,
            routed_scaling_factor=layer.routed_scaling_factor,
            routing_method_type=layer.routing_method_type,
            use_shuffled_weight=True,
            weight_layout=0,
            fp8_quantization_type=Fp8QuantizationType.MxFp8,
        )

        if fi_activation_type != ActivationType.Swiglu:
            raise NotImplementedError(
                "FlashInfer TRTLLM MXFP8 MoE supports only Swiglu activation, "
                f"got {fi_activation_type}."
            )

        return flashinfer_trtllm_fp8_block_scale_moe(**kwargs)

    def apply(
        self,
        layer: FusedMoE,
        x: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
        shared_experts_input: torch.Tensor | None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        assert not self.is_monolithic
        raise NotImplementedError(
            "Non-monolithic MXFP8 MoE path is not yet implemented."
        )


2047
2048
# Register the method classes for ModelOptMxFp8Config
ModelOptMxFp8Config.LinearMethodCls = ModelOptMxFp8LinearMethod
2049
ModelOptMxFp8Config.FusedMoEMethodCls = ModelOptMxFp8FusedMoE
2050
ModelOptMxFp8Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235


class ModelOptMixedPrecisionConfig(ModelOptQuantConfigBase):
    """Config class for ModelOpt MIXED_PRECISION.

    Supports checkpoints where different layers use different quantization
    algorithms (e.g., FP8 for dense layers and NVFP4 for MoE experts).
    The per-layer algorithm is specified in the ``quantized_layers`` dict
    inside ``config.json``'s ``quantization_config`` (preferred) or the
    legacy ``hf_quant_config.json``.
    """

    def __init__(
        self,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
        quantized_layers: dict[str, dict[str, Any]],
        fp8_config: ModelOptFp8Config,
        nvfp4_config: ModelOptNvFp4Config,
    ) -> None:
        super().__init__(exclude_modules)
        self.kv_cache_quant_method = kv_cache_quant_method
        self.quantized_layers = quantized_layers
        self.fp8_config = fp8_config
        self.nvfp4_config = nvfp4_config

    def get_name(self) -> QuantizationMethods:
        return "modelopt_mixed"

    def get_supported_act_dtypes(self) -> list[torch.dtype]:
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
        return 89

    @classmethod
    def override_quantization_method(
        cls, hf_quant_cfg, user_quant
    ) -> QuantizationMethods | None:
        algo = cls._extract_modelopt_quant_algo(hf_quant_cfg)
        if algo is not None and algo == "MIXED_PRECISION":
            return "modelopt_mixed"
        return None

    @classmethod
    def _from_config(
        cls,
        *,
        quant_method: str,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
        original_config: dict[str, Any],
        group_size: int | None,
        **kwargs: Any,
    ) -> "ModelOptMixedPrecisionConfig":
        if "quantization" in original_config:
            quantized_layers = original_config["quantization"].get(
                "quantized_layers", {}
            )
        else:
            quantized_layers = original_config.get("quantized_layers", {})

        if not quantized_layers:
            raise ValueError(
                "MIXED_PRECISION quant_algo requires a non-empty "
                "'quantized_layers' mapping in the quantization config."
            )

        # Determine group_size from the first NVFP4 entry if not provided.
        if group_size is None:
            for layer_info in quantized_layers.values():
                if layer_info.get("quant_algo", "").upper() == "NVFP4":
                    group_size = layer_info.get("group_size", 16)
                    break
        if group_size is None:
            group_size = 16

        fp8_config = ModelOptFp8Config(
            quant_method="FP8",
            is_checkpoint_fp8_serialized=True,
            kv_cache_quant_method=kv_cache_quant_method,
            exclude_modules=[],
        )
        nvfp4_config = ModelOptNvFp4Config(
            is_checkpoint_nvfp4_serialized=True,
            kv_cache_quant_algo=kv_cache_quant_method,
            exclude_modules=[],
            group_size=group_size,
        )

        return cls(
            kv_cache_quant_method=kv_cache_quant_method,
            exclude_modules=exclude_modules,
            quantized_layers=quantized_layers,
            fp8_config=fp8_config,
            nvfp4_config=nvfp4_config,
        )

    def _resolve_quant_algo(self, prefix: str) -> str | None:
        """Look up the quant_algo for a vLLM-side layer prefix.

        Tries three strategies in order:
        1. Direct lookup in ``quantized_layers``.
        2. Packed/fused-layer lookup (unfuse via ``packed_modules_mapping``).
        3. Prefix-based lookup for FusedMoE (any child key starts with
           ``prefix + "."``).

        Returns the upper-cased quant_algo string, or *None* if the prefix
        is not found.
        """
        # 1. Direct lookup
        if prefix in self.quantized_layers:
            return self.quantized_layers[prefix]["quant_algo"].upper()

        # 2. Packed / fused layer lookup
        proj_name = prefix.rsplit(".", 1)[-1]
        if self.packed_modules_mapping and proj_name in self.packed_modules_mapping:
            algos: set[str] = set()
            base = prefix.rsplit(".", 1)[0]
            for shard_name in self.packed_modules_mapping[proj_name]:
                shard_prefix = f"{base}.{shard_name}"
                if shard_prefix in self.quantized_layers:
                    algos.add(self.quantized_layers[shard_prefix]["quant_algo"].upper())
            if len(algos) == 1:
                return algos.pop()
            if len(algos) > 1:
                raise ValueError(
                    f"Mixed quant_algo within fused layer {prefix}: "
                    f"{algos}. All shards must use the same quantization."
                )

        # 3. Prefix-based lookup (for FusedMoE / parent modules)
        prefix_dot = prefix + "."
        for key, info in self.quantized_layers.items():
            if key.startswith(prefix_dot):
                return info["quant_algo"].upper()

        return None

    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> "QuantizeMethodBase | None":
        """Return quantize-method based on layer."""
        # KV-cache quantization
        if isinstance(layer, Attention):
            if self.kv_cache_quant_method:
                return ModelOptFp8KVCacheMethod(self)
            return None

        # Excluded layers
        if self.is_layer_excluded(prefix):
            if isinstance(layer, LinearBase):
                return UnquantizedLinearMethod()
            return None

        quant_algo = self._resolve_quant_algo(prefix)

        if isinstance(layer, LinearBase):
            if quant_algo == "FP8":
                return ModelOptFp8LinearMethod(self.fp8_config)
            if quant_algo == "NVFP4":
                return ModelOptNvFp4LinearMethod(self.nvfp4_config)
            # Layer not in quantized_layers — leave unquantized
            return UnquantizedLinearMethod()

        if isinstance(layer, FusedMoE):
            if quant_algo == "FP8":
                return ModelOptFp8MoEMethod(
                    quant_config=self.fp8_config,
                    moe_config=layer.moe_config,
                )
            if quant_algo == "NVFP4":
                return ModelOptNvFp4FusedMoE(
                    quant_config=self.nvfp4_config,
                    moe_config=layer.moe_config,
                )
            return None

        return None

    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        super().apply_vllm_mapper(hf_to_vllm_mapper)
        if self.quantized_layers:
            self.quantized_layers = hf_to_vllm_mapper.apply_dict(self.quantized_layers)