"vscode:/vscode.git/clone" did not exist on "82e6b8646deb1487c0da035235ceed50cf5b69c0"
modelopt.py 60.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from fnmatch import fnmatch
5
from typing import TYPE_CHECKING, Any, Optional
6
7
8
9
10

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter

11
12
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
13
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
14
from vllm.logger import init_logger
15
from vllm.model_executor.layers.attention import Attention
16
from vllm.model_executor.layers.fused_moe.config import (
17
    FusedMoEConfig,
18
19
    FusedMoEQuantConfig,
)
20
from vllm.model_executor.layers.fused_moe.layer import (
21
22
23
24
    FusedMoE,
    FusedMoEMethodBase,
    FusedMoeWeightScaleSupported,
)
25
26
27
28
29
30
31
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
    Fp8MoeBackend,
    convert_to_fp8_moe_kernel_format,
    make_fp8_moe_kernel,
    make_fp8_moe_quant_config,
    select_fp8_moe_backend,
)
32
33
34
35
36
37
38
39
from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
    NvFp4MoeBackend,
    convert_to_nvfp4_moe_kernel_format,
    is_global_sf_supported_for_nvfp4_backend,
    make_nvfp4_moe_kernel,
    make_nvfp4_moe_quant_config,
    select_nvfp4_moe_backend,
)
40
41
42
43
44
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
)
45
from vllm.model_executor.layers.quantization import QuantizationMethods
46
from vllm.model_executor.layers.quantization.base_config import (
47
48
49
    QuantizationConfig,
    QuantizeMethodBase,
)
50
51
52
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
    init_fp8_linear_kernel,
)
53
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
54
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
55
    flashinfer_trtllm_fp4_moe,
56
    flashinfer_trtllm_fp4_routed_moe,
57
)
58
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
59
    apply_fi_trtllm_fp8_per_tensor_moe,
60
)
61
62
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
    W8A8BlockFp8LinearOp,
63
64
    process_fp8_input_tensor_strategy_moe,
    process_fp8_weight_tensor_strategy_moe,
65
)
66
67
68
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
    get_marlin_input_dtype,
)
69
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
70
71
72
73
    apply_fp4_marlin_linear,
    is_fp4_marlin_supported,
    prepare_fp4_layer_for_marlin,
)
74
from vllm.model_executor.layers.quantization.utils.quant_utils import (
75
76
77
    GroupShape,
    cutlass_fp4_supported,
    is_layer_skipped,
78
79
80
    kFp8DynamicTokenSym,
    kFp8StaticTensorSym,
    kFp8StaticTokenSym,
81
82
    kNvfp4Dynamic,
    kNvfp4Static,
83
84
85
    pad_nvfp4_activation_for_cutlass,
    pad_nvfp4_weight_for_cutlass,
    slice_nvfp4_output,
86
87
    swizzle_blockscale,
)
88
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
89
    cutlass_block_fp8_supported,
90
91
    requantize_with_max_scale,
)
92
93
94
95
96
97
from vllm.model_executor.parameter import (
    BlockQuantScaleParameter,
    ChannelQuantScaleParameter,
    ModelWeightParameter,
    PerTensorScaleParameter,
)
98
from vllm.model_executor.utils import replace_parameter
99
100
101
102
from vllm.utils.flashinfer import (
    flashinfer_scaled_fp4_mm,
    has_flashinfer,
)
103

104
105
106
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

107
108
logger = init_logger(__name__)

109
110
111
112
113
114
115
116
117
118
QUANT_ALGOS = [
    # FP8 (per-tensor weight + optional static activation scale).
    "FP8",
    # FP8 per-channel weight scale + per-token activation scale.
    "FP8_PER_CHANNEL_PER_TOKEN",
    # FP8 per-block weight-only (ModelOpt may emit this as lowercase).
    "FP8_PB_WO",
    # FP4
    "NVFP4",
]
119
KV_CACHE_QUANT_ALGOS = ["FP8"]
120
121


122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
    """

    def __init__(self, quant_config: "ModelOptQuantConfigBase"):
        super().__init__(quant_config)


class ModelOptQuantConfigBase(QuantizationConfig):
    LinearMethodCls: type = LinearMethodBase
    FusedMoEMethodCls: type = FusedMoEMethodBase
    KVCacheMethodCls: type = BaseKVCacheMethod

    def __init__(
        self,
        exclude_modules: list[str],
    ):
        super().__init__()
        self.exclude_modules: list[str] = exclude_modules

    def is_layer_excluded(self, prefix: str) -> bool:
        """
        Check if a layer should be excluded from quantization.

        Handles both exact matching (for fused layers) and ModelOpt wildcard matching.

        The ModelOpt exclude_modules list is a list of wildcards.
        """
        if len(self.exclude_modules) == 0:
            return False

        # First check exact matching with fused layer support
        if is_layer_skipped(prefix, self.exclude_modules, self.packed_modules_mapping):
            return True

        # TODO: This special hard coded logic is not needed for quantized checkpoints
        # generated by ModelOpt >= 0.39.0 where they are handled natually by the
        # exclude_modules config. But need to keep them for loading quantized
        # checkpoints generated by older versions. Then check substring matching
        # for patterns not caught by exact match
        for exclude_module in self.exclude_modules:
            # Skip exact matches already handled above
            if exclude_module != prefix and (
                exclude_module in prefix
                or (
                    prefix.startswith("language_model.")
                    and exclude_module in prefix.removeprefix("language_model.")
                )
            ):
                return True

        # modelopt exclude modules are not simple strings, they are wildcards
        for wildcard_pattern in self.exclude_modules:
            if fnmatch(prefix, wildcard_pattern):
                return True

        return False

    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
        # handle kv-cache first so we can focus only on weight quantization thereafter
        if isinstance(layer, Attention):
            return self.KVCacheMethodCls(self)

        # handle exclusion
        if self.is_layer_excluded(prefix):
            if isinstance(layer, LinearBase):
                return UnquantizedLinearMethod()
            return None

        # TODO: This special hard coded logic is not needed for quantized checkpoints
        # generated by ModelOpt >= 0.39.0 where they are handled natually by the
        # exclude_modules config. But need to keep them for loading quantized
        # checkpoints generated by older versions. Then check substring matching
        # for patterns not caught by exact match
        if "vision_tower" in prefix or "vision_model" in prefix:
            return UnquantizedLinearMethod()

        # now, the layer is quantized, handle it here
        if isinstance(layer, LinearBase):
204
205
206
207
            quant_method = self.LinearMethodCls(self)
            if getattr(quant_method, "backend", "") == "marlin":
                quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
            return quant_method
208
        elif isinstance(layer, FusedMoE):
209
210
211
            quant_method = self.FusedMoEMethodCls(
                quant_config=self, moe_config=layer.moe_config
            )
212
213
214
            if getattr(quant_method, "backend", "") == "marlin":
                quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
            return quant_method
215
216
217
218
219

        return None

    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if len(self.exclude_modules) > 0:
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
            # This is a workaround for the weights remapping issue:
            # https://github.com/vllm-project/vllm/issues/28072
            # Right now, the Nvidia ModelOpt library use just one wildcard pattern:
            #        module_path*
            # It gets applied if the whole tree of modules rooted at module_path
            # is not quantized. Here we replace such pattern by 2 patterns that are
            # collectively equivalent to the original pattern:
            #        module_path
            #        module_path.*
            new_exclude_modules = []
            for exclude in self.exclude_modules:
                if len(exclude) >= 2 and exclude[-1] == "*" and exclude[-2] != ".":
                    new_exclude_modules.append(exclude[:-1])
                    new_exclude_modules.append(exclude[:-1] + ".*")
                else:
                    new_exclude_modules.append(exclude)

            self.exclude_modules = hf_to_vllm_mapper.apply_list(new_exclude_modules)
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286

    @staticmethod
    def get_config_filenames() -> list[str]:
        return ["hf_quant_config.json"]

    @classmethod
    def _from_config(
        cls,
        *,
        quant_method: str,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
        original_config: dict[str, Any],
        group_size: int | None,
    ) -> "ModelOptQuantConfigBase":
        raise NotImplementedError("Please implement this function in sub classes")

    @classmethod
    def from_config(cls, config: dict[str, Any]) -> "ModelOptQuantConfigBase":
        # Handle both ModelOpt format and compressed-tensors style format
        if "quantization" in config:
            # Traditional ModelOpt format:
            # {"quantization": {"quant_algo": "..."}}
            quant_config = cls.get_from_keys(config, ["quantization"])
            if not isinstance(quant_config, dict):
                raise ValueError("Expected 'quantization' to be a dictionary in config")

            quant_method = quant_config.get("quant_algo")

            # Handle kv_cache_quant_algo with proper type validation
            kv_cache_quant_method = quant_config.get("kv_cache_quant_algo")

            # Handle group_size with proper type validation
            group_size_raw = quant_config.get("group_size")

            # "exclude_modules" is the key in the legacy hf_quant_config.json
            exclude_modules = quant_config.get("exclude_modules", [])
        else:
            # Compressed-tensors style format:
            # {"quant_algo": "...", "quant_method": "modelopt"}
            quant_method = config.get("quant_algo")
            kv_cache_quant_method = config.get("kv_cache_quant_algo")
            # "ignore" is the key in config.json
            exclude_modules = config.get("ignore", [])
            group_size_raw = config.get("group_size")

        if not quant_method:
            raise ValueError("Missing 'quant_algo' in quantization config")

287
288
289
        # Normalize quant_algo for robust matching (ModelOpt may emit lowercase).
        quant_method = str(quant_method).upper()

290
291
292
293
294
295
296
297
        if kv_cache_quant_method is None:
            # No KV cache quantization, keep this branch just to have this comment
            pass
        elif not isinstance(kv_cache_quant_method, str):
            raise ValueError(
                f"kv_cache_quant_algo must be a string, got "
                f"{type(kv_cache_quant_method)}"
            )
298
299
        else:
            kv_cache_quant_method = kv_cache_quant_method.upper()
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334

        if not isinstance(exclude_modules, list):
            raise ValueError(
                f"exclude_modules must be a list, got {type(exclude_modules)}"
            )

        if group_size_raw is None:
            group_size = None
        elif isinstance(group_size_raw, int):
            group_size = group_size_raw
        else:
            try:
                group_size = int(group_size_raw)
            except (ValueError, TypeError):
                raise ValueError(
                    f"group_size must be an integer, got {type(group_size_raw)}"
                ) from None

        if quant_method not in QUANT_ALGOS:
            raise ValueError(
                f"ModelOpt currently only supports: {QUANT_ALGOS} "
                "quantizations in vLLM. Please check the "
                "`hf_quant_config.json` file for your model's "
                "quant configuration."
            )
        return cls._from_config(
            quant_method=quant_method,
            kv_cache_quant_method=kv_cache_quant_method,
            exclude_modules=exclude_modules,
            group_size=group_size,
            original_config=config,
        )


class ModelOptFp8Config(ModelOptQuantConfigBase):
335
336
337
338
    """Config class for ModelOpt FP8."""

    def __init__(
        self,
339
        quant_method: str,
340
341
342
        is_checkpoint_fp8_serialized: bool,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
343
    ) -> None:
344
        super().__init__(exclude_modules)
345
        self.quant_method = quant_method
346
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
347
        self.kv_cache_quant_method = kv_cache_quant_method
348
        if is_checkpoint_fp8_serialized:
349
            logger.warning(
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
                "Detected ModelOpt fp8 checkpoint (quant_algo=%s). Please note "
                "that the format is experimental and could change.",
                quant_method,
            )

        # Select LinearMethod implementation based on quant_algo.
        if self.quant_method == "FP8":
            self.LinearMethodCls = ModelOptFp8LinearMethod
        elif self.quant_method == "FP8_PER_CHANNEL_PER_TOKEN":
            self.LinearMethodCls = ModelOptFp8PcPtLinearMethod
        elif self.quant_method == "FP8_PB_WO":
            self.LinearMethodCls = ModelOptFp8PbWoLinearMethod
        else:
            raise ValueError(
                "Unsupported ModelOpt FP8 quant_algo for vLLM: "
                f"{self.quant_method}. Supported: FP8 / "
                "FP8_PER_CHANNEL_PER_TOKEN / FP8_PB_WO."
367
            )
368

369
    def get_name(self) -> QuantizationMethods:
370
371
        return "modelopt"

372
    def get_supported_act_dtypes(self) -> list[torch.dtype]:
373
374
375
376
377
378
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
        return 89

379
380
    @classmethod
    def override_quantization_method(
381
        cls, hf_quant_cfg, user_quant
382
    ) -> QuantizationMethods | None:
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
        """Detect if this ModelOpt config should be used based on
        quantization config."""

        if hf_quant_cfg is None:
            return None

        # Use the community standard 'quant_method'
        quant_method = hf_quant_cfg.get("quant_method", "").lower()

        # Only proceed if the method is explicitly "modelopt"
        if quant_method != "modelopt":
            return None

        # Look for ModelOpt-specific config structure
        if "quantization" in hf_quant_cfg:
            quant_config = hf_quant_cfg["quantization"]
            if isinstance(quant_config, dict):
400
401
                quant_algo = str(quant_config.get("quant_algo", ""))
                if "FP8" in quant_algo.upper():
402
403
404
                    return "modelopt"
        else:
            # Check for compressed-tensors style config with specific quant_algo
405
406
            quant_algo = str(hf_quant_cfg.get("quant_algo", ""))
            if "FP8" in quant_algo.upper():
407
408
409
410
                return "modelopt"

        return None

411
    @classmethod
412
413
414
415
416
417
418
419
420
    def _from_config(
        cls,
        *,
        quant_method: str,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
        original_config: dict[str, Any],
        **kwargs: Any,
    ) -> "ModelOptFp8Config":
421
        is_checkpoint_fp8_serialized = "FP8" in quant_method
422

423
424
425
426
427
428
        return cls(
            quant_method,
            is_checkpoint_fp8_serialized,
            kv_cache_quant_method,
            exclude_modules,
        )
429

430
431
432
433

class ModelOptFp8LinearMethod(LinearMethodBase):
    """Linear method for Model Optimizer static quantization.
    Supports loading FP8 checkpoints with static weight scale and
434
    activation scale. Future support might be added for dynamic
435
436
437
438
    scales.

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
439
    2. Only support float8_e4m3fn datatype
440
441
442
        Args: quant_config: The ModelOpt quantization config.
    """

443
    def __init__(self, quant_config: ModelOptFp8Config) -> None:
444
        self.quant_config = quant_config
445
446
447
448
449
        self.fp8_linear = init_fp8_linear_kernel(
            activation_quant_key=kFp8StaticTensorSym,
            weight_quant_key=kFp8StaticTensorSym,
            out_dtype=torch.get_default_dtype(),
            module_name=self.__class__.__name__,
450
        )
451
452
453
454
455

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
456
        output_partition_sizes: list[int],
457
458
459
460
461
462
463
464
465
466
467
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
468
469
470
471
472
473
474
475
476
477
478
479
480
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_fp8_serialized
            else params_dtype
        )
        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition, input_size_per_partition, dtype=weight_dtype
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
481
482
483
484
        layer.register_parameter("weight", weight)

        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
485
486
487
488
            weight_scale = PerTensorScaleParameter(
                data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
                weight_loader=weight_loader,
            )
489
490
491
            weight_scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("weight_scale", weight_scale)
            # INPUT SCALE
492
493
494
495
            scale = PerTensorScaleParameter(
                data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
                weight_loader=weight_loader,
            )
496
497
498
499
500

            scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("input_scale", scale)

    def process_weights_after_loading(self, layer: Module) -> None:
501
502
503
504
        weight = layer.weight
        max_w_scale = layer.weight_scale.max()
        if not (layer.weight_scale == layer.weight_scale[0]).all():
            max_w_scale, weight = requantize_with_max_scale(
505
506
                layer.weight, layer.weight_scale, layer.logical_widths
            )
507
508
        layer.weight = Parameter(weight.t(), requires_grad=False)
        layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
509
        layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
510
511
512
513
514

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
515
        bias: torch.Tensor | None = None,
516
    ) -> torch.Tensor:
517
        return self.fp8_linear.apply_weights(layer, x, bias)
518
519


520
521
522
523
524
525
526
527
528
529
530
class ModelOptFp8PcPtLinearMethod(LinearMethodBase):
    """Linear method for ModelOpt FP8_PER_CHANNEL_PER_TOKEN checkpoints.

    Expected checkpoint structure (per Linear):
    - weight: fp8-e4m3fn, shape [out, in]
    - weight_scale: fp32, shape [out] (per-output-channel)
    - no input_scale (activations are dynamically quantized per-token)
    """

    def __init__(self, quant_config: ModelOptFp8Config) -> None:
        self.quant_config = quant_config
531
532
533
534
535
        self.fp8_linear = init_fp8_linear_kernel(
            activation_quant_key=kFp8DynamicTokenSym,
            weight_quant_key=kFp8StaticTokenSym,
            out_dtype=torch.get_default_dtype(),
            module_name=self.__class__.__name__,
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
        )

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size

        if not self.quant_config.is_checkpoint_fp8_serialized:
            raise ValueError(
                "FP8_PER_CHANNEL_PER_TOKEN currently only supports "
                "FP8-serialized checkpoints."
            )

        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition

        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition,
                dtype=torch.float8_e4m3fn,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
        layer.register_parameter("weight", weight)

        weight_scale = ChannelQuantScaleParameter(
            data=torch.empty(output_size_per_partition, dtype=torch.float32),
            output_dim=0,
            weight_loader=weight_loader,
        )
        weight_scale[:] = torch.finfo(torch.float32).min
        layer.register_parameter("weight_scale", weight_scale)

    def process_weights_after_loading(self, layer: Module) -> None:
        layer.weight = Parameter(layer.weight.t(), requires_grad=False)
        layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
592
        return self.fp8_linear.apply_weights(layer, x, bias)
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713


class ModelOptFp8PbWoLinearMethod(LinearMethodBase):
    """Linear method for ModelOpt FP8_PB_WO checkpoints.

    ModelOpt exports `weight_scale` as a 4D tensor:
      [out_blk, 1, in_blk, 1]
    where block size is typically 128 for both dims.

    vLLM executes it as FP8 GEMM with *dynamic per-token* activation quant.
    """

    _WEIGHT_BLOCK_SIZE: tuple[int, int] = (128, 128)

    def __init__(self, quant_config: ModelOptFp8Config) -> None:
        self.quant_config = quant_config
        block_n, block_k = self._WEIGHT_BLOCK_SIZE
        self.weight_block_size = list(self._WEIGHT_BLOCK_SIZE)
        self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
            weight_group_shape=GroupShape(block_n, block_k),
            act_quant_group_shape=GroupShape(1, block_k),
            cutlass_block_fp8_supported=cutlass_block_fp8_supported(),
            use_aiter_and_is_supported=False,
        )

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size

        if not self.quant_config.is_checkpoint_fp8_serialized:
            raise ValueError(
                "FP8_PB_WO currently only supports FP8-serialized checkpoints."
            )

        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition

        # Expose block size so the v2 weight loaders can translate offsets from
        # element-space -> block-space for BlockQuantScaleParameter.
        layer.weight_block_size = self.weight_block_size

        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition,
                dtype=torch.float8_e4m3fn,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
        layer.register_parameter("weight", weight)

        block_n, block_k = self._WEIGHT_BLOCK_SIZE
        if output_size_per_partition % block_n != 0:
            raise ValueError(
                "ModelOpt FP8_PB_WO requires out_features divisible by "
                f"{block_n}, got {output_size_per_partition}."
            )
        if input_size_per_partition % block_k != 0:
            raise ValueError(
                "ModelOpt FP8_PB_WO requires in_features divisible by "
                f"{block_k}, got {input_size_per_partition}."
            )

        out_blks = output_size_per_partition // block_n
        in_blks = input_size_per_partition // block_k

        # Match ModelOpt's exported shape so weight loading works without a
        # custom loader: [out_blk, 1, in_blk, 1]
        weight_scale = BlockQuantScaleParameter(
            data=torch.empty((out_blks, 1, in_blks, 1), dtype=torch.float32),
            input_dim=2,
            output_dim=0,
            weight_loader=weight_loader,
        )
        weight_scale[:] = torch.finfo(torch.float32).min
        layer.register_parameter("weight_scale", weight_scale)

    def process_weights_after_loading(self, layer: Module) -> None:
        # Keep weight in [out, in] layout for W8A8BlockFp8LinearOp.
        layer.weight = Parameter(layer.weight.data, requires_grad=False)

        scale = layer.weight_scale
        if scale.dim() == 4:
            # [out_blk, 1, in_blk, 1] -> [out_blk, in_blk]
            scale = scale.squeeze(1).squeeze(-1)
        elif scale.dim() != 2:
            raise ValueError(
                "Unexpected ModelOpt FP8_PB_WO weight_scale shape: "
                f"{tuple(scale.shape)}."
            )

        layer.weight_scale = Parameter(scale.contiguous(), requires_grad=False)

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        return self.w8a8_block_fp8_linear.apply(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            input_scale=None,
            bias=bias,
        )


714
715
716
717
718
719
720
721
class ModelOptFp8MoEMethod(FusedMoEMethodBase):
    """MoE method for ModelOpt FP8.
    Supports loading FP8 checkpoints with static weight scale and
    activation scale.
    Args:
        quant_config: The ModelOpt quantization config.
    """

722
723
724
    def __init__(
        self,
        quant_config: ModelOptFp8Config,
725
        moe_config: FusedMoEConfig,
726
    ) -> None:
727
        super().__init__(moe_config)
728
        self.quant_config = quant_config
729
        assert self.quant_config.is_checkpoint_fp8_serialized
730
731
732
733
734
735

        # Select Fp8 MoE backend
        self.fp8_backend, self.experts_cls = select_fp8_moe_backend(
            config=self.moe,
            weight_key=kFp8StaticTensorSym,
            activation_key=kFp8StaticTensorSym,
736
        )
737

738
    def maybe_make_prepare_finalize(
739
        self,
740
        routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
741
    ) -> mk.FusedMoEPrepareAndFinalize | None:
742
743
744
745
        raise ValueError(
            f"{self.__class__.__name__} uses the new modular kernel initialization "
            "logic. This function should not be called."
        )
746
747
748
749

    def select_gemm_impl(
        self,
        prepare_finalize: mk.FusedMoEPrepareAndFinalize,
750
        layer: torch.nn.Module,
751
    ) -> mk.FusedMoEPermuteExpertsUnpermute:
752
753
754
        raise ValueError(
            f"{self.__class__.__name__} uses the new modular kernel initialization "
            "logic. This function should not be called."
755
        )
756
757
758
759
760
761
762
763
764
765

    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
766
767
768
        layer.orig_dtype = params_dtype
        layer.num_experts = num_experts

769
        # Use FP8 dtype if checkpoint is serialized
770
771
772
773
774
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_fp8_serialized
            else params_dtype
        )
775
776
        weight_loader = extra_weight_attrs.get("weight_loader")

777
        w13_num_shards = 2 if self.moe.is_act_and_mul else 1
778

779
        w13_weight = ModelWeightParameter(
780
781
            data=torch.empty(
                num_experts,
782
                w13_num_shards * intermediate_size_per_partition,
783
784
785
                hidden_size,
                dtype=weight_dtype,
            ),
786
787
788
789
790
791
792
            input_dim=2,
            output_dim=1,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_weight", w13_weight)

        w2_weight = ModelWeightParameter(
793
794
795
796
797
798
            data=torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=weight_dtype,
            ),
799
800
801
802
803
804
            input_dim=2,
            output_dim=1,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w2_weight", w2_weight)

805
806
807
808
809
810
        # WEIGHT SCALES - Per-tensor scaling for ModelOpts
        # For gated MoE, allocate 2 scales for w1 and w3 respectively.
        # They will be combined to a single scale after weight loading.
        # For non-gated MoE, allocate 1 scale for w13.
        w13_weight_scale = PerTensorScaleParameter(
            data=torch.full(
811
                (num_experts, w13_num_shards),
812
813
814
815
816
817
818
819
820
821
822
                1.0,
                dtype=torch.float32,
            ),
            weight_loader=weight_loader,
        )
        w2_weight_scale = PerTensorScaleParameter(
            data=torch.full((num_experts,), 1.0, dtype=torch.float32),
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_weight_scale", w13_weight_scale)
        layer.register_parameter("w2_weight_scale", w2_weight_scale)
823

824
825
826
827
        # INPUT SCALES - Per-tensor scaling for ModelOpt
        w13_input_scale = PerTensorScaleParameter(
            data=torch.full((num_experts,), 1.0, dtype=torch.float32),
            weight_loader=weight_loader,
828
        )
829
830
831
832
833
834
        w2_input_scale = PerTensorScaleParameter(
            data=torch.full((num_experts,), 1.0, dtype=torch.float32),
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_input_scale", w13_input_scale)
        layer.register_parameter("w2_input_scale", w2_input_scale)
835

836
837
    def _setup_kernel(
        self,
838
        layer: FusedMoE,
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
        w13: torch.Tensor,
        w2: torch.Tensor,
        w13_scale: torch.Tensor,
        w2_scale: torch.Tensor,
        w13_input_scale: torch.Tensor,
        w2_input_scale: torch.Tensor,
    ):
        w13, w2, w13_scale, w2_scale = convert_to_fp8_moe_kernel_format(
            fp8_backend=self.fp8_backend,
            layer=layer,
            w13=w13,
            w2=w2,
            w13_scale=w13_scale,
            w2_scale=w2_scale,
            w13_input_scale=w13_input_scale,
            w2_input_scale=w2_input_scale,
        )
856

857
858
859
860
861
862
863
        # Replace parameters with updated versions. Note that this helper
        # function ensures the replacement is compatible with RL weight reloads.
        replace_parameter(layer, "w13_weight", w13)
        replace_parameter(layer, "w2_weight", w2)
        replace_parameter(layer, "w13_weight_scale", w13_scale)
        replace_parameter(layer, "w2_weight_scale", w2_scale)

864
        # Setup modular kernel.
865
866
        self.moe_quant_config = self.get_fused_moe_quant_config(layer)
        if self.moe_quant_config:
867
            assert self.experts_cls is not None
868
            self.moe_mk, self.use_inplace = make_fp8_moe_kernel(
869
870
871
                moe_quant_config=self.moe_quant_config,
                moe_config=self.moe,
                fp8_backend=self.fp8_backend,
872
                experts_cls=self.experts_cls,
873
874
                routing_tables=layer._maybe_init_expert_routing_tables(),
                shared_experts=layer.shared_experts,
875
            )
876

877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        w13 = layer.w13_weight
        w2 = layer.w2_weight
        w13_scale = layer.w13_weight_scale
        w2_scale = layer.w2_weight_scale
        w13_input_scale = layer.w13_input_scale
        w2_input_scale = layer.w2_input_scale

        # Per tensor kernels require single activation scale. Use the max.
        w13_input_scale, w2_input_scale = process_fp8_input_tensor_strategy_moe(
            w13_input_scale, w2_input_scale
        )
        replace_parameter(layer, "w13_input_scale", w13_input_scale)
        replace_parameter(layer, "w2_input_scale", w2_input_scale)

        # Per tensor kernels require single weight scale for w13 per expert, but
        # on disk there is a scale for w1 and w3. Use the max to requantize.
        shard_size = layer.intermediate_size_per_partition
        w13, w13_scale = process_fp8_weight_tensor_strategy_moe(
            w13,
            w13_scale,
            shard_size,
            num_experts=layer.w13_weight.shape[0],
            is_act_and_mul=self.moe.is_act_and_mul,
901
902
        )

903
904
905
906
        # Shuffle weights to runtime format and setup kernel.
        self._setup_kernel(
            layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale
        )
907

908
    def get_fused_moe_quant_config(
909
        self, layer: torch.nn.Module
910
    ) -> FusedMoEQuantConfig | None:
911
912
913
914
915
916
917
918
919
920
921
922
        w1_scale = layer.w13_weight_scale
        w2_scale = layer.w2_weight_scale
        a1_scale = layer.w13_input_scale
        a2_scale = layer.w2_input_scale

        return make_fp8_moe_quant_config(
            fp8_backend=self.fp8_backend,
            w1_scale=w1_scale,
            w2_scale=w2_scale,
            a1_scale=a1_scale,
            a2_scale=a2_scale,
        )
923

924
925
926
927
928
    @property
    def is_monolithic(self) -> bool:
        return self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM

    def apply_monolithic(
929
        self,
930
        layer: FusedMoE,
931
932
        x: torch.Tensor,
        router_logits: torch.Tensor,
933
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
934
935
936
937
938
        assert self.is_monolithic
        assert self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
        if layer.enable_eplb:
            raise NotImplementedError(
                "EPLB not supported for FlashInfer TRTLLM FP8 MoE Backend."
939
            )
940
941
942
943
944
945
946
947
        # TODO(rob): this validation should happen at kernel selection
        # time in the oracle rather than here.
        assert layer.activation == "silu", (
            f"Expected 'silu' activation but got {layer.activation}"
        )
        assert not layer.renormalize
        return apply_fi_trtllm_fp8_per_tensor_moe(
            layer=layer,
948
949
            hidden_states=x,
            router_logits=router_logits,
950
951
952
953
954
955
            routing_bias=layer.e_score_correction_bias,
            global_num_experts=layer.global_num_experts,
            top_k=layer.top_k,
            num_expert_group=layer.num_expert_group,
            topk_group=layer.topk_group,
            apply_router_weight_on_input=layer.apply_router_weight_on_input,
956
        )
957

958
959
960
961
962
963
964
965
966
    def apply(
        self,
        layer: FusedMoE,
        x: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        assert not self.is_monolithic

967
968
969
        # TODO(rob): this validation should happen at kernel selection
        # time in the oracle rather than here.
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
970
            assert layer.activation in ("silu", "relu2_no_mul"), (
971
                "Expected activation to be in ('silu', 'relu2_no_mul'),"
972
                f"but got {layer.activation}"
973
            )
974

975
976
        assert self.moe_mk is not None
        return self.moe_mk(
977
978
979
980
981
982
983
984
985
986
987
988
            x,
            layer.w13_weight,
            layer.w2_weight,
            topk_weights,
            topk_ids,
            inplace=self.use_inplace,
            activation=layer.activation,
            global_num_experts=layer.global_num_experts,
            expert_map=layer.expert_map,
            apply_router_weight_on_input=layer.apply_router_weight_on_input,
        )

989

990
991
992
993
994
995
ModelOptFp8Config.LinearMethodCls = ModelOptFp8LinearMethod
ModelOptFp8Config.FusedMoEMethodCls = ModelOptFp8MoEMethod
ModelOptFp8Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod


class ModelOptNvFp4Config(ModelOptQuantConfigBase):
996
997
998
999
1000
    """Config class for ModelOpt FP4."""

    def __init__(
        self,
        is_checkpoint_nvfp4_serialized: bool,
1001
        kv_cache_quant_algo: str | None,
1002
        exclude_modules: list[str],
1003
1004
        group_size: int = 16,
    ) -> None:
1005
        super().__init__(exclude_modules)
1006
1007
1008
1009
        self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
        if is_checkpoint_nvfp4_serialized:
            logger.warning(
                "Detected ModelOpt NVFP4 checkpoint. Please note that"
1010
1011
                " the format is experimental and could change in future."
            )
1012
1013
1014
1015

            self.group_size = group_size
            self.kv_cache_quant_algo = kv_cache_quant_algo

1016
    def get_name(self) -> QuantizationMethods:
1017
        return "modelopt_fp4"
1018

1019
    def get_supported_act_dtypes(self) -> list[torch.dtype]:
1020
1021
1022
1023
        return [torch.bfloat16, torch.half, torch.float8_e4m3fn]

    @classmethod
    def get_min_capability(cls) -> int:
1024
        return 75
1025

1026
1027
    @classmethod
    def override_quantization_method(
1028
        cls, hf_quant_cfg, user_quant
1029
    ) -> QuantizationMethods | None:
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
        """Detect if this ModelOpt FP4 config should be used based on
        quantization config."""
        if hf_quant_cfg is None:
            return None

        # Use the community standard 'quant_method'
        quant_method = hf_quant_cfg.get("quant_method", "").lower()

        # Only proceed if the method is explicitly "modelopt"
        if quant_method != "modelopt":
            return None

        # Look for ModelOpt-specific config structure
        if "quantization" in hf_quant_cfg:
            quant_config = hf_quant_cfg["quantization"]
            if isinstance(quant_config, dict):
                quant_algo = quant_config.get("quant_algo", "")
                if "NVFP4" in quant_algo:
                    return "modelopt_fp4"
        else:
            # Check for compressed-tensors style config with specific
            # quant_algo field
            quant_algo = hf_quant_cfg.get("quant_algo", "")
            if isinstance(quant_algo, str) and "FP4" in quant_algo.upper():
                return "modelopt_fp4"

        return None

1058
    @classmethod
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
    def _from_config(
        cls,
        *,
        quant_method: str,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
        original_config: dict[str, Any],
        group_size: int | None,
        **kwargs: Any,
    ) -> "ModelOptNvFp4Config":
1069
        is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
1070

1071
1072
1073
        if group_size is None:
            group_size = 16  # Default value

1074
        # For FP4, these fields are required
1075
        if is_checkpoint_nvfp4_serialized and "quantization" in original_config:
1076
            # Check if required fields are present in the quantization config
1077
            quant_config = original_config["quantization"]
1078
            required_fields = ["group_size", "kv_cache_quant_algo", "exclude_modules"]
1079
1080
1081
1082
1083
1084
            missing_fields = [
                field for field in required_fields if field not in quant_config
            ]
            if missing_fields:
                raise ValueError(
                    f"NVFP4 quantization requires the following fields in "
1085
1086
1087
1088
1089
                    f"hf_quant_config.json: {missing_fields}"
                )

        return cls(
            is_checkpoint_nvfp4_serialized,
1090
            kv_cache_quant_method,
1091
1092
1093
            exclude_modules,
            group_size,
        )
1094
1095
1096
1097
1098


class ModelOptNvFp4LinearMethod(LinearMethodBase):
    """Linear method for Model Optimizer NVFP4.
    Supports loading NVFP4 checkpoints with the following structure:
1099

1100
1101
1102
1103
1104
1105
1106
    input_scale: torch.float32, scalar ,
    weight: NVFP4(represented as byte) Shape: [1, X, y/2]
    weight_scale: FP8-E4M3, Shape: [X, Y], aka per block scale,
    weight_scale_2: torch.float32, scalar,
    Args: quant_config: The ModelOpt quantization config.
    """

1107
    def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
1108
        self.quant_config = quant_config
1109
        self.marlin_input_dtype = None
1110

1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
        self.backend = "none"
        if envs.VLLM_NVFP4_GEMM_BACKEND is None:
            if has_flashinfer():
                self.backend = "flashinfer-cutlass"
            elif cutlass_fp4_supported():
                self.backend = "cutlass"
            elif is_fp4_marlin_supported():
                self.backend = "marlin"
        elif envs.VLLM_NVFP4_GEMM_BACKEND.startswith("flashinfer-"):
            self.backend = envs.VLLM_NVFP4_GEMM_BACKEND
            assert has_flashinfer(), f"FlashInfer is required for {self.backend}"
1122
1123
1124
        elif envs.VLLM_NVFP4_GEMM_BACKEND == "cutlass":
            self.backend = "cutlass"
            assert cutlass_fp4_supported(), f"Cutlass is required for {self.backend}"
1125
1126
1127
        elif envs.VLLM_NVFP4_GEMM_BACKEND == "marlin":
            self.backend = "marlin"
            assert is_fp4_marlin_supported(), f"Marlin is required for {self.backend}"
1128
1129

        if self.backend == "none":
1130
            raise ValueError(
1131
1132
                "No valid NVFP4 GEMM backend found. "
                "Please check your platform capability."
1133
            )
1134

1135
1136
        logger.info_once(f"Using {self.backend} for NVFP4 GEMM")

1137
1138
1139
1140
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
1141
        output_partition_sizes: list[int],
1142
1143
1144
1145
1146
1147
1148
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size
        if not self.quant_config.is_checkpoint_nvfp4_serialized:
1149
1150
1151
1152
            raise ValueError(
                "NVFP4 quantization was selected, "
                " dynamic quantization is not supported."
            )
1153
1154
1155
1156
1157
1158
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition

1159
1160
1161
1162
        if input_size_per_partition % 16 != 0:
            raise ValueError(
                "Unsupported model when in features size is not multiple of 16"
            )
1163
        # The nvfp4 weight is still represented as
1164
1165
1166
1167
1168
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_nvfp4_serialized
            else params_dtype
        )
1169
1170
1171
1172
1173
1174
        # Weight
        weight = ModelWeightParameter(
            data=torch.empty(
                # 2 fp4 items are packed in the input dimension
                layer.output_size_per_partition,
                layer.input_size_per_partition // 2,
1175
1176
                dtype=torch.uint8,
            ),
1177
1178
            input_dim=1,
            output_dim=0,
1179
1180
            weight_loader=weight_loader,
        )
1181
1182
1183
        layer.register_parameter("weight", weight)

        # Input Weight Scale
1184
1185
1186
1187
        input_scale = PerTensorScaleParameter(
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )
1188
1189
1190
        layer.register_parameter("input_scale", input_scale)

        # Global Weight Scale
1191
1192
1193
1194
        weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )
1195
1196
1197
        layer.register_parameter("weight_scale_2", weight_scale_2)

        # Per Block Weight Scale
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
        weight_scale = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition // self.quant_config.group_size,
                dtype=weight_dtype,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218

        layer.register_parameter("weight_scale", weight_scale)

    def process_weights_after_loading(self, layer: Module) -> None:
        # global scales:
        input_scale_2 = layer.input_scale.max().to(torch.float32)
        layer.input_scale = Parameter(input_scale_2, requires_grad=False)

        weight_scale_2 = layer.weight_scale_2.max().to(torch.float32)
        layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False)

1219
1220
1221
        layer.alpha = Parameter(
            layer.input_scale * layer.weight_scale_2, requires_grad=False
        )
1222

1223
1224
        # Calculate `1 / input_scale` so that we don't need to do so at runtime
        layer.input_scale_inv = Parameter(
1225
1226
            (1 / layer.input_scale).to(torch.float32), requires_grad=False
        )
1227

1228
1229
1230
        # Swizzle the weight blockscale.
        # contracting dimension is input dimension
        # block_size = 16;
1231
1232
1233
        assert layer.weight_scale.dtype == torch.float8_e4m3fn, (
            "Weight Block scale must be represented as FP8-E4M3"
        )
1234

1235
1236
1237
1238
1239
        if self.backend == "marlin":
            prepare_fp4_layer_for_marlin(layer)
            del layer.alpha
            del layer.input_scale
        elif self.backend == "flashinfer-trtllm":
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
            # FlashInfer TRTLLM FP4 GEMM requires a different weight layout.
            # FlashInfer provides nvfp4_quantize to quantize + shuffle the
            # layout but we use our own quantization so we have to call
            # shuffles ourselves.
            from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a

            weight = layer.weight.data
            weight_scale = layer.weight_scale.data

            epilogue_tile_m = 128
1250
1251
1252
1253
1254
1255
            weight = shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m)
            weight_scale = (
                shuffle_matrix_sf_a(weight_scale.view(torch.uint8), epilogue_tile_m)
                .reshape(weight_scale.shape)
                .view(torch.float8_e4m3fn)
            )
1256

1257
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
1258
1259
            layer.weight = Parameter(weight, requires_grad=False)
        else:
1260
1261
            # Swizzle block scales and pad the packed NVFP4 weights for kernel
            # alignment (CUTLASS/FlashInfer require K and N divisible by 32).
1262
            swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
1263
            layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False)
1264
1265
1266
1267
1268
1269

            weight, weights_padding_cols = pad_nvfp4_weight_for_cutlass(
                layer.weight.data
            )
            layer.weights_padding_cols = weights_padding_cols
            layer.weight = Parameter(weight, requires_grad=False)
1270
1271
1272
1273
1274

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
1275
        bias: torch.Tensor | None = None,
1276
    ) -> torch.Tensor:
1277
        if self.backend == "marlin":
1278
1279
1280
1281
1282
1283
1284
1285
            return apply_fp4_marlin_linear(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale,
                weight_scale_2=layer.weight_scale_2,
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
1286
                bias=bias,
1287
                input_dtype=self.marlin_input_dtype,
1288
            )
1289

1290
1291
1292
        output_dtype = x.dtype

        # quantize BF16 or FP16 to (FP4 and interleaved block scale)
1293
1294
1295
        x_fp4, x_blockscale = scaled_fp4_quant(
            x, layer.input_scale_inv, is_sf_swizzled_layout=True, backend=self.backend
        )
1296
1297
1298

        # validate dtypes of quantized input, input block scale,
        # weight and weight_blockscale
1299
1300
1301
1302
1303
        assert x_fp4.dtype == torch.uint8
        assert layer.weight.dtype == torch.uint8
        assert x_blockscale.dtype == torch.float8_e4m3fn
        assert layer.weight_scale.dtype == torch.float8_e4m3fn
        assert layer.alpha.dtype == torch.float32
1304

1305
1306
1307
1308
1309
1310
        # Pad activations to match weight K-dimension padding
        weights_padding_cols = getattr(layer, "weights_padding_cols", 0)
        output_size = layer.output_size_per_partition
        output_shape = [x.shape[0], output_size]
        x_fp4 = pad_nvfp4_activation_for_cutlass(x_fp4, weights_padding_cols)

1311
1312
1313
1314
        mm_args = (
            x_fp4,
            layer.weight,
            x_blockscale,
1315
            layer.weight_scale,
1316
1317
1318
            layer.alpha,
            output_dtype,
        )
1319

1320
1321
1322
        if self.backend.startswith("flashinfer-"):
            backend_name = self.backend[len("flashinfer-") :]
            out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name)
1323
        else:
1324
            assert self.backend == "cutlass"
1325
1326
            out = cutlass_scaled_fp4_mm(*mm_args)

1327
1328
1329
        # Slice output to remove N-dimension padding
        out = slice_nvfp4_output(out, output_size)

1330
1331
1332
        if bias is not None:
            out = out + bias
        return out.view(*output_shape)
1333
1334
1335
1336
1337


class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
    """
    MoE Method for FP4 Quantization.
1338
    Args:
1339
1340
1341
        quant_config: NVFP4 Quant Config
    """

1342
1343
1344
    def __init__(
        self,
        quant_config: ModelOptNvFp4Config,
1345
        moe_config: FusedMoEConfig,
1346
    ) -> None:
1347
        super().__init__(moe_config)
1348
        self.quant_config = quant_config
1349
1350
1351
1352
1353
1354
1355
        # Select experts implementation.
        self.nvfp4_backend, self.experts_cls = select_nvfp4_moe_backend(
            config=self.moe,
            weight_key=kNvfp4Static,
            activation_key=kNvfp4Dynamic,
        )

1356
1357
1358
        self.use_global_sf = is_global_sf_supported_for_nvfp4_backend(
            self.nvfp4_backend
        )
1359

1360
1361
1362
1363
    def maybe_make_prepare_finalize(
        self,
        routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
    ) -> mk.FusedMoEPrepareAndFinalize | None:
1364
1365
1366
1367
        raise ValueError(
            f"{self.__class__.__name__} uses the new modular kernel initialization "
            "logic. This function should not be called."
        )
1368

1369
1370
1371
    def select_gemm_impl(
        self,
        prepare_finalize: mk.FusedMoEPrepareAndFinalize,
1372
        layer: torch.nn.Module,
1373
    ) -> mk.FusedMoEPermuteExpertsUnpermute:
1374
1375
1376
        raise ValueError(
            f"{self.__class__.__name__} uses the new modular kernel initialization "
            "logic. This function should not be called."
1377
        )
1378

1379
1380
1381
1382
1383
1384
    def uses_weight_scale_2_pattern(self) -> bool:
        """
        FP4 variants use 'weight_scale_2' pattern for per-tensor weight scales.
        """
        return True

1385
1386
1387
1388
1389
1390
1391
1392
1393
    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
1394
        assert self.quant_config.is_checkpoint_nvfp4_serialized
1395

1396
1397
        layer.num_experts = num_experts
        layer.params_dtype = params_dtype
1398
1399
1400
1401
        layer.quant_config = self.quant_config
        weight_dtype = torch.uint8
        weight_scale_dtype = torch.float8_e4m3fn
        weight_loader = extra_weight_attrs.get("weight_loader")
1402
        global_num_experts = extra_weight_attrs.get("global_num_experts")
1403
        w13_num_shards = 2 if self.moe.is_act_and_mul else 1
1404
1405
1406
1407
        # GEMM 1
        w13_weight = ModelWeightParameter(
            data=torch.empty(
                num_experts,
1408
                w13_num_shards * intermediate_size_per_partition,
1409
1410
                # 2 fp4 items are packed in the input dimension
                hidden_size // 2,
1411
1412
                dtype=weight_dtype,
            ),
1413
1414
            input_dim=1,
            output_dim=2,
1415
1416
            weight_loader=weight_loader,
        )
1417
1418
1419
1420
1421
1422
1423
1424
1425
        layer.register_parameter("w13_weight", w13_weight)

        # GEMM 2
        w2_weight = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                hidden_size,
                # 2 fp4 items are packed in the input dimension
                intermediate_size_per_partition // 2,
1426
1427
                dtype=weight_dtype,
            ),
1428
1429
            input_dim=1,
            output_dim=2,
1430
1431
            weight_loader=weight_loader,
        )
1432
1433
1434
1435
1436
        layer.register_parameter("w2_weight", w2_weight)

        w13_weight_scale = ModelWeightParameter(
            data=torch.empty(
                num_experts,
1437
                w13_num_shards * intermediate_size_per_partition,
1438
1439
                # 2 fp4 items are packed in the input dimension
                hidden_size // self.quant_config.group_size,
1440
1441
                dtype=weight_scale_dtype,
            ),
1442
1443
            input_dim=1,
            output_dim=2,
1444
1445
            weight_loader=weight_loader,
        )
1446
1447
1448
1449
1450
1451
1452
        layer.register_parameter("w13_weight_scale", w13_weight_scale)

        w2_weight_scale = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                hidden_size,
                # 2 fp4 items are packed in the input dimension
1453
1454
1455
                intermediate_size_per_partition // self.quant_config.group_size,
                dtype=weight_scale_dtype,
            ),
1456
1457
            input_dim=1,
            output_dim=2,
1458
1459
            weight_loader=weight_loader,
        )
1460
1461
1462
        layer.register_parameter("w2_weight_scale", w2_weight_scale)

        extra_weight_attrs.update(
1463
1464
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
        )
1465
1466

        w13_weight_scale_2 = PerTensorScaleParameter(
1467
            data=torch.empty(num_experts, w13_num_shards, dtype=torch.float32),
1468
1469
            weight_loader=weight_loader,
        )
1470
1471
1472
1473
        layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)

        w2_weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(num_experts, dtype=torch.float32),
1474
1475
            weight_loader=weight_loader,
        )
1476
1477
1478
        layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2)

        extra_weight_attrs.update(
1479
1480
            {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
        )
1481

1482
1483
        global_sf_num_experts = (
            global_num_experts if self.use_global_sf else num_experts
1484
        )
1485
        w13_input_scale = PerTensorScaleParameter(
1486
            data=torch.empty(
1487
                global_sf_num_experts,
1488
                w13_num_shards,
1489
1490
                dtype=torch.float32,
            ),
1491
1492
            weight_loader=weight_loader,
        )
1493
1494
        layer.register_parameter("w13_input_scale", w13_input_scale)

1495
        w2_input_scale = PerTensorScaleParameter(
1496
            data=torch.empty(global_sf_num_experts, dtype=torch.float32),
1497
1498
            weight_loader=weight_loader,
        )
1499
1500
        layer.register_parameter("w2_input_scale", w2_input_scale)

1501
    def process_weights_after_loading(self, layer: FusedMoE) -> None:
1502
1503
1504
        """
        Convert NVFP4 MoE weights into kernel format and setup the kernel.
        """
1505

1506
        # Use a single gscale for w13.
1507
        if self.moe.is_act_and_mul and not torch.allclose(
1508
1509
            layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]
        ):
1510
1511
            logger.warning_once(
                "w1_weight_scale_2 must match w3_weight_scale_2. "
1512
1513
                "Accuracy may be affected."
            )
1514
        w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0].contiguous()
1515

1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
        (
            w13,
            w13_scale,
            w13_scale_2,
            a13_scale,
            w2,
            w2_scale,
            w2_scale_2,
            a2_scale,
        ) = convert_to_nvfp4_moe_kernel_format(
            nvfp4_backend=self.nvfp4_backend,
            layer=layer,
            w13=layer.w13_weight,
            w13_scale=layer.w13_weight_scale,
            w13_scale_2=w13_weight_scale_2,
            a13_scale=layer.w13_input_scale,
            w2=layer.w2_weight,
            w2_scale=layer.w2_weight_scale,
            w2_scale_2=layer.w2_weight_scale_2,
            a2_scale=layer.w2_input_scale,
            is_act_and_mul=self.moe.is_act_and_mul,
1537
        )
1538

1539
1540
1541
1542
1543
1544
1545
1546
        replace_parameter(layer, "w13_weight", w13)
        replace_parameter(layer, "w13_weight_scale", w13_scale)
        replace_parameter(layer, "w13_weight_scale_2", w13_scale_2)
        replace_parameter(layer, "w13_input_scale", a13_scale)
        replace_parameter(layer, "w2_weight", w2)
        replace_parameter(layer, "w2_weight_scale", w2_scale)
        replace_parameter(layer, "w2_weight_scale_2", w2_scale_2)
        replace_parameter(layer, "w2_input_scale", a2_scale)
1547

1548
1549
1550
1551
        # Setup modular kernel for TP case and naive DP/EP case.
        # In non-naive DP/EP case, we will create a ModularKernelMethod.
        # TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
        # in both cases.
1552
        self.moe_quant_config = self.get_fused_moe_quant_config(layer)
1553
        if self.moe_quant_config:
1554
            assert self.experts_cls is not None
1555
            self.moe_mk = make_nvfp4_moe_kernel(
1556
                moe_quant_config=self.moe_quant_config,
1557
                moe_config=self.moe,
1558
                experts_cls=self.experts_cls,
1559
1560
                shared_experts=layer.shared_experts,
                routing_tables=layer._maybe_init_expert_routing_tables(),
1561
            )
1562

1563
1564
1565
1566
    @property
    def do_post_quant_allgather(self):
        return self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM

1567
1568
1569
1570
1571
1572
1573
    def prepare_dp_allgather_tensor(
        self,
        layer: FusedMoE,
        hidden_states: torch.Tensor,
        router_logits: torch.Tensor,
    ) -> tuple[torch.Tensor, list[torch.Tensor]]:
        """Optionally prepare extra tensors to carry through DP allgather/EP."""
1574
1575
1576
1577
1578
1579
        if self.nvfp4_backend != NvFp4MoeBackend.FLASHINFER_TRTLLM:
            raise RuntimeError(
                "prepare_dp_allgather_tensor is only supported for "
                "FlashInfer TRTLLM NVFP4 MoE backend."
            )

1580
1581
1582
1583
        import flashinfer

        hidden_states_fp4, hidden_states_sf = flashinfer.fp4_quantize(
            hidden_states,
1584
            layer.a1_gscale,
1585
1586
1587
1588
1589
            is_sf_swizzled_layout=False,
        )
        extra_tensors: list[torch.Tensor] = [hidden_states_sf]
        return hidden_states_fp4, extra_tensors

1590
    def get_fused_moe_quant_config(
1591
        self, layer: torch.nn.Module
1592
    ) -> FusedMoEQuantConfig | None:
1593
1594
1595
        return make_nvfp4_moe_quant_config(
            backend=self.nvfp4_backend,
            w13_scale=layer.w13_weight_scale,
1596
            w2_scale=layer.w2_weight_scale,
1597
1598
1599
1600
            w13_scale_2=layer.w13_weight_scale_2,
            w2_scale_2=layer.w2_weight_scale_2,
            a13_scale=layer.w13_input_scale,
            a2_scale=layer.w2_input_scale,
1601
1602
        )

1603
1604
1605
1606
    @property
    def supports_eplb(self) -> bool:
        return True

1607
1608
1609
1610
1611
1612
1613
1614
    @property
    def is_monolithic(self) -> bool:
        return (
            self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
            and not self.moe.moe_parallel_config.enable_eplb
        )

    def apply_monolithic(
1615
        self,
1616
        layer: FusedMoE,
1617
1618
        x: torch.Tensor,
        router_logits: torch.Tensor,
1619
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
1620
1621
        assert self.is_monolithic
        assert (
1622
            self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
1623
            and not layer.enable_eplb
1624
        )
1625

1626
1627
1628
        return flashinfer_trtllm_fp4_moe(
            layer=layer,
            x=x,
1629
            router_logits=router_logits,
1630
1631
1632
1633
1634
1635
1636
            top_k=layer.top_k,
            activation=layer.activation,
            global_num_experts=layer.global_num_experts,
            num_expert_group=layer.num_expert_group,
            topk_group=layer.topk_group,
            custom_routing_function=layer.custom_routing_function,
            e_score_correction_bias=layer.e_score_correction_bias,
1637
        )
1638

1639
1640
1641
1642
1643
1644
1645
1646
1647
    def apply(
        self,
        layer: FusedMoE,
        x: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        assert not self.is_monolithic

1648
        # EPLB path
1649
1650
        if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
            assert layer.enable_eplb
1651
1652
1653
1654
1655
1656
            return flashinfer_trtllm_fp4_routed_moe(
                layer=layer,
                x=x,
                topk_ids=topk_ids,
                topk_weights=topk_weights,
                top_k=layer.top_k,
1657
                activation=layer.activation,
1658
1659
                global_num_experts=layer.global_num_experts,
            )
1660
        else:
1661
1662
            assert self.moe_mk is not None
            return self.moe_mk(
1663
1664
1665
1666
1667
                x,
                layer.w13_weight,
                layer.w2_weight,
                topk_weights,
                topk_ids,
1668
                inplace=False,
1669
1670
1671
1672
                activation=layer.activation,
                global_num_experts=layer.global_num_experts,
                expert_map=layer.expert_map,
                apply_router_weight_on_input=layer.apply_router_weight_on_input,
1673
            )
1674
1675
1676
1677
1678


ModelOptNvFp4Config.LinearMethodCls = ModelOptNvFp4LinearMethod
ModelOptNvFp4Config.FusedMoEMethodCls = ModelOptNvFp4FusedMoE
ModelOptNvFp4Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod