modelopt.py 62.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from fnmatch import fnmatch
5
from typing import TYPE_CHECKING, Any, Optional
6
7
8
9
10

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter

11
12
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
13
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
14
from vllm.attention.layer import Attention
15
from vllm.logger import init_logger
16
from vllm.model_executor.layers.fused_moe.config import (
17
    FusedMoEConfig,
18
19
    FusedMoEQuantConfig,
)
20
from vllm.model_executor.layers.fused_moe.layer import (
21
22
23
24
    FusedMoE,
    FusedMoEMethodBase,
    FusedMoeWeightScaleSupported,
)
25
26
27
28
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
    Fp8MoeBackend,
    convert_to_fp8_moe_kernel_format,
    make_fp8_moe_kernel,
29
    make_fp8_moe_kernel_for_mkm,
30
31
32
    make_fp8_moe_quant_config,
    select_fp8_moe_backend,
)
33
34
35
36
37
from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
    NvFp4MoeBackend,
    convert_to_nvfp4_moe_kernel_format,
    is_global_sf_supported_for_nvfp4_backend,
    make_nvfp4_moe_kernel,
38
    make_nvfp4_moe_kernel_for_mkm,
39
40
41
    make_nvfp4_moe_quant_config,
    select_nvfp4_moe_backend,
)
42
43
44
45
46
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
)
47
from vllm.model_executor.layers.quantization import QuantizationMethods
48
from vllm.model_executor.layers.quantization.base_config import (
49
50
51
    QuantizationConfig,
    QuantizeMethodBase,
)
52
53
54
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
    init_fp8_linear_kernel,
)
55
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
56
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
57
    build_flashinfer_fp4_cutlass_moe_prepare_finalize,
58
    flashinfer_trtllm_fp4_moe,
59
    flashinfer_trtllm_fp4_routed_moe,
60
)
61
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
62
    apply_fi_trtllm_fp8_per_tensor_moe,
63
    build_flashinfer_fp8_cutlass_moe_prepare_finalize,
64
)
65
66
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
    W8A8BlockFp8LinearOp,
67
68
    process_fp8_input_tensor_strategy_moe,
    process_fp8_weight_tensor_strategy_moe,
69
)
70
71
72
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
    get_marlin_input_dtype,
)
73
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
74
75
76
77
    apply_fp4_marlin_linear,
    is_fp4_marlin_supported,
    prepare_fp4_layer_for_marlin,
)
78
from vllm.model_executor.layers.quantization.utils.quant_utils import (
79
80
81
    GroupShape,
    cutlass_fp4_supported,
    is_layer_skipped,
82
83
84
    kFp8DynamicTokenSym,
    kFp8StaticTensorSym,
    kFp8StaticTokenSym,
85
86
    kNvfp4Dynamic,
    kNvfp4Static,
87
88
89
    pad_nvfp4_activation_for_cutlass,
    pad_nvfp4_weight_for_cutlass,
    slice_nvfp4_output,
90
91
    swizzle_blockscale,
)
92
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
93
    cutlass_block_fp8_supported,
94
95
    requantize_with_max_scale,
)
96
97
98
99
100
101
from vllm.model_executor.parameter import (
    BlockQuantScaleParameter,
    ChannelQuantScaleParameter,
    ModelWeightParameter,
    PerTensorScaleParameter,
)
102
from vllm.model_executor.utils import replace_parameter
103
104
105
106
from vllm.utils.flashinfer import (
    flashinfer_scaled_fp4_mm,
    has_flashinfer,
)
107

108
109
110
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

111
112
logger = init_logger(__name__)

113
114
115
116
117
118
119
120
121
122
QUANT_ALGOS = [
    # FP8 (per-tensor weight + optional static activation scale).
    "FP8",
    # FP8 per-channel weight scale + per-token activation scale.
    "FP8_PER_CHANNEL_PER_TOKEN",
    # FP8 per-block weight-only (ModelOpt may emit this as lowercase).
    "FP8_PB_WO",
    # FP4
    "NVFP4",
]
123
KV_CACHE_QUANT_ALGOS = ["FP8"]
124
125


126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
    """

    def __init__(self, quant_config: "ModelOptQuantConfigBase"):
        super().__init__(quant_config)


class ModelOptQuantConfigBase(QuantizationConfig):
    LinearMethodCls: type = LinearMethodBase
    FusedMoEMethodCls: type = FusedMoEMethodBase
    KVCacheMethodCls: type = BaseKVCacheMethod

    def __init__(
        self,
        exclude_modules: list[str],
    ):
        super().__init__()
        self.exclude_modules: list[str] = exclude_modules

    def is_layer_excluded(self, prefix: str) -> bool:
        """
        Check if a layer should be excluded from quantization.

        Handles both exact matching (for fused layers) and ModelOpt wildcard matching.

        The ModelOpt exclude_modules list is a list of wildcards.
        """
        if len(self.exclude_modules) == 0:
            return False

        # First check exact matching with fused layer support
        if is_layer_skipped(prefix, self.exclude_modules, self.packed_modules_mapping):
            return True

        # TODO: This special hard coded logic is not needed for quantized checkpoints
        # generated by ModelOpt >= 0.39.0 where they are handled natually by the
        # exclude_modules config. But need to keep them for loading quantized
        # checkpoints generated by older versions. Then check substring matching
        # for patterns not caught by exact match
        for exclude_module in self.exclude_modules:
            # Skip exact matches already handled above
            if exclude_module != prefix and (
                exclude_module in prefix
                or (
                    prefix.startswith("language_model.")
                    and exclude_module in prefix.removeprefix("language_model.")
                )
            ):
                return True

        # modelopt exclude modules are not simple strings, they are wildcards
        for wildcard_pattern in self.exclude_modules:
            if fnmatch(prefix, wildcard_pattern):
                return True

        return False

    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
        # handle kv-cache first so we can focus only on weight quantization thereafter
        if isinstance(layer, Attention):
            return self.KVCacheMethodCls(self)

        # handle exclusion
        if self.is_layer_excluded(prefix):
            if isinstance(layer, LinearBase):
                return UnquantizedLinearMethod()
            return None

        # TODO: This special hard coded logic is not needed for quantized checkpoints
        # generated by ModelOpt >= 0.39.0 where they are handled natually by the
        # exclude_modules config. But need to keep them for loading quantized
        # checkpoints generated by older versions. Then check substring matching
        # for patterns not caught by exact match
        if "vision_tower" in prefix or "vision_model" in prefix:
            return UnquantizedLinearMethod()

        # now, the layer is quantized, handle it here
        if isinstance(layer, LinearBase):
208
209
210
211
            quant_method = self.LinearMethodCls(self)
            if getattr(quant_method, "backend", "") == "marlin":
                quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
            return quant_method
212
        elif isinstance(layer, FusedMoE):
213
214
215
            quant_method = self.FusedMoEMethodCls(
                quant_config=self, moe_config=layer.moe_config
            )
216
217
218
            if getattr(quant_method, "backend", "") == "marlin":
                quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
            return quant_method
219
220
221
222
223

        return None

    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if len(self.exclude_modules) > 0:
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
            # This is a workaround for the weights remapping issue:
            # https://github.com/vllm-project/vllm/issues/28072
            # Right now, the Nvidia ModelOpt library use just one wildcard pattern:
            #        module_path*
            # It gets applied if the whole tree of modules rooted at module_path
            # is not quantized. Here we replace such pattern by 2 patterns that are
            # collectively equivalent to the original pattern:
            #        module_path
            #        module_path.*
            new_exclude_modules = []
            for exclude in self.exclude_modules:
                if len(exclude) >= 2 and exclude[-1] == "*" and exclude[-2] != ".":
                    new_exclude_modules.append(exclude[:-1])
                    new_exclude_modules.append(exclude[:-1] + ".*")
                else:
                    new_exclude_modules.append(exclude)

            self.exclude_modules = hf_to_vllm_mapper.apply_list(new_exclude_modules)
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290

    @staticmethod
    def get_config_filenames() -> list[str]:
        return ["hf_quant_config.json"]

    @classmethod
    def _from_config(
        cls,
        *,
        quant_method: str,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
        original_config: dict[str, Any],
        group_size: int | None,
    ) -> "ModelOptQuantConfigBase":
        raise NotImplementedError("Please implement this function in sub classes")

    @classmethod
    def from_config(cls, config: dict[str, Any]) -> "ModelOptQuantConfigBase":
        # Handle both ModelOpt format and compressed-tensors style format
        if "quantization" in config:
            # Traditional ModelOpt format:
            # {"quantization": {"quant_algo": "..."}}
            quant_config = cls.get_from_keys(config, ["quantization"])
            if not isinstance(quant_config, dict):
                raise ValueError("Expected 'quantization' to be a dictionary in config")

            quant_method = quant_config.get("quant_algo")

            # Handle kv_cache_quant_algo with proper type validation
            kv_cache_quant_method = quant_config.get("kv_cache_quant_algo")

            # Handle group_size with proper type validation
            group_size_raw = quant_config.get("group_size")

            # "exclude_modules" is the key in the legacy hf_quant_config.json
            exclude_modules = quant_config.get("exclude_modules", [])
        else:
            # Compressed-tensors style format:
            # {"quant_algo": "...", "quant_method": "modelopt"}
            quant_method = config.get("quant_algo")
            kv_cache_quant_method = config.get("kv_cache_quant_algo")
            # "ignore" is the key in config.json
            exclude_modules = config.get("ignore", [])
            group_size_raw = config.get("group_size")

        if not quant_method:
            raise ValueError("Missing 'quant_algo' in quantization config")

291
292
293
        # Normalize quant_algo for robust matching (ModelOpt may emit lowercase).
        quant_method = str(quant_method).upper()

294
295
296
297
298
299
300
301
        if kv_cache_quant_method is None:
            # No KV cache quantization, keep this branch just to have this comment
            pass
        elif not isinstance(kv_cache_quant_method, str):
            raise ValueError(
                f"kv_cache_quant_algo must be a string, got "
                f"{type(kv_cache_quant_method)}"
            )
302
303
        else:
            kv_cache_quant_method = kv_cache_quant_method.upper()
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338

        if not isinstance(exclude_modules, list):
            raise ValueError(
                f"exclude_modules must be a list, got {type(exclude_modules)}"
            )

        if group_size_raw is None:
            group_size = None
        elif isinstance(group_size_raw, int):
            group_size = group_size_raw
        else:
            try:
                group_size = int(group_size_raw)
            except (ValueError, TypeError):
                raise ValueError(
                    f"group_size must be an integer, got {type(group_size_raw)}"
                ) from None

        if quant_method not in QUANT_ALGOS:
            raise ValueError(
                f"ModelOpt currently only supports: {QUANT_ALGOS} "
                "quantizations in vLLM. Please check the "
                "`hf_quant_config.json` file for your model's "
                "quant configuration."
            )
        return cls._from_config(
            quant_method=quant_method,
            kv_cache_quant_method=kv_cache_quant_method,
            exclude_modules=exclude_modules,
            group_size=group_size,
            original_config=config,
        )


class ModelOptFp8Config(ModelOptQuantConfigBase):
339
340
341
342
    """Config class for ModelOpt FP8."""

    def __init__(
        self,
343
        quant_method: str,
344
345
346
        is_checkpoint_fp8_serialized: bool,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
347
    ) -> None:
348
        super().__init__(exclude_modules)
349
        self.quant_method = quant_method
350
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
351
        self.kv_cache_quant_method = kv_cache_quant_method
352
        if is_checkpoint_fp8_serialized:
353
            logger.warning(
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
                "Detected ModelOpt fp8 checkpoint (quant_algo=%s). Please note "
                "that the format is experimental and could change.",
                quant_method,
            )

        # Select LinearMethod implementation based on quant_algo.
        if self.quant_method == "FP8":
            self.LinearMethodCls = ModelOptFp8LinearMethod
        elif self.quant_method == "FP8_PER_CHANNEL_PER_TOKEN":
            self.LinearMethodCls = ModelOptFp8PcPtLinearMethod
        elif self.quant_method == "FP8_PB_WO":
            self.LinearMethodCls = ModelOptFp8PbWoLinearMethod
        else:
            raise ValueError(
                "Unsupported ModelOpt FP8 quant_algo for vLLM: "
                f"{self.quant_method}. Supported: FP8 / "
                "FP8_PER_CHANNEL_PER_TOKEN / FP8_PB_WO."
371
            )
372

373
    def get_name(self) -> QuantizationMethods:
374
375
        return "modelopt"

376
    def get_supported_act_dtypes(self) -> list[torch.dtype]:
377
378
379
380
381
382
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
        return 89

383
384
    @classmethod
    def override_quantization_method(
385
        cls, hf_quant_cfg, user_quant
386
    ) -> QuantizationMethods | None:
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
        """Detect if this ModelOpt config should be used based on
        quantization config."""

        if hf_quant_cfg is None:
            return None

        # Use the community standard 'quant_method'
        quant_method = hf_quant_cfg.get("quant_method", "").lower()

        # Only proceed if the method is explicitly "modelopt"
        if quant_method != "modelopt":
            return None

        # Look for ModelOpt-specific config structure
        if "quantization" in hf_quant_cfg:
            quant_config = hf_quant_cfg["quantization"]
            if isinstance(quant_config, dict):
404
405
                quant_algo = str(quant_config.get("quant_algo", ""))
                if "FP8" in quant_algo.upper():
406
407
408
                    return "modelopt"
        else:
            # Check for compressed-tensors style config with specific quant_algo
409
410
            quant_algo = str(hf_quant_cfg.get("quant_algo", ""))
            if "FP8" in quant_algo.upper():
411
412
413
414
                return "modelopt"

        return None

415
    @classmethod
416
417
418
419
420
421
422
423
424
    def _from_config(
        cls,
        *,
        quant_method: str,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
        original_config: dict[str, Any],
        **kwargs: Any,
    ) -> "ModelOptFp8Config":
425
        is_checkpoint_fp8_serialized = "FP8" in quant_method
426

427
428
429
430
431
432
        return cls(
            quant_method,
            is_checkpoint_fp8_serialized,
            kv_cache_quant_method,
            exclude_modules,
        )
433

434
435
436
437

class ModelOptFp8LinearMethod(LinearMethodBase):
    """Linear method for Model Optimizer static quantization.
    Supports loading FP8 checkpoints with static weight scale and
438
    activation scale. Future support might be added for dynamic
439
440
441
442
    scales.

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
443
    2. Only support float8_e4m3fn datatype
444
445
446
        Args: quant_config: The ModelOpt quantization config.
    """

447
    def __init__(self, quant_config: ModelOptFp8Config) -> None:
448
        self.quant_config = quant_config
449
450
451
452
453
        self.fp8_linear = init_fp8_linear_kernel(
            activation_quant_key=kFp8StaticTensorSym,
            weight_quant_key=kFp8StaticTensorSym,
            out_dtype=torch.get_default_dtype(),
            module_name=self.__class__.__name__,
454
        )
455
456
457
458
459

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
460
        output_partition_sizes: list[int],
461
462
463
464
465
466
467
468
469
470
471
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
472
473
474
475
476
477
478
479
480
481
482
483
484
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_fp8_serialized
            else params_dtype
        )
        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition, input_size_per_partition, dtype=weight_dtype
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
485
486
487
488
        layer.register_parameter("weight", weight)

        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
489
490
491
492
            weight_scale = PerTensorScaleParameter(
                data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
                weight_loader=weight_loader,
            )
493
494
495
            weight_scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("weight_scale", weight_scale)
            # INPUT SCALE
496
497
498
499
            scale = PerTensorScaleParameter(
                data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
                weight_loader=weight_loader,
            )
500
501
502
503
504

            scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("input_scale", scale)

    def process_weights_after_loading(self, layer: Module) -> None:
505
506
507
508
        weight = layer.weight
        max_w_scale = layer.weight_scale.max()
        if not (layer.weight_scale == layer.weight_scale[0]).all():
            max_w_scale, weight = requantize_with_max_scale(
509
510
                layer.weight, layer.weight_scale, layer.logical_widths
            )
511
512
        layer.weight = Parameter(weight.t(), requires_grad=False)
        layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
513
        layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
514
515
516
517
518

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
519
        bias: torch.Tensor | None = None,
520
    ) -> torch.Tensor:
521
        return self.fp8_linear.apply_weights(layer, x, bias)
522
523


524
525
526
527
528
529
530
531
532
533
534
class ModelOptFp8PcPtLinearMethod(LinearMethodBase):
    """Linear method for ModelOpt FP8_PER_CHANNEL_PER_TOKEN checkpoints.

    Expected checkpoint structure (per Linear):
    - weight: fp8-e4m3fn, shape [out, in]
    - weight_scale: fp32, shape [out] (per-output-channel)
    - no input_scale (activations are dynamically quantized per-token)
    """

    def __init__(self, quant_config: ModelOptFp8Config) -> None:
        self.quant_config = quant_config
535
536
537
538
539
        self.fp8_linear = init_fp8_linear_kernel(
            activation_quant_key=kFp8DynamicTokenSym,
            weight_quant_key=kFp8StaticTokenSym,
            out_dtype=torch.get_default_dtype(),
            module_name=self.__class__.__name__,
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
        )

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size

        if not self.quant_config.is_checkpoint_fp8_serialized:
            raise ValueError(
                "FP8_PER_CHANNEL_PER_TOKEN currently only supports "
                "FP8-serialized checkpoints."
            )

        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition

        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition,
                dtype=torch.float8_e4m3fn,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
        layer.register_parameter("weight", weight)

        weight_scale = ChannelQuantScaleParameter(
            data=torch.empty(output_size_per_partition, dtype=torch.float32),
            output_dim=0,
            weight_loader=weight_loader,
        )
        weight_scale[:] = torch.finfo(torch.float32).min
        layer.register_parameter("weight_scale", weight_scale)

    def process_weights_after_loading(self, layer: Module) -> None:
        layer.weight = Parameter(layer.weight.t(), requires_grad=False)
        layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
596
        return self.fp8_linear.apply_weights(layer, x, bias)
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717


class ModelOptFp8PbWoLinearMethod(LinearMethodBase):
    """Linear method for ModelOpt FP8_PB_WO checkpoints.

    ModelOpt exports `weight_scale` as a 4D tensor:
      [out_blk, 1, in_blk, 1]
    where block size is typically 128 for both dims.

    vLLM executes it as FP8 GEMM with *dynamic per-token* activation quant.
    """

    _WEIGHT_BLOCK_SIZE: tuple[int, int] = (128, 128)

    def __init__(self, quant_config: ModelOptFp8Config) -> None:
        self.quant_config = quant_config
        block_n, block_k = self._WEIGHT_BLOCK_SIZE
        self.weight_block_size = list(self._WEIGHT_BLOCK_SIZE)
        self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
            weight_group_shape=GroupShape(block_n, block_k),
            act_quant_group_shape=GroupShape(1, block_k),
            cutlass_block_fp8_supported=cutlass_block_fp8_supported(),
            use_aiter_and_is_supported=False,
        )

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size

        if not self.quant_config.is_checkpoint_fp8_serialized:
            raise ValueError(
                "FP8_PB_WO currently only supports FP8-serialized checkpoints."
            )

        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition

        # Expose block size so the v2 weight loaders can translate offsets from
        # element-space -> block-space for BlockQuantScaleParameter.
        layer.weight_block_size = self.weight_block_size

        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition,
                dtype=torch.float8_e4m3fn,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
        layer.register_parameter("weight", weight)

        block_n, block_k = self._WEIGHT_BLOCK_SIZE
        if output_size_per_partition % block_n != 0:
            raise ValueError(
                "ModelOpt FP8_PB_WO requires out_features divisible by "
                f"{block_n}, got {output_size_per_partition}."
            )
        if input_size_per_partition % block_k != 0:
            raise ValueError(
                "ModelOpt FP8_PB_WO requires in_features divisible by "
                f"{block_k}, got {input_size_per_partition}."
            )

        out_blks = output_size_per_partition // block_n
        in_blks = input_size_per_partition // block_k

        # Match ModelOpt's exported shape so weight loading works without a
        # custom loader: [out_blk, 1, in_blk, 1]
        weight_scale = BlockQuantScaleParameter(
            data=torch.empty((out_blks, 1, in_blks, 1), dtype=torch.float32),
            input_dim=2,
            output_dim=0,
            weight_loader=weight_loader,
        )
        weight_scale[:] = torch.finfo(torch.float32).min
        layer.register_parameter("weight_scale", weight_scale)

    def process_weights_after_loading(self, layer: Module) -> None:
        # Keep weight in [out, in] layout for W8A8BlockFp8LinearOp.
        layer.weight = Parameter(layer.weight.data, requires_grad=False)

        scale = layer.weight_scale
        if scale.dim() == 4:
            # [out_blk, 1, in_blk, 1] -> [out_blk, in_blk]
            scale = scale.squeeze(1).squeeze(-1)
        elif scale.dim() != 2:
            raise ValueError(
                "Unexpected ModelOpt FP8_PB_WO weight_scale shape: "
                f"{tuple(scale.shape)}."
            )

        layer.weight_scale = Parameter(scale.contiguous(), requires_grad=False)

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        return self.w8a8_block_fp8_linear.apply(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            input_scale=None,
            bias=bias,
        )


718
719
720
721
722
723
724
725
class ModelOptFp8MoEMethod(FusedMoEMethodBase):
    """MoE method for ModelOpt FP8.
    Supports loading FP8 checkpoints with static weight scale and
    activation scale.
    Args:
        quant_config: The ModelOpt quantization config.
    """

726
727
728
    def __init__(
        self,
        quant_config: ModelOptFp8Config,
729
        moe_config: FusedMoEConfig,
730
    ) -> None:
731
        super().__init__(moe_config)
732
        self.quant_config = quant_config
733
        assert self.quant_config.is_checkpoint_fp8_serialized
734
735
736
737
738
739

        # Select Fp8 MoE backend
        self.fp8_backend, self.experts_cls = select_fp8_moe_backend(
            config=self.moe,
            weight_key=kFp8StaticTensorSym,
            activation_key=kFp8StaticTensorSym,
740
        )
741
742

        # Delay creation of the kernel until after process-weights.
743
        self.kernel: mk.FusedMoEModularKernel | None = None
744

745
746
747
748
749
750
    @property
    def topk_indices_dtype(self) -> torch.dtype | None:
        if self.kernel is not None:
            return self.kernel.prepare_finalize.topk_indices_dtype()
        return None

751
    def maybe_make_prepare_finalize(
752
        self,
753
        routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
754
    ) -> mk.FusedMoEPrepareAndFinalize | None:
755
        # TRT LLM not supported with all2all yet.
756
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
757
            return None
758
        elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
759
760
            # For no-EP case, don't use the MKM framework.
            if not self.moe.moe_parallel_config.use_all2all_kernels:
761
762
763
764
765
766
767
768
                return None

            prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
                self.moe,
                use_deepseek_fp8_block_scale=False,
            )
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
769
        return super().maybe_make_prepare_finalize(routing_tables)
770
771
772
773

    def select_gemm_impl(
        self,
        prepare_finalize: mk.FusedMoEPrepareAndFinalize,
774
        layer: torch.nn.Module,
775
    ) -> mk.FusedMoEPermuteExpertsUnpermute:
776
        assert self.moe_quant_config is not None
777
778
779
780
781
782
        assert self.experts_cls is not None
        return make_fp8_moe_kernel_for_mkm(
            moe_config=self.moe,
            quant_config=self.moe_quant_config,
            experts_cls=self.experts_cls,
            prepare_finalize=prepare_finalize,
783
        )
784
785
786
787
788
789
790
791
792
793

    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
794
795
796
        layer.orig_dtype = params_dtype
        layer.num_experts = num_experts

797
        # Use FP8 dtype if checkpoint is serialized
798
799
800
801
802
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_fp8_serialized
            else params_dtype
        )
803
804
        weight_loader = extra_weight_attrs.get("weight_loader")

805
        w13_num_shards = 2 if self.moe.is_act_and_mul else 1
806

807
        w13_weight = ModelWeightParameter(
808
809
            data=torch.empty(
                num_experts,
810
                w13_num_shards * intermediate_size_per_partition,
811
812
813
                hidden_size,
                dtype=weight_dtype,
            ),
814
815
816
817
818
819
820
            input_dim=2,
            output_dim=1,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_weight", w13_weight)

        w2_weight = ModelWeightParameter(
821
822
823
824
825
826
            data=torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=weight_dtype,
            ),
827
828
829
830
831
832
            input_dim=2,
            output_dim=1,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w2_weight", w2_weight)

833
834
835
836
837
838
        # WEIGHT SCALES - Per-tensor scaling for ModelOpts
        # For gated MoE, allocate 2 scales for w1 and w3 respectively.
        # They will be combined to a single scale after weight loading.
        # For non-gated MoE, allocate 1 scale for w13.
        w13_weight_scale = PerTensorScaleParameter(
            data=torch.full(
839
                (num_experts, w13_num_shards),
840
841
842
843
844
845
846
847
848
849
850
                1.0,
                dtype=torch.float32,
            ),
            weight_loader=weight_loader,
        )
        w2_weight_scale = PerTensorScaleParameter(
            data=torch.full((num_experts,), 1.0, dtype=torch.float32),
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_weight_scale", w13_weight_scale)
        layer.register_parameter("w2_weight_scale", w2_weight_scale)
851

852
853
854
855
        # INPUT SCALES - Per-tensor scaling for ModelOpt
        w13_input_scale = PerTensorScaleParameter(
            data=torch.full((num_experts,), 1.0, dtype=torch.float32),
            weight_loader=weight_loader,
856
        )
857
858
859
860
861
862
        w2_input_scale = PerTensorScaleParameter(
            data=torch.full((num_experts,), 1.0, dtype=torch.float32),
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_input_scale", w13_input_scale)
        layer.register_parameter("w2_input_scale", w2_input_scale)
863

864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
    def _setup_kernel(
        self,
        layer: torch.nn.Module,
        w13: torch.Tensor,
        w2: torch.Tensor,
        w13_scale: torch.Tensor,
        w2_scale: torch.Tensor,
        w13_input_scale: torch.Tensor,
        w2_input_scale: torch.Tensor,
    ):
        w13, w2, w13_scale, w2_scale = convert_to_fp8_moe_kernel_format(
            fp8_backend=self.fp8_backend,
            layer=layer,
            w13=w13,
            w2=w2,
            w13_scale=w13_scale,
            w2_scale=w2_scale,
            w13_input_scale=w13_input_scale,
            w2_input_scale=w2_input_scale,
        )
884

885
886
887
888
889
890
891
        # Replace parameters with updated versions. Note that this helper
        # function ensures the replacement is compatible with RL weight reloads.
        replace_parameter(layer, "w13_weight", w13)
        replace_parameter(layer, "w2_weight", w2)
        replace_parameter(layer, "w13_weight_scale", w13_scale)
        replace_parameter(layer, "w2_weight_scale", w2_scale)

892
        # Setup modular kernel.
893
894
        self.moe_quant_config = self.get_fused_moe_quant_config(layer)
        if self.moe_quant_config:
895
            assert self.experts_cls is not None
896
897
898
899
            self.kernel, self.use_inplace = make_fp8_moe_kernel(
                moe_quant_config=self.moe_quant_config,
                moe_config=self.moe,
                fp8_backend=self.fp8_backend,
900
                experts_cls=self.experts_cls,
901
            )
902

903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        w13 = layer.w13_weight
        w2 = layer.w2_weight
        w13_scale = layer.w13_weight_scale
        w2_scale = layer.w2_weight_scale
        w13_input_scale = layer.w13_input_scale
        w2_input_scale = layer.w2_input_scale

        # Per tensor kernels require single activation scale. Use the max.
        w13_input_scale, w2_input_scale = process_fp8_input_tensor_strategy_moe(
            w13_input_scale, w2_input_scale
        )
        replace_parameter(layer, "w13_input_scale", w13_input_scale)
        replace_parameter(layer, "w2_input_scale", w2_input_scale)

        # Per tensor kernels require single weight scale for w13 per expert, but
        # on disk there is a scale for w1 and w3. Use the max to requantize.
        shard_size = layer.intermediate_size_per_partition
        w13, w13_scale = process_fp8_weight_tensor_strategy_moe(
            w13,
            w13_scale,
            shard_size,
            num_experts=layer.w13_weight.shape[0],
            is_act_and_mul=self.moe.is_act_and_mul,
927
928
        )

929
930
931
932
        # Shuffle weights to runtime format and setup kernel.
        self._setup_kernel(
            layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale
        )
933

934
    def get_fused_moe_quant_config(
935
        self, layer: torch.nn.Module
936
    ) -> FusedMoEQuantConfig | None:
937
938
939
940
941
942
943
944
945
946
947
948
        w1_scale = layer.w13_weight_scale
        w2_scale = layer.w2_weight_scale
        a1_scale = layer.w13_input_scale
        a2_scale = layer.w2_input_scale

        return make_fp8_moe_quant_config(
            fp8_backend=self.fp8_backend,
            w1_scale=w1_scale,
            w2_scale=w2_scale,
            a1_scale=a1_scale,
            a2_scale=a2_scale,
        )
949

950
951
952
953
954
    @property
    def is_monolithic(self) -> bool:
        return self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM

    def apply_monolithic(
955
        self,
956
        layer: FusedMoE,
957
958
        x: torch.Tensor,
        router_logits: torch.Tensor,
959
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
960
961
962
963
964
        assert self.is_monolithic
        assert self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
        if layer.enable_eplb:
            raise NotImplementedError(
                "EPLB not supported for FlashInfer TRTLLM FP8 MoE Backend."
965
            )
966
967
968
969
970
971
972
973
        # TODO(rob): this validation should happen at kernel selection
        # time in the oracle rather than here.
        assert layer.activation == "silu", (
            f"Expected 'silu' activation but got {layer.activation}"
        )
        assert not layer.renormalize
        return apply_fi_trtllm_fp8_per_tensor_moe(
            layer=layer,
974
975
            hidden_states=x,
            router_logits=router_logits,
976
977
978
979
980
981
            routing_bias=layer.e_score_correction_bias,
            global_num_experts=layer.global_num_experts,
            top_k=layer.top_k,
            num_expert_group=layer.num_expert_group,
            topk_group=layer.topk_group,
            apply_router_weight_on_input=layer.apply_router_weight_on_input,
982
        )
983

984
985
986
987
988
989
990
991
992
    def apply(
        self,
        layer: FusedMoE,
        x: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        assert not self.is_monolithic

993
994
995
        # TODO(rob): this validation should happen at kernel selection
        # time in the oracle rather than here.
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
996
            assert layer.activation in ("silu", "relu2_no_mul"), (
997
                "Expected activation to be in ('silu', 'relu2_no_mul'),"
998
                f"but got {layer.activation}"
999
            )
1000
1001

        assert self.kernel is not None
1002
        return self.kernel(
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
            x,
            layer.w13_weight,
            layer.w2_weight,
            topk_weights,
            topk_ids,
            inplace=self.use_inplace,
            activation=layer.activation,
            global_num_experts=layer.global_num_experts,
            expert_map=layer.expert_map,
            apply_router_weight_on_input=layer.apply_router_weight_on_input,
        )

1015

1016
1017
1018
1019
1020
1021
ModelOptFp8Config.LinearMethodCls = ModelOptFp8LinearMethod
ModelOptFp8Config.FusedMoEMethodCls = ModelOptFp8MoEMethod
ModelOptFp8Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod


class ModelOptNvFp4Config(ModelOptQuantConfigBase):
1022
1023
1024
1025
1026
    """Config class for ModelOpt FP4."""

    def __init__(
        self,
        is_checkpoint_nvfp4_serialized: bool,
1027
        kv_cache_quant_algo: str | None,
1028
        exclude_modules: list[str],
1029
1030
        group_size: int = 16,
    ) -> None:
1031
        super().__init__(exclude_modules)
1032
1033
1034
1035
        self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
        if is_checkpoint_nvfp4_serialized:
            logger.warning(
                "Detected ModelOpt NVFP4 checkpoint. Please note that"
1036
1037
                " the format is experimental and could change in future."
            )
1038
1039
1040
1041

            self.group_size = group_size
            self.kv_cache_quant_algo = kv_cache_quant_algo

1042
    def get_name(self) -> QuantizationMethods:
1043
        return "modelopt_fp4"
1044

1045
    def get_supported_act_dtypes(self) -> list[torch.dtype]:
1046
1047
1048
1049
        return [torch.bfloat16, torch.half, torch.float8_e4m3fn]

    @classmethod
    def get_min_capability(cls) -> int:
1050
        return 75
1051

1052
1053
    @classmethod
    def override_quantization_method(
1054
        cls, hf_quant_cfg, user_quant
1055
    ) -> QuantizationMethods | None:
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
        """Detect if this ModelOpt FP4 config should be used based on
        quantization config."""
        if hf_quant_cfg is None:
            return None

        # Use the community standard 'quant_method'
        quant_method = hf_quant_cfg.get("quant_method", "").lower()

        # Only proceed if the method is explicitly "modelopt"
        if quant_method != "modelopt":
            return None

        # Look for ModelOpt-specific config structure
        if "quantization" in hf_quant_cfg:
            quant_config = hf_quant_cfg["quantization"]
            if isinstance(quant_config, dict):
                quant_algo = quant_config.get("quant_algo", "")
                if "NVFP4" in quant_algo:
                    return "modelopt_fp4"
        else:
            # Check for compressed-tensors style config with specific
            # quant_algo field
            quant_algo = hf_quant_cfg.get("quant_algo", "")
            if isinstance(quant_algo, str) and "FP4" in quant_algo.upper():
                return "modelopt_fp4"

        return None

1084
    @classmethod
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
    def _from_config(
        cls,
        *,
        quant_method: str,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
        original_config: dict[str, Any],
        group_size: int | None,
        **kwargs: Any,
    ) -> "ModelOptNvFp4Config":
1095
        is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
1096

1097
1098
1099
        if group_size is None:
            group_size = 16  # Default value

1100
        # For FP4, these fields are required
1101
        if is_checkpoint_nvfp4_serialized and "quantization" in original_config:
1102
            # Check if required fields are present in the quantization config
1103
            quant_config = original_config["quantization"]
1104
            required_fields = ["group_size", "kv_cache_quant_algo", "exclude_modules"]
1105
1106
1107
1108
1109
1110
            missing_fields = [
                field for field in required_fields if field not in quant_config
            ]
            if missing_fields:
                raise ValueError(
                    f"NVFP4 quantization requires the following fields in "
1111
1112
1113
1114
1115
                    f"hf_quant_config.json: {missing_fields}"
                )

        return cls(
            is_checkpoint_nvfp4_serialized,
1116
            kv_cache_quant_method,
1117
1118
1119
            exclude_modules,
            group_size,
        )
1120
1121
1122
1123
1124


class ModelOptNvFp4LinearMethod(LinearMethodBase):
    """Linear method for Model Optimizer NVFP4.
    Supports loading NVFP4 checkpoints with the following structure:
1125

1126
1127
1128
1129
1130
1131
1132
    input_scale: torch.float32, scalar ,
    weight: NVFP4(represented as byte) Shape: [1, X, y/2]
    weight_scale: FP8-E4M3, Shape: [X, Y], aka per block scale,
    weight_scale_2: torch.float32, scalar,
    Args: quant_config: The ModelOpt quantization config.
    """

1133
    def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
1134
        self.quant_config = quant_config
1135
        self.marlin_input_dtype = None
1136

1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
        self.backend = "none"
        if envs.VLLM_NVFP4_GEMM_BACKEND is None:
            if has_flashinfer():
                self.backend = "flashinfer-cutlass"
            elif cutlass_fp4_supported():
                self.backend = "cutlass"
            elif is_fp4_marlin_supported():
                self.backend = "marlin"
        elif envs.VLLM_NVFP4_GEMM_BACKEND.startswith("flashinfer-"):
            self.backend = envs.VLLM_NVFP4_GEMM_BACKEND
            assert has_flashinfer(), f"FlashInfer is required for {self.backend}"
1148
1149
1150
        elif envs.VLLM_NVFP4_GEMM_BACKEND == "cutlass":
            self.backend = "cutlass"
            assert cutlass_fp4_supported(), f"Cutlass is required for {self.backend}"
1151
1152
1153
        elif envs.VLLM_NVFP4_GEMM_BACKEND == "marlin":
            self.backend = "marlin"
            assert is_fp4_marlin_supported(), f"Marlin is required for {self.backend}"
1154
1155

        if self.backend == "none":
1156
            raise ValueError(
1157
1158
                "No valid NVFP4 GEMM backend found. "
                "Please check your platform capability."
1159
            )
1160

1161
1162
        logger.info_once(f"Using {self.backend} for NVFP4 GEMM")

1163
1164
1165
1166
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
1167
        output_partition_sizes: list[int],
1168
1169
1170
1171
1172
1173
1174
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size
        if not self.quant_config.is_checkpoint_nvfp4_serialized:
1175
1176
1177
1178
            raise ValueError(
                "NVFP4 quantization was selected, "
                " dynamic quantization is not supported."
            )
1179
1180
1181
1182
1183
1184
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition

1185
1186
1187
1188
        if input_size_per_partition % 16 != 0:
            raise ValueError(
                "Unsupported model when in features size is not multiple of 16"
            )
1189
        # The nvfp4 weight is still represented as
1190
1191
1192
1193
1194
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_nvfp4_serialized
            else params_dtype
        )
1195
1196
1197
1198
1199
1200
        # Weight
        weight = ModelWeightParameter(
            data=torch.empty(
                # 2 fp4 items are packed in the input dimension
                layer.output_size_per_partition,
                layer.input_size_per_partition // 2,
1201
1202
                dtype=torch.uint8,
            ),
1203
1204
            input_dim=1,
            output_dim=0,
1205
1206
            weight_loader=weight_loader,
        )
1207
1208
1209
        layer.register_parameter("weight", weight)

        # Input Weight Scale
1210
1211
1212
1213
        input_scale = PerTensorScaleParameter(
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )
1214
1215
1216
        layer.register_parameter("input_scale", input_scale)

        # Global Weight Scale
1217
1218
1219
1220
        weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )
1221
1222
1223
        layer.register_parameter("weight_scale_2", weight_scale_2)

        # Per Block Weight Scale
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
        weight_scale = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition // self.quant_config.group_size,
                dtype=weight_dtype,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244

        layer.register_parameter("weight_scale", weight_scale)

    def process_weights_after_loading(self, layer: Module) -> None:
        # global scales:
        input_scale_2 = layer.input_scale.max().to(torch.float32)
        layer.input_scale = Parameter(input_scale_2, requires_grad=False)

        weight_scale_2 = layer.weight_scale_2.max().to(torch.float32)
        layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False)

1245
1246
1247
        layer.alpha = Parameter(
            layer.input_scale * layer.weight_scale_2, requires_grad=False
        )
1248

1249
1250
        # Calculate `1 / input_scale` so that we don't need to do so at runtime
        layer.input_scale_inv = Parameter(
1251
1252
            (1 / layer.input_scale).to(torch.float32), requires_grad=False
        )
1253

1254
1255
1256
        # Swizzle the weight blockscale.
        # contracting dimension is input dimension
        # block_size = 16;
1257
1258
1259
        assert layer.weight_scale.dtype == torch.float8_e4m3fn, (
            "Weight Block scale must be represented as FP8-E4M3"
        )
1260

1261
1262
1263
1264
1265
        if self.backend == "marlin":
            prepare_fp4_layer_for_marlin(layer)
            del layer.alpha
            del layer.input_scale
        elif self.backend == "flashinfer-trtllm":
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
            # FlashInfer TRTLLM FP4 GEMM requires a different weight layout.
            # FlashInfer provides nvfp4_quantize to quantize + shuffle the
            # layout but we use our own quantization so we have to call
            # shuffles ourselves.
            from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a

            weight = layer.weight.data
            weight_scale = layer.weight_scale.data

            epilogue_tile_m = 128
1276
1277
1278
1279
1280
1281
            weight = shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m)
            weight_scale = (
                shuffle_matrix_sf_a(weight_scale.view(torch.uint8), epilogue_tile_m)
                .reshape(weight_scale.shape)
                .view(torch.float8_e4m3fn)
            )
1282

1283
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
1284
1285
            layer.weight = Parameter(weight, requires_grad=False)
        else:
1286
1287
            # Swizzle block scales and pad the packed NVFP4 weights for kernel
            # alignment (CUTLASS/FlashInfer require K and N divisible by 32).
1288
            swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
1289
            layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False)
1290
1291
1292
1293
1294
1295

            weight, weights_padding_cols = pad_nvfp4_weight_for_cutlass(
                layer.weight.data
            )
            layer.weights_padding_cols = weights_padding_cols
            layer.weight = Parameter(weight, requires_grad=False)
1296
1297
1298
1299
1300

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
1301
        bias: torch.Tensor | None = None,
1302
    ) -> torch.Tensor:
1303
        if self.backend == "marlin":
1304
1305
1306
1307
1308
1309
1310
1311
            return apply_fp4_marlin_linear(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale,
                weight_scale_2=layer.weight_scale_2,
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
1312
                bias=bias,
1313
                input_dtype=self.marlin_input_dtype,
1314
            )
1315

1316
1317
1318
        output_dtype = x.dtype

        # quantize BF16 or FP16 to (FP4 and interleaved block scale)
1319
1320
1321
        x_fp4, x_blockscale = scaled_fp4_quant(
            x, layer.input_scale_inv, is_sf_swizzled_layout=True, backend=self.backend
        )
1322
1323
1324

        # validate dtypes of quantized input, input block scale,
        # weight and weight_blockscale
1325
1326
1327
1328
1329
        assert x_fp4.dtype == torch.uint8
        assert layer.weight.dtype == torch.uint8
        assert x_blockscale.dtype == torch.float8_e4m3fn
        assert layer.weight_scale.dtype == torch.float8_e4m3fn
        assert layer.alpha.dtype == torch.float32
1330

1331
1332
1333
1334
1335
1336
        # Pad activations to match weight K-dimension padding
        weights_padding_cols = getattr(layer, "weights_padding_cols", 0)
        output_size = layer.output_size_per_partition
        output_shape = [x.shape[0], output_size]
        x_fp4 = pad_nvfp4_activation_for_cutlass(x_fp4, weights_padding_cols)

1337
1338
1339
1340
        mm_args = (
            x_fp4,
            layer.weight,
            x_blockscale,
1341
            layer.weight_scale,
1342
1343
1344
            layer.alpha,
            output_dtype,
        )
1345

1346
1347
1348
        if self.backend.startswith("flashinfer-"):
            backend_name = self.backend[len("flashinfer-") :]
            out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name)
1349
        else:
1350
            assert self.backend == "cutlass"
1351
1352
            out = cutlass_scaled_fp4_mm(*mm_args)

1353
1354
1355
        # Slice output to remove N-dimension padding
        out = slice_nvfp4_output(out, output_size)

1356
1357
1358
        if bias is not None:
            out = out + bias
        return out.view(*output_shape)
1359
1360
1361
1362
1363


class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
    """
    MoE Method for FP4 Quantization.
1364
    Args:
1365
1366
1367
        quant_config: NVFP4 Quant Config
    """

1368
1369
1370
    def __init__(
        self,
        quant_config: ModelOptNvFp4Config,
1371
        moe_config: FusedMoEConfig,
1372
    ) -> None:
1373
        super().__init__(moe_config)
1374
        self.quant_config = quant_config
1375
1376
1377
1378
1379
1380
1381
1382
1383
        # Select experts implementation.
        self.nvfp4_backend, self.experts_cls = select_nvfp4_moe_backend(
            config=self.moe,
            weight_key=kNvfp4Static,
            activation_key=kNvfp4Dynamic,
        )

        # Delay creation of the kernel until after process-weights.
        self.kernel: mk.FusedMoEModularKernel | None = None
1384
1385
1386
1387

        self.use_global_sf = is_global_sf_supported_for_nvfp4_backend(
            self.nvfp4_backend
        )
1388
1389
1390
1391
1392
1393

    @property
    def topk_indices_dtype(self) -> torch.dtype | None:
        if self.kernel is not None:
            return self.kernel.prepare_finalize.topk_indices_dtype()
        return None
1394

1395
1396
1397
1398
    def maybe_make_prepare_finalize(
        self,
        routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
    ) -> mk.FusedMoEPrepareAndFinalize | None:
1399
        if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
1400
            return None
1401
        elif self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTLASS:
1402
1403
            # For no-EP case, don't use the MKM framework.
            if not self.moe.moe_parallel_config.use_all2all_kernels:
1404
                return None
1405
            # For now, fp4 moe only works with the flashinfer dispatcher.
1406
1407
1408
            prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(
                self.moe
            )
1409
1410
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
1411
        else:
1412
            return super().maybe_make_prepare_finalize(routing_tables)
1413

1414
1415
1416
    def select_gemm_impl(
        self,
        prepare_finalize: mk.FusedMoEPrepareAndFinalize,
1417
        layer: torch.nn.Module,
1418
    ) -> mk.FusedMoEPermuteExpertsUnpermute:
1419
        assert self.moe_quant_config is not None
1420
1421
1422
1423
1424
1425
        assert self.experts_cls is not None
        return make_nvfp4_moe_kernel_for_mkm(
            moe_config=self.moe,
            quant_config=self.moe_quant_config,
            experts_cls=self.experts_cls,
            prepare_finalize=prepare_finalize,
1426
        )
1427

1428
1429
1430
1431
1432
1433
    def uses_weight_scale_2_pattern(self) -> bool:
        """
        FP4 variants use 'weight_scale_2' pattern for per-tensor weight scales.
        """
        return True

1434
1435
1436
1437
1438
1439
1440
1441
1442
    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
1443
        assert self.quant_config.is_checkpoint_nvfp4_serialized
1444

1445
1446
        layer.num_experts = num_experts
        layer.params_dtype = params_dtype
1447
1448
1449
1450
        layer.quant_config = self.quant_config
        weight_dtype = torch.uint8
        weight_scale_dtype = torch.float8_e4m3fn
        weight_loader = extra_weight_attrs.get("weight_loader")
1451
        global_num_experts = extra_weight_attrs.get("global_num_experts")
1452
        w13_num_shards = 2 if self.moe.is_act_and_mul else 1
1453
1454
1455
1456
        # GEMM 1
        w13_weight = ModelWeightParameter(
            data=torch.empty(
                num_experts,
1457
                w13_num_shards * intermediate_size_per_partition,
1458
1459
                # 2 fp4 items are packed in the input dimension
                hidden_size // 2,
1460
1461
                dtype=weight_dtype,
            ),
1462
1463
            input_dim=1,
            output_dim=2,
1464
1465
            weight_loader=weight_loader,
        )
1466
1467
1468
1469
1470
1471
1472
1473
1474
        layer.register_parameter("w13_weight", w13_weight)

        # GEMM 2
        w2_weight = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                hidden_size,
                # 2 fp4 items are packed in the input dimension
                intermediate_size_per_partition // 2,
1475
1476
                dtype=weight_dtype,
            ),
1477
1478
            input_dim=1,
            output_dim=2,
1479
1480
            weight_loader=weight_loader,
        )
1481
1482
1483
1484
1485
        layer.register_parameter("w2_weight", w2_weight)

        w13_weight_scale = ModelWeightParameter(
            data=torch.empty(
                num_experts,
1486
                w13_num_shards * intermediate_size_per_partition,
1487
1488
                # 2 fp4 items are packed in the input dimension
                hidden_size // self.quant_config.group_size,
1489
1490
                dtype=weight_scale_dtype,
            ),
1491
1492
            input_dim=1,
            output_dim=2,
1493
1494
            weight_loader=weight_loader,
        )
1495
1496
1497
1498
1499
1500
1501
        layer.register_parameter("w13_weight_scale", w13_weight_scale)

        w2_weight_scale = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                hidden_size,
                # 2 fp4 items are packed in the input dimension
1502
1503
1504
                intermediate_size_per_partition // self.quant_config.group_size,
                dtype=weight_scale_dtype,
            ),
1505
1506
            input_dim=1,
            output_dim=2,
1507
1508
            weight_loader=weight_loader,
        )
1509
1510
1511
        layer.register_parameter("w2_weight_scale", w2_weight_scale)

        extra_weight_attrs.update(
1512
1513
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
        )
1514
1515

        w13_weight_scale_2 = PerTensorScaleParameter(
1516
            data=torch.empty(num_experts, w13_num_shards, dtype=torch.float32),
1517
1518
            weight_loader=weight_loader,
        )
1519
1520
1521
1522
        layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)

        w2_weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(num_experts, dtype=torch.float32),
1523
1524
            weight_loader=weight_loader,
        )
1525
1526
1527
        layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2)

        extra_weight_attrs.update(
1528
1529
            {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
        )
1530

1531
1532
        global_sf_num_experts = (
            global_num_experts if self.use_global_sf else num_experts
1533
        )
1534
        w13_input_scale = PerTensorScaleParameter(
1535
            data=torch.empty(
1536
                global_sf_num_experts,
1537
                w13_num_shards,
1538
1539
                dtype=torch.float32,
            ),
1540
1541
            weight_loader=weight_loader,
        )
1542
1543
        layer.register_parameter("w13_input_scale", w13_input_scale)

1544
        w2_input_scale = PerTensorScaleParameter(
1545
            data=torch.empty(global_sf_num_experts, dtype=torch.float32),
1546
1547
            weight_loader=weight_loader,
        )
1548
1549
1550
        layer.register_parameter("w2_input_scale", w2_input_scale)

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
1551
1552
1553
        """
        Convert NVFP4 MoE weights into kernel format and setup the kernel.
        """
1554

1555
        # Use a single gscale for w13.
1556
        if self.moe.is_act_and_mul and not torch.allclose(
1557
1558
            layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]
        ):
1559
1560
            logger.warning_once(
                "w1_weight_scale_2 must match w3_weight_scale_2. "
1561
1562
                "Accuracy may be affected."
            )
1563
        w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0].contiguous()
1564

1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
        (
            w13,
            w13_scale,
            w13_scale_2,
            a13_scale,
            w2,
            w2_scale,
            w2_scale_2,
            a2_scale,
        ) = convert_to_nvfp4_moe_kernel_format(
            nvfp4_backend=self.nvfp4_backend,
            layer=layer,
            w13=layer.w13_weight,
            w13_scale=layer.w13_weight_scale,
            w13_scale_2=w13_weight_scale_2,
            a13_scale=layer.w13_input_scale,
            w2=layer.w2_weight,
            w2_scale=layer.w2_weight_scale,
            w2_scale_2=layer.w2_weight_scale_2,
            a2_scale=layer.w2_input_scale,
            is_act_and_mul=self.moe.is_act_and_mul,
1586
        )
1587

1588
1589
1590
1591
1592
1593
1594
1595
        replace_parameter(layer, "w13_weight", w13)
        replace_parameter(layer, "w13_weight_scale", w13_scale)
        replace_parameter(layer, "w13_weight_scale_2", w13_scale_2)
        replace_parameter(layer, "w13_input_scale", a13_scale)
        replace_parameter(layer, "w2_weight", w2)
        replace_parameter(layer, "w2_weight_scale", w2_scale)
        replace_parameter(layer, "w2_weight_scale_2", w2_scale_2)
        replace_parameter(layer, "w2_input_scale", a2_scale)
1596

1597
1598
1599
1600
        # Setup modular kernel for TP case and naive DP/EP case.
        # In non-naive DP/EP case, we will create a ModularKernelMethod.
        # TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
        # in both cases.
1601
        self.moe_quant_config = self.get_fused_moe_quant_config(layer)
1602
1603
1604
1605
1606
        if self.moe_quant_config and (
            (not self.moe.moe_parallel_config.use_all2all_kernels)
            or self.moe.moe_parallel_config.use_naive_all2all_kernels
        ):
            assert self.experts_cls is not None
1607
            self.kernel = make_nvfp4_moe_kernel(
1608
                moe_quant_config=self.moe_quant_config,
1609
                moe_config=self.moe,
1610
                experts_cls=self.experts_cls,
1611
            )
1612

1613
1614
1615
1616
    @property
    def do_post_quant_allgather(self):
        return self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM

1617
1618
1619
1620
1621
1622
1623
    def prepare_dp_allgather_tensor(
        self,
        layer: FusedMoE,
        hidden_states: torch.Tensor,
        router_logits: torch.Tensor,
    ) -> tuple[torch.Tensor, list[torch.Tensor]]:
        """Optionally prepare extra tensors to carry through DP allgather/EP."""
1624
1625
1626
1627
1628
1629
        if self.nvfp4_backend != NvFp4MoeBackend.FLASHINFER_TRTLLM:
            raise RuntimeError(
                "prepare_dp_allgather_tensor is only supported for "
                "FlashInfer TRTLLM NVFP4 MoE backend."
            )

1630
1631
1632
1633
        import flashinfer

        hidden_states_fp4, hidden_states_sf = flashinfer.fp4_quantize(
            hidden_states,
1634
            layer.a1_gscale,
1635
1636
1637
1638
1639
            is_sf_swizzled_layout=False,
        )
        extra_tensors: list[torch.Tensor] = [hidden_states_sf]
        return hidden_states_fp4, extra_tensors

1640
    def get_fused_moe_quant_config(
1641
        self, layer: torch.nn.Module
1642
    ) -> FusedMoEQuantConfig | None:
1643
1644
1645
        return make_nvfp4_moe_quant_config(
            backend=self.nvfp4_backend,
            w13_scale=layer.w13_weight_scale,
1646
            w2_scale=layer.w2_weight_scale,
1647
1648
1649
1650
            w13_scale_2=layer.w13_weight_scale_2,
            w2_scale_2=layer.w2_weight_scale_2,
            a13_scale=layer.w13_input_scale,
            a2_scale=layer.w2_input_scale,
1651
1652
        )

1653
1654
1655
1656
    @property
    def supports_eplb(self) -> bool:
        return True

1657
1658
1659
1660
1661
1662
1663
1664
    @property
    def is_monolithic(self) -> bool:
        return (
            self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
            and not self.moe.moe_parallel_config.enable_eplb
        )

    def apply_monolithic(
1665
        self,
1666
        layer: FusedMoE,
1667
1668
        x: torch.Tensor,
        router_logits: torch.Tensor,
1669
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
1670
1671
        assert self.is_monolithic
        assert (
1672
            self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
1673
            and not layer.enable_eplb
1674
        )
1675

1676
1677
1678
        return flashinfer_trtllm_fp4_moe(
            layer=layer,
            x=x,
1679
            router_logits=router_logits,
1680
1681
1682
1683
1684
1685
1686
            top_k=layer.top_k,
            activation=layer.activation,
            global_num_experts=layer.global_num_experts,
            num_expert_group=layer.num_expert_group,
            topk_group=layer.topk_group,
            custom_routing_function=layer.custom_routing_function,
            e_score_correction_bias=layer.e_score_correction_bias,
1687
        )
1688

1689
1690
1691
1692
1693
1694
1695
1696
1697
    def apply(
        self,
        layer: FusedMoE,
        x: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        assert not self.is_monolithic

1698
        # EPLB path
1699
1700
        if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
            assert layer.enable_eplb
1701
1702
1703
1704
1705
1706
            return flashinfer_trtllm_fp4_routed_moe(
                layer=layer,
                x=x,
                topk_ids=topk_ids,
                topk_weights=topk_weights,
                top_k=layer.top_k,
1707
                activation=layer.activation,
1708
1709
                global_num_experts=layer.global_num_experts,
            )
1710
1711
1712
        else:
            assert self.kernel is not None
            return self.kernel(
1713
1714
1715
1716
1717
                x,
                layer.w13_weight,
                layer.w2_weight,
                topk_weights,
                topk_ids,
1718
                inplace=False,
1719
1720
1721
1722
                activation=layer.activation,
                global_num_experts=layer.global_num_experts,
                expert_map=layer.expert_map,
                apply_router_weight_on_input=layer.apply_router_weight_on_input,
1723
            )
1724
1725
1726
1727
1728


ModelOptNvFp4Config.LinearMethodCls = ModelOptNvFp4LinearMethod
ModelOptNvFp4Config.FusedMoEMethodCls = ModelOptNvFp4FusedMoE
ModelOptNvFp4Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod