modelopt.py 59.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Callable
5
from fnmatch import fnmatch
6
from typing import TYPE_CHECKING, Any, Optional
7
8
9
10
11

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter

12
13
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
14
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
15
from vllm.attention.layer import Attention
16
from vllm.logger import init_logger
17
from vllm.model_executor.layers.fused_moe.config import (
18
19
20
21
    FusedMoEQuantConfig,
    fp8_w8a8_moe_quant_config,
    nvfp4_moe_quant_config,
)
22
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
23
from vllm.model_executor.layers.fused_moe.layer import (
24
25
26
27
28
29
30
31
32
    FusedMoE,
    FusedMoEMethodBase,
    FusedMoeWeightScaleSupported,
)
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
)
33
from vllm.model_executor.layers.quantization import QuantizationMethods
34
from vllm.model_executor.layers.quantization.base_config import (
35
36
37
    QuantizationConfig,
    QuantizeMethodBase,
)
38
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
39
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
40
    build_flashinfer_fp4_cutlass_moe_prepare_finalize,
41
42
    flashinfer_trtllm_fp4_moe,
    prepare_static_weights_for_trtllm_fp4_moe,
43
44
45
    reorder_w1w3_to_w3w1,
    select_nvfp4_gemm_impl,
)
46
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
47
48
    FlashinferMoeBackend,
    apply_flashinfer_per_tensor_scale_fp8,
49
    build_flashinfer_fp8_cutlass_moe_prepare_finalize,
50
51
    flashinfer_cutlass_moe_fp8,
    get_flashinfer_moe_backend,
52
    is_flashinfer_supporting_global_sf,
53
54
55
56
57
    register_moe_scaling_factors,
    rotate_flashinfer_fp8_moe_weights,
    select_cutlass_fp8_gemm_impl,
    swap_w13_to_w31,
)
58
59
60
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
    get_marlin_input_dtype,
)
61
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
62
63
64
65
66
    apply_fp4_marlin_linear,
    is_fp4_marlin_supported,
    prepare_fp4_layer_for_marlin,
    prepare_moe_fp4_layer_for_marlin,
)
67
from vllm.model_executor.layers.quantization.utils.quant_utils import (
68
69
70
71
72
    GroupShape,
    cutlass_fp4_supported,
    is_layer_skipped,
    swizzle_blockscale,
)
73
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
74
75
76
77
    Fp8LinearOp,
    requantize_with_max_scale,
)
from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
78
from vllm.scalar_type import scalar_types
79
80
81
82
83
from vllm.utils.flashinfer import (
    flashinfer_scaled_fp4_mm,
    has_flashinfer,
    has_flashinfer_moe,
)
84

85
86
87
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

88
89
logger = init_logger(__name__)

90
91
QUANT_ALGOS = ["FP8", "NVFP4"]
KV_CACHE_QUANT_ALGOS = ["FP8"]
92
93


94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
    """

    def __init__(self, quant_config: "ModelOptQuantConfigBase"):
        super().__init__(quant_config)


class ModelOptQuantConfigBase(QuantizationConfig):
    LinearMethodCls: type = LinearMethodBase
    FusedMoEMethodCls: type = FusedMoEMethodBase
    KVCacheMethodCls: type = BaseKVCacheMethod

    def __init__(
        self,
        exclude_modules: list[str],
    ):
        super().__init__()
        self.exclude_modules: list[str] = exclude_modules

    def is_layer_excluded(self, prefix: str) -> bool:
        """
        Check if a layer should be excluded from quantization.

        Handles both exact matching (for fused layers) and ModelOpt wildcard matching.

        The ModelOpt exclude_modules list is a list of wildcards.
        """
        if len(self.exclude_modules) == 0:
            return False

        # First check exact matching with fused layer support
        if is_layer_skipped(prefix, self.exclude_modules, self.packed_modules_mapping):
            return True

        # TODO: This special hard coded logic is not needed for quantized checkpoints
        # generated by ModelOpt >= 0.39.0 where they are handled natually by the
        # exclude_modules config. But need to keep them for loading quantized
        # checkpoints generated by older versions. Then check substring matching
        # for patterns not caught by exact match
        for exclude_module in self.exclude_modules:
            # Skip exact matches already handled above
            if exclude_module != prefix and (
                exclude_module in prefix
                or (
                    prefix.startswith("language_model.")
                    and exclude_module in prefix.removeprefix("language_model.")
                )
            ):
                return True

        # modelopt exclude modules are not simple strings, they are wildcards
        for wildcard_pattern in self.exclude_modules:
            if fnmatch(prefix, wildcard_pattern):
                return True

        return False

    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
        # handle kv-cache first so we can focus only on weight quantization thereafter
        if isinstance(layer, Attention):
            return self.KVCacheMethodCls(self)

        # handle exclusion
        if self.is_layer_excluded(prefix):
            if isinstance(layer, LinearBase):
                return UnquantizedLinearMethod()
            return None

        # TODO: This special hard coded logic is not needed for quantized checkpoints
        # generated by ModelOpt >= 0.39.0 where they are handled natually by the
        # exclude_modules config. But need to keep them for loading quantized
        # checkpoints generated by older versions. Then check substring matching
        # for patterns not caught by exact match
        if "vision_tower" in prefix or "vision_model" in prefix:
            return UnquantizedLinearMethod()

        # now, the layer is quantized, handle it here
        if isinstance(layer, LinearBase):
176
177
178
179
            quant_method = self.LinearMethodCls(self)
            if getattr(quant_method, "backend", "") == "marlin":
                quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
            return quant_method
180
        elif isinstance(layer, FusedMoE):
181
182
183
184
            quant_method = self.FusedMoEMethodCls(quant_config=self, layer=layer)
            if getattr(quant_method, "backend", "") == "marlin":
                quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
            return quant_method
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282

        return None

    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if len(self.exclude_modules) > 0:
            self.exclude_modules = hf_to_vllm_mapper.apply_list(self.exclude_modules)

    @staticmethod
    def get_config_filenames() -> list[str]:
        return ["hf_quant_config.json"]

    @classmethod
    def _from_config(
        cls,
        *,
        quant_method: str,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
        original_config: dict[str, Any],
        group_size: int | None,
    ) -> "ModelOptQuantConfigBase":
        raise NotImplementedError("Please implement this function in sub classes")

    @classmethod
    def from_config(cls, config: dict[str, Any]) -> "ModelOptQuantConfigBase":
        # Handle both ModelOpt format and compressed-tensors style format
        if "quantization" in config:
            # Traditional ModelOpt format:
            # {"quantization": {"quant_algo": "..."}}
            quant_config = cls.get_from_keys(config, ["quantization"])
            if not isinstance(quant_config, dict):
                raise ValueError("Expected 'quantization' to be a dictionary in config")

            quant_method = quant_config.get("quant_algo")

            # Handle kv_cache_quant_algo with proper type validation
            kv_cache_quant_method = quant_config.get("kv_cache_quant_algo")

            # Handle group_size with proper type validation
            group_size_raw = quant_config.get("group_size")

            # "exclude_modules" is the key in the legacy hf_quant_config.json
            exclude_modules = quant_config.get("exclude_modules", [])
        else:
            # Compressed-tensors style format:
            # {"quant_algo": "...", "quant_method": "modelopt"}
            quant_method = config.get("quant_algo")
            kv_cache_quant_method = config.get("kv_cache_quant_algo")
            # "ignore" is the key in config.json
            exclude_modules = config.get("ignore", [])
            group_size_raw = config.get("group_size")

        if not quant_method:
            raise ValueError("Missing 'quant_algo' in quantization config")

        if kv_cache_quant_method is None:
            # No KV cache quantization, keep this branch just to have this comment
            pass
        elif not isinstance(kv_cache_quant_method, str):
            raise ValueError(
                f"kv_cache_quant_algo must be a string, got "
                f"{type(kv_cache_quant_method)}"
            )

        if not isinstance(exclude_modules, list):
            raise ValueError(
                f"exclude_modules must be a list, got {type(exclude_modules)}"
            )

        if group_size_raw is None:
            group_size = None
        elif isinstance(group_size_raw, int):
            group_size = group_size_raw
        else:
            try:
                group_size = int(group_size_raw)
            except (ValueError, TypeError):
                raise ValueError(
                    f"group_size must be an integer, got {type(group_size_raw)}"
                ) from None

        if quant_method not in QUANT_ALGOS:
            raise ValueError(
                f"ModelOpt currently only supports: {QUANT_ALGOS} "
                "quantizations in vLLM. Please check the "
                "`hf_quant_config.json` file for your model's "
                "quant configuration."
            )
        return cls._from_config(
            quant_method=quant_method,
            kv_cache_quant_method=kv_cache_quant_method,
            exclude_modules=exclude_modules,
            group_size=group_size,
            original_config=config,
        )


class ModelOptFp8Config(ModelOptQuantConfigBase):
283
284
285
286
    """Config class for ModelOpt FP8."""

    def __init__(
        self,
287
288
289
        is_checkpoint_fp8_serialized: bool,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
290
    ) -> None:
291
        super().__init__(exclude_modules)
292
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
293
        self.kv_cache_quant_method = kv_cache_quant_method
294
        if is_checkpoint_fp8_serialized:
295
296
297
298
            logger.warning(
                "Detected ModelOpt fp8 checkpoint. Please note that"
                " the format is experimental and could change."
            )
299

300
    def get_name(self) -> QuantizationMethods:
301
302
        return "modelopt"

303
    def get_supported_act_dtypes(self) -> list[torch.dtype]:
304
305
306
307
308
309
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
        return 89

310
311
    @classmethod
    def override_quantization_method(
312
        cls, hf_quant_cfg, user_quant
313
    ) -> QuantizationMethods | None:
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
        """Detect if this ModelOpt config should be used based on
        quantization config."""

        if hf_quant_cfg is None:
            return None

        # Use the community standard 'quant_method'
        quant_method = hf_quant_cfg.get("quant_method", "").lower()

        # Only proceed if the method is explicitly "modelopt"
        if quant_method != "modelopt":
            return None

        # Look for ModelOpt-specific config structure
        if "quantization" in hf_quant_cfg:
            quant_config = hf_quant_cfg["quantization"]
            if isinstance(quant_config, dict):
                quant_algo = quant_config.get("quant_algo", "")
                if "FP8" in quant_algo:
                    return "modelopt"
        else:
            # Check for compressed-tensors style config with specific quant_algo
            quant_algo = hf_quant_cfg.get("quant_algo", "")
            if isinstance(quant_algo, str) and "FP8" in quant_algo:
                return "modelopt"

        return None

342
    @classmethod
343
344
345
346
347
348
349
350
351
    def _from_config(
        cls,
        *,
        quant_method: str,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
        original_config: dict[str, Any],
        **kwargs: Any,
    ) -> "ModelOptFp8Config":
352
        is_checkpoint_fp8_serialized = "FP8" in quant_method
353

354
        return cls(is_checkpoint_fp8_serialized, kv_cache_quant_method, exclude_modules)
355

356
357
358
359

class ModelOptFp8LinearMethod(LinearMethodBase):
    """Linear method for Model Optimizer static quantization.
    Supports loading FP8 checkpoints with static weight scale and
360
    activation scale. Future support might be added for dynamic
361
362
363
364
    scales.

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
365
    2. Only support float8_e4m3fn datatype
366
367
368
        Args: quant_config: The ModelOpt quantization config.
    """

369
    def __init__(self, quant_config: ModelOptFp8Config) -> None:
370
        self.quant_config = quant_config
371
        self.fp8_linear = Fp8LinearOp(
372
373
            act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR
        )
374
375
376
377
378

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
379
        output_partition_sizes: list[int],
380
381
382
383
384
385
386
387
388
389
390
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
391
392
393
394
395
396
397
398
399
400
401
402
403
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_fp8_serialized
            else params_dtype
        )
        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition, input_size_per_partition, dtype=weight_dtype
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
404
405
406
407
        layer.register_parameter("weight", weight)

        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
408
409
410
411
            weight_scale = PerTensorScaleParameter(
                data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
                weight_loader=weight_loader,
            )
412
413
414
            weight_scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("weight_scale", weight_scale)
            # INPUT SCALE
415
416
417
418
            scale = PerTensorScaleParameter(
                data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
                weight_loader=weight_loader,
            )
419
420
421
422
423

            scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("input_scale", scale)

    def process_weights_after_loading(self, layer: Module) -> None:
424
425
426
427
        weight = layer.weight
        max_w_scale = layer.weight_scale.max()
        if not (layer.weight_scale == layer.weight_scale[0]).all():
            max_w_scale, weight = requantize_with_max_scale(
428
429
                layer.weight, layer.weight_scale, layer.logical_widths
            )
430
431
        layer.weight = Parameter(weight.t(), requires_grad=False)
        layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
432
        layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
433
434
435
436
437

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
438
        bias: torch.Tensor | None = None,
439
    ) -> torch.Tensor:
440
441
442
443
444
445
446
        return self.fp8_linear.apply(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            input_scale=layer.input_scale,
            bias=bias,
        )
447
448


449
450
451
452
453
454
455
456
class ModelOptFp8MoEMethod(FusedMoEMethodBase):
    """MoE method for ModelOpt FP8.
    Supports loading FP8 checkpoints with static weight scale and
    activation scale.
    Args:
        quant_config: The ModelOpt quantization config.
    """

457
458
459
    def __init__(
        self,
        quant_config: ModelOptFp8Config,
460
        layer: FusedMoE,
461
    ) -> None:
462
463
        super().__init__(layer.moe_config)
        self.layer = layer
464
465
        self.quant_config = quant_config
        from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
466
467
468
            cutlass_fp8_supported,
        )

469
        self.cutlass_fp8_supported = cutlass_fp8_supported()
470
        self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
471
        if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
472
            self.flashinfer_moe_backend = get_flashinfer_moe_backend()
473
474
475
476
477
478
479
480
481
482
            if (
                self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
                and not self.moe.is_act_and_mul
            ):
                logger.info_once(
                    "Non-gated MoE is not supported for min-latency mode,"
                    "falling back to high-throughput mode"
                )
                self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS

483
            logger.info_once(
484
485
486
487
                f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
            )

    def maybe_make_prepare_finalize(
488
        self,
489
        routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
490
    ) -> mk.FusedMoEPrepareAndFinalize | None:
491
492
493
494
        # TRT LLM not supported with all2all yet.
        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
            return None
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
495
496
497
            prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
                self.moe
            )
498
499
500
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
        else:
501
            return super().maybe_make_prepare_finalize(routing_tables)
502
503
504
505

    def select_gemm_impl(
        self,
        prepare_finalize: mk.FusedMoEPrepareAndFinalize,
506
        layer: torch.nn.Module,
507
    ) -> mk.FusedMoEPermuteExpertsUnpermute:
508
        assert self.moe_quant_config is not None
509
        experts = select_cutlass_fp8_gemm_impl(
510
511
            self.moe,
            self.moe_quant_config,
512
513
514
        )
        logger.debug_once("Using %s", experts.__class__.__name__)
        return experts
515
516
517
518
519
520
521
522
523
524
525

    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        # Use FP8 dtype if checkpoint is serialized
526
527
528
529
530
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_fp8_serialized
            else params_dtype
        )
531
532
        weight_loader = extra_weight_attrs.get("weight_loader")

533
534
535
536
537
        if self.moe.is_act_and_mul:
            w13_up_dim = 2 * intermediate_size_per_partition
        else:
            w13_up_dim = intermediate_size_per_partition

538
        w13_weight = ModelWeightParameter(
539
540
            data=torch.empty(
                num_experts,
541
                w13_up_dim,
542
543
544
                hidden_size,
                dtype=weight_dtype,
            ),
545
546
547
548
549
550
551
            input_dim=2,
            output_dim=1,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_weight", w13_weight)

        w2_weight = ModelWeightParameter(
552
553
554
555
556
557
            data=torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=weight_dtype,
            ),
558
559
560
561
562
563
564
565
            input_dim=2,
            output_dim=1,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w2_weight", w2_weight)

        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALES - Per-tensor scaling for ModelOpts
566
            # For gated MoE, allocate 2 scales for w1 and w3 respectively.
567
            # They will be combined to a single scale after weight loading.
568
569
570
571
572
            # For non-gated MoE, allocate 1 scale for w13.
            if self.moe.is_act_and_mul:
                w13_weight_scale_shape = (num_experts, 2)
            else:
                w13_weight_scale_shape = (num_experts, 1)
573
574
            w13_weight_scale = PerTensorScaleParameter(
                data=torch.full(
575
                    w13_weight_scale_shape,
576
577
578
579
580
581
                    1.0,
                    dtype=torch.float32,
                ),
                weight_loader=weight_loader,
            )
            w2_weight_scale = PerTensorScaleParameter(
582
                data=torch.full((num_experts,), 1.0, dtype=torch.float32),
583
584
585
586
587
588
589
                weight_loader=weight_loader,
            )
            layer.register_parameter("w13_weight_scale", w13_weight_scale)
            layer.register_parameter("w2_weight_scale", w2_weight_scale)

            # Set weight loader attributes for scales
            extra_weight_attrs.update(
590
591
                {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
            )
592
593
594

            # INPUT SCALES - Per-tensor scaling for ModelOpt
            w13_input_scale = PerTensorScaleParameter(
595
                data=torch.full((num_experts,), 1.0, dtype=torch.float32),
596
597
598
                weight_loader=weight_loader,
            )
            w2_input_scale = PerTensorScaleParameter(
599
                data=torch.full((num_experts,), 1.0, dtype=torch.float32),
600
601
602
603
604
605
606
607
608
609
                weight_loader=weight_loader,
            )
            layer.register_parameter("w13_input_scale", w13_input_scale)
            layer.register_parameter("w2_input_scale", w2_input_scale)

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        """Process FP8 MoE weights after loading from serialized checkpoint.
        Only supports pre-quantized checkpoints with FP8 weights and scales.
        """

610
        layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
611
612
613
614
        layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)

        from vllm._custom_ops import scaled_fp8_quant
        from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
615
616
            per_tensor_dequantize,
        )
617
618

        # Handle scale parameters
619
        if hasattr(layer, "w13_weight_scale") and layer.w13_weight_scale is not None:
620
621
622
            # Fp8 moe kernel needs single weight scale for w13 per expert.
            # We take the max of the w1 and w3 scales
            # then dequant and requant each expert.
623
624
625
626
627
628
629
630
            if (
                layer.w13_weight_scale.dim() == 2
                and layer.w13_weight_scale.shape[1] == 2
            ):
                assert self.moe.is_act_and_mul, (
                    "w13_weight_scale should have 2 elements per expert "
                    "only for gated MoE"
                )
631
632
633
634
635
636
637
638
639
640
641
642
                # Get the maximum scale across w1 and w3 for each expert
                max_w13_scales = layer.w13_weight_scale.max(dim=1).values

                # Requantize each expert's weights using the combined scale
                # w13_weight (num_experts, 2 * intermediate_size, hidden_size)
                # where the first intermediate_size rows are w1, the next are w3
                intermediate_size = layer.w13_weight.shape[1] // 2
                for expert_id in range(layer.w13_weight.shape[0]):
                    start = 0
                    for shard_id in range(2):  # w1 and w3
                        # Dequantize using the original scale for this shard
                        dq_weight = per_tensor_dequantize(
643
644
645
                            layer.w13_weight[expert_id][
                                start : start + intermediate_size, :
                            ],
646
647
648
649
650
                            layer.w13_weight_scale[expert_id][shard_id],
                        )
                        # Requantize using the combined max scale

                        (
651
652
653
                            layer.w13_weight[expert_id][
                                start : start + intermediate_size, :
                            ],
654
                            _,
655
                        ) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
656
657
658
659

                        start += intermediate_size

                # Update the scale parameter to be per-expert
660
                layer.w13_weight_scale = Parameter(max_w13_scales, requires_grad=False)
661
            else:
662
663
664
                layer.w13_weight_scale = Parameter(
                    layer.w13_weight_scale.data, requires_grad=False
                )
665

666
667
668
669
        if hasattr(layer, "w2_weight_scale") and layer.w2_weight_scale is not None:
            layer.w2_weight_scale = Parameter(
                layer.w2_weight_scale.data, requires_grad=False
            )
670
        # Input scales must be equal for each expert in fp8 MoE layers.
671
672
673
674
675
676
677
678
        if hasattr(layer, "w13_input_scale") and layer.w13_input_scale is not None:
            layer.w13_input_scale = Parameter(
                layer.w13_input_scale.max(), requires_grad=False
            )
        if hasattr(layer, "w2_input_scale") and layer.w2_input_scale is not None:
            layer.w2_input_scale = Parameter(
                layer.w2_input_scale.max(), requires_grad=False
            )
679

680
        if self.flashinfer_moe_backend is not None:
681
682
            if self.moe.is_act_and_mul:
                layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
683
            if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
684
                rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight)
685
        register_moe_scaling_factors(layer)
686

687
    def get_fused_moe_quant_config(
688
        self, layer: torch.nn.Module
689
    ) -> FusedMoEQuantConfig | None:
690
691
692
693
694
        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
            return None

        return fp8_w8a8_moe_quant_config(
            w1_scale=layer.w13_weight_scale,
695
            g1_alphas=layer.output1_scales_gate_scalar.squeeze(),
696
            w2_scale=layer.w2_weight_scale,
697
            g2_alphas=layer.output2_scales_scalar.squeeze(),
698
            a1_scale=layer.w13_input_scale,
699
            a1_gscale=layer.w13_input_scale,
700
            a2_scale=layer.w2_input_scale,
701
            a2_gscale=layer.w2_input_scale_inv,
702
703
704
            per_act_token_quant=False,
        )

705
706
    def apply(
        self,
707
        layer: FusedMoE,
708
709
710
711
712
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
        use_grouped_topk: bool = False,
713
714
        topk_group: int | None = None,
        num_expert_group: int | None = None,
715
        global_num_experts: int = -1,
716
717
        expert_map: torch.Tensor | None = None,
        custom_routing_function: Callable | None = None,
718
        scoring_func: str = "softmax",
719
        routed_scaling_factor: float = 1.0,
720
        e_score_correction_bias: torch.Tensor | None = None,
721
722
723
        apply_router_weight_on_input: bool = False,
        activation: str = "silu",
        enable_eplb: bool = False,
724
725
726
727
        expert_load_view: torch.Tensor | None = None,
        logical_to_physical_map: torch.Tensor | None = None,
        logical_replica_count: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
728
        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
729
730
731
732
            if layer.enable_eplb:
                raise NotImplementedError(
                    "EPLB not supported for `ModelOptFp8MoEMethod` yet."
                )
733
734
735
            assert activation == "silu", (
                f"Expected 'silu' activation but got {activation}"
            )
736
737
738
739
740
741
742
743
744
745
            assert not renormalize
            return apply_flashinfer_per_tensor_scale_fp8(
                layer=layer,
                hidden_states=x,
                router_logits=router_logits,
                routing_bias=e_score_correction_bias,
                global_num_experts=global_num_experts,
                top_k=top_k,
                num_expert_group=num_expert_group,
                topk_group=topk_group,
746
747
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
748

749
        # Expert selection
750
        topk_weights, topk_ids, _ = layer.select_experts(
751
752
753
            hidden_states=x,
            router_logits=router_logits,
        )
754

755
        if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
756
757
758
            assert activation in ("silu", "relu2_no_mul"), (
                "Expected activation to be in ('silu', 'relu2_no_mul'),"
                f"but got {activation}"
759
            )
760
761
762
763
764
765
766
767
768
769
770
771
            return flashinfer_cutlass_moe_fp8(
                x,
                layer,
                topk_weights,
                topk_ids,
                inplace=False,
                activation=activation,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
        else:
772
773
            from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts

774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
            assert self.moe_quant_config is not None

            return fused_experts(
                x,
                layer.w13_weight,
                layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                inplace=True,
                activation=activation,
                quant_config=self.moe_quant_config,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
789
790


791
792
793
794
795
796
ModelOptFp8Config.LinearMethodCls = ModelOptFp8LinearMethod
ModelOptFp8Config.FusedMoEMethodCls = ModelOptFp8MoEMethod
ModelOptFp8Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod


class ModelOptNvFp4Config(ModelOptQuantConfigBase):
797
798
799
800
801
    """Config class for ModelOpt FP4."""

    def __init__(
        self,
        is_checkpoint_nvfp4_serialized: bool,
802
        kv_cache_quant_algo: str | None,
803
        exclude_modules: list[str],
804
805
        group_size: int = 16,
    ) -> None:
806
        super().__init__(exclude_modules)
807
808
809
810
        self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
        if is_checkpoint_nvfp4_serialized:
            logger.warning(
                "Detected ModelOpt NVFP4 checkpoint. Please note that"
811
812
                " the format is experimental and could change in future."
            )
813
814
815
816

            self.group_size = group_size
            self.kv_cache_quant_algo = kv_cache_quant_algo

817
    def get_name(self) -> QuantizationMethods:
818
        return "modelopt_fp4"
819

820
    def get_supported_act_dtypes(self) -> list[torch.dtype]:
821
822
823
824
        return [torch.bfloat16, torch.half, torch.float8_e4m3fn]

    @classmethod
    def get_min_capability(cls) -> int:
825
        return 80
826

827
828
    @classmethod
    def override_quantization_method(
829
        cls, hf_quant_cfg, user_quant
830
    ) -> QuantizationMethods | None:
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
        """Detect if this ModelOpt FP4 config should be used based on
        quantization config."""
        if hf_quant_cfg is None:
            return None

        # Use the community standard 'quant_method'
        quant_method = hf_quant_cfg.get("quant_method", "").lower()

        # Only proceed if the method is explicitly "modelopt"
        if quant_method != "modelopt":
            return None

        # Look for ModelOpt-specific config structure
        if "quantization" in hf_quant_cfg:
            quant_config = hf_quant_cfg["quantization"]
            if isinstance(quant_config, dict):
                quant_algo = quant_config.get("quant_algo", "")
                if "NVFP4" in quant_algo:
                    return "modelopt_fp4"
        else:
            # Check for compressed-tensors style config with specific
            # quant_algo field
            quant_algo = hf_quant_cfg.get("quant_algo", "")
            if isinstance(quant_algo, str) and "FP4" in quant_algo.upper():
                return "modelopt_fp4"

        return None

859
    @classmethod
860
861
862
863
864
865
866
867
868
869
    def _from_config(
        cls,
        *,
        quant_method: str,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
        original_config: dict[str, Any],
        group_size: int | None,
        **kwargs: Any,
    ) -> "ModelOptNvFp4Config":
870
        is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
871

872
873
874
        if group_size is None:
            group_size = 16  # Default value

875
        # For FP4, these fields are required
876
        if is_checkpoint_nvfp4_serialized and "quantization" in original_config:
877
            # Check if required fields are present in the quantization config
878
            quant_config = original_config["quantization"]
879
            required_fields = ["group_size", "kv_cache_quant_algo", "exclude_modules"]
880
881
882
883
884
885
            missing_fields = [
                field for field in required_fields if field not in quant_config
            ]
            if missing_fields:
                raise ValueError(
                    f"NVFP4 quantization requires the following fields in "
886
887
888
889
890
                    f"hf_quant_config.json: {missing_fields}"
                )

        return cls(
            is_checkpoint_nvfp4_serialized,
891
            kv_cache_quant_method,
892
893
894
            exclude_modules,
            group_size,
        )
895
896
897
898
899


class ModelOptNvFp4LinearMethod(LinearMethodBase):
    """Linear method for Model Optimizer NVFP4.
    Supports loading NVFP4 checkpoints with the following structure:
900

901
902
903
904
905
906
907
    input_scale: torch.float32, scalar ,
    weight: NVFP4(represented as byte) Shape: [1, X, y/2]
    weight_scale: FP8-E4M3, Shape: [X, Y], aka per block scale,
    weight_scale_2: torch.float32, scalar,
    Args: quant_config: The ModelOpt quantization config.
    """

908
    def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
909
        self.quant_config = quant_config
910
        self.marlin_input_dtype = None
911

912
913
914
915
916
917
918
919
920
921
922
        self.backend = "none"
        if envs.VLLM_NVFP4_GEMM_BACKEND is None:
            if has_flashinfer():
                self.backend = "flashinfer-cutlass"
            elif cutlass_fp4_supported():
                self.backend = "cutlass"
            elif is_fp4_marlin_supported():
                self.backend = "marlin"
        elif envs.VLLM_NVFP4_GEMM_BACKEND.startswith("flashinfer-"):
            self.backend = envs.VLLM_NVFP4_GEMM_BACKEND
            assert has_flashinfer(), f"FlashInfer is required for {self.backend}"
923
924
925
        elif envs.VLLM_NVFP4_GEMM_BACKEND == "cutlass":
            self.backend = "cutlass"
            assert cutlass_fp4_supported(), f"Cutlass is required for {self.backend}"
926
927

        if self.backend == "none":
928
            raise ValueError(
929
930
                "No valid NVFP4 GEMM backend found. "
                "Please check your platform capability."
931
            )
932

933
934
        logger.info_once(f"Using {self.backend} for NVFP4 GEMM")

935
936
937
938
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
939
        output_partition_sizes: list[int],
940
941
942
943
944
945
946
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size
        if not self.quant_config.is_checkpoint_nvfp4_serialized:
947
948
949
950
            raise ValueError(
                "NVFP4 quantization was selected, "
                " dynamic quantization is not supported."
            )
951
952
953
954
955
956
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition

957
958
959
960
        if input_size_per_partition % 16 != 0:
            raise ValueError(
                "Unsupported model when in features size is not multiple of 16"
            )
961
        # The nvfp4 weight is still represented as
962
963
964
965
966
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_nvfp4_serialized
            else params_dtype
        )
967
968
969
970
971
972
        # Weight
        weight = ModelWeightParameter(
            data=torch.empty(
                # 2 fp4 items are packed in the input dimension
                layer.output_size_per_partition,
                layer.input_size_per_partition // 2,
973
974
                dtype=torch.uint8,
            ),
975
976
            input_dim=1,
            output_dim=0,
977
978
            weight_loader=weight_loader,
        )
979
980
981
        layer.register_parameter("weight", weight)

        # Input Weight Scale
982
983
984
985
        input_scale = PerTensorScaleParameter(
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )
986
987
988
        layer.register_parameter("input_scale", input_scale)

        # Global Weight Scale
989
990
991
992
        weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )
993
994
995
        layer.register_parameter("weight_scale_2", weight_scale_2)

        # Per Block Weight Scale
996
997
998
999
1000
1001
1002
1003
1004
1005
        weight_scale = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition // self.quant_config.group_size,
                dtype=weight_dtype,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016

        layer.register_parameter("weight_scale", weight_scale)

    def process_weights_after_loading(self, layer: Module) -> None:
        # global scales:
        input_scale_2 = layer.input_scale.max().to(torch.float32)
        layer.input_scale = Parameter(input_scale_2, requires_grad=False)

        weight_scale_2 = layer.weight_scale_2.max().to(torch.float32)
        layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False)

1017
1018
1019
        layer.alpha = Parameter(
            layer.input_scale * layer.weight_scale_2, requires_grad=False
        )
1020

1021
1022
        # Calculate `1 / input_scale` so that we don't need to do so at runtime
        layer.input_scale_inv = Parameter(
1023
1024
            (1 / layer.input_scale).to(torch.float32), requires_grad=False
        )
1025

1026
1027
1028
        # Swizzle the weight blockscale.
        # contracting dimension is input dimension
        # block_size = 16;
1029
1030
1031
        assert layer.weight_scale.dtype == torch.float8_e4m3fn, (
            "Weight Block scale must be represented as FP8-E4M3"
        )
1032

1033
1034
1035
1036
1037
        if self.backend == "marlin":
            prepare_fp4_layer_for_marlin(layer)
            del layer.alpha
            del layer.input_scale
        elif self.backend == "flashinfer-trtllm":
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
            # FlashInfer TRTLLM FP4 GEMM requires a different weight layout.
            # FlashInfer provides nvfp4_quantize to quantize + shuffle the
            # layout but we use our own quantization so we have to call
            # shuffles ourselves.
            from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a

            weight = layer.weight.data
            weight_scale = layer.weight_scale.data

            epilogue_tile_m = 128
1048
1049
1050
1051
1052
1053
            weight = shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m)
            weight_scale = (
                shuffle_matrix_sf_a(weight_scale.view(torch.uint8), epilogue_tile_m)
                .reshape(weight_scale.shape)
                .view(torch.float8_e4m3fn)
            )
1054

1055
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
1056
1057
1058
            layer.weight = Parameter(weight, requires_grad=False)
        else:
            swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
1059
            layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False)
1060
            layer.weight = Parameter(layer.weight.data, requires_grad=False)
1061
1062
1063
1064
1065

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
1066
        bias: torch.Tensor | None = None,
1067
    ) -> torch.Tensor:
1068
        if self.backend == "marlin":
1069
1070
1071
1072
1073
1074
1075
1076
            return apply_fp4_marlin_linear(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale,
                weight_scale_2=layer.weight_scale_2,
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
1077
                bias=bias,
1078
                input_dtype=self.marlin_input_dtype,
1079
            )
1080

1081
        output_dtype = x.dtype
1082
        output_shape = [x.shape[0], layer.weight.shape[0]]
1083
1084

        # quantize BF16 or FP16 to (FP4 and interleaved block scale)
1085
        x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_scale_inv)
1086
1087
1088

        # validate dtypes of quantized input, input block scale,
        # weight and weight_blockscale
1089
1090
1091
1092
1093
        assert x_fp4.dtype == torch.uint8
        assert layer.weight.dtype == torch.uint8
        assert x_blockscale.dtype == torch.float8_e4m3fn
        assert layer.weight_scale.dtype == torch.float8_e4m3fn
        assert layer.alpha.dtype == torch.float32
1094

1095
1096
1097
1098
        mm_args = (
            x_fp4,
            layer.weight,
            x_blockscale,
1099
            layer.weight_scale,
1100
1101
1102
            layer.alpha,
            output_dtype,
        )
1103
1104
1105
        if self.backend.startswith("flashinfer-"):
            backend_name = self.backend[len("flashinfer-") :]
            out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name)
1106
        else:
1107
            assert self.backend == "cutlass"
1108
1109
            out = cutlass_scaled_fp4_mm(*mm_args)

1110
1111
1112
        if bias is not None:
            out = out + bias
        return out.view(*output_shape)
1113
1114
1115
1116
1117


class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
    """
    MoE Method for FP4 Quantization.
1118
    Args:
1119
1120
1121
        quant_config: NVFP4 Quant Config
    """

1122
1123
1124
    def __init__(
        self,
        quant_config: ModelOptNvFp4Config,
1125
        layer: FusedMoE,
1126
    ) -> None:
1127
1128
        from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import (
            detect_nvfp4_moe_support,  # noqa: E501
1129
1130
        )

1131
        super().__init__(layer.moe_config)
1132
1133
        self.quant_config = quant_config
        self.layer = layer
1134
1135
        _nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
        self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
1136
        self.allow_flashinfer = _nvfp4.allow_flashinfer
1137
        self.use_marlin = _nvfp4.use_marlin
1138
        self.marlin_input_dtype = None
1139
1140
        self.flashinfer_moe_backend = None
        if self.allow_flashinfer:
1141
1142
1143
            self.flashinfer_moe_backend = get_flashinfer_moe_backend()
            logger.info_once(
                f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
1144
1145
                " for ModelOptNvFp4FusedMoE."
            )
1146
1147
1148
1149
        elif self.use_marlin:
            logger.info_once("Using Marlin for ModelOptNvFp4FusedMoE.")
        else:
            logger.info_once("Using Cutlass for ModelOptNvFp4FusedMoE.")
1150

1151
1152
1153
1154
    def maybe_make_prepare_finalize(
        self,
        routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
    ) -> mk.FusedMoEPrepareAndFinalize | None:
1155
1156
1157
1158
        if self.use_marlin or (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1159
            return None
1160
1161
1162
1163
        elif (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
        ):
1164
            # For now, fp4 moe only works with the flashinfer dispatcher.
1165
1166
1167
            prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(
                self.moe
            )
1168
1169
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
1170
        else:
1171
            return super().maybe_make_prepare_finalize(routing_tables)
1172

1173
1174
1175
    def select_gemm_impl(
        self,
        prepare_finalize: mk.FusedMoEPrepareAndFinalize,
1176
        layer: torch.nn.Module,
1177
    ) -> mk.FusedMoEPermuteExpertsUnpermute:
1178
        assert self.moe_quant_config is not None
1179
        experts = select_nvfp4_gemm_impl(
1180
1181
            self.moe,
            self.moe_quant_config,
1182
1183
1184
1185
            allow_flashinfer=self.allow_flashinfer,
        )
        logger.debug_once("Using %s", experts.__class__.__name__)
        return experts
1186

1187
1188
1189
1190
1191
1192
    def uses_weight_scale_2_pattern(self) -> bool:
        """
        FP4 variants use 'weight_scale_2' pattern for per-tensor weight scales.
        """
        return True

1193
1194
1195
1196
1197
1198
1199
1200
1201
    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
1202
        if not self.quant_config.is_checkpoint_nvfp4_serialized:
1203
1204
1205
1206
            raise ValueError(
                "NVFP4 quantization was selected, "
                " dynamic quantization is not supported."
            )
1207

1208
1209
        layer.num_experts = num_experts
        layer.params_dtype = params_dtype
1210
1211
1212
1213
        layer.quant_config = self.quant_config
        weight_dtype = torch.uint8
        weight_scale_dtype = torch.float8_e4m3fn
        weight_loader = extra_weight_attrs.get("weight_loader")
1214
        global_num_experts = extra_weight_attrs.get("global_num_experts")
1215
1216
1217
1218
1219
1220
1221
        # GEMM 1
        w13_weight = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                # 2 fp4 items are packed in the input dimension
                hidden_size // 2,
1222
1223
                dtype=weight_dtype,
            ),
1224
1225
            input_dim=1,
            output_dim=2,
1226
1227
            weight_loader=weight_loader,
        )
1228
1229
1230
1231
1232
1233
1234
1235
1236
        layer.register_parameter("w13_weight", w13_weight)

        # GEMM 2
        w2_weight = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                hidden_size,
                # 2 fp4 items are packed in the input dimension
                intermediate_size_per_partition // 2,
1237
1238
                dtype=weight_dtype,
            ),
1239
1240
            input_dim=1,
            output_dim=2,
1241
1242
            weight_loader=weight_loader,
        )
1243
1244
1245
1246
1247
1248
1249
1250
        layer.register_parameter("w2_weight", w2_weight)

        w13_weight_scale = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                # 2 fp4 items are packed in the input dimension
                hidden_size // self.quant_config.group_size,
1251
1252
                dtype=weight_scale_dtype,
            ),
1253
1254
            input_dim=1,
            output_dim=2,
1255
1256
            weight_loader=weight_loader,
        )
1257
1258
1259
1260
1261
1262
1263
        layer.register_parameter("w13_weight_scale", w13_weight_scale)

        w2_weight_scale = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                hidden_size,
                # 2 fp4 items are packed in the input dimension
1264
1265
1266
                intermediate_size_per_partition // self.quant_config.group_size,
                dtype=weight_scale_dtype,
            ),
1267
1268
            input_dim=1,
            output_dim=2,
1269
1270
            weight_loader=weight_loader,
        )
1271
1272
1273
        layer.register_parameter("w2_weight_scale", w2_weight_scale)

        extra_weight_attrs.update(
1274
1275
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
        )
1276
1277
1278

        w13_weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(num_experts, 2, dtype=torch.float32),
1279
1280
            weight_loader=weight_loader,
        )
1281
1282
1283
1284
        layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)

        w2_weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(num_experts, dtype=torch.float32),
1285
1286
            weight_loader=weight_loader,
        )
1287
1288
1289
        layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2)

        extra_weight_attrs.update(
1290
1291
            {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
        )
1292

1293
1294
1295
1296
1297
        use_global_sf = self.allow_flashinfer and is_flashinfer_supporting_global_sf(
            self.flashinfer_moe_backend
        )
        global_scale_num_experts = global_num_experts if use_global_sf else num_experts

1298
        w13_input_scale = PerTensorScaleParameter(
1299
            data=torch.empty(global_scale_num_experts, 2, dtype=torch.float32),
1300
1301
            weight_loader=weight_loader,
        )
1302
1303
        layer.register_parameter("w13_input_scale", w13_input_scale)

1304
        w2_input_scale = PerTensorScaleParameter(
1305
            data=torch.empty(global_scale_num_experts, dtype=torch.float32),
1306
1307
            weight_loader=weight_loader,
        )
1308
1309
1310
        layer.register_parameter("w2_input_scale", w2_input_scale)

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
1311
        # GEMM 1 processing
1312
1313
1314
        gemm1_weight = layer.w13_weight.data
        gemm1_weight_scale = layer.w13_weight_scale.data

1315
1316
1317
        if self.allow_flashinfer and (
            self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
            or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
1318
        ):
1319
            gemm1_weight, gemm1_weight_scale = reorder_w1w3_to_w3w1(
1320
1321
                gemm1_weight, gemm1_weight_scale, dim=-2
            )
1322
1323

        layer.w13_weight = Parameter(gemm1_weight, requires_grad=False)
1324
        layer.w13_weight_scale = Parameter(gemm1_weight_scale, requires_grad=False)
1325

1326
        # Common processing for w13_weight_scale_2
1327
1328
1329
        if not torch.allclose(
            layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]
        ):
1330
1331
            logger.warning_once(
                "w1_weight_scale_2 must match w3_weight_scale_2. "
1332
1333
                "Accuracy may be affected."
            )
1334
1335

        w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
1336
        layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False)
1337

1338
        # Common processing for input scales and alphas
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
        use_global_sf = self.allow_flashinfer and is_flashinfer_supporting_global_sf(
            self.flashinfer_moe_backend
        )
        if use_global_sf:
            # For backends provide by Flashinfer, the input global scales are
            # shared across all experts.
            w13_input_scale = (
                layer.w13_input_scale.max().to(torch.float32).expand(layer.num_experts)
            )
        else:
            w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
1350
1351
        layer.g1_alphas = Parameter(
            (w13_input_scale * w13_weight_scale_2).to(torch.float32),
1352
1353
            requires_grad=False,
        )
1354
1355
1356

        # This is for quantization, so we need to invert it.
        layer.w13_input_scale_quant = Parameter(
1357
1358
            (1 / w13_input_scale).to(torch.float32), requires_grad=False
        )
1359

1360
        # GEMM 2 processing
1361
1362
1363
1364
1365
1366
1367
1368
        if use_global_sf:
            # For backends provide by Flashinfer, the input global scales are
            # shared across all experts.
            w2_input_scale = (
                layer.w2_input_scale.max().to(torch.float32).expand(layer.num_experts)
            )
        else:
            w2_input_scale = layer.w2_input_scale
1369
        layer.g2_alphas = Parameter(
1370
            (w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
1371
1372
            requires_grad=False,
        )
1373
1374
1375

        # This is for quantization, so we need to invert it.
        layer.w2_input_scale_quant = Parameter(
1376
            (1 / w2_input_scale).to(torch.float32), requires_grad=False
1377
        )
1378

1379
        # TensorRT-LLM specific processing
1380
1381
1382
1383
        if (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1384
            # Prepare static weights for TRT-LLM kernel
1385
            # alternate: prepare_static_weight_layouts_for_trtllm_moe
1386
1387
1388
1389
1390
            (
                gemm1_weights_fp4_shuffled,
                gemm1_scales_fp4_shuffled,
                gemm2_weights_fp4_shuffled,
                gemm2_scales_fp4_shuffled,
1391
            ) = prepare_static_weights_for_trtllm_fp4_moe(
1392
1393
1394
1395
1396
1397
1398
1399
                layer.w13_weight,
                layer.w2_weight,
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                layer.w2_weight.size(-2),  # hidden_size
                layer.w13_weight.size(-2) // 2,  # intermediate_size
                layer.w13_weight.size(0),  # num_experts
            )
1400
            logger.debug_once("Finished shuffling weights for TRT-LLM MOE")
1401
1402

            layer.gemm1_weights_fp4_shuffled = Parameter(
1403
1404
                gemm1_weights_fp4_shuffled, requires_grad=False
            )
1405
            layer.gemm2_weights_fp4_shuffled = Parameter(
1406
1407
                gemm2_weights_fp4_shuffled, requires_grad=False
            )
1408
            layer.gemm1_scales_fp4_shuffled = Parameter(
1409
1410
                gemm1_scales_fp4_shuffled, requires_grad=False
            )
1411
            layer.gemm2_scales_fp4_shuffled = Parameter(
1412
1413
                gemm2_scales_fp4_shuffled, requires_grad=False
            )
1414
1415
1416

            # Additional parameter needed for TRT-LLM
            layer.g1_scale_c = Parameter(
1417
                (layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32),
1418
1419
                requires_grad=False,
            )
1420

1421
1422
1423
1424
1425
            # Clean up weights that won't be used by TRT-LLM
            del layer.w2_weight
            del layer.w2_weight_scale
            del layer.w13_weight
            del layer.w13_weight_scale
1426
1427
1428
1429
1430
1431
1432
        elif self.use_marlin:
            # Marlin processing
            prepare_moe_fp4_layer_for_marlin(layer)
            del layer.g1_alphas
            del layer.g2_alphas
            del layer.w13_input_scale_quant
            del layer.w2_input_scale_quant
1433
1434
        else:
            # Non-TRT-LLM processing (Cutlass or non-flashinfer)
1435
1436
1437
1438
1439
            w13_blockscale_swizzled = swizzle_blockscale(layer.w13_weight_scale)
            layer.w13_weight_scale = Parameter(
                w13_blockscale_swizzled, requires_grad=False
            )

1440
            w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale)
1441
1442
1443
1444
            layer.w2_weight_scale = Parameter(
                w2_blockscale_swizzled, requires_grad=False
            )
            layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
1445

1446
    def get_fused_moe_quant_config(
1447
        self, layer: torch.nn.Module
1448
    ) -> FusedMoEQuantConfig | None:
1449
1450
1451
1452
        if (
            self.use_marlin
            or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
            return None

        return nvfp4_moe_quant_config(
            w1_scale=layer.w13_weight_scale,
            w2_scale=layer.w2_weight_scale,
            g1_alphas=layer.g1_alphas,
            g2_alphas=layer.g2_alphas,
            a1_gscale=layer.w13_input_scale_quant,
            a2_gscale=layer.w2_input_scale_quant,
        )

1464
1465
    def apply(
        self,
1466
        layer: FusedMoE,
1467
1468
1469
1470
1471
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
        use_grouped_topk: bool = False,
1472
1473
        topk_group: int | None = None,
        num_expert_group: int | None = None,
1474
        global_num_experts: int = -1,
1475
1476
        expert_map: torch.Tensor | None = None,
        custom_routing_function: Callable | None = None,
1477
        scoring_func: str = "softmax",
1478
        routed_scaling_factor: float = 1.0,
1479
        e_score_correction_bias: torch.Tensor | None = None,
1480
1481
        apply_router_weight_on_input: bool = False,
        activation: str = "silu",
1482
        enable_eplb: bool = False,
1483
1484
1485
1486
        expert_load_view: torch.Tensor | None = None,
        logical_to_physical_map: torch.Tensor | None = None,
        logical_replica_count: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
1487
        assert activation == "silu", "Only SiLU activation is supported."
1488

1489
1490
1491
1492
        if (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1493
1494
1495
1496
            if enable_eplb:
                raise NotImplementedError(
                    "EPLB not supported for `ModelOptNvFp4FusedMoE` yet."
                )
1497
1498
1499
1500
            return flashinfer_trtllm_fp4_moe(
                layer=layer,
                x=x,
                router_logits=router_logits,
1501
                top_k=top_k,
1502
1503
                global_num_experts=global_num_experts,
                num_expert_group=num_expert_group,
1504
                topk_group=topk_group,
1505
1506
1507
                custom_routing_function=custom_routing_function,
                e_score_correction_bias=e_score_correction_bias,
            )
1508

1509
        topk_weights, topk_ids, _ = layer.select_experts(
1510
1511
            hidden_states=x,
            router_logits=router_logits,
1512
        )
1513

1514
        if self.use_marlin:
1515
            return fused_marlin_moe(
1516
1517
1518
                x,
                layer.w13_weight,
                layer.w2_weight,
1519
1520
                None,
                None,
1521
1522
1523
1524
1525
1526
1527
1528
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                router_logits,
                topk_weights,
                topk_ids,
                global_scale1=layer.w13_weight_scale_2,
                global_scale2=layer.w2_weight_scale_2,
                quant_type_id=scalar_types.float4_e2m1f.id,
1529
                apply_router_weight_on_input=apply_router_weight_on_input,
1530
                global_num_experts=global_num_experts,
1531
                expert_map=expert_map,
1532
                input_dtype=self.marlin_input_dtype,
1533
            )
1534

1535
1536
1537
1538
        elif self.allow_flashinfer:
            assert self.flashinfer_moe_backend in (
                FlashinferMoeBackend.CUTLASS,
                FlashinferMoeBackend.CUTEDSL,
1539
            )
1540
1541
1542
1543
            if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
                from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (  # noqa: E501
                    flashinfer_cutlass_moe_fp4,
                )
1544

1545
1546
1547
1548
1549
1550
1551
                flashinfer_fn_moe_fp4 = flashinfer_cutlass_moe_fp4
            else:
                from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import (  # noqa: E501
                    flashinfer_cutedsl_moe_fp4,
                )

                flashinfer_fn_moe_fp4 = flashinfer_cutedsl_moe_fp4
1552

1553
1554
            assert self.moe_quant_config is not None
            return flashinfer_fn_moe_fp4(
1555
1556
1557
1558
1559
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
1560
1561
                quant_config=self.moe_quant_config,
                inplace=False,
1562
1563
1564
1565
1566
1567
                activation=activation,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
        else:
1568
1569
            # If no modular kernel is provided, use cutlass_moe_fp4 for TP case
            # only (no EP).
1570
1571
            from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4

1572
1573
            assert self.moe_quant_config is not None
            return cutlass_moe_fp4(
1574
1575
1576
1577
1578
                a=x,
                w1_fp4=layer.w13_weight,
                w2_fp4=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
1579
1580
1581
1582
                quant_config=self.moe_quant_config,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
                # TODO: derive from arguments
1583
1584
1585
1586
                m=x.shape[0],
                n=layer.w2_weight.shape[2] * 2,
                k=x.shape[1],
                e=layer.w13_weight.shape[0],
1587
            )
1588
1589
1590
1591
1592


ModelOptNvFp4Config.LinearMethodCls = ModelOptNvFp4LinearMethod
ModelOptNvFp4Config.FusedMoEMethodCls = ModelOptNvFp4FusedMoE
ModelOptNvFp4Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod