modelopt.py 65.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from fnmatch import fnmatch
5
from typing import TYPE_CHECKING, Any
6
7
8
9

import torch
from torch.nn.parameter import Parameter

10
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
11
from vllm.logger import init_logger
12
13
14
from vllm.model_executor.kernels.linear import (
    init_fp8_linear_kernel,
)
15
from vllm.model_executor.layers.attention import Attention
16
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
17
from vllm.model_executor.layers.fused_moe.config import (
18
    FusedMoEConfig,
19
20
    FusedMoEQuantConfig,
)
21
from vllm.model_executor.layers.fused_moe.layer import (
22
23
24
25
    FusedMoE,
    FusedMoEMethodBase,
    FusedMoeWeightScaleSupported,
)
26
27
28
29
30
31
32
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
    Fp8MoeBackend,
    convert_to_fp8_moe_kernel_format,
    make_fp8_moe_kernel,
    make_fp8_moe_quant_config,
    select_fp8_moe_backend,
)
33
34
35
36
37
38
39
40
from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
    NvFp4MoeBackend,
    convert_to_nvfp4_moe_kernel_format,
    is_global_sf_supported_for_nvfp4_backend,
    make_nvfp4_moe_kernel,
    make_nvfp4_moe_quant_config,
    select_nvfp4_moe_backend,
)
41
42
43
44
45
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
)
46
from vllm.model_executor.layers.quantization import QuantizationMethods
47
from vllm.model_executor.layers.quantization.base_config import (
48
49
50
    QuantizationConfig,
    QuantizeMethodBase,
)
51
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
52
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
53
    flashinfer_trtllm_fp4_moe,
54
    flashinfer_trtllm_fp4_routed_moe,
55
)
56
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
57
    apply_fi_trtllm_fp8_per_tensor_moe,
58
)
59
60
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
    W8A8BlockFp8LinearOp,
61
62
    process_fp8_input_tensor_strategy_moe,
    process_fp8_weight_tensor_strategy_moe,
63
)
64
65
66
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
    get_marlin_input_dtype,
)
67
68
69
70
71
72
from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
    MXFP8_BLOCK_SIZE,
    MXFP8_SCALE_DTYPE,
    MXFP8_VALUE_DTYPE,
    Mxfp8LinearBackend,
    Mxfp8LinearOp,
73
    swizzle_mxfp8_scale,
74
)
75
76
77
78
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
    apply_nvfp4_linear,
    convert_to_nvfp4_linear_kernel_format,
    select_nvfp4_linear_backend,
79
)
80
from vllm.model_executor.layers.quantization.utils.quant_utils import (
81
82
    GroupShape,
    is_layer_skipped,
83
84
85
    kFp8DynamicTokenSym,
    kFp8StaticTensorSym,
    kFp8StaticTokenSym,
86
87
    kNvfp4Dynamic,
    kNvfp4Static,
88
)
89
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
90
    cutlass_block_fp8_supported,
91
92
    requantize_with_max_scale,
)
93
94
95
96
97
98
from vllm.model_executor.parameter import (
    BlockQuantScaleParameter,
    ChannelQuantScaleParameter,
    ModelWeightParameter,
    PerTensorScaleParameter,
)
99
from vllm.model_executor.utils import replace_parameter
100

101
102
103
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

104
105
logger = init_logger(__name__)

106
107
108
109
110
111
112
113
114
QUANT_ALGOS = [
    # FP8 (per-tensor weight + optional static activation scale).
    "FP8",
    # FP8 per-channel weight scale + per-token activation scale.
    "FP8_PER_CHANNEL_PER_TOKEN",
    # FP8 per-block weight-only (ModelOpt may emit this as lowercase).
    "FP8_PB_WO",
    # FP4
    "NVFP4",
115
116
    # MXFP8
    "MXFP8",
117
]
118
KV_CACHE_QUANT_ALGOS = ["FP8"]
119
120


121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
    """

    def __init__(self, quant_config: "ModelOptQuantConfigBase"):
        super().__init__(quant_config)


class ModelOptQuantConfigBase(QuantizationConfig):
    LinearMethodCls: type = LinearMethodBase
    FusedMoEMethodCls: type = FusedMoEMethodBase
    KVCacheMethodCls: type = BaseKVCacheMethod

    def __init__(
        self,
        exclude_modules: list[str],
    ):
        super().__init__()
        self.exclude_modules: list[str] = exclude_modules

    def is_layer_excluded(self, prefix: str) -> bool:
        """
        Check if a layer should be excluded from quantization.

        Handles both exact matching (for fused layers) and ModelOpt wildcard matching.

        The ModelOpt exclude_modules list is a list of wildcards.
        """
        if len(self.exclude_modules) == 0:
            return False

        # First check exact matching with fused layer support
        if is_layer_skipped(prefix, self.exclude_modules, self.packed_modules_mapping):
            return True

        # TODO: This special hard coded logic is not needed for quantized checkpoints
        # generated by ModelOpt >= 0.39.0 where they are handled natually by the
        # exclude_modules config. But need to keep them for loading quantized
        # checkpoints generated by older versions. Then check substring matching
        # for patterns not caught by exact match
        for exclude_module in self.exclude_modules:
            # Skip exact matches already handled above
            if exclude_module != prefix and (
                exclude_module in prefix
                or (
                    prefix.startswith("language_model.")
                    and exclude_module in prefix.removeprefix("language_model.")
                )
            ):
                return True

        # modelopt exclude modules are not simple strings, they are wildcards
        for wildcard_pattern in self.exclude_modules:
            if fnmatch(prefix, wildcard_pattern):
                return True

        return False

    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
182
    ) -> "QuantizeMethodBase | None":
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
        # handle kv-cache first so we can focus only on weight quantization thereafter
        if isinstance(layer, Attention):
            return self.KVCacheMethodCls(self)

        # handle exclusion
        if self.is_layer_excluded(prefix):
            if isinstance(layer, LinearBase):
                return UnquantizedLinearMethod()
            return None

        # TODO: This special hard coded logic is not needed for quantized checkpoints
        # generated by ModelOpt >= 0.39.0 where they are handled natually by the
        # exclude_modules config. But need to keep them for loading quantized
        # checkpoints generated by older versions. Then check substring matching
        # for patterns not caught by exact match
        if "vision_tower" in prefix or "vision_model" in prefix:
            return UnquantizedLinearMethod()

        # now, the layer is quantized, handle it here
        if isinstance(layer, LinearBase):
203
204
205
206
            quant_method = self.LinearMethodCls(self)
            if getattr(quant_method, "backend", "") == "marlin":
                quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
            return quant_method
207
        elif isinstance(layer, FusedMoE):
208
209
210
            quant_method = self.FusedMoEMethodCls(
                quant_config=self, moe_config=layer.moe_config
            )
211
212
213
            if getattr(quant_method, "backend", "") == "marlin":
                quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
            return quant_method
214
215
216
217
218

        return None

    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if len(self.exclude_modules) > 0:
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
            # This is a workaround for the weights remapping issue:
            # https://github.com/vllm-project/vllm/issues/28072
            # Right now, the Nvidia ModelOpt library use just one wildcard pattern:
            #        module_path*
            # It gets applied if the whole tree of modules rooted at module_path
            # is not quantized. Here we replace such pattern by 2 patterns that are
            # collectively equivalent to the original pattern:
            #        module_path
            #        module_path.*
            new_exclude_modules = []
            for exclude in self.exclude_modules:
                if len(exclude) >= 2 and exclude[-1] == "*" and exclude[-2] != ".":
                    new_exclude_modules.append(exclude[:-1])
                    new_exclude_modules.append(exclude[:-1] + ".*")
                else:
                    new_exclude_modules.append(exclude)

            self.exclude_modules = hf_to_vllm_mapper.apply_list(new_exclude_modules)
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285

    @staticmethod
    def get_config_filenames() -> list[str]:
        return ["hf_quant_config.json"]

    @classmethod
    def _from_config(
        cls,
        *,
        quant_method: str,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
        original_config: dict[str, Any],
        group_size: int | None,
    ) -> "ModelOptQuantConfigBase":
        raise NotImplementedError("Please implement this function in sub classes")

    @classmethod
    def from_config(cls, config: dict[str, Any]) -> "ModelOptQuantConfigBase":
        # Handle both ModelOpt format and compressed-tensors style format
        if "quantization" in config:
            # Traditional ModelOpt format:
            # {"quantization": {"quant_algo": "..."}}
            quant_config = cls.get_from_keys(config, ["quantization"])
            if not isinstance(quant_config, dict):
                raise ValueError("Expected 'quantization' to be a dictionary in config")

            quant_method = quant_config.get("quant_algo")

            # Handle kv_cache_quant_algo with proper type validation
            kv_cache_quant_method = quant_config.get("kv_cache_quant_algo")

            # Handle group_size with proper type validation
            group_size_raw = quant_config.get("group_size")

            # "exclude_modules" is the key in the legacy hf_quant_config.json
            exclude_modules = quant_config.get("exclude_modules", [])
        else:
            # Compressed-tensors style format:
            # {"quant_algo": "...", "quant_method": "modelopt"}
            quant_method = config.get("quant_algo")
            kv_cache_quant_method = config.get("kv_cache_quant_algo")
            # "ignore" is the key in config.json
            exclude_modules = config.get("ignore", [])
            group_size_raw = config.get("group_size")

        if not quant_method:
            raise ValueError("Missing 'quant_algo' in quantization config")

286
287
288
        # Normalize quant_algo for robust matching (ModelOpt may emit lowercase).
        quant_method = str(quant_method).upper()

289
290
291
292
293
294
295
296
        if kv_cache_quant_method is None:
            # No KV cache quantization, keep this branch just to have this comment
            pass
        elif not isinstance(kv_cache_quant_method, str):
            raise ValueError(
                f"kv_cache_quant_algo must be a string, got "
                f"{type(kv_cache_quant_method)}"
            )
297
298
        else:
            kv_cache_quant_method = kv_cache_quant_method.upper()
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333

        if not isinstance(exclude_modules, list):
            raise ValueError(
                f"exclude_modules must be a list, got {type(exclude_modules)}"
            )

        if group_size_raw is None:
            group_size = None
        elif isinstance(group_size_raw, int):
            group_size = group_size_raw
        else:
            try:
                group_size = int(group_size_raw)
            except (ValueError, TypeError):
                raise ValueError(
                    f"group_size must be an integer, got {type(group_size_raw)}"
                ) from None

        if quant_method not in QUANT_ALGOS:
            raise ValueError(
                f"ModelOpt currently only supports: {QUANT_ALGOS} "
                "quantizations in vLLM. Please check the "
                "`hf_quant_config.json` file for your model's "
                "quant configuration."
            )
        return cls._from_config(
            quant_method=quant_method,
            kv_cache_quant_method=kv_cache_quant_method,
            exclude_modules=exclude_modules,
            group_size=group_size,
            original_config=config,
        )


class ModelOptFp8Config(ModelOptQuantConfigBase):
334
335
336
337
    """Config class for ModelOpt FP8."""

    def __init__(
        self,
338
        quant_method: str,
339
340
341
        is_checkpoint_fp8_serialized: bool,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
342
    ) -> None:
343
        super().__init__(exclude_modules)
344
        self.quant_method = quant_method
345
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
346
        self.kv_cache_quant_method = kv_cache_quant_method
347
        if is_checkpoint_fp8_serialized:
348
            logger.warning(
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
                "Detected ModelOpt fp8 checkpoint (quant_algo=%s). Please note "
                "that the format is experimental and could change.",
                quant_method,
            )

        # Select LinearMethod implementation based on quant_algo.
        if self.quant_method == "FP8":
            self.LinearMethodCls = ModelOptFp8LinearMethod
        elif self.quant_method == "FP8_PER_CHANNEL_PER_TOKEN":
            self.LinearMethodCls = ModelOptFp8PcPtLinearMethod
        elif self.quant_method == "FP8_PB_WO":
            self.LinearMethodCls = ModelOptFp8PbWoLinearMethod
        else:
            raise ValueError(
                "Unsupported ModelOpt FP8 quant_algo for vLLM: "
                f"{self.quant_method}. Supported: FP8 / "
                "FP8_PER_CHANNEL_PER_TOKEN / FP8_PB_WO."
366
            )
367

368
    def get_name(self) -> QuantizationMethods:
369
370
        return "modelopt"

371
    def get_supported_act_dtypes(self) -> list[torch.dtype]:
372
373
374
375
376
377
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
        return 89

378
379
    @classmethod
    def override_quantization_method(
380
        cls, hf_quant_cfg, user_quant
381
    ) -> QuantizationMethods | None:
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
        """Detect if this ModelOpt config should be used based on
        quantization config."""

        if hf_quant_cfg is None:
            return None

        # Use the community standard 'quant_method'
        quant_method = hf_quant_cfg.get("quant_method", "").lower()

        # Only proceed if the method is explicitly "modelopt"
        if quant_method != "modelopt":
            return None

        # Look for ModelOpt-specific config structure
        if "quantization" in hf_quant_cfg:
            quant_config = hf_quant_cfg["quantization"]
            if isinstance(quant_config, dict):
399
                quant_algo = str(quant_config.get("quant_algo", ""))
400
                if quant_algo.upper() == "FP8":
401
402
403
                    return "modelopt"
        else:
            # Check for compressed-tensors style config with specific quant_algo
404
            quant_algo = str(hf_quant_cfg.get("quant_algo", ""))
405
            if quant_algo.upper() == "FP8":
406
407
408
409
                return "modelopt"

        return None

410
    @classmethod
411
412
413
414
415
416
417
418
419
    def _from_config(
        cls,
        *,
        quant_method: str,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
        original_config: dict[str, Any],
        **kwargs: Any,
    ) -> "ModelOptFp8Config":
420
        is_checkpoint_fp8_serialized = "FP8" in quant_method
421

422
423
424
425
426
427
        return cls(
            quant_method,
            is_checkpoint_fp8_serialized,
            kv_cache_quant_method,
            exclude_modules,
        )
428

429
430
431
432

class ModelOptFp8LinearMethod(LinearMethodBase):
    """Linear method for Model Optimizer static quantization.
    Supports loading FP8 checkpoints with static weight scale and
433
    activation scale. Future support might be added for dynamic
434
435
436
437
    scales.

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
438
    2. Only support float8_e4m3fn datatype
439
440
441
        Args: quant_config: The ModelOpt quantization config.
    """

442
    def __init__(self, quant_config: ModelOptFp8Config) -> None:
443
        self.quant_config = quant_config
444
445
446
447
448
        self.fp8_linear = init_fp8_linear_kernel(
            activation_quant_key=kFp8StaticTensorSym,
            weight_quant_key=kFp8StaticTensorSym,
            out_dtype=torch.get_default_dtype(),
            module_name=self.__class__.__name__,
449
        )
450
451
452
453
454

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
455
        output_partition_sizes: list[int],
456
457
458
459
460
461
462
463
464
465
466
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
467
468
469
470
471
472
473
474
475
476
477
478
479
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_fp8_serialized
            else params_dtype
        )
        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition, input_size_per_partition, dtype=weight_dtype
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
480
481
482
483
        layer.register_parameter("weight", weight)

        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
484
485
486
487
            weight_scale = PerTensorScaleParameter(
                data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
                weight_loader=weight_loader,
            )
488
489
490
            weight_scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("weight_scale", weight_scale)
            # INPUT SCALE
491
492
493
494
            scale = PerTensorScaleParameter(
                data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
                weight_loader=weight_loader,
            )
495
496
497
498

            scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("input_scale", scale)

499
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
500
501
502
503
        weight = layer.weight
        max_w_scale = layer.weight_scale.max()
        if not (layer.weight_scale == layer.weight_scale[0]).all():
            max_w_scale, weight = requantize_with_max_scale(
504
505
                layer.weight, layer.weight_scale, layer.logical_widths
            )
506
507
        layer.weight = Parameter(weight.t(), requires_grad=False)
        layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
508
        layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
509
510
511
512
513

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
514
        bias: torch.Tensor | None = None,
515
    ) -> torch.Tensor:
516
        return self.fp8_linear.apply_weights(layer, x, bias)
517
518


519
520
521
522
523
524
525
526
527
528
529
class ModelOptFp8PcPtLinearMethod(LinearMethodBase):
    """Linear method for ModelOpt FP8_PER_CHANNEL_PER_TOKEN checkpoints.

    Expected checkpoint structure (per Linear):
    - weight: fp8-e4m3fn, shape [out, in]
    - weight_scale: fp32, shape [out] (per-output-channel)
    - no input_scale (activations are dynamically quantized per-token)
    """

    def __init__(self, quant_config: ModelOptFp8Config) -> None:
        self.quant_config = quant_config
530
531
532
533
534
        self.fp8_linear = init_fp8_linear_kernel(
            activation_quant_key=kFp8DynamicTokenSym,
            weight_quant_key=kFp8StaticTokenSym,
            out_dtype=torch.get_default_dtype(),
            module_name=self.__class__.__name__,
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
        )

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size

        if not self.quant_config.is_checkpoint_fp8_serialized:
            raise ValueError(
                "FP8_PER_CHANNEL_PER_TOKEN currently only supports "
                "FP8-serialized checkpoints."
            )

        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition

        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition,
                dtype=torch.float8_e4m3fn,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
        layer.register_parameter("weight", weight)

        weight_scale = ChannelQuantScaleParameter(
            data=torch.empty(output_size_per_partition, dtype=torch.float32),
            output_dim=0,
            weight_loader=weight_loader,
        )
        weight_scale[:] = torch.finfo(torch.float32).min
        layer.register_parameter("weight_scale", weight_scale)

581
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
582
583
584
585
586
587
588
589
590
        layer.weight = Parameter(layer.weight.t(), requires_grad=False)
        layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
591
        return self.fp8_linear.apply_weights(layer, x, bias)
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681


class ModelOptFp8PbWoLinearMethod(LinearMethodBase):
    """Linear method for ModelOpt FP8_PB_WO checkpoints.

    ModelOpt exports `weight_scale` as a 4D tensor:
      [out_blk, 1, in_blk, 1]
    where block size is typically 128 for both dims.

    vLLM executes it as FP8 GEMM with *dynamic per-token* activation quant.
    """

    _WEIGHT_BLOCK_SIZE: tuple[int, int] = (128, 128)

    def __init__(self, quant_config: ModelOptFp8Config) -> None:
        self.quant_config = quant_config
        block_n, block_k = self._WEIGHT_BLOCK_SIZE
        self.weight_block_size = list(self._WEIGHT_BLOCK_SIZE)
        self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
            weight_group_shape=GroupShape(block_n, block_k),
            act_quant_group_shape=GroupShape(1, block_k),
            cutlass_block_fp8_supported=cutlass_block_fp8_supported(),
            use_aiter_and_is_supported=False,
        )

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size

        if not self.quant_config.is_checkpoint_fp8_serialized:
            raise ValueError(
                "FP8_PB_WO currently only supports FP8-serialized checkpoints."
            )

        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition

        # Expose block size so the v2 weight loaders can translate offsets from
        # element-space -> block-space for BlockQuantScaleParameter.
        layer.weight_block_size = self.weight_block_size

        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition,
                dtype=torch.float8_e4m3fn,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
        layer.register_parameter("weight", weight)

        block_n, block_k = self._WEIGHT_BLOCK_SIZE
        if output_size_per_partition % block_n != 0:
            raise ValueError(
                "ModelOpt FP8_PB_WO requires out_features divisible by "
                f"{block_n}, got {output_size_per_partition}."
            )
        if input_size_per_partition % block_k != 0:
            raise ValueError(
                "ModelOpt FP8_PB_WO requires in_features divisible by "
                f"{block_k}, got {input_size_per_partition}."
            )

        out_blks = output_size_per_partition // block_n
        in_blks = input_size_per_partition // block_k

        # Match ModelOpt's exported shape so weight loading works without a
        # custom loader: [out_blk, 1, in_blk, 1]
        weight_scale = BlockQuantScaleParameter(
            data=torch.empty((out_blks, 1, in_blks, 1), dtype=torch.float32),
            input_dim=2,
            output_dim=0,
            weight_loader=weight_loader,
        )
        weight_scale[:] = torch.finfo(torch.float32).min
        layer.register_parameter("weight_scale", weight_scale)

682
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
        # Keep weight in [out, in] layout for W8A8BlockFp8LinearOp.
        layer.weight = Parameter(layer.weight.data, requires_grad=False)

        scale = layer.weight_scale
        if scale.dim() == 4:
            # [out_blk, 1, in_blk, 1] -> [out_blk, in_blk]
            scale = scale.squeeze(1).squeeze(-1)
        elif scale.dim() != 2:
            raise ValueError(
                "Unexpected ModelOpt FP8_PB_WO weight_scale shape: "
                f"{tuple(scale.shape)}."
            )

        layer.weight_scale = Parameter(scale.contiguous(), requires_grad=False)

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        return self.w8a8_block_fp8_linear.apply(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            input_scale=None,
            bias=bias,
        )


713
714
715
716
717
718
719
720
class ModelOptFp8MoEMethod(FusedMoEMethodBase):
    """MoE method for ModelOpt FP8.
    Supports loading FP8 checkpoints with static weight scale and
    activation scale.
    Args:
        quant_config: The ModelOpt quantization config.
    """

721
722
723
    def __init__(
        self,
        quant_config: ModelOptFp8Config,
724
        moe_config: FusedMoEConfig,
725
    ) -> None:
726
        super().__init__(moe_config)
727
        self.quant_config = quant_config
728
        assert self.quant_config.is_checkpoint_fp8_serialized
729
730
731
732
733
734

        # Select Fp8 MoE backend
        self.fp8_backend, self.experts_cls = select_fp8_moe_backend(
            config=self.moe,
            weight_key=kFp8StaticTensorSym,
            activation_key=kFp8StaticTensorSym,
735
        )
736

737
    def maybe_make_prepare_finalize(
738
        self,
739
        routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
740
    ) -> mk.FusedMoEPrepareAndFinalize | None:
741
742
743
744
        raise ValueError(
            f"{self.__class__.__name__} uses the new modular kernel initialization "
            "logic. This function should not be called."
        )
745
746
747
748

    def select_gemm_impl(
        self,
        prepare_finalize: mk.FusedMoEPrepareAndFinalize,
749
        layer: torch.nn.Module,
750
    ) -> mk.FusedMoEPermuteExpertsUnpermute:
751
752
753
        raise ValueError(
            f"{self.__class__.__name__} uses the new modular kernel initialization "
            "logic. This function should not be called."
754
        )
755
756
757
758
759
760
761
762
763
764

    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
765
766
767
        layer.orig_dtype = params_dtype
        layer.num_experts = num_experts

768
        # Use FP8 dtype if checkpoint is serialized
769
770
771
772
773
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_fp8_serialized
            else params_dtype
        )
774
775
        weight_loader = extra_weight_attrs.get("weight_loader")

776
        w13_num_shards = 2 if self.moe.is_act_and_mul else 1
777

778
        w13_weight = ModelWeightParameter(
779
780
            data=torch.empty(
                num_experts,
781
                w13_num_shards * intermediate_size_per_partition,
782
783
784
                hidden_size,
                dtype=weight_dtype,
            ),
785
786
787
788
789
790
791
            input_dim=2,
            output_dim=1,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_weight", w13_weight)

        w2_weight = ModelWeightParameter(
792
793
794
795
796
797
            data=torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=weight_dtype,
            ),
798
799
800
801
802
803
            input_dim=2,
            output_dim=1,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w2_weight", w2_weight)

804
805
806
807
808
809
        # WEIGHT SCALES - Per-tensor scaling for ModelOpts
        # For gated MoE, allocate 2 scales for w1 and w3 respectively.
        # They will be combined to a single scale after weight loading.
        # For non-gated MoE, allocate 1 scale for w13.
        w13_weight_scale = PerTensorScaleParameter(
            data=torch.full(
810
                (num_experts, w13_num_shards),
811
812
813
814
815
816
817
818
819
820
821
                1.0,
                dtype=torch.float32,
            ),
            weight_loader=weight_loader,
        )
        w2_weight_scale = PerTensorScaleParameter(
            data=torch.full((num_experts,), 1.0, dtype=torch.float32),
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_weight_scale", w13_weight_scale)
        layer.register_parameter("w2_weight_scale", w2_weight_scale)
822

823
824
825
826
        # INPUT SCALES - Per-tensor scaling for ModelOpt
        w13_input_scale = PerTensorScaleParameter(
            data=torch.full((num_experts,), 1.0, dtype=torch.float32),
            weight_loader=weight_loader,
827
        )
828
829
830
831
832
833
        w2_input_scale = PerTensorScaleParameter(
            data=torch.full((num_experts,), 1.0, dtype=torch.float32),
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_input_scale", w13_input_scale)
        layer.register_parameter("w2_input_scale", w2_input_scale)
834

835
836
    def _setup_kernel(
        self,
837
        layer: FusedMoE,
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
        w13: torch.Tensor,
        w2: torch.Tensor,
        w13_scale: torch.Tensor,
        w2_scale: torch.Tensor,
        w13_input_scale: torch.Tensor,
        w2_input_scale: torch.Tensor,
    ):
        w13, w2, w13_scale, w2_scale = convert_to_fp8_moe_kernel_format(
            fp8_backend=self.fp8_backend,
            layer=layer,
            w13=w13,
            w2=w2,
            w13_scale=w13_scale,
            w2_scale=w2_scale,
            w13_input_scale=w13_input_scale,
            w2_input_scale=w2_input_scale,
        )
855

856
857
858
859
860
861
862
        # Replace parameters with updated versions. Note that this helper
        # function ensures the replacement is compatible with RL weight reloads.
        replace_parameter(layer, "w13_weight", w13)
        replace_parameter(layer, "w2_weight", w2)
        replace_parameter(layer, "w13_weight_scale", w13_scale)
        replace_parameter(layer, "w2_weight_scale", w2_scale)

863
        # Setup modular kernel.
864
865
        self.moe_quant_config = self.get_fused_moe_quant_config(layer)
        if self.moe_quant_config:
866
            assert self.experts_cls is not None
867
            self.moe_mk = make_fp8_moe_kernel(
868
869
870
                moe_quant_config=self.moe_quant_config,
                moe_config=self.moe,
                fp8_backend=self.fp8_backend,
871
                experts_cls=self.experts_cls,
872
873
                routing_tables=layer._maybe_init_expert_routing_tables(),
                shared_experts=layer.shared_experts,
874
            )
875

876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        w13 = layer.w13_weight
        w2 = layer.w2_weight
        w13_scale = layer.w13_weight_scale
        w2_scale = layer.w2_weight_scale
        w13_input_scale = layer.w13_input_scale
        w2_input_scale = layer.w2_input_scale

        # Per tensor kernels require single activation scale. Use the max.
        w13_input_scale, w2_input_scale = process_fp8_input_tensor_strategy_moe(
            w13_input_scale, w2_input_scale
        )
        replace_parameter(layer, "w13_input_scale", w13_input_scale)
        replace_parameter(layer, "w2_input_scale", w2_input_scale)

        # Per tensor kernels require single weight scale for w13 per expert, but
        # on disk there is a scale for w1 and w3. Use the max to requantize.
        shard_size = layer.intermediate_size_per_partition
        w13, w13_scale = process_fp8_weight_tensor_strategy_moe(
            w13,
            w13_scale,
            shard_size,
            num_experts=layer.w13_weight.shape[0],
            is_act_and_mul=self.moe.is_act_and_mul,
900
901
        )

902
903
904
905
        # Shuffle weights to runtime format and setup kernel.
        self._setup_kernel(
            layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale
        )
906

907
    def get_fused_moe_quant_config(
908
        self, layer: torch.nn.Module
909
    ) -> FusedMoEQuantConfig | None:
910
911
912
913
914
915
916
917
918
919
920
921
        w1_scale = layer.w13_weight_scale
        w2_scale = layer.w2_weight_scale
        a1_scale = layer.w13_input_scale
        a2_scale = layer.w2_input_scale

        return make_fp8_moe_quant_config(
            fp8_backend=self.fp8_backend,
            w1_scale=w1_scale,
            w2_scale=w2_scale,
            a1_scale=a1_scale,
            a2_scale=a2_scale,
        )
922

923
924
925
926
927
    @property
    def is_monolithic(self) -> bool:
        return self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM

    def apply_monolithic(
928
        self,
929
        layer: FusedMoE,
930
931
        x: torch.Tensor,
        router_logits: torch.Tensor,
932
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
933
934
935
936
937
        assert self.is_monolithic
        assert self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
        if layer.enable_eplb:
            raise NotImplementedError(
                "EPLB not supported for FlashInfer TRTLLM FP8 MoE Backend."
938
            )
939
940
        # TODO(rob): this validation should happen at kernel selection
        # time in the oracle rather than here.
941
942
943
944
        SUPPORTED_ACTIVATIONS = [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
        assert layer.activation in SUPPORTED_ACTIVATIONS, (
            f"Only {SUPPORTED_ACTIVATIONS} activations are supported for FlashInfer "
            f"TRTLLM FP4 MoE, {layer.activation} found instead."
945
946
947
        )
        return apply_fi_trtllm_fp8_per_tensor_moe(
            layer=layer,
948
949
            hidden_states=x,
            router_logits=router_logits,
950
951
952
953
954
955
            routing_bias=layer.e_score_correction_bias,
            global_num_experts=layer.global_num_experts,
            top_k=layer.top_k,
            num_expert_group=layer.num_expert_group,
            topk_group=layer.topk_group,
            apply_router_weight_on_input=layer.apply_router_weight_on_input,
956
        )
957

958
959
960
961
962
963
    def apply(
        self,
        layer: FusedMoE,
        x: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
964
        shared_experts_input: torch.Tensor | None,
965
966
967
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        assert not self.is_monolithic

968
969
970
        # TODO(rob): this validation should happen at kernel selection
        # time in the oracle rather than here.
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
971
972
973
974
            assert layer.activation in (
                MoEActivation.SILU,
                MoEActivation.RELU2_NO_MUL,
            ), (
975
                "Expected activation to be in ('silu', 'relu2_no_mul'),"
976
                f"but got {layer.activation}"
977
            )
978

979
980
        assert self.moe_mk is not None
        return self.moe_mk(
981
982
983
984
985
            hidden_states=x,
            w1=layer.w13_weight,
            w2=layer.w2_weight,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
986
987
988
989
            activation=layer.activation,
            global_num_experts=layer.global_num_experts,
            expert_map=layer.expert_map,
            apply_router_weight_on_input=layer.apply_router_weight_on_input,
990
            shared_experts_input=shared_experts_input,
991
992
        )

993

994
995
996
997
998
999
ModelOptFp8Config.LinearMethodCls = ModelOptFp8LinearMethod
ModelOptFp8Config.FusedMoEMethodCls = ModelOptFp8MoEMethod
ModelOptFp8Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod


class ModelOptNvFp4Config(ModelOptQuantConfigBase):
1000
1001
1002
1003
1004
    """Config class for ModelOpt FP4."""

    def __init__(
        self,
        is_checkpoint_nvfp4_serialized: bool,
1005
        kv_cache_quant_algo: str | None,
1006
        exclude_modules: list[str],
1007
1008
        group_size: int = 16,
    ) -> None:
1009
        super().__init__(exclude_modules)
1010
1011
1012
1013
        self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
        if is_checkpoint_nvfp4_serialized:
            logger.warning(
                "Detected ModelOpt NVFP4 checkpoint. Please note that"
1014
1015
                " the format is experimental and could change in future."
            )
1016
1017
1018
1019

            self.group_size = group_size
            self.kv_cache_quant_algo = kv_cache_quant_algo

1020
    def get_name(self) -> QuantizationMethods:
1021
        return "modelopt_fp4"
1022

1023
    def get_supported_act_dtypes(self) -> list[torch.dtype]:
1024
1025
1026
1027
        return [torch.bfloat16, torch.half, torch.float8_e4m3fn]

    @classmethod
    def get_min_capability(cls) -> int:
1028
        return 75
1029

1030
1031
    @classmethod
    def override_quantization_method(
1032
        cls, hf_quant_cfg, user_quant
1033
    ) -> QuantizationMethods | None:
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
        """Detect if this ModelOpt FP4 config should be used based on
        quantization config."""
        if hf_quant_cfg is None:
            return None

        # Use the community standard 'quant_method'
        quant_method = hf_quant_cfg.get("quant_method", "").lower()

        # Only proceed if the method is explicitly "modelopt"
        if quant_method != "modelopt":
            return None

        # Look for ModelOpt-specific config structure
        if "quantization" in hf_quant_cfg:
            quant_config = hf_quant_cfg["quantization"]
            if isinstance(quant_config, dict):
                quant_algo = quant_config.get("quant_algo", "")
                if "NVFP4" in quant_algo:
                    return "modelopt_fp4"
        else:
            # Check for compressed-tensors style config with specific
            # quant_algo field
            quant_algo = hf_quant_cfg.get("quant_algo", "")
            if isinstance(quant_algo, str) and "FP4" in quant_algo.upper():
                return "modelopt_fp4"

        return None

1062
    @classmethod
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
    def _from_config(
        cls,
        *,
        quant_method: str,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
        original_config: dict[str, Any],
        group_size: int | None,
        **kwargs: Any,
    ) -> "ModelOptNvFp4Config":
1073
        is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
1074

1075
1076
1077
        if group_size is None:
            group_size = 16  # Default value

1078
        # For FP4, these fields are required
1079
        if is_checkpoint_nvfp4_serialized and "quantization" in original_config:
1080
            # Check if required fields are present in the quantization config
1081
            quant_config = original_config["quantization"]
1082
            required_fields = ["group_size", "kv_cache_quant_algo", "exclude_modules"]
1083
1084
1085
1086
1087
1088
            missing_fields = [
                field for field in required_fields if field not in quant_config
            ]
            if missing_fields:
                raise ValueError(
                    f"NVFP4 quantization requires the following fields in "
1089
1090
1091
1092
1093
                    f"hf_quant_config.json: {missing_fields}"
                )

        return cls(
            is_checkpoint_nvfp4_serialized,
1094
            kv_cache_quant_method,
1095
1096
1097
            exclude_modules,
            group_size,
        )
1098
1099
1100
1101
1102


class ModelOptNvFp4LinearMethod(LinearMethodBase):
    """Linear method for Model Optimizer NVFP4.
    Supports loading NVFP4 checkpoints with the following structure:
1103

1104
1105
1106
1107
1108
1109
1110
    input_scale: torch.float32, scalar ,
    weight: NVFP4(represented as byte) Shape: [1, X, y/2]
    weight_scale: FP8-E4M3, Shape: [X, Y], aka per block scale,
    weight_scale_2: torch.float32, scalar,
    Args: quant_config: The ModelOpt quantization config.
    """

1111
    def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
1112
        self.quant_config = quant_config
1113
        self.marlin_input_dtype = None
1114
        self.backend = select_nvfp4_linear_backend()
1115

1116
1117
1118
1119
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
1120
        output_partition_sizes: list[int],
1121
1122
1123
1124
1125
1126
1127
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size
        if not self.quant_config.is_checkpoint_nvfp4_serialized:
1128
1129
1130
1131
            raise ValueError(
                "NVFP4 quantization was selected, "
                " dynamic quantization is not supported."
            )
1132
1133
1134
1135
1136
1137
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition

1138
1139
1140
1141
        if input_size_per_partition % 16 != 0:
            raise ValueError(
                "Unsupported model when in features size is not multiple of 16"
            )
1142
        # The nvfp4 weight is still represented as
1143
1144
1145
1146
1147
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_nvfp4_serialized
            else params_dtype
        )
1148
1149
1150
1151
1152
1153
        # Weight
        weight = ModelWeightParameter(
            data=torch.empty(
                # 2 fp4 items are packed in the input dimension
                layer.output_size_per_partition,
                layer.input_size_per_partition // 2,
1154
1155
                dtype=torch.uint8,
            ),
1156
1157
            input_dim=1,
            output_dim=0,
1158
1159
            weight_loader=weight_loader,
        )
1160
1161
        layer.register_parameter("weight", weight)

1162
1163
        # Input Global Scale
        input_global_scale = PerTensorScaleParameter(
1164
1165
1166
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )
1167
        layer.register_parameter("input_scale", input_global_scale)
1168

1169
1170
        # Weight Global Scale
        weight_global_scale = PerTensorScaleParameter(
1171
1172
1173
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )
1174
        layer.register_parameter("weight_scale_2", weight_global_scale)
1175
1176

        # Per Block Weight Scale
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
        weight_scale = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition // self.quant_config.group_size,
                dtype=weight_dtype,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
1187
1188
1189

        layer.register_parameter("weight_scale", weight_scale)

1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        # Rename ModelOpt checkpoint names to standardized names
        input_global_scale = layer.input_scale.max().to(torch.float32)
        layer.input_global_scale = Parameter(input_global_scale, requires_grad=False)
        del layer.input_scale
        weight_global_scale = layer.weight_scale_2.max().to(torch.float32)
        layer.weight_global_scale = Parameter(weight_global_scale, requires_grad=False)
        del layer.weight_scale_2

        # Pre-compute alpha and inverse for runtime quantization
1200
        layer.alpha = Parameter(
1201
            layer.input_global_scale * layer.weight_global_scale, requires_grad=False
1202
        )
1203
1204
        layer.input_global_scale_inv = Parameter(
            (1.0 / layer.input_global_scale).to(torch.float32), requires_grad=False
1205
        )
1206

1207
1208
        # Convert layer to NVFP4 linear kernel format
        convert_to_nvfp4_linear_kernel_format(self.backend, layer)
1209
1210
1211
1212
1213

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
1214
        bias: torch.Tensor | None = None,
1215
    ) -> torch.Tensor:
1216
1217
1218
1219
1220
        return apply_nvfp4_linear(
            backend=self.backend,
            layer=layer,
            x=x,
            bias=bias,
1221
        )
1222

1223
1224
1225
1226

class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
    """
    MoE Method for FP4 Quantization.
1227
    Args:
1228
1229
1230
        quant_config: NVFP4 Quant Config
    """

1231
1232
1233
    def __init__(
        self,
        quant_config: ModelOptNvFp4Config,
1234
        moe_config: FusedMoEConfig,
1235
    ) -> None:
1236
        super().__init__(moe_config)
1237
        self.quant_config = quant_config
1238
1239
1240
1241
1242
1243
1244
        # Select experts implementation.
        self.nvfp4_backend, self.experts_cls = select_nvfp4_moe_backend(
            config=self.moe,
            weight_key=kNvfp4Static,
            activation_key=kNvfp4Dynamic,
        )

1245
1246
1247
        self.use_global_sf = is_global_sf_supported_for_nvfp4_backend(
            self.nvfp4_backend
        )
1248

1249
1250
1251
1252
    def maybe_make_prepare_finalize(
        self,
        routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
    ) -> mk.FusedMoEPrepareAndFinalize | None:
1253
1254
1255
1256
        raise ValueError(
            f"{self.__class__.__name__} uses the new modular kernel initialization "
            "logic. This function should not be called."
        )
1257

1258
1259
1260
    def select_gemm_impl(
        self,
        prepare_finalize: mk.FusedMoEPrepareAndFinalize,
1261
        layer: torch.nn.Module,
1262
    ) -> mk.FusedMoEPermuteExpertsUnpermute:
1263
1264
1265
        raise ValueError(
            f"{self.__class__.__name__} uses the new modular kernel initialization "
            "logic. This function should not be called."
1266
        )
1267

1268
1269
1270
1271
1272
1273
    def uses_weight_scale_2_pattern(self) -> bool:
        """
        FP4 variants use 'weight_scale_2' pattern for per-tensor weight scales.
        """
        return True

1274
1275
1276
1277
1278
1279
1280
1281
1282
    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
1283
        assert self.quant_config.is_checkpoint_nvfp4_serialized
1284

1285
1286
        layer.num_experts = num_experts
        layer.params_dtype = params_dtype
1287
1288
1289
1290
        layer.quant_config = self.quant_config
        weight_dtype = torch.uint8
        weight_scale_dtype = torch.float8_e4m3fn
        weight_loader = extra_weight_attrs.get("weight_loader")
1291
        global_num_experts = extra_weight_attrs.get("global_num_experts")
1292
        w13_num_shards = 2 if self.moe.is_act_and_mul else 1
1293
1294
1295
1296
        # GEMM 1
        w13_weight = ModelWeightParameter(
            data=torch.empty(
                num_experts,
1297
                w13_num_shards * intermediate_size_per_partition,
1298
1299
                # 2 fp4 items are packed in the input dimension
                hidden_size // 2,
1300
1301
                dtype=weight_dtype,
            ),
1302
1303
            input_dim=1,
            output_dim=2,
1304
1305
            weight_loader=weight_loader,
        )
1306
1307
1308
1309
1310
1311
1312
1313
1314
        layer.register_parameter("w13_weight", w13_weight)

        # GEMM 2
        w2_weight = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                hidden_size,
                # 2 fp4 items are packed in the input dimension
                intermediate_size_per_partition // 2,
1315
1316
                dtype=weight_dtype,
            ),
1317
1318
            input_dim=1,
            output_dim=2,
1319
1320
            weight_loader=weight_loader,
        )
1321
1322
1323
1324
1325
        layer.register_parameter("w2_weight", w2_weight)

        w13_weight_scale = ModelWeightParameter(
            data=torch.empty(
                num_experts,
1326
                w13_num_shards * intermediate_size_per_partition,
1327
1328
                # 2 fp4 items are packed in the input dimension
                hidden_size // self.quant_config.group_size,
1329
1330
                dtype=weight_scale_dtype,
            ),
1331
1332
            input_dim=1,
            output_dim=2,
1333
1334
            weight_loader=weight_loader,
        )
1335
1336
1337
1338
1339
1340
1341
        layer.register_parameter("w13_weight_scale", w13_weight_scale)

        w2_weight_scale = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                hidden_size,
                # 2 fp4 items are packed in the input dimension
1342
1343
1344
                intermediate_size_per_partition // self.quant_config.group_size,
                dtype=weight_scale_dtype,
            ),
1345
1346
            input_dim=1,
            output_dim=2,
1347
1348
            weight_loader=weight_loader,
        )
1349
1350
1351
        layer.register_parameter("w2_weight_scale", w2_weight_scale)

        extra_weight_attrs.update(
1352
1353
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
        )
1354
1355

        w13_weight_scale_2 = PerTensorScaleParameter(
1356
            data=torch.empty(num_experts, w13_num_shards, dtype=torch.float32),
1357
1358
            weight_loader=weight_loader,
        )
1359
1360
1361
1362
        layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)

        w2_weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(num_experts, dtype=torch.float32),
1363
1364
            weight_loader=weight_loader,
        )
1365
1366
1367
        layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2)

        extra_weight_attrs.update(
1368
1369
            {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
        )
1370

1371
1372
        global_sf_num_experts = (
            global_num_experts if self.use_global_sf else num_experts
1373
        )
1374
        w13_input_scale = PerTensorScaleParameter(
1375
            data=torch.empty(
1376
                global_sf_num_experts,
1377
                w13_num_shards,
1378
1379
                dtype=torch.float32,
            ),
1380
1381
            weight_loader=weight_loader,
        )
1382
1383
        layer.register_parameter("w13_input_scale", w13_input_scale)

1384
        w2_input_scale = PerTensorScaleParameter(
1385
            data=torch.empty(global_sf_num_experts, dtype=torch.float32),
1386
1387
            weight_loader=weight_loader,
        )
1388
1389
        layer.register_parameter("w2_input_scale", w2_input_scale)

1390
    def process_weights_after_loading(self, layer: FusedMoE) -> None:
1391
1392
1393
        """
        Convert NVFP4 MoE weights into kernel format and setup the kernel.
        """
1394

1395
        # Use a single gscale for w13.
1396
        if self.moe.is_act_and_mul and not torch.allclose(
1397
1398
            layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]
        ):
1399
1400
            logger.warning_once(
                "w1_weight_scale_2 must match w3_weight_scale_2. "
1401
1402
                "Accuracy may be affected."
            )
1403
        w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0].contiguous()
1404

1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
        (
            w13,
            w13_scale,
            w13_scale_2,
            a13_scale,
            w2,
            w2_scale,
            w2_scale_2,
            a2_scale,
        ) = convert_to_nvfp4_moe_kernel_format(
            nvfp4_backend=self.nvfp4_backend,
            layer=layer,
            w13=layer.w13_weight,
            w13_scale=layer.w13_weight_scale,
            w13_scale_2=w13_weight_scale_2,
            a13_scale=layer.w13_input_scale,
            w2=layer.w2_weight,
            w2_scale=layer.w2_weight_scale,
            w2_scale_2=layer.w2_weight_scale_2,
            a2_scale=layer.w2_input_scale,
            is_act_and_mul=self.moe.is_act_and_mul,
1426
        )
1427

1428
1429
1430
1431
1432
1433
1434
1435
        replace_parameter(layer, "w13_weight", w13)
        replace_parameter(layer, "w13_weight_scale", w13_scale)
        replace_parameter(layer, "w13_weight_scale_2", w13_scale_2)
        replace_parameter(layer, "w13_input_scale", a13_scale)
        replace_parameter(layer, "w2_weight", w2)
        replace_parameter(layer, "w2_weight_scale", w2_scale)
        replace_parameter(layer, "w2_weight_scale_2", w2_scale_2)
        replace_parameter(layer, "w2_input_scale", a2_scale)
1436

1437
1438
1439
1440
        # Setup modular kernel for TP case and naive DP/EP case.
        # In non-naive DP/EP case, we will create a ModularKernelMethod.
        # TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
        # in both cases.
1441
        self.moe_quant_config = self.get_fused_moe_quant_config(layer)
1442
        if self.moe_quant_config:
1443
            assert self.experts_cls is not None
1444
            self.moe_mk = make_nvfp4_moe_kernel(
1445
                moe_quant_config=self.moe_quant_config,
1446
                moe_config=self.moe,
1447
                experts_cls=self.experts_cls,
1448
1449
                shared_experts=layer.shared_experts,
                routing_tables=layer._maybe_init_expert_routing_tables(),
1450
            )
1451

1452
1453
1454
1455
    @property
    def do_post_quant_allgather(self):
        return self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM

1456
1457
1458
1459
1460
1461
1462
    def prepare_dp_allgather_tensor(
        self,
        layer: FusedMoE,
        hidden_states: torch.Tensor,
        router_logits: torch.Tensor,
    ) -> tuple[torch.Tensor, list[torch.Tensor]]:
        """Optionally prepare extra tensors to carry through DP allgather/EP."""
1463
1464
1465
1466
1467
1468
        if self.nvfp4_backend != NvFp4MoeBackend.FLASHINFER_TRTLLM:
            raise RuntimeError(
                "prepare_dp_allgather_tensor is only supported for "
                "FlashInfer TRTLLM NVFP4 MoE backend."
            )

1469
1470
1471
1472
        import flashinfer

        hidden_states_fp4, hidden_states_sf = flashinfer.fp4_quantize(
            hidden_states,
1473
            layer.a1_gscale,
1474
1475
1476
1477
1478
            is_sf_swizzled_layout=False,
        )
        extra_tensors: list[torch.Tensor] = [hidden_states_sf]
        return hidden_states_fp4, extra_tensors

1479
    def get_fused_moe_quant_config(
1480
        self, layer: torch.nn.Module
1481
    ) -> FusedMoEQuantConfig | None:
1482
1483
1484
        return make_nvfp4_moe_quant_config(
            backend=self.nvfp4_backend,
            w13_scale=layer.w13_weight_scale,
1485
            w2_scale=layer.w2_weight_scale,
1486
1487
1488
1489
            w13_scale_2=layer.w13_weight_scale_2,
            w2_scale_2=layer.w2_weight_scale_2,
            a13_scale=layer.w13_input_scale,
            a2_scale=layer.w2_input_scale,
1490
1491
        )

1492
1493
1494
1495
    @property
    def supports_eplb(self) -> bool:
        return True

1496
1497
1498
1499
1500
1501
1502
1503
    @property
    def is_monolithic(self) -> bool:
        return (
            self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
            and not self.moe.moe_parallel_config.enable_eplb
        )

    def apply_monolithic(
1504
        self,
1505
        layer: FusedMoE,
1506
1507
        x: torch.Tensor,
        router_logits: torch.Tensor,
1508
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
1509
1510
        assert self.is_monolithic
        assert (
1511
            self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
1512
            and not layer.enable_eplb
1513
        )
1514

1515
1516
1517
        return flashinfer_trtllm_fp4_moe(
            layer=layer,
            x=x,
1518
            router_logits=router_logits,
1519
1520
1521
1522
1523
1524
1525
            top_k=layer.top_k,
            activation=layer.activation,
            global_num_experts=layer.global_num_experts,
            num_expert_group=layer.num_expert_group,
            topk_group=layer.topk_group,
            custom_routing_function=layer.custom_routing_function,
            e_score_correction_bias=layer.e_score_correction_bias,
1526
        )
1527

1528
1529
1530
1531
1532
1533
    def apply(
        self,
        layer: FusedMoE,
        x: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
1534
        shared_experts_input: torch.Tensor | None,
1535
1536
1537
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        assert not self.is_monolithic

1538
        # EPLB path
1539
1540
        if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
            assert layer.enable_eplb
1541
1542
1543
1544
1545
1546
            return flashinfer_trtllm_fp4_routed_moe(
                layer=layer,
                x=x,
                topk_ids=topk_ids,
                topk_weights=topk_weights,
                top_k=layer.top_k,
1547
                activation=layer.activation,
1548
1549
                global_num_experts=layer.global_num_experts,
            )
1550
        else:
1551
1552
            assert self.moe_mk is not None
            return self.moe_mk(
1553
1554
1555
1556
1557
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
1558
1559
1560
1561
                activation=layer.activation,
                global_num_experts=layer.global_num_experts,
                expert_map=layer.expert_map,
                apply_router_weight_on_input=layer.apply_router_weight_on_input,
1562
                shared_experts_input=shared_experts_input,
1563
            )
1564
1565
1566
1567
1568


ModelOptNvFp4Config.LinearMethodCls = ModelOptNvFp4LinearMethod
ModelOptNvFp4Config.FusedMoEMethodCls = ModelOptNvFp4FusedMoE
ModelOptNvFp4Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692


class ModelOptMxFp8Config(ModelOptQuantConfigBase):
    """Config class for ModelOpt MXFP8."""

    def __init__(
        self,
        is_checkpoint_mxfp8_serialized: bool,
        kv_cache_quant_algo: str | None,
        exclude_modules: list[str],
    ) -> None:
        super().__init__(exclude_modules)
        self.is_checkpoint_mxfp8_serialized = is_checkpoint_mxfp8_serialized

        if not is_checkpoint_mxfp8_serialized:
            raise ValueError(
                "MXFP8 quantization requires a serialized checkpoint. "
                "Dynamic quantization is not supported."
            )

        logger.warning(
            "Detected ModelOpt MXFP8 checkpoint. Please note that "
            "the format is experimental and could change in future."
        )

        self.kv_cache_quant_algo = kv_cache_quant_algo

    def get_name(self) -> QuantizationMethods:
        return "modelopt_mxfp8"

    def get_supported_act_dtypes(self) -> list[torch.dtype]:
        return [torch.bfloat16]

    @classmethod
    def get_min_capability(cls) -> int:
        # MXFP8 hardware acceleration requires Blackwell (SM100) or newer
        return 100

    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> "QuantizeMethodBase | None":
        # MXFP8 does not yet support MoE models
        if isinstance(layer, FusedMoE):
            raise NotImplementedError(
                "MXFP8 quantization does not yet support MoE models. "
                "Please use FP8 or NVFP4 quantization for MoE models."
            )
        return super().get_quant_method(layer, prefix)

    @classmethod
    def override_quantization_method(
        cls, hf_quant_cfg, user_quant
    ) -> QuantizationMethods | None:
        """Detect if this ModelOpt MXFP8 config should be used based on
        quantization config."""
        if hf_quant_cfg is None:
            return None

        # Use the community standard 'quant_method'
        quant_method = hf_quant_cfg.get("quant_method", "").lower()

        # Only proceed if the method is explicitly "modelopt"
        if quant_method != "modelopt":
            return None

        # Look for ModelOpt-specific config structure
        if "quantization" in hf_quant_cfg:
            quant_config = hf_quant_cfg["quantization"]
            if isinstance(quant_config, dict):
                quant_algo = str(quant_config.get("quant_algo", "")).upper()
                if "MXFP8" in quant_algo:
                    return "modelopt_mxfp8"
        else:
            # Check for compressed-tensors style config with specific quant_algo
            quant_algo = str(hf_quant_cfg.get("quant_algo", "")).upper()
            if "MXFP8" in quant_algo:
                return "modelopt_mxfp8"

        return None

    @classmethod
    def _from_config(
        cls,
        *,
        quant_method: str,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
        original_config: dict[str, Any],
        **kwargs: Any,
    ) -> "ModelOptMxFp8Config":
        is_checkpoint_mxfp8_serialized = "MXFP8" in quant_method.upper()

        # For MXFP8, validate required fields in the config
        if is_checkpoint_mxfp8_serialized and "quantization" in original_config:
            quant_config = original_config["quantization"]
            required_fields = ["kv_cache_quant_algo", "exclude_modules"]
            missing_fields = [
                field for field in required_fields if field not in quant_config
            ]
            if missing_fields:
                raise ValueError(
                    f"MXFP8 quantization requires the following fields in "
                    f"hf_quant_config.json: {missing_fields}"
                )

        return cls(
            is_checkpoint_mxfp8_serialized,
            kv_cache_quant_method,
            exclude_modules,
        )


class ModelOptMxFp8LinearMethod(LinearMethodBase):
    """Linear method for ModelOpt MXFP8 quantization."""

    def __init__(self, quant_config: ModelOptMxFp8Config) -> None:
        self.quant_config = quant_config

        if not self.quant_config.is_checkpoint_mxfp8_serialized:
            raise ValueError(
                "MXFP8 currently only supports serialized checkpoints. "
                "Dynamic quantization is not supported."
            )

1693
1694
1695
        self.backend: Mxfp8LinearBackend = Mxfp8LinearBackend.FLASHINFER_CUTLASS
        self.mxfp8_linear_op = Mxfp8LinearOp(backend=self.backend)
        logger.info_once("Using %s backend for MXFP8 GEMM", self.backend.value)
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size

        if not self.quant_config.is_checkpoint_mxfp8_serialized:
            raise ValueError(
                "MXFP8 quantization was selected, but checkpoint is not "
                "MXFP8 serialized. Dynamic quantization is not supported."
            )

        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition

        if input_size_per_partition % MXFP8_BLOCK_SIZE != 0:
            raise ValueError(
                f"MXFP8 requires input dimension to be divisible by "
                f"{MXFP8_BLOCK_SIZE}, got {input_size_per_partition}"
            )

        # Weight tensor: FP8 E4M3 format
        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition,
                dtype=MXFP8_VALUE_DTYPE,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
        layer.register_parameter("weight", weight)

        # Weight scale tensor (E8M0 encoded as uint8), one scale per block of 32 along K
        weight_scale = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition // MXFP8_BLOCK_SIZE,
                dtype=MXFP8_SCALE_DTYPE,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
        layer.register_parameter("weight_scale", weight_scale)

1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
    def _process_weights_after_loading_scale_2d(self, layer: torch.nn.Module) -> None:
        """Not swizzled - MXFP8 GEMM emulation"""
        weight = layer.weight.data  # [N, K]
        N, K = weight.shape
        scale_k = K // MXFP8_BLOCK_SIZE

        # Slice weight_scale to match weight dimensions (handles padding)
        weight_scale = layer.weight_scale.data[:N, :scale_k].contiguous()

        layer.weight = Parameter(weight.contiguous(), requires_grad=False)
        layer.weight_scale = Parameter(weight_scale, requires_grad=False)

    def _process_weights_after_loading_scale_1d(self, layer: torch.nn.Module) -> None:
        """Swizzled - MXFP8 GEMM Flashinfer CUTLASS"""
        weight = layer.weight.data  # [N, K]
        N, K = weight.shape

        # 2D weight scale
        weight_scale = layer.weight_scale.data

        # Swizzle the weight scales
        scale_k = K // MXFP8_BLOCK_SIZE
        weight_scale_2d = weight_scale[:N, :scale_k].contiguous()
        weight_scale_swizzled = swizzle_mxfp8_scale(weight_scale_2d, M=N, K=K)

        layer.weight = Parameter(weight.contiguous(), requires_grad=False)
        layer.weight_scale = Parameter(
            weight_scale_swizzled.contiguous(), requires_grad=False
        )

1783
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
1784
        # Validate weight tensor
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
        if layer.weight.ndim != 2:
            raise ValueError(
                f"MXFP8 weight must be 2D tensor [N, K], got {layer.weight.ndim}D "
                f"with shape {tuple(layer.weight.shape)}"
            )

        if layer.weight.dtype != MXFP8_VALUE_DTYPE:
            raise ValueError(
                f"MXFP8 weight must be {MXFP8_VALUE_DTYPE} (FP8 E4M3), "
                f"got {layer.weight.dtype}. The checkpoint may not be properly "
                f"quantized with MXFP8."
            )

1798
1799
1800
1801
1802
1803
1804
1805
        # Validate weight scale tensor (should be 2D, not swizzled)
        assert layer.weight_scale.ndim == 2, (
            f"MXFP8 weight scale must be 2D, got {layer.weight_scale.ndim}D"
        )
        assert layer.weight_scale.dtype == MXFP8_SCALE_DTYPE, (
            f"MXFP8 weight scale must be {MXFP8_SCALE_DTYPE},"
            f" got {layer.weight_scale.dtype}"
        )
1806

1807
1808
1809
1810
        if self.backend == Mxfp8LinearBackend.EMULATION:
            # Swizzled layout is not used
            self._process_weights_after_loading_scale_2d(layer)
            return
1811

1812
1813
1814
        assert self.backend == Mxfp8LinearBackend.FLASHINFER_CUTLASS
        # Swizzled layout is required for Flashinfer CUTLASS
        self._process_weights_after_loading_scale_1d(layer)
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        if layer.weight.dtype != MXFP8_VALUE_DTYPE:
            raise ValueError(
                f"Weight dtype {layer.weight.dtype} != expected {MXFP8_VALUE_DTYPE}"
            )
        if layer.weight_scale.dtype != MXFP8_SCALE_DTYPE:
            raise ValueError(
                f"Weight scale dtype {layer.weight_scale.dtype} != "
                f"expected {MXFP8_SCALE_DTYPE}"
            )

        return self.mxfp8_linear_op.apply(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            out_dtype=x.dtype,
            bias=bias,
        )


# Register the method classes for ModelOptMxFp8Config
ModelOptMxFp8Config.LinearMethodCls = ModelOptMxFp8LinearMethod
ModelOptMxFp8Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod