modelopt.py 62.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from fnmatch import fnmatch
5
from typing import TYPE_CHECKING, Any, Optional
6
7
8
9
10

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter

11
12
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
13
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
14
from vllm.attention.layer import Attention
15
from vllm.logger import init_logger
16
from vllm.model_executor.layers.fused_moe.config import (
17
18
19
20
    FusedMoEQuantConfig,
    fp8_w8a8_moe_quant_config,
    nvfp4_moe_quant_config,
)
21
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
22
from vllm.model_executor.layers.fused_moe.layer import (
23
24
25
26
27
28
29
30
31
    FusedMoE,
    FusedMoEMethodBase,
    FusedMoeWeightScaleSupported,
)
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
)
32
from vllm.model_executor.layers.quantization import QuantizationMethods
33
from vllm.model_executor.layers.quantization.base_config import (
34
35
36
    QuantizationConfig,
    QuantizeMethodBase,
)
37
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
38
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
39
    build_flashinfer_fp4_cutlass_moe_prepare_finalize,
40
    flashinfer_trtllm_fp4_moe,
41
    flashinfer_trtllm_fp4_routed_moe,
42
    prepare_static_weights_for_trtllm_fp4_moe,
43
44
45
    reorder_w1w3_to_w3w1,
    select_nvfp4_gemm_impl,
)
46
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
47
48
    FlashinferMoeBackend,
    apply_flashinfer_per_tensor_scale_fp8,
49
    build_flashinfer_fp8_cutlass_moe_prepare_finalize,
50
51
    flashinfer_cutlass_moe_fp8,
    get_flashinfer_moe_backend,
52
    is_flashinfer_supporting_global_sf,
53
54
55
56
57
    register_moe_scaling_factors,
    rotate_flashinfer_fp8_moe_weights,
    select_cutlass_fp8_gemm_impl,
    swap_w13_to_w31,
)
58
59
60
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
    get_marlin_input_dtype,
)
61
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
62
63
64
65
66
    apply_fp4_marlin_linear,
    is_fp4_marlin_supported,
    prepare_fp4_layer_for_marlin,
    prepare_moe_fp4_layer_for_marlin,
)
67
from vllm.model_executor.layers.quantization.utils.quant_utils import (
68
69
70
71
72
    GroupShape,
    cutlass_fp4_supported,
    is_layer_skipped,
    swizzle_blockscale,
)
73
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
74
75
76
77
    Fp8LinearOp,
    requantize_with_max_scale,
)
from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
78
from vllm.scalar_type import scalar_types
79
80
81
82
83
from vllm.utils.flashinfer import (
    flashinfer_scaled_fp4_mm,
    has_flashinfer,
    has_flashinfer_moe,
)
84
from vllm.utils.math_utils import round_up
85

86
87
88
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

89
90
logger = init_logger(__name__)

91
92
QUANT_ALGOS = ["FP8", "NVFP4"]
KV_CACHE_QUANT_ALGOS = ["FP8"]
93
94


95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
    """

    def __init__(self, quant_config: "ModelOptQuantConfigBase"):
        super().__init__(quant_config)


class ModelOptQuantConfigBase(QuantizationConfig):
    LinearMethodCls: type = LinearMethodBase
    FusedMoEMethodCls: type = FusedMoEMethodBase
    KVCacheMethodCls: type = BaseKVCacheMethod

    def __init__(
        self,
        exclude_modules: list[str],
    ):
        super().__init__()
        self.exclude_modules: list[str] = exclude_modules

    def is_layer_excluded(self, prefix: str) -> bool:
        """
        Check if a layer should be excluded from quantization.

        Handles both exact matching (for fused layers) and ModelOpt wildcard matching.

        The ModelOpt exclude_modules list is a list of wildcards.
        """
        if len(self.exclude_modules) == 0:
            return False

        # First check exact matching with fused layer support
        if is_layer_skipped(prefix, self.exclude_modules, self.packed_modules_mapping):
            return True

        # TODO: This special hard coded logic is not needed for quantized checkpoints
        # generated by ModelOpt >= 0.39.0 where they are handled natually by the
        # exclude_modules config. But need to keep them for loading quantized
        # checkpoints generated by older versions. Then check substring matching
        # for patterns not caught by exact match
        for exclude_module in self.exclude_modules:
            # Skip exact matches already handled above
            if exclude_module != prefix and (
                exclude_module in prefix
                or (
                    prefix.startswith("language_model.")
                    and exclude_module in prefix.removeprefix("language_model.")
                )
            ):
                return True

        # modelopt exclude modules are not simple strings, they are wildcards
        for wildcard_pattern in self.exclude_modules:
            if fnmatch(prefix, wildcard_pattern):
                return True

        return False

    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
        # handle kv-cache first so we can focus only on weight quantization thereafter
        if isinstance(layer, Attention):
            return self.KVCacheMethodCls(self)

        # handle exclusion
        if self.is_layer_excluded(prefix):
            if isinstance(layer, LinearBase):
                return UnquantizedLinearMethod()
            return None

        # TODO: This special hard coded logic is not needed for quantized checkpoints
        # generated by ModelOpt >= 0.39.0 where they are handled natually by the
        # exclude_modules config. But need to keep them for loading quantized
        # checkpoints generated by older versions. Then check substring matching
        # for patterns not caught by exact match
        if "vision_tower" in prefix or "vision_model" in prefix:
            return UnquantizedLinearMethod()

        # now, the layer is quantized, handle it here
        if isinstance(layer, LinearBase):
177
178
179
180
            quant_method = self.LinearMethodCls(self)
            if getattr(quant_method, "backend", "") == "marlin":
                quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
            return quant_method
181
        elif isinstance(layer, FusedMoE):
182
183
184
185
            quant_method = self.FusedMoEMethodCls(quant_config=self, layer=layer)
            if getattr(quant_method, "backend", "") == "marlin":
                quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
            return quant_method
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283

        return None

    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if len(self.exclude_modules) > 0:
            self.exclude_modules = hf_to_vllm_mapper.apply_list(self.exclude_modules)

    @staticmethod
    def get_config_filenames() -> list[str]:
        return ["hf_quant_config.json"]

    @classmethod
    def _from_config(
        cls,
        *,
        quant_method: str,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
        original_config: dict[str, Any],
        group_size: int | None,
    ) -> "ModelOptQuantConfigBase":
        raise NotImplementedError("Please implement this function in sub classes")

    @classmethod
    def from_config(cls, config: dict[str, Any]) -> "ModelOptQuantConfigBase":
        # Handle both ModelOpt format and compressed-tensors style format
        if "quantization" in config:
            # Traditional ModelOpt format:
            # {"quantization": {"quant_algo": "..."}}
            quant_config = cls.get_from_keys(config, ["quantization"])
            if not isinstance(quant_config, dict):
                raise ValueError("Expected 'quantization' to be a dictionary in config")

            quant_method = quant_config.get("quant_algo")

            # Handle kv_cache_quant_algo with proper type validation
            kv_cache_quant_method = quant_config.get("kv_cache_quant_algo")

            # Handle group_size with proper type validation
            group_size_raw = quant_config.get("group_size")

            # "exclude_modules" is the key in the legacy hf_quant_config.json
            exclude_modules = quant_config.get("exclude_modules", [])
        else:
            # Compressed-tensors style format:
            # {"quant_algo": "...", "quant_method": "modelopt"}
            quant_method = config.get("quant_algo")
            kv_cache_quant_method = config.get("kv_cache_quant_algo")
            # "ignore" is the key in config.json
            exclude_modules = config.get("ignore", [])
            group_size_raw = config.get("group_size")

        if not quant_method:
            raise ValueError("Missing 'quant_algo' in quantization config")

        if kv_cache_quant_method is None:
            # No KV cache quantization, keep this branch just to have this comment
            pass
        elif not isinstance(kv_cache_quant_method, str):
            raise ValueError(
                f"kv_cache_quant_algo must be a string, got "
                f"{type(kv_cache_quant_method)}"
            )

        if not isinstance(exclude_modules, list):
            raise ValueError(
                f"exclude_modules must be a list, got {type(exclude_modules)}"
            )

        if group_size_raw is None:
            group_size = None
        elif isinstance(group_size_raw, int):
            group_size = group_size_raw
        else:
            try:
                group_size = int(group_size_raw)
            except (ValueError, TypeError):
                raise ValueError(
                    f"group_size must be an integer, got {type(group_size_raw)}"
                ) from None

        if quant_method not in QUANT_ALGOS:
            raise ValueError(
                f"ModelOpt currently only supports: {QUANT_ALGOS} "
                "quantizations in vLLM. Please check the "
                "`hf_quant_config.json` file for your model's "
                "quant configuration."
            )
        return cls._from_config(
            quant_method=quant_method,
            kv_cache_quant_method=kv_cache_quant_method,
            exclude_modules=exclude_modules,
            group_size=group_size,
            original_config=config,
        )


class ModelOptFp8Config(ModelOptQuantConfigBase):
284
285
286
287
    """Config class for ModelOpt FP8."""

    def __init__(
        self,
288
289
290
        is_checkpoint_fp8_serialized: bool,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
291
    ) -> None:
292
        super().__init__(exclude_modules)
293
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
294
        self.kv_cache_quant_method = kv_cache_quant_method
295
        if is_checkpoint_fp8_serialized:
296
297
298
299
            logger.warning(
                "Detected ModelOpt fp8 checkpoint. Please note that"
                " the format is experimental and could change."
            )
300

301
    def get_name(self) -> QuantizationMethods:
302
303
        return "modelopt"

304
    def get_supported_act_dtypes(self) -> list[torch.dtype]:
305
306
307
308
309
310
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
        return 89

311
312
    @classmethod
    def override_quantization_method(
313
        cls, hf_quant_cfg, user_quant
314
    ) -> QuantizationMethods | None:
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
        """Detect if this ModelOpt config should be used based on
        quantization config."""

        if hf_quant_cfg is None:
            return None

        # Use the community standard 'quant_method'
        quant_method = hf_quant_cfg.get("quant_method", "").lower()

        # Only proceed if the method is explicitly "modelopt"
        if quant_method != "modelopt":
            return None

        # Look for ModelOpt-specific config structure
        if "quantization" in hf_quant_cfg:
            quant_config = hf_quant_cfg["quantization"]
            if isinstance(quant_config, dict):
                quant_algo = quant_config.get("quant_algo", "")
                if "FP8" in quant_algo:
                    return "modelopt"
        else:
            # Check for compressed-tensors style config with specific quant_algo
            quant_algo = hf_quant_cfg.get("quant_algo", "")
            if isinstance(quant_algo, str) and "FP8" in quant_algo:
                return "modelopt"

        return None

343
    @classmethod
344
345
346
347
348
349
350
351
352
    def _from_config(
        cls,
        *,
        quant_method: str,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
        original_config: dict[str, Any],
        **kwargs: Any,
    ) -> "ModelOptFp8Config":
353
        is_checkpoint_fp8_serialized = "FP8" in quant_method
354

355
        return cls(is_checkpoint_fp8_serialized, kv_cache_quant_method, exclude_modules)
356

357
358
359
360

class ModelOptFp8LinearMethod(LinearMethodBase):
    """Linear method for Model Optimizer static quantization.
    Supports loading FP8 checkpoints with static weight scale and
361
    activation scale. Future support might be added for dynamic
362
363
364
365
    scales.

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
366
    2. Only support float8_e4m3fn datatype
367
368
369
        Args: quant_config: The ModelOpt quantization config.
    """

370
    def __init__(self, quant_config: ModelOptFp8Config) -> None:
371
        self.quant_config = quant_config
372
        self.fp8_linear = Fp8LinearOp(
373
374
            act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR
        )
375
376
377
378
379

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
380
        output_partition_sizes: list[int],
381
382
383
384
385
386
387
388
389
390
391
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
392
393
394
395
396
397
398
399
400
401
402
403
404
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_fp8_serialized
            else params_dtype
        )
        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition, input_size_per_partition, dtype=weight_dtype
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
405
406
407
408
        layer.register_parameter("weight", weight)

        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
409
410
411
412
            weight_scale = PerTensorScaleParameter(
                data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
                weight_loader=weight_loader,
            )
413
414
415
            weight_scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("weight_scale", weight_scale)
            # INPUT SCALE
416
417
418
419
            scale = PerTensorScaleParameter(
                data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
                weight_loader=weight_loader,
            )
420
421
422
423
424

            scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("input_scale", scale)

    def process_weights_after_loading(self, layer: Module) -> None:
425
426
427
428
        weight = layer.weight
        max_w_scale = layer.weight_scale.max()
        if not (layer.weight_scale == layer.weight_scale[0]).all():
            max_w_scale, weight = requantize_with_max_scale(
429
430
                layer.weight, layer.weight_scale, layer.logical_widths
            )
431
432
        layer.weight = Parameter(weight.t(), requires_grad=False)
        layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
433
        layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
434
435
436
437
438

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
439
        bias: torch.Tensor | None = None,
440
    ) -> torch.Tensor:
441
442
443
444
445
446
447
        return self.fp8_linear.apply(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            input_scale=layer.input_scale,
            bias=bias,
        )
448
449


450
451
452
453
454
455
456
457
class ModelOptFp8MoEMethod(FusedMoEMethodBase):
    """MoE method for ModelOpt FP8.
    Supports loading FP8 checkpoints with static weight scale and
    activation scale.
    Args:
        quant_config: The ModelOpt quantization config.
    """

458
459
460
    def __init__(
        self,
        quant_config: ModelOptFp8Config,
461
        layer: FusedMoE,
462
    ) -> None:
463
464
        super().__init__(layer.moe_config)
        self.layer = layer
465
466
        self.quant_config = quant_config
        from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
467
468
469
            cutlass_fp8_supported,
        )

470
        self.cutlass_fp8_supported = cutlass_fp8_supported()
471
        self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
472
        if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
473
            self.flashinfer_moe_backend = get_flashinfer_moe_backend()
474
475
476
477
478
479
480
481
482
483
            if (
                self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
                and not self.moe.is_act_and_mul
            ):
                logger.info_once(
                    "Non-gated MoE is not supported for min-latency mode,"
                    "falling back to high-throughput mode"
                )
                self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS

484
            logger.info_once(
485
486
487
488
                f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
            )

    def maybe_make_prepare_finalize(
489
        self,
490
        routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
491
    ) -> mk.FusedMoEPrepareAndFinalize | None:
492
493
494
495
        # TRT LLM not supported with all2all yet.
        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
            return None
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
496
497
498
            prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
                self.moe
            )
499
500
501
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
        else:
502
            return super().maybe_make_prepare_finalize(routing_tables)
503
504
505
506

    def select_gemm_impl(
        self,
        prepare_finalize: mk.FusedMoEPrepareAndFinalize,
507
        layer: torch.nn.Module,
508
    ) -> mk.FusedMoEPermuteExpertsUnpermute:
509
        assert self.moe_quant_config is not None
510
        experts = select_cutlass_fp8_gemm_impl(
511
512
            self.moe,
            self.moe_quant_config,
513
514
515
        )
        logger.debug_once("Using %s", experts.__class__.__name__)
        return experts
516
517
518
519
520
521
522
523
524
525
526

    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        # Use FP8 dtype if checkpoint is serialized
527
528
529
530
531
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_fp8_serialized
            else params_dtype
        )
532
533
        weight_loader = extra_weight_attrs.get("weight_loader")

534
535
536
537
538
        if self.moe.is_act_and_mul:
            w13_up_dim = 2 * intermediate_size_per_partition
        else:
            w13_up_dim = intermediate_size_per_partition

539
        w13_weight = ModelWeightParameter(
540
541
            data=torch.empty(
                num_experts,
542
                w13_up_dim,
543
544
545
                hidden_size,
                dtype=weight_dtype,
            ),
546
547
548
549
550
551
552
            input_dim=2,
            output_dim=1,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_weight", w13_weight)

        w2_weight = ModelWeightParameter(
553
554
555
556
557
558
            data=torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=weight_dtype,
            ),
559
560
561
562
563
564
565
566
            input_dim=2,
            output_dim=1,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w2_weight", w2_weight)

        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALES - Per-tensor scaling for ModelOpts
567
            # For gated MoE, allocate 2 scales for w1 and w3 respectively.
568
            # They will be combined to a single scale after weight loading.
569
570
571
572
573
            # For non-gated MoE, allocate 1 scale for w13.
            if self.moe.is_act_and_mul:
                w13_weight_scale_shape = (num_experts, 2)
            else:
                w13_weight_scale_shape = (num_experts, 1)
574
575
            w13_weight_scale = PerTensorScaleParameter(
                data=torch.full(
576
                    w13_weight_scale_shape,
577
578
579
580
581
582
                    1.0,
                    dtype=torch.float32,
                ),
                weight_loader=weight_loader,
            )
            w2_weight_scale = PerTensorScaleParameter(
583
                data=torch.full((num_experts,), 1.0, dtype=torch.float32),
584
585
586
587
588
589
590
                weight_loader=weight_loader,
            )
            layer.register_parameter("w13_weight_scale", w13_weight_scale)
            layer.register_parameter("w2_weight_scale", w2_weight_scale)

            # Set weight loader attributes for scales
            extra_weight_attrs.update(
591
592
                {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
            )
593
594
595

            # INPUT SCALES - Per-tensor scaling for ModelOpt
            w13_input_scale = PerTensorScaleParameter(
596
                data=torch.full((num_experts,), 1.0, dtype=torch.float32),
597
598
599
                weight_loader=weight_loader,
            )
            w2_input_scale = PerTensorScaleParameter(
600
                data=torch.full((num_experts,), 1.0, dtype=torch.float32),
601
602
603
604
605
606
607
608
609
610
                weight_loader=weight_loader,
            )
            layer.register_parameter("w13_input_scale", w13_input_scale)
            layer.register_parameter("w2_input_scale", w2_input_scale)

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        """Process FP8 MoE weights after loading from serialized checkpoint.
        Only supports pre-quantized checkpoints with FP8 weights and scales.
        """

611
612
613
        if self.flashinfer_moe_backend is not None:
            self._maybe_pad_intermediate_for_flashinfer(layer)

614
        layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
615
616
617
618
        layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)

        from vllm._custom_ops import scaled_fp8_quant
        from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
619
620
            per_tensor_dequantize,
        )
621
622

        # Handle scale parameters
623
        if hasattr(layer, "w13_weight_scale") and layer.w13_weight_scale is not None:
624
625
626
            # Fp8 moe kernel needs single weight scale for w13 per expert.
            # We take the max of the w1 and w3 scales
            # then dequant and requant each expert.
627
628
629
630
631
632
633
634
            if (
                layer.w13_weight_scale.dim() == 2
                and layer.w13_weight_scale.shape[1] == 2
            ):
                assert self.moe.is_act_and_mul, (
                    "w13_weight_scale should have 2 elements per expert "
                    "only for gated MoE"
                )
635
636
637
638
639
640
641
642
643
644
645
646
                # Get the maximum scale across w1 and w3 for each expert
                max_w13_scales = layer.w13_weight_scale.max(dim=1).values

                # Requantize each expert's weights using the combined scale
                # w13_weight (num_experts, 2 * intermediate_size, hidden_size)
                # where the first intermediate_size rows are w1, the next are w3
                intermediate_size = layer.w13_weight.shape[1] // 2
                for expert_id in range(layer.w13_weight.shape[0]):
                    start = 0
                    for shard_id in range(2):  # w1 and w3
                        # Dequantize using the original scale for this shard
                        dq_weight = per_tensor_dequantize(
647
648
649
                            layer.w13_weight[expert_id][
                                start : start + intermediate_size, :
                            ],
650
651
652
653
654
                            layer.w13_weight_scale[expert_id][shard_id],
                        )
                        # Requantize using the combined max scale

                        (
655
656
657
                            layer.w13_weight[expert_id][
                                start : start + intermediate_size, :
                            ],
658
                            _,
659
                        ) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
660
661
662
663

                        start += intermediate_size

                # Update the scale parameter to be per-expert
664
                layer.w13_weight_scale = Parameter(max_w13_scales, requires_grad=False)
665
            else:
666
667
668
                layer.w13_weight_scale = Parameter(
                    layer.w13_weight_scale.data, requires_grad=False
                )
669

670
671
672
673
        if hasattr(layer, "w2_weight_scale") and layer.w2_weight_scale is not None:
            layer.w2_weight_scale = Parameter(
                layer.w2_weight_scale.data, requires_grad=False
            )
674
        # Input scales must be equal for each expert in fp8 MoE layers.
675
676
677
678
679
680
681
682
        if hasattr(layer, "w13_input_scale") and layer.w13_input_scale is not None:
            layer.w13_input_scale = Parameter(
                layer.w13_input_scale.max(), requires_grad=False
            )
        if hasattr(layer, "w2_input_scale") and layer.w2_input_scale is not None:
            layer.w2_input_scale = Parameter(
                layer.w2_input_scale.max(), requires_grad=False
            )
683

684
        if self.flashinfer_moe_backend is not None:
685
686
            if self.moe.is_act_and_mul:
                layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
687
            if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
688
                rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight)
689
        register_moe_scaling_factors(layer)
690

691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
    def _maybe_pad_intermediate_for_flashinfer(self, layer: torch.nn.Module) -> None:
        """Pad intermediate size so FlashInfer kernels' alignment constraints hold.

        Some FlashInfer FP8 MoE kernels require the (gated) intermediate size
        used for GEMM to be divisible by a small alignment value. When this is
        not satisfied (e.g. with certain tensor-parallel sizes), we pad the
        gate/up and down projection weights along the intermediate dim.
        """
        if not hasattr(layer, "w13_weight") or not hasattr(layer, "w2_weight"):
            return

        # Current local intermediate size (per partition) is the K dimension of
        # the down projection.
        num_experts, hidden_size, intermediate = layer.w2_weight.shape

        min_alignment = 16
        padded_intermediate = round_up(intermediate, min_alignment)

        if padded_intermediate == intermediate:
            return

        logger.info(
            "Padding intermediate size from %d to %d for up/down projection weights.",
            intermediate,
            padded_intermediate,
        )

        up_mult = 2 if self.moe.is_act_and_mul else 1
        padded_gate_up_dim = up_mult * padded_intermediate

        # Pad w13 and w12 along its intermediate dimension.
        w13 = layer.w13_weight.data
        padded_w13 = w13.new_zeros((num_experts, padded_gate_up_dim, hidden_size))
        padded_w13[:, : w13.shape[1], :] = w13
        layer.w13_weight.data = padded_w13

        w2 = layer.w2_weight.data
        padded_w2 = w2.new_zeros((num_experts, hidden_size, padded_intermediate))
        padded_w2[:, :, :intermediate] = w2
        layer.w2_weight.data = padded_w2

        if hasattr(layer, "intermediate_size_per_partition"):
            layer.intermediate_size_per_partition = padded_intermediate

735
    def get_fused_moe_quant_config(
736
        self, layer: torch.nn.Module
737
    ) -> FusedMoEQuantConfig | None:
738
739
740
741
742
        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
            return None

        return fp8_w8a8_moe_quant_config(
            w1_scale=layer.w13_weight_scale,
743
            g1_alphas=layer.output1_scales_gate_scalar.squeeze(),
744
            w2_scale=layer.w2_weight_scale,
745
            g2_alphas=layer.output2_scales_scalar.squeeze(),
746
            a1_scale=layer.w13_input_scale,
747
            a1_gscale=layer.w13_input_scale,
748
            a2_scale=layer.w2_input_scale,
749
            a2_gscale=layer.w2_input_scale_inv,
750
751
752
            per_act_token_quant=False,
        )

753
754
    def apply(
        self,
755
        layer: FusedMoE,
756
757
        x: torch.Tensor,
        router_logits: torch.Tensor,
758
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
759
        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
760
761
762
763
            if layer.enable_eplb:
                raise NotImplementedError(
                    "EPLB not supported for `ModelOptFp8MoEMethod` yet."
                )
764
765
            assert layer.activation == "silu", (
                f"Expected 'silu' activation but got {layer.activation}"
766
            )
767
768

            assert not layer.renormalize
769
770
771
772
            return apply_flashinfer_per_tensor_scale_fp8(
                layer=layer,
                hidden_states=x,
                router_logits=router_logits,
773
774
775
776
777
778
                routing_bias=layer.e_score_correction_bias,
                global_num_experts=layer.global_num_experts,
                top_k=layer.top_k,
                num_expert_group=layer.num_expert_group,
                topk_group=layer.topk_group,
                apply_router_weight_on_input=layer.apply_router_weight_on_input,
779
            )
780

781
        # Expert selection
782
        topk_weights, topk_ids, _ = layer.select_experts(
783
784
785
            hidden_states=x,
            router_logits=router_logits,
        )
786

787
        if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
788
            assert layer.activation in ("silu", "relu2_no_mul"), (
789
                "Expected activation to be in ('silu', 'relu2_no_mul'),"
790
                f"but got {layer.activation}"
791
            )
792
793
794
795
796
797
            return flashinfer_cutlass_moe_fp8(
                x,
                layer,
                topk_weights,
                topk_ids,
                inplace=False,
798
799
800
801
                activation=layer.activation,
                global_num_experts=layer.global_num_experts,
                expert_map=layer.expert_map,
                apply_router_weight_on_input=layer.apply_router_weight_on_input,
802
803
            )
        else:
804
805
            from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts

806
807
808
809
810
811
812
813
814
            assert self.moe_quant_config is not None

            return fused_experts(
                x,
                layer.w13_weight,
                layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                inplace=True,
815
                activation=layer.activation,
816
                quant_config=self.moe_quant_config,
817
818
819
                global_num_experts=layer.global_num_experts,
                expert_map=layer.expert_map,
                apply_router_weight_on_input=layer.apply_router_weight_on_input,
820
            )
821
822


823
824
825
826
827
828
ModelOptFp8Config.LinearMethodCls = ModelOptFp8LinearMethod
ModelOptFp8Config.FusedMoEMethodCls = ModelOptFp8MoEMethod
ModelOptFp8Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod


class ModelOptNvFp4Config(ModelOptQuantConfigBase):
829
830
831
832
833
    """Config class for ModelOpt FP4."""

    def __init__(
        self,
        is_checkpoint_nvfp4_serialized: bool,
834
        kv_cache_quant_algo: str | None,
835
        exclude_modules: list[str],
836
837
        group_size: int = 16,
    ) -> None:
838
        super().__init__(exclude_modules)
839
840
841
842
        self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
        if is_checkpoint_nvfp4_serialized:
            logger.warning(
                "Detected ModelOpt NVFP4 checkpoint. Please note that"
843
844
                " the format is experimental and could change in future."
            )
845
846
847
848

            self.group_size = group_size
            self.kv_cache_quant_algo = kv_cache_quant_algo

849
    def get_name(self) -> QuantizationMethods:
850
        return "modelopt_fp4"
851

852
    def get_supported_act_dtypes(self) -> list[torch.dtype]:
853
854
855
856
        return [torch.bfloat16, torch.half, torch.float8_e4m3fn]

    @classmethod
    def get_min_capability(cls) -> int:
857
        return 80
858

859
860
    @classmethod
    def override_quantization_method(
861
        cls, hf_quant_cfg, user_quant
862
    ) -> QuantizationMethods | None:
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
        """Detect if this ModelOpt FP4 config should be used based on
        quantization config."""
        if hf_quant_cfg is None:
            return None

        # Use the community standard 'quant_method'
        quant_method = hf_quant_cfg.get("quant_method", "").lower()

        # Only proceed if the method is explicitly "modelopt"
        if quant_method != "modelopt":
            return None

        # Look for ModelOpt-specific config structure
        if "quantization" in hf_quant_cfg:
            quant_config = hf_quant_cfg["quantization"]
            if isinstance(quant_config, dict):
                quant_algo = quant_config.get("quant_algo", "")
                if "NVFP4" in quant_algo:
                    return "modelopt_fp4"
        else:
            # Check for compressed-tensors style config with specific
            # quant_algo field
            quant_algo = hf_quant_cfg.get("quant_algo", "")
            if isinstance(quant_algo, str) and "FP4" in quant_algo.upper():
                return "modelopt_fp4"

        return None

891
    @classmethod
892
893
894
895
896
897
898
899
900
901
    def _from_config(
        cls,
        *,
        quant_method: str,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
        original_config: dict[str, Any],
        group_size: int | None,
        **kwargs: Any,
    ) -> "ModelOptNvFp4Config":
902
        is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
903

904
905
906
        if group_size is None:
            group_size = 16  # Default value

907
        # For FP4, these fields are required
908
        if is_checkpoint_nvfp4_serialized and "quantization" in original_config:
909
            # Check if required fields are present in the quantization config
910
            quant_config = original_config["quantization"]
911
            required_fields = ["group_size", "kv_cache_quant_algo", "exclude_modules"]
912
913
914
915
916
917
            missing_fields = [
                field for field in required_fields if field not in quant_config
            ]
            if missing_fields:
                raise ValueError(
                    f"NVFP4 quantization requires the following fields in "
918
919
920
921
922
                    f"hf_quant_config.json: {missing_fields}"
                )

        return cls(
            is_checkpoint_nvfp4_serialized,
923
            kv_cache_quant_method,
924
925
926
            exclude_modules,
            group_size,
        )
927
928
929
930
931


class ModelOptNvFp4LinearMethod(LinearMethodBase):
    """Linear method for Model Optimizer NVFP4.
    Supports loading NVFP4 checkpoints with the following structure:
932

933
934
935
936
937
938
939
    input_scale: torch.float32, scalar ,
    weight: NVFP4(represented as byte) Shape: [1, X, y/2]
    weight_scale: FP8-E4M3, Shape: [X, Y], aka per block scale,
    weight_scale_2: torch.float32, scalar,
    Args: quant_config: The ModelOpt quantization config.
    """

940
    def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
941
        self.quant_config = quant_config
942
        self.marlin_input_dtype = None
943

944
945
946
947
948
949
950
951
952
953
954
        self.backend = "none"
        if envs.VLLM_NVFP4_GEMM_BACKEND is None:
            if has_flashinfer():
                self.backend = "flashinfer-cutlass"
            elif cutlass_fp4_supported():
                self.backend = "cutlass"
            elif is_fp4_marlin_supported():
                self.backend = "marlin"
        elif envs.VLLM_NVFP4_GEMM_BACKEND.startswith("flashinfer-"):
            self.backend = envs.VLLM_NVFP4_GEMM_BACKEND
            assert has_flashinfer(), f"FlashInfer is required for {self.backend}"
955
956
957
        elif envs.VLLM_NVFP4_GEMM_BACKEND == "cutlass":
            self.backend = "cutlass"
            assert cutlass_fp4_supported(), f"Cutlass is required for {self.backend}"
958
959

        if self.backend == "none":
960
            raise ValueError(
961
962
                "No valid NVFP4 GEMM backend found. "
                "Please check your platform capability."
963
            )
964

965
966
        logger.info_once(f"Using {self.backend} for NVFP4 GEMM")

967
968
969
970
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
971
        output_partition_sizes: list[int],
972
973
974
975
976
977
978
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size
        if not self.quant_config.is_checkpoint_nvfp4_serialized:
979
980
981
982
            raise ValueError(
                "NVFP4 quantization was selected, "
                " dynamic quantization is not supported."
            )
983
984
985
986
987
988
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition

989
990
991
992
        if input_size_per_partition % 16 != 0:
            raise ValueError(
                "Unsupported model when in features size is not multiple of 16"
            )
993
        # The nvfp4 weight is still represented as
994
995
996
997
998
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_nvfp4_serialized
            else params_dtype
        )
999
1000
1001
1002
1003
1004
        # Weight
        weight = ModelWeightParameter(
            data=torch.empty(
                # 2 fp4 items are packed in the input dimension
                layer.output_size_per_partition,
                layer.input_size_per_partition // 2,
1005
1006
                dtype=torch.uint8,
            ),
1007
1008
            input_dim=1,
            output_dim=0,
1009
1010
            weight_loader=weight_loader,
        )
1011
1012
1013
        layer.register_parameter("weight", weight)

        # Input Weight Scale
1014
1015
1016
1017
        input_scale = PerTensorScaleParameter(
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )
1018
1019
1020
        layer.register_parameter("input_scale", input_scale)

        # Global Weight Scale
1021
1022
1023
1024
        weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )
1025
1026
1027
        layer.register_parameter("weight_scale_2", weight_scale_2)

        # Per Block Weight Scale
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
        weight_scale = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition // self.quant_config.group_size,
                dtype=weight_dtype,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048

        layer.register_parameter("weight_scale", weight_scale)

    def process_weights_after_loading(self, layer: Module) -> None:
        # global scales:
        input_scale_2 = layer.input_scale.max().to(torch.float32)
        layer.input_scale = Parameter(input_scale_2, requires_grad=False)

        weight_scale_2 = layer.weight_scale_2.max().to(torch.float32)
        layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False)

1049
1050
1051
        layer.alpha = Parameter(
            layer.input_scale * layer.weight_scale_2, requires_grad=False
        )
1052

1053
1054
        # Calculate `1 / input_scale` so that we don't need to do so at runtime
        layer.input_scale_inv = Parameter(
1055
1056
            (1 / layer.input_scale).to(torch.float32), requires_grad=False
        )
1057

1058
1059
1060
        # Swizzle the weight blockscale.
        # contracting dimension is input dimension
        # block_size = 16;
1061
1062
1063
        assert layer.weight_scale.dtype == torch.float8_e4m3fn, (
            "Weight Block scale must be represented as FP8-E4M3"
        )
1064

1065
1066
1067
1068
1069
        if self.backend == "marlin":
            prepare_fp4_layer_for_marlin(layer)
            del layer.alpha
            del layer.input_scale
        elif self.backend == "flashinfer-trtllm":
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
            # FlashInfer TRTLLM FP4 GEMM requires a different weight layout.
            # FlashInfer provides nvfp4_quantize to quantize + shuffle the
            # layout but we use our own quantization so we have to call
            # shuffles ourselves.
            from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a

            weight = layer.weight.data
            weight_scale = layer.weight_scale.data

            epilogue_tile_m = 128
1080
1081
1082
1083
1084
1085
            weight = shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m)
            weight_scale = (
                shuffle_matrix_sf_a(weight_scale.view(torch.uint8), epilogue_tile_m)
                .reshape(weight_scale.shape)
                .view(torch.float8_e4m3fn)
            )
1086

1087
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
1088
1089
1090
            layer.weight = Parameter(weight, requires_grad=False)
        else:
            swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
1091
            layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False)
1092
            layer.weight = Parameter(layer.weight.data, requires_grad=False)
1093
1094
1095
1096
1097

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
1098
        bias: torch.Tensor | None = None,
1099
    ) -> torch.Tensor:
1100
        if self.backend == "marlin":
1101
1102
1103
1104
1105
1106
1107
1108
            return apply_fp4_marlin_linear(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale,
                weight_scale_2=layer.weight_scale_2,
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
1109
                bias=bias,
1110
                input_dtype=self.marlin_input_dtype,
1111
            )
1112

1113
        output_dtype = x.dtype
1114
        output_shape = [x.shape[0], layer.weight.shape[0]]
1115
1116

        # quantize BF16 or FP16 to (FP4 and interleaved block scale)
1117
        x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_scale_inv)
1118
1119
1120

        # validate dtypes of quantized input, input block scale,
        # weight and weight_blockscale
1121
1122
1123
1124
1125
        assert x_fp4.dtype == torch.uint8
        assert layer.weight.dtype == torch.uint8
        assert x_blockscale.dtype == torch.float8_e4m3fn
        assert layer.weight_scale.dtype == torch.float8_e4m3fn
        assert layer.alpha.dtype == torch.float32
1126

1127
1128
1129
1130
        mm_args = (
            x_fp4,
            layer.weight,
            x_blockscale,
1131
            layer.weight_scale,
1132
1133
1134
            layer.alpha,
            output_dtype,
        )
1135
1136
1137
        if self.backend.startswith("flashinfer-"):
            backend_name = self.backend[len("flashinfer-") :]
            out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name)
1138
        else:
1139
            assert self.backend == "cutlass"
1140
1141
            out = cutlass_scaled_fp4_mm(*mm_args)

1142
1143
1144
        if bias is not None:
            out = out + bias
        return out.view(*output_shape)
1145
1146
1147
1148
1149


class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
    """
    MoE Method for FP4 Quantization.
1150
    Args:
1151
1152
1153
        quant_config: NVFP4 Quant Config
    """

1154
1155
1156
    def __init__(
        self,
        quant_config: ModelOptNvFp4Config,
1157
        layer: FusedMoE,
1158
    ) -> None:
1159
1160
        from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import (
            detect_nvfp4_moe_support,  # noqa: E501
1161
1162
        )

1163
        super().__init__(layer.moe_config)
1164
1165
        self.quant_config = quant_config
        self.layer = layer
1166
1167
        _nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
        self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
1168
        self.allow_flashinfer = _nvfp4.allow_flashinfer
1169
        self.use_marlin = _nvfp4.use_marlin
1170
        self.marlin_input_dtype = None
1171
1172
        self.flashinfer_moe_backend = None
        if self.allow_flashinfer:
1173
1174
1175
            self.flashinfer_moe_backend = get_flashinfer_moe_backend()
            logger.info_once(
                f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
1176
1177
                " for ModelOptNvFp4FusedMoE."
            )
1178
1179
1180
1181
        elif self.use_marlin:
            logger.info_once("Using Marlin for ModelOptNvFp4FusedMoE.")
        else:
            logger.info_once("Using Cutlass for ModelOptNvFp4FusedMoE.")
1182

1183
1184
1185
1186
    def maybe_make_prepare_finalize(
        self,
        routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
    ) -> mk.FusedMoEPrepareAndFinalize | None:
1187
1188
1189
1190
        if self.use_marlin or (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1191
            return None
1192
1193
1194
1195
        elif (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
        ):
1196
            # For now, fp4 moe only works with the flashinfer dispatcher.
1197
1198
1199
            prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(
                self.moe
            )
1200
1201
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
1202
        else:
1203
            return super().maybe_make_prepare_finalize(routing_tables)
1204

1205
1206
1207
    def select_gemm_impl(
        self,
        prepare_finalize: mk.FusedMoEPrepareAndFinalize,
1208
        layer: torch.nn.Module,
1209
    ) -> mk.FusedMoEPermuteExpertsUnpermute:
1210
        assert self.moe_quant_config is not None
1211
        experts = select_nvfp4_gemm_impl(
1212
1213
            self.moe,
            self.moe_quant_config,
1214
1215
1216
1217
            allow_flashinfer=self.allow_flashinfer,
        )
        logger.debug_once("Using %s", experts.__class__.__name__)
        return experts
1218

1219
1220
1221
1222
1223
1224
    def uses_weight_scale_2_pattern(self) -> bool:
        """
        FP4 variants use 'weight_scale_2' pattern for per-tensor weight scales.
        """
        return True

1225
1226
1227
1228
1229
1230
1231
1232
1233
    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
1234
        if not self.quant_config.is_checkpoint_nvfp4_serialized:
1235
1236
1237
1238
            raise ValueError(
                "NVFP4 quantization was selected, "
                " dynamic quantization is not supported."
            )
1239

1240
1241
        layer.num_experts = num_experts
        layer.params_dtype = params_dtype
1242
1243
1244
1245
        layer.quant_config = self.quant_config
        weight_dtype = torch.uint8
        weight_scale_dtype = torch.float8_e4m3fn
        weight_loader = extra_weight_attrs.get("weight_loader")
1246
        global_num_experts = extra_weight_attrs.get("global_num_experts")
1247
1248
1249
1250
        # GEMM 1
        w13_weight = ModelWeightParameter(
            data=torch.empty(
                num_experts,
1251
                (2 if self.moe.is_act_and_mul else 1) * intermediate_size_per_partition,
1252
1253
                # 2 fp4 items are packed in the input dimension
                hidden_size // 2,
1254
1255
                dtype=weight_dtype,
            ),
1256
1257
            input_dim=1,
            output_dim=2,
1258
1259
            weight_loader=weight_loader,
        )
1260
1261
1262
1263
1264
1265
1266
1267
1268
        layer.register_parameter("w13_weight", w13_weight)

        # GEMM 2
        w2_weight = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                hidden_size,
                # 2 fp4 items are packed in the input dimension
                intermediate_size_per_partition // 2,
1269
1270
                dtype=weight_dtype,
            ),
1271
1272
            input_dim=1,
            output_dim=2,
1273
1274
            weight_loader=weight_loader,
        )
1275
1276
1277
1278
1279
        layer.register_parameter("w2_weight", w2_weight)

        w13_weight_scale = ModelWeightParameter(
            data=torch.empty(
                num_experts,
1280
                (2 if self.moe.is_act_and_mul else 1) * intermediate_size_per_partition,
1281
1282
                # 2 fp4 items are packed in the input dimension
                hidden_size // self.quant_config.group_size,
1283
1284
                dtype=weight_scale_dtype,
            ),
1285
1286
            input_dim=1,
            output_dim=2,
1287
1288
            weight_loader=weight_loader,
        )
1289
1290
1291
1292
1293
1294
1295
        layer.register_parameter("w13_weight_scale", w13_weight_scale)

        w2_weight_scale = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                hidden_size,
                # 2 fp4 items are packed in the input dimension
1296
1297
1298
                intermediate_size_per_partition // self.quant_config.group_size,
                dtype=weight_scale_dtype,
            ),
1299
1300
            input_dim=1,
            output_dim=2,
1301
1302
            weight_loader=weight_loader,
        )
1303
1304
1305
        layer.register_parameter("w2_weight_scale", w2_weight_scale)

        extra_weight_attrs.update(
1306
1307
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
        )
1308
1309

        w13_weight_scale_2 = PerTensorScaleParameter(
1310
1311
1312
            data=torch.empty(
                num_experts, 2 if self.moe.is_act_and_mul else 1, dtype=torch.float32
            ),
1313
1314
            weight_loader=weight_loader,
        )
1315
1316
1317
1318
        layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)

        w2_weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(num_experts, dtype=torch.float32),
1319
1320
            weight_loader=weight_loader,
        )
1321
1322
1323
        layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2)

        extra_weight_attrs.update(
1324
1325
            {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
        )
1326

1327
1328
1329
1330
1331
        use_global_sf = self.allow_flashinfer and is_flashinfer_supporting_global_sf(
            self.flashinfer_moe_backend
        )
        global_scale_num_experts = global_num_experts if use_global_sf else num_experts

1332
        w13_input_scale = PerTensorScaleParameter(
1333
1334
1335
1336
1337
            data=torch.empty(
                global_scale_num_experts,
                2 if self.moe.is_act_and_mul else 1,
                dtype=torch.float32,
            ),
1338
1339
            weight_loader=weight_loader,
        )
1340
1341
        layer.register_parameter("w13_input_scale", w13_input_scale)

1342
        w2_input_scale = PerTensorScaleParameter(
1343
            data=torch.empty(global_scale_num_experts, dtype=torch.float32),
1344
1345
            weight_loader=weight_loader,
        )
1346
1347
1348
        layer.register_parameter("w2_input_scale", w2_input_scale)

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
1349
        # GEMM 1 processing
1350
1351
1352
        gemm1_weight = layer.w13_weight.data
        gemm1_weight_scale = layer.w13_weight_scale.data

1353
1354
1355
1356
1357
1358
1359
        if (
            self.allow_flashinfer
            and (
                self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
                or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
            )
            and self.moe.is_act_and_mul
1360
        ):
1361
            gemm1_weight, gemm1_weight_scale = reorder_w1w3_to_w3w1(
1362
1363
                gemm1_weight, gemm1_weight_scale, dim=-2
            )
1364
1365

        layer.w13_weight = Parameter(gemm1_weight, requires_grad=False)
1366
        layer.w13_weight_scale = Parameter(gemm1_weight_scale, requires_grad=False)
1367

1368
        # Common processing for w13_weight_scale_2
1369
        if self.moe.is_act_and_mul and not torch.allclose(
1370
1371
            layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]
        ):
1372
1373
            logger.warning_once(
                "w1_weight_scale_2 must match w3_weight_scale_2. "
1374
1375
                "Accuracy may be affected."
            )
1376

1377
        w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0].contiguous()
1378
        layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False)
1379

1380
        # Common processing for input scales and alphas
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
        use_global_sf = self.allow_flashinfer and is_flashinfer_supporting_global_sf(
            self.flashinfer_moe_backend
        )
        if use_global_sf:
            # For backends provide by Flashinfer, the input global scales are
            # shared across all experts.
            w13_input_scale = (
                layer.w13_input_scale.max().to(torch.float32).expand(layer.num_experts)
            )
        else:
            w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
1392
1393
        layer.g1_alphas = Parameter(
            (w13_input_scale * w13_weight_scale_2).to(torch.float32),
1394
1395
            requires_grad=False,
        )
1396
1397
1398

        # This is for quantization, so we need to invert it.
        layer.w13_input_scale_quant = Parameter(
1399
1400
            (1 / w13_input_scale).to(torch.float32), requires_grad=False
        )
1401

1402
        # GEMM 2 processing
1403
1404
1405
1406
1407
1408
1409
1410
        if use_global_sf:
            # For backends provide by Flashinfer, the input global scales are
            # shared across all experts.
            w2_input_scale = (
                layer.w2_input_scale.max().to(torch.float32).expand(layer.num_experts)
            )
        else:
            w2_input_scale = layer.w2_input_scale
1411
        layer.g2_alphas = Parameter(
1412
            (w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
1413
1414
            requires_grad=False,
        )
1415
1416
1417

        # This is for quantization, so we need to invert it.
        layer.w2_input_scale_quant = Parameter(
1418
            (1 / w2_input_scale).to(torch.float32), requires_grad=False
1419
        )
1420

1421
        # TensorRT-LLM specific processing
1422
1423
1424
1425
        if (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1426
            # Prepare static weights for TRT-LLM kernel
1427
            # alternate: prepare_static_weight_layouts_for_trtllm_moe
1428
1429
1430
1431
1432
            (
                gemm1_weights_fp4_shuffled,
                gemm1_scales_fp4_shuffled,
                gemm2_weights_fp4_shuffled,
                gemm2_scales_fp4_shuffled,
1433
            ) = prepare_static_weights_for_trtllm_fp4_moe(
1434
1435
1436
1437
1438
1439
1440
1441
                layer.w13_weight,
                layer.w2_weight,
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                layer.w2_weight.size(-2),  # hidden_size
                layer.w13_weight.size(-2) // 2,  # intermediate_size
                layer.w13_weight.size(0),  # num_experts
            )
1442
            logger.debug_once("Finished shuffling weights for TRT-LLM MOE")
1443
1444

            layer.gemm1_weights_fp4_shuffled = Parameter(
1445
1446
                gemm1_weights_fp4_shuffled, requires_grad=False
            )
1447
            layer.gemm2_weights_fp4_shuffled = Parameter(
1448
1449
                gemm2_weights_fp4_shuffled, requires_grad=False
            )
1450
            layer.gemm1_scales_fp4_shuffled = Parameter(
1451
1452
                gemm1_scales_fp4_shuffled, requires_grad=False
            )
1453
            layer.gemm2_scales_fp4_shuffled = Parameter(
1454
1455
                gemm2_scales_fp4_shuffled, requires_grad=False
            )
1456
1457
1458

            # Additional parameter needed for TRT-LLM
            layer.g1_scale_c = Parameter(
1459
                (layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32),
1460
1461
                requires_grad=False,
            )
1462

1463
1464
1465
1466
1467
            # Clean up weights that won't be used by TRT-LLM
            del layer.w2_weight
            del layer.w2_weight_scale
            del layer.w13_weight
            del layer.w13_weight_scale
1468
1469
1470
1471
1472
1473
1474
        elif self.use_marlin:
            # Marlin processing
            prepare_moe_fp4_layer_for_marlin(layer)
            del layer.g1_alphas
            del layer.g2_alphas
            del layer.w13_input_scale_quant
            del layer.w2_input_scale_quant
1475
1476
        else:
            # Non-TRT-LLM processing (Cutlass or non-flashinfer)
1477
1478
1479
1480
1481
            w13_blockscale_swizzled = swizzle_blockscale(layer.w13_weight_scale)
            layer.w13_weight_scale = Parameter(
                w13_blockscale_swizzled, requires_grad=False
            )

1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
            w13_weight = layer.w13_weight
            intermediate_size_pad = w13_blockscale_swizzled.size(1) - w13_weight.size(1)
            if intermediate_size_pad:
                # padding gated activations will require to split w1 and w3
                # and pad them individually
                assert not self.moe.is_act_and_mul, (
                    "The intermediate size required padding, "
                    "but padding is not implemented for gated activations"
                )

                layer.w13_weight = Parameter(
                    torch.nn.functional.pad(
                        w13_weight, (0, 0, 0, intermediate_size_pad)
                    ),
                    requires_grad=False,
                )
                layer.w2_weight = Parameter(
                    torch.nn.functional.pad(
                        layer.w2_weight, (0, intermediate_size_pad // 2, 0, 0)
                    ),
                    requires_grad=False,
                )
                layer.w2_weight_scale = Parameter(
                    torch.nn.functional.pad(
                        layer.w2_weight_scale, (0, intermediate_size_pad // 16)
                    ),
                    requires_grad=False,
                )

1511
            w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale)
1512
1513
1514
            layer.w2_weight_scale = Parameter(
                w2_blockscale_swizzled, requires_grad=False
            )
1515

1516
    def get_fused_moe_quant_config(
1517
        self, layer: torch.nn.Module
1518
    ) -> FusedMoEQuantConfig | None:
1519
1520
1521
1522
        if (
            self.use_marlin
            or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
            return None

        return nvfp4_moe_quant_config(
            w1_scale=layer.w13_weight_scale,
            w2_scale=layer.w2_weight_scale,
            g1_alphas=layer.g1_alphas,
            g2_alphas=layer.g2_alphas,
            a1_gscale=layer.w13_input_scale_quant,
            a2_gscale=layer.w2_input_scale_quant,
        )

1534
1535
1536
1537
    @property
    def supports_eplb(self) -> bool:
        return True

1538
1539
    def apply(
        self,
1540
        layer: FusedMoE,
1541
1542
        x: torch.Tensor,
        router_logits: torch.Tensor,
1543
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
1544
1545
1546
1547
1548
1549
1550
1551
        if not self.moe.is_act_and_mul:
            assert (
                self.allow_flashinfer
                and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
            ), (
                "Non-gated activations are only supported by the"
                " flashinfer CUTLASS backend for modelopt checkpoints"
            )
1552

1553
1554
1555
        if (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
1556
            and not layer.enable_eplb
1557
        ):
1558
1559
1560
1561
            return flashinfer_trtllm_fp4_moe(
                layer=layer,
                x=x,
                router_logits=router_logits,
1562
1563
1564
1565
1566
1567
                top_k=layer.top_k,
                global_num_experts=layer.global_num_experts,
                num_expert_group=layer.num_expert_group,
                topk_group=layer.topk_group,
                custom_routing_function=layer.custom_routing_function,
                e_score_correction_bias=layer.e_score_correction_bias,
1568
            )
1569

1570
        topk_weights, topk_ids, _ = layer.select_experts(
1571
1572
            hidden_states=x,
            router_logits=router_logits,
1573
        )
1574

1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
        # EPLB path
        if (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
            return flashinfer_trtllm_fp4_routed_moe(
                layer=layer,
                x=x,
                topk_ids=topk_ids,
                topk_weights=topk_weights,
                top_k=layer.top_k,
                global_num_experts=layer.global_num_experts,
            )

1589
        if self.use_marlin:
1590
            return fused_marlin_moe(
1591
1592
1593
                x,
                layer.w13_weight,
                layer.w2_weight,
1594
1595
                None,
                None,
1596
1597
1598
1599
1600
1601
1602
1603
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                router_logits,
                topk_weights,
                topk_ids,
                global_scale1=layer.w13_weight_scale_2,
                global_scale2=layer.w2_weight_scale_2,
                quant_type_id=scalar_types.float4_e2m1f.id,
1604
1605
1606
                apply_router_weight_on_input=layer.apply_router_weight_on_input,
                global_num_experts=layer.global_num_experts,
                expert_map=layer.expert_map,
1607
                input_dtype=self.marlin_input_dtype,
1608
            )
1609

1610
1611
1612
1613
        elif self.allow_flashinfer:
            assert self.flashinfer_moe_backend in (
                FlashinferMoeBackend.CUTLASS,
                FlashinferMoeBackend.CUTEDSL,
1614
            )
1615
1616
1617
1618
            if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
                from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (  # noqa: E501
                    flashinfer_cutlass_moe_fp4,
                )
1619

1620
1621
1622
1623
1624
1625
1626
                flashinfer_fn_moe_fp4 = flashinfer_cutlass_moe_fp4
            else:
                from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import (  # noqa: E501
                    flashinfer_cutedsl_moe_fp4,
                )

                flashinfer_fn_moe_fp4 = flashinfer_cutedsl_moe_fp4
1627

1628
1629
            assert self.moe_quant_config is not None
            return flashinfer_fn_moe_fp4(
1630
1631
1632
1633
1634
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
1635
1636
                quant_config=self.moe_quant_config,
                inplace=False,
1637
1638
1639
1640
                activation=layer.activation,
                global_num_experts=layer.global_num_experts,
                expert_map=layer.expert_map,
                apply_router_weight_on_input=layer.apply_router_weight_on_input,
1641
1642
            )
        else:
1643
1644
            # If no modular kernel is provided, use cutlass_moe_fp4 for TP case
            # only (no EP).
1645
1646
            from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4

1647
1648
            assert self.moe_quant_config is not None
            return cutlass_moe_fp4(
1649
1650
1651
1652
1653
                a=x,
                w1_fp4=layer.w13_weight,
                w2_fp4=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
1654
                quant_config=self.moe_quant_config,
1655
1656
                expert_map=layer.expert_map,
                apply_router_weight_on_input=layer.apply_router_weight_on_input,
1657
                # TODO: derive from arguments
1658
1659
1660
1661
                m=x.shape[0],
                n=layer.w2_weight.shape[2] * 2,
                k=x.shape[1],
                e=layer.w13_weight.shape[0],
1662
            )
1663
1664
1665
1666
1667


ModelOptNvFp4Config.LinearMethodCls = ModelOptNvFp4LinearMethod
ModelOptNvFp4Config.FusedMoEMethodCls = ModelOptNvFp4FusedMoE
ModelOptNvFp4Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod