modelopt.py 59.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from fnmatch import fnmatch
5
from typing import TYPE_CHECKING, Any, Optional
6
7
8
9
10

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter

11
12
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
13
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
14
from vllm.attention.layer import Attention
15
from vllm.logger import init_logger
16
from vllm.model_executor.layers.fused_moe.config import (
17
18
19
20
    FusedMoEQuantConfig,
    fp8_w8a8_moe_quant_config,
    nvfp4_moe_quant_config,
)
21
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
22
from vllm.model_executor.layers.fused_moe.layer import (
23
24
25
26
27
28
29
30
31
    FusedMoE,
    FusedMoEMethodBase,
    FusedMoeWeightScaleSupported,
)
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
)
32
from vllm.model_executor.layers.quantization import QuantizationMethods
33
from vllm.model_executor.layers.quantization.base_config import (
34
35
36
    QuantizationConfig,
    QuantizeMethodBase,
)
37
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
38
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
39
    build_flashinfer_fp4_cutlass_moe_prepare_finalize,
40
41
    flashinfer_trtllm_fp4_moe,
    prepare_static_weights_for_trtllm_fp4_moe,
42
43
44
    reorder_w1w3_to_w3w1,
    select_nvfp4_gemm_impl,
)
45
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
46
47
    FlashinferMoeBackend,
    apply_flashinfer_per_tensor_scale_fp8,
48
    build_flashinfer_fp8_cutlass_moe_prepare_finalize,
49
50
    flashinfer_cutlass_moe_fp8,
    get_flashinfer_moe_backend,
51
    is_flashinfer_supporting_global_sf,
52
53
54
55
56
    register_moe_scaling_factors,
    rotate_flashinfer_fp8_moe_weights,
    select_cutlass_fp8_gemm_impl,
    swap_w13_to_w31,
)
57
58
59
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
    get_marlin_input_dtype,
)
60
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
61
62
63
64
65
    apply_fp4_marlin_linear,
    is_fp4_marlin_supported,
    prepare_fp4_layer_for_marlin,
    prepare_moe_fp4_layer_for_marlin,
)
66
from vllm.model_executor.layers.quantization.utils.quant_utils import (
67
68
69
70
71
    GroupShape,
    cutlass_fp4_supported,
    is_layer_skipped,
    swizzle_blockscale,
)
72
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
73
74
75
76
    Fp8LinearOp,
    requantize_with_max_scale,
)
from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
77
from vllm.scalar_type import scalar_types
78
79
80
81
82
from vllm.utils.flashinfer import (
    flashinfer_scaled_fp4_mm,
    has_flashinfer,
    has_flashinfer_moe,
)
83

84
85
86
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

87
88
logger = init_logger(__name__)

89
90
QUANT_ALGOS = ["FP8", "NVFP4"]
KV_CACHE_QUANT_ALGOS = ["FP8"]
91
92


93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
    """

    def __init__(self, quant_config: "ModelOptQuantConfigBase"):
        super().__init__(quant_config)


class ModelOptQuantConfigBase(QuantizationConfig):
    LinearMethodCls: type = LinearMethodBase
    FusedMoEMethodCls: type = FusedMoEMethodBase
    KVCacheMethodCls: type = BaseKVCacheMethod

    def __init__(
        self,
        exclude_modules: list[str],
    ):
        super().__init__()
        self.exclude_modules: list[str] = exclude_modules

    def is_layer_excluded(self, prefix: str) -> bool:
        """
        Check if a layer should be excluded from quantization.

        Handles both exact matching (for fused layers) and ModelOpt wildcard matching.

        The ModelOpt exclude_modules list is a list of wildcards.
        """
        if len(self.exclude_modules) == 0:
            return False

        # First check exact matching with fused layer support
        if is_layer_skipped(prefix, self.exclude_modules, self.packed_modules_mapping):
            return True

        # TODO: This special hard coded logic is not needed for quantized checkpoints
        # generated by ModelOpt >= 0.39.0 where they are handled natually by the
        # exclude_modules config. But need to keep them for loading quantized
        # checkpoints generated by older versions. Then check substring matching
        # for patterns not caught by exact match
        for exclude_module in self.exclude_modules:
            # Skip exact matches already handled above
            if exclude_module != prefix and (
                exclude_module in prefix
                or (
                    prefix.startswith("language_model.")
                    and exclude_module in prefix.removeprefix("language_model.")
                )
            ):
                return True

        # modelopt exclude modules are not simple strings, they are wildcards
        for wildcard_pattern in self.exclude_modules:
            if fnmatch(prefix, wildcard_pattern):
                return True

        return False

    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
        # handle kv-cache first so we can focus only on weight quantization thereafter
        if isinstance(layer, Attention):
            return self.KVCacheMethodCls(self)

        # handle exclusion
        if self.is_layer_excluded(prefix):
            if isinstance(layer, LinearBase):
                return UnquantizedLinearMethod()
            return None

        # TODO: This special hard coded logic is not needed for quantized checkpoints
        # generated by ModelOpt >= 0.39.0 where they are handled natually by the
        # exclude_modules config. But need to keep them for loading quantized
        # checkpoints generated by older versions. Then check substring matching
        # for patterns not caught by exact match
        if "vision_tower" in prefix or "vision_model" in prefix:
            return UnquantizedLinearMethod()

        # now, the layer is quantized, handle it here
        if isinstance(layer, LinearBase):
175
176
177
178
            quant_method = self.LinearMethodCls(self)
            if getattr(quant_method, "backend", "") == "marlin":
                quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
            return quant_method
179
        elif isinstance(layer, FusedMoE):
180
181
182
183
            quant_method = self.FusedMoEMethodCls(quant_config=self, layer=layer)
            if getattr(quant_method, "backend", "") == "marlin":
                quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
            return quant_method
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281

        return None

    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if len(self.exclude_modules) > 0:
            self.exclude_modules = hf_to_vllm_mapper.apply_list(self.exclude_modules)

    @staticmethod
    def get_config_filenames() -> list[str]:
        return ["hf_quant_config.json"]

    @classmethod
    def _from_config(
        cls,
        *,
        quant_method: str,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
        original_config: dict[str, Any],
        group_size: int | None,
    ) -> "ModelOptQuantConfigBase":
        raise NotImplementedError("Please implement this function in sub classes")

    @classmethod
    def from_config(cls, config: dict[str, Any]) -> "ModelOptQuantConfigBase":
        # Handle both ModelOpt format and compressed-tensors style format
        if "quantization" in config:
            # Traditional ModelOpt format:
            # {"quantization": {"quant_algo": "..."}}
            quant_config = cls.get_from_keys(config, ["quantization"])
            if not isinstance(quant_config, dict):
                raise ValueError("Expected 'quantization' to be a dictionary in config")

            quant_method = quant_config.get("quant_algo")

            # Handle kv_cache_quant_algo with proper type validation
            kv_cache_quant_method = quant_config.get("kv_cache_quant_algo")

            # Handle group_size with proper type validation
            group_size_raw = quant_config.get("group_size")

            # "exclude_modules" is the key in the legacy hf_quant_config.json
            exclude_modules = quant_config.get("exclude_modules", [])
        else:
            # Compressed-tensors style format:
            # {"quant_algo": "...", "quant_method": "modelopt"}
            quant_method = config.get("quant_algo")
            kv_cache_quant_method = config.get("kv_cache_quant_algo")
            # "ignore" is the key in config.json
            exclude_modules = config.get("ignore", [])
            group_size_raw = config.get("group_size")

        if not quant_method:
            raise ValueError("Missing 'quant_algo' in quantization config")

        if kv_cache_quant_method is None:
            # No KV cache quantization, keep this branch just to have this comment
            pass
        elif not isinstance(kv_cache_quant_method, str):
            raise ValueError(
                f"kv_cache_quant_algo must be a string, got "
                f"{type(kv_cache_quant_method)}"
            )

        if not isinstance(exclude_modules, list):
            raise ValueError(
                f"exclude_modules must be a list, got {type(exclude_modules)}"
            )

        if group_size_raw is None:
            group_size = None
        elif isinstance(group_size_raw, int):
            group_size = group_size_raw
        else:
            try:
                group_size = int(group_size_raw)
            except (ValueError, TypeError):
                raise ValueError(
                    f"group_size must be an integer, got {type(group_size_raw)}"
                ) from None

        if quant_method not in QUANT_ALGOS:
            raise ValueError(
                f"ModelOpt currently only supports: {QUANT_ALGOS} "
                "quantizations in vLLM. Please check the "
                "`hf_quant_config.json` file for your model's "
                "quant configuration."
            )
        return cls._from_config(
            quant_method=quant_method,
            kv_cache_quant_method=kv_cache_quant_method,
            exclude_modules=exclude_modules,
            group_size=group_size,
            original_config=config,
        )


class ModelOptFp8Config(ModelOptQuantConfigBase):
282
283
284
285
    """Config class for ModelOpt FP8."""

    def __init__(
        self,
286
287
288
        is_checkpoint_fp8_serialized: bool,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
289
    ) -> None:
290
        super().__init__(exclude_modules)
291
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
292
        self.kv_cache_quant_method = kv_cache_quant_method
293
        if is_checkpoint_fp8_serialized:
294
295
296
297
            logger.warning(
                "Detected ModelOpt fp8 checkpoint. Please note that"
                " the format is experimental and could change."
            )
298

299
    def get_name(self) -> QuantizationMethods:
300
301
        return "modelopt"

302
    def get_supported_act_dtypes(self) -> list[torch.dtype]:
303
304
305
306
307
308
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
        return 89

309
310
    @classmethod
    def override_quantization_method(
311
        cls, hf_quant_cfg, user_quant
312
    ) -> QuantizationMethods | None:
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
        """Detect if this ModelOpt config should be used based on
        quantization config."""

        if hf_quant_cfg is None:
            return None

        # Use the community standard 'quant_method'
        quant_method = hf_quant_cfg.get("quant_method", "").lower()

        # Only proceed if the method is explicitly "modelopt"
        if quant_method != "modelopt":
            return None

        # Look for ModelOpt-specific config structure
        if "quantization" in hf_quant_cfg:
            quant_config = hf_quant_cfg["quantization"]
            if isinstance(quant_config, dict):
                quant_algo = quant_config.get("quant_algo", "")
                if "FP8" in quant_algo:
                    return "modelopt"
        else:
            # Check for compressed-tensors style config with specific quant_algo
            quant_algo = hf_quant_cfg.get("quant_algo", "")
            if isinstance(quant_algo, str) and "FP8" in quant_algo:
                return "modelopt"

        return None

341
    @classmethod
342
343
344
345
346
347
348
349
350
    def _from_config(
        cls,
        *,
        quant_method: str,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
        original_config: dict[str, Any],
        **kwargs: Any,
    ) -> "ModelOptFp8Config":
351
        is_checkpoint_fp8_serialized = "FP8" in quant_method
352

353
        return cls(is_checkpoint_fp8_serialized, kv_cache_quant_method, exclude_modules)
354

355
356
357
358

class ModelOptFp8LinearMethod(LinearMethodBase):
    """Linear method for Model Optimizer static quantization.
    Supports loading FP8 checkpoints with static weight scale and
359
    activation scale. Future support might be added for dynamic
360
361
362
363
    scales.

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
364
    2. Only support float8_e4m3fn datatype
365
366
367
        Args: quant_config: The ModelOpt quantization config.
    """

368
    def __init__(self, quant_config: ModelOptFp8Config) -> None:
369
        self.quant_config = quant_config
370
        self.fp8_linear = Fp8LinearOp(
371
372
            act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR
        )
373
374
375
376
377

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
378
        output_partition_sizes: list[int],
379
380
381
382
383
384
385
386
387
388
389
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
390
391
392
393
394
395
396
397
398
399
400
401
402
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_fp8_serialized
            else params_dtype
        )
        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition, input_size_per_partition, dtype=weight_dtype
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
403
404
405
406
        layer.register_parameter("weight", weight)

        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
407
408
409
410
            weight_scale = PerTensorScaleParameter(
                data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
                weight_loader=weight_loader,
            )
411
412
413
            weight_scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("weight_scale", weight_scale)
            # INPUT SCALE
414
415
416
417
            scale = PerTensorScaleParameter(
                data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
                weight_loader=weight_loader,
            )
418
419
420
421
422

            scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("input_scale", scale)

    def process_weights_after_loading(self, layer: Module) -> None:
423
424
425
426
        weight = layer.weight
        max_w_scale = layer.weight_scale.max()
        if not (layer.weight_scale == layer.weight_scale[0]).all():
            max_w_scale, weight = requantize_with_max_scale(
427
428
                layer.weight, layer.weight_scale, layer.logical_widths
            )
429
430
        layer.weight = Parameter(weight.t(), requires_grad=False)
        layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
431
        layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
432
433
434
435
436

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
437
        bias: torch.Tensor | None = None,
438
    ) -> torch.Tensor:
439
440
441
442
443
444
445
        return self.fp8_linear.apply(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            input_scale=layer.input_scale,
            bias=bias,
        )
446
447


448
449
450
451
452
453
454
455
class ModelOptFp8MoEMethod(FusedMoEMethodBase):
    """MoE method for ModelOpt FP8.
    Supports loading FP8 checkpoints with static weight scale and
    activation scale.
    Args:
        quant_config: The ModelOpt quantization config.
    """

456
457
458
    def __init__(
        self,
        quant_config: ModelOptFp8Config,
459
        layer: FusedMoE,
460
    ) -> None:
461
462
        super().__init__(layer.moe_config)
        self.layer = layer
463
464
        self.quant_config = quant_config
        from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
465
466
467
            cutlass_fp8_supported,
        )

468
        self.cutlass_fp8_supported = cutlass_fp8_supported()
469
        self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
470
        if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
471
            self.flashinfer_moe_backend = get_flashinfer_moe_backend()
472
473
474
475
476
477
478
479
480
481
            if (
                self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
                and not self.moe.is_act_and_mul
            ):
                logger.info_once(
                    "Non-gated MoE is not supported for min-latency mode,"
                    "falling back to high-throughput mode"
                )
                self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS

482
            logger.info_once(
483
484
485
486
                f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
            )

    def maybe_make_prepare_finalize(
487
        self,
488
        routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
489
    ) -> mk.FusedMoEPrepareAndFinalize | None:
490
491
492
493
        # TRT LLM not supported with all2all yet.
        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
            return None
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
494
495
496
            prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
                self.moe
            )
497
498
499
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
        else:
500
            return super().maybe_make_prepare_finalize(routing_tables)
501
502
503
504

    def select_gemm_impl(
        self,
        prepare_finalize: mk.FusedMoEPrepareAndFinalize,
505
        layer: torch.nn.Module,
506
    ) -> mk.FusedMoEPermuteExpertsUnpermute:
507
        assert self.moe_quant_config is not None
508
        experts = select_cutlass_fp8_gemm_impl(
509
510
            self.moe,
            self.moe_quant_config,
511
512
513
        )
        logger.debug_once("Using %s", experts.__class__.__name__)
        return experts
514
515
516
517
518
519
520
521
522
523
524

    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        # Use FP8 dtype if checkpoint is serialized
525
526
527
528
529
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_fp8_serialized
            else params_dtype
        )
530
531
        weight_loader = extra_weight_attrs.get("weight_loader")

532
533
534
535
536
        if self.moe.is_act_and_mul:
            w13_up_dim = 2 * intermediate_size_per_partition
        else:
            w13_up_dim = intermediate_size_per_partition

537
        w13_weight = ModelWeightParameter(
538
539
            data=torch.empty(
                num_experts,
540
                w13_up_dim,
541
542
543
                hidden_size,
                dtype=weight_dtype,
            ),
544
545
546
547
548
549
550
            input_dim=2,
            output_dim=1,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_weight", w13_weight)

        w2_weight = ModelWeightParameter(
551
552
553
554
555
556
            data=torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=weight_dtype,
            ),
557
558
559
560
561
562
563
564
            input_dim=2,
            output_dim=1,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w2_weight", w2_weight)

        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALES - Per-tensor scaling for ModelOpts
565
            # For gated MoE, allocate 2 scales for w1 and w3 respectively.
566
            # They will be combined to a single scale after weight loading.
567
568
569
570
571
            # For non-gated MoE, allocate 1 scale for w13.
            if self.moe.is_act_and_mul:
                w13_weight_scale_shape = (num_experts, 2)
            else:
                w13_weight_scale_shape = (num_experts, 1)
572
573
            w13_weight_scale = PerTensorScaleParameter(
                data=torch.full(
574
                    w13_weight_scale_shape,
575
576
577
578
579
580
                    1.0,
                    dtype=torch.float32,
                ),
                weight_loader=weight_loader,
            )
            w2_weight_scale = PerTensorScaleParameter(
581
                data=torch.full((num_experts,), 1.0, dtype=torch.float32),
582
583
584
585
586
587
588
                weight_loader=weight_loader,
            )
            layer.register_parameter("w13_weight_scale", w13_weight_scale)
            layer.register_parameter("w2_weight_scale", w2_weight_scale)

            # Set weight loader attributes for scales
            extra_weight_attrs.update(
589
590
                {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
            )
591
592
593

            # INPUT SCALES - Per-tensor scaling for ModelOpt
            w13_input_scale = PerTensorScaleParameter(
594
                data=torch.full((num_experts,), 1.0, dtype=torch.float32),
595
596
597
                weight_loader=weight_loader,
            )
            w2_input_scale = PerTensorScaleParameter(
598
                data=torch.full((num_experts,), 1.0, dtype=torch.float32),
599
600
601
602
603
604
605
606
607
608
                weight_loader=weight_loader,
            )
            layer.register_parameter("w13_input_scale", w13_input_scale)
            layer.register_parameter("w2_input_scale", w2_input_scale)

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        """Process FP8 MoE weights after loading from serialized checkpoint.
        Only supports pre-quantized checkpoints with FP8 weights and scales.
        """

609
        layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
610
611
612
613
        layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)

        from vllm._custom_ops import scaled_fp8_quant
        from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
614
615
            per_tensor_dequantize,
        )
616
617

        # Handle scale parameters
618
        if hasattr(layer, "w13_weight_scale") and layer.w13_weight_scale is not None:
619
620
621
            # Fp8 moe kernel needs single weight scale for w13 per expert.
            # We take the max of the w1 and w3 scales
            # then dequant and requant each expert.
622
623
624
625
626
627
628
629
            if (
                layer.w13_weight_scale.dim() == 2
                and layer.w13_weight_scale.shape[1] == 2
            ):
                assert self.moe.is_act_and_mul, (
                    "w13_weight_scale should have 2 elements per expert "
                    "only for gated MoE"
                )
630
631
632
633
634
635
636
637
638
639
640
641
                # Get the maximum scale across w1 and w3 for each expert
                max_w13_scales = layer.w13_weight_scale.max(dim=1).values

                # Requantize each expert's weights using the combined scale
                # w13_weight (num_experts, 2 * intermediate_size, hidden_size)
                # where the first intermediate_size rows are w1, the next are w3
                intermediate_size = layer.w13_weight.shape[1] // 2
                for expert_id in range(layer.w13_weight.shape[0]):
                    start = 0
                    for shard_id in range(2):  # w1 and w3
                        # Dequantize using the original scale for this shard
                        dq_weight = per_tensor_dequantize(
642
643
644
                            layer.w13_weight[expert_id][
                                start : start + intermediate_size, :
                            ],
645
646
647
648
649
                            layer.w13_weight_scale[expert_id][shard_id],
                        )
                        # Requantize using the combined max scale

                        (
650
651
652
                            layer.w13_weight[expert_id][
                                start : start + intermediate_size, :
                            ],
653
                            _,
654
                        ) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
655
656
657
658

                        start += intermediate_size

                # Update the scale parameter to be per-expert
659
                layer.w13_weight_scale = Parameter(max_w13_scales, requires_grad=False)
660
            else:
661
662
663
                layer.w13_weight_scale = Parameter(
                    layer.w13_weight_scale.data, requires_grad=False
                )
664

665
666
667
668
        if hasattr(layer, "w2_weight_scale") and layer.w2_weight_scale is not None:
            layer.w2_weight_scale = Parameter(
                layer.w2_weight_scale.data, requires_grad=False
            )
669
        # Input scales must be equal for each expert in fp8 MoE layers.
670
671
672
673
674
675
676
677
        if hasattr(layer, "w13_input_scale") and layer.w13_input_scale is not None:
            layer.w13_input_scale = Parameter(
                layer.w13_input_scale.max(), requires_grad=False
            )
        if hasattr(layer, "w2_input_scale") and layer.w2_input_scale is not None:
            layer.w2_input_scale = Parameter(
                layer.w2_input_scale.max(), requires_grad=False
            )
678

679
        if self.flashinfer_moe_backend is not None:
680
681
            if self.moe.is_act_and_mul:
                layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
682
            if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
683
                rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight)
684
        register_moe_scaling_factors(layer)
685

686
    def get_fused_moe_quant_config(
687
        self, layer: torch.nn.Module
688
    ) -> FusedMoEQuantConfig | None:
689
690
691
692
693
        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
            return None

        return fp8_w8a8_moe_quant_config(
            w1_scale=layer.w13_weight_scale,
694
            g1_alphas=layer.output1_scales_gate_scalar.squeeze(),
695
            w2_scale=layer.w2_weight_scale,
696
            g2_alphas=layer.output2_scales_scalar.squeeze(),
697
            a1_scale=layer.w13_input_scale,
698
            a1_gscale=layer.w13_input_scale,
699
            a2_scale=layer.w2_input_scale,
700
            a2_gscale=layer.w2_input_scale_inv,
701
702
703
            per_act_token_quant=False,
        )

704
705
    def apply(
        self,
706
        layer: FusedMoE,
707
708
        x: torch.Tensor,
        router_logits: torch.Tensor,
709
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
710
        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
711
712
713
714
            if layer.enable_eplb:
                raise NotImplementedError(
                    "EPLB not supported for `ModelOptFp8MoEMethod` yet."
                )
715
716
            assert layer.activation == "silu", (
                f"Expected 'silu' activation but got {layer.activation}"
717
            )
718
719

            assert not layer.renormalize
720
721
722
723
            return apply_flashinfer_per_tensor_scale_fp8(
                layer=layer,
                hidden_states=x,
                router_logits=router_logits,
724
725
726
727
728
729
                routing_bias=layer.e_score_correction_bias,
                global_num_experts=layer.global_num_experts,
                top_k=layer.top_k,
                num_expert_group=layer.num_expert_group,
                topk_group=layer.topk_group,
                apply_router_weight_on_input=layer.apply_router_weight_on_input,
730
            )
731

732
        # Expert selection
733
        topk_weights, topk_ids, _ = layer.select_experts(
734
735
736
            hidden_states=x,
            router_logits=router_logits,
        )
737

738
        if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
739
            assert layer.activation in ("silu", "relu2_no_mul"), (
740
                "Expected activation to be in ('silu', 'relu2_no_mul'),"
741
                f"but got {layer.activation}"
742
            )
743
744
745
746
747
748
            return flashinfer_cutlass_moe_fp8(
                x,
                layer,
                topk_weights,
                topk_ids,
                inplace=False,
749
750
751
752
                activation=layer.activation,
                global_num_experts=layer.global_num_experts,
                expert_map=layer.expert_map,
                apply_router_weight_on_input=layer.apply_router_weight_on_input,
753
754
            )
        else:
755
756
            from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts

757
758
759
760
761
762
763
764
765
            assert self.moe_quant_config is not None

            return fused_experts(
                x,
                layer.w13_weight,
                layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                inplace=True,
766
                activation=layer.activation,
767
                quant_config=self.moe_quant_config,
768
769
770
                global_num_experts=layer.global_num_experts,
                expert_map=layer.expert_map,
                apply_router_weight_on_input=layer.apply_router_weight_on_input,
771
            )
772
773


774
775
776
777
778
779
ModelOptFp8Config.LinearMethodCls = ModelOptFp8LinearMethod
ModelOptFp8Config.FusedMoEMethodCls = ModelOptFp8MoEMethod
ModelOptFp8Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod


class ModelOptNvFp4Config(ModelOptQuantConfigBase):
780
781
782
783
784
    """Config class for ModelOpt FP4."""

    def __init__(
        self,
        is_checkpoint_nvfp4_serialized: bool,
785
        kv_cache_quant_algo: str | None,
786
        exclude_modules: list[str],
787
788
        group_size: int = 16,
    ) -> None:
789
        super().__init__(exclude_modules)
790
791
792
793
        self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
        if is_checkpoint_nvfp4_serialized:
            logger.warning(
                "Detected ModelOpt NVFP4 checkpoint. Please note that"
794
795
                " the format is experimental and could change in future."
            )
796
797
798
799

            self.group_size = group_size
            self.kv_cache_quant_algo = kv_cache_quant_algo

800
    def get_name(self) -> QuantizationMethods:
801
        return "modelopt_fp4"
802

803
    def get_supported_act_dtypes(self) -> list[torch.dtype]:
804
805
806
807
        return [torch.bfloat16, torch.half, torch.float8_e4m3fn]

    @classmethod
    def get_min_capability(cls) -> int:
808
        return 80
809

810
811
    @classmethod
    def override_quantization_method(
812
        cls, hf_quant_cfg, user_quant
813
    ) -> QuantizationMethods | None:
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
        """Detect if this ModelOpt FP4 config should be used based on
        quantization config."""
        if hf_quant_cfg is None:
            return None

        # Use the community standard 'quant_method'
        quant_method = hf_quant_cfg.get("quant_method", "").lower()

        # Only proceed if the method is explicitly "modelopt"
        if quant_method != "modelopt":
            return None

        # Look for ModelOpt-specific config structure
        if "quantization" in hf_quant_cfg:
            quant_config = hf_quant_cfg["quantization"]
            if isinstance(quant_config, dict):
                quant_algo = quant_config.get("quant_algo", "")
                if "NVFP4" in quant_algo:
                    return "modelopt_fp4"
        else:
            # Check for compressed-tensors style config with specific
            # quant_algo field
            quant_algo = hf_quant_cfg.get("quant_algo", "")
            if isinstance(quant_algo, str) and "FP4" in quant_algo.upper():
                return "modelopt_fp4"

        return None

842
    @classmethod
843
844
845
846
847
848
849
850
851
852
    def _from_config(
        cls,
        *,
        quant_method: str,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
        original_config: dict[str, Any],
        group_size: int | None,
        **kwargs: Any,
    ) -> "ModelOptNvFp4Config":
853
        is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
854

855
856
857
        if group_size is None:
            group_size = 16  # Default value

858
        # For FP4, these fields are required
859
        if is_checkpoint_nvfp4_serialized and "quantization" in original_config:
860
            # Check if required fields are present in the quantization config
861
            quant_config = original_config["quantization"]
862
            required_fields = ["group_size", "kv_cache_quant_algo", "exclude_modules"]
863
864
865
866
867
868
            missing_fields = [
                field for field in required_fields if field not in quant_config
            ]
            if missing_fields:
                raise ValueError(
                    f"NVFP4 quantization requires the following fields in "
869
870
871
872
873
                    f"hf_quant_config.json: {missing_fields}"
                )

        return cls(
            is_checkpoint_nvfp4_serialized,
874
            kv_cache_quant_method,
875
876
877
            exclude_modules,
            group_size,
        )
878
879
880
881
882


class ModelOptNvFp4LinearMethod(LinearMethodBase):
    """Linear method for Model Optimizer NVFP4.
    Supports loading NVFP4 checkpoints with the following structure:
883

884
885
886
887
888
889
890
    input_scale: torch.float32, scalar ,
    weight: NVFP4(represented as byte) Shape: [1, X, y/2]
    weight_scale: FP8-E4M3, Shape: [X, Y], aka per block scale,
    weight_scale_2: torch.float32, scalar,
    Args: quant_config: The ModelOpt quantization config.
    """

891
    def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
892
        self.quant_config = quant_config
893
        self.marlin_input_dtype = None
894

895
896
897
898
899
900
901
902
903
904
905
        self.backend = "none"
        if envs.VLLM_NVFP4_GEMM_BACKEND is None:
            if has_flashinfer():
                self.backend = "flashinfer-cutlass"
            elif cutlass_fp4_supported():
                self.backend = "cutlass"
            elif is_fp4_marlin_supported():
                self.backend = "marlin"
        elif envs.VLLM_NVFP4_GEMM_BACKEND.startswith("flashinfer-"):
            self.backend = envs.VLLM_NVFP4_GEMM_BACKEND
            assert has_flashinfer(), f"FlashInfer is required for {self.backend}"
906
907
908
        elif envs.VLLM_NVFP4_GEMM_BACKEND == "cutlass":
            self.backend = "cutlass"
            assert cutlass_fp4_supported(), f"Cutlass is required for {self.backend}"
909
910

        if self.backend == "none":
911
            raise ValueError(
912
913
                "No valid NVFP4 GEMM backend found. "
                "Please check your platform capability."
914
            )
915

916
917
        logger.info_once(f"Using {self.backend} for NVFP4 GEMM")

918
919
920
921
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
922
        output_partition_sizes: list[int],
923
924
925
926
927
928
929
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size
        if not self.quant_config.is_checkpoint_nvfp4_serialized:
930
931
932
933
            raise ValueError(
                "NVFP4 quantization was selected, "
                " dynamic quantization is not supported."
            )
934
935
936
937
938
939
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition

940
941
942
943
        if input_size_per_partition % 16 != 0:
            raise ValueError(
                "Unsupported model when in features size is not multiple of 16"
            )
944
        # The nvfp4 weight is still represented as
945
946
947
948
949
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_nvfp4_serialized
            else params_dtype
        )
950
951
952
953
954
955
        # Weight
        weight = ModelWeightParameter(
            data=torch.empty(
                # 2 fp4 items are packed in the input dimension
                layer.output_size_per_partition,
                layer.input_size_per_partition // 2,
956
957
                dtype=torch.uint8,
            ),
958
959
            input_dim=1,
            output_dim=0,
960
961
            weight_loader=weight_loader,
        )
962
963
964
        layer.register_parameter("weight", weight)

        # Input Weight Scale
965
966
967
968
        input_scale = PerTensorScaleParameter(
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )
969
970
971
        layer.register_parameter("input_scale", input_scale)

        # Global Weight Scale
972
973
974
975
        weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )
976
977
978
        layer.register_parameter("weight_scale_2", weight_scale_2)

        # Per Block Weight Scale
979
980
981
982
983
984
985
986
987
988
        weight_scale = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition // self.quant_config.group_size,
                dtype=weight_dtype,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
989
990
991
992
993
994
995
996
997
998
999

        layer.register_parameter("weight_scale", weight_scale)

    def process_weights_after_loading(self, layer: Module) -> None:
        # global scales:
        input_scale_2 = layer.input_scale.max().to(torch.float32)
        layer.input_scale = Parameter(input_scale_2, requires_grad=False)

        weight_scale_2 = layer.weight_scale_2.max().to(torch.float32)
        layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False)

1000
1001
1002
        layer.alpha = Parameter(
            layer.input_scale * layer.weight_scale_2, requires_grad=False
        )
1003

1004
1005
        # Calculate `1 / input_scale` so that we don't need to do so at runtime
        layer.input_scale_inv = Parameter(
1006
1007
            (1 / layer.input_scale).to(torch.float32), requires_grad=False
        )
1008

1009
1010
1011
        # Swizzle the weight blockscale.
        # contracting dimension is input dimension
        # block_size = 16;
1012
1013
1014
        assert layer.weight_scale.dtype == torch.float8_e4m3fn, (
            "Weight Block scale must be represented as FP8-E4M3"
        )
1015

1016
1017
1018
1019
1020
        if self.backend == "marlin":
            prepare_fp4_layer_for_marlin(layer)
            del layer.alpha
            del layer.input_scale
        elif self.backend == "flashinfer-trtllm":
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
            # FlashInfer TRTLLM FP4 GEMM requires a different weight layout.
            # FlashInfer provides nvfp4_quantize to quantize + shuffle the
            # layout but we use our own quantization so we have to call
            # shuffles ourselves.
            from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a

            weight = layer.weight.data
            weight_scale = layer.weight_scale.data

            epilogue_tile_m = 128
1031
1032
1033
1034
1035
1036
            weight = shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m)
            weight_scale = (
                shuffle_matrix_sf_a(weight_scale.view(torch.uint8), epilogue_tile_m)
                .reshape(weight_scale.shape)
                .view(torch.float8_e4m3fn)
            )
1037

1038
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
1039
1040
1041
            layer.weight = Parameter(weight, requires_grad=False)
        else:
            swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
1042
            layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False)
1043
            layer.weight = Parameter(layer.weight.data, requires_grad=False)
1044
1045
1046
1047
1048

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
1049
        bias: torch.Tensor | None = None,
1050
    ) -> torch.Tensor:
1051
        if self.backend == "marlin":
1052
1053
1054
1055
1056
1057
1058
1059
            return apply_fp4_marlin_linear(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale,
                weight_scale_2=layer.weight_scale_2,
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
1060
                bias=bias,
1061
                input_dtype=self.marlin_input_dtype,
1062
            )
1063

1064
        output_dtype = x.dtype
1065
        output_shape = [x.shape[0], layer.weight.shape[0]]
1066
1067

        # quantize BF16 or FP16 to (FP4 and interleaved block scale)
1068
        x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_scale_inv)
1069
1070
1071

        # validate dtypes of quantized input, input block scale,
        # weight and weight_blockscale
1072
1073
1074
1075
1076
        assert x_fp4.dtype == torch.uint8
        assert layer.weight.dtype == torch.uint8
        assert x_blockscale.dtype == torch.float8_e4m3fn
        assert layer.weight_scale.dtype == torch.float8_e4m3fn
        assert layer.alpha.dtype == torch.float32
1077

1078
1079
1080
1081
        mm_args = (
            x_fp4,
            layer.weight,
            x_blockscale,
1082
            layer.weight_scale,
1083
1084
1085
            layer.alpha,
            output_dtype,
        )
1086
1087
1088
        if self.backend.startswith("flashinfer-"):
            backend_name = self.backend[len("flashinfer-") :]
            out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name)
1089
        else:
1090
            assert self.backend == "cutlass"
1091
1092
            out = cutlass_scaled_fp4_mm(*mm_args)

1093
1094
1095
        if bias is not None:
            out = out + bias
        return out.view(*output_shape)
1096
1097
1098
1099
1100


class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
    """
    MoE Method for FP4 Quantization.
1101
    Args:
1102
1103
1104
        quant_config: NVFP4 Quant Config
    """

1105
1106
1107
    def __init__(
        self,
        quant_config: ModelOptNvFp4Config,
1108
        layer: FusedMoE,
1109
    ) -> None:
1110
1111
        from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import (
            detect_nvfp4_moe_support,  # noqa: E501
1112
1113
        )

1114
        super().__init__(layer.moe_config)
1115
1116
        self.quant_config = quant_config
        self.layer = layer
1117
1118
        _nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
        self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
1119
        self.allow_flashinfer = _nvfp4.allow_flashinfer
1120
        self.use_marlin = _nvfp4.use_marlin
1121
        self.marlin_input_dtype = None
1122
1123
        self.flashinfer_moe_backend = None
        if self.allow_flashinfer:
1124
1125
1126
            self.flashinfer_moe_backend = get_flashinfer_moe_backend()
            logger.info_once(
                f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
1127
1128
                " for ModelOptNvFp4FusedMoE."
            )
1129
1130
1131
1132
        elif self.use_marlin:
            logger.info_once("Using Marlin for ModelOptNvFp4FusedMoE.")
        else:
            logger.info_once("Using Cutlass for ModelOptNvFp4FusedMoE.")
1133

1134
1135
1136
1137
    def maybe_make_prepare_finalize(
        self,
        routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
    ) -> mk.FusedMoEPrepareAndFinalize | None:
1138
1139
1140
1141
        if self.use_marlin or (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1142
            return None
1143
1144
1145
1146
        elif (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
        ):
1147
            # For now, fp4 moe only works with the flashinfer dispatcher.
1148
1149
1150
            prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(
                self.moe
            )
1151
1152
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
1153
        else:
1154
            return super().maybe_make_prepare_finalize(routing_tables)
1155

1156
1157
1158
    def select_gemm_impl(
        self,
        prepare_finalize: mk.FusedMoEPrepareAndFinalize,
1159
        layer: torch.nn.Module,
1160
    ) -> mk.FusedMoEPermuteExpertsUnpermute:
1161
        assert self.moe_quant_config is not None
1162
        experts = select_nvfp4_gemm_impl(
1163
1164
            self.moe,
            self.moe_quant_config,
1165
1166
1167
1168
            allow_flashinfer=self.allow_flashinfer,
        )
        logger.debug_once("Using %s", experts.__class__.__name__)
        return experts
1169

1170
1171
1172
1173
1174
1175
    def uses_weight_scale_2_pattern(self) -> bool:
        """
        FP4 variants use 'weight_scale_2' pattern for per-tensor weight scales.
        """
        return True

1176
1177
1178
1179
1180
1181
1182
1183
1184
    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
1185
        if not self.quant_config.is_checkpoint_nvfp4_serialized:
1186
1187
1188
1189
            raise ValueError(
                "NVFP4 quantization was selected, "
                " dynamic quantization is not supported."
            )
1190

1191
1192
        layer.num_experts = num_experts
        layer.params_dtype = params_dtype
1193
1194
1195
1196
        layer.quant_config = self.quant_config
        weight_dtype = torch.uint8
        weight_scale_dtype = torch.float8_e4m3fn
        weight_loader = extra_weight_attrs.get("weight_loader")
1197
        global_num_experts = extra_weight_attrs.get("global_num_experts")
1198
1199
1200
1201
        # GEMM 1
        w13_weight = ModelWeightParameter(
            data=torch.empty(
                num_experts,
1202
                (2 if self.moe.is_act_and_mul else 1) * intermediate_size_per_partition,
1203
1204
                # 2 fp4 items are packed in the input dimension
                hidden_size // 2,
1205
1206
                dtype=weight_dtype,
            ),
1207
1208
            input_dim=1,
            output_dim=2,
1209
1210
            weight_loader=weight_loader,
        )
1211
1212
1213
1214
1215
1216
1217
1218
1219
        layer.register_parameter("w13_weight", w13_weight)

        # GEMM 2
        w2_weight = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                hidden_size,
                # 2 fp4 items are packed in the input dimension
                intermediate_size_per_partition // 2,
1220
1221
                dtype=weight_dtype,
            ),
1222
1223
            input_dim=1,
            output_dim=2,
1224
1225
            weight_loader=weight_loader,
        )
1226
1227
1228
1229
1230
        layer.register_parameter("w2_weight", w2_weight)

        w13_weight_scale = ModelWeightParameter(
            data=torch.empty(
                num_experts,
1231
                (2 if self.moe.is_act_and_mul else 1) * intermediate_size_per_partition,
1232
1233
                # 2 fp4 items are packed in the input dimension
                hidden_size // self.quant_config.group_size,
1234
1235
                dtype=weight_scale_dtype,
            ),
1236
1237
            input_dim=1,
            output_dim=2,
1238
1239
            weight_loader=weight_loader,
        )
1240
1241
1242
1243
1244
1245
1246
        layer.register_parameter("w13_weight_scale", w13_weight_scale)

        w2_weight_scale = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                hidden_size,
                # 2 fp4 items are packed in the input dimension
1247
1248
1249
                intermediate_size_per_partition // self.quant_config.group_size,
                dtype=weight_scale_dtype,
            ),
1250
1251
            input_dim=1,
            output_dim=2,
1252
1253
            weight_loader=weight_loader,
        )
1254
1255
1256
        layer.register_parameter("w2_weight_scale", w2_weight_scale)

        extra_weight_attrs.update(
1257
1258
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
        )
1259
1260

        w13_weight_scale_2 = PerTensorScaleParameter(
1261
1262
1263
            data=torch.empty(
                num_experts, 2 if self.moe.is_act_and_mul else 1, dtype=torch.float32
            ),
1264
1265
            weight_loader=weight_loader,
        )
1266
1267
1268
1269
        layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)

        w2_weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(num_experts, dtype=torch.float32),
1270
1271
            weight_loader=weight_loader,
        )
1272
1273
1274
        layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2)

        extra_weight_attrs.update(
1275
1276
            {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
        )
1277

1278
1279
1280
1281
1282
        use_global_sf = self.allow_flashinfer and is_flashinfer_supporting_global_sf(
            self.flashinfer_moe_backend
        )
        global_scale_num_experts = global_num_experts if use_global_sf else num_experts

1283
        w13_input_scale = PerTensorScaleParameter(
1284
1285
1286
1287
1288
            data=torch.empty(
                global_scale_num_experts,
                2 if self.moe.is_act_and_mul else 1,
                dtype=torch.float32,
            ),
1289
1290
            weight_loader=weight_loader,
        )
1291
1292
        layer.register_parameter("w13_input_scale", w13_input_scale)

1293
        w2_input_scale = PerTensorScaleParameter(
1294
            data=torch.empty(global_scale_num_experts, dtype=torch.float32),
1295
1296
            weight_loader=weight_loader,
        )
1297
1298
1299
        layer.register_parameter("w2_input_scale", w2_input_scale)

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
1300
        # GEMM 1 processing
1301
1302
1303
        gemm1_weight = layer.w13_weight.data
        gemm1_weight_scale = layer.w13_weight_scale.data

1304
1305
1306
1307
1308
1309
1310
        if (
            self.allow_flashinfer
            and (
                self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
                or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
            )
            and self.moe.is_act_and_mul
1311
        ):
1312
            gemm1_weight, gemm1_weight_scale = reorder_w1w3_to_w3w1(
1313
1314
                gemm1_weight, gemm1_weight_scale, dim=-2
            )
1315
1316

        layer.w13_weight = Parameter(gemm1_weight, requires_grad=False)
1317
        layer.w13_weight_scale = Parameter(gemm1_weight_scale, requires_grad=False)
1318

1319
        # Common processing for w13_weight_scale_2
1320
        if self.moe.is_act_and_mul and not torch.allclose(
1321
1322
            layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]
        ):
1323
1324
            logger.warning_once(
                "w1_weight_scale_2 must match w3_weight_scale_2. "
1325
1326
                "Accuracy may be affected."
            )
1327
1328

        w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
1329
        layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False)
1330

1331
        # Common processing for input scales and alphas
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
        use_global_sf = self.allow_flashinfer and is_flashinfer_supporting_global_sf(
            self.flashinfer_moe_backend
        )
        if use_global_sf:
            # For backends provide by Flashinfer, the input global scales are
            # shared across all experts.
            w13_input_scale = (
                layer.w13_input_scale.max().to(torch.float32).expand(layer.num_experts)
            )
        else:
            w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
1343
1344
        layer.g1_alphas = Parameter(
            (w13_input_scale * w13_weight_scale_2).to(torch.float32),
1345
1346
            requires_grad=False,
        )
1347
1348
1349

        # This is for quantization, so we need to invert it.
        layer.w13_input_scale_quant = Parameter(
1350
1351
            (1 / w13_input_scale).to(torch.float32), requires_grad=False
        )
1352

1353
        # GEMM 2 processing
1354
1355
1356
1357
1358
1359
1360
1361
        if use_global_sf:
            # For backends provide by Flashinfer, the input global scales are
            # shared across all experts.
            w2_input_scale = (
                layer.w2_input_scale.max().to(torch.float32).expand(layer.num_experts)
            )
        else:
            w2_input_scale = layer.w2_input_scale
1362
        layer.g2_alphas = Parameter(
1363
            (w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
1364
1365
            requires_grad=False,
        )
1366
1367
1368

        # This is for quantization, so we need to invert it.
        layer.w2_input_scale_quant = Parameter(
1369
            (1 / w2_input_scale).to(torch.float32), requires_grad=False
1370
        )
1371

1372
        # TensorRT-LLM specific processing
1373
1374
1375
1376
        if (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1377
            # Prepare static weights for TRT-LLM kernel
1378
            # alternate: prepare_static_weight_layouts_for_trtllm_moe
1379
1380
1381
1382
1383
            (
                gemm1_weights_fp4_shuffled,
                gemm1_scales_fp4_shuffled,
                gemm2_weights_fp4_shuffled,
                gemm2_scales_fp4_shuffled,
1384
            ) = prepare_static_weights_for_trtllm_fp4_moe(
1385
1386
1387
1388
1389
1390
1391
1392
                layer.w13_weight,
                layer.w2_weight,
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                layer.w2_weight.size(-2),  # hidden_size
                layer.w13_weight.size(-2) // 2,  # intermediate_size
                layer.w13_weight.size(0),  # num_experts
            )
1393
            logger.debug_once("Finished shuffling weights for TRT-LLM MOE")
1394
1395

            layer.gemm1_weights_fp4_shuffled = Parameter(
1396
1397
                gemm1_weights_fp4_shuffled, requires_grad=False
            )
1398
            layer.gemm2_weights_fp4_shuffled = Parameter(
1399
1400
                gemm2_weights_fp4_shuffled, requires_grad=False
            )
1401
            layer.gemm1_scales_fp4_shuffled = Parameter(
1402
1403
                gemm1_scales_fp4_shuffled, requires_grad=False
            )
1404
            layer.gemm2_scales_fp4_shuffled = Parameter(
1405
1406
                gemm2_scales_fp4_shuffled, requires_grad=False
            )
1407
1408
1409

            # Additional parameter needed for TRT-LLM
            layer.g1_scale_c = Parameter(
1410
                (layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32),
1411
1412
                requires_grad=False,
            )
1413

1414
1415
1416
1417
1418
            # Clean up weights that won't be used by TRT-LLM
            del layer.w2_weight
            del layer.w2_weight_scale
            del layer.w13_weight
            del layer.w13_weight_scale
1419
1420
1421
1422
1423
1424
1425
        elif self.use_marlin:
            # Marlin processing
            prepare_moe_fp4_layer_for_marlin(layer)
            del layer.g1_alphas
            del layer.g2_alphas
            del layer.w13_input_scale_quant
            del layer.w2_input_scale_quant
1426
1427
        else:
            # Non-TRT-LLM processing (Cutlass or non-flashinfer)
1428
1429
1430
1431
1432
            w13_blockscale_swizzled = swizzle_blockscale(layer.w13_weight_scale)
            layer.w13_weight_scale = Parameter(
                w13_blockscale_swizzled, requires_grad=False
            )

1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
            w13_weight = layer.w13_weight
            intermediate_size_pad = w13_blockscale_swizzled.size(1) - w13_weight.size(1)
            if intermediate_size_pad:
                # padding gated activations will require to split w1 and w3
                # and pad them individually
                assert not self.moe.is_act_and_mul, (
                    "The intermediate size required padding, "
                    "but padding is not implemented for gated activations"
                )

                layer.w13_weight = Parameter(
                    torch.nn.functional.pad(
                        w13_weight, (0, 0, 0, intermediate_size_pad)
                    ),
                    requires_grad=False,
                )
                layer.w2_weight = Parameter(
                    torch.nn.functional.pad(
                        layer.w2_weight, (0, intermediate_size_pad // 2, 0, 0)
                    ),
                    requires_grad=False,
                )
                layer.w2_weight_scale = Parameter(
                    torch.nn.functional.pad(
                        layer.w2_weight_scale, (0, intermediate_size_pad // 16)
                    ),
                    requires_grad=False,
                )

1462
            w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale)
1463
1464
1465
            layer.w2_weight_scale = Parameter(
                w2_blockscale_swizzled, requires_grad=False
            )
1466

1467
    def get_fused_moe_quant_config(
1468
        self, layer: torch.nn.Module
1469
    ) -> FusedMoEQuantConfig | None:
1470
1471
1472
1473
        if (
            self.use_marlin
            or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
            return None

        return nvfp4_moe_quant_config(
            w1_scale=layer.w13_weight_scale,
            w2_scale=layer.w2_weight_scale,
            g1_alphas=layer.g1_alphas,
            g2_alphas=layer.g2_alphas,
            a1_gscale=layer.w13_input_scale_quant,
            a2_gscale=layer.w2_input_scale_quant,
        )

1485
1486
    def apply(
        self,
1487
        layer: FusedMoE,
1488
1489
        x: torch.Tensor,
        router_logits: torch.Tensor,
1490
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
1491
1492
1493
1494
1495
1496
1497
1498
        if not self.moe.is_act_and_mul:
            assert (
                self.allow_flashinfer
                and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
            ), (
                "Non-gated activations are only supported by the"
                " flashinfer CUTLASS backend for modelopt checkpoints"
            )
1499

1500
1501
1502
1503
        if (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1504
            if layer.enable_eplb:
1505
1506
1507
                raise NotImplementedError(
                    "EPLB not supported for `ModelOptNvFp4FusedMoE` yet."
                )
1508
1509
1510
1511
            return flashinfer_trtllm_fp4_moe(
                layer=layer,
                x=x,
                router_logits=router_logits,
1512
1513
1514
1515
1516
1517
                top_k=layer.top_k,
                global_num_experts=layer.global_num_experts,
                num_expert_group=layer.num_expert_group,
                topk_group=layer.topk_group,
                custom_routing_function=layer.custom_routing_function,
                e_score_correction_bias=layer.e_score_correction_bias,
1518
            )
1519

1520
        topk_weights, topk_ids, _ = layer.select_experts(
1521
1522
            hidden_states=x,
            router_logits=router_logits,
1523
        )
1524

1525
        if self.use_marlin:
1526
            return fused_marlin_moe(
1527
1528
1529
                x,
                layer.w13_weight,
                layer.w2_weight,
1530
1531
                None,
                None,
1532
1533
1534
1535
1536
1537
1538
1539
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                router_logits,
                topk_weights,
                topk_ids,
                global_scale1=layer.w13_weight_scale_2,
                global_scale2=layer.w2_weight_scale_2,
                quant_type_id=scalar_types.float4_e2m1f.id,
1540
1541
1542
                apply_router_weight_on_input=layer.apply_router_weight_on_input,
                global_num_experts=layer.global_num_experts,
                expert_map=layer.expert_map,
1543
                input_dtype=self.marlin_input_dtype,
1544
            )
1545

1546
1547
1548
1549
        elif self.allow_flashinfer:
            assert self.flashinfer_moe_backend in (
                FlashinferMoeBackend.CUTLASS,
                FlashinferMoeBackend.CUTEDSL,
1550
            )
1551
1552
1553
1554
            if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
                from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (  # noqa: E501
                    flashinfer_cutlass_moe_fp4,
                )
1555

1556
1557
1558
1559
1560
1561
1562
                flashinfer_fn_moe_fp4 = flashinfer_cutlass_moe_fp4
            else:
                from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import (  # noqa: E501
                    flashinfer_cutedsl_moe_fp4,
                )

                flashinfer_fn_moe_fp4 = flashinfer_cutedsl_moe_fp4
1563

1564
1565
            assert self.moe_quant_config is not None
            return flashinfer_fn_moe_fp4(
1566
1567
1568
1569
1570
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
1571
1572
                quant_config=self.moe_quant_config,
                inplace=False,
1573
1574
1575
1576
                activation=layer.activation,
                global_num_experts=layer.global_num_experts,
                expert_map=layer.expert_map,
                apply_router_weight_on_input=layer.apply_router_weight_on_input,
1577
1578
            )
        else:
1579
1580
            # If no modular kernel is provided, use cutlass_moe_fp4 for TP case
            # only (no EP).
1581
1582
            from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4

1583
1584
            assert self.moe_quant_config is not None
            return cutlass_moe_fp4(
1585
1586
1587
1588
1589
                a=x,
                w1_fp4=layer.w13_weight,
                w2_fp4=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
1590
                quant_config=self.moe_quant_config,
1591
1592
                expert_map=layer.expert_map,
                apply_router_weight_on_input=layer.apply_router_weight_on_input,
1593
                # TODO: derive from arguments
1594
1595
1596
1597
                m=x.shape[0],
                n=layer.w2_weight.shape[2] * 2,
                k=x.shape[1],
                e=layer.w13_weight.shape[0],
1598
            )
1599
1600
1601
1602
1603


ModelOptNvFp4Config.LinearMethodCls = ModelOptNvFp4LinearMethod
ModelOptNvFp4Config.FusedMoEMethodCls = ModelOptNvFp4FusedMoE
ModelOptNvFp4Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod