modelopt.py 55.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from fnmatch import fnmatch
5
from typing import TYPE_CHECKING, Any
6
7
8
9

import torch
from torch.nn.parameter import Parameter

10
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
11
from vllm.logger import init_logger
12
from vllm.model_executor.layers.attention import Attention
13
from vllm.model_executor.layers.fused_moe.config import (
14
    FusedMoEConfig,
15
16
    FusedMoEQuantConfig,
)
17
from vllm.model_executor.layers.fused_moe.layer import (
18
19
20
21
    FusedMoE,
    FusedMoEMethodBase,
    FusedMoeWeightScaleSupported,
)
22
23
24
25
26
27
28
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
    Fp8MoeBackend,
    convert_to_fp8_moe_kernel_format,
    make_fp8_moe_kernel,
    make_fp8_moe_quant_config,
    select_fp8_moe_backend,
)
29
30
31
32
33
34
35
36
from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
    NvFp4MoeBackend,
    convert_to_nvfp4_moe_kernel_format,
    is_global_sf_supported_for_nvfp4_backend,
    make_nvfp4_moe_kernel,
    make_nvfp4_moe_quant_config,
    select_nvfp4_moe_backend,
)
37
38
39
40
41
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
)
42
from vllm.model_executor.layers.quantization import QuantizationMethods
43
from vllm.model_executor.layers.quantization.base_config import (
44
45
46
    QuantizationConfig,
    QuantizeMethodBase,
)
47
48
49
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
    init_fp8_linear_kernel,
)
50
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
51
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
52
    flashinfer_trtllm_fp4_moe,
53
    flashinfer_trtllm_fp4_routed_moe,
54
)
55
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
56
    apply_fi_trtllm_fp8_per_tensor_moe,
57
)
58
59
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
    W8A8BlockFp8LinearOp,
60
61
    process_fp8_input_tensor_strategy_moe,
    process_fp8_weight_tensor_strategy_moe,
62
)
63
64
65
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
    get_marlin_input_dtype,
)
66
67
68
69
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
    apply_nvfp4_linear,
    convert_to_nvfp4_linear_kernel_format,
    select_nvfp4_linear_backend,
70
)
71
from vllm.model_executor.layers.quantization.utils.quant_utils import (
72
73
    GroupShape,
    is_layer_skipped,
74
75
76
    kFp8DynamicTokenSym,
    kFp8StaticTensorSym,
    kFp8StaticTokenSym,
77
78
    kNvfp4Dynamic,
    kNvfp4Static,
79
)
80
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
81
    cutlass_block_fp8_supported,
82
83
    requantize_with_max_scale,
)
84
85
86
87
88
89
from vllm.model_executor.parameter import (
    BlockQuantScaleParameter,
    ChannelQuantScaleParameter,
    ModelWeightParameter,
    PerTensorScaleParameter,
)
90
from vllm.model_executor.utils import replace_parameter
91

92
93
94
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

95
96
logger = init_logger(__name__)

97
98
99
100
101
102
103
104
105
106
QUANT_ALGOS = [
    # FP8 (per-tensor weight + optional static activation scale).
    "FP8",
    # FP8 per-channel weight scale + per-token activation scale.
    "FP8_PER_CHANNEL_PER_TOKEN",
    # FP8 per-block weight-only (ModelOpt may emit this as lowercase).
    "FP8_PB_WO",
    # FP4
    "NVFP4",
]
107
KV_CACHE_QUANT_ALGOS = ["FP8"]
108
109


110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
    """

    def __init__(self, quant_config: "ModelOptQuantConfigBase"):
        super().__init__(quant_config)


class ModelOptQuantConfigBase(QuantizationConfig):
    LinearMethodCls: type = LinearMethodBase
    FusedMoEMethodCls: type = FusedMoEMethodBase
    KVCacheMethodCls: type = BaseKVCacheMethod

    def __init__(
        self,
        exclude_modules: list[str],
    ):
        super().__init__()
        self.exclude_modules: list[str] = exclude_modules

    def is_layer_excluded(self, prefix: str) -> bool:
        """
        Check if a layer should be excluded from quantization.

        Handles both exact matching (for fused layers) and ModelOpt wildcard matching.

        The ModelOpt exclude_modules list is a list of wildcards.
        """
        if len(self.exclude_modules) == 0:
            return False

        # First check exact matching with fused layer support
        if is_layer_skipped(prefix, self.exclude_modules, self.packed_modules_mapping):
            return True

        # TODO: This special hard coded logic is not needed for quantized checkpoints
        # generated by ModelOpt >= 0.39.0 where they are handled natually by the
        # exclude_modules config. But need to keep them for loading quantized
        # checkpoints generated by older versions. Then check substring matching
        # for patterns not caught by exact match
        for exclude_module in self.exclude_modules:
            # Skip exact matches already handled above
            if exclude_module != prefix and (
                exclude_module in prefix
                or (
                    prefix.startswith("language_model.")
                    and exclude_module in prefix.removeprefix("language_model.")
                )
            ):
                return True

        # modelopt exclude modules are not simple strings, they are wildcards
        for wildcard_pattern in self.exclude_modules:
            if fnmatch(prefix, wildcard_pattern):
                return True

        return False

    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
171
    ) -> "QuantizeMethodBase | None":
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
        # handle kv-cache first so we can focus only on weight quantization thereafter
        if isinstance(layer, Attention):
            return self.KVCacheMethodCls(self)

        # handle exclusion
        if self.is_layer_excluded(prefix):
            if isinstance(layer, LinearBase):
                return UnquantizedLinearMethod()
            return None

        # TODO: This special hard coded logic is not needed for quantized checkpoints
        # generated by ModelOpt >= 0.39.0 where they are handled natually by the
        # exclude_modules config. But need to keep them for loading quantized
        # checkpoints generated by older versions. Then check substring matching
        # for patterns not caught by exact match
        if "vision_tower" in prefix or "vision_model" in prefix:
            return UnquantizedLinearMethod()

        # now, the layer is quantized, handle it here
        if isinstance(layer, LinearBase):
192
193
194
195
            quant_method = self.LinearMethodCls(self)
            if getattr(quant_method, "backend", "") == "marlin":
                quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
            return quant_method
196
        elif isinstance(layer, FusedMoE):
197
198
199
            quant_method = self.FusedMoEMethodCls(
                quant_config=self, moe_config=layer.moe_config
            )
200
201
202
            if getattr(quant_method, "backend", "") == "marlin":
                quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
            return quant_method
203
204
205
206
207

        return None

    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if len(self.exclude_modules) > 0:
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
            # This is a workaround for the weights remapping issue:
            # https://github.com/vllm-project/vllm/issues/28072
            # Right now, the Nvidia ModelOpt library use just one wildcard pattern:
            #        module_path*
            # It gets applied if the whole tree of modules rooted at module_path
            # is not quantized. Here we replace such pattern by 2 patterns that are
            # collectively equivalent to the original pattern:
            #        module_path
            #        module_path.*
            new_exclude_modules = []
            for exclude in self.exclude_modules:
                if len(exclude) >= 2 and exclude[-1] == "*" and exclude[-2] != ".":
                    new_exclude_modules.append(exclude[:-1])
                    new_exclude_modules.append(exclude[:-1] + ".*")
                else:
                    new_exclude_modules.append(exclude)

            self.exclude_modules = hf_to_vllm_mapper.apply_list(new_exclude_modules)
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274

    @staticmethod
    def get_config_filenames() -> list[str]:
        return ["hf_quant_config.json"]

    @classmethod
    def _from_config(
        cls,
        *,
        quant_method: str,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
        original_config: dict[str, Any],
        group_size: int | None,
    ) -> "ModelOptQuantConfigBase":
        raise NotImplementedError("Please implement this function in sub classes")

    @classmethod
    def from_config(cls, config: dict[str, Any]) -> "ModelOptQuantConfigBase":
        # Handle both ModelOpt format and compressed-tensors style format
        if "quantization" in config:
            # Traditional ModelOpt format:
            # {"quantization": {"quant_algo": "..."}}
            quant_config = cls.get_from_keys(config, ["quantization"])
            if not isinstance(quant_config, dict):
                raise ValueError("Expected 'quantization' to be a dictionary in config")

            quant_method = quant_config.get("quant_algo")

            # Handle kv_cache_quant_algo with proper type validation
            kv_cache_quant_method = quant_config.get("kv_cache_quant_algo")

            # Handle group_size with proper type validation
            group_size_raw = quant_config.get("group_size")

            # "exclude_modules" is the key in the legacy hf_quant_config.json
            exclude_modules = quant_config.get("exclude_modules", [])
        else:
            # Compressed-tensors style format:
            # {"quant_algo": "...", "quant_method": "modelopt"}
            quant_method = config.get("quant_algo")
            kv_cache_quant_method = config.get("kv_cache_quant_algo")
            # "ignore" is the key in config.json
            exclude_modules = config.get("ignore", [])
            group_size_raw = config.get("group_size")

        if not quant_method:
            raise ValueError("Missing 'quant_algo' in quantization config")

275
276
277
        # Normalize quant_algo for robust matching (ModelOpt may emit lowercase).
        quant_method = str(quant_method).upper()

278
279
280
281
282
283
284
285
        if kv_cache_quant_method is None:
            # No KV cache quantization, keep this branch just to have this comment
            pass
        elif not isinstance(kv_cache_quant_method, str):
            raise ValueError(
                f"kv_cache_quant_algo must be a string, got "
                f"{type(kv_cache_quant_method)}"
            )
286
287
        else:
            kv_cache_quant_method = kv_cache_quant_method.upper()
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322

        if not isinstance(exclude_modules, list):
            raise ValueError(
                f"exclude_modules must be a list, got {type(exclude_modules)}"
            )

        if group_size_raw is None:
            group_size = None
        elif isinstance(group_size_raw, int):
            group_size = group_size_raw
        else:
            try:
                group_size = int(group_size_raw)
            except (ValueError, TypeError):
                raise ValueError(
                    f"group_size must be an integer, got {type(group_size_raw)}"
                ) from None

        if quant_method not in QUANT_ALGOS:
            raise ValueError(
                f"ModelOpt currently only supports: {QUANT_ALGOS} "
                "quantizations in vLLM. Please check the "
                "`hf_quant_config.json` file for your model's "
                "quant configuration."
            )
        return cls._from_config(
            quant_method=quant_method,
            kv_cache_quant_method=kv_cache_quant_method,
            exclude_modules=exclude_modules,
            group_size=group_size,
            original_config=config,
        )


class ModelOptFp8Config(ModelOptQuantConfigBase):
323
324
325
326
    """Config class for ModelOpt FP8."""

    def __init__(
        self,
327
        quant_method: str,
328
329
330
        is_checkpoint_fp8_serialized: bool,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
331
    ) -> None:
332
        super().__init__(exclude_modules)
333
        self.quant_method = quant_method
334
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
335
        self.kv_cache_quant_method = kv_cache_quant_method
336
        if is_checkpoint_fp8_serialized:
337
            logger.warning(
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
                "Detected ModelOpt fp8 checkpoint (quant_algo=%s). Please note "
                "that the format is experimental and could change.",
                quant_method,
            )

        # Select LinearMethod implementation based on quant_algo.
        if self.quant_method == "FP8":
            self.LinearMethodCls = ModelOptFp8LinearMethod
        elif self.quant_method == "FP8_PER_CHANNEL_PER_TOKEN":
            self.LinearMethodCls = ModelOptFp8PcPtLinearMethod
        elif self.quant_method == "FP8_PB_WO":
            self.LinearMethodCls = ModelOptFp8PbWoLinearMethod
        else:
            raise ValueError(
                "Unsupported ModelOpt FP8 quant_algo for vLLM: "
                f"{self.quant_method}. Supported: FP8 / "
                "FP8_PER_CHANNEL_PER_TOKEN / FP8_PB_WO."
355
            )
356

357
    def get_name(self) -> QuantizationMethods:
358
359
        return "modelopt"

360
    def get_supported_act_dtypes(self) -> list[torch.dtype]:
361
362
363
364
365
366
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
        return 89

367
368
    @classmethod
    def override_quantization_method(
369
        cls, hf_quant_cfg, user_quant
370
    ) -> QuantizationMethods | None:
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
        """Detect if this ModelOpt config should be used based on
        quantization config."""

        if hf_quant_cfg is None:
            return None

        # Use the community standard 'quant_method'
        quant_method = hf_quant_cfg.get("quant_method", "").lower()

        # Only proceed if the method is explicitly "modelopt"
        if quant_method != "modelopt":
            return None

        # Look for ModelOpt-specific config structure
        if "quantization" in hf_quant_cfg:
            quant_config = hf_quant_cfg["quantization"]
            if isinstance(quant_config, dict):
388
389
                quant_algo = str(quant_config.get("quant_algo", ""))
                if "FP8" in quant_algo.upper():
390
391
392
                    return "modelopt"
        else:
            # Check for compressed-tensors style config with specific quant_algo
393
394
            quant_algo = str(hf_quant_cfg.get("quant_algo", ""))
            if "FP8" in quant_algo.upper():
395
396
397
398
                return "modelopt"

        return None

399
    @classmethod
400
401
402
403
404
405
406
407
408
    def _from_config(
        cls,
        *,
        quant_method: str,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
        original_config: dict[str, Any],
        **kwargs: Any,
    ) -> "ModelOptFp8Config":
409
        is_checkpoint_fp8_serialized = "FP8" in quant_method
410

411
412
413
414
415
416
        return cls(
            quant_method,
            is_checkpoint_fp8_serialized,
            kv_cache_quant_method,
            exclude_modules,
        )
417

418
419
420
421

class ModelOptFp8LinearMethod(LinearMethodBase):
    """Linear method for Model Optimizer static quantization.
    Supports loading FP8 checkpoints with static weight scale and
422
    activation scale. Future support might be added for dynamic
423
424
425
426
    scales.

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
427
    2. Only support float8_e4m3fn datatype
428
429
430
        Args: quant_config: The ModelOpt quantization config.
    """

431
    def __init__(self, quant_config: ModelOptFp8Config) -> None:
432
        self.quant_config = quant_config
433
434
435
436
437
        self.fp8_linear = init_fp8_linear_kernel(
            activation_quant_key=kFp8StaticTensorSym,
            weight_quant_key=kFp8StaticTensorSym,
            out_dtype=torch.get_default_dtype(),
            module_name=self.__class__.__name__,
438
        )
439
440
441
442
443

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
444
        output_partition_sizes: list[int],
445
446
447
448
449
450
451
452
453
454
455
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
456
457
458
459
460
461
462
463
464
465
466
467
468
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_fp8_serialized
            else params_dtype
        )
        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition, input_size_per_partition, dtype=weight_dtype
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
469
470
471
472
        layer.register_parameter("weight", weight)

        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
473
474
475
476
            weight_scale = PerTensorScaleParameter(
                data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
                weight_loader=weight_loader,
            )
477
478
479
            weight_scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("weight_scale", weight_scale)
            # INPUT SCALE
480
481
482
483
            scale = PerTensorScaleParameter(
                data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
                weight_loader=weight_loader,
            )
484
485
486
487

            scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("input_scale", scale)

488
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
489
490
491
492
        weight = layer.weight
        max_w_scale = layer.weight_scale.max()
        if not (layer.weight_scale == layer.weight_scale[0]).all():
            max_w_scale, weight = requantize_with_max_scale(
493
494
                layer.weight, layer.weight_scale, layer.logical_widths
            )
495
496
        layer.weight = Parameter(weight.t(), requires_grad=False)
        layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
497
        layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
498
499
500
501
502

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
503
        bias: torch.Tensor | None = None,
504
    ) -> torch.Tensor:
505
        return self.fp8_linear.apply_weights(layer, x, bias)
506
507


508
509
510
511
512
513
514
515
516
517
518
class ModelOptFp8PcPtLinearMethod(LinearMethodBase):
    """Linear method for ModelOpt FP8_PER_CHANNEL_PER_TOKEN checkpoints.

    Expected checkpoint structure (per Linear):
    - weight: fp8-e4m3fn, shape [out, in]
    - weight_scale: fp32, shape [out] (per-output-channel)
    - no input_scale (activations are dynamically quantized per-token)
    """

    def __init__(self, quant_config: ModelOptFp8Config) -> None:
        self.quant_config = quant_config
519
520
521
522
523
        self.fp8_linear = init_fp8_linear_kernel(
            activation_quant_key=kFp8DynamicTokenSym,
            weight_quant_key=kFp8StaticTokenSym,
            out_dtype=torch.get_default_dtype(),
            module_name=self.__class__.__name__,
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
        )

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size

        if not self.quant_config.is_checkpoint_fp8_serialized:
            raise ValueError(
                "FP8_PER_CHANNEL_PER_TOKEN currently only supports "
                "FP8-serialized checkpoints."
            )

        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition

        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition,
                dtype=torch.float8_e4m3fn,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
        layer.register_parameter("weight", weight)

        weight_scale = ChannelQuantScaleParameter(
            data=torch.empty(output_size_per_partition, dtype=torch.float32),
            output_dim=0,
            weight_loader=weight_loader,
        )
        weight_scale[:] = torch.finfo(torch.float32).min
        layer.register_parameter("weight_scale", weight_scale)

570
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
571
572
573
574
575
576
577
578
579
        layer.weight = Parameter(layer.weight.t(), requires_grad=False)
        layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
580
        return self.fp8_linear.apply_weights(layer, x, bias)
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670


class ModelOptFp8PbWoLinearMethod(LinearMethodBase):
    """Linear method for ModelOpt FP8_PB_WO checkpoints.

    ModelOpt exports `weight_scale` as a 4D tensor:
      [out_blk, 1, in_blk, 1]
    where block size is typically 128 for both dims.

    vLLM executes it as FP8 GEMM with *dynamic per-token* activation quant.
    """

    _WEIGHT_BLOCK_SIZE: tuple[int, int] = (128, 128)

    def __init__(self, quant_config: ModelOptFp8Config) -> None:
        self.quant_config = quant_config
        block_n, block_k = self._WEIGHT_BLOCK_SIZE
        self.weight_block_size = list(self._WEIGHT_BLOCK_SIZE)
        self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
            weight_group_shape=GroupShape(block_n, block_k),
            act_quant_group_shape=GroupShape(1, block_k),
            cutlass_block_fp8_supported=cutlass_block_fp8_supported(),
            use_aiter_and_is_supported=False,
        )

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size

        if not self.quant_config.is_checkpoint_fp8_serialized:
            raise ValueError(
                "FP8_PB_WO currently only supports FP8-serialized checkpoints."
            )

        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition

        # Expose block size so the v2 weight loaders can translate offsets from
        # element-space -> block-space for BlockQuantScaleParameter.
        layer.weight_block_size = self.weight_block_size

        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition,
                dtype=torch.float8_e4m3fn,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
        layer.register_parameter("weight", weight)

        block_n, block_k = self._WEIGHT_BLOCK_SIZE
        if output_size_per_partition % block_n != 0:
            raise ValueError(
                "ModelOpt FP8_PB_WO requires out_features divisible by "
                f"{block_n}, got {output_size_per_partition}."
            )
        if input_size_per_partition % block_k != 0:
            raise ValueError(
                "ModelOpt FP8_PB_WO requires in_features divisible by "
                f"{block_k}, got {input_size_per_partition}."
            )

        out_blks = output_size_per_partition // block_n
        in_blks = input_size_per_partition // block_k

        # Match ModelOpt's exported shape so weight loading works without a
        # custom loader: [out_blk, 1, in_blk, 1]
        weight_scale = BlockQuantScaleParameter(
            data=torch.empty((out_blks, 1, in_blks, 1), dtype=torch.float32),
            input_dim=2,
            output_dim=0,
            weight_loader=weight_loader,
        )
        weight_scale[:] = torch.finfo(torch.float32).min
        layer.register_parameter("weight_scale", weight_scale)

671
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
        # Keep weight in [out, in] layout for W8A8BlockFp8LinearOp.
        layer.weight = Parameter(layer.weight.data, requires_grad=False)

        scale = layer.weight_scale
        if scale.dim() == 4:
            # [out_blk, 1, in_blk, 1] -> [out_blk, in_blk]
            scale = scale.squeeze(1).squeeze(-1)
        elif scale.dim() != 2:
            raise ValueError(
                "Unexpected ModelOpt FP8_PB_WO weight_scale shape: "
                f"{tuple(scale.shape)}."
            )

        layer.weight_scale = Parameter(scale.contiguous(), requires_grad=False)

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        return self.w8a8_block_fp8_linear.apply(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            input_scale=None,
            bias=bias,
        )


702
703
704
705
706
707
708
709
class ModelOptFp8MoEMethod(FusedMoEMethodBase):
    """MoE method for ModelOpt FP8.
    Supports loading FP8 checkpoints with static weight scale and
    activation scale.
    Args:
        quant_config: The ModelOpt quantization config.
    """

710
711
712
    def __init__(
        self,
        quant_config: ModelOptFp8Config,
713
        moe_config: FusedMoEConfig,
714
    ) -> None:
715
        super().__init__(moe_config)
716
        self.quant_config = quant_config
717
        assert self.quant_config.is_checkpoint_fp8_serialized
718
719
720
721
722
723

        # Select Fp8 MoE backend
        self.fp8_backend, self.experts_cls = select_fp8_moe_backend(
            config=self.moe,
            weight_key=kFp8StaticTensorSym,
            activation_key=kFp8StaticTensorSym,
724
        )
725

726
    def maybe_make_prepare_finalize(
727
        self,
728
        routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
729
    ) -> mk.FusedMoEPrepareAndFinalize | None:
730
731
732
733
        raise ValueError(
            f"{self.__class__.__name__} uses the new modular kernel initialization "
            "logic. This function should not be called."
        )
734
735
736
737

    def select_gemm_impl(
        self,
        prepare_finalize: mk.FusedMoEPrepareAndFinalize,
738
        layer: torch.nn.Module,
739
    ) -> mk.FusedMoEPermuteExpertsUnpermute:
740
741
742
        raise ValueError(
            f"{self.__class__.__name__} uses the new modular kernel initialization "
            "logic. This function should not be called."
743
        )
744
745
746
747
748
749
750
751
752
753

    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
754
755
756
        layer.orig_dtype = params_dtype
        layer.num_experts = num_experts

757
        # Use FP8 dtype if checkpoint is serialized
758
759
760
761
762
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_fp8_serialized
            else params_dtype
        )
763
764
        weight_loader = extra_weight_attrs.get("weight_loader")

765
        w13_num_shards = 2 if self.moe.is_act_and_mul else 1
766

767
        w13_weight = ModelWeightParameter(
768
769
            data=torch.empty(
                num_experts,
770
                w13_num_shards * intermediate_size_per_partition,
771
772
773
                hidden_size,
                dtype=weight_dtype,
            ),
774
775
776
777
778
779
780
            input_dim=2,
            output_dim=1,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_weight", w13_weight)

        w2_weight = ModelWeightParameter(
781
782
783
784
785
786
            data=torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=weight_dtype,
            ),
787
788
789
790
791
792
            input_dim=2,
            output_dim=1,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w2_weight", w2_weight)

793
794
795
796
797
798
        # WEIGHT SCALES - Per-tensor scaling for ModelOpts
        # For gated MoE, allocate 2 scales for w1 and w3 respectively.
        # They will be combined to a single scale after weight loading.
        # For non-gated MoE, allocate 1 scale for w13.
        w13_weight_scale = PerTensorScaleParameter(
            data=torch.full(
799
                (num_experts, w13_num_shards),
800
801
802
803
804
805
806
807
808
809
810
                1.0,
                dtype=torch.float32,
            ),
            weight_loader=weight_loader,
        )
        w2_weight_scale = PerTensorScaleParameter(
            data=torch.full((num_experts,), 1.0, dtype=torch.float32),
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_weight_scale", w13_weight_scale)
        layer.register_parameter("w2_weight_scale", w2_weight_scale)
811

812
813
814
815
        # INPUT SCALES - Per-tensor scaling for ModelOpt
        w13_input_scale = PerTensorScaleParameter(
            data=torch.full((num_experts,), 1.0, dtype=torch.float32),
            weight_loader=weight_loader,
816
        )
817
818
819
820
821
822
        w2_input_scale = PerTensorScaleParameter(
            data=torch.full((num_experts,), 1.0, dtype=torch.float32),
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_input_scale", w13_input_scale)
        layer.register_parameter("w2_input_scale", w2_input_scale)
823

824
825
    def _setup_kernel(
        self,
826
        layer: FusedMoE,
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
        w13: torch.Tensor,
        w2: torch.Tensor,
        w13_scale: torch.Tensor,
        w2_scale: torch.Tensor,
        w13_input_scale: torch.Tensor,
        w2_input_scale: torch.Tensor,
    ):
        w13, w2, w13_scale, w2_scale = convert_to_fp8_moe_kernel_format(
            fp8_backend=self.fp8_backend,
            layer=layer,
            w13=w13,
            w2=w2,
            w13_scale=w13_scale,
            w2_scale=w2_scale,
            w13_input_scale=w13_input_scale,
            w2_input_scale=w2_input_scale,
        )
844

845
846
847
848
849
850
851
        # Replace parameters with updated versions. Note that this helper
        # function ensures the replacement is compatible with RL weight reloads.
        replace_parameter(layer, "w13_weight", w13)
        replace_parameter(layer, "w2_weight", w2)
        replace_parameter(layer, "w13_weight_scale", w13_scale)
        replace_parameter(layer, "w2_weight_scale", w2_scale)

852
        # Setup modular kernel.
853
854
        self.moe_quant_config = self.get_fused_moe_quant_config(layer)
        if self.moe_quant_config:
855
            assert self.experts_cls is not None
856
            self.moe_mk, self.use_inplace = make_fp8_moe_kernel(
857
858
859
                moe_quant_config=self.moe_quant_config,
                moe_config=self.moe,
                fp8_backend=self.fp8_backend,
860
                experts_cls=self.experts_cls,
861
862
                routing_tables=layer._maybe_init_expert_routing_tables(),
                shared_experts=layer.shared_experts,
863
            )
864

865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        w13 = layer.w13_weight
        w2 = layer.w2_weight
        w13_scale = layer.w13_weight_scale
        w2_scale = layer.w2_weight_scale
        w13_input_scale = layer.w13_input_scale
        w2_input_scale = layer.w2_input_scale

        # Per tensor kernels require single activation scale. Use the max.
        w13_input_scale, w2_input_scale = process_fp8_input_tensor_strategy_moe(
            w13_input_scale, w2_input_scale
        )
        replace_parameter(layer, "w13_input_scale", w13_input_scale)
        replace_parameter(layer, "w2_input_scale", w2_input_scale)

        # Per tensor kernels require single weight scale for w13 per expert, but
        # on disk there is a scale for w1 and w3. Use the max to requantize.
        shard_size = layer.intermediate_size_per_partition
        w13, w13_scale = process_fp8_weight_tensor_strategy_moe(
            w13,
            w13_scale,
            shard_size,
            num_experts=layer.w13_weight.shape[0],
            is_act_and_mul=self.moe.is_act_and_mul,
889
890
        )

891
892
893
894
        # Shuffle weights to runtime format and setup kernel.
        self._setup_kernel(
            layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale
        )
895

896
    def get_fused_moe_quant_config(
897
        self, layer: torch.nn.Module
898
    ) -> FusedMoEQuantConfig | None:
899
900
901
902
903
904
905
906
907
908
909
910
        w1_scale = layer.w13_weight_scale
        w2_scale = layer.w2_weight_scale
        a1_scale = layer.w13_input_scale
        a2_scale = layer.w2_input_scale

        return make_fp8_moe_quant_config(
            fp8_backend=self.fp8_backend,
            w1_scale=w1_scale,
            w2_scale=w2_scale,
            a1_scale=a1_scale,
            a2_scale=a2_scale,
        )
911

912
913
914
915
916
    @property
    def is_monolithic(self) -> bool:
        return self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM

    def apply_monolithic(
917
        self,
918
        layer: FusedMoE,
919
920
        x: torch.Tensor,
        router_logits: torch.Tensor,
921
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
922
923
924
925
926
        assert self.is_monolithic
        assert self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
        if layer.enable_eplb:
            raise NotImplementedError(
                "EPLB not supported for FlashInfer TRTLLM FP8 MoE Backend."
927
            )
928
929
930
931
932
933
934
935
        # TODO(rob): this validation should happen at kernel selection
        # time in the oracle rather than here.
        assert layer.activation == "silu", (
            f"Expected 'silu' activation but got {layer.activation}"
        )
        assert not layer.renormalize
        return apply_fi_trtllm_fp8_per_tensor_moe(
            layer=layer,
936
937
            hidden_states=x,
            router_logits=router_logits,
938
939
940
941
942
943
            routing_bias=layer.e_score_correction_bias,
            global_num_experts=layer.global_num_experts,
            top_k=layer.top_k,
            num_expert_group=layer.num_expert_group,
            topk_group=layer.topk_group,
            apply_router_weight_on_input=layer.apply_router_weight_on_input,
944
        )
945

946
947
948
949
950
951
952
953
954
    def apply(
        self,
        layer: FusedMoE,
        x: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        assert not self.is_monolithic

955
956
957
        # TODO(rob): this validation should happen at kernel selection
        # time in the oracle rather than here.
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
958
            assert layer.activation in ("silu", "relu2_no_mul"), (
959
                "Expected activation to be in ('silu', 'relu2_no_mul'),"
960
                f"but got {layer.activation}"
961
            )
962

963
964
        assert self.moe_mk is not None
        return self.moe_mk(
965
966
967
968
969
970
971
972
973
974
975
976
            x,
            layer.w13_weight,
            layer.w2_weight,
            topk_weights,
            topk_ids,
            inplace=self.use_inplace,
            activation=layer.activation,
            global_num_experts=layer.global_num_experts,
            expert_map=layer.expert_map,
            apply_router_weight_on_input=layer.apply_router_weight_on_input,
        )

977

978
979
980
981
982
983
ModelOptFp8Config.LinearMethodCls = ModelOptFp8LinearMethod
ModelOptFp8Config.FusedMoEMethodCls = ModelOptFp8MoEMethod
ModelOptFp8Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod


class ModelOptNvFp4Config(ModelOptQuantConfigBase):
984
985
986
987
988
    """Config class for ModelOpt FP4."""

    def __init__(
        self,
        is_checkpoint_nvfp4_serialized: bool,
989
        kv_cache_quant_algo: str | None,
990
        exclude_modules: list[str],
991
992
        group_size: int = 16,
    ) -> None:
993
        super().__init__(exclude_modules)
994
995
996
997
        self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
        if is_checkpoint_nvfp4_serialized:
            logger.warning(
                "Detected ModelOpt NVFP4 checkpoint. Please note that"
998
999
                " the format is experimental and could change in future."
            )
1000
1001
1002
1003

            self.group_size = group_size
            self.kv_cache_quant_algo = kv_cache_quant_algo

1004
    def get_name(self) -> QuantizationMethods:
1005
        return "modelopt_fp4"
1006

1007
    def get_supported_act_dtypes(self) -> list[torch.dtype]:
1008
1009
1010
1011
        return [torch.bfloat16, torch.half, torch.float8_e4m3fn]

    @classmethod
    def get_min_capability(cls) -> int:
1012
        return 75
1013

1014
1015
    @classmethod
    def override_quantization_method(
1016
        cls, hf_quant_cfg, user_quant
1017
    ) -> QuantizationMethods | None:
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
        """Detect if this ModelOpt FP4 config should be used based on
        quantization config."""
        if hf_quant_cfg is None:
            return None

        # Use the community standard 'quant_method'
        quant_method = hf_quant_cfg.get("quant_method", "").lower()

        # Only proceed if the method is explicitly "modelopt"
        if quant_method != "modelopt":
            return None

        # Look for ModelOpt-specific config structure
        if "quantization" in hf_quant_cfg:
            quant_config = hf_quant_cfg["quantization"]
            if isinstance(quant_config, dict):
                quant_algo = quant_config.get("quant_algo", "")
                if "NVFP4" in quant_algo:
                    return "modelopt_fp4"
        else:
            # Check for compressed-tensors style config with specific
            # quant_algo field
            quant_algo = hf_quant_cfg.get("quant_algo", "")
            if isinstance(quant_algo, str) and "FP4" in quant_algo.upper():
                return "modelopt_fp4"

        return None

1046
    @classmethod
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
    def _from_config(
        cls,
        *,
        quant_method: str,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
        original_config: dict[str, Any],
        group_size: int | None,
        **kwargs: Any,
    ) -> "ModelOptNvFp4Config":
1057
        is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
1058

1059
1060
1061
        if group_size is None:
            group_size = 16  # Default value

1062
        # For FP4, these fields are required
1063
        if is_checkpoint_nvfp4_serialized and "quantization" in original_config:
1064
            # Check if required fields are present in the quantization config
1065
            quant_config = original_config["quantization"]
1066
            required_fields = ["group_size", "kv_cache_quant_algo", "exclude_modules"]
1067
1068
1069
1070
1071
1072
            missing_fields = [
                field for field in required_fields if field not in quant_config
            ]
            if missing_fields:
                raise ValueError(
                    f"NVFP4 quantization requires the following fields in "
1073
1074
1075
1076
1077
                    f"hf_quant_config.json: {missing_fields}"
                )

        return cls(
            is_checkpoint_nvfp4_serialized,
1078
            kv_cache_quant_method,
1079
1080
1081
            exclude_modules,
            group_size,
        )
1082
1083
1084
1085
1086


class ModelOptNvFp4LinearMethod(LinearMethodBase):
    """Linear method for Model Optimizer NVFP4.
    Supports loading NVFP4 checkpoints with the following structure:
1087

1088
1089
1090
1091
1092
1093
1094
    input_scale: torch.float32, scalar ,
    weight: NVFP4(represented as byte) Shape: [1, X, y/2]
    weight_scale: FP8-E4M3, Shape: [X, Y], aka per block scale,
    weight_scale_2: torch.float32, scalar,
    Args: quant_config: The ModelOpt quantization config.
    """

1095
    def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
1096
        self.quant_config = quant_config
1097
        self.marlin_input_dtype = None
1098
        self.backend = select_nvfp4_linear_backend()
1099

1100
1101
1102
1103
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
1104
        output_partition_sizes: list[int],
1105
1106
1107
1108
1109
1110
1111
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size
        if not self.quant_config.is_checkpoint_nvfp4_serialized:
1112
1113
1114
1115
            raise ValueError(
                "NVFP4 quantization was selected, "
                " dynamic quantization is not supported."
            )
1116
1117
1118
1119
1120
1121
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition

1122
1123
1124
1125
        if input_size_per_partition % 16 != 0:
            raise ValueError(
                "Unsupported model when in features size is not multiple of 16"
            )
1126
        # The nvfp4 weight is still represented as
1127
1128
1129
1130
1131
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_nvfp4_serialized
            else params_dtype
        )
1132
1133
1134
1135
1136
1137
        # Weight
        weight = ModelWeightParameter(
            data=torch.empty(
                # 2 fp4 items are packed in the input dimension
                layer.output_size_per_partition,
                layer.input_size_per_partition // 2,
1138
1139
                dtype=torch.uint8,
            ),
1140
1141
            input_dim=1,
            output_dim=0,
1142
1143
            weight_loader=weight_loader,
        )
1144
1145
        layer.register_parameter("weight", weight)

1146
1147
        # Input Global Scale
        input_global_scale = PerTensorScaleParameter(
1148
1149
1150
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )
1151
        layer.register_parameter("input_scale", input_global_scale)
1152

1153
1154
        # Weight Global Scale
        weight_global_scale = PerTensorScaleParameter(
1155
1156
1157
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )
1158
        layer.register_parameter("weight_scale_2", weight_global_scale)
1159
1160

        # Per Block Weight Scale
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
        weight_scale = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition // self.quant_config.group_size,
                dtype=weight_dtype,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
1171
1172
1173

        layer.register_parameter("weight_scale", weight_scale)

1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        # Rename ModelOpt checkpoint names to standardized names
        input_global_scale = layer.input_scale.max().to(torch.float32)
        layer.input_global_scale = Parameter(input_global_scale, requires_grad=False)
        del layer.input_scale
        weight_global_scale = layer.weight_scale_2.max().to(torch.float32)
        layer.weight_global_scale = Parameter(weight_global_scale, requires_grad=False)
        del layer.weight_scale_2

        # Pre-compute alpha and inverse for runtime quantization
1184
        layer.alpha = Parameter(
1185
            layer.input_global_scale * layer.weight_global_scale, requires_grad=False
1186
        )
1187
1188
        layer.input_global_scale_inv = Parameter(
            (1.0 / layer.input_global_scale).to(torch.float32), requires_grad=False
1189
        )
1190

1191
1192
        # Convert layer to NVFP4 linear kernel format
        convert_to_nvfp4_linear_kernel_format(self.backend, layer)
1193
1194
1195
1196
1197

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
1198
        bias: torch.Tensor | None = None,
1199
    ) -> torch.Tensor:
1200
1201
1202
1203
1204
        return apply_nvfp4_linear(
            backend=self.backend,
            layer=layer,
            x=x,
            bias=bias,
1205
        )
1206

1207
1208
1209
1210

class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
    """
    MoE Method for FP4 Quantization.
1211
    Args:
1212
1213
1214
        quant_config: NVFP4 Quant Config
    """

1215
1216
1217
    def __init__(
        self,
        quant_config: ModelOptNvFp4Config,
1218
        moe_config: FusedMoEConfig,
1219
    ) -> None:
1220
        super().__init__(moe_config)
1221
        self.quant_config = quant_config
1222
1223
1224
1225
1226
1227
1228
        # Select experts implementation.
        self.nvfp4_backend, self.experts_cls = select_nvfp4_moe_backend(
            config=self.moe,
            weight_key=kNvfp4Static,
            activation_key=kNvfp4Dynamic,
        )

1229
1230
1231
        self.use_global_sf = is_global_sf_supported_for_nvfp4_backend(
            self.nvfp4_backend
        )
1232

1233
1234
1235
1236
    def maybe_make_prepare_finalize(
        self,
        routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
    ) -> mk.FusedMoEPrepareAndFinalize | None:
1237
1238
1239
1240
        raise ValueError(
            f"{self.__class__.__name__} uses the new modular kernel initialization "
            "logic. This function should not be called."
        )
1241

1242
1243
1244
    def select_gemm_impl(
        self,
        prepare_finalize: mk.FusedMoEPrepareAndFinalize,
1245
        layer: torch.nn.Module,
1246
    ) -> mk.FusedMoEPermuteExpertsUnpermute:
1247
1248
1249
        raise ValueError(
            f"{self.__class__.__name__} uses the new modular kernel initialization "
            "logic. This function should not be called."
1250
        )
1251

1252
1253
1254
1255
1256
1257
    def uses_weight_scale_2_pattern(self) -> bool:
        """
        FP4 variants use 'weight_scale_2' pattern for per-tensor weight scales.
        """
        return True

1258
1259
1260
1261
1262
1263
1264
1265
1266
    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
1267
        assert self.quant_config.is_checkpoint_nvfp4_serialized
1268

1269
1270
        layer.num_experts = num_experts
        layer.params_dtype = params_dtype
1271
1272
1273
1274
        layer.quant_config = self.quant_config
        weight_dtype = torch.uint8
        weight_scale_dtype = torch.float8_e4m3fn
        weight_loader = extra_weight_attrs.get("weight_loader")
1275
        global_num_experts = extra_weight_attrs.get("global_num_experts")
1276
        w13_num_shards = 2 if self.moe.is_act_and_mul else 1
1277
1278
1279
1280
        # GEMM 1
        w13_weight = ModelWeightParameter(
            data=torch.empty(
                num_experts,
1281
                w13_num_shards * intermediate_size_per_partition,
1282
1283
                # 2 fp4 items are packed in the input dimension
                hidden_size // 2,
1284
1285
                dtype=weight_dtype,
            ),
1286
1287
            input_dim=1,
            output_dim=2,
1288
1289
            weight_loader=weight_loader,
        )
1290
1291
1292
1293
1294
1295
1296
1297
1298
        layer.register_parameter("w13_weight", w13_weight)

        # GEMM 2
        w2_weight = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                hidden_size,
                # 2 fp4 items are packed in the input dimension
                intermediate_size_per_partition // 2,
1299
1300
                dtype=weight_dtype,
            ),
1301
1302
            input_dim=1,
            output_dim=2,
1303
1304
            weight_loader=weight_loader,
        )
1305
1306
1307
1308
1309
        layer.register_parameter("w2_weight", w2_weight)

        w13_weight_scale = ModelWeightParameter(
            data=torch.empty(
                num_experts,
1310
                w13_num_shards * intermediate_size_per_partition,
1311
1312
                # 2 fp4 items are packed in the input dimension
                hidden_size // self.quant_config.group_size,
1313
1314
                dtype=weight_scale_dtype,
            ),
1315
1316
            input_dim=1,
            output_dim=2,
1317
1318
            weight_loader=weight_loader,
        )
1319
1320
1321
1322
1323
1324
1325
        layer.register_parameter("w13_weight_scale", w13_weight_scale)

        w2_weight_scale = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                hidden_size,
                # 2 fp4 items are packed in the input dimension
1326
1327
1328
                intermediate_size_per_partition // self.quant_config.group_size,
                dtype=weight_scale_dtype,
            ),
1329
1330
            input_dim=1,
            output_dim=2,
1331
1332
            weight_loader=weight_loader,
        )
1333
1334
1335
        layer.register_parameter("w2_weight_scale", w2_weight_scale)

        extra_weight_attrs.update(
1336
1337
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
        )
1338
1339

        w13_weight_scale_2 = PerTensorScaleParameter(
1340
            data=torch.empty(num_experts, w13_num_shards, dtype=torch.float32),
1341
1342
            weight_loader=weight_loader,
        )
1343
1344
1345
1346
        layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)

        w2_weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(num_experts, dtype=torch.float32),
1347
1348
            weight_loader=weight_loader,
        )
1349
1350
1351
        layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2)

        extra_weight_attrs.update(
1352
1353
            {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
        )
1354

1355
1356
        global_sf_num_experts = (
            global_num_experts if self.use_global_sf else num_experts
1357
        )
1358
        w13_input_scale = PerTensorScaleParameter(
1359
            data=torch.empty(
1360
                global_sf_num_experts,
1361
                w13_num_shards,
1362
1363
                dtype=torch.float32,
            ),
1364
1365
            weight_loader=weight_loader,
        )
1366
1367
        layer.register_parameter("w13_input_scale", w13_input_scale)

1368
        w2_input_scale = PerTensorScaleParameter(
1369
            data=torch.empty(global_sf_num_experts, dtype=torch.float32),
1370
1371
            weight_loader=weight_loader,
        )
1372
1373
        layer.register_parameter("w2_input_scale", w2_input_scale)

1374
    def process_weights_after_loading(self, layer: FusedMoE) -> None:
1375
1376
1377
        """
        Convert NVFP4 MoE weights into kernel format and setup the kernel.
        """
1378

1379
        # Use a single gscale for w13.
1380
        if self.moe.is_act_and_mul and not torch.allclose(
1381
1382
            layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]
        ):
1383
1384
            logger.warning_once(
                "w1_weight_scale_2 must match w3_weight_scale_2. "
1385
1386
                "Accuracy may be affected."
            )
1387
        w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0].contiguous()
1388

1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
        (
            w13,
            w13_scale,
            w13_scale_2,
            a13_scale,
            w2,
            w2_scale,
            w2_scale_2,
            a2_scale,
        ) = convert_to_nvfp4_moe_kernel_format(
            nvfp4_backend=self.nvfp4_backend,
            layer=layer,
            w13=layer.w13_weight,
            w13_scale=layer.w13_weight_scale,
            w13_scale_2=w13_weight_scale_2,
            a13_scale=layer.w13_input_scale,
            w2=layer.w2_weight,
            w2_scale=layer.w2_weight_scale,
            w2_scale_2=layer.w2_weight_scale_2,
            a2_scale=layer.w2_input_scale,
            is_act_and_mul=self.moe.is_act_and_mul,
1410
        )
1411

1412
1413
1414
1415
1416
1417
1418
1419
        replace_parameter(layer, "w13_weight", w13)
        replace_parameter(layer, "w13_weight_scale", w13_scale)
        replace_parameter(layer, "w13_weight_scale_2", w13_scale_2)
        replace_parameter(layer, "w13_input_scale", a13_scale)
        replace_parameter(layer, "w2_weight", w2)
        replace_parameter(layer, "w2_weight_scale", w2_scale)
        replace_parameter(layer, "w2_weight_scale_2", w2_scale_2)
        replace_parameter(layer, "w2_input_scale", a2_scale)
1420

1421
1422
1423
1424
        # Setup modular kernel for TP case and naive DP/EP case.
        # In non-naive DP/EP case, we will create a ModularKernelMethod.
        # TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
        # in both cases.
1425
        self.moe_quant_config = self.get_fused_moe_quant_config(layer)
1426
        if self.moe_quant_config:
1427
            assert self.experts_cls is not None
1428
            self.moe_mk = make_nvfp4_moe_kernel(
1429
                moe_quant_config=self.moe_quant_config,
1430
                moe_config=self.moe,
1431
                experts_cls=self.experts_cls,
1432
1433
                shared_experts=layer.shared_experts,
                routing_tables=layer._maybe_init_expert_routing_tables(),
1434
            )
1435

1436
1437
1438
1439
    @property
    def do_post_quant_allgather(self):
        return self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM

1440
1441
1442
1443
1444
1445
1446
    def prepare_dp_allgather_tensor(
        self,
        layer: FusedMoE,
        hidden_states: torch.Tensor,
        router_logits: torch.Tensor,
    ) -> tuple[torch.Tensor, list[torch.Tensor]]:
        """Optionally prepare extra tensors to carry through DP allgather/EP."""
1447
1448
1449
1450
1451
1452
        if self.nvfp4_backend != NvFp4MoeBackend.FLASHINFER_TRTLLM:
            raise RuntimeError(
                "prepare_dp_allgather_tensor is only supported for "
                "FlashInfer TRTLLM NVFP4 MoE backend."
            )

1453
1454
1455
1456
        import flashinfer

        hidden_states_fp4, hidden_states_sf = flashinfer.fp4_quantize(
            hidden_states,
1457
            layer.a1_gscale,
1458
1459
1460
1461
1462
            is_sf_swizzled_layout=False,
        )
        extra_tensors: list[torch.Tensor] = [hidden_states_sf]
        return hidden_states_fp4, extra_tensors

1463
    def get_fused_moe_quant_config(
1464
        self, layer: torch.nn.Module
1465
    ) -> FusedMoEQuantConfig | None:
1466
1467
1468
        return make_nvfp4_moe_quant_config(
            backend=self.nvfp4_backend,
            w13_scale=layer.w13_weight_scale,
1469
            w2_scale=layer.w2_weight_scale,
1470
1471
1472
1473
            w13_scale_2=layer.w13_weight_scale_2,
            w2_scale_2=layer.w2_weight_scale_2,
            a13_scale=layer.w13_input_scale,
            a2_scale=layer.w2_input_scale,
1474
1475
        )

1476
1477
1478
1479
    @property
    def supports_eplb(self) -> bool:
        return True

1480
1481
1482
1483
1484
1485
1486
1487
    @property
    def is_monolithic(self) -> bool:
        return (
            self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
            and not self.moe.moe_parallel_config.enable_eplb
        )

    def apply_monolithic(
1488
        self,
1489
        layer: FusedMoE,
1490
1491
        x: torch.Tensor,
        router_logits: torch.Tensor,
1492
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
1493
1494
        assert self.is_monolithic
        assert (
1495
            self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
1496
            and not layer.enable_eplb
1497
        )
1498

1499
1500
1501
        return flashinfer_trtllm_fp4_moe(
            layer=layer,
            x=x,
1502
            router_logits=router_logits,
1503
1504
1505
1506
1507
1508
1509
            top_k=layer.top_k,
            activation=layer.activation,
            global_num_experts=layer.global_num_experts,
            num_expert_group=layer.num_expert_group,
            topk_group=layer.topk_group,
            custom_routing_function=layer.custom_routing_function,
            e_score_correction_bias=layer.e_score_correction_bias,
1510
        )
1511

1512
1513
1514
1515
1516
1517
1518
1519
1520
    def apply(
        self,
        layer: FusedMoE,
        x: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        assert not self.is_monolithic

1521
        # EPLB path
1522
1523
        if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
            assert layer.enable_eplb
1524
1525
1526
1527
1528
1529
            return flashinfer_trtllm_fp4_routed_moe(
                layer=layer,
                x=x,
                topk_ids=topk_ids,
                topk_weights=topk_weights,
                top_k=layer.top_k,
1530
                activation=layer.activation,
1531
1532
                global_num_experts=layer.global_num_experts,
            )
1533
        else:
1534
1535
            assert self.moe_mk is not None
            return self.moe_mk(
1536
1537
1538
1539
1540
                x,
                layer.w13_weight,
                layer.w2_weight,
                topk_weights,
                topk_ids,
1541
                inplace=False,
1542
1543
1544
1545
                activation=layer.activation,
                global_num_experts=layer.global_num_experts,
                expert_map=layer.expert_map,
                apply_router_weight_on_input=layer.apply_router_weight_on_input,
1546
            )
1547
1548
1549
1550
1551


ModelOptNvFp4Config.LinearMethodCls = ModelOptNvFp4LinearMethod
ModelOptNvFp4Config.FusedMoEMethodCls = ModelOptNvFp4FusedMoE
ModelOptNvFp4Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod