modelopt.py 53.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import Any, Callable, Optional, Union
5
6
7
8
9

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter

10
11
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
12
13
from vllm._custom_ops import (cutlass_scaled_fp4_mm,
                              cutlass_scaled_mm_supports_fp4, scaled_fp4_quant)
14
from vllm.distributed import get_ep_group
15
from vllm.logger import init_logger
16
17
18
from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import (  # noqa: E501
    FlashInferCutlassMoEPrepareAndFinalize)
19
20
from vllm.model_executor.layers.fused_moe.layer import (
    FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
21
22
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
                                               UnquantizedLinearMethod)
23
from vllm.model_executor.layers.quantization import QuantizationMethods
24
25
26
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
27
28
29
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
    apply_fp4_marlin_linear, is_fp4_marlin_supported,
    prepare_fp4_layer_for_marlin, prepare_moe_fp4_layer_for_marlin)
30
from vllm.model_executor.layers.quantization.utils.quant_utils import (
31
    GroupShape, is_layer_skipped)
32
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
33
    Fp8LinearOp, requantize_with_max_scale)
34
35
from vllm.model_executor.parameter import (ModelWeightParameter,
                                           PerTensorScaleParameter)
36
from vllm.platforms import current_platform
37
from vllm.scalar_type import scalar_types
38
39
40

logger = init_logger(__name__)

41
42
QUANT_ALGOS = ["FP8", "NVFP4"]
KV_CACHE_QUANT_ALGOS = ["FP8"]
43
44
45
46
47
48
49
50


class ModelOptFp8Config(QuantizationConfig):
    """Config class for ModelOpt FP8."""

    def __init__(
        self,
        is_checkpoint_fp8_serialized: bool = False,
51
52
        kv_cache_quant_method: Optional[str] = None,
        exclude_modules: Optional[list[str]] = None,
53
    ) -> None:
54
        super().__init__()
55
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
56
57
        self.kv_cache_quant_method = kv_cache_quant_method
        self.exclude_modules = exclude_modules
58
59
60
61
62
        if is_checkpoint_fp8_serialized:
            logger.warning("Detected ModelOpt fp8 checkpoint. Please note that"
                           " the format is experimental and could change.")

    @classmethod
63
    def get_name(cls) -> QuantizationMethods:
64
65
66
        return "modelopt"

    @classmethod
67
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
68
69
70
71
72
73
74
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
        return 89

    @classmethod
75
    def get_config_filenames(cls) -> list[str]:
76
77
        return ["hf_quant_config.json"]

78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
    @classmethod
    def override_quantization_method(
            cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
        """Detect if this ModelOpt config should be used based on
        quantization config."""

        if hf_quant_cfg is None:
            return None

        # Use the community standard 'quant_method'
        quant_method = hf_quant_cfg.get("quant_method", "").lower()

        # Only proceed if the method is explicitly "modelopt"
        if quant_method != "modelopt":
            return None

        # Look for ModelOpt-specific config structure
        if "quantization" in hf_quant_cfg:
            quant_config = hf_quant_cfg["quantization"]
            if isinstance(quant_config, dict):
                quant_algo = quant_config.get("quant_algo", "")
                if "FP8" in quant_algo:
                    return "modelopt"
        else:
            # Check for compressed-tensors style config with specific quant_algo
            quant_algo = hf_quant_cfg.get("quant_algo", "")
            if isinstance(quant_algo, str) and "FP8" in quant_algo:
                return "modelopt"

        return None

109
    @classmethod
110
    def from_config(cls, config: dict[str, Any]) -> "ModelOptFp8Config":
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
        # Handle both ModelOpt format and compressed-tensors style format
        if "quantization" in config:
            # ModelOpt format: {"quantization": {"quant_algo": "..."}}
            quant_config = cls.get_from_keys(config, ["quantization"])
            if not isinstance(quant_config, dict):
                raise ValueError(
                    "Expected 'quantization' to be a dictionary in config")
            quant_method = quant_config.get("quant_algo", "")
            if not quant_method:
                raise ValueError("Missing 'quant_algo' in quantization config")
            kv_cache_quant_method = quant_config.get("kv_cache_quant_algo")
            exclude_modules = quant_config.get("exclude_modules")
        else:
            # Compressed-tensors style format:
            # {"quant_algo": "...", "quant_method": "modelopt"}
            quant_method = config.get("quant_algo", "")
            kv_cache_quant_method = config.get("kv_cache_quant_algo")
            exclude_modules = config.get("exclude_modules")
129

130
        if quant_method not in QUANT_ALGOS:
131
132
133
134
135
            raise ValueError(
                f"ModelOpt currently only supports: {QUANT_ALGOS} "
                "quantizations in vLLM. Please check the "
                "`hf_quant_config.json` file for your model's "
                "quant configuration.")
136
137
        is_checkpoint_fp8_serialized = ("FP8" in quant_method)

138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
        return cls(is_checkpoint_fp8_serialized, kv_cache_quant_method,
                   exclude_modules)

    def is_layer_excluded(self, prefix: str) -> bool:
        """
        Check if a layer should be excluded from quantization.

        This method handles both regular models and multimodal models that use
        the language_model prefix. For multimodal models, it checks if the
        module name (without the language_model prefix) is in the exclude list.
        """
        if self.exclude_modules is None:
            return False

        # Check if any excluded module matches the prefix
        for module in self.exclude_modules:
            if (module in prefix
                    or (prefix.startswith("language_model.")
                        and module in prefix.removeprefix("language_model."))):
                return True
        return False
159
160
161
162
163

    def get_quant_method(self, layer: torch.nn.Module,
                         prefix: str) -> Optional["QuantizeMethodBase"]:
        from vllm.attention.layer import Attention  # Avoid circular import
        if isinstance(layer, LinearBase):
164
165
            if self.is_layer_excluded(prefix):
                return UnquantizedLinearMethod()
166
167
168
            return ModelOptFp8LinearMethod(self)
        elif isinstance(layer, Attention):
            return ModelOptFp8KVCacheMethod(self)
169
170
        elif isinstance(layer, FusedMoE):
            return ModelOptFp8MoEMethod(self)
171
172
173
174
175
176
        return None


class ModelOptFp8LinearMethod(LinearMethodBase):
    """Linear method for Model Optimizer static quantization.
    Supports loading FP8 checkpoints with static weight scale and
177
    activation scale. Future support might be added for dynamic
178
179
180
181
    scales.

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
182
    2. Only support float8_e4m3fn datatype
183
184
185
186
187
        Args: quant_config: The ModelOpt quantization config.
    """

    def __init__(self, quant_config: ModelOptFp8Config):
        self.quant_config = quant_config
188
189
        self.fp8_linear = Fp8LinearOp(
            act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR)
190
191
192
193
194

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
195
        output_partition_sizes: list[int],
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        weight_dtype = (torch.float8_e4m3fn
                        if self.quant_config.is_checkpoint_fp8_serialized else
                        params_dtype)
        weight = ModelWeightParameter(data=torch.empty(
            output_size_per_partition,
            input_size_per_partition,
            dtype=weight_dtype),
                                      input_dim=1,
                                      output_dim=0,
                                      weight_loader=weight_loader)
        layer.register_parameter("weight", weight)

        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
            weight_scale = PerTensorScaleParameter(data=torch.empty(
                len(output_partition_sizes), dtype=torch.float32),
                                                   weight_loader=weight_loader)
            weight_scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("weight_scale", weight_scale)
            # INPUT SCALE
            scale = PerTensorScaleParameter(data=torch.empty(
                len(output_partition_sizes), dtype=torch.float32),
                                            weight_loader=weight_loader)

            scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("input_scale", scale)

    def process_weights_after_loading(self, layer: Module) -> None:
235
236
237
238
239
        weight = layer.weight
        max_w_scale = layer.weight_scale.max()
        if not (layer.weight_scale == layer.weight_scale[0]).all():
            max_w_scale, weight = requantize_with_max_scale(
                layer.weight, layer.weight_scale, layer.logical_widths)
240
241
242
243
244
245
246
247
248
249
250
        layer.weight = Parameter(weight.t(), requires_grad=False)
        layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
        layer.input_scale = Parameter(layer.input_scale.max(),
                                      requires_grad=False)

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
251
252
253
254
255
        return self.fp8_linear.apply(input=x,
                                     weight=layer.weight,
                                     weight_scale=layer.weight_scale,
                                     input_scale=layer.input_scale,
                                     bias=bias)
256
257


258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
class ModelOptFp8MoEMethod(FusedMoEMethodBase):
    """MoE method for ModelOpt FP8.
    Supports loading FP8 checkpoints with static weight scale and
    activation scale.
    Args:
        quant_config: The ModelOpt quantization config.
    """

    def __init__(self, quant_config: ModelOptFp8Config):
        self.quant_config = quant_config
        from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
            cutlass_fp8_supported)
        self.cutlass_fp8_supported = cutlass_fp8_supported()

    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):

        # Use FP8 dtype if checkpoint is serialized
        weight_dtype = (torch.float8_e4m3fn
                        if self.quant_config.is_checkpoint_fp8_serialized else
                        params_dtype)
        weight_loader = extra_weight_attrs.get("weight_loader")

        w13_weight = ModelWeightParameter(
            data=torch.empty(num_experts,
                             2 * intermediate_size_per_partition,
                             hidden_size,
                             dtype=weight_dtype),
            input_dim=2,
            output_dim=1,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_weight", w13_weight)

        w2_weight = ModelWeightParameter(
            data=torch.empty(num_experts,
                             hidden_size,
                             intermediate_size_per_partition,
                             dtype=weight_dtype),
            input_dim=2,
            output_dim=1,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w2_weight", w2_weight)

        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALES - Per-tensor scaling for ModelOpts
            # Allocate 2 scales for w1 and w3 respectively.
            # They will be combined to a single scale after weight loading.
            w13_weight_scale = PerTensorScaleParameter(
                data=torch.full(
                    (num_experts, 2),
                    1.0,
                    dtype=torch.float32,
                ),
                weight_loader=weight_loader,
            )
            w2_weight_scale = PerTensorScaleParameter(
                data=torch.full((num_experts, ), 1.0, dtype=torch.float32),
                weight_loader=weight_loader,
            )
            layer.register_parameter("w13_weight_scale", w13_weight_scale)
            layer.register_parameter("w2_weight_scale", w2_weight_scale)

            # Set weight loader attributes for scales
            extra_weight_attrs.update(
                {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})

            # INPUT SCALES - Per-tensor scaling for ModelOpt
            w13_input_scale = PerTensorScaleParameter(
                data=torch.full((num_experts, ), 1.0, dtype=torch.float32),
                weight_loader=weight_loader,
            )
            w2_input_scale = PerTensorScaleParameter(
                data=torch.full((num_experts, ), 1.0, dtype=torch.float32),
                weight_loader=weight_loader,
            )
            layer.register_parameter("w13_input_scale", w13_input_scale)
            layer.register_parameter("w2_input_scale", w2_input_scale)

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        """Process FP8 MoE weights after loading from serialized checkpoint.
        Only supports pre-quantized checkpoints with FP8 weights and scales.
        """

        layer.w13_weight = Parameter(layer.w13_weight.data,
                                     requires_grad=False)
        layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)

        from vllm._custom_ops import scaled_fp8_quant
        from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
            per_tensor_dequantize)

        # Handle scale parameters
        if hasattr(layer,
                   "w13_weight_scale") and layer.w13_weight_scale is not None:
            # Fp8 moe kernel needs single weight scale for w13 per expert.
            # We take the max of the w1 and w3 scales
            # then dequant and requant each expert.
            if layer.w13_weight_scale.dim() == 2:

                # Get the maximum scale across w1 and w3 for each expert
                max_w13_scales = layer.w13_weight_scale.max(dim=1).values

                # Requantize each expert's weights using the combined scale
                # w13_weight (num_experts, 2 * intermediate_size, hidden_size)
                # where the first intermediate_size rows are w1, the next are w3
                intermediate_size = layer.w13_weight.shape[1] // 2
                for expert_id in range(layer.w13_weight.shape[0]):
                    start = 0
                    for shard_id in range(2):  # w1 and w3
                        # Dequantize using the original scale for this shard
                        dq_weight = per_tensor_dequantize(
                            layer.w13_weight[expert_id][start:start +
                                                        intermediate_size, :],
                            layer.w13_weight_scale[expert_id][shard_id],
                        )
                        # Requantize using the combined max scale

                        (
                            layer.w13_weight[expert_id][start:start +
                                                        intermediate_size, :],
                            _,
                        ) = scaled_fp8_quant(dq_weight,
                                             max_w13_scales[expert_id])

                        start += intermediate_size

                # Update the scale parameter to be per-expert
                layer.w13_weight_scale = Parameter(max_w13_scales,
                                                   requires_grad=False)
            else:
                layer.w13_weight_scale = Parameter(layer.w13_weight_scale.data,
                                                   requires_grad=False)

        if hasattr(layer,
                   "w2_weight_scale") and layer.w2_weight_scale is not None:
            layer.w2_weight_scale = Parameter(layer.w2_weight_scale.data,
                                              requires_grad=False)
        # Input scales must be equal for each expert in fp8 MoE layers.
        if hasattr(layer,
                   "w13_input_scale") and layer.w13_input_scale is not None:
            layer.w13_input_scale = Parameter(layer.w13_input_scale.max(),
                                              requires_grad=False)
        if hasattr(layer,
                   "w2_input_scale") and layer.w2_input_scale is not None:
            layer.w2_input_scale = Parameter(layer.w2_input_scale.max(),
                                             requires_grad=False)

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
        use_grouped_topk: bool = False,
        topk_group: Optional[int] = None,
        num_expert_group: Optional[int] = None,
        global_num_experts: int = -1,
        expert_map: Optional[torch.Tensor] = None,
        custom_routing_function: Optional[Callable] = None,
        scoring_func: str = "softmax",
        e_score_correction_bias: Optional[torch.Tensor] = None,
        apply_router_weight_on_input: bool = False,
        activation: str = "silu",
        enable_eplb: bool = False,
        expert_load_view: Optional[torch.Tensor] = None,
        logical_to_physical_map: Optional[torch.Tensor] = None,
        logical_replica_count: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        if enable_eplb:
            raise NotImplementedError(
                "EPLB not supported for `ModelOptFp8MoEMethod` yet.")

        # Expert selection
        topk_weights, topk_ids = FusedMoE.select_experts(
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
            num_expert_group=num_expert_group,
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
            e_score_correction_bias=e_score_correction_bias,
        )
        from vllm.model_executor.layers.fused_moe.fused_moe import (
            fused_experts)
        return fused_experts(
            x,
            layer.w13_weight,
            layer.w2_weight,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            inplace=True,
            activation=activation,
            use_fp8_w8a8=True,
            per_channel_quant=False,
            global_num_experts=global_num_experts,
            expert_map=expert_map,
            w1_scale=layer.w13_weight_scale,
            w2_scale=layer.w2_weight_scale,
            a1_scale=layer.w13_input_scale,
            a2_scale=layer.w2_input_scale,
            apply_router_weight_on_input=apply_router_weight_on_input,
        )


475
476
477
478
479
480
class ModelOptNvFp4Config(QuantizationConfig):
    """Config class for ModelOpt FP4."""

    def __init__(
        self,
        is_checkpoint_nvfp4_serialized: bool,
481
        kv_cache_quant_algo: Optional[str],
482
        exclude_modules: list[str],
483
484
        group_size: int = 16,
    ) -> None:
485
        super().__init__()
486
487
488
489
490
491
492
493
494
495
496
        self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
        if is_checkpoint_nvfp4_serialized:
            logger.warning(
                "Detected ModelOpt NVFP4 checkpoint. Please note that"
                " the format is experimental and could change in future.")

            self.group_size = group_size
            self.kv_cache_quant_algo = kv_cache_quant_algo
            self.exclude_modules = exclude_modules

    @classmethod
497
    def get_name(cls) -> QuantizationMethods:
498
        return "modelopt_fp4"
499
500

    @classmethod
501
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
502
503
504
505
        return [torch.bfloat16, torch.half, torch.float8_e4m3fn]

    @classmethod
    def get_min_capability(cls) -> int:
506
        return 80
507
508

    @classmethod
509
    def get_config_filenames(cls) -> list[str]:
510
511
        return ["hf_quant_config.json"]

512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
    @classmethod
    def override_quantization_method(
            cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
        """Detect if this ModelOpt FP4 config should be used based on
        quantization config."""
        if hf_quant_cfg is None:
            return None

        # Use the community standard 'quant_method'
        quant_method = hf_quant_cfg.get("quant_method", "").lower()

        # Only proceed if the method is explicitly "modelopt"
        if quant_method != "modelopt":
            return None

        # Look for ModelOpt-specific config structure
        if "quantization" in hf_quant_cfg:
            quant_config = hf_quant_cfg["quantization"]
            if isinstance(quant_config, dict):
                quant_algo = quant_config.get("quant_algo", "")
                if "NVFP4" in quant_algo:
                    return "modelopt_fp4"
        else:
            # Check for compressed-tensors style config with specific
            # quant_algo field
            quant_algo = hf_quant_cfg.get("quant_algo", "")
            if isinstance(quant_algo, str) and "FP4" in quant_algo.upper():
                return "modelopt_fp4"

        return None

543
    @classmethod
544
    def from_config(cls, config: dict[str, Any]) -> "ModelOptNvFp4Config":
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
        # Handle both traditional ModelOpt format and compressed-tensors
        # style format
        if "quantization" in config:
            # Traditional ModelOpt format:
            # {"quantization": {"quant_algo": "..."}}
            quant_config = cls.get_from_keys(config, ["quantization"])
            if not isinstance(quant_config, dict):
                raise ValueError(
                    "Expected 'quantization' to be a dictionary in config")

            quant_method = quant_config.get("quant_algo", "")
            if not quant_method:
                raise ValueError("Missing 'quant_algo' in quantization config")

            # Handle kv_cache_quant_algo with proper type validation
            kv_cache_quant_algo_raw = quant_config.get("kv_cache_quant_algo")
            if kv_cache_quant_algo_raw is None:
                # No KV cache quantization by default
                kv_cache_quant_algo = None
            elif isinstance(kv_cache_quant_algo_raw, str):
                kv_cache_quant_algo = kv_cache_quant_algo_raw
            else:
                raise ValueError(f"kv_cache_quant_algo must be a string, got "
                                 f"{type(kv_cache_quant_algo_raw)}")

            # Handle group_size with proper type validation
            group_size_raw = quant_config.get("group_size")
            if group_size_raw is None:
                group_size = 16  # Default value
            elif isinstance(group_size_raw, int):
                group_size = group_size_raw
            else:
                try:
                    group_size = int(group_size_raw)
                except (ValueError, TypeError):
                    raise ValueError(f"group_size must be an integer, got "
                                     f"{type(group_size_raw)}") from None

            exclude_modules = quant_config.get("exclude_modules", [])
            if not isinstance(exclude_modules, list):
                raise ValueError(f"exclude_modules must be a list, got "
                                 f"{type(exclude_modules)}")
        else:
            # Compressed-tensors style format:
            # {"quant_algo": "...", "quant_method": "modelopt"}
            quant_method = config.get("quant_algo", "")

            # Handle kv_cache_quant_algo with proper type validation
            kv_cache_quant_algo_raw = config.get("kv_cache_quant_algo")
            if kv_cache_quant_algo_raw is None:
                # No KV cache quantization by default
                kv_cache_quant_algo = None
            elif isinstance(kv_cache_quant_algo_raw, str):
                kv_cache_quant_algo = kv_cache_quant_algo_raw
            else:
                raise ValueError(f"kv_cache_quant_algo must be a string, got "
                                 f"{type(kv_cache_quant_algo_raw)}")

            # Handle group_size with proper type validation
            group_size_raw = config.get("group_size")
            if group_size_raw is None:
                group_size = 16  # Default value
            elif isinstance(group_size_raw, int):
                group_size = group_size_raw
            else:
                try:
                    group_size = int(group_size_raw)
                except (ValueError, TypeError):
                    raise ValueError(f"group_size must be an integer, got "
                                     f"{type(group_size_raw)}") from None

            exclude_modules = config.get("exclude_modules", [])
            if not isinstance(exclude_modules, list):
                raise ValueError(f"exclude_modules must be a list, got "
                                 f"{type(exclude_modules)}")

621
        if quant_method not in QUANT_ALGOS:
622
623
624
625
626
            raise ValueError(
                f"ModelOpt currently only supports: {QUANT_ALGOS} "
                "quantizations in vLLM. Please check the "
                "`hf_quant_config.json` file for your model's "
                "quant configuration.")
627
        is_checkpoint_nvfp4_serialized = ("NVFP4" in quant_method)
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643

        # For FP4, these fields are required
        if is_checkpoint_nvfp4_serialized and "quantization" in config:
            # Check if required fields are present in the quantization config
            quant_config = config["quantization"]
            required_fields = [
                "group_size", "kv_cache_quant_algo", "exclude_modules"
            ]
            missing_fields = [
                field for field in required_fields if field not in quant_config
            ]
            if missing_fields:
                raise ValueError(
                    f"NVFP4 quantization requires the following fields in "
                    f"hf_quant_config.json: {missing_fields}")

644
645
646
        return cls(is_checkpoint_nvfp4_serialized, kv_cache_quant_algo,
                   exclude_modules, group_size)

647
    def is_layer_excluded(self, prefix: str, exclude_modules: list):
648
        import regex as re
649
650
651
652
653
654
        for pattern in exclude_modules:
            regex_str = pattern.replace('.', r'\.').replace('*', r'.*')
            if re.fullmatch(regex_str, prefix):
                return True
        return False

655
656
657
658
    def get_quant_method(self, layer: torch.nn.Module,
                         prefix: str) -> Optional["QuantizeMethodBase"]:
        from vllm.attention.layer import Attention  # Avoid circular import
        if isinstance(layer, LinearBase):
659
660
            if (is_layer_skipped(prefix, self.exclude_modules)
                    or self.is_layer_excluded(prefix, self.exclude_modules)):
661
662
663
664
                return UnquantizedLinearMethod()
            return ModelOptNvFp4LinearMethod(self)
        elif isinstance(layer, Attention):
            return ModelOptFp8KVCacheMethod(self)
665
666
        elif isinstance(layer, FusedMoE):
            return ModelOptNvFp4FusedMoE(self)
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
        return None


def cutlass_fp4_supported() -> bool:
    if not current_platform.is_cuda():
        return False
    capability_tuple = current_platform.get_device_capability()
    capability = -1 if capability_tuple is None else capability_tuple.to_int()
    return cutlass_scaled_mm_supports_fp4(capability)


class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
    """

    def __init__(self, quant_config: Union[ModelOptFp8Config,
                                           ModelOptNvFp4Config]):
        super().__init__(quant_config)


class ModelOptNvFp4LinearMethod(LinearMethodBase):
    """Linear method for Model Optimizer NVFP4.
    Supports loading NVFP4 checkpoints with the following structure:
691

692
693
694
695
696
697
698
699
700
701
    input_scale: torch.float32, scalar ,
    weight: NVFP4(represented as byte) Shape: [1, X, y/2]
    weight_scale: FP8-E4M3, Shape: [X, Y], aka per block scale,
    weight_scale_2: torch.float32, scalar,
    Args: quant_config: The ModelOpt quantization config.
    """

    def __init__(self, quant_config: ModelOptNvFp4Config):
        self.quant_config = quant_config
        self.cutlass_nvfp4_supported = cutlass_fp4_supported()
702
703
        self.use_marlin = False

704
        if not self.cutlass_nvfp4_supported:
705
706
707
708
709
710
            if is_fp4_marlin_supported():
                self.use_marlin = True
            else:
                raise ValueError("Current platform does not support NVFP4"
                                 " quantization. Please use Blackwell and"
                                 " above.")
711
712
713
714
715

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
716
        output_partition_sizes: list[int],
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size
        if not self.quant_config.is_checkpoint_nvfp4_serialized:
            raise ValueError("NVFP4 quantization was selected, "
                             " dynamic quantization is not supported.")
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition

        if (input_size_per_partition % 16 != 0):
            raise ValueError("Unsupported model when in features size is "
                             "not multiple of 16")
        # The nvfp4 weight is still represented as
        weight_dtype = (torch.float8_e4m3fn
                        if self.quant_config.is_checkpoint_nvfp4_serialized
                        else params_dtype)
        # Weight
        weight = ModelWeightParameter(
            data=torch.empty(
                # 2 fp4 items are packed in the input dimension
                layer.output_size_per_partition,
                layer.input_size_per_partition // 2,
                dtype=torch.uint8),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader)
        layer.register_parameter("weight", weight)

        # Input Weight Scale
        input_scale = PerTensorScaleParameter(data=torch.empty(
            len(output_partition_sizes), dtype=torch.float32),
                                              weight_loader=weight_loader)
        layer.register_parameter("input_scale", input_scale)

        # Global Weight Scale
        weight_scale_2 = PerTensorScaleParameter(data=torch.empty(
            len(output_partition_sizes), dtype=torch.float32),
                                                 weight_loader=weight_loader)
        layer.register_parameter("weight_scale_2", weight_scale_2)

        # Per Block Weight Scale
        weight_scale = ModelWeightParameter(data=torch.empty(
            output_size_per_partition,
            input_size_per_partition // self.quant_config.group_size,
            dtype=weight_dtype,
        ),
                                            input_dim=1,
                                            output_dim=0,
                                            weight_loader=weight_loader)

        layer.register_parameter("weight_scale", weight_scale)

    def swizzle_blockscale(self, scale: torch.tensor):
        assert (scale.dtype == torch.float8_e4m3fn)
        # Pad and blockwise interleave weight_scale
        scale_ndim = scale.ndim
        if scale.ndim == 2:
            scale = scale.unsqueeze(0)
        assert scale.ndim == 3
        B, M, K = scale.shape
        round_up_multiple = lambda x, m: (x + m - 1) // m * m
        M_padded = round_up_multiple(M, 128)
        K_padded = round_up_multiple(K, 4)
        padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype)
        padded_scale[:B, :M, :K] = scale
        batches, rows, cols = padded_scale.shape
        assert rows % 128 == 0
        assert cols % 4 == 0
        padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32,
                                            cols // 4, 4)
        swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
        swizzled_scale = swizzled_scale.contiguous().cuda()
        return (swizzled_scale.reshape(M, K)
                if scale_ndim == 2 else swizzled_scale.reshape(B, M, K))

    def process_weights_after_loading(self, layer: Module) -> None:

        # global scales:
        input_scale_2 = layer.input_scale.max().to(torch.float32)
        layer.input_scale = Parameter(input_scale_2, requires_grad=False)

        weight_scale_2 = layer.weight_scale_2.max().to(torch.float32)
        layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False)

        layer.alpha = Parameter(layer.input_scale * layer.weight_scale_2,
                                requires_grad=False)

        # Swizzle the weight blockscale.
        # contracting dimension is input dimension
        # block_size = 16;
        assert (layer.weight_scale.shape[1] % 16 == 0), (
            "Expected weight_scale.dim(1) to be divisible by 16")
        assert (layer.weight_scale.dtype == torch.float8_e4m3fn), (
            "Weight Block scale must be represented as FP8-E4M3")
        swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale)

        layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
                                                requires_grad=False)
821
        layer.weight = Parameter(layer.weight.data, requires_grad=False)
822

823
824
825
826
827
828
        if self.use_marlin:
            prepare_fp4_layer_for_marlin(layer)
            del layer.alpha
            del layer.input_scale
            del layer.weight_scale_swizzled

829
830
831
832
833
834
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
835
836
837
838
839
840
841
842
843
844
845
        if self.use_marlin:
            return apply_fp4_marlin_linear(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale,
                weight_scale_2=layer.weight_scale_2,
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
                bias=bias)

846
        output_dtype = x.dtype
847
        output_shape = [x.shape[0], layer.weight.shape[0]]
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866

        # quantize BF16 or FP16 to (FP4 and interleaved block scale)
        s_quant = 1 / layer.input_scale
        x_fp4, x_blockscale = scaled_fp4_quant(x, s_quant)

        # validate dtypes of quantized input, input block scale,
        # weight and weight_blockscale
        assert (x_fp4.dtype == torch.uint8)
        assert (layer.weight.dtype == torch.uint8)
        assert (x_blockscale.dtype == torch.float8_e4m3fn)
        assert (layer.weight_scale_swizzled.dtype == torch.float8_e4m3fn)
        assert (layer.alpha.dtype == torch.float32)

        out = cutlass_scaled_fp4_mm(x_fp4, layer.weight, x_blockscale,
                                    layer.weight_scale_swizzled, layer.alpha,
                                    output_dtype)
        if bias is not None:
            out = out + bias
        return out.view(*output_shape)
867
868
869
870
871


class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
    """
    MoE Method for FP4 Quantization.
872
    Args:
873
874
875
876
877
        quant_config: NVFP4 Quant Config
    """

    def __init__(self, quant_config: ModelOptNvFp4Config):
        self.quant_config = quant_config
878
879
        self.cutlass_nvfp4_supported = cutlass_fp4_supported()
        self.use_marlin = False
880
881
        self.allow_flashinfer_cutlass = False

882
        if envs.VLLM_USE_FLASHINFER_MOE_FP4:
883
884
885
886
887
888
889
890
891
            if self.cutlass_nvfp4_supported and current_platform.is_cuda() \
               and current_platform.is_device_capability(100):
                logger.info_once(
                    "Using FlashInfer kernels for ModelOptNvFp4FusedMoE.")
                self.allow_flashinfer_cutlass = True
            else:
                logger.warning_once(
                    "Flashinfer CUTLASS Fused MoE not supported "
                    "or found on the current platform.")
892
893
894
895
896
897
898
899

        if not self.cutlass_nvfp4_supported:
            if is_fp4_marlin_supported():
                self.use_marlin = True
            else:
                raise ValueError("Current platform does not support NVFP4"
                                 " quantization. Please use Blackwell and"
                                 " above.")
900

901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
        self.fused_experts = None  # type: ignore

    def maybe_swap_experts_impl(
        self,
        moe_parallel_config: FusedMoEParallelConfig,
    ):
        if not self.allow_flashinfer_cutlass:
            return

        logger.debug_once("FlashInferExperts")
        # default to TP/EP case only

        experts_kwargs: dict[str, Any] = {
            "use_nvfp4_w4a4": True,
            "use_dp": moe_parallel_config.dp_size > 1,
            "ep_rank": moe_parallel_config.ep_rank,
            "ep_size": moe_parallel_config.ep_size,
            "tp_rank": moe_parallel_config.tp_rank,
            "tp_size": moe_parallel_config.tp_size,
        }

        from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (  # noqa: E501
            FlashInferExperts)
        experts = FlashInferExperts(**experts_kwargs)
        self.fused_experts = mk.FusedMoEModularKernel(
            FlashInferCutlassMoEPrepareAndFinalize(
                quant_dtype=torch.uint8,
                #meaning 2x e2m1 packed in one, kernel requirement
            ),
            experts,
        )

    # This method update self.fused_experts
    # only prepare_finalize is not None call select_gemm_impl
    # so when native cutlass fp4, fused_expert is in fuse_moe.py fused_expert
    # when it's not called(TP case), we still have 2 kernels to use.
    def select_gemm_impl(self, prepare_finalize,
                         moe) -> mk.FusedMoEPermuteExpertsUnpermute:

        assert moe is not None
        assert prepare_finalize is not None
        experts = None
        all2all_manager = get_ep_group().device_communicator.all2all_manager
        assert all2all_manager is not None
        if self.allow_flashinfer_cutlass:
            from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (  # noqa: E501
                FlashInferExperts)
            logger.debug_once("Using FlashInferExperts")
            experts = FlashInferExperts(
                use_nvfp4_w4a4=True,
                use_dp=moe.moe_parallel_config.dp_size > 1,
                ep_rank=moe.moe_parallel_config.ep_rank,
                ep_size=moe.moe_parallel_config.ep_size,
                tp_rank=moe.moe_parallel_config.tp_rank,
                tp_size=moe.moe_parallel_config.tp_size,
            )
        else:
            assert moe.dp_size > 1
            logger.debug_once("Using CutlassExpertsFp4")
            # Currently CutlassExpertsFp4 doesn't support DP
961
962
963
            raise ValueError("CutlassExpertsFp4 doesn't support DP. "
                             "Use flashinfer CUTLASS FusedMoE backend instead "
                             "(set VLLM_USE_FLASHINFER_MOE_FP4=1)")
964
965
966

        return experts

967
968
969
970
971
972
    def uses_weight_scale_2_pattern(self) -> bool:
        """
        FP4 variants use 'weight_scale_2' pattern for per-tensor weight scales.
        """
        return True

973
974
975
976
977
978
979
    def create_weights(self, layer: torch.nn.Module, num_experts: int,
                       hidden_size: int, intermediate_size_per_partition: int,
                       params_dtype: torch.dtype, **extra_weight_attrs):
        if not self.quant_config.is_checkpoint_nvfp4_serialized:
            raise ValueError("NVFP4 quantization was selected, "
                             " dynamic quantization is not supported.")

980
981
        layer.num_experts = num_experts
        layer.params_dtype = params_dtype
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
        layer.quant_config = self.quant_config
        weight_dtype = torch.uint8
        weight_scale_dtype = torch.float8_e4m3fn
        weight_loader = extra_weight_attrs.get("weight_loader")
        # GEMM 1
        w13_weight = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                # 2 fp4 items are packed in the input dimension
                hidden_size // 2,
                dtype=weight_dtype),
            input_dim=1,
            output_dim=2,
            weight_loader=weight_loader)
        layer.register_parameter("w13_weight", w13_weight)

        # GEMM 2
        w2_weight = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                hidden_size,
                # 2 fp4 items are packed in the input dimension
                intermediate_size_per_partition // 2,
                dtype=weight_dtype),
            input_dim=1,
            output_dim=2,
            weight_loader=weight_loader)
        layer.register_parameter("w2_weight", w2_weight)

        w13_weight_scale = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                # 2 fp4 items are packed in the input dimension
                hidden_size // self.quant_config.group_size,
                dtype=weight_scale_dtype),
            input_dim=1,
            output_dim=2,
            weight_loader=weight_loader)
        layer.register_parameter("w13_weight_scale", w13_weight_scale)

        w2_weight_scale = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                hidden_size,
                # 2 fp4 items are packed in the input dimension
                intermediate_size_per_partition //
                self.quant_config.group_size,
                dtype=weight_scale_dtype),
            input_dim=1,
            output_dim=2,
            weight_loader=weight_loader)
        layer.register_parameter("w2_weight_scale", w2_weight_scale)

        extra_weight_attrs.update(
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value})

        w13_weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(num_experts, 2, dtype=torch.float32),
            weight_loader=weight_loader)
        layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)

        w2_weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(num_experts, dtype=torch.float32),
            weight_loader=weight_loader)
        layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2)

        extra_weight_attrs.update(
            {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})

        w13_input_scale = PerTensorScaleParameter(data=torch.empty(
            num_experts, 2, dtype=torch.float32),
                                                  weight_loader=weight_loader)
        layer.register_parameter("w13_input_scale", w13_input_scale)

        w2_input_scale = PerTensorScaleParameter(data=torch.empty(
            num_experts, dtype=torch.float32),
                                                 weight_loader=weight_loader)
        layer.register_parameter("w2_input_scale", w2_input_scale)

    def swizzle_blockscale(self, scale: torch.tensor):
        assert (scale.dtype == torch.float8_e4m3fn)
        # Pad and blockwise interleave weight_scale
        scale_ndim = scale.ndim
        if scale.ndim == 2:
            scale = scale.unsqueeze(0)
        assert scale.ndim == 3
        B, M, K = scale.shape
        round_up_multiple = lambda x, m: (x + m - 1) // m * m
        M_padded = round_up_multiple(M, 128)
        K_padded = round_up_multiple(K, 4)
        padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype)
        padded_scale[:B, :M, :K] = scale
        batches, rows, cols = padded_scale.shape
        assert rows % 128 == 0
        assert cols % 4 == 0
        padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32,
                                            cols // 4, 4)
        swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
        swizzled_scale = swizzled_scale.contiguous().cuda()
        return (swizzled_scale.reshape(M, K)
                if scale_ndim == 2 else swizzled_scale.reshape(B, M, K))

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
1087
        # GEMM 1
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
        # The FlashInfer Cutlass fused MoE kernel expects the combined weights
        # to be ordered as [w3, w1], unlike the standard [w1, w3] layout.
        gemm1_weight = layer.w13_weight.data
        gemm1_weight_scale = layer.w13_weight_scale.data

        if self.allow_flashinfer_cutlass:
            dim = -2
            size = gemm1_weight.size(dim)
            assert size % 2 == 0, f"Expected even size in dim {dim}, got {size}"
            half = size // 2

            # Reorder weight
            w1, w3 = gemm1_weight.split(half, dim=dim)
            gemm1_weight = torch.cat([w3, w1], dim=dim).contiguous()

            # Reorder scale
            s1, s3 = gemm1_weight_scale.split(half, dim=dim)
            gemm1_weight_scale = torch.cat([s3, s1], dim=dim).contiguous()

        layer.w13_weight = Parameter(gemm1_weight, requires_grad=False)
        layer.w13_weight_scale = Parameter(gemm1_weight_scale,
                                           requires_grad=False)

1111
1112
1113
1114
1115
        if not torch.allclose(layer.w13_weight_scale_2[:, 0],
                              layer.w13_weight_scale_2[:, 1]):
            logger.warning_once(
                "w1_weight_scale_2 must match w3_weight_scale_2. "
                "Accuracy may be affected.")
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157

        w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
        layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2,
                                             requires_grad=False)

        w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(
            torch.float32)
        layer.g1_alphas = Parameter(
            (w13_input_scale * w13_weight_scale_2).to(torch.float32),
            requires_grad=False)

        assert (layer.w13_weight_scale.shape[2] % 16 == 0), (
            "Expected weight_scale.dim(1) to be divisible by 16")
        assert (layer.w13_weight_scale.dtype == torch.float8_e4m3fn), (
            "Weight Blockscale must be represented as FP8-E4M3")
        w13_blockscale_swizzled = self.swizzle_blockscale(
            layer.w13_weight_scale)

        layer.w13_blockscale_swizzled = Parameter(w13_blockscale_swizzled,
                                                  requires_grad=False)

        # This is for quantization, so we need to invert it.
        layer.w13_input_scale_quant = Parameter(
            (1 / w13_input_scale).to(torch.float32), requires_grad=False)

        # GEMM 2
        layer.g2_alphas = Parameter(
            (layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
            requires_grad=False)

        # This is for quantization, so we need to invert it.
        layer.w2_input_scale_quant = Parameter(
            (1 / layer.w2_input_scale).to(torch.float32), requires_grad=False)

        assert (layer.w2_weight_scale.shape[2] % 16 == 0), (
            "Expected weight_scale.dim(1) to be divisible by 16")
        assert (layer.w2_weight_scale.dtype == torch.float8_e4m3fn), (
            "Weight Blockscale must be represented as FP8-E4M3")
        w2_blockscale_swizzled = self.swizzle_blockscale(layer.w2_weight_scale)

        layer.w2_blockscale_swizzled = Parameter(w2_blockscale_swizzled,
                                                 requires_grad=False)
1158
        layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
1159
1160
1161
1162
1163
1164
1165
1166
1167

        if self.use_marlin:
            prepare_moe_fp4_layer_for_marlin(layer)
            del layer.g1_alphas
            del layer.g2_alphas
            del layer.w13_input_scale_quant
            del layer.w2_input_scale_quant
            del layer.w13_blockscale_swizzled
            del layer.w2_blockscale_swizzled
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
        use_grouped_topk: bool = False,
        topk_group: Optional[int] = None,
        num_expert_group: Optional[int] = None,
        global_num_experts: int = -1,
        expert_map: Optional[torch.Tensor] = None,
        custom_routing_function: Optional[Callable] = None,
        scoring_func: str = "softmax",
        e_score_correction_bias: Optional[torch.Tensor] = None,
        apply_router_weight_on_input: bool = False,
        activation: str = "silu",
1186
1187
1188
1189
        enable_eplb: bool = False,
        expert_load_view: Optional[torch.Tensor] = None,
        logical_to_physical_map: Optional[torch.Tensor] = None,
        logical_replica_count: Optional[torch.Tensor] = None,
1190
    ):
1191
1192
1193
        if enable_eplb:
            raise NotImplementedError(
                "EPLB not supported for `ModelOptNvFp4FusedMoE` yet.")
1194
        assert activation == "silu", "Only SiLU activation is supported."
1195

1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
        topk_weights, topk_ids = FusedMoE.select_experts(
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
            num_expert_group=num_expert_group,
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
            e_score_correction_bias=e_score_correction_bias)
1207

1208
        if self.use_marlin:
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
            return torch.ops.vllm.fused_marlin_moe(
                x,
                layer.w13_weight,
                layer.w2_weight,
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                router_logits,
                topk_weights,
                topk_ids,
                global_scale1=layer.w13_weight_scale_2,
                global_scale2=layer.w2_weight_scale_2,
                quant_type_id=scalar_types.float4_e2m1f.id,
1221
                apply_router_weight_on_input=apply_router_weight_on_input,
1222
1223
1224
                global_num_experts=global_num_experts,
                expert_map=expert_map)

1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
        if self.fused_experts is None:
            # If no modular kernel is provided, use cutlass_moe_fp4 for TP case
            # only (no EP).
            from vllm.model_executor.layers.fused_moe.cutlass_moe import (
                cutlass_moe_fp4)
            out = cutlass_moe_fp4(
                a=x,
                w1_fp4=layer.w13_weight,
                w2_fp4=layer.w2_weight,
                w1_blockscale=layer.w13_blockscale_swizzled,
                w2_blockscale=layer.w2_blockscale_swizzled,
                g1_alphas=layer.g1_alphas,
                g2_alphas=layer.g2_alphas,
                a1_gscale=layer.w13_input_scale_quant,
                a2_gscale=layer.w2_input_scale_quant,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                m=x.shape[0],
                n=layer.w2_weight.shape[2] * 2,
                k=x.shape[1],
                e=layer.w13_weight.shape[0],
                device=x.device,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input)
        else:
            # TP or DP case
            from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (  # noqa: E501
                is_valid_flashinfer_cutlass_fused_moe)
            assert is_valid_flashinfer_cutlass_fused_moe(
                x, layer.w13_weight, layer.w2_weight), (
                    "Flashinfer CUTLASS Fused MoE not applicable!")

            a1_gscale = torch.min(layer.w13_input_scale_quant)
            a2_gscale = torch.min(layer.w2_input_scale_quant)
            extra_expert_args = {
                'g1_alphas': layer.g1_alphas,
                'g2_alphas': layer.g2_alphas,
                'out_dtype': x.dtype,
                # Avoid confusion with a1_scale and a2_scale
                # where are batch size related.
                'a1_gscale': a1_gscale,
                'a2_gscale': a2_gscale,
            }
            extra_prepare_args = {
                'use_dp': layer.dp_size > 1,
                'local_tokens': x.shape[0],
                'a1_gscale': a1_gscale,
            }
            extra_finalize_args = {
                'use_dp': layer.dp_size > 1,
                'local_tokens': x.shape[0],
            }

            out = self.fused_experts(
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                inplace=False,  # TODO(shuw): fix later, now output is high prec
                activation=activation,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                w1_scale=layer.w13_blockscale_swizzled,
                w2_scale=layer.w2_blockscale_swizzled,
                apply_router_weight_on_input=apply_router_weight_on_input,
                extra_expert_args=extra_expert_args,
                extra_prepare_args=extra_prepare_args,
                extra_finalize_args=extra_finalize_args,
            )
        return out