modelopt.py 58.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Callable
5
from fnmatch import fnmatch
6
from typing import TYPE_CHECKING, Any, Optional
7
8
9
10
11

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter

12
13
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
14
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
15
from vllm.attention.layer import Attention
16
from vllm.logger import init_logger
17
from vllm.model_executor.layers.fused_moe.config import (
18
19
20
21
    FusedMoEQuantConfig,
    fp8_w8a8_moe_quant_config,
    nvfp4_moe_quant_config,
)
22
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
23
from vllm.model_executor.layers.fused_moe.layer import (
24
25
26
27
28
29
30
31
32
    FusedMoE,
    FusedMoEMethodBase,
    FusedMoeWeightScaleSupported,
)
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
)
33
from vllm.model_executor.layers.quantization import QuantizationMethods
34
from vllm.model_executor.layers.quantization.base_config import (
35
36
37
    QuantizationConfig,
    QuantizeMethodBase,
)
38
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
39
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
40
    build_flashinfer_fp4_cutlass_moe_prepare_finalize,
41
42
    flashinfer_trtllm_fp4_moe,
    prepare_static_weights_for_trtllm_fp4_moe,
43
44
45
    reorder_w1w3_to_w3w1,
    select_nvfp4_gemm_impl,
)
46
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
47
48
    FlashinferMoeBackend,
    apply_flashinfer_per_tensor_scale_fp8,
49
    build_flashinfer_fp8_cutlass_moe_prepare_finalize,
50
51
    flashinfer_cutlass_moe_fp8,
    get_flashinfer_moe_backend,
52
    is_flashinfer_supporting_global_sf,
53
54
55
56
57
    register_moe_scaling_factors,
    rotate_flashinfer_fp8_moe_weights,
    select_cutlass_fp8_gemm_impl,
    swap_w13_to_w31,
)
58
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
59
60
61
62
63
    apply_fp4_marlin_linear,
    is_fp4_marlin_supported,
    prepare_fp4_layer_for_marlin,
    prepare_moe_fp4_layer_for_marlin,
)
64
from vllm.model_executor.layers.quantization.utils.quant_utils import (
65
66
67
68
69
    GroupShape,
    cutlass_fp4_supported,
    is_layer_skipped,
    swizzle_blockscale,
)
70
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
71
72
73
74
    Fp8LinearOp,
    requantize_with_max_scale,
)
from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
75
from vllm.scalar_type import scalar_types
76
77
78
79
80
from vllm.utils.flashinfer import (
    flashinfer_scaled_fp4_mm,
    has_flashinfer,
    has_flashinfer_moe,
)
81

82
83
84
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

85
86
logger = init_logger(__name__)

87
88
QUANT_ALGOS = ["FP8", "NVFP4"]
KV_CACHE_QUANT_ALGOS = ["FP8"]
89
90


91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
    """

    def __init__(self, quant_config: "ModelOptQuantConfigBase"):
        super().__init__(quant_config)


class ModelOptQuantConfigBase(QuantizationConfig):
    LinearMethodCls: type = LinearMethodBase
    FusedMoEMethodCls: type = FusedMoEMethodBase
    KVCacheMethodCls: type = BaseKVCacheMethod

    def __init__(
        self,
        exclude_modules: list[str],
    ):
        super().__init__()
        self.exclude_modules: list[str] = exclude_modules

    def is_layer_excluded(self, prefix: str) -> bool:
        """
        Check if a layer should be excluded from quantization.

        Handles both exact matching (for fused layers) and ModelOpt wildcard matching.

        The ModelOpt exclude_modules list is a list of wildcards.
        """
        if len(self.exclude_modules) == 0:
            return False

        # First check exact matching with fused layer support
        if is_layer_skipped(prefix, self.exclude_modules, self.packed_modules_mapping):
            return True

        # TODO: This special hard coded logic is not needed for quantized checkpoints
        # generated by ModelOpt >= 0.39.0 where they are handled natually by the
        # exclude_modules config. But need to keep them for loading quantized
        # checkpoints generated by older versions. Then check substring matching
        # for patterns not caught by exact match
        for exclude_module in self.exclude_modules:
            # Skip exact matches already handled above
            if exclude_module != prefix and (
                exclude_module in prefix
                or (
                    prefix.startswith("language_model.")
                    and exclude_module in prefix.removeprefix("language_model.")
                )
            ):
                return True

        # modelopt exclude modules are not simple strings, they are wildcards
        for wildcard_pattern in self.exclude_modules:
            if fnmatch(prefix, wildcard_pattern):
                return True

        return False

    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
        # handle kv-cache first so we can focus only on weight quantization thereafter
        if isinstance(layer, Attention):
            return self.KVCacheMethodCls(self)

        # handle exclusion
        if self.is_layer_excluded(prefix):
            if isinstance(layer, LinearBase):
                return UnquantizedLinearMethod()
            return None

        # TODO: This special hard coded logic is not needed for quantized checkpoints
        # generated by ModelOpt >= 0.39.0 where they are handled natually by the
        # exclude_modules config. But need to keep them for loading quantized
        # checkpoints generated by older versions. Then check substring matching
        # for patterns not caught by exact match
        if "vision_tower" in prefix or "vision_model" in prefix:
            return UnquantizedLinearMethod()

        # now, the layer is quantized, handle it here
        if isinstance(layer, LinearBase):
            return self.LinearMethodCls(self)
        elif isinstance(layer, FusedMoE):
            return self.FusedMoEMethodCls(quant_config=self, layer=layer)

        return None

    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if len(self.exclude_modules) > 0:
            self.exclude_modules = hf_to_vllm_mapper.apply_list(self.exclude_modules)

    @staticmethod
    def get_config_filenames() -> list[str]:
        return ["hf_quant_config.json"]

    @classmethod
    def _from_config(
        cls,
        *,
        quant_method: str,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
        original_config: dict[str, Any],
        group_size: int | None,
    ) -> "ModelOptQuantConfigBase":
        raise NotImplementedError("Please implement this function in sub classes")

    @classmethod
    def from_config(cls, config: dict[str, Any]) -> "ModelOptQuantConfigBase":
        # Handle both ModelOpt format and compressed-tensors style format
        if "quantization" in config:
            # Traditional ModelOpt format:
            # {"quantization": {"quant_algo": "..."}}
            quant_config = cls.get_from_keys(config, ["quantization"])
            if not isinstance(quant_config, dict):
                raise ValueError("Expected 'quantization' to be a dictionary in config")

            quant_method = quant_config.get("quant_algo")

            # Handle kv_cache_quant_algo with proper type validation
            kv_cache_quant_method = quant_config.get("kv_cache_quant_algo")

            # Handle group_size with proper type validation
            group_size_raw = quant_config.get("group_size")

            # "exclude_modules" is the key in the legacy hf_quant_config.json
            exclude_modules = quant_config.get("exclude_modules", [])
        else:
            # Compressed-tensors style format:
            # {"quant_algo": "...", "quant_method": "modelopt"}
            quant_method = config.get("quant_algo")
            kv_cache_quant_method = config.get("kv_cache_quant_algo")
            # "ignore" is the key in config.json
            exclude_modules = config.get("ignore", [])
            group_size_raw = config.get("group_size")

        if not quant_method:
            raise ValueError("Missing 'quant_algo' in quantization config")

        if kv_cache_quant_method is None:
            # No KV cache quantization, keep this branch just to have this comment
            pass
        elif not isinstance(kv_cache_quant_method, str):
            raise ValueError(
                f"kv_cache_quant_algo must be a string, got "
                f"{type(kv_cache_quant_method)}"
            )

        if not isinstance(exclude_modules, list):
            raise ValueError(
                f"exclude_modules must be a list, got {type(exclude_modules)}"
            )

        if group_size_raw is None:
            group_size = None
        elif isinstance(group_size_raw, int):
            group_size = group_size_raw
        else:
            try:
                group_size = int(group_size_raw)
            except (ValueError, TypeError):
                raise ValueError(
                    f"group_size must be an integer, got {type(group_size_raw)}"
                ) from None

        if quant_method not in QUANT_ALGOS:
            raise ValueError(
                f"ModelOpt currently only supports: {QUANT_ALGOS} "
                "quantizations in vLLM. Please check the "
                "`hf_quant_config.json` file for your model's "
                "quant configuration."
            )
        return cls._from_config(
            quant_method=quant_method,
            kv_cache_quant_method=kv_cache_quant_method,
            exclude_modules=exclude_modules,
            group_size=group_size,
            original_config=config,
        )


class ModelOptFp8Config(ModelOptQuantConfigBase):
274
275
276
277
    """Config class for ModelOpt FP8."""

    def __init__(
        self,
278
279
280
        is_checkpoint_fp8_serialized: bool,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
281
    ) -> None:
282
        super().__init__(exclude_modules)
283
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
284
        self.kv_cache_quant_method = kv_cache_quant_method
285
        if is_checkpoint_fp8_serialized:
286
287
288
289
            logger.warning(
                "Detected ModelOpt fp8 checkpoint. Please note that"
                " the format is experimental and could change."
            )
290

291
    def get_name(self) -> QuantizationMethods:
292
293
        return "modelopt"

294
    def get_supported_act_dtypes(self) -> list[torch.dtype]:
295
296
297
298
299
300
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
        return 89

301
302
    @classmethod
    def override_quantization_method(
303
        cls, hf_quant_cfg, user_quant
304
    ) -> QuantizationMethods | None:
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
        """Detect if this ModelOpt config should be used based on
        quantization config."""

        if hf_quant_cfg is None:
            return None

        # Use the community standard 'quant_method'
        quant_method = hf_quant_cfg.get("quant_method", "").lower()

        # Only proceed if the method is explicitly "modelopt"
        if quant_method != "modelopt":
            return None

        # Look for ModelOpt-specific config structure
        if "quantization" in hf_quant_cfg:
            quant_config = hf_quant_cfg["quantization"]
            if isinstance(quant_config, dict):
                quant_algo = quant_config.get("quant_algo", "")
                if "FP8" in quant_algo:
                    return "modelopt"
        else:
            # Check for compressed-tensors style config with specific quant_algo
            quant_algo = hf_quant_cfg.get("quant_algo", "")
            if isinstance(quant_algo, str) and "FP8" in quant_algo:
                return "modelopt"

        return None

333
    @classmethod
334
335
336
337
338
339
340
341
342
    def _from_config(
        cls,
        *,
        quant_method: str,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
        original_config: dict[str, Any],
        **kwargs: Any,
    ) -> "ModelOptFp8Config":
343
        is_checkpoint_fp8_serialized = "FP8" in quant_method
344

345
        return cls(is_checkpoint_fp8_serialized, kv_cache_quant_method, exclude_modules)
346

347
348
349
350

class ModelOptFp8LinearMethod(LinearMethodBase):
    """Linear method for Model Optimizer static quantization.
    Supports loading FP8 checkpoints with static weight scale and
351
    activation scale. Future support might be added for dynamic
352
353
354
355
    scales.

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
356
    2. Only support float8_e4m3fn datatype
357
358
359
        Args: quant_config: The ModelOpt quantization config.
    """

360
    def __init__(self, quant_config: ModelOptFp8Config) -> None:
361
        self.quant_config = quant_config
362
        self.fp8_linear = Fp8LinearOp(
363
364
            act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR
        )
365
366
367
368
369

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
370
        output_partition_sizes: list[int],
371
372
373
374
375
376
377
378
379
380
381
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
382
383
384
385
386
387
388
389
390
391
392
393
394
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_fp8_serialized
            else params_dtype
        )
        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition, input_size_per_partition, dtype=weight_dtype
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
395
396
397
398
        layer.register_parameter("weight", weight)

        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
399
400
401
402
            weight_scale = PerTensorScaleParameter(
                data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
                weight_loader=weight_loader,
            )
403
404
405
            weight_scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("weight_scale", weight_scale)
            # INPUT SCALE
406
407
408
409
            scale = PerTensorScaleParameter(
                data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
                weight_loader=weight_loader,
            )
410
411
412
413
414

            scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("input_scale", scale)

    def process_weights_after_loading(self, layer: Module) -> None:
415
416
417
418
        weight = layer.weight
        max_w_scale = layer.weight_scale.max()
        if not (layer.weight_scale == layer.weight_scale[0]).all():
            max_w_scale, weight = requantize_with_max_scale(
419
420
                layer.weight, layer.weight_scale, layer.logical_widths
            )
421
422
        layer.weight = Parameter(weight.t(), requires_grad=False)
        layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
423
        layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
424
425
426
427
428

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
429
        bias: torch.Tensor | None = None,
430
    ) -> torch.Tensor:
431
432
433
434
435
436
437
        return self.fp8_linear.apply(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            input_scale=layer.input_scale,
            bias=bias,
        )
438
439


440
441
442
443
444
445
446
447
class ModelOptFp8MoEMethod(FusedMoEMethodBase):
    """MoE method for ModelOpt FP8.
    Supports loading FP8 checkpoints with static weight scale and
    activation scale.
    Args:
        quant_config: The ModelOpt quantization config.
    """

448
449
450
    def __init__(
        self,
        quant_config: ModelOptFp8Config,
451
        layer: FusedMoE,
452
    ) -> None:
453
454
        super().__init__(layer.moe_config)
        self.layer = layer
455
456
        self.quant_config = quant_config
        from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
457
458
459
            cutlass_fp8_supported,
        )

460
        self.cutlass_fp8_supported = cutlass_fp8_supported()
461
        self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
462
        if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
463
            self.flashinfer_moe_backend = get_flashinfer_moe_backend()
464
465
466
467
468
469
470
471
472
473
            if (
                self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
                and not self.moe.is_act_and_mul
            ):
                logger.info_once(
                    "Non-gated MoE is not supported for min-latency mode,"
                    "falling back to high-throughput mode"
                )
                self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS

474
            logger.info_once(
475
476
477
478
                f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
            )

    def maybe_make_prepare_finalize(
479
        self,
480
        routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
481
    ) -> mk.FusedMoEPrepareAndFinalize | None:
482
483
484
485
        # TRT LLM not supported with all2all yet.
        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
            return None
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
486
487
488
            prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
                self.moe
            )
489
490
491
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
        else:
492
            return super().maybe_make_prepare_finalize(routing_tables)
493
494
495
496

    def select_gemm_impl(
        self,
        prepare_finalize: mk.FusedMoEPrepareAndFinalize,
497
        layer: torch.nn.Module,
498
    ) -> mk.FusedMoEPermuteExpertsUnpermute:
499
        assert self.moe_quant_config is not None
500
        experts = select_cutlass_fp8_gemm_impl(
501
502
            self.moe,
            self.moe_quant_config,
503
504
505
        )
        logger.debug_once("Using %s", experts.__class__.__name__)
        return experts
506
507
508
509
510
511
512
513
514
515
516

    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        # Use FP8 dtype if checkpoint is serialized
517
518
519
520
521
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_fp8_serialized
            else params_dtype
        )
522
523
        weight_loader = extra_weight_attrs.get("weight_loader")

524
525
526
527
528
        if self.moe.is_act_and_mul:
            w13_up_dim = 2 * intermediate_size_per_partition
        else:
            w13_up_dim = intermediate_size_per_partition

529
        w13_weight = ModelWeightParameter(
530
531
            data=torch.empty(
                num_experts,
532
                w13_up_dim,
533
534
535
                hidden_size,
                dtype=weight_dtype,
            ),
536
537
538
539
540
541
542
            input_dim=2,
            output_dim=1,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_weight", w13_weight)

        w2_weight = ModelWeightParameter(
543
544
545
546
547
548
            data=torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=weight_dtype,
            ),
549
550
551
552
553
554
555
556
            input_dim=2,
            output_dim=1,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w2_weight", w2_weight)

        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALES - Per-tensor scaling for ModelOpts
557
            # For gated MoE, allocate 2 scales for w1 and w3 respectively.
558
            # They will be combined to a single scale after weight loading.
559
560
561
562
563
            # For non-gated MoE, allocate 1 scale for w13.
            if self.moe.is_act_and_mul:
                w13_weight_scale_shape = (num_experts, 2)
            else:
                w13_weight_scale_shape = (num_experts, 1)
564
565
            w13_weight_scale = PerTensorScaleParameter(
                data=torch.full(
566
                    w13_weight_scale_shape,
567
568
569
570
571
572
                    1.0,
                    dtype=torch.float32,
                ),
                weight_loader=weight_loader,
            )
            w2_weight_scale = PerTensorScaleParameter(
573
                data=torch.full((num_experts,), 1.0, dtype=torch.float32),
574
575
576
577
578
579
580
                weight_loader=weight_loader,
            )
            layer.register_parameter("w13_weight_scale", w13_weight_scale)
            layer.register_parameter("w2_weight_scale", w2_weight_scale)

            # Set weight loader attributes for scales
            extra_weight_attrs.update(
581
582
                {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
            )
583
584
585

            # INPUT SCALES - Per-tensor scaling for ModelOpt
            w13_input_scale = PerTensorScaleParameter(
586
                data=torch.full((num_experts,), 1.0, dtype=torch.float32),
587
588
589
                weight_loader=weight_loader,
            )
            w2_input_scale = PerTensorScaleParameter(
590
                data=torch.full((num_experts,), 1.0, dtype=torch.float32),
591
592
593
594
595
596
597
598
599
600
                weight_loader=weight_loader,
            )
            layer.register_parameter("w13_input_scale", w13_input_scale)
            layer.register_parameter("w2_input_scale", w2_input_scale)

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        """Process FP8 MoE weights after loading from serialized checkpoint.
        Only supports pre-quantized checkpoints with FP8 weights and scales.
        """

601
        layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
602
603
604
605
        layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)

        from vllm._custom_ops import scaled_fp8_quant
        from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
606
607
            per_tensor_dequantize,
        )
608
609

        # Handle scale parameters
610
        if hasattr(layer, "w13_weight_scale") and layer.w13_weight_scale is not None:
611
612
613
            # Fp8 moe kernel needs single weight scale for w13 per expert.
            # We take the max of the w1 and w3 scales
            # then dequant and requant each expert.
614
615
616
617
618
619
620
621
            if (
                layer.w13_weight_scale.dim() == 2
                and layer.w13_weight_scale.shape[1] == 2
            ):
                assert self.moe.is_act_and_mul, (
                    "w13_weight_scale should have 2 elements per expert "
                    "only for gated MoE"
                )
622
623
624
625
626
627
628
629
630
631
632
633
                # Get the maximum scale across w1 and w3 for each expert
                max_w13_scales = layer.w13_weight_scale.max(dim=1).values

                # Requantize each expert's weights using the combined scale
                # w13_weight (num_experts, 2 * intermediate_size, hidden_size)
                # where the first intermediate_size rows are w1, the next are w3
                intermediate_size = layer.w13_weight.shape[1] // 2
                for expert_id in range(layer.w13_weight.shape[0]):
                    start = 0
                    for shard_id in range(2):  # w1 and w3
                        # Dequantize using the original scale for this shard
                        dq_weight = per_tensor_dequantize(
634
635
636
                            layer.w13_weight[expert_id][
                                start : start + intermediate_size, :
                            ],
637
638
639
640
641
                            layer.w13_weight_scale[expert_id][shard_id],
                        )
                        # Requantize using the combined max scale

                        (
642
643
644
                            layer.w13_weight[expert_id][
                                start : start + intermediate_size, :
                            ],
645
                            _,
646
                        ) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
647
648
649
650

                        start += intermediate_size

                # Update the scale parameter to be per-expert
651
                layer.w13_weight_scale = Parameter(max_w13_scales, requires_grad=False)
652
            else:
653
654
655
                layer.w13_weight_scale = Parameter(
                    layer.w13_weight_scale.data, requires_grad=False
                )
656

657
658
659
660
        if hasattr(layer, "w2_weight_scale") and layer.w2_weight_scale is not None:
            layer.w2_weight_scale = Parameter(
                layer.w2_weight_scale.data, requires_grad=False
            )
661
        # Input scales must be equal for each expert in fp8 MoE layers.
662
663
664
665
666
667
668
669
        if hasattr(layer, "w13_input_scale") and layer.w13_input_scale is not None:
            layer.w13_input_scale = Parameter(
                layer.w13_input_scale.max(), requires_grad=False
            )
        if hasattr(layer, "w2_input_scale") and layer.w2_input_scale is not None:
            layer.w2_input_scale = Parameter(
                layer.w2_input_scale.max(), requires_grad=False
            )
670

671
        if self.flashinfer_moe_backend is not None:
672
673
            if self.moe.is_act_and_mul:
                layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
674
            if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
675
                rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight)
676
        register_moe_scaling_factors(layer)
677

678
    def get_fused_moe_quant_config(
679
        self, layer: torch.nn.Module
680
    ) -> FusedMoEQuantConfig | None:
681
682
683
684
685
        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
            return None

        return fp8_w8a8_moe_quant_config(
            w1_scale=layer.w13_weight_scale,
686
            g1_alphas=layer.output1_scales_gate_scalar.squeeze(),
687
            w2_scale=layer.w2_weight_scale,
688
            g2_alphas=layer.output2_scales_scalar.squeeze(),
689
            a1_scale=layer.w13_input_scale,
690
            a1_gscale=layer.w13_input_scale,
691
            a2_scale=layer.w2_input_scale,
692
            a2_gscale=layer.w2_input_scale_inv,
693
694
695
            per_act_token_quant=False,
        )

696
697
    def apply(
        self,
698
        layer: FusedMoE,
699
700
701
702
703
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
        use_grouped_topk: bool = False,
704
705
        topk_group: int | None = None,
        num_expert_group: int | None = None,
706
        global_num_experts: int = -1,
707
708
        expert_map: torch.Tensor | None = None,
        custom_routing_function: Callable | None = None,
709
        scoring_func: str = "softmax",
710
        routed_scaling_factor: float = 1.0,
711
        e_score_correction_bias: torch.Tensor | None = None,
712
713
714
        apply_router_weight_on_input: bool = False,
        activation: str = "silu",
        enable_eplb: bool = False,
715
716
717
718
        expert_load_view: torch.Tensor | None = None,
        logical_to_physical_map: torch.Tensor | None = None,
        logical_replica_count: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
719
        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
720
721
722
723
            if layer.enable_eplb:
                raise NotImplementedError(
                    "EPLB not supported for `ModelOptFp8MoEMethod` yet."
                )
724
725
726
            assert activation == "silu", (
                f"Expected 'silu' activation but got {activation}"
            )
727
728
729
730
731
732
733
734
735
736
            assert not renormalize
            return apply_flashinfer_per_tensor_scale_fp8(
                layer=layer,
                hidden_states=x,
                router_logits=router_logits,
                routing_bias=e_score_correction_bias,
                global_num_experts=global_num_experts,
                top_k=top_k,
                num_expert_group=num_expert_group,
                topk_group=topk_group,
737
738
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
739

740
        # Expert selection
741
        topk_weights, topk_ids, _ = layer.select_experts(
742
743
744
            hidden_states=x,
            router_logits=router_logits,
        )
745

746
        if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
747
748
749
            assert activation in ("silu", "relu2_no_mul"), (
                "Expected activation to be in ('silu', 'relu2_no_mul'),"
                f"but got {activation}"
750
            )
751
752
753
754
755
756
757
758
759
760
761
762
            return flashinfer_cutlass_moe_fp8(
                x,
                layer,
                topk_weights,
                topk_ids,
                inplace=False,
                activation=activation,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
        else:
763
764
            from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts

765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
            assert self.moe_quant_config is not None

            return fused_experts(
                x,
                layer.w13_weight,
                layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                inplace=True,
                activation=activation,
                quant_config=self.moe_quant_config,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
780
781


782
783
784
785
786
787
ModelOptFp8Config.LinearMethodCls = ModelOptFp8LinearMethod
ModelOptFp8Config.FusedMoEMethodCls = ModelOptFp8MoEMethod
ModelOptFp8Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod


class ModelOptNvFp4Config(ModelOptQuantConfigBase):
788
789
790
791
792
    """Config class for ModelOpt FP4."""

    def __init__(
        self,
        is_checkpoint_nvfp4_serialized: bool,
793
        kv_cache_quant_algo: str | None,
794
        exclude_modules: list[str],
795
796
        group_size: int = 16,
    ) -> None:
797
        super().__init__(exclude_modules)
798
799
800
801
        self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
        if is_checkpoint_nvfp4_serialized:
            logger.warning(
                "Detected ModelOpt NVFP4 checkpoint. Please note that"
802
803
                " the format is experimental and could change in future."
            )
804
805
806
807

            self.group_size = group_size
            self.kv_cache_quant_algo = kv_cache_quant_algo

808
    def get_name(self) -> QuantizationMethods:
809
        return "modelopt_fp4"
810

811
    def get_supported_act_dtypes(self) -> list[torch.dtype]:
812
813
814
815
        return [torch.bfloat16, torch.half, torch.float8_e4m3fn]

    @classmethod
    def get_min_capability(cls) -> int:
816
        return 80
817

818
819
    @classmethod
    def override_quantization_method(
820
        cls, hf_quant_cfg, user_quant
821
    ) -> QuantizationMethods | None:
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
        """Detect if this ModelOpt FP4 config should be used based on
        quantization config."""
        if hf_quant_cfg is None:
            return None

        # Use the community standard 'quant_method'
        quant_method = hf_quant_cfg.get("quant_method", "").lower()

        # Only proceed if the method is explicitly "modelopt"
        if quant_method != "modelopt":
            return None

        # Look for ModelOpt-specific config structure
        if "quantization" in hf_quant_cfg:
            quant_config = hf_quant_cfg["quantization"]
            if isinstance(quant_config, dict):
                quant_algo = quant_config.get("quant_algo", "")
                if "NVFP4" in quant_algo:
                    return "modelopt_fp4"
        else:
            # Check for compressed-tensors style config with specific
            # quant_algo field
            quant_algo = hf_quant_cfg.get("quant_algo", "")
            if isinstance(quant_algo, str) and "FP4" in quant_algo.upper():
                return "modelopt_fp4"

        return None

850
    @classmethod
851
852
853
854
855
856
857
858
859
860
    def _from_config(
        cls,
        *,
        quant_method: str,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
        original_config: dict[str, Any],
        group_size: int | None,
        **kwargs: Any,
    ) -> "ModelOptNvFp4Config":
861
        is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
862

863
864
865
        if group_size is None:
            group_size = 16  # Default value

866
        # For FP4, these fields are required
867
        if is_checkpoint_nvfp4_serialized and "quantization" in original_config:
868
            # Check if required fields are present in the quantization config
869
            quant_config = original_config["quantization"]
870
            required_fields = ["group_size", "kv_cache_quant_algo", "exclude_modules"]
871
872
873
874
875
876
            missing_fields = [
                field for field in required_fields if field not in quant_config
            ]
            if missing_fields:
                raise ValueError(
                    f"NVFP4 quantization requires the following fields in "
877
878
879
880
881
                    f"hf_quant_config.json: {missing_fields}"
                )

        return cls(
            is_checkpoint_nvfp4_serialized,
882
            kv_cache_quant_method,
883
884
885
            exclude_modules,
            group_size,
        )
886
887
888
889
890


class ModelOptNvFp4LinearMethod(LinearMethodBase):
    """Linear method for Model Optimizer NVFP4.
    Supports loading NVFP4 checkpoints with the following structure:
891

892
893
894
895
896
897
898
    input_scale: torch.float32, scalar ,
    weight: NVFP4(represented as byte) Shape: [1, X, y/2]
    weight_scale: FP8-E4M3, Shape: [X, Y], aka per block scale,
    weight_scale_2: torch.float32, scalar,
    Args: quant_config: The ModelOpt quantization config.
    """

899
    def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
900
        self.quant_config = quant_config
901

902
903
904
905
906
907
908
909
910
911
912
        self.backend = "none"
        if envs.VLLM_NVFP4_GEMM_BACKEND is None:
            if has_flashinfer():
                self.backend = "flashinfer-cutlass"
            elif cutlass_fp4_supported():
                self.backend = "cutlass"
            elif is_fp4_marlin_supported():
                self.backend = "marlin"
        elif envs.VLLM_NVFP4_GEMM_BACKEND.startswith("flashinfer-"):
            self.backend = envs.VLLM_NVFP4_GEMM_BACKEND
            assert has_flashinfer(), f"FlashInfer is required for {self.backend}"
913
914
915
        elif envs.VLLM_NVFP4_GEMM_BACKEND == "cutlass":
            self.backend = "cutlass"
            assert cutlass_fp4_supported(), f"Cutlass is required for {self.backend}"
916
917

        if self.backend == "none":
918
            raise ValueError(
919
920
                "No valid NVFP4 GEMM backend found. "
                "Please check your platform capability."
921
            )
922

923
924
        logger.info_once(f"Using {self.backend} for NVFP4 GEMM")

925
926
927
928
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
929
        output_partition_sizes: list[int],
930
931
932
933
934
935
936
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size
        if not self.quant_config.is_checkpoint_nvfp4_serialized:
937
938
939
940
            raise ValueError(
                "NVFP4 quantization was selected, "
                " dynamic quantization is not supported."
            )
941
942
943
944
945
946
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition

947
948
949
950
        if input_size_per_partition % 16 != 0:
            raise ValueError(
                "Unsupported model when in features size is not multiple of 16"
            )
951
        # The nvfp4 weight is still represented as
952
953
954
955
956
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_nvfp4_serialized
            else params_dtype
        )
957
958
959
960
961
962
        # Weight
        weight = ModelWeightParameter(
            data=torch.empty(
                # 2 fp4 items are packed in the input dimension
                layer.output_size_per_partition,
                layer.input_size_per_partition // 2,
963
964
                dtype=torch.uint8,
            ),
965
966
            input_dim=1,
            output_dim=0,
967
968
            weight_loader=weight_loader,
        )
969
970
971
        layer.register_parameter("weight", weight)

        # Input Weight Scale
972
973
974
975
        input_scale = PerTensorScaleParameter(
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )
976
977
978
        layer.register_parameter("input_scale", input_scale)

        # Global Weight Scale
979
980
981
982
        weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )
983
984
985
        layer.register_parameter("weight_scale_2", weight_scale_2)

        # Per Block Weight Scale
986
987
988
989
990
991
992
993
994
995
        weight_scale = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition // self.quant_config.group_size,
                dtype=weight_dtype,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
996
997
998
999
1000
1001
1002
1003
1004
1005
1006

        layer.register_parameter("weight_scale", weight_scale)

    def process_weights_after_loading(self, layer: Module) -> None:
        # global scales:
        input_scale_2 = layer.input_scale.max().to(torch.float32)
        layer.input_scale = Parameter(input_scale_2, requires_grad=False)

        weight_scale_2 = layer.weight_scale_2.max().to(torch.float32)
        layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False)

1007
1008
1009
        layer.alpha = Parameter(
            layer.input_scale * layer.weight_scale_2, requires_grad=False
        )
1010

1011
1012
        # Calculate `1 / input_scale` so that we don't need to do so at runtime
        layer.input_scale_inv = Parameter(
1013
1014
            (1 / layer.input_scale).to(torch.float32), requires_grad=False
        )
1015

1016
1017
1018
        # Swizzle the weight blockscale.
        # contracting dimension is input dimension
        # block_size = 16;
1019
1020
1021
        assert layer.weight_scale.dtype == torch.float8_e4m3fn, (
            "Weight Block scale must be represented as FP8-E4M3"
        )
1022

1023
1024
1025
1026
1027
        if self.backend == "marlin":
            prepare_fp4_layer_for_marlin(layer)
            del layer.alpha
            del layer.input_scale
        elif self.backend == "flashinfer-trtllm":
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
            # FlashInfer TRTLLM FP4 GEMM requires a different weight layout.
            # FlashInfer provides nvfp4_quantize to quantize + shuffle the
            # layout but we use our own quantization so we have to call
            # shuffles ourselves.
            from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a

            weight = layer.weight.data
            weight_scale = layer.weight_scale.data

            epilogue_tile_m = 128
1038
1039
1040
1041
1042
1043
            weight = shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m)
            weight_scale = (
                shuffle_matrix_sf_a(weight_scale.view(torch.uint8), epilogue_tile_m)
                .reshape(weight_scale.shape)
                .view(torch.float8_e4m3fn)
            )
1044

1045
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
1046
1047
1048
            layer.weight = Parameter(weight, requires_grad=False)
        else:
            swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
1049
            layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False)
1050
            layer.weight = Parameter(layer.weight.data, requires_grad=False)
1051
1052
1053
1054
1055

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
1056
        bias: torch.Tensor | None = None,
1057
    ) -> torch.Tensor:
1058
        if self.backend == "marlin":
1059
1060
1061
1062
1063
1064
1065
1066
            return apply_fp4_marlin_linear(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale,
                weight_scale_2=layer.weight_scale_2,
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
1067
1068
                bias=bias,
            )
1069

1070
        output_dtype = x.dtype
1071
        output_shape = [x.shape[0], layer.weight.shape[0]]
1072
1073

        # quantize BF16 or FP16 to (FP4 and interleaved block scale)
1074
        x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_scale_inv)
1075
1076
1077

        # validate dtypes of quantized input, input block scale,
        # weight and weight_blockscale
1078
1079
1080
1081
1082
        assert x_fp4.dtype == torch.uint8
        assert layer.weight.dtype == torch.uint8
        assert x_blockscale.dtype == torch.float8_e4m3fn
        assert layer.weight_scale.dtype == torch.float8_e4m3fn
        assert layer.alpha.dtype == torch.float32
1083

1084
1085
1086
1087
        mm_args = (
            x_fp4,
            layer.weight,
            x_blockscale,
1088
            layer.weight_scale,
1089
1090
1091
            layer.alpha,
            output_dtype,
        )
1092
1093
1094
        if self.backend.startswith("flashinfer-"):
            backend_name = self.backend[len("flashinfer-") :]
            out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name)
1095
        else:
1096
            assert self.backend == "cutlass"
1097
1098
            out = cutlass_scaled_fp4_mm(*mm_args)

1099
1100
1101
        if bias is not None:
            out = out + bias
        return out.view(*output_shape)
1102
1103
1104
1105
1106


class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
    """
    MoE Method for FP4 Quantization.
1107
    Args:
1108
1109
1110
        quant_config: NVFP4 Quant Config
    """

1111
1112
1113
    def __init__(
        self,
        quant_config: ModelOptNvFp4Config,
1114
        layer: FusedMoE,
1115
    ) -> None:
1116
1117
        from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import (
            detect_nvfp4_moe_support,  # noqa: E501
1118
1119
        )

1120
        super().__init__(layer.moe_config)
1121
1122
        self.quant_config = quant_config
        self.layer = layer
1123
1124
        _nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
        self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
1125
        self.allow_flashinfer = _nvfp4.allow_flashinfer
1126
        self.use_marlin = _nvfp4.use_marlin
1127
1128
        self.flashinfer_moe_backend = None
        if self.allow_flashinfer:
1129
1130
1131
            self.flashinfer_moe_backend = get_flashinfer_moe_backend()
            logger.info_once(
                f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
1132
1133
                " for ModelOptNvFp4FusedMoE."
            )
1134
1135
1136
1137
        elif self.use_marlin:
            logger.info_once("Using Marlin for ModelOptNvFp4FusedMoE.")
        else:
            logger.info_once("Using Cutlass for ModelOptNvFp4FusedMoE.")
1138

1139
1140
1141
1142
    def maybe_make_prepare_finalize(
        self,
        routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
    ) -> mk.FusedMoEPrepareAndFinalize | None:
1143
1144
1145
1146
        if self.use_marlin or (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1147
            return None
1148
1149
1150
1151
        elif (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
        ):
1152
            # For now, fp4 moe only works with the flashinfer dispatcher.
1153
1154
1155
            prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(
                self.moe
            )
1156
1157
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
1158
        else:
1159
            return super().maybe_make_prepare_finalize(routing_tables)
1160

1161
1162
1163
    def select_gemm_impl(
        self,
        prepare_finalize: mk.FusedMoEPrepareAndFinalize,
1164
        layer: torch.nn.Module,
1165
    ) -> mk.FusedMoEPermuteExpertsUnpermute:
1166
        assert self.moe_quant_config is not None
1167
        experts = select_nvfp4_gemm_impl(
1168
1169
            self.moe,
            self.moe_quant_config,
1170
1171
1172
1173
            allow_flashinfer=self.allow_flashinfer,
        )
        logger.debug_once("Using %s", experts.__class__.__name__)
        return experts
1174

1175
1176
1177
1178
1179
1180
    def uses_weight_scale_2_pattern(self) -> bool:
        """
        FP4 variants use 'weight_scale_2' pattern for per-tensor weight scales.
        """
        return True

1181
1182
1183
1184
1185
1186
1187
1188
1189
    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
1190
        if not self.quant_config.is_checkpoint_nvfp4_serialized:
1191
1192
1193
1194
            raise ValueError(
                "NVFP4 quantization was selected, "
                " dynamic quantization is not supported."
            )
1195

1196
1197
        layer.num_experts = num_experts
        layer.params_dtype = params_dtype
1198
1199
1200
1201
        layer.quant_config = self.quant_config
        weight_dtype = torch.uint8
        weight_scale_dtype = torch.float8_e4m3fn
        weight_loader = extra_weight_attrs.get("weight_loader")
1202
        global_num_experts = extra_weight_attrs.get("global_num_experts")
1203
1204
1205
1206
1207
1208
1209
        # GEMM 1
        w13_weight = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                # 2 fp4 items are packed in the input dimension
                hidden_size // 2,
1210
1211
                dtype=weight_dtype,
            ),
1212
1213
            input_dim=1,
            output_dim=2,
1214
1215
            weight_loader=weight_loader,
        )
1216
1217
1218
1219
1220
1221
1222
1223
1224
        layer.register_parameter("w13_weight", w13_weight)

        # GEMM 2
        w2_weight = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                hidden_size,
                # 2 fp4 items are packed in the input dimension
                intermediate_size_per_partition // 2,
1225
1226
                dtype=weight_dtype,
            ),
1227
1228
            input_dim=1,
            output_dim=2,
1229
1230
            weight_loader=weight_loader,
        )
1231
1232
1233
1234
1235
1236
1237
1238
        layer.register_parameter("w2_weight", w2_weight)

        w13_weight_scale = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                # 2 fp4 items are packed in the input dimension
                hidden_size // self.quant_config.group_size,
1239
1240
                dtype=weight_scale_dtype,
            ),
1241
1242
            input_dim=1,
            output_dim=2,
1243
1244
            weight_loader=weight_loader,
        )
1245
1246
1247
1248
1249
1250
1251
        layer.register_parameter("w13_weight_scale", w13_weight_scale)

        w2_weight_scale = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                hidden_size,
                # 2 fp4 items are packed in the input dimension
1252
1253
1254
                intermediate_size_per_partition // self.quant_config.group_size,
                dtype=weight_scale_dtype,
            ),
1255
1256
            input_dim=1,
            output_dim=2,
1257
1258
            weight_loader=weight_loader,
        )
1259
1260
1261
        layer.register_parameter("w2_weight_scale", w2_weight_scale)

        extra_weight_attrs.update(
1262
1263
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
        )
1264
1265
1266

        w13_weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(num_experts, 2, dtype=torch.float32),
1267
1268
            weight_loader=weight_loader,
        )
1269
1270
1271
1272
        layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)

        w2_weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(num_experts, dtype=torch.float32),
1273
1274
            weight_loader=weight_loader,
        )
1275
1276
1277
        layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2)

        extra_weight_attrs.update(
1278
1279
            {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
        )
1280

1281
1282
1283
1284
1285
        use_global_sf = self.allow_flashinfer and is_flashinfer_supporting_global_sf(
            self.flashinfer_moe_backend
        )
        global_scale_num_experts = global_num_experts if use_global_sf else num_experts

1286
        w13_input_scale = PerTensorScaleParameter(
1287
            data=torch.empty(global_scale_num_experts, 2, dtype=torch.float32),
1288
1289
            weight_loader=weight_loader,
        )
1290
1291
        layer.register_parameter("w13_input_scale", w13_input_scale)

1292
        w2_input_scale = PerTensorScaleParameter(
1293
            data=torch.empty(global_scale_num_experts, dtype=torch.float32),
1294
1295
            weight_loader=weight_loader,
        )
1296
1297
1298
        layer.register_parameter("w2_input_scale", w2_input_scale)

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
1299
        # GEMM 1 processing
1300
1301
1302
        gemm1_weight = layer.w13_weight.data
        gemm1_weight_scale = layer.w13_weight_scale.data

1303
1304
1305
        if self.allow_flashinfer and (
            self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
            or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
1306
        ):
1307
            gemm1_weight, gemm1_weight_scale = reorder_w1w3_to_w3w1(
1308
1309
                gemm1_weight, gemm1_weight_scale, dim=-2
            )
1310
1311

        layer.w13_weight = Parameter(gemm1_weight, requires_grad=False)
1312
        layer.w13_weight_scale = Parameter(gemm1_weight_scale, requires_grad=False)
1313

1314
        # Common processing for w13_weight_scale_2
1315
1316
1317
        if not torch.allclose(
            layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]
        ):
1318
1319
            logger.warning_once(
                "w1_weight_scale_2 must match w3_weight_scale_2. "
1320
1321
                "Accuracy may be affected."
            )
1322
1323

        w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
1324
        layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False)
1325

1326
        # Common processing for input scales and alphas
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
        use_global_sf = self.allow_flashinfer and is_flashinfer_supporting_global_sf(
            self.flashinfer_moe_backend
        )
        if use_global_sf:
            # For backends provide by Flashinfer, the input global scales are
            # shared across all experts.
            w13_input_scale = (
                layer.w13_input_scale.max().to(torch.float32).expand(layer.num_experts)
            )
        else:
            w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
1338
1339
        layer.g1_alphas = Parameter(
            (w13_input_scale * w13_weight_scale_2).to(torch.float32),
1340
1341
            requires_grad=False,
        )
1342
1343
1344

        # This is for quantization, so we need to invert it.
        layer.w13_input_scale_quant = Parameter(
1345
1346
            (1 / w13_input_scale).to(torch.float32), requires_grad=False
        )
1347

1348
        # GEMM 2 processing
1349
1350
1351
1352
1353
1354
1355
1356
        if use_global_sf:
            # For backends provide by Flashinfer, the input global scales are
            # shared across all experts.
            w2_input_scale = (
                layer.w2_input_scale.max().to(torch.float32).expand(layer.num_experts)
            )
        else:
            w2_input_scale = layer.w2_input_scale
1357
        layer.g2_alphas = Parameter(
1358
            (w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
1359
1360
            requires_grad=False,
        )
1361
1362
1363

        # This is for quantization, so we need to invert it.
        layer.w2_input_scale_quant = Parameter(
1364
            (1 / w2_input_scale).to(torch.float32), requires_grad=False
1365
        )
1366

1367
        # TensorRT-LLM specific processing
1368
1369
1370
1371
        if (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1372
            # Prepare static weights for TRT-LLM kernel
1373
            # alternate: prepare_static_weight_layouts_for_trtllm_moe
1374
1375
1376
1377
1378
            (
                gemm1_weights_fp4_shuffled,
                gemm1_scales_fp4_shuffled,
                gemm2_weights_fp4_shuffled,
                gemm2_scales_fp4_shuffled,
1379
            ) = prepare_static_weights_for_trtllm_fp4_moe(
1380
1381
1382
1383
1384
1385
1386
1387
                layer.w13_weight,
                layer.w2_weight,
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                layer.w2_weight.size(-2),  # hidden_size
                layer.w13_weight.size(-2) // 2,  # intermediate_size
                layer.w13_weight.size(0),  # num_experts
            )
1388
            logger.debug_once("Finished shuffling weights for TRT-LLM MOE")
1389
1390

            layer.gemm1_weights_fp4_shuffled = Parameter(
1391
1392
                gemm1_weights_fp4_shuffled, requires_grad=False
            )
1393
            layer.gemm2_weights_fp4_shuffled = Parameter(
1394
1395
                gemm2_weights_fp4_shuffled, requires_grad=False
            )
1396
            layer.gemm1_scales_fp4_shuffled = Parameter(
1397
1398
                gemm1_scales_fp4_shuffled, requires_grad=False
            )
1399
            layer.gemm2_scales_fp4_shuffled = Parameter(
1400
1401
                gemm2_scales_fp4_shuffled, requires_grad=False
            )
1402
1403
1404

            # Additional parameter needed for TRT-LLM
            layer.g1_scale_c = Parameter(
1405
                (layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32),
1406
1407
                requires_grad=False,
            )
1408

1409
1410
1411
1412
1413
            # Clean up weights that won't be used by TRT-LLM
            del layer.w2_weight
            del layer.w2_weight_scale
            del layer.w13_weight
            del layer.w13_weight_scale
1414
1415
1416
1417
1418
1419
1420
        elif self.use_marlin:
            # Marlin processing
            prepare_moe_fp4_layer_for_marlin(layer)
            del layer.g1_alphas
            del layer.g2_alphas
            del layer.w13_input_scale_quant
            del layer.w2_input_scale_quant
1421
1422
        else:
            # Non-TRT-LLM processing (Cutlass or non-flashinfer)
1423
1424
1425
1426
1427
            w13_blockscale_swizzled = swizzle_blockscale(layer.w13_weight_scale)
            layer.w13_weight_scale = Parameter(
                w13_blockscale_swizzled, requires_grad=False
            )

1428
            w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale)
1429
1430
1431
1432
            layer.w2_weight_scale = Parameter(
                w2_blockscale_swizzled, requires_grad=False
            )
            layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
1433

1434
    def get_fused_moe_quant_config(
1435
        self, layer: torch.nn.Module
1436
    ) -> FusedMoEQuantConfig | None:
1437
1438
1439
1440
        if (
            self.use_marlin
            or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
            return None

        return nvfp4_moe_quant_config(
            w1_scale=layer.w13_weight_scale,
            w2_scale=layer.w2_weight_scale,
            g1_alphas=layer.g1_alphas,
            g2_alphas=layer.g2_alphas,
            a1_gscale=layer.w13_input_scale_quant,
            a2_gscale=layer.w2_input_scale_quant,
        )

1452
1453
    def apply(
        self,
1454
        layer: FusedMoE,
1455
1456
1457
1458
1459
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
        use_grouped_topk: bool = False,
1460
1461
        topk_group: int | None = None,
        num_expert_group: int | None = None,
1462
        global_num_experts: int = -1,
1463
1464
        expert_map: torch.Tensor | None = None,
        custom_routing_function: Callable | None = None,
1465
        scoring_func: str = "softmax",
1466
        routed_scaling_factor: float = 1.0,
1467
        e_score_correction_bias: torch.Tensor | None = None,
1468
1469
        apply_router_weight_on_input: bool = False,
        activation: str = "silu",
1470
        enable_eplb: bool = False,
1471
1472
1473
1474
        expert_load_view: torch.Tensor | None = None,
        logical_to_physical_map: torch.Tensor | None = None,
        logical_replica_count: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
1475
        assert activation == "silu", "Only SiLU activation is supported."
1476

1477
1478
1479
1480
        if (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1481
1482
1483
1484
            if enable_eplb:
                raise NotImplementedError(
                    "EPLB not supported for `ModelOptNvFp4FusedMoE` yet."
                )
1485
1486
1487
1488
            return flashinfer_trtllm_fp4_moe(
                layer=layer,
                x=x,
                router_logits=router_logits,
1489
                top_k=top_k,
1490
1491
                global_num_experts=global_num_experts,
                num_expert_group=num_expert_group,
1492
                topk_group=topk_group,
1493
1494
1495
                custom_routing_function=custom_routing_function,
                e_score_correction_bias=e_score_correction_bias,
            )
1496

1497
        topk_weights, topk_ids, _ = layer.select_experts(
1498
1499
            hidden_states=x,
            router_logits=router_logits,
1500
        )
1501

1502
        if self.use_marlin:
1503
            return fused_marlin_moe(
1504
1505
1506
                x,
                layer.w13_weight,
                layer.w2_weight,
1507
1508
                None,
                None,
1509
1510
1511
1512
1513
1514
1515
1516
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                router_logits,
                topk_weights,
                topk_ids,
                global_scale1=layer.w13_weight_scale_2,
                global_scale2=layer.w2_weight_scale_2,
                quant_type_id=scalar_types.float4_e2m1f.id,
1517
                apply_router_weight_on_input=apply_router_weight_on_input,
1518
                global_num_experts=global_num_experts,
1519
                expert_map=expert_map,
1520
1521
                workspace=layer.workspace,
            )
1522

1523
1524
1525
1526
        elif self.allow_flashinfer:
            assert self.flashinfer_moe_backend in (
                FlashinferMoeBackend.CUTLASS,
                FlashinferMoeBackend.CUTEDSL,
1527
            )
1528
1529
1530
1531
            if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
                from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (  # noqa: E501
                    flashinfer_cutlass_moe_fp4,
                )
1532

1533
1534
1535
1536
1537
1538
1539
                flashinfer_fn_moe_fp4 = flashinfer_cutlass_moe_fp4
            else:
                from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import (  # noqa: E501
                    flashinfer_cutedsl_moe_fp4,
                )

                flashinfer_fn_moe_fp4 = flashinfer_cutedsl_moe_fp4
1540

1541
1542
            assert self.moe_quant_config is not None
            return flashinfer_fn_moe_fp4(
1543
1544
1545
1546
1547
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
1548
1549
                quant_config=self.moe_quant_config,
                inplace=False,
1550
1551
1552
1553
1554
1555
                activation=activation,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
        else:
1556
1557
            # If no modular kernel is provided, use cutlass_moe_fp4 for TP case
            # only (no EP).
1558
1559
            from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4

1560
1561
            assert self.moe_quant_config is not None
            return cutlass_moe_fp4(
1562
1563
1564
1565
1566
                a=x,
                w1_fp4=layer.w13_weight,
                w2_fp4=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
1567
1568
1569
1570
                quant_config=self.moe_quant_config,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
                # TODO: derive from arguments
1571
1572
1573
1574
                m=x.shape[0],
                n=layer.w2_weight.shape[2] * 2,
                k=x.shape[1],
                e=layer.w13_weight.shape[0],
1575
            )
1576
1577
1578
1579
1580


ModelOptNvFp4Config.LinearMethodCls = ModelOptNvFp4LinearMethod
ModelOptNvFp4Config.FusedMoEMethodCls = ModelOptNvFp4FusedMoE
ModelOptNvFp4Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod