"vscode:/vscode.git/clone" did not exist on "6cc1e7d96dab6b9c344ec87dec6dc9ab07ad5d21"
modelopt.py 61.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from fnmatch import fnmatch
5
from typing import TYPE_CHECKING, Any, Optional
6
7
8
9
10

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter

11
12
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
13
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
14
from vllm.attention.layer import Attention
15
from vllm.logger import init_logger
16
from vllm.model_executor.layers.fused_moe import FusedMoERouter
17
from vllm.model_executor.layers.fused_moe.config import (
18
    FusedMoEConfig,
19
20
    FusedMoEQuantConfig,
)
21
from vllm.model_executor.layers.fused_moe.layer import (
22
23
24
25
    FusedMoE,
    FusedMoEMethodBase,
    FusedMoeWeightScaleSupported,
)
26
27
28
29
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
    Fp8MoeBackend,
    convert_to_fp8_moe_kernel_format,
    make_fp8_moe_kernel,
30
    make_fp8_moe_kernel_for_mkm,
31
32
33
    make_fp8_moe_quant_config,
    select_fp8_moe_backend,
)
34
35
36
37
38
from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
    NvFp4MoeBackend,
    convert_to_nvfp4_moe_kernel_format,
    is_global_sf_supported_for_nvfp4_backend,
    make_nvfp4_moe_kernel,
39
    make_nvfp4_moe_kernel_for_mkm,
40
41
42
    make_nvfp4_moe_quant_config,
    select_nvfp4_moe_backend,
)
43
44
45
46
47
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
)
48
from vllm.model_executor.layers.quantization import QuantizationMethods
49
from vllm.model_executor.layers.quantization.base_config import (
50
51
52
    QuantizationConfig,
    QuantizeMethodBase,
)
53
54
55
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
    init_fp8_linear_kernel,
)
56
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
57
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
58
    build_flashinfer_fp4_cutlass_moe_prepare_finalize,
59
    flashinfer_trtllm_fp4_moe,
60
    flashinfer_trtllm_fp4_routed_moe,
61
)
62
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
63
    apply_fi_trtllm_fp8_per_tensor_moe,
64
    build_flashinfer_fp8_cutlass_moe_prepare_finalize,
65
)
66
67
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
    W8A8BlockFp8LinearOp,
68
69
    process_fp8_input_tensor_strategy_moe,
    process_fp8_weight_tensor_strategy_moe,
70
)
71
72
73
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
    get_marlin_input_dtype,
)
74
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
75
76
77
78
    apply_fp4_marlin_linear,
    is_fp4_marlin_supported,
    prepare_fp4_layer_for_marlin,
)
79
from vllm.model_executor.layers.quantization.utils.quant_utils import (
80
81
82
    GroupShape,
    cutlass_fp4_supported,
    is_layer_skipped,
83
84
85
    kFp8DynamicTokenSym,
    kFp8StaticTensorSym,
    kFp8StaticTokenSym,
86
87
    kNvfp4Dynamic,
    kNvfp4Static,
88
89
    swizzle_blockscale,
)
90
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
91
    cutlass_block_fp8_supported,
92
93
    requantize_with_max_scale,
)
94
95
96
97
98
99
from vllm.model_executor.parameter import (
    BlockQuantScaleParameter,
    ChannelQuantScaleParameter,
    ModelWeightParameter,
    PerTensorScaleParameter,
)
100
from vllm.model_executor.utils import replace_parameter
101
102
103
104
from vllm.utils.flashinfer import (
    flashinfer_scaled_fp4_mm,
    has_flashinfer,
)
105

106
107
108
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

109
110
logger = init_logger(__name__)

111
112
113
114
115
116
117
118
119
120
QUANT_ALGOS = [
    # FP8 (per-tensor weight + optional static activation scale).
    "FP8",
    # FP8 per-channel weight scale + per-token activation scale.
    "FP8_PER_CHANNEL_PER_TOKEN",
    # FP8 per-block weight-only (ModelOpt may emit this as lowercase).
    "FP8_PB_WO",
    # FP4
    "NVFP4",
]
121
KV_CACHE_QUANT_ALGOS = ["FP8"]
122
123


124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
    """

    def __init__(self, quant_config: "ModelOptQuantConfigBase"):
        super().__init__(quant_config)


class ModelOptQuantConfigBase(QuantizationConfig):
    LinearMethodCls: type = LinearMethodBase
    FusedMoEMethodCls: type = FusedMoEMethodBase
    KVCacheMethodCls: type = BaseKVCacheMethod

    def __init__(
        self,
        exclude_modules: list[str],
    ):
        super().__init__()
        self.exclude_modules: list[str] = exclude_modules

    def is_layer_excluded(self, prefix: str) -> bool:
        """
        Check if a layer should be excluded from quantization.

        Handles both exact matching (for fused layers) and ModelOpt wildcard matching.

        The ModelOpt exclude_modules list is a list of wildcards.
        """
        if len(self.exclude_modules) == 0:
            return False

        # First check exact matching with fused layer support
        if is_layer_skipped(prefix, self.exclude_modules, self.packed_modules_mapping):
            return True

        # TODO: This special hard coded logic is not needed for quantized checkpoints
        # generated by ModelOpt >= 0.39.0 where they are handled natually by the
        # exclude_modules config. But need to keep them for loading quantized
        # checkpoints generated by older versions. Then check substring matching
        # for patterns not caught by exact match
        for exclude_module in self.exclude_modules:
            # Skip exact matches already handled above
            if exclude_module != prefix and (
                exclude_module in prefix
                or (
                    prefix.startswith("language_model.")
                    and exclude_module in prefix.removeprefix("language_model.")
                )
            ):
                return True

        # modelopt exclude modules are not simple strings, they are wildcards
        for wildcard_pattern in self.exclude_modules:
            if fnmatch(prefix, wildcard_pattern):
                return True

        return False

    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
        # handle kv-cache first so we can focus only on weight quantization thereafter
        if isinstance(layer, Attention):
            return self.KVCacheMethodCls(self)

        # handle exclusion
        if self.is_layer_excluded(prefix):
            if isinstance(layer, LinearBase):
                return UnquantizedLinearMethod()
            return None

        # TODO: This special hard coded logic is not needed for quantized checkpoints
        # generated by ModelOpt >= 0.39.0 where they are handled natually by the
        # exclude_modules config. But need to keep them for loading quantized
        # checkpoints generated by older versions. Then check substring matching
        # for patterns not caught by exact match
        if "vision_tower" in prefix or "vision_model" in prefix:
            return UnquantizedLinearMethod()

        # now, the layer is quantized, handle it here
        if isinstance(layer, LinearBase):
206
207
208
209
            quant_method = self.LinearMethodCls(self)
            if getattr(quant_method, "backend", "") == "marlin":
                quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
            return quant_method
210
        elif isinstance(layer, FusedMoE):
211
212
213
            quant_method = self.FusedMoEMethodCls(
                quant_config=self, moe_config=layer.moe_config
            )
214
215
216
            if getattr(quant_method, "backend", "") == "marlin":
                quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
            return quant_method
217
218
219
220
221

        return None

    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if len(self.exclude_modules) > 0:
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
            # This is a workaround for the weights remapping issue:
            # https://github.com/vllm-project/vllm/issues/28072
            # Right now, the Nvidia ModelOpt library use just one wildcard pattern:
            #        module_path*
            # It gets applied if the whole tree of modules rooted at module_path
            # is not quantized. Here we replace such pattern by 2 patterns that are
            # collectively equivalent to the original pattern:
            #        module_path
            #        module_path.*
            new_exclude_modules = []
            for exclude in self.exclude_modules:
                if len(exclude) >= 2 and exclude[-1] == "*" and exclude[-2] != ".":
                    new_exclude_modules.append(exclude[:-1])
                    new_exclude_modules.append(exclude[:-1] + ".*")
                else:
                    new_exclude_modules.append(exclude)

            self.exclude_modules = hf_to_vllm_mapper.apply_list(new_exclude_modules)
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288

    @staticmethod
    def get_config_filenames() -> list[str]:
        return ["hf_quant_config.json"]

    @classmethod
    def _from_config(
        cls,
        *,
        quant_method: str,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
        original_config: dict[str, Any],
        group_size: int | None,
    ) -> "ModelOptQuantConfigBase":
        raise NotImplementedError("Please implement this function in sub classes")

    @classmethod
    def from_config(cls, config: dict[str, Any]) -> "ModelOptQuantConfigBase":
        # Handle both ModelOpt format and compressed-tensors style format
        if "quantization" in config:
            # Traditional ModelOpt format:
            # {"quantization": {"quant_algo": "..."}}
            quant_config = cls.get_from_keys(config, ["quantization"])
            if not isinstance(quant_config, dict):
                raise ValueError("Expected 'quantization' to be a dictionary in config")

            quant_method = quant_config.get("quant_algo")

            # Handle kv_cache_quant_algo with proper type validation
            kv_cache_quant_method = quant_config.get("kv_cache_quant_algo")

            # Handle group_size with proper type validation
            group_size_raw = quant_config.get("group_size")

            # "exclude_modules" is the key in the legacy hf_quant_config.json
            exclude_modules = quant_config.get("exclude_modules", [])
        else:
            # Compressed-tensors style format:
            # {"quant_algo": "...", "quant_method": "modelopt"}
            quant_method = config.get("quant_algo")
            kv_cache_quant_method = config.get("kv_cache_quant_algo")
            # "ignore" is the key in config.json
            exclude_modules = config.get("ignore", [])
            group_size_raw = config.get("group_size")

        if not quant_method:
            raise ValueError("Missing 'quant_algo' in quantization config")

289
290
291
        # Normalize quant_algo for robust matching (ModelOpt may emit lowercase).
        quant_method = str(quant_method).upper()

292
293
294
295
296
297
298
299
        if kv_cache_quant_method is None:
            # No KV cache quantization, keep this branch just to have this comment
            pass
        elif not isinstance(kv_cache_quant_method, str):
            raise ValueError(
                f"kv_cache_quant_algo must be a string, got "
                f"{type(kv_cache_quant_method)}"
            )
300
301
        else:
            kv_cache_quant_method = kv_cache_quant_method.upper()
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336

        if not isinstance(exclude_modules, list):
            raise ValueError(
                f"exclude_modules must be a list, got {type(exclude_modules)}"
            )

        if group_size_raw is None:
            group_size = None
        elif isinstance(group_size_raw, int):
            group_size = group_size_raw
        else:
            try:
                group_size = int(group_size_raw)
            except (ValueError, TypeError):
                raise ValueError(
                    f"group_size must be an integer, got {type(group_size_raw)}"
                ) from None

        if quant_method not in QUANT_ALGOS:
            raise ValueError(
                f"ModelOpt currently only supports: {QUANT_ALGOS} "
                "quantizations in vLLM. Please check the "
                "`hf_quant_config.json` file for your model's "
                "quant configuration."
            )
        return cls._from_config(
            quant_method=quant_method,
            kv_cache_quant_method=kv_cache_quant_method,
            exclude_modules=exclude_modules,
            group_size=group_size,
            original_config=config,
        )


class ModelOptFp8Config(ModelOptQuantConfigBase):
337
338
339
340
    """Config class for ModelOpt FP8."""

    def __init__(
        self,
341
        quant_method: str,
342
343
344
        is_checkpoint_fp8_serialized: bool,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
345
    ) -> None:
346
        super().__init__(exclude_modules)
347
        self.quant_method = quant_method
348
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
349
        self.kv_cache_quant_method = kv_cache_quant_method
350
        if is_checkpoint_fp8_serialized:
351
            logger.warning(
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
                "Detected ModelOpt fp8 checkpoint (quant_algo=%s). Please note "
                "that the format is experimental and could change.",
                quant_method,
            )

        # Select LinearMethod implementation based on quant_algo.
        if self.quant_method == "FP8":
            self.LinearMethodCls = ModelOptFp8LinearMethod
        elif self.quant_method == "FP8_PER_CHANNEL_PER_TOKEN":
            self.LinearMethodCls = ModelOptFp8PcPtLinearMethod
        elif self.quant_method == "FP8_PB_WO":
            self.LinearMethodCls = ModelOptFp8PbWoLinearMethod
        else:
            raise ValueError(
                "Unsupported ModelOpt FP8 quant_algo for vLLM: "
                f"{self.quant_method}. Supported: FP8 / "
                "FP8_PER_CHANNEL_PER_TOKEN / FP8_PB_WO."
369
            )
370

371
    def get_name(self) -> QuantizationMethods:
372
373
        return "modelopt"

374
    def get_supported_act_dtypes(self) -> list[torch.dtype]:
375
376
377
378
379
380
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
        return 89

381
382
    @classmethod
    def override_quantization_method(
383
        cls, hf_quant_cfg, user_quant
384
    ) -> QuantizationMethods | None:
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
        """Detect if this ModelOpt config should be used based on
        quantization config."""

        if hf_quant_cfg is None:
            return None

        # Use the community standard 'quant_method'
        quant_method = hf_quant_cfg.get("quant_method", "").lower()

        # Only proceed if the method is explicitly "modelopt"
        if quant_method != "modelopt":
            return None

        # Look for ModelOpt-specific config structure
        if "quantization" in hf_quant_cfg:
            quant_config = hf_quant_cfg["quantization"]
            if isinstance(quant_config, dict):
402
403
                quant_algo = str(quant_config.get("quant_algo", ""))
                if "FP8" in quant_algo.upper():
404
405
406
                    return "modelopt"
        else:
            # Check for compressed-tensors style config with specific quant_algo
407
408
            quant_algo = str(hf_quant_cfg.get("quant_algo", ""))
            if "FP8" in quant_algo.upper():
409
410
411
412
                return "modelopt"

        return None

413
    @classmethod
414
415
416
417
418
419
420
421
422
    def _from_config(
        cls,
        *,
        quant_method: str,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
        original_config: dict[str, Any],
        **kwargs: Any,
    ) -> "ModelOptFp8Config":
423
        is_checkpoint_fp8_serialized = "FP8" in quant_method
424

425
426
427
428
429
430
        return cls(
            quant_method,
            is_checkpoint_fp8_serialized,
            kv_cache_quant_method,
            exclude_modules,
        )
431

432
433
434
435

class ModelOptFp8LinearMethod(LinearMethodBase):
    """Linear method for Model Optimizer static quantization.
    Supports loading FP8 checkpoints with static weight scale and
436
    activation scale. Future support might be added for dynamic
437
438
439
440
    scales.

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
441
    2. Only support float8_e4m3fn datatype
442
443
444
        Args: quant_config: The ModelOpt quantization config.
    """

445
    def __init__(self, quant_config: ModelOptFp8Config) -> None:
446
        self.quant_config = quant_config
447
448
449
450
451
        self.fp8_linear = init_fp8_linear_kernel(
            activation_quant_key=kFp8StaticTensorSym,
            weight_quant_key=kFp8StaticTensorSym,
            out_dtype=torch.get_default_dtype(),
            module_name=self.__class__.__name__,
452
        )
453
454
455
456
457

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
458
        output_partition_sizes: list[int],
459
460
461
462
463
464
465
466
467
468
469
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
470
471
472
473
474
475
476
477
478
479
480
481
482
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_fp8_serialized
            else params_dtype
        )
        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition, input_size_per_partition, dtype=weight_dtype
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
483
484
485
486
        layer.register_parameter("weight", weight)

        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
487
488
489
490
            weight_scale = PerTensorScaleParameter(
                data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
                weight_loader=weight_loader,
            )
491
492
493
            weight_scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("weight_scale", weight_scale)
            # INPUT SCALE
494
495
496
497
            scale = PerTensorScaleParameter(
                data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
                weight_loader=weight_loader,
            )
498
499
500
501
502

            scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("input_scale", scale)

    def process_weights_after_loading(self, layer: Module) -> None:
503
504
505
506
        weight = layer.weight
        max_w_scale = layer.weight_scale.max()
        if not (layer.weight_scale == layer.weight_scale[0]).all():
            max_w_scale, weight = requantize_with_max_scale(
507
508
                layer.weight, layer.weight_scale, layer.logical_widths
            )
509
510
        layer.weight = Parameter(weight.t(), requires_grad=False)
        layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
511
        layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
512
513
514
515
516

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
517
        bias: torch.Tensor | None = None,
518
    ) -> torch.Tensor:
519
        return self.fp8_linear.apply_weights(layer, x, bias)
520
521


522
523
524
525
526
527
528
529
530
531
532
class ModelOptFp8PcPtLinearMethod(LinearMethodBase):
    """Linear method for ModelOpt FP8_PER_CHANNEL_PER_TOKEN checkpoints.

    Expected checkpoint structure (per Linear):
    - weight: fp8-e4m3fn, shape [out, in]
    - weight_scale: fp32, shape [out] (per-output-channel)
    - no input_scale (activations are dynamically quantized per-token)
    """

    def __init__(self, quant_config: ModelOptFp8Config) -> None:
        self.quant_config = quant_config
533
534
535
536
537
        self.fp8_linear = init_fp8_linear_kernel(
            activation_quant_key=kFp8DynamicTokenSym,
            weight_quant_key=kFp8StaticTokenSym,
            out_dtype=torch.get_default_dtype(),
            module_name=self.__class__.__name__,
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
        )

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size

        if not self.quant_config.is_checkpoint_fp8_serialized:
            raise ValueError(
                "FP8_PER_CHANNEL_PER_TOKEN currently only supports "
                "FP8-serialized checkpoints."
            )

        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition

        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition,
                dtype=torch.float8_e4m3fn,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
        layer.register_parameter("weight", weight)

        weight_scale = ChannelQuantScaleParameter(
            data=torch.empty(output_size_per_partition, dtype=torch.float32),
            output_dim=0,
            weight_loader=weight_loader,
        )
        weight_scale[:] = torch.finfo(torch.float32).min
        layer.register_parameter("weight_scale", weight_scale)

    def process_weights_after_loading(self, layer: Module) -> None:
        layer.weight = Parameter(layer.weight.t(), requires_grad=False)
        layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
594
        return self.fp8_linear.apply_weights(layer, x, bias)
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715


class ModelOptFp8PbWoLinearMethod(LinearMethodBase):
    """Linear method for ModelOpt FP8_PB_WO checkpoints.

    ModelOpt exports `weight_scale` as a 4D tensor:
      [out_blk, 1, in_blk, 1]
    where block size is typically 128 for both dims.

    vLLM executes it as FP8 GEMM with *dynamic per-token* activation quant.
    """

    _WEIGHT_BLOCK_SIZE: tuple[int, int] = (128, 128)

    def __init__(self, quant_config: ModelOptFp8Config) -> None:
        self.quant_config = quant_config
        block_n, block_k = self._WEIGHT_BLOCK_SIZE
        self.weight_block_size = list(self._WEIGHT_BLOCK_SIZE)
        self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
            weight_group_shape=GroupShape(block_n, block_k),
            act_quant_group_shape=GroupShape(1, block_k),
            cutlass_block_fp8_supported=cutlass_block_fp8_supported(),
            use_aiter_and_is_supported=False,
        )

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size

        if not self.quant_config.is_checkpoint_fp8_serialized:
            raise ValueError(
                "FP8_PB_WO currently only supports FP8-serialized checkpoints."
            )

        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition

        # Expose block size so the v2 weight loaders can translate offsets from
        # element-space -> block-space for BlockQuantScaleParameter.
        layer.weight_block_size = self.weight_block_size

        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition,
                dtype=torch.float8_e4m3fn,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
        layer.register_parameter("weight", weight)

        block_n, block_k = self._WEIGHT_BLOCK_SIZE
        if output_size_per_partition % block_n != 0:
            raise ValueError(
                "ModelOpt FP8_PB_WO requires out_features divisible by "
                f"{block_n}, got {output_size_per_partition}."
            )
        if input_size_per_partition % block_k != 0:
            raise ValueError(
                "ModelOpt FP8_PB_WO requires in_features divisible by "
                f"{block_k}, got {input_size_per_partition}."
            )

        out_blks = output_size_per_partition // block_n
        in_blks = input_size_per_partition // block_k

        # Match ModelOpt's exported shape so weight loading works without a
        # custom loader: [out_blk, 1, in_blk, 1]
        weight_scale = BlockQuantScaleParameter(
            data=torch.empty((out_blks, 1, in_blks, 1), dtype=torch.float32),
            input_dim=2,
            output_dim=0,
            weight_loader=weight_loader,
        )
        weight_scale[:] = torch.finfo(torch.float32).min
        layer.register_parameter("weight_scale", weight_scale)

    def process_weights_after_loading(self, layer: Module) -> None:
        # Keep weight in [out, in] layout for W8A8BlockFp8LinearOp.
        layer.weight = Parameter(layer.weight.data, requires_grad=False)

        scale = layer.weight_scale
        if scale.dim() == 4:
            # [out_blk, 1, in_blk, 1] -> [out_blk, in_blk]
            scale = scale.squeeze(1).squeeze(-1)
        elif scale.dim() != 2:
            raise ValueError(
                "Unexpected ModelOpt FP8_PB_WO weight_scale shape: "
                f"{tuple(scale.shape)}."
            )

        layer.weight_scale = Parameter(scale.contiguous(), requires_grad=False)

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        return self.w8a8_block_fp8_linear.apply(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            input_scale=None,
            bias=bias,
        )


716
717
718
719
720
721
722
723
class ModelOptFp8MoEMethod(FusedMoEMethodBase):
    """MoE method for ModelOpt FP8.
    Supports loading FP8 checkpoints with static weight scale and
    activation scale.
    Args:
        quant_config: The ModelOpt quantization config.
    """

724
725
726
    def __init__(
        self,
        quant_config: ModelOptFp8Config,
727
        moe_config: FusedMoEConfig,
728
    ) -> None:
729
        super().__init__(moe_config)
730
        self.quant_config = quant_config
731
        assert self.quant_config.is_checkpoint_fp8_serialized
732
733
734
735
736
737

        # Select Fp8 MoE backend
        self.fp8_backend, self.experts_cls = select_fp8_moe_backend(
            config=self.moe,
            weight_key=kFp8StaticTensorSym,
            activation_key=kFp8StaticTensorSym,
738
        )
739
740

        # Delay creation of the kernel until after process-weights.
741
        self.kernel: mk.FusedMoEModularKernel | None = None
742

743
744
745
746
747
748
    @property
    def topk_indices_dtype(self) -> torch.dtype | None:
        if self.kernel is not None:
            return self.kernel.prepare_finalize.topk_indices_dtype()
        return None

749
    def maybe_make_prepare_finalize(
750
        self,
751
        routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
752
    ) -> mk.FusedMoEPrepareAndFinalize | None:
753
        # TRT LLM not supported with all2all yet.
754
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
755
            return None
756
        elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
757
758
            # For no-EP case, don't use the MKM framework.
            if not self.moe.moe_parallel_config.use_all2all_kernels:
759
760
761
762
763
764
765
766
                return None

            prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
                self.moe,
                use_deepseek_fp8_block_scale=False,
            )
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
767
        return super().maybe_make_prepare_finalize(routing_tables)
768
769
770
771

    def select_gemm_impl(
        self,
        prepare_finalize: mk.FusedMoEPrepareAndFinalize,
772
        layer: torch.nn.Module,
773
    ) -> mk.FusedMoEPermuteExpertsUnpermute:
774
        assert self.moe_quant_config is not None
775
776
777
778
779
780
        assert self.experts_cls is not None
        return make_fp8_moe_kernel_for_mkm(
            moe_config=self.moe,
            quant_config=self.moe_quant_config,
            experts_cls=self.experts_cls,
            prepare_finalize=prepare_finalize,
781
        )
782
783
784
785
786
787
788
789
790
791

    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
792
793
794
        layer.orig_dtype = params_dtype
        layer.num_experts = num_experts

795
        # Use FP8 dtype if checkpoint is serialized
796
797
798
799
800
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_fp8_serialized
            else params_dtype
        )
801
802
        weight_loader = extra_weight_attrs.get("weight_loader")

803
        w13_num_shards = 2 if self.moe.is_act_and_mul else 1
804

805
        w13_weight = ModelWeightParameter(
806
807
            data=torch.empty(
                num_experts,
808
                w13_num_shards * intermediate_size_per_partition,
809
810
811
                hidden_size,
                dtype=weight_dtype,
            ),
812
813
814
815
816
817
818
            input_dim=2,
            output_dim=1,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_weight", w13_weight)

        w2_weight = ModelWeightParameter(
819
820
821
822
823
824
            data=torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=weight_dtype,
            ),
825
826
827
828
829
830
            input_dim=2,
            output_dim=1,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w2_weight", w2_weight)

831
832
833
834
835
836
        # WEIGHT SCALES - Per-tensor scaling for ModelOpts
        # For gated MoE, allocate 2 scales for w1 and w3 respectively.
        # They will be combined to a single scale after weight loading.
        # For non-gated MoE, allocate 1 scale for w13.
        w13_weight_scale = PerTensorScaleParameter(
            data=torch.full(
837
                (num_experts, w13_num_shards),
838
839
840
841
842
843
844
845
846
847
848
                1.0,
                dtype=torch.float32,
            ),
            weight_loader=weight_loader,
        )
        w2_weight_scale = PerTensorScaleParameter(
            data=torch.full((num_experts,), 1.0, dtype=torch.float32),
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_weight_scale", w13_weight_scale)
        layer.register_parameter("w2_weight_scale", w2_weight_scale)
849

850
851
852
853
        # INPUT SCALES - Per-tensor scaling for ModelOpt
        w13_input_scale = PerTensorScaleParameter(
            data=torch.full((num_experts,), 1.0, dtype=torch.float32),
            weight_loader=weight_loader,
854
        )
855
856
857
858
859
860
        w2_input_scale = PerTensorScaleParameter(
            data=torch.full((num_experts,), 1.0, dtype=torch.float32),
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_input_scale", w13_input_scale)
        layer.register_parameter("w2_input_scale", w2_input_scale)
861

862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
    def _setup_kernel(
        self,
        layer: torch.nn.Module,
        w13: torch.Tensor,
        w2: torch.Tensor,
        w13_scale: torch.Tensor,
        w2_scale: torch.Tensor,
        w13_input_scale: torch.Tensor,
        w2_input_scale: torch.Tensor,
    ):
        w13, w2, w13_scale, w2_scale = convert_to_fp8_moe_kernel_format(
            fp8_backend=self.fp8_backend,
            layer=layer,
            w13=w13,
            w2=w2,
            w13_scale=w13_scale,
            w2_scale=w2_scale,
            w13_input_scale=w13_input_scale,
            w2_input_scale=w2_input_scale,
        )
882

883
884
885
886
887
888
889
        # Replace parameters with updated versions. Note that this helper
        # function ensures the replacement is compatible with RL weight reloads.
        replace_parameter(layer, "w13_weight", w13)
        replace_parameter(layer, "w2_weight", w2)
        replace_parameter(layer, "w13_weight_scale", w13_scale)
        replace_parameter(layer, "w2_weight_scale", w2_scale)

890
        # Setup modular kernel.
891
892
        self.moe_quant_config = self.get_fused_moe_quant_config(layer)
        if self.moe_quant_config:
893
            assert self.experts_cls is not None
894
895
896
897
            self.kernel, self.use_inplace = make_fp8_moe_kernel(
                moe_quant_config=self.moe_quant_config,
                moe_config=self.moe,
                fp8_backend=self.fp8_backend,
898
                experts_cls=self.experts_cls,
899
            )
900

901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        w13 = layer.w13_weight
        w2 = layer.w2_weight
        w13_scale = layer.w13_weight_scale
        w2_scale = layer.w2_weight_scale
        w13_input_scale = layer.w13_input_scale
        w2_input_scale = layer.w2_input_scale

        # Per tensor kernels require single activation scale. Use the max.
        w13_input_scale, w2_input_scale = process_fp8_input_tensor_strategy_moe(
            w13_input_scale, w2_input_scale
        )
        replace_parameter(layer, "w13_input_scale", w13_input_scale)
        replace_parameter(layer, "w2_input_scale", w2_input_scale)

        # Per tensor kernels require single weight scale for w13 per expert, but
        # on disk there is a scale for w1 and w3. Use the max to requantize.
        shard_size = layer.intermediate_size_per_partition
        w13, w13_scale = process_fp8_weight_tensor_strategy_moe(
            w13,
            w13_scale,
            shard_size,
            num_experts=layer.w13_weight.shape[0],
            is_act_and_mul=self.moe.is_act_and_mul,
925
926
        )

927
928
929
930
        # Shuffle weights to runtime format and setup kernel.
        self._setup_kernel(
            layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale
        )
931

932
    def get_fused_moe_quant_config(
933
        self, layer: torch.nn.Module
934
    ) -> FusedMoEQuantConfig | None:
935
936
937
938
939
940
941
942
943
944
945
946
        w1_scale = layer.w13_weight_scale
        w2_scale = layer.w2_weight_scale
        a1_scale = layer.w13_input_scale
        a2_scale = layer.w2_input_scale

        return make_fp8_moe_quant_config(
            fp8_backend=self.fp8_backend,
            w1_scale=w1_scale,
            w2_scale=w2_scale,
            a1_scale=a1_scale,
            a2_scale=a2_scale,
        )
947

948
949
    def apply(
        self,
950
        layer: FusedMoE,
951
        router: FusedMoERouter,
952
953
        x: torch.Tensor,
        router_logits: torch.Tensor,
954
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
955
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
956
957
            if layer.enable_eplb:
                raise NotImplementedError(
958
                    "EPLB not supported for FlashInfer TRTLLM FP8 MoE Backend."
959
                )
960
961
            # TODO(rob): this validation should happen at kernel selection
            # time in the oracle rather than here.
962
963
            assert layer.activation == "silu", (
                f"Expected 'silu' activation but got {layer.activation}"
964
            )
965
            assert not layer.renormalize
966
            return apply_fi_trtllm_fp8_per_tensor_moe(
967
968
969
                layer=layer,
                hidden_states=x,
                router_logits=router_logits,
970
971
972
973
974
975
                routing_bias=layer.e_score_correction_bias,
                global_num_experts=layer.global_num_experts,
                top_k=layer.top_k,
                num_expert_group=layer.num_expert_group,
                topk_group=layer.topk_group,
                apply_router_weight_on_input=layer.apply_router_weight_on_input,
976
            )
977

978
979
        # Expert selection
        topk_weights, topk_ids = router.select_experts(
980
981
982
            hidden_states=x,
            router_logits=router_logits,
        )
983

984
985
986
        # TODO(rob): this validation should happen at kernel selection
        # time in the oracle rather than here.
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
987
            assert layer.activation in ("silu", "relu2_no_mul"), (
988
                "Expected activation to be in ('silu', 'relu2_no_mul'),"
989
                f"but got {layer.activation}"
990
            )
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006

        assert self.kernel is not None
        result = self.kernel(
            x,
            layer.w13_weight,
            layer.w2_weight,
            topk_weights,
            topk_ids,
            inplace=self.use_inplace,
            activation=layer.activation,
            global_num_experts=layer.global_num_experts,
            expert_map=layer.expert_map,
            apply_router_weight_on_input=layer.apply_router_weight_on_input,
        )

        return result
1007
1008


1009
1010
1011
1012
1013
1014
ModelOptFp8Config.LinearMethodCls = ModelOptFp8LinearMethod
ModelOptFp8Config.FusedMoEMethodCls = ModelOptFp8MoEMethod
ModelOptFp8Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod


class ModelOptNvFp4Config(ModelOptQuantConfigBase):
1015
1016
1017
1018
1019
    """Config class for ModelOpt FP4."""

    def __init__(
        self,
        is_checkpoint_nvfp4_serialized: bool,
1020
        kv_cache_quant_algo: str | None,
1021
        exclude_modules: list[str],
1022
1023
        group_size: int = 16,
    ) -> None:
1024
        super().__init__(exclude_modules)
1025
1026
1027
1028
        self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
        if is_checkpoint_nvfp4_serialized:
            logger.warning(
                "Detected ModelOpt NVFP4 checkpoint. Please note that"
1029
1030
                " the format is experimental and could change in future."
            )
1031
1032
1033
1034

            self.group_size = group_size
            self.kv_cache_quant_algo = kv_cache_quant_algo

1035
    def get_name(self) -> QuantizationMethods:
1036
        return "modelopt_fp4"
1037

1038
    def get_supported_act_dtypes(self) -> list[torch.dtype]:
1039
1040
1041
1042
        return [torch.bfloat16, torch.half, torch.float8_e4m3fn]

    @classmethod
    def get_min_capability(cls) -> int:
1043
        return 75
1044

1045
1046
    @classmethod
    def override_quantization_method(
1047
        cls, hf_quant_cfg, user_quant
1048
    ) -> QuantizationMethods | None:
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
        """Detect if this ModelOpt FP4 config should be used based on
        quantization config."""
        if hf_quant_cfg is None:
            return None

        # Use the community standard 'quant_method'
        quant_method = hf_quant_cfg.get("quant_method", "").lower()

        # Only proceed if the method is explicitly "modelopt"
        if quant_method != "modelopt":
            return None

        # Look for ModelOpt-specific config structure
        if "quantization" in hf_quant_cfg:
            quant_config = hf_quant_cfg["quantization"]
            if isinstance(quant_config, dict):
                quant_algo = quant_config.get("quant_algo", "")
                if "NVFP4" in quant_algo:
                    return "modelopt_fp4"
        else:
            # Check for compressed-tensors style config with specific
            # quant_algo field
            quant_algo = hf_quant_cfg.get("quant_algo", "")
            if isinstance(quant_algo, str) and "FP4" in quant_algo.upper():
                return "modelopt_fp4"

        return None

1077
    @classmethod
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
    def _from_config(
        cls,
        *,
        quant_method: str,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
        original_config: dict[str, Any],
        group_size: int | None,
        **kwargs: Any,
    ) -> "ModelOptNvFp4Config":
1088
        is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
1089

1090
1091
1092
        if group_size is None:
            group_size = 16  # Default value

1093
        # For FP4, these fields are required
1094
        if is_checkpoint_nvfp4_serialized and "quantization" in original_config:
1095
            # Check if required fields are present in the quantization config
1096
            quant_config = original_config["quantization"]
1097
            required_fields = ["group_size", "kv_cache_quant_algo", "exclude_modules"]
1098
1099
1100
1101
1102
1103
            missing_fields = [
                field for field in required_fields if field not in quant_config
            ]
            if missing_fields:
                raise ValueError(
                    f"NVFP4 quantization requires the following fields in "
1104
1105
1106
1107
1108
                    f"hf_quant_config.json: {missing_fields}"
                )

        return cls(
            is_checkpoint_nvfp4_serialized,
1109
            kv_cache_quant_method,
1110
1111
1112
            exclude_modules,
            group_size,
        )
1113
1114
1115
1116
1117


class ModelOptNvFp4LinearMethod(LinearMethodBase):
    """Linear method for Model Optimizer NVFP4.
    Supports loading NVFP4 checkpoints with the following structure:
1118

1119
1120
1121
1122
1123
1124
1125
    input_scale: torch.float32, scalar ,
    weight: NVFP4(represented as byte) Shape: [1, X, y/2]
    weight_scale: FP8-E4M3, Shape: [X, Y], aka per block scale,
    weight_scale_2: torch.float32, scalar,
    Args: quant_config: The ModelOpt quantization config.
    """

1126
    def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
1127
        self.quant_config = quant_config
1128
        self.marlin_input_dtype = None
1129

1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
        self.backend = "none"
        if envs.VLLM_NVFP4_GEMM_BACKEND is None:
            if has_flashinfer():
                self.backend = "flashinfer-cutlass"
            elif cutlass_fp4_supported():
                self.backend = "cutlass"
            elif is_fp4_marlin_supported():
                self.backend = "marlin"
        elif envs.VLLM_NVFP4_GEMM_BACKEND.startswith("flashinfer-"):
            self.backend = envs.VLLM_NVFP4_GEMM_BACKEND
            assert has_flashinfer(), f"FlashInfer is required for {self.backend}"
1141
1142
1143
        elif envs.VLLM_NVFP4_GEMM_BACKEND == "cutlass":
            self.backend = "cutlass"
            assert cutlass_fp4_supported(), f"Cutlass is required for {self.backend}"
1144
1145
1146
        elif envs.VLLM_NVFP4_GEMM_BACKEND == "marlin":
            self.backend = "marlin"
            assert is_fp4_marlin_supported(), f"Marlin is required for {self.backend}"
1147
1148

        if self.backend == "none":
1149
            raise ValueError(
1150
1151
                "No valid NVFP4 GEMM backend found. "
                "Please check your platform capability."
1152
            )
1153

1154
1155
        logger.info_once(f"Using {self.backend} for NVFP4 GEMM")

1156
1157
1158
1159
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
1160
        output_partition_sizes: list[int],
1161
1162
1163
1164
1165
1166
1167
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size
        if not self.quant_config.is_checkpoint_nvfp4_serialized:
1168
1169
1170
1171
            raise ValueError(
                "NVFP4 quantization was selected, "
                " dynamic quantization is not supported."
            )
1172
1173
1174
1175
1176
1177
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition

1178
1179
1180
1181
        if input_size_per_partition % 16 != 0:
            raise ValueError(
                "Unsupported model when in features size is not multiple of 16"
            )
1182
        # The nvfp4 weight is still represented as
1183
1184
1185
1186
1187
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_nvfp4_serialized
            else params_dtype
        )
1188
1189
1190
1191
1192
1193
        # Weight
        weight = ModelWeightParameter(
            data=torch.empty(
                # 2 fp4 items are packed in the input dimension
                layer.output_size_per_partition,
                layer.input_size_per_partition // 2,
1194
1195
                dtype=torch.uint8,
            ),
1196
1197
            input_dim=1,
            output_dim=0,
1198
1199
            weight_loader=weight_loader,
        )
1200
1201
1202
        layer.register_parameter("weight", weight)

        # Input Weight Scale
1203
1204
1205
1206
        input_scale = PerTensorScaleParameter(
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )
1207
1208
1209
        layer.register_parameter("input_scale", input_scale)

        # Global Weight Scale
1210
1211
1212
1213
        weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )
1214
1215
1216
        layer.register_parameter("weight_scale_2", weight_scale_2)

        # Per Block Weight Scale
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
        weight_scale = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition // self.quant_config.group_size,
                dtype=weight_dtype,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237

        layer.register_parameter("weight_scale", weight_scale)

    def process_weights_after_loading(self, layer: Module) -> None:
        # global scales:
        input_scale_2 = layer.input_scale.max().to(torch.float32)
        layer.input_scale = Parameter(input_scale_2, requires_grad=False)

        weight_scale_2 = layer.weight_scale_2.max().to(torch.float32)
        layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False)

1238
1239
1240
        layer.alpha = Parameter(
            layer.input_scale * layer.weight_scale_2, requires_grad=False
        )
1241

1242
1243
        # Calculate `1 / input_scale` so that we don't need to do so at runtime
        layer.input_scale_inv = Parameter(
1244
1245
            (1 / layer.input_scale).to(torch.float32), requires_grad=False
        )
1246

1247
1248
1249
        # Swizzle the weight blockscale.
        # contracting dimension is input dimension
        # block_size = 16;
1250
1251
1252
        assert layer.weight_scale.dtype == torch.float8_e4m3fn, (
            "Weight Block scale must be represented as FP8-E4M3"
        )
1253

1254
1255
1256
1257
1258
        if self.backend == "marlin":
            prepare_fp4_layer_for_marlin(layer)
            del layer.alpha
            del layer.input_scale
        elif self.backend == "flashinfer-trtllm":
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
            # FlashInfer TRTLLM FP4 GEMM requires a different weight layout.
            # FlashInfer provides nvfp4_quantize to quantize + shuffle the
            # layout but we use our own quantization so we have to call
            # shuffles ourselves.
            from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a

            weight = layer.weight.data
            weight_scale = layer.weight_scale.data

            epilogue_tile_m = 128
1269
1270
1271
1272
1273
1274
            weight = shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m)
            weight_scale = (
                shuffle_matrix_sf_a(weight_scale.view(torch.uint8), epilogue_tile_m)
                .reshape(weight_scale.shape)
                .view(torch.float8_e4m3fn)
            )
1275

1276
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
1277
1278
1279
            layer.weight = Parameter(weight, requires_grad=False)
        else:
            swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
1280
            layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False)
1281
            layer.weight = Parameter(layer.weight.data, requires_grad=False)
1282
1283
1284
1285
1286

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
1287
        bias: torch.Tensor | None = None,
1288
    ) -> torch.Tensor:
1289
        if self.backend == "marlin":
1290
1291
1292
1293
1294
1295
1296
1297
            return apply_fp4_marlin_linear(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale,
                weight_scale_2=layer.weight_scale_2,
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
1298
                bias=bias,
1299
                input_dtype=self.marlin_input_dtype,
1300
            )
1301

1302
        output_dtype = x.dtype
1303
        output_shape = [x.shape[0], layer.weight.shape[0]]
1304
1305

        # quantize BF16 or FP16 to (FP4 and interleaved block scale)
1306
        x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_scale_inv, self.backend)
1307
1308
1309

        # validate dtypes of quantized input, input block scale,
        # weight and weight_blockscale
1310
1311
1312
1313
1314
        assert x_fp4.dtype == torch.uint8
        assert layer.weight.dtype == torch.uint8
        assert x_blockscale.dtype == torch.float8_e4m3fn
        assert layer.weight_scale.dtype == torch.float8_e4m3fn
        assert layer.alpha.dtype == torch.float32
1315

1316
1317
1318
1319
        mm_args = (
            x_fp4,
            layer.weight,
            x_blockscale,
1320
            layer.weight_scale,
1321
1322
1323
            layer.alpha,
            output_dtype,
        )
1324
1325
1326
        if self.backend.startswith("flashinfer-"):
            backend_name = self.backend[len("flashinfer-") :]
            out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name)
1327
        else:
1328
            assert self.backend == "cutlass"
1329
1330
            out = cutlass_scaled_fp4_mm(*mm_args)

1331
1332
1333
        if bias is not None:
            out = out + bias
        return out.view(*output_shape)
1334
1335
1336
1337
1338


class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
    """
    MoE Method for FP4 Quantization.
1339
    Args:
1340
1341
1342
        quant_config: NVFP4 Quant Config
    """

1343
1344
1345
    def __init__(
        self,
        quant_config: ModelOptNvFp4Config,
1346
        moe_config: FusedMoEConfig,
1347
    ) -> None:
1348
        super().__init__(moe_config)
1349
        self.quant_config = quant_config
1350
1351
1352
1353
1354
1355
1356
1357
1358
        # Select experts implementation.
        self.nvfp4_backend, self.experts_cls = select_nvfp4_moe_backend(
            config=self.moe,
            weight_key=kNvfp4Static,
            activation_key=kNvfp4Dynamic,
        )

        # Delay creation of the kernel until after process-weights.
        self.kernel: mk.FusedMoEModularKernel | None = None
1359
1360
1361
1362

        self.use_global_sf = is_global_sf_supported_for_nvfp4_backend(
            self.nvfp4_backend
        )
1363
1364
1365
1366
1367
1368

    @property
    def topk_indices_dtype(self) -> torch.dtype | None:
        if self.kernel is not None:
            return self.kernel.prepare_finalize.topk_indices_dtype()
        return None
1369

1370
1371
1372
1373
    def maybe_make_prepare_finalize(
        self,
        routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
    ) -> mk.FusedMoEPrepareAndFinalize | None:
1374
        if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
1375
            return None
1376
        elif self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTLASS:
1377
1378
            # For no-EP case, don't use the MKM framework.
            if not self.moe.moe_parallel_config.use_all2all_kernels:
1379
                return None
1380
            # For now, fp4 moe only works with the flashinfer dispatcher.
1381
1382
1383
            prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(
                self.moe
            )
1384
1385
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
1386
        else:
1387
            return super().maybe_make_prepare_finalize(routing_tables)
1388

1389
1390
1391
    def select_gemm_impl(
        self,
        prepare_finalize: mk.FusedMoEPrepareAndFinalize,
1392
        layer: torch.nn.Module,
1393
    ) -> mk.FusedMoEPermuteExpertsUnpermute:
1394
        assert self.moe_quant_config is not None
1395
1396
1397
1398
1399
1400
        assert self.experts_cls is not None
        return make_nvfp4_moe_kernel_for_mkm(
            moe_config=self.moe,
            quant_config=self.moe_quant_config,
            experts_cls=self.experts_cls,
            prepare_finalize=prepare_finalize,
1401
        )
1402

1403
1404
1405
1406
1407
1408
    def uses_weight_scale_2_pattern(self) -> bool:
        """
        FP4 variants use 'weight_scale_2' pattern for per-tensor weight scales.
        """
        return True

1409
1410
1411
1412
1413
1414
1415
1416
1417
    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
1418
        assert self.quant_config.is_checkpoint_nvfp4_serialized
1419

1420
1421
        layer.num_experts = num_experts
        layer.params_dtype = params_dtype
1422
1423
1424
1425
        layer.quant_config = self.quant_config
        weight_dtype = torch.uint8
        weight_scale_dtype = torch.float8_e4m3fn
        weight_loader = extra_weight_attrs.get("weight_loader")
1426
        global_num_experts = extra_weight_attrs.get("global_num_experts")
1427
        w13_num_shards = 2 if self.moe.is_act_and_mul else 1
1428
1429
1430
1431
        # GEMM 1
        w13_weight = ModelWeightParameter(
            data=torch.empty(
                num_experts,
1432
                w13_num_shards * intermediate_size_per_partition,
1433
1434
                # 2 fp4 items are packed in the input dimension
                hidden_size // 2,
1435
1436
                dtype=weight_dtype,
            ),
1437
1438
            input_dim=1,
            output_dim=2,
1439
1440
            weight_loader=weight_loader,
        )
1441
1442
1443
1444
1445
1446
1447
1448
1449
        layer.register_parameter("w13_weight", w13_weight)

        # GEMM 2
        w2_weight = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                hidden_size,
                # 2 fp4 items are packed in the input dimension
                intermediate_size_per_partition // 2,
1450
1451
                dtype=weight_dtype,
            ),
1452
1453
            input_dim=1,
            output_dim=2,
1454
1455
            weight_loader=weight_loader,
        )
1456
1457
1458
1459
1460
        layer.register_parameter("w2_weight", w2_weight)

        w13_weight_scale = ModelWeightParameter(
            data=torch.empty(
                num_experts,
1461
                w13_num_shards * intermediate_size_per_partition,
1462
1463
                # 2 fp4 items are packed in the input dimension
                hidden_size // self.quant_config.group_size,
1464
1465
                dtype=weight_scale_dtype,
            ),
1466
1467
            input_dim=1,
            output_dim=2,
1468
1469
            weight_loader=weight_loader,
        )
1470
1471
1472
1473
1474
1475
1476
        layer.register_parameter("w13_weight_scale", w13_weight_scale)

        w2_weight_scale = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                hidden_size,
                # 2 fp4 items are packed in the input dimension
1477
1478
1479
                intermediate_size_per_partition // self.quant_config.group_size,
                dtype=weight_scale_dtype,
            ),
1480
1481
            input_dim=1,
            output_dim=2,
1482
1483
            weight_loader=weight_loader,
        )
1484
1485
1486
        layer.register_parameter("w2_weight_scale", w2_weight_scale)

        extra_weight_attrs.update(
1487
1488
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
        )
1489
1490

        w13_weight_scale_2 = PerTensorScaleParameter(
1491
            data=torch.empty(num_experts, w13_num_shards, dtype=torch.float32),
1492
1493
            weight_loader=weight_loader,
        )
1494
1495
1496
1497
        layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)

        w2_weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(num_experts, dtype=torch.float32),
1498
1499
            weight_loader=weight_loader,
        )
1500
1501
1502
        layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2)

        extra_weight_attrs.update(
1503
1504
            {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
        )
1505

1506
1507
        global_sf_num_experts = (
            global_num_experts if self.use_global_sf else num_experts
1508
        )
1509
        w13_input_scale = PerTensorScaleParameter(
1510
            data=torch.empty(
1511
                global_sf_num_experts,
1512
                w13_num_shards,
1513
1514
                dtype=torch.float32,
            ),
1515
1516
            weight_loader=weight_loader,
        )
1517
1518
        layer.register_parameter("w13_input_scale", w13_input_scale)

1519
        w2_input_scale = PerTensorScaleParameter(
1520
            data=torch.empty(global_sf_num_experts, dtype=torch.float32),
1521
1522
            weight_loader=weight_loader,
        )
1523
1524
1525
        layer.register_parameter("w2_input_scale", w2_input_scale)

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
1526
1527
1528
        """
        Convert NVFP4 MoE weights into kernel format and setup the kernel.
        """
1529

1530
        # Use a single gscale for w13.
1531
        if self.moe.is_act_and_mul and not torch.allclose(
1532
1533
            layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]
        ):
1534
1535
            logger.warning_once(
                "w1_weight_scale_2 must match w3_weight_scale_2. "
1536
1537
                "Accuracy may be affected."
            )
1538
        w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0].contiguous()
1539

1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
        (
            w13,
            w13_scale,
            w13_scale_2,
            a13_scale,
            w2,
            w2_scale,
            w2_scale_2,
            a2_scale,
        ) = convert_to_nvfp4_moe_kernel_format(
            nvfp4_backend=self.nvfp4_backend,
            layer=layer,
            w13=layer.w13_weight,
            w13_scale=layer.w13_weight_scale,
            w13_scale_2=w13_weight_scale_2,
            a13_scale=layer.w13_input_scale,
            w2=layer.w2_weight,
            w2_scale=layer.w2_weight_scale,
            w2_scale_2=layer.w2_weight_scale_2,
            a2_scale=layer.w2_input_scale,
            is_act_and_mul=self.moe.is_act_and_mul,
1561
        )
1562

1563
1564
1565
1566
1567
1568
1569
1570
        replace_parameter(layer, "w13_weight", w13)
        replace_parameter(layer, "w13_weight_scale", w13_scale)
        replace_parameter(layer, "w13_weight_scale_2", w13_scale_2)
        replace_parameter(layer, "w13_input_scale", a13_scale)
        replace_parameter(layer, "w2_weight", w2)
        replace_parameter(layer, "w2_weight_scale", w2_scale)
        replace_parameter(layer, "w2_weight_scale_2", w2_scale_2)
        replace_parameter(layer, "w2_input_scale", a2_scale)
1571

1572
1573
1574
1575
        # Setup modular kernel for TP case and naive DP/EP case.
        # In non-naive DP/EP case, we will create a ModularKernelMethod.
        # TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
        # in both cases.
1576
        self.moe_quant_config = self.get_fused_moe_quant_config(layer)
1577
1578
1579
1580
1581
        if self.moe_quant_config and (
            (not self.moe.moe_parallel_config.use_all2all_kernels)
            or self.moe.moe_parallel_config.use_naive_all2all_kernels
        ):
            assert self.experts_cls is not None
1582
            self.kernel = make_nvfp4_moe_kernel(
1583
                moe_quant_config=self.moe_quant_config,
1584
                moe_config=self.moe,
1585
                experts_cls=self.experts_cls,
1586
            )
1587

1588
1589
1590
1591
    @property
    def do_post_quant_allgather(self):
        return self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM

1592
1593
1594
1595
1596
1597
1598
    def prepare_dp_allgather_tensor(
        self,
        layer: FusedMoE,
        hidden_states: torch.Tensor,
        router_logits: torch.Tensor,
    ) -> tuple[torch.Tensor, list[torch.Tensor]]:
        """Optionally prepare extra tensors to carry through DP allgather/EP."""
1599
1600
1601
1602
1603
1604
        if self.nvfp4_backend != NvFp4MoeBackend.FLASHINFER_TRTLLM:
            raise RuntimeError(
                "prepare_dp_allgather_tensor is only supported for "
                "FlashInfer TRTLLM NVFP4 MoE backend."
            )

1605
1606
1607
1608
        import flashinfer

        hidden_states_fp4, hidden_states_sf = flashinfer.fp4_quantize(
            hidden_states,
1609
            layer.a1_gscale,
1610
1611
1612
1613
1614
            is_sf_swizzled_layout=False,
        )
        extra_tensors: list[torch.Tensor] = [hidden_states_sf]
        return hidden_states_fp4, extra_tensors

1615
    def get_fused_moe_quant_config(
1616
        self, layer: torch.nn.Module
1617
    ) -> FusedMoEQuantConfig | None:
1618
1619
1620
        return make_nvfp4_moe_quant_config(
            backend=self.nvfp4_backend,
            w13_scale=layer.w13_weight_scale,
1621
            w2_scale=layer.w2_weight_scale,
1622
1623
1624
1625
            w13_scale_2=layer.w13_weight_scale_2,
            w2_scale_2=layer.w2_weight_scale_2,
            a13_scale=layer.w13_input_scale,
            a2_scale=layer.w2_input_scale,
1626
1627
        )

1628
1629
1630
1631
    @property
    def supports_eplb(self) -> bool:
        return True

1632
1633
    def apply(
        self,
1634
        layer: FusedMoE,
1635
        router: FusedMoERouter,
1636
1637
        x: torch.Tensor,
        router_logits: torch.Tensor,
1638
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
1639
        if (
1640
            self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
1641
            and not layer.enable_eplb
1642
        ):
1643
1644
1645
1646
            return flashinfer_trtllm_fp4_moe(
                layer=layer,
                x=x,
                router_logits=router_logits,
1647
                top_k=layer.top_k,
1648
                activation=layer.activation,
1649
1650
1651
1652
1653
                global_num_experts=layer.global_num_experts,
                num_expert_group=layer.num_expert_group,
                topk_group=layer.topk_group,
                custom_routing_function=layer.custom_routing_function,
                e_score_correction_bias=layer.e_score_correction_bias,
1654
            )
1655

1656
1657
1658
1659
1660
        # Hidden_states in select_experts is only used to extract metadata
        if isinstance(x, tuple):
            x_routing, _ = x
        else:
            x_routing = x
1661
        topk_weights, topk_ids = router.select_experts(
1662
            hidden_states=x_routing,
1663
            router_logits=router_logits,
1664
        )
1665

1666
        # EPLB path
1667
1668
        if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
            assert layer.enable_eplb
1669
1670
1671
1672
1673
1674
            return flashinfer_trtllm_fp4_routed_moe(
                layer=layer,
                x=x,
                topk_ids=topk_ids,
                topk_weights=topk_weights,
                top_k=layer.top_k,
1675
                activation=layer.activation,
1676
1677
                global_num_experts=layer.global_num_experts,
            )
1678
1679
1680
        else:
            assert self.kernel is not None
            return self.kernel(
1681
1682
1683
1684
1685
                x,
                layer.w13_weight,
                layer.w2_weight,
                topk_weights,
                topk_ids,
1686
                inplace=False,
1687
1688
1689
1690
                activation=layer.activation,
                global_num_experts=layer.global_num_experts,
                expert_map=layer.expert_map,
                apply_router_weight_on_input=layer.apply_router_weight_on_input,
1691
            )
1692
1693
1694
1695
1696


ModelOptNvFp4Config.LinearMethodCls = ModelOptNvFp4LinearMethod
ModelOptNvFp4Config.FusedMoEMethodCls = ModelOptNvFp4FusedMoE
ModelOptNvFp4Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod