modelopt.py 58.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Callable
5
from fnmatch import fnmatch
6
from typing import TYPE_CHECKING, Any, Optional
7
8
9
10
11

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter

12
13
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
14
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
15
from vllm.logger import init_logger
16
from vllm.model_executor.layers.fused_moe.config import (
17
18
19
20
    FusedMoEQuantConfig,
    fp8_w8a8_moe_quant_config,
    nvfp4_moe_quant_config,
)
21
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
22
from vllm.model_executor.layers.fused_moe.layer import (
23
24
25
26
27
28
29
30
31
    FusedMoE,
    FusedMoEMethodBase,
    FusedMoeWeightScaleSupported,
)
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
)
32
from vllm.model_executor.layers.quantization import QuantizationMethods
33
from vllm.model_executor.layers.quantization.base_config import (
34
35
36
    QuantizationConfig,
    QuantizeMethodBase,
)
37
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
38
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
39
    build_flashinfer_fp4_cutlass_moe_prepare_finalize,
40
41
    flashinfer_trtllm_fp4_moe,
    prepare_static_weights_for_trtllm_fp4_moe,
42
43
44
    reorder_w1w3_to_w3w1,
    select_nvfp4_gemm_impl,
)
45
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
46
47
    FlashinferMoeBackend,
    apply_flashinfer_per_tensor_scale_fp8,
48
    build_flashinfer_fp8_cutlass_moe_prepare_finalize,
49
50
    flashinfer_cutlass_moe_fp8,
    get_flashinfer_moe_backend,
51
    is_flashinfer_supporting_global_sf,
52
53
54
55
56
    register_moe_scaling_factors,
    rotate_flashinfer_fp8_moe_weights,
    select_cutlass_fp8_gemm_impl,
    swap_w13_to_w31,
)
57
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
58
59
60
61
62
    apply_fp4_marlin_linear,
    is_fp4_marlin_supported,
    prepare_fp4_layer_for_marlin,
    prepare_moe_fp4_layer_for_marlin,
)
63
from vllm.model_executor.layers.quantization.utils.quant_utils import (
64
65
66
67
68
    GroupShape,
    cutlass_fp4_supported,
    is_layer_skipped,
    swizzle_blockscale,
)
69
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
70
71
72
73
    Fp8LinearOp,
    requantize_with_max_scale,
)
from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
74
from vllm.scalar_type import scalar_types
75
76
77
78
79
from vllm.utils.flashinfer import (
    flashinfer_scaled_fp4_mm,
    has_flashinfer,
    has_flashinfer_moe,
)
80

81
82
83
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

84
85
logger = init_logger(__name__)

86
87
QUANT_ALGOS = ["FP8", "NVFP4"]
KV_CACHE_QUANT_ALGOS = ["FP8"]
88
89


90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
    """

    def __init__(self, quant_config: "ModelOptQuantConfigBase"):
        super().__init__(quant_config)


class ModelOptQuantConfigBase(QuantizationConfig):
    LinearMethodCls: type = LinearMethodBase
    FusedMoEMethodCls: type = FusedMoEMethodBase
    KVCacheMethodCls: type = BaseKVCacheMethod

    def __init__(
        self,
        exclude_modules: list[str],
    ):
        super().__init__()
        self.exclude_modules: list[str] = exclude_modules

    def is_layer_excluded(self, prefix: str) -> bool:
        """
        Check if a layer should be excluded from quantization.

        Handles both exact matching (for fused layers) and ModelOpt wildcard matching.

        The ModelOpt exclude_modules list is a list of wildcards.
        """
        if len(self.exclude_modules) == 0:
            return False

        # First check exact matching with fused layer support
        if is_layer_skipped(prefix, self.exclude_modules, self.packed_modules_mapping):
            return True

        # TODO: This special hard coded logic is not needed for quantized checkpoints
        # generated by ModelOpt >= 0.39.0 where they are handled natually by the
        # exclude_modules config. But need to keep them for loading quantized
        # checkpoints generated by older versions. Then check substring matching
        # for patterns not caught by exact match
        for exclude_module in self.exclude_modules:
            # Skip exact matches already handled above
            if exclude_module != prefix and (
                exclude_module in prefix
                or (
                    prefix.startswith("language_model.")
                    and exclude_module in prefix.removeprefix("language_model.")
                )
            ):
                return True

        # modelopt exclude modules are not simple strings, they are wildcards
        for wildcard_pattern in self.exclude_modules:
            if fnmatch(prefix, wildcard_pattern):
                return True

        return False

    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
        from vllm.attention.layer import Attention  # Avoid circular import

        # handle kv-cache first so we can focus only on weight quantization thereafter
        if isinstance(layer, Attention):
            return self.KVCacheMethodCls(self)

        # handle exclusion
        if self.is_layer_excluded(prefix):
            if isinstance(layer, LinearBase):
                return UnquantizedLinearMethod()
            return None

        # TODO: This special hard coded logic is not needed for quantized checkpoints
        # generated by ModelOpt >= 0.39.0 where they are handled natually by the
        # exclude_modules config. But need to keep them for loading quantized
        # checkpoints generated by older versions. Then check substring matching
        # for patterns not caught by exact match
        if "vision_tower" in prefix or "vision_model" in prefix:
            return UnquantizedLinearMethod()

        # now, the layer is quantized, handle it here
        if isinstance(layer, LinearBase):
            return self.LinearMethodCls(self)
        elif isinstance(layer, FusedMoE):
            return self.FusedMoEMethodCls(quant_config=self, layer=layer)

        return None

    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if len(self.exclude_modules) > 0:
            self.exclude_modules = hf_to_vllm_mapper.apply_list(self.exclude_modules)

    @staticmethod
    def get_config_filenames() -> list[str]:
        return ["hf_quant_config.json"]

    @classmethod
    def _from_config(
        cls,
        *,
        quant_method: str,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
        original_config: dict[str, Any],
        group_size: int | None,
    ) -> "ModelOptQuantConfigBase":
        raise NotImplementedError("Please implement this function in sub classes")

    @classmethod
    def from_config(cls, config: dict[str, Any]) -> "ModelOptQuantConfigBase":
        # Handle both ModelOpt format and compressed-tensors style format
        if "quantization" in config:
            # Traditional ModelOpt format:
            # {"quantization": {"quant_algo": "..."}}
            quant_config = cls.get_from_keys(config, ["quantization"])
            if not isinstance(quant_config, dict):
                raise ValueError("Expected 'quantization' to be a dictionary in config")

            quant_method = quant_config.get("quant_algo")

            # Handle kv_cache_quant_algo with proper type validation
            kv_cache_quant_method = quant_config.get("kv_cache_quant_algo")

            # Handle group_size with proper type validation
            group_size_raw = quant_config.get("group_size")

            # "exclude_modules" is the key in the legacy hf_quant_config.json
            exclude_modules = quant_config.get("exclude_modules", [])
        else:
            # Compressed-tensors style format:
            # {"quant_algo": "...", "quant_method": "modelopt"}
            quant_method = config.get("quant_algo")
            kv_cache_quant_method = config.get("kv_cache_quant_algo")
            # "ignore" is the key in config.json
            exclude_modules = config.get("ignore", [])
            group_size_raw = config.get("group_size")

        if not quant_method:
            raise ValueError("Missing 'quant_algo' in quantization config")

        if kv_cache_quant_method is None:
            # No KV cache quantization, keep this branch just to have this comment
            pass
        elif not isinstance(kv_cache_quant_method, str):
            raise ValueError(
                f"kv_cache_quant_algo must be a string, got "
                f"{type(kv_cache_quant_method)}"
            )

        if not isinstance(exclude_modules, list):
            raise ValueError(
                f"exclude_modules must be a list, got {type(exclude_modules)}"
            )

        if group_size_raw is None:
            group_size = None
        elif isinstance(group_size_raw, int):
            group_size = group_size_raw
        else:
            try:
                group_size = int(group_size_raw)
            except (ValueError, TypeError):
                raise ValueError(
                    f"group_size must be an integer, got {type(group_size_raw)}"
                ) from None

        if quant_method not in QUANT_ALGOS:
            raise ValueError(
                f"ModelOpt currently only supports: {QUANT_ALGOS} "
                "quantizations in vLLM. Please check the "
                "`hf_quant_config.json` file for your model's "
                "quant configuration."
            )
        return cls._from_config(
            quant_method=quant_method,
            kv_cache_quant_method=kv_cache_quant_method,
            exclude_modules=exclude_modules,
            group_size=group_size,
            original_config=config,
        )


class ModelOptFp8Config(ModelOptQuantConfigBase):
275
276
277
278
    """Config class for ModelOpt FP8."""

    def __init__(
        self,
279
280
281
        is_checkpoint_fp8_serialized: bool,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
282
    ) -> None:
283
        super().__init__(exclude_modules)
284
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
285
        self.kv_cache_quant_method = kv_cache_quant_method
286
        if is_checkpoint_fp8_serialized:
287
288
289
290
            logger.warning(
                "Detected ModelOpt fp8 checkpoint. Please note that"
                " the format is experimental and could change."
            )
291

292
    def get_name(self) -> QuantizationMethods:
293
294
        return "modelopt"

295
    def get_supported_act_dtypes(self) -> list[torch.dtype]:
296
297
298
299
300
301
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
        return 89

302
303
    @classmethod
    def override_quantization_method(
304
        cls, hf_quant_cfg, user_quant
305
    ) -> QuantizationMethods | None:
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
        """Detect if this ModelOpt config should be used based on
        quantization config."""

        if hf_quant_cfg is None:
            return None

        # Use the community standard 'quant_method'
        quant_method = hf_quant_cfg.get("quant_method", "").lower()

        # Only proceed if the method is explicitly "modelopt"
        if quant_method != "modelopt":
            return None

        # Look for ModelOpt-specific config structure
        if "quantization" in hf_quant_cfg:
            quant_config = hf_quant_cfg["quantization"]
            if isinstance(quant_config, dict):
                quant_algo = quant_config.get("quant_algo", "")
                if "FP8" in quant_algo:
                    return "modelopt"
        else:
            # Check for compressed-tensors style config with specific quant_algo
            quant_algo = hf_quant_cfg.get("quant_algo", "")
            if isinstance(quant_algo, str) and "FP8" in quant_algo:
                return "modelopt"

        return None

334
    @classmethod
335
336
337
338
339
340
341
342
343
    def _from_config(
        cls,
        *,
        quant_method: str,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
        original_config: dict[str, Any],
        **kwargs: Any,
    ) -> "ModelOptFp8Config":
344
        is_checkpoint_fp8_serialized = "FP8" in quant_method
345

346
        return cls(is_checkpoint_fp8_serialized, kv_cache_quant_method, exclude_modules)
347

348
349
350
351

class ModelOptFp8LinearMethod(LinearMethodBase):
    """Linear method for Model Optimizer static quantization.
    Supports loading FP8 checkpoints with static weight scale and
352
    activation scale. Future support might be added for dynamic
353
354
355
356
    scales.

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
357
    2. Only support float8_e4m3fn datatype
358
359
360
        Args: quant_config: The ModelOpt quantization config.
    """

361
    def __init__(self, quant_config: ModelOptFp8Config) -> None:
362
        self.quant_config = quant_config
363
        self.fp8_linear = Fp8LinearOp(
364
365
            act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR
        )
366
367
368
369
370

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
371
        output_partition_sizes: list[int],
372
373
374
375
376
377
378
379
380
381
382
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
383
384
385
386
387
388
389
390
391
392
393
394
395
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_fp8_serialized
            else params_dtype
        )
        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition, input_size_per_partition, dtype=weight_dtype
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
396
397
398
399
        layer.register_parameter("weight", weight)

        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
400
401
402
403
            weight_scale = PerTensorScaleParameter(
                data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
                weight_loader=weight_loader,
            )
404
405
406
            weight_scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("weight_scale", weight_scale)
            # INPUT SCALE
407
408
409
410
            scale = PerTensorScaleParameter(
                data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
                weight_loader=weight_loader,
            )
411
412
413
414
415

            scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("input_scale", scale)

    def process_weights_after_loading(self, layer: Module) -> None:
416
417
418
419
        weight = layer.weight
        max_w_scale = layer.weight_scale.max()
        if not (layer.weight_scale == layer.weight_scale[0]).all():
            max_w_scale, weight = requantize_with_max_scale(
420
421
                layer.weight, layer.weight_scale, layer.logical_widths
            )
422
423
        layer.weight = Parameter(weight.t(), requires_grad=False)
        layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
424
        layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
425
426
427
428
429

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
430
        bias: torch.Tensor | None = None,
431
    ) -> torch.Tensor:
432
433
434
435
436
437
438
        return self.fp8_linear.apply(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            input_scale=layer.input_scale,
            bias=bias,
        )
439
440


441
442
443
444
445
446
447
448
class ModelOptFp8MoEMethod(FusedMoEMethodBase):
    """MoE method for ModelOpt FP8.
    Supports loading FP8 checkpoints with static weight scale and
    activation scale.
    Args:
        quant_config: The ModelOpt quantization config.
    """

449
450
451
    def __init__(
        self,
        quant_config: ModelOptFp8Config,
452
        layer: FusedMoE,
453
    ) -> None:
454
455
        super().__init__(layer.moe_config)
        self.layer = layer
456
457
        self.quant_config = quant_config
        from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
458
459
460
            cutlass_fp8_supported,
        )

461
        self.cutlass_fp8_supported = cutlass_fp8_supported()
462
        self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
463
        if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
464
            self.flashinfer_moe_backend = get_flashinfer_moe_backend()
465
466
467
468
469
470
471
472
473
474
            if (
                self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
                and not self.moe.is_act_and_mul
            ):
                logger.info_once(
                    "Non-gated MoE is not supported for min-latency mode,"
                    "falling back to high-throughput mode"
                )
                self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS

475
            logger.info_once(
476
477
478
479
                f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
            )

    def maybe_make_prepare_finalize(
480
        self,
481
        routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
482
    ) -> mk.FusedMoEPrepareAndFinalize | None:
483
484
485
486
        # TRT LLM not supported with all2all yet.
        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
            return None
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
487
488
489
            prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
                self.moe
            )
490
491
492
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
        else:
493
            return super().maybe_make_prepare_finalize(routing_tables)
494
495
496
497

    def select_gemm_impl(
        self,
        prepare_finalize: mk.FusedMoEPrepareAndFinalize,
498
        layer: torch.nn.Module,
499
    ) -> mk.FusedMoEPermuteExpertsUnpermute:
500
        assert self.moe_quant_config is not None
501
        experts = select_cutlass_fp8_gemm_impl(
502
503
            self.moe,
            self.moe_quant_config,
504
505
506
        )
        logger.debug_once("Using %s", experts.__class__.__name__)
        return experts
507
508
509
510
511
512
513
514
515
516
517

    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        # Use FP8 dtype if checkpoint is serialized
518
519
520
521
522
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_fp8_serialized
            else params_dtype
        )
523
524
        weight_loader = extra_weight_attrs.get("weight_loader")

525
526
527
528
529
        if self.moe.is_act_and_mul:
            w13_up_dim = 2 * intermediate_size_per_partition
        else:
            w13_up_dim = intermediate_size_per_partition

530
        w13_weight = ModelWeightParameter(
531
532
            data=torch.empty(
                num_experts,
533
                w13_up_dim,
534
535
536
                hidden_size,
                dtype=weight_dtype,
            ),
537
538
539
540
541
542
543
            input_dim=2,
            output_dim=1,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_weight", w13_weight)

        w2_weight = ModelWeightParameter(
544
545
546
547
548
549
            data=torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=weight_dtype,
            ),
550
551
552
553
554
555
556
557
            input_dim=2,
            output_dim=1,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w2_weight", w2_weight)

        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALES - Per-tensor scaling for ModelOpts
558
            # For gated MoE, allocate 2 scales for w1 and w3 respectively.
559
            # They will be combined to a single scale after weight loading.
560
561
562
563
564
            # For non-gated MoE, allocate 1 scale for w13.
            if self.moe.is_act_and_mul:
                w13_weight_scale_shape = (num_experts, 2)
            else:
                w13_weight_scale_shape = (num_experts, 1)
565
566
            w13_weight_scale = PerTensorScaleParameter(
                data=torch.full(
567
                    w13_weight_scale_shape,
568
569
570
571
572
573
                    1.0,
                    dtype=torch.float32,
                ),
                weight_loader=weight_loader,
            )
            w2_weight_scale = PerTensorScaleParameter(
574
                data=torch.full((num_experts,), 1.0, dtype=torch.float32),
575
576
577
578
579
580
581
                weight_loader=weight_loader,
            )
            layer.register_parameter("w13_weight_scale", w13_weight_scale)
            layer.register_parameter("w2_weight_scale", w2_weight_scale)

            # Set weight loader attributes for scales
            extra_weight_attrs.update(
582
583
                {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
            )
584
585
586

            # INPUT SCALES - Per-tensor scaling for ModelOpt
            w13_input_scale = PerTensorScaleParameter(
587
                data=torch.full((num_experts,), 1.0, dtype=torch.float32),
588
589
590
                weight_loader=weight_loader,
            )
            w2_input_scale = PerTensorScaleParameter(
591
                data=torch.full((num_experts,), 1.0, dtype=torch.float32),
592
593
594
595
596
597
598
599
600
601
                weight_loader=weight_loader,
            )
            layer.register_parameter("w13_input_scale", w13_input_scale)
            layer.register_parameter("w2_input_scale", w2_input_scale)

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        """Process FP8 MoE weights after loading from serialized checkpoint.
        Only supports pre-quantized checkpoints with FP8 weights and scales.
        """

602
        layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
603
604
605
606
        layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)

        from vllm._custom_ops import scaled_fp8_quant
        from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
607
608
            per_tensor_dequantize,
        )
609
610

        # Handle scale parameters
611
        if hasattr(layer, "w13_weight_scale") and layer.w13_weight_scale is not None:
612
613
614
            # Fp8 moe kernel needs single weight scale for w13 per expert.
            # We take the max of the w1 and w3 scales
            # then dequant and requant each expert.
615
616
617
618
619
620
621
622
            if (
                layer.w13_weight_scale.dim() == 2
                and layer.w13_weight_scale.shape[1] == 2
            ):
                assert self.moe.is_act_and_mul, (
                    "w13_weight_scale should have 2 elements per expert "
                    "only for gated MoE"
                )
623
624
625
626
627
628
629
630
631
632
633
634
                # Get the maximum scale across w1 and w3 for each expert
                max_w13_scales = layer.w13_weight_scale.max(dim=1).values

                # Requantize each expert's weights using the combined scale
                # w13_weight (num_experts, 2 * intermediate_size, hidden_size)
                # where the first intermediate_size rows are w1, the next are w3
                intermediate_size = layer.w13_weight.shape[1] // 2
                for expert_id in range(layer.w13_weight.shape[0]):
                    start = 0
                    for shard_id in range(2):  # w1 and w3
                        # Dequantize using the original scale for this shard
                        dq_weight = per_tensor_dequantize(
635
636
637
                            layer.w13_weight[expert_id][
                                start : start + intermediate_size, :
                            ],
638
639
640
641
642
                            layer.w13_weight_scale[expert_id][shard_id],
                        )
                        # Requantize using the combined max scale

                        (
643
644
645
                            layer.w13_weight[expert_id][
                                start : start + intermediate_size, :
                            ],
646
                            _,
647
                        ) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
648
649
650
651

                        start += intermediate_size

                # Update the scale parameter to be per-expert
652
                layer.w13_weight_scale = Parameter(max_w13_scales, requires_grad=False)
653
            else:
654
655
656
                layer.w13_weight_scale = Parameter(
                    layer.w13_weight_scale.data, requires_grad=False
                )
657

658
659
660
661
        if hasattr(layer, "w2_weight_scale") and layer.w2_weight_scale is not None:
            layer.w2_weight_scale = Parameter(
                layer.w2_weight_scale.data, requires_grad=False
            )
662
        # Input scales must be equal for each expert in fp8 MoE layers.
663
664
665
666
667
668
669
670
        if hasattr(layer, "w13_input_scale") and layer.w13_input_scale is not None:
            layer.w13_input_scale = Parameter(
                layer.w13_input_scale.max(), requires_grad=False
            )
        if hasattr(layer, "w2_input_scale") and layer.w2_input_scale is not None:
            layer.w2_input_scale = Parameter(
                layer.w2_input_scale.max(), requires_grad=False
            )
671

672
        if self.flashinfer_moe_backend is not None:
673
674
            if self.moe.is_act_and_mul:
                layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
675
            if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
676
                rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight)
677
        register_moe_scaling_factors(layer)
678

679
    def get_fused_moe_quant_config(
680
        self, layer: torch.nn.Module
681
    ) -> FusedMoEQuantConfig | None:
682
683
684
685
686
        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
            return None

        return fp8_w8a8_moe_quant_config(
            w1_scale=layer.w13_weight_scale,
687
            g1_alphas=layer.output1_scales_gate_scalar.squeeze(),
688
            w2_scale=layer.w2_weight_scale,
689
            g2_alphas=layer.output2_scales_scalar.squeeze(),
690
            a1_scale=layer.w13_input_scale,
691
            a1_gscale=layer.w13_input_scale,
692
            a2_scale=layer.w2_input_scale,
693
            a2_gscale=layer.w2_input_scale_inv,
694
695
696
            per_act_token_quant=False,
        )

697
698
    def apply(
        self,
699
        layer: FusedMoE,
700
701
702
703
704
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
        use_grouped_topk: bool = False,
705
706
        topk_group: int | None = None,
        num_expert_group: int | None = None,
707
        global_num_experts: int = -1,
708
709
        expert_map: torch.Tensor | None = None,
        custom_routing_function: Callable | None = None,
710
        scoring_func: str = "softmax",
711
        routed_scaling_factor: float = 1.0,
712
        e_score_correction_bias: torch.Tensor | None = None,
713
714
715
        apply_router_weight_on_input: bool = False,
        activation: str = "silu",
        enable_eplb: bool = False,
716
717
718
719
        expert_load_view: torch.Tensor | None = None,
        logical_to_physical_map: torch.Tensor | None = None,
        logical_replica_count: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
720
        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
721
722
723
724
            if layer.enable_eplb:
                raise NotImplementedError(
                    "EPLB not supported for `ModelOptFp8MoEMethod` yet."
                )
725
726
727
            assert activation == "silu", (
                f"Expected 'silu' activation but got {activation}"
            )
728
729
730
731
732
733
734
735
736
737
            assert not renormalize
            return apply_flashinfer_per_tensor_scale_fp8(
                layer=layer,
                hidden_states=x,
                router_logits=router_logits,
                routing_bias=e_score_correction_bias,
                global_num_experts=global_num_experts,
                top_k=top_k,
                num_expert_group=num_expert_group,
                topk_group=topk_group,
738
739
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
740

741
        # Expert selection
742
        topk_weights, topk_ids, _ = layer.select_experts(
743
744
745
            hidden_states=x,
            router_logits=router_logits,
        )
746

747
        if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
748
749
750
            assert activation in ("silu", "relu2_no_mul"), (
                "Expected activation to be in ('silu', 'relu2_no_mul'),"
                f"but got {activation}"
751
            )
752
753
754
755
756
757
758
759
760
761
762
763
            return flashinfer_cutlass_moe_fp8(
                x,
                layer,
                topk_weights,
                topk_ids,
                inplace=False,
                activation=activation,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
        else:
764
765
            from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts

766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
            assert self.moe_quant_config is not None

            return fused_experts(
                x,
                layer.w13_weight,
                layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                inplace=True,
                activation=activation,
                quant_config=self.moe_quant_config,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
781
782


783
784
785
786
787
788
ModelOptFp8Config.LinearMethodCls = ModelOptFp8LinearMethod
ModelOptFp8Config.FusedMoEMethodCls = ModelOptFp8MoEMethod
ModelOptFp8Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod


class ModelOptNvFp4Config(ModelOptQuantConfigBase):
789
790
791
792
793
    """Config class for ModelOpt FP4."""

    def __init__(
        self,
        is_checkpoint_nvfp4_serialized: bool,
794
        kv_cache_quant_algo: str | None,
795
        exclude_modules: list[str],
796
797
        group_size: int = 16,
    ) -> None:
798
        super().__init__(exclude_modules)
799
800
801
802
        self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
        if is_checkpoint_nvfp4_serialized:
            logger.warning(
                "Detected ModelOpt NVFP4 checkpoint. Please note that"
803
804
                " the format is experimental and could change in future."
            )
805
806
807
808

            self.group_size = group_size
            self.kv_cache_quant_algo = kv_cache_quant_algo

809
    def get_name(self) -> QuantizationMethods:
810
        return "modelopt_fp4"
811

812
    def get_supported_act_dtypes(self) -> list[torch.dtype]:
813
814
815
816
        return [torch.bfloat16, torch.half, torch.float8_e4m3fn]

    @classmethod
    def get_min_capability(cls) -> int:
817
        return 80
818

819
820
    @classmethod
    def override_quantization_method(
821
        cls, hf_quant_cfg, user_quant
822
    ) -> QuantizationMethods | None:
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
        """Detect if this ModelOpt FP4 config should be used based on
        quantization config."""
        if hf_quant_cfg is None:
            return None

        # Use the community standard 'quant_method'
        quant_method = hf_quant_cfg.get("quant_method", "").lower()

        # Only proceed if the method is explicitly "modelopt"
        if quant_method != "modelopt":
            return None

        # Look for ModelOpt-specific config structure
        if "quantization" in hf_quant_cfg:
            quant_config = hf_quant_cfg["quantization"]
            if isinstance(quant_config, dict):
                quant_algo = quant_config.get("quant_algo", "")
                if "NVFP4" in quant_algo:
                    return "modelopt_fp4"
        else:
            # Check for compressed-tensors style config with specific
            # quant_algo field
            quant_algo = hf_quant_cfg.get("quant_algo", "")
            if isinstance(quant_algo, str) and "FP4" in quant_algo.upper():
                return "modelopt_fp4"

        return None

851
    @classmethod
852
853
854
855
856
857
858
859
860
861
    def _from_config(
        cls,
        *,
        quant_method: str,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
        original_config: dict[str, Any],
        group_size: int | None,
        **kwargs: Any,
    ) -> "ModelOptNvFp4Config":
862
        is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
863

864
865
866
        if group_size is None:
            group_size = 16  # Default value

867
        # For FP4, these fields are required
868
        if is_checkpoint_nvfp4_serialized and "quantization" in original_config:
869
            # Check if required fields are present in the quantization config
870
            quant_config = original_config["quantization"]
871
            required_fields = ["group_size", "kv_cache_quant_algo", "exclude_modules"]
872
873
874
875
876
877
            missing_fields = [
                field for field in required_fields if field not in quant_config
            ]
            if missing_fields:
                raise ValueError(
                    f"NVFP4 quantization requires the following fields in "
878
879
880
881
882
                    f"hf_quant_config.json: {missing_fields}"
                )

        return cls(
            is_checkpoint_nvfp4_serialized,
883
            kv_cache_quant_method,
884
885
886
            exclude_modules,
            group_size,
        )
887
888
889
890
891


class ModelOptNvFp4LinearMethod(LinearMethodBase):
    """Linear method for Model Optimizer NVFP4.
    Supports loading NVFP4 checkpoints with the following structure:
892

893
894
895
896
897
898
899
    input_scale: torch.float32, scalar ,
    weight: NVFP4(represented as byte) Shape: [1, X, y/2]
    weight_scale: FP8-E4M3, Shape: [X, Y], aka per block scale,
    weight_scale_2: torch.float32, scalar,
    Args: quant_config: The ModelOpt quantization config.
    """

900
    def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
901
        self.quant_config = quant_config
902

903
904
905
906
907
908
909
910
911
912
913
        self.backend = "none"
        if envs.VLLM_NVFP4_GEMM_BACKEND is None:
            if has_flashinfer():
                self.backend = "flashinfer-cutlass"
            elif cutlass_fp4_supported():
                self.backend = "cutlass"
            elif is_fp4_marlin_supported():
                self.backend = "marlin"
        elif envs.VLLM_NVFP4_GEMM_BACKEND.startswith("flashinfer-"):
            self.backend = envs.VLLM_NVFP4_GEMM_BACKEND
            assert has_flashinfer(), f"FlashInfer is required for {self.backend}"
914
915
916
        elif envs.VLLM_NVFP4_GEMM_BACKEND == "cutlass":
            self.backend = "cutlass"
            assert cutlass_fp4_supported(), f"Cutlass is required for {self.backend}"
917
918

        if self.backend == "none":
919
            raise ValueError(
920
921
                "No valid NVFP4 GEMM backend found. "
                "Please check your platform capability."
922
            )
923

924
925
        logger.info_once(f"Using {self.backend} for NVFP4 GEMM")

926
927
928
929
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
930
        output_partition_sizes: list[int],
931
932
933
934
935
936
937
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size
        if not self.quant_config.is_checkpoint_nvfp4_serialized:
938
939
940
941
            raise ValueError(
                "NVFP4 quantization was selected, "
                " dynamic quantization is not supported."
            )
942
943
944
945
946
947
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition

948
949
950
951
        if input_size_per_partition % 16 != 0:
            raise ValueError(
                "Unsupported model when in features size is not multiple of 16"
            )
952
        # The nvfp4 weight is still represented as
953
954
955
956
957
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_nvfp4_serialized
            else params_dtype
        )
958
959
960
961
962
963
        # Weight
        weight = ModelWeightParameter(
            data=torch.empty(
                # 2 fp4 items are packed in the input dimension
                layer.output_size_per_partition,
                layer.input_size_per_partition // 2,
964
965
                dtype=torch.uint8,
            ),
966
967
            input_dim=1,
            output_dim=0,
968
969
            weight_loader=weight_loader,
        )
970
971
972
        layer.register_parameter("weight", weight)

        # Input Weight Scale
973
974
975
976
        input_scale = PerTensorScaleParameter(
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )
977
978
979
        layer.register_parameter("input_scale", input_scale)

        # Global Weight Scale
980
981
982
983
        weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )
984
985
986
        layer.register_parameter("weight_scale_2", weight_scale_2)

        # Per Block Weight Scale
987
988
989
990
991
992
993
994
995
996
        weight_scale = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition // self.quant_config.group_size,
                dtype=weight_dtype,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007

        layer.register_parameter("weight_scale", weight_scale)

    def process_weights_after_loading(self, layer: Module) -> None:
        # global scales:
        input_scale_2 = layer.input_scale.max().to(torch.float32)
        layer.input_scale = Parameter(input_scale_2, requires_grad=False)

        weight_scale_2 = layer.weight_scale_2.max().to(torch.float32)
        layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False)

1008
1009
1010
        layer.alpha = Parameter(
            layer.input_scale * layer.weight_scale_2, requires_grad=False
        )
1011

1012
1013
        # Calculate `1 / input_scale` so that we don't need to do so at runtime
        layer.input_scale_inv = Parameter(
1014
1015
            (1 / layer.input_scale).to(torch.float32), requires_grad=False
        )
1016

1017
1018
1019
        # Swizzle the weight blockscale.
        # contracting dimension is input dimension
        # block_size = 16;
1020
1021
1022
        assert layer.weight_scale.dtype == torch.float8_e4m3fn, (
            "Weight Block scale must be represented as FP8-E4M3"
        )
1023

1024
1025
1026
1027
1028
        if self.backend == "marlin":
            prepare_fp4_layer_for_marlin(layer)
            del layer.alpha
            del layer.input_scale
        elif self.backend == "flashinfer-trtllm":
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
            # FlashInfer TRTLLM FP4 GEMM requires a different weight layout.
            # FlashInfer provides nvfp4_quantize to quantize + shuffle the
            # layout but we use our own quantization so we have to call
            # shuffles ourselves.
            from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a

            weight = layer.weight.data
            weight_scale = layer.weight_scale.data

            epilogue_tile_m = 128
1039
1040
1041
1042
1043
1044
            weight = shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m)
            weight_scale = (
                shuffle_matrix_sf_a(weight_scale.view(torch.uint8), epilogue_tile_m)
                .reshape(weight_scale.shape)
                .view(torch.float8_e4m3fn)
            )
1045

1046
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
1047
1048
1049
            layer.weight = Parameter(weight, requires_grad=False)
        else:
            swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
1050
            layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False)
1051
            layer.weight = Parameter(layer.weight.data, requires_grad=False)
1052
1053
1054
1055
1056

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
1057
        bias: torch.Tensor | None = None,
1058
    ) -> torch.Tensor:
1059
        if self.backend == "marlin":
1060
1061
1062
1063
1064
1065
1066
1067
            return apply_fp4_marlin_linear(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale,
                weight_scale_2=layer.weight_scale_2,
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
1068
1069
                bias=bias,
            )
1070

1071
        output_dtype = x.dtype
1072
        output_shape = [x.shape[0], layer.weight.shape[0]]
1073
1074

        # quantize BF16 or FP16 to (FP4 and interleaved block scale)
1075
        x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_scale_inv)
1076
1077
1078

        # validate dtypes of quantized input, input block scale,
        # weight and weight_blockscale
1079
1080
1081
1082
1083
        assert x_fp4.dtype == torch.uint8
        assert layer.weight.dtype == torch.uint8
        assert x_blockscale.dtype == torch.float8_e4m3fn
        assert layer.weight_scale.dtype == torch.float8_e4m3fn
        assert layer.alpha.dtype == torch.float32
1084

1085
1086
1087
1088
        mm_args = (
            x_fp4,
            layer.weight,
            x_blockscale,
1089
            layer.weight_scale,
1090
1091
1092
            layer.alpha,
            output_dtype,
        )
1093
1094
1095
        if self.backend.startswith("flashinfer-"):
            backend_name = self.backend[len("flashinfer-") :]
            out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name)
1096
        else:
1097
            assert self.backend == "cutlass"
1098
1099
            out = cutlass_scaled_fp4_mm(*mm_args)

1100
1101
1102
        if bias is not None:
            out = out + bias
        return out.view(*output_shape)
1103
1104
1105
1106
1107


class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
    """
    MoE Method for FP4 Quantization.
1108
    Args:
1109
1110
1111
        quant_config: NVFP4 Quant Config
    """

1112
1113
1114
    def __init__(
        self,
        quant_config: ModelOptNvFp4Config,
1115
        layer: FusedMoE,
1116
    ) -> None:
1117
1118
        from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import (
            detect_nvfp4_moe_support,  # noqa: E501
1119
1120
        )

1121
        super().__init__(layer.moe_config)
1122
1123
        self.quant_config = quant_config
        self.layer = layer
1124
1125
        _nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
        self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
1126
        self.allow_flashinfer = _nvfp4.allow_flashinfer
1127
        self.use_marlin = _nvfp4.use_marlin
1128
1129
        self.flashinfer_moe_backend = None
        if self.allow_flashinfer:
1130
1131
1132
            self.flashinfer_moe_backend = get_flashinfer_moe_backend()
            logger.info_once(
                f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
1133
1134
                " for ModelOptNvFp4FusedMoE."
            )
1135
1136
1137
1138
        elif self.use_marlin:
            logger.info_once("Using Marlin for ModelOptNvFp4FusedMoE.")
        else:
            logger.info_once("Using Cutlass for ModelOptNvFp4FusedMoE.")
1139

1140
1141
1142
1143
    def maybe_make_prepare_finalize(
        self,
        routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
    ) -> mk.FusedMoEPrepareAndFinalize | None:
1144
1145
1146
1147
        if self.use_marlin or (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1148
            return None
1149
1150
1151
1152
        elif (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
        ):
1153
            # For now, fp4 moe only works with the flashinfer dispatcher.
1154
1155
1156
            prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(
                self.moe
            )
1157
1158
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
1159
        else:
1160
            return super().maybe_make_prepare_finalize(routing_tables)
1161

1162
1163
1164
    def select_gemm_impl(
        self,
        prepare_finalize: mk.FusedMoEPrepareAndFinalize,
1165
        layer: torch.nn.Module,
1166
    ) -> mk.FusedMoEPermuteExpertsUnpermute:
1167
        assert self.moe_quant_config is not None
1168
        experts = select_nvfp4_gemm_impl(
1169
1170
            self.moe,
            self.moe_quant_config,
1171
1172
1173
1174
            allow_flashinfer=self.allow_flashinfer,
        )
        logger.debug_once("Using %s", experts.__class__.__name__)
        return experts
1175

1176
1177
1178
1179
1180
1181
    def uses_weight_scale_2_pattern(self) -> bool:
        """
        FP4 variants use 'weight_scale_2' pattern for per-tensor weight scales.
        """
        return True

1182
1183
1184
1185
1186
1187
1188
1189
1190
    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
1191
        if not self.quant_config.is_checkpoint_nvfp4_serialized:
1192
1193
1194
1195
            raise ValueError(
                "NVFP4 quantization was selected, "
                " dynamic quantization is not supported."
            )
1196

1197
1198
        layer.num_experts = num_experts
        layer.params_dtype = params_dtype
1199
1200
1201
1202
        layer.quant_config = self.quant_config
        weight_dtype = torch.uint8
        weight_scale_dtype = torch.float8_e4m3fn
        weight_loader = extra_weight_attrs.get("weight_loader")
1203
        global_num_experts = extra_weight_attrs.get("global_num_experts")
1204
1205
1206
1207
1208
1209
1210
        # GEMM 1
        w13_weight = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                # 2 fp4 items are packed in the input dimension
                hidden_size // 2,
1211
1212
                dtype=weight_dtype,
            ),
1213
1214
            input_dim=1,
            output_dim=2,
1215
1216
            weight_loader=weight_loader,
        )
1217
1218
1219
1220
1221
1222
1223
1224
1225
        layer.register_parameter("w13_weight", w13_weight)

        # GEMM 2
        w2_weight = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                hidden_size,
                # 2 fp4 items are packed in the input dimension
                intermediate_size_per_partition // 2,
1226
1227
                dtype=weight_dtype,
            ),
1228
1229
            input_dim=1,
            output_dim=2,
1230
1231
            weight_loader=weight_loader,
        )
1232
1233
1234
1235
1236
1237
1238
1239
        layer.register_parameter("w2_weight", w2_weight)

        w13_weight_scale = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                # 2 fp4 items are packed in the input dimension
                hidden_size // self.quant_config.group_size,
1240
1241
                dtype=weight_scale_dtype,
            ),
1242
1243
            input_dim=1,
            output_dim=2,
1244
1245
            weight_loader=weight_loader,
        )
1246
1247
1248
1249
1250
1251
1252
        layer.register_parameter("w13_weight_scale", w13_weight_scale)

        w2_weight_scale = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                hidden_size,
                # 2 fp4 items are packed in the input dimension
1253
1254
1255
                intermediate_size_per_partition // self.quant_config.group_size,
                dtype=weight_scale_dtype,
            ),
1256
1257
            input_dim=1,
            output_dim=2,
1258
1259
            weight_loader=weight_loader,
        )
1260
1261
1262
        layer.register_parameter("w2_weight_scale", w2_weight_scale)

        extra_weight_attrs.update(
1263
1264
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
        )
1265
1266
1267

        w13_weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(num_experts, 2, dtype=torch.float32),
1268
1269
            weight_loader=weight_loader,
        )
1270
1271
1272
1273
        layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)

        w2_weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(num_experts, dtype=torch.float32),
1274
1275
            weight_loader=weight_loader,
        )
1276
1277
1278
        layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2)

        extra_weight_attrs.update(
1279
1280
            {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
        )
1281

1282
1283
1284
1285
1286
        use_global_sf = self.allow_flashinfer and is_flashinfer_supporting_global_sf(
            self.flashinfer_moe_backend
        )
        global_scale_num_experts = global_num_experts if use_global_sf else num_experts

1287
        w13_input_scale = PerTensorScaleParameter(
1288
            data=torch.empty(global_scale_num_experts, 2, dtype=torch.float32),
1289
1290
            weight_loader=weight_loader,
        )
1291
1292
        layer.register_parameter("w13_input_scale", w13_input_scale)

1293
        w2_input_scale = PerTensorScaleParameter(
1294
            data=torch.empty(global_scale_num_experts, dtype=torch.float32),
1295
1296
            weight_loader=weight_loader,
        )
1297
1298
1299
        layer.register_parameter("w2_input_scale", w2_input_scale)

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
1300
        # GEMM 1 processing
1301
1302
1303
        gemm1_weight = layer.w13_weight.data
        gemm1_weight_scale = layer.w13_weight_scale.data

1304
1305
1306
        if self.allow_flashinfer and (
            self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
            or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
1307
        ):
1308
            gemm1_weight, gemm1_weight_scale = reorder_w1w3_to_w3w1(
1309
1310
                gemm1_weight, gemm1_weight_scale, dim=-2
            )
1311
1312

        layer.w13_weight = Parameter(gemm1_weight, requires_grad=False)
1313
        layer.w13_weight_scale = Parameter(gemm1_weight_scale, requires_grad=False)
1314

1315
        # Common processing for w13_weight_scale_2
1316
1317
1318
        if not torch.allclose(
            layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]
        ):
1319
1320
            logger.warning_once(
                "w1_weight_scale_2 must match w3_weight_scale_2. "
1321
1322
                "Accuracy may be affected."
            )
1323
1324

        w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
1325
        layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False)
1326

1327
        # Common processing for input scales and alphas
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
        use_global_sf = self.allow_flashinfer and is_flashinfer_supporting_global_sf(
            self.flashinfer_moe_backend
        )
        if use_global_sf:
            # For backends provide by Flashinfer, the input global scales are
            # shared across all experts.
            w13_input_scale = (
                layer.w13_input_scale.max().to(torch.float32).expand(layer.num_experts)
            )
        else:
            w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
1339
1340
        layer.g1_alphas = Parameter(
            (w13_input_scale * w13_weight_scale_2).to(torch.float32),
1341
1342
            requires_grad=False,
        )
1343
1344
1345

        # This is for quantization, so we need to invert it.
        layer.w13_input_scale_quant = Parameter(
1346
1347
            (1 / w13_input_scale).to(torch.float32), requires_grad=False
        )
1348

1349
        # GEMM 2 processing
1350
1351
1352
1353
1354
1355
1356
1357
        if use_global_sf:
            # For backends provide by Flashinfer, the input global scales are
            # shared across all experts.
            w2_input_scale = (
                layer.w2_input_scale.max().to(torch.float32).expand(layer.num_experts)
            )
        else:
            w2_input_scale = layer.w2_input_scale
1358
        layer.g2_alphas = Parameter(
1359
            (w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
1360
1361
            requires_grad=False,
        )
1362
1363
1364

        # This is for quantization, so we need to invert it.
        layer.w2_input_scale_quant = Parameter(
1365
            (1 / w2_input_scale).to(torch.float32), requires_grad=False
1366
        )
1367

1368
        # TensorRT-LLM specific processing
1369
1370
1371
1372
        if (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1373
            # Prepare static weights for TRT-LLM kernel
1374
            # alternate: prepare_static_weight_layouts_for_trtllm_moe
1375
1376
1377
1378
1379
            (
                gemm1_weights_fp4_shuffled,
                gemm1_scales_fp4_shuffled,
                gemm2_weights_fp4_shuffled,
                gemm2_scales_fp4_shuffled,
1380
            ) = prepare_static_weights_for_trtllm_fp4_moe(
1381
1382
1383
1384
1385
1386
1387
1388
                layer.w13_weight,
                layer.w2_weight,
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                layer.w2_weight.size(-2),  # hidden_size
                layer.w13_weight.size(-2) // 2,  # intermediate_size
                layer.w13_weight.size(0),  # num_experts
            )
1389
            logger.debug_once("Finished shuffling weights for TRT-LLM MOE")
1390
1391

            layer.gemm1_weights_fp4_shuffled = Parameter(
1392
1393
                gemm1_weights_fp4_shuffled, requires_grad=False
            )
1394
            layer.gemm2_weights_fp4_shuffled = Parameter(
1395
1396
                gemm2_weights_fp4_shuffled, requires_grad=False
            )
1397
            layer.gemm1_scales_fp4_shuffled = Parameter(
1398
1399
                gemm1_scales_fp4_shuffled, requires_grad=False
            )
1400
            layer.gemm2_scales_fp4_shuffled = Parameter(
1401
1402
                gemm2_scales_fp4_shuffled, requires_grad=False
            )
1403
1404
1405

            # Additional parameter needed for TRT-LLM
            layer.g1_scale_c = Parameter(
1406
                (layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32),
1407
1408
                requires_grad=False,
            )
1409

1410
1411
1412
1413
1414
            # Clean up weights that won't be used by TRT-LLM
            del layer.w2_weight
            del layer.w2_weight_scale
            del layer.w13_weight
            del layer.w13_weight_scale
1415
1416
1417
1418
1419
1420
1421
        elif self.use_marlin:
            # Marlin processing
            prepare_moe_fp4_layer_for_marlin(layer)
            del layer.g1_alphas
            del layer.g2_alphas
            del layer.w13_input_scale_quant
            del layer.w2_input_scale_quant
1422
1423
        else:
            # Non-TRT-LLM processing (Cutlass or non-flashinfer)
1424
1425
1426
1427
1428
            w13_blockscale_swizzled = swizzle_blockscale(layer.w13_weight_scale)
            layer.w13_weight_scale = Parameter(
                w13_blockscale_swizzled, requires_grad=False
            )

1429
            w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale)
1430
1431
1432
1433
            layer.w2_weight_scale = Parameter(
                w2_blockscale_swizzled, requires_grad=False
            )
            layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
1434

1435
    def get_fused_moe_quant_config(
1436
        self, layer: torch.nn.Module
1437
    ) -> FusedMoEQuantConfig | None:
1438
1439
1440
1441
        if (
            self.use_marlin
            or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
            return None

        return nvfp4_moe_quant_config(
            w1_scale=layer.w13_weight_scale,
            w2_scale=layer.w2_weight_scale,
            g1_alphas=layer.g1_alphas,
            g2_alphas=layer.g2_alphas,
            a1_gscale=layer.w13_input_scale_quant,
            a2_gscale=layer.w2_input_scale_quant,
        )

1453
1454
    def apply(
        self,
1455
        layer: FusedMoE,
1456
1457
1458
1459
1460
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
        use_grouped_topk: bool = False,
1461
1462
        topk_group: int | None = None,
        num_expert_group: int | None = None,
1463
        global_num_experts: int = -1,
1464
1465
        expert_map: torch.Tensor | None = None,
        custom_routing_function: Callable | None = None,
1466
        scoring_func: str = "softmax",
1467
        routed_scaling_factor: float = 1.0,
1468
        e_score_correction_bias: torch.Tensor | None = None,
1469
1470
        apply_router_weight_on_input: bool = False,
        activation: str = "silu",
1471
        enable_eplb: bool = False,
1472
1473
1474
1475
        expert_load_view: torch.Tensor | None = None,
        logical_to_physical_map: torch.Tensor | None = None,
        logical_replica_count: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
1476
        assert activation == "silu", "Only SiLU activation is supported."
1477

1478
1479
1480
1481
        if (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1482
1483
1484
1485
            if enable_eplb:
                raise NotImplementedError(
                    "EPLB not supported for `ModelOptNvFp4FusedMoE` yet."
                )
1486
1487
1488
1489
            return flashinfer_trtllm_fp4_moe(
                layer=layer,
                x=x,
                router_logits=router_logits,
1490
                top_k=top_k,
1491
1492
                global_num_experts=global_num_experts,
                num_expert_group=num_expert_group,
1493
                topk_group=topk_group,
1494
1495
1496
                custom_routing_function=custom_routing_function,
                e_score_correction_bias=e_score_correction_bias,
            )
1497

1498
        topk_weights, topk_ids, _ = layer.select_experts(
1499
1500
            hidden_states=x,
            router_logits=router_logits,
1501
        )
1502

1503
        if self.use_marlin:
1504
            return fused_marlin_moe(
1505
1506
1507
                x,
                layer.w13_weight,
                layer.w2_weight,
1508
1509
                None,
                None,
1510
1511
1512
1513
1514
1515
1516
1517
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                router_logits,
                topk_weights,
                topk_ids,
                global_scale1=layer.w13_weight_scale_2,
                global_scale2=layer.w2_weight_scale_2,
                quant_type_id=scalar_types.float4_e2m1f.id,
1518
                apply_router_weight_on_input=apply_router_weight_on_input,
1519
                global_num_experts=global_num_experts,
1520
                expert_map=expert_map,
1521
1522
                workspace=layer.workspace,
            )
1523

1524
1525
1526
1527
        elif self.allow_flashinfer:
            assert self.flashinfer_moe_backend in (
                FlashinferMoeBackend.CUTLASS,
                FlashinferMoeBackend.CUTEDSL,
1528
            )
1529
1530
1531
1532
            if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
                from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (  # noqa: E501
                    flashinfer_cutlass_moe_fp4,
                )
1533

1534
1535
1536
1537
1538
1539
1540
                flashinfer_fn_moe_fp4 = flashinfer_cutlass_moe_fp4
            else:
                from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import (  # noqa: E501
                    flashinfer_cutedsl_moe_fp4,
                )

                flashinfer_fn_moe_fp4 = flashinfer_cutedsl_moe_fp4
1541

1542
1543
            assert self.moe_quant_config is not None
            return flashinfer_fn_moe_fp4(
1544
1545
1546
1547
1548
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
1549
1550
                quant_config=self.moe_quant_config,
                inplace=False,
1551
1552
1553
1554
1555
1556
                activation=activation,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
        else:
1557
1558
            # If no modular kernel is provided, use cutlass_moe_fp4 for TP case
            # only (no EP).
1559
1560
            from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4

1561
1562
            assert self.moe_quant_config is not None
            return cutlass_moe_fp4(
1563
1564
1565
1566
1567
                a=x,
                w1_fp4=layer.w13_weight,
                w2_fp4=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
1568
1569
1570
1571
                quant_config=self.moe_quant_config,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
                # TODO: derive from arguments
1572
1573
1574
1575
                m=x.shape[0],
                n=layer.w2_weight.shape[2] * 2,
                k=x.shape[1],
                e=layer.w13_weight.shape[0],
1576
            )
1577
1578
1579
1580
1581


ModelOptNvFp4Config.LinearMethodCls = ModelOptNvFp4LinearMethod
ModelOptNvFp4Config.FusedMoEMethodCls = ModelOptNvFp4FusedMoE
ModelOptNvFp4Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod