modelopt.py 67.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Optional
6
7
8
9
10

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter

11
12
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
13
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
14
from vllm.logger import init_logger
15
from vllm.model_executor.layers.fused_moe.config import (
16
17
18
19
20
    FusedMoEConfig,
    FusedMoEQuantConfig,
    fp8_w8a8_moe_quant_config,
    nvfp4_moe_quant_config,
)
21
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
22
from vllm.model_executor.layers.fused_moe.layer import (
23
24
25
26
27
28
29
30
31
    FusedMoE,
    FusedMoEMethodBase,
    FusedMoeWeightScaleSupported,
)
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
)
32
from vllm.model_executor.layers.quantization import QuantizationMethods
33
from vllm.model_executor.layers.quantization.base_config import (
34
35
36
    QuantizationConfig,
    QuantizeMethodBase,
)
37
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
38
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
39
40
41
42
    build_flashinfer_fp4_cutlass_moe_prepare_finalize,
    reorder_w1w3_to_w3w1,
    select_nvfp4_gemm_impl,
)
43
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
44
45
    FlashinferMoeBackend,
    apply_flashinfer_per_tensor_scale_fp8,
46
    build_flashinfer_fp8_cutlass_moe_prepare_finalize,
47
48
    flashinfer_cutlass_moe_fp8,
    get_flashinfer_moe_backend,
49
    is_flashinfer_supporting_global_sf,
50
51
52
53
54
    register_moe_scaling_factors,
    rotate_flashinfer_fp8_moe_weights,
    select_cutlass_fp8_gemm_impl,
    swap_w13_to_w31,
)
55
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
56
57
58
59
60
    apply_fp4_marlin_linear,
    is_fp4_marlin_supported,
    prepare_fp4_layer_for_marlin,
    prepare_moe_fp4_layer_for_marlin,
)
61
from vllm.model_executor.layers.quantization.utils.quant_utils import (
62
63
64
65
66
    GroupShape,
    cutlass_fp4_supported,
    is_layer_skipped,
    swizzle_blockscale,
)
67
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
68
69
70
71
    Fp8LinearOp,
    requantize_with_max_scale,
)
from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
72
from vllm.scalar_type import scalar_types
73
74
75
76
77
from vllm.utils.flashinfer import (
    flashinfer_scaled_fp4_mm,
    has_flashinfer,
    has_flashinfer_moe,
)
78

79
80
81
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

82
83
logger = init_logger(__name__)

84
85
QUANT_ALGOS = ["FP8", "NVFP4"]
KV_CACHE_QUANT_ALGOS = ["FP8"]
86
87
88
89
90
91
92
93


class ModelOptFp8Config(QuantizationConfig):
    """Config class for ModelOpt FP8."""

    def __init__(
        self,
        is_checkpoint_fp8_serialized: bool = False,
94
95
        kv_cache_quant_method: str | None = None,
        exclude_modules: list[str] | None = None,
96
    ) -> None:
97
        super().__init__()
98
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
99
        self.kv_cache_quant_method = kv_cache_quant_method
100
        self.exclude_modules = exclude_modules or []
101
        if is_checkpoint_fp8_serialized:
102
103
104
105
            logger.warning(
                "Detected ModelOpt fp8 checkpoint. Please note that"
                " the format is experimental and could change."
            )
106
107

    @classmethod
108
    def get_name(cls) -> QuantizationMethods:
109
110
111
        return "modelopt"

    @classmethod
112
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
113
114
115
116
117
118
119
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
        return 89

    @classmethod
120
    def get_config_filenames(cls) -> list[str]:
121
122
        return ["hf_quant_config.json"]

123
124
    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if self.exclude_modules is not None:
125
            self.exclude_modules = hf_to_vllm_mapper.apply_list(self.exclude_modules)
126

127
128
    @classmethod
    def override_quantization_method(
129
        cls, hf_quant_cfg, user_quant
130
    ) -> QuantizationMethods | None:
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
        """Detect if this ModelOpt config should be used based on
        quantization config."""

        if hf_quant_cfg is None:
            return None

        # Use the community standard 'quant_method'
        quant_method = hf_quant_cfg.get("quant_method", "").lower()

        # Only proceed if the method is explicitly "modelopt"
        if quant_method != "modelopt":
            return None

        # Look for ModelOpt-specific config structure
        if "quantization" in hf_quant_cfg:
            quant_config = hf_quant_cfg["quantization"]
            if isinstance(quant_config, dict):
                quant_algo = quant_config.get("quant_algo", "")
                if "FP8" in quant_algo:
                    return "modelopt"
        else:
            # Check for compressed-tensors style config with specific quant_algo
            quant_algo = hf_quant_cfg.get("quant_algo", "")
            if isinstance(quant_algo, str) and "FP8" in quant_algo:
                return "modelopt"

        return None

159
    @classmethod
160
    def from_config(cls, config: dict[str, Any]) -> "ModelOptFp8Config":
161
162
163
164
165
        # Handle both ModelOpt format and compressed-tensors style format
        if "quantization" in config:
            # ModelOpt format: {"quantization": {"quant_algo": "..."}}
            quant_config = cls.get_from_keys(config, ["quantization"])
            if not isinstance(quant_config, dict):
166
                raise ValueError("Expected 'quantization' to be a dictionary in config")
167
168
169
170
            quant_method = quant_config.get("quant_algo", "")
            if not quant_method:
                raise ValueError("Missing 'quant_algo' in quantization config")
            kv_cache_quant_method = quant_config.get("kv_cache_quant_algo")
171
            # "exclude_modules" is the key in the legacy hf_quant_config.json
172
173
174
175
176
177
            exclude_modules = quant_config.get("exclude_modules")
        else:
            # Compressed-tensors style format:
            # {"quant_algo": "...", "quant_method": "modelopt"}
            quant_method = config.get("quant_algo", "")
            kv_cache_quant_method = config.get("kv_cache_quant_algo")
178
179
            # "ignore" is the key in config.json
            exclude_modules = config.get("ignore")
180

181
        if quant_method not in QUANT_ALGOS:
182
183
184
185
            raise ValueError(
                f"ModelOpt currently only supports: {QUANT_ALGOS} "
                "quantizations in vLLM. Please check the "
                "`hf_quant_config.json` file for your model's "
186
187
188
                "quant configuration."
            )
        is_checkpoint_fp8_serialized = "FP8" in quant_method
189

190
        return cls(is_checkpoint_fp8_serialized, kv_cache_quant_method, exclude_modules)
191
192
193
194

    def is_layer_excluded(self, prefix: str) -> bool:
        """
        Check if a layer should be excluded from quantization.
195
        Handles both exact matching (for fused layers) and substring matching.
196
197
198
199
200
201
202
203

        This method handles both regular models and multimodal models that use
        the language_model prefix. For multimodal models, it checks if the
        module name (without the language_model prefix) is in the exclude list.
        """
        if self.exclude_modules is None:
            return False

204
        # First check exact matching with fused layer support
205
        if is_layer_skipped(prefix, self.exclude_modules, self.packed_modules_mapping):
206
207
208
            return True

        # Then check substring matching for patterns not caught by exact match
209
        for module in self.exclude_modules:
210
            # Skip exact matches already handled above
211
212
213
214
215
216
217
            if module != prefix and (
                module in prefix
                or (
                    prefix.startswith("language_model.")
                    and module in prefix.removeprefix("language_model.")
                )
            ):
218
219
                return True
        return False
220

221
222
223
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
224
225
226
227
        from vllm.attention.layer import (  # Avoid circular import
            Attention,
            MLAAttention,
        )
228

229
        if isinstance(layer, LinearBase):
230
231
232
            if self.is_layer_excluded(prefix):
                return UnquantizedLinearMethod()
            # Check if this is a vision model layer that should not be quantized
233
            if "vision_tower" in prefix or "vision_model" in prefix:
234
                return UnquantizedLinearMethod()
235
            return ModelOptFp8LinearMethod(self)
236
        elif isinstance(layer, (Attention, MLAAttention)):
237
            return ModelOptFp8KVCacheMethod(self)
238
        elif isinstance(layer, FusedMoE):
239
            return ModelOptFp8MoEMethod(self, layer)
240
241
242
243
244
245
        return None


class ModelOptFp8LinearMethod(LinearMethodBase):
    """Linear method for Model Optimizer static quantization.
    Supports loading FP8 checkpoints with static weight scale and
246
    activation scale. Future support might be added for dynamic
247
248
249
250
    scales.

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
251
    2. Only support float8_e4m3fn datatype
252
253
254
        Args: quant_config: The ModelOpt quantization config.
    """

255
    def __init__(self, quant_config: ModelOptFp8Config) -> None:
256
        self.quant_config = quant_config
257
        self.fp8_linear = Fp8LinearOp(
258
259
            act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR
        )
260
261
262
263
264

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
265
        output_partition_sizes: list[int],
266
267
268
269
270
271
272
273
274
275
276
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
277
278
279
280
281
282
283
284
285
286
287
288
289
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_fp8_serialized
            else params_dtype
        )
        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition, input_size_per_partition, dtype=weight_dtype
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
290
291
292
293
        layer.register_parameter("weight", weight)

        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
294
295
296
297
            weight_scale = PerTensorScaleParameter(
                data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
                weight_loader=weight_loader,
            )
298
299
300
            weight_scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("weight_scale", weight_scale)
            # INPUT SCALE
301
302
303
304
            scale = PerTensorScaleParameter(
                data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
                weight_loader=weight_loader,
            )
305
306
307
308
309

            scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("input_scale", scale)

    def process_weights_after_loading(self, layer: Module) -> None:
310
311
312
313
        weight = layer.weight
        max_w_scale = layer.weight_scale.max()
        if not (layer.weight_scale == layer.weight_scale[0]).all():
            max_w_scale, weight = requantize_with_max_scale(
314
315
                layer.weight, layer.weight_scale, layer.logical_widths
            )
316
317
        layer.weight = Parameter(weight.t(), requires_grad=False)
        layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
318
        layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
319
320
321
322
323

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
324
        bias: torch.Tensor | None = None,
325
    ) -> torch.Tensor:
326
327
328
329
330
331
332
        return self.fp8_linear.apply(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            input_scale=layer.input_scale,
            bias=bias,
        )
333
334


335
336
337
338
339
340
341
342
class ModelOptFp8MoEMethod(FusedMoEMethodBase):
    """MoE method for ModelOpt FP8.
    Supports loading FP8 checkpoints with static weight scale and
    activation scale.
    Args:
        quant_config: The ModelOpt quantization config.
    """

343
344
345
    def __init__(
        self,
        quant_config: ModelOptFp8Config,
346
        layer: torch.nn.Module,
347
    ) -> None:
348
349
        super().__init__(layer.moe_config)
        self.layer = layer
350
351
        self.quant_config = quant_config
        from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
352
353
354
            cutlass_fp8_supported,
        )

355
        self.cutlass_fp8_supported = cutlass_fp8_supported()
356
        self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
357
358
359
360
361
        if (
            envs.VLLM_USE_FLASHINFER_MOE_FP8
            and has_flashinfer_moe()
            and self.moe.is_act_and_mul
        ):
362
            self.flashinfer_moe_backend = get_flashinfer_moe_backend()
363
            logger.info_once(
364
365
366
367
                f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
            )

    def maybe_make_prepare_finalize(
368
        self,
369
    ) -> mk.FusedMoEPrepareAndFinalize | None:
370
371
372
373
        # TRT LLM not supported with all2all yet.
        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
            return None
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
374
375
376
            prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
                self.moe
            )
377
378
379
380
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
        else:
            return super().maybe_make_prepare_finalize()
381
382
383
384

    def select_gemm_impl(
        self,
        prepare_finalize: mk.FusedMoEPrepareAndFinalize,
385
        layer: torch.nn.Module,
386
    ) -> mk.FusedMoEPermuteExpertsUnpermute:
387
        assert self.moe_quant_config is not None
388
        experts = select_cutlass_fp8_gemm_impl(
389
390
            self.moe,
            self.moe_quant_config,
391
392
393
        )
        logger.debug_once("Using %s", experts.__class__.__name__)
        return experts
394
395
396
397
398
399
400
401
402
403
404

    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        # Use FP8 dtype if checkpoint is serialized
405
406
407
408
409
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_fp8_serialized
            else params_dtype
        )
410
411
        weight_loader = extra_weight_attrs.get("weight_loader")

412
413
414
415
416
        if self.moe.is_act_and_mul:
            w13_up_dim = 2 * intermediate_size_per_partition
        else:
            w13_up_dim = intermediate_size_per_partition

417
        w13_weight = ModelWeightParameter(
418
419
            data=torch.empty(
                num_experts,
420
                w13_up_dim,
421
422
423
                hidden_size,
                dtype=weight_dtype,
            ),
424
425
426
427
428
429
430
            input_dim=2,
            output_dim=1,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_weight", w13_weight)

        w2_weight = ModelWeightParameter(
431
432
433
434
435
436
            data=torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=weight_dtype,
            ),
437
438
439
440
441
442
443
444
            input_dim=2,
            output_dim=1,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w2_weight", w2_weight)

        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALES - Per-tensor scaling for ModelOpts
445
            # For gated MoE, allocate 2 scales for w1 and w3 respectively.
446
            # They will be combined to a single scale after weight loading.
447
448
449
450
451
            # For non-gated MoE, allocate 1 scale for w13.
            if self.moe.is_act_and_mul:
                w13_weight_scale_shape = (num_experts, 2)
            else:
                w13_weight_scale_shape = (num_experts, 1)
452
453
            w13_weight_scale = PerTensorScaleParameter(
                data=torch.full(
454
                    w13_weight_scale_shape,
455
456
457
458
459
460
                    1.0,
                    dtype=torch.float32,
                ),
                weight_loader=weight_loader,
            )
            w2_weight_scale = PerTensorScaleParameter(
461
                data=torch.full((num_experts,), 1.0, dtype=torch.float32),
462
463
464
465
466
467
468
                weight_loader=weight_loader,
            )
            layer.register_parameter("w13_weight_scale", w13_weight_scale)
            layer.register_parameter("w2_weight_scale", w2_weight_scale)

            # Set weight loader attributes for scales
            extra_weight_attrs.update(
469
470
                {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
            )
471
472
473

            # INPUT SCALES - Per-tensor scaling for ModelOpt
            w13_input_scale = PerTensorScaleParameter(
474
                data=torch.full((num_experts,), 1.0, dtype=torch.float32),
475
476
477
                weight_loader=weight_loader,
            )
            w2_input_scale = PerTensorScaleParameter(
478
                data=torch.full((num_experts,), 1.0, dtype=torch.float32),
479
480
481
482
483
484
485
486
487
488
                weight_loader=weight_loader,
            )
            layer.register_parameter("w13_input_scale", w13_input_scale)
            layer.register_parameter("w2_input_scale", w2_input_scale)

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        """Process FP8 MoE weights after loading from serialized checkpoint.
        Only supports pre-quantized checkpoints with FP8 weights and scales.
        """

489
        layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
490
491
492
493
        layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)

        from vllm._custom_ops import scaled_fp8_quant
        from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
494
495
            per_tensor_dequantize,
        )
496
497

        # Handle scale parameters
498
        if hasattr(layer, "w13_weight_scale") and layer.w13_weight_scale is not None:
499
500
501
            # Fp8 moe kernel needs single weight scale for w13 per expert.
            # We take the max of the w1 and w3 scales
            # then dequant and requant each expert.
502
503
504
505
506
507
508
509
            if (
                layer.w13_weight_scale.dim() == 2
                and layer.w13_weight_scale.shape[1] == 2
            ):
                assert self.moe.is_act_and_mul, (
                    "w13_weight_scale should have 2 elements per expert "
                    "only for gated MoE"
                )
510
511
512
513
514
515
516
517
518
519
520
521
                # Get the maximum scale across w1 and w3 for each expert
                max_w13_scales = layer.w13_weight_scale.max(dim=1).values

                # Requantize each expert's weights using the combined scale
                # w13_weight (num_experts, 2 * intermediate_size, hidden_size)
                # where the first intermediate_size rows are w1, the next are w3
                intermediate_size = layer.w13_weight.shape[1] // 2
                for expert_id in range(layer.w13_weight.shape[0]):
                    start = 0
                    for shard_id in range(2):  # w1 and w3
                        # Dequantize using the original scale for this shard
                        dq_weight = per_tensor_dequantize(
522
523
524
                            layer.w13_weight[expert_id][
                                start : start + intermediate_size, :
                            ],
525
526
527
528
529
                            layer.w13_weight_scale[expert_id][shard_id],
                        )
                        # Requantize using the combined max scale

                        (
530
531
532
                            layer.w13_weight[expert_id][
                                start : start + intermediate_size, :
                            ],
533
                            _,
534
                        ) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
535
536
537
538

                        start += intermediate_size

                # Update the scale parameter to be per-expert
539
                layer.w13_weight_scale = Parameter(max_w13_scales, requires_grad=False)
540
            else:
541
542
543
                layer.w13_weight_scale = Parameter(
                    layer.w13_weight_scale.data, requires_grad=False
                )
544

545
546
547
548
        if hasattr(layer, "w2_weight_scale") and layer.w2_weight_scale is not None:
            layer.w2_weight_scale = Parameter(
                layer.w2_weight_scale.data, requires_grad=False
            )
549
        # Input scales must be equal for each expert in fp8 MoE layers.
550
551
552
553
554
555
556
557
        if hasattr(layer, "w13_input_scale") and layer.w13_input_scale is not None:
            layer.w13_input_scale = Parameter(
                layer.w13_input_scale.max(), requires_grad=False
            )
        if hasattr(layer, "w2_input_scale") and layer.w2_input_scale is not None:
            layer.w2_input_scale = Parameter(
                layer.w2_input_scale.max(), requires_grad=False
            )
558

559
        if self.flashinfer_moe_backend is not None:
560
            layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
561
            register_moe_scaling_factors(layer)
562
            if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
563
                rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight)
564

565
    def get_fused_moe_quant_config(
566
        self, layer: torch.nn.Module
567
    ) -> FusedMoEQuantConfig | None:
568
569
570
571
572
        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
            return None

        return fp8_w8a8_moe_quant_config(
            w1_scale=layer.w13_weight_scale,
573
            g1_alphas=(layer.w13_weight_scale * layer.w13_input_scale).squeeze(),
574
            w2_scale=layer.w2_weight_scale,
575
            g2_alphas=(layer.w2_weight_scale * layer.w2_input_scale).squeeze(),
576
            a1_scale=layer.w13_input_scale,
577
            a1_gscale=layer.w13_input_scale,
578
            a2_scale=layer.w2_input_scale,
579
            a2_gscale=1.0 / layer.w2_input_scale,
580
581
582
            per_act_token_quant=False,
        )

583
584
585
586
587
588
589
590
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
        use_grouped_topk: bool = False,
591
592
        topk_group: int | None = None,
        num_expert_group: int | None = None,
593
        global_num_experts: int = -1,
594
595
        expert_map: torch.Tensor | None = None,
        custom_routing_function: Callable | None = None,
596
        scoring_func: str = "softmax",
597
        routed_scaling_factor: float = 1.0,
598
        e_score_correction_bias: torch.Tensor | None = None,
599
600
601
        apply_router_weight_on_input: bool = False,
        activation: str = "silu",
        enable_eplb: bool = False,
602
603
604
605
        expert_load_view: torch.Tensor | None = None,
        logical_to_physical_map: torch.Tensor | None = None,
        logical_replica_count: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
606
607
        if enable_eplb:
            raise NotImplementedError(
608
609
                "EPLB not supported for `ModelOptFp8MoEMethod` yet."
            )
610

611
        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
612
613
614
            assert activation == "silu", (
                f"Expected 'silu' activation but got {activation}"
            )
615
616
617
618
619
620
621
622
623
624
            assert not renormalize
            return apply_flashinfer_per_tensor_scale_fp8(
                layer=layer,
                hidden_states=x,
                router_logits=router_logits,
                routing_bias=e_score_correction_bias,
                global_num_experts=global_num_experts,
                top_k=top_k,
                num_expert_group=num_expert_group,
                topk_group=topk_group,
625
626
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
627

628
        # Expert selection
XuruiYang's avatar
XuruiYang committed
629
        topk_weights, topk_ids, _ = FusedMoE.select_experts(
630
631
632
633
634
635
636
637
638
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
            num_expert_group=num_expert_group,
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
639
            routed_scaling_factor=routed_scaling_factor,
640
            e_score_correction_bias=e_score_correction_bias,
641
            indices_type=self.topk_indices_dtype,
642
        )
643

644
        if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
645
            assert not renormalize
646
647
648
            assert activation == "silu", (
                f"Expected 'silu' activation but got {activation}"
            )
649
650
651
652
653
654
655
656
657
658
659
660
            return flashinfer_cutlass_moe_fp8(
                x,
                layer,
                topk_weights,
                topk_ids,
                inplace=False,
                activation=activation,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
        else:
661
662
            from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts

663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
            assert self.moe_quant_config is not None

            return fused_experts(
                x,
                layer.w13_weight,
                layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                inplace=True,
                activation=activation,
                quant_config=self.moe_quant_config,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
678
679


680
681
682
683
684
685
class ModelOptNvFp4Config(QuantizationConfig):
    """Config class for ModelOpt FP4."""

    def __init__(
        self,
        is_checkpoint_nvfp4_serialized: bool,
686
        kv_cache_quant_algo: str | None,
687
        exclude_modules: list[str],
688
689
        group_size: int = 16,
    ) -> None:
690
        super().__init__()
691
692
693
694
        self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
        if is_checkpoint_nvfp4_serialized:
            logger.warning(
                "Detected ModelOpt NVFP4 checkpoint. Please note that"
695
696
                " the format is experimental and could change in future."
            )
697
698
699
700
701
702

            self.group_size = group_size
            self.kv_cache_quant_algo = kv_cache_quant_algo
            self.exclude_modules = exclude_modules

    @classmethod
703
    def get_name(cls) -> QuantizationMethods:
704
        return "modelopt_fp4"
705
706

    @classmethod
707
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
708
709
710
711
        return [torch.bfloat16, torch.half, torch.float8_e4m3fn]

    @classmethod
    def get_min_capability(cls) -> int:
712
        return 80
713
714

    @classmethod
715
    def get_config_filenames(cls) -> list[str]:
716
717
        return ["hf_quant_config.json"]

718
719
    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if self.exclude_modules is not None:
720
            self.exclude_modules = hf_to_vllm_mapper.apply_list(self.exclude_modules)
721

722
723
    @classmethod
    def override_quantization_method(
724
        cls, hf_quant_cfg, user_quant
725
    ) -> QuantizationMethods | None:
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
        """Detect if this ModelOpt FP4 config should be used based on
        quantization config."""
        if hf_quant_cfg is None:
            return None

        # Use the community standard 'quant_method'
        quant_method = hf_quant_cfg.get("quant_method", "").lower()

        # Only proceed if the method is explicitly "modelopt"
        if quant_method != "modelopt":
            return None

        # Look for ModelOpt-specific config structure
        if "quantization" in hf_quant_cfg:
            quant_config = hf_quant_cfg["quantization"]
            if isinstance(quant_config, dict):
                quant_algo = quant_config.get("quant_algo", "")
                if "NVFP4" in quant_algo:
                    return "modelopt_fp4"
        else:
            # Check for compressed-tensors style config with specific
            # quant_algo field
            quant_algo = hf_quant_cfg.get("quant_algo", "")
            if isinstance(quant_algo, str) and "FP4" in quant_algo.upper():
                return "modelopt_fp4"

        return None

754
    @classmethod
755
    def from_config(cls, config: dict[str, Any]) -> "ModelOptNvFp4Config":
756
757
758
759
760
761
762
        # Handle both traditional ModelOpt format and compressed-tensors
        # style format
        if "quantization" in config:
            # Traditional ModelOpt format:
            # {"quantization": {"quant_algo": "..."}}
            quant_config = cls.get_from_keys(config, ["quantization"])
            if not isinstance(quant_config, dict):
763
                raise ValueError("Expected 'quantization' to be a dictionary in config")
764
765
766
767
768
769
770
771
772
773
774
775
776

            quant_method = quant_config.get("quant_algo", "")
            if not quant_method:
                raise ValueError("Missing 'quant_algo' in quantization config")

            # Handle kv_cache_quant_algo with proper type validation
            kv_cache_quant_algo_raw = quant_config.get("kv_cache_quant_algo")
            if kv_cache_quant_algo_raw is None:
                # No KV cache quantization by default
                kv_cache_quant_algo = None
            elif isinstance(kv_cache_quant_algo_raw, str):
                kv_cache_quant_algo = kv_cache_quant_algo_raw
            else:
777
778
779
780
                raise ValueError(
                    f"kv_cache_quant_algo must be a string, got "
                    f"{type(kv_cache_quant_algo_raw)}"
                )
781
782
783
784
785
786
787
788
789
790
791

            # Handle group_size with proper type validation
            group_size_raw = quant_config.get("group_size")
            if group_size_raw is None:
                group_size = 16  # Default value
            elif isinstance(group_size_raw, int):
                group_size = group_size_raw
            else:
                try:
                    group_size = int(group_size_raw)
                except (ValueError, TypeError):
792
793
794
                    raise ValueError(
                        f"group_size must be an integer, got {type(group_size_raw)}"
                    ) from None
795

796
            # "exclude_modules" is the key in the legacy hf_quant_config.json
797
798
            exclude_modules = quant_config.get("exclude_modules", [])
            if not isinstance(exclude_modules, list):
799
800
801
                raise ValueError(
                    f"exclude_modules must be a list, got {type(exclude_modules)}"
                )
802
803
804
805
806
807
808
809
810
811
812
813
814
        else:
            # Compressed-tensors style format:
            # {"quant_algo": "...", "quant_method": "modelopt"}
            quant_method = config.get("quant_algo", "")

            # Handle kv_cache_quant_algo with proper type validation
            kv_cache_quant_algo_raw = config.get("kv_cache_quant_algo")
            if kv_cache_quant_algo_raw is None:
                # No KV cache quantization by default
                kv_cache_quant_algo = None
            elif isinstance(kv_cache_quant_algo_raw, str):
                kv_cache_quant_algo = kv_cache_quant_algo_raw
            else:
815
816
817
818
                raise ValueError(
                    f"kv_cache_quant_algo must be a string, got "
                    f"{type(kv_cache_quant_algo_raw)}"
                )
819
820
821
822
823
824
825
826
827
828
829

            # Handle group_size with proper type validation
            group_size_raw = config.get("group_size")
            if group_size_raw is None:
                group_size = 16  # Default value
            elif isinstance(group_size_raw, int):
                group_size = group_size_raw
            else:
                try:
                    group_size = int(group_size_raw)
                except (ValueError, TypeError):
830
831
832
                    raise ValueError(
                        f"group_size must be an integer, got {type(group_size_raw)}"
                    ) from None
833

834
835
            # "ignore" is the key in config.json
            exclude_modules = config.get("ignore", [])
836
            if not isinstance(exclude_modules, list):
837
838
839
                raise ValueError(
                    f"exclude_modules must be a list, got {type(exclude_modules)}"
                )
840

841
        if quant_method not in QUANT_ALGOS:
842
843
844
845
            raise ValueError(
                f"ModelOpt currently only supports: {QUANT_ALGOS} "
                "quantizations in vLLM. Please check the "
                "`hf_quant_config.json` file for your model's "
846
847
848
                "quant configuration."
            )
        is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
849
850
851
852
853

        # For FP4, these fields are required
        if is_checkpoint_nvfp4_serialized and "quantization" in config:
            # Check if required fields are present in the quantization config
            quant_config = config["quantization"]
854
            required_fields = ["group_size", "kv_cache_quant_algo", "exclude_modules"]
855
856
857
858
859
860
            missing_fields = [
                field for field in required_fields if field not in quant_config
            ]
            if missing_fields:
                raise ValueError(
                    f"NVFP4 quantization requires the following fields in "
861
862
863
864
865
866
867
868
869
                    f"hf_quant_config.json: {missing_fields}"
                )

        return cls(
            is_checkpoint_nvfp4_serialized,
            kv_cache_quant_algo,
            exclude_modules,
            group_size,
        )
870

871
872
873
874
875
876
    def is_layer_excluded(self, prefix: str) -> bool:
        """
        Check if a layer should be excluded from quantization.
        Handles both exact matching (for fused layers) and pattern matching.
        """
        # First check exact matching with fused layer support
877
        if is_layer_skipped(prefix, self.exclude_modules, self.packed_modules_mapping):
878
879
880
            return True

        # Check regex pattern matching for patterns not caught by exact match
881
        import regex as re
882

883
884
        for pattern in self.exclude_modules:
            # Skip patterns that would be caught by exact matching
885
886
            if "*" in pattern or "." in pattern:
                regex_str = pattern.replace(".", r"\.").replace("*", r".*")
887
888
                if re.fullmatch(regex_str, prefix):
                    return True
889
890
        return False

891
892
893
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
894
895
896
897
        from vllm.attention.layer import (  # Avoid circular import
            Attention,
            MLAAttention,
        )
898

899
        skip_layer = self.is_layer_excluded(prefix)
900
        if isinstance(layer, LinearBase):
901
            if skip_layer:
902
903
                return UnquantizedLinearMethod()
            # Check if this is a vision model layer that should not be quantized
904
            if "vision_tower" in prefix or "vision_model" in prefix:
905
906
                return UnquantizedLinearMethod()
            return ModelOptNvFp4LinearMethod(self)
907
        elif isinstance(layer, (Attention, MLAAttention)):
908
            return ModelOptFp8KVCacheMethod(self)
909
        elif isinstance(layer, FusedMoE):
910
911
            if skip_layer:
                return None
912
            return ModelOptNvFp4FusedMoE(self, layer.moe_config, layer)
913
914
915
916
917
918
919
920
        return None


class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
    """

921
    def __init__(self, quant_config: ModelOptFp8Config | ModelOptNvFp4Config):
922
923
924
925
926
927
        super().__init__(quant_config)


class ModelOptNvFp4LinearMethod(LinearMethodBase):
    """Linear method for Model Optimizer NVFP4.
    Supports loading NVFP4 checkpoints with the following structure:
928

929
930
931
932
933
934
935
    input_scale: torch.float32, scalar ,
    weight: NVFP4(represented as byte) Shape: [1, X, y/2]
    weight_scale: FP8-E4M3, Shape: [X, Y], aka per block scale,
    weight_scale_2: torch.float32, scalar,
    Args: quant_config: The ModelOpt quantization config.
    """

936
    def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
937
        self.quant_config = quant_config
938

939
940
941
942
943
944
945
946
947
948
949
        self.backend = "none"
        if envs.VLLM_NVFP4_GEMM_BACKEND is None:
            if has_flashinfer():
                self.backend = "flashinfer-cutlass"
            elif cutlass_fp4_supported():
                self.backend = "cutlass"
            elif is_fp4_marlin_supported():
                self.backend = "marlin"
        elif envs.VLLM_NVFP4_GEMM_BACKEND.startswith("flashinfer-"):
            self.backend = envs.VLLM_NVFP4_GEMM_BACKEND
            assert has_flashinfer(), f"FlashInfer is required for {self.backend}"
950
951
952
        elif envs.VLLM_NVFP4_GEMM_BACKEND == "cutlass":
            self.backend = "cutlass"
            assert cutlass_fp4_supported(), f"Cutlass is required for {self.backend}"
953
954

        if self.backend == "none":
955
            raise ValueError(
956
957
                "No valid NVFP4 GEMM backend found. "
                "Please check your platform capability."
958
            )
959

960
961
        logger.info_once(f"Using {self.backend} for NVFP4 GEMM")

962
963
964
965
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
966
        output_partition_sizes: list[int],
967
968
969
970
971
972
973
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size
        if not self.quant_config.is_checkpoint_nvfp4_serialized:
974
975
976
977
            raise ValueError(
                "NVFP4 quantization was selected, "
                " dynamic quantization is not supported."
            )
978
979
980
981
982
983
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition

984
985
986
987
        if input_size_per_partition % 16 != 0:
            raise ValueError(
                "Unsupported model when in features size is not multiple of 16"
            )
988
        # The nvfp4 weight is still represented as
989
990
991
992
993
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_nvfp4_serialized
            else params_dtype
        )
994
995
996
997
998
999
        # Weight
        weight = ModelWeightParameter(
            data=torch.empty(
                # 2 fp4 items are packed in the input dimension
                layer.output_size_per_partition,
                layer.input_size_per_partition // 2,
1000
1001
                dtype=torch.uint8,
            ),
1002
1003
            input_dim=1,
            output_dim=0,
1004
1005
            weight_loader=weight_loader,
        )
1006
1007
1008
        layer.register_parameter("weight", weight)

        # Input Weight Scale
1009
1010
1011
1012
        input_scale = PerTensorScaleParameter(
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )
1013
1014
1015
        layer.register_parameter("input_scale", input_scale)

        # Global Weight Scale
1016
1017
1018
1019
        weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )
1020
1021
1022
        layer.register_parameter("weight_scale_2", weight_scale_2)

        # Per Block Weight Scale
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
        weight_scale = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition // self.quant_config.group_size,
                dtype=weight_dtype,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043

        layer.register_parameter("weight_scale", weight_scale)

    def process_weights_after_loading(self, layer: Module) -> None:
        # global scales:
        input_scale_2 = layer.input_scale.max().to(torch.float32)
        layer.input_scale = Parameter(input_scale_2, requires_grad=False)

        weight_scale_2 = layer.weight_scale_2.max().to(torch.float32)
        layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False)

1044
1045
1046
        layer.alpha = Parameter(
            layer.input_scale * layer.weight_scale_2, requires_grad=False
        )
1047

1048
1049
        # Calculate `1 / input_scale` so that we don't need to do so at runtime
        layer.input_scale_inv = Parameter(
1050
1051
            (1 / layer.input_scale).to(torch.float32), requires_grad=False
        )
1052

1053
1054
1055
        # Swizzle the weight blockscale.
        # contracting dimension is input dimension
        # block_size = 16;
1056
1057
1058
        assert layer.weight_scale.dtype == torch.float8_e4m3fn, (
            "Weight Block scale must be represented as FP8-E4M3"
        )
1059

1060
1061
1062
1063
1064
        if self.backend == "marlin":
            prepare_fp4_layer_for_marlin(layer)
            del layer.alpha
            del layer.input_scale
        elif self.backend == "flashinfer-trtllm":
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
            # FlashInfer TRTLLM FP4 GEMM requires a different weight layout.
            # FlashInfer provides nvfp4_quantize to quantize + shuffle the
            # layout but we use our own quantization so we have to call
            # shuffles ourselves.
            from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a

            weight = layer.weight.data
            weight_scale = layer.weight_scale.data

            epilogue_tile_m = 128
1075
1076
1077
1078
1079
1080
            weight = shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m)
            weight_scale = (
                shuffle_matrix_sf_a(weight_scale.view(torch.uint8), epilogue_tile_m)
                .reshape(weight_scale.shape)
                .view(torch.float8_e4m3fn)
            )
1081

1082
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
1083
1084
1085
            layer.weight = Parameter(weight, requires_grad=False)
        else:
            swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
1086
            layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False)
1087
            layer.weight = Parameter(layer.weight.data, requires_grad=False)
1088
1089
1090
1091
1092

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
1093
        bias: torch.Tensor | None = None,
1094
    ) -> torch.Tensor:
1095
        if self.backend == "marlin":
1096
1097
1098
1099
1100
1101
1102
1103
            return apply_fp4_marlin_linear(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale,
                weight_scale_2=layer.weight_scale_2,
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
1104
1105
                bias=bias,
            )
1106

1107
        output_dtype = x.dtype
1108
        output_shape = [x.shape[0], layer.weight.shape[0]]
1109
1110

        # quantize BF16 or FP16 to (FP4 and interleaved block scale)
1111
        x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_scale_inv)
1112
1113
1114

        # validate dtypes of quantized input, input block scale,
        # weight and weight_blockscale
1115
1116
1117
1118
1119
        assert x_fp4.dtype == torch.uint8
        assert layer.weight.dtype == torch.uint8
        assert x_blockscale.dtype == torch.float8_e4m3fn
        assert layer.weight_scale.dtype == torch.float8_e4m3fn
        assert layer.alpha.dtype == torch.float32
1120

1121
1122
1123
1124
        mm_args = (
            x_fp4,
            layer.weight,
            x_blockscale,
1125
            layer.weight_scale,
1126
1127
1128
            layer.alpha,
            output_dtype,
        )
1129
1130
1131
        if self.backend.startswith("flashinfer-"):
            backend_name = self.backend[len("flashinfer-") :]
            out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name)
1132
        else:
1133
            assert self.backend == "cutlass"
1134
1135
            out = cutlass_scaled_fp4_mm(*mm_args)

1136
1137
1138
        if bias is not None:
            out = out + bias
        return out.view(*output_shape)
1139
1140
1141
1142
1143


class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
    """
    MoE Method for FP4 Quantization.
1144
    Args:
1145
1146
1147
        quant_config: NVFP4 Quant Config
    """

1148
1149
1150
1151
1152
1153
    def __init__(
        self,
        quant_config: ModelOptNvFp4Config,
        moe: FusedMoEConfig,
        layer: torch.nn.Module,
    ) -> None:
1154
1155
        from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import (
            detect_nvfp4_moe_support,  # noqa: E501
1156
1157
        )

1158
1159
1160
        super().__init__(moe)
        self.quant_config = quant_config
        self.layer = layer
1161
1162
        _nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
        self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
1163
        self.allow_flashinfer = _nvfp4.allow_flashinfer
1164
        self.use_marlin = _nvfp4.use_marlin
1165
        self.flashinfer_moe_backend = None
1166
        self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
1167
        if self.allow_flashinfer:
1168
1169
1170
            self.flashinfer_moe_backend = get_flashinfer_moe_backend()
            logger.info_once(
                f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
1171
1172
                " for ModelOptNvFp4FusedMoE."
            )
1173

1174
    def maybe_make_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize | None:
1175
1176
1177
1178
        if self.use_marlin or (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1179
            return None
1180
1181
1182
1183
        elif (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
        ):
1184
            # For now, fp4 moe only works with the flashinfer dispatcher.
1185
1186
1187
            prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(
                self.moe
            )
1188
1189
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
1190
1191
        else:
            return super().maybe_make_prepare_finalize()
1192

1193
1194
1195
    def select_gemm_impl(
        self,
        prepare_finalize: mk.FusedMoEPrepareAndFinalize,
1196
        layer: torch.nn.Module,
1197
    ) -> mk.FusedMoEPermuteExpertsUnpermute:
1198
        assert self.moe_quant_config is not None
1199
        experts = select_nvfp4_gemm_impl(
1200
1201
            self.moe,
            self.moe_quant_config,
1202
1203
1204
1205
            allow_flashinfer=self.allow_flashinfer,
        )
        logger.debug_once("Using %s", experts.__class__.__name__)
        return experts
1206

1207
1208
1209
1210
1211
1212
    def uses_weight_scale_2_pattern(self) -> bool:
        """
        FP4 variants use 'weight_scale_2' pattern for per-tensor weight scales.
        """
        return True

1213
1214
1215
1216
1217
1218
1219
1220
1221
    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
1222
        if not self.quant_config.is_checkpoint_nvfp4_serialized:
1223
1224
1225
1226
            raise ValueError(
                "NVFP4 quantization was selected, "
                " dynamic quantization is not supported."
            )
1227

1228
1229
        layer.num_experts = num_experts
        layer.params_dtype = params_dtype
1230
1231
1232
1233
        layer.quant_config = self.quant_config
        weight_dtype = torch.uint8
        weight_scale_dtype = torch.float8_e4m3fn
        weight_loader = extra_weight_attrs.get("weight_loader")
1234
        global_num_experts = extra_weight_attrs.get("global_num_experts")
1235
1236
1237
1238
1239
1240
1241
        # GEMM 1
        w13_weight = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                # 2 fp4 items are packed in the input dimension
                hidden_size // 2,
1242
1243
                dtype=weight_dtype,
            ),
1244
1245
            input_dim=1,
            output_dim=2,
1246
1247
            weight_loader=weight_loader,
        )
1248
1249
1250
1251
1252
1253
1254
1255
1256
        layer.register_parameter("w13_weight", w13_weight)

        # GEMM 2
        w2_weight = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                hidden_size,
                # 2 fp4 items are packed in the input dimension
                intermediate_size_per_partition // 2,
1257
1258
                dtype=weight_dtype,
            ),
1259
1260
            input_dim=1,
            output_dim=2,
1261
1262
            weight_loader=weight_loader,
        )
1263
1264
1265
1266
1267
1268
1269
1270
        layer.register_parameter("w2_weight", w2_weight)

        w13_weight_scale = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                # 2 fp4 items are packed in the input dimension
                hidden_size // self.quant_config.group_size,
1271
1272
                dtype=weight_scale_dtype,
            ),
1273
1274
            input_dim=1,
            output_dim=2,
1275
1276
            weight_loader=weight_loader,
        )
1277
1278
1279
1280
1281
1282
1283
        layer.register_parameter("w13_weight_scale", w13_weight_scale)

        w2_weight_scale = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                hidden_size,
                # 2 fp4 items are packed in the input dimension
1284
1285
1286
                intermediate_size_per_partition // self.quant_config.group_size,
                dtype=weight_scale_dtype,
            ),
1287
1288
            input_dim=1,
            output_dim=2,
1289
1290
            weight_loader=weight_loader,
        )
1291
1292
1293
        layer.register_parameter("w2_weight_scale", w2_weight_scale)

        extra_weight_attrs.update(
1294
1295
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
        )
1296
1297
1298

        w13_weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(num_experts, 2, dtype=torch.float32),
1299
1300
            weight_loader=weight_loader,
        )
1301
1302
1303
1304
        layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)

        w2_weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(num_experts, dtype=torch.float32),
1305
1306
            weight_loader=weight_loader,
        )
1307
1308
1309
        layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2)

        extra_weight_attrs.update(
1310
1311
            {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
        )
1312

1313
1314
1315
1316
1317
        use_global_sf = self.allow_flashinfer and is_flashinfer_supporting_global_sf(
            self.flashinfer_moe_backend
        )
        global_scale_num_experts = global_num_experts if use_global_sf else num_experts

1318
        w13_input_scale = PerTensorScaleParameter(
1319
            data=torch.empty(global_scale_num_experts, 2, dtype=torch.float32),
1320
1321
            weight_loader=weight_loader,
        )
1322
1323
        layer.register_parameter("w13_input_scale", w13_input_scale)

1324
        w2_input_scale = PerTensorScaleParameter(
1325
            data=torch.empty(global_scale_num_experts, dtype=torch.float32),
1326
1327
            weight_loader=weight_loader,
        )
1328
1329
        layer.register_parameter("w2_input_scale", w2_input_scale)

1330
    def prepare_static_weights_for_trtllm_fp4_moe(
1331
        self,
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
        # args_dequant,
        # args,
        gemm1_weights,
        gemm2_weights,
        gemm1_scales_linear_fp4_bytes,
        gemm2_scales_linear_fp4_bytes,
        hidden_size,
        intermediate_size,
        num_experts,
    ):
        from flashinfer import nvfp4_block_scale_interleave
        from flashinfer.fused_moe.core import (
1344
            _maybe_get_cached_w3_w1_permute_indices,
1345
            get_w2_permute_indices_with_cache,
1346
1347
        )

1348
1349
1350
1351
1352
        """Prepare quantized weights for kernel (done offline with weights)."""
        epilogue_tile_m = 128  # FIXME: this depends on the kernel internals

        # Convert quantized weights to proper formats
        gemm1_weights_fp4 = gemm1_weights.view(torch.float8_e4m3fn).reshape(
1353
1354
            num_experts, 2 * intermediate_size, hidden_size // 2
        )  # packed fp4
1355
        gemm1_scales_linear_fp4 = gemm1_scales_linear_fp4_bytes.view(
1356
1357
1358
1359
            torch.float8_e4m3fn
        ).reshape(
            num_experts, 2 * intermediate_size, hidden_size // 16
        )  # fp8 scaling factors
1360
1361

        gemm2_weights_fp4 = gemm2_weights.view(torch.float8_e4m3fn).reshape(
1362
1363
            num_experts, hidden_size, intermediate_size // 2
        )  # packed fp4
1364
        gemm2_scales_linear_fp4 = gemm2_scales_linear_fp4_bytes.view(
1365
1366
1367
1368
            torch.float8_e4m3fn
        ).reshape(
            num_experts, hidden_size, intermediate_size // 16
        )  # fp8 scaling factors
1369
1370
1371
1372
1373
1374

        gemm1_weights_fp4_shuffled = []
        gemm1_scales_fp4_shuffled = []
        gemm2_weights_fp4_shuffled = []
        gemm2_scales_fp4_shuffled = []
        for i in range(num_experts):
1375
1376
1377
1378
1379
1380
1381
1382
1383
            # Calculate the permute indices for the following:
            # 1. Reorder rows of W1 and scales for fused gated activation
            # 2. Shuffle weights and scaling factors for transposed mma output
            # for both w3_w1 and w2 weights and scale factors
            permute_indices = _maybe_get_cached_w3_w1_permute_indices(
                self._cache_permute_indices,
                gemm1_weights_fp4[i].view(torch.uint8),
                epilogue_tile_m,
            )
1384
1385
1386
1387
1388
            gemm1_weights_fp4_shuffled.append(
                gemm1_weights_fp4[i]
                .view(torch.uint8)[permute_indices.to(gemm1_weights_fp4.device)]
                .contiguous()
            )
1389
1390
1391
1392
1393
1394
1395

            permute_sf_indices = _maybe_get_cached_w3_w1_permute_indices(
                self._cache_permute_indices,
                gemm1_scales_linear_fp4[i].view(torch.uint8),
                epilogue_tile_m,
                num_elts_per_sf=16,
            )
1396
            gemm1_scales_fp4_shuffled.append(
1397
1398
1399
1400
1401
1402
1403
1404
                nvfp4_block_scale_interleave(
                    gemm1_scales_linear_fp4[i]
                    .view(torch.uint8)[
                        permute_sf_indices.to(gemm1_scales_linear_fp4.device)
                    ]
                    .contiguous()
                )
            )
1405

1406
            permute_indices = get_w2_permute_indices_with_cache(
1407
1408
1409
1410
                self._cache_permute_indices,
                gemm2_weights_fp4[i].view(torch.uint8),
                epilogue_tile_m,
            )
1411
1412
1413
1414
1415
            gemm2_weights_fp4_shuffled.append(
                gemm2_weights_fp4[i]
                .view(torch.uint8)[permute_indices.to(gemm2_weights_fp4.device)]
                .contiguous()
            )
1416

1417
            permute_sf_indices = get_w2_permute_indices_with_cache(
1418
1419
1420
1421
1422
                self._cache_permute_indices,
                gemm2_scales_linear_fp4[i].view(torch.uint8),
                epilogue_tile_m,
                num_elts_per_sf=16,
            )
1423
            gemm2_scales_fp4_shuffled.append(
1424
1425
1426
1427
1428
1429
1430
1431
                nvfp4_block_scale_interleave(
                    gemm2_scales_linear_fp4[i]
                    .view(torch.uint8)[
                        permute_sf_indices.to(gemm2_scales_linear_fp4.device)
                    ]
                    .contiguous()
                )
            )
1432
1433
1434
1435

        # Stack weights for all experts
        gemm1_weights_fp4_shuffled = torch.stack(gemm1_weights_fp4_shuffled)
        gemm1_scales_fp4_shuffled = (
1436
1437
1438
1439
            torch.stack(gemm1_scales_fp4_shuffled)
            .view(torch.float8_e4m3fn)
            .reshape(num_experts, 2 * intermediate_size, hidden_size // 16)
        )
1440
1441
1442

        gemm2_weights_fp4_shuffled = torch.stack(gemm2_weights_fp4_shuffled)
        gemm2_scales_fp4_shuffled = (
1443
1444
1445
1446
            torch.stack(gemm2_scales_fp4_shuffled)
            .view(torch.float8_e4m3fn)
            .reshape(num_experts, hidden_size, intermediate_size // 16)
        )
1447
1448
1449
1450
1451
1452
        return (
            gemm1_weights_fp4_shuffled,
            gemm1_scales_fp4_shuffled,
            gemm2_weights_fp4_shuffled,
            gemm2_scales_fp4_shuffled,
        )
1453

1454
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
1455
        # GEMM 1 processing
1456
1457
1458
        gemm1_weight = layer.w13_weight.data
        gemm1_weight_scale = layer.w13_weight_scale.data

1459
        if self.allow_flashinfer:
1460
            gemm1_weight, gemm1_weight_scale = reorder_w1w3_to_w3w1(
1461
1462
                gemm1_weight, gemm1_weight_scale, dim=-2
            )
1463
1464

        layer.w13_weight = Parameter(gemm1_weight, requires_grad=False)
1465
        layer.w13_weight_scale = Parameter(gemm1_weight_scale, requires_grad=False)
1466

1467
        # Common processing for w13_weight_scale_2
1468
1469
1470
        if not torch.allclose(
            layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]
        ):
1471
1472
            logger.warning_once(
                "w1_weight_scale_2 must match w3_weight_scale_2. "
1473
1474
                "Accuracy may be affected."
            )
1475
1476

        w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
1477
        layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False)
1478

1479
        # Common processing for input scales and alphas
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
        use_global_sf = self.allow_flashinfer and is_flashinfer_supporting_global_sf(
            self.flashinfer_moe_backend
        )
        if use_global_sf:
            # For backends provide by Flashinfer, the input global scales are
            # shared across all experts.
            w13_input_scale = (
                layer.w13_input_scale.max().to(torch.float32).expand(layer.num_experts)
            )
        else:
            w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
1491
1492
        layer.g1_alphas = Parameter(
            (w13_input_scale * w13_weight_scale_2).to(torch.float32),
1493
1494
            requires_grad=False,
        )
1495
1496
1497

        # This is for quantization, so we need to invert it.
        layer.w13_input_scale_quant = Parameter(
1498
1499
            (1 / w13_input_scale).to(torch.float32), requires_grad=False
        )
1500

1501
        # GEMM 2 processing
1502
1503
1504
1505
1506
1507
1508
1509
        if use_global_sf:
            # For backends provide by Flashinfer, the input global scales are
            # shared across all experts.
            w2_input_scale = (
                layer.w2_input_scale.max().to(torch.float32).expand(layer.num_experts)
            )
        else:
            w2_input_scale = layer.w2_input_scale
1510
        layer.g2_alphas = Parameter(
1511
            (w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
1512
1513
            requires_grad=False,
        )
1514
1515
1516

        # This is for quantization, so we need to invert it.
        layer.w2_input_scale_quant = Parameter(
1517
            (1 / w2_input_scale).to(torch.float32), requires_grad=False
1518
        )
1519

1520
        # TensorRT-LLM specific processing
1521
1522
1523
1524
        if (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1525
            # Prepare static weights for TRT-LLM kernel
1526
            # alternate: prepare_static_weight_layouts_for_trtllm_moe
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
            (
                gemm1_weights_fp4_shuffled,
                gemm1_scales_fp4_shuffled,
                gemm2_weights_fp4_shuffled,
                gemm2_scales_fp4_shuffled,
            ) = self.prepare_static_weights_for_trtllm_fp4_moe(
                layer.w13_weight,
                layer.w2_weight,
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                layer.w2_weight.size(-2),  # hidden_size
                layer.w13_weight.size(-2) // 2,  # intermediate_size
                layer.w13_weight.size(0),  # num_experts
            )
1541
            logger.debug_once("Finished shuffling weights for TRT-LLM MOE")
1542
1543

            layer.gemm1_weights_fp4_shuffled = Parameter(
1544
1545
                gemm1_weights_fp4_shuffled, requires_grad=False
            )
1546
            layer.gemm2_weights_fp4_shuffled = Parameter(
1547
1548
                gemm2_weights_fp4_shuffled, requires_grad=False
            )
1549
            layer.gemm1_scales_fp4_shuffled = Parameter(
1550
1551
                gemm1_scales_fp4_shuffled, requires_grad=False
            )
1552
            layer.gemm2_scales_fp4_shuffled = Parameter(
1553
1554
                gemm2_scales_fp4_shuffled, requires_grad=False
            )
1555
1556
1557

            # Additional parameter needed for TRT-LLM
            layer.g1_scale_c = Parameter(
1558
                (layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32),
1559
1560
                requires_grad=False,
            )
1561

1562
1563
1564
1565
1566
            # Clean up weights that won't be used by TRT-LLM
            del layer.w2_weight
            del layer.w2_weight_scale
            del layer.w13_weight
            del layer.w13_weight_scale
1567
1568
1569
1570
1571
1572
1573
        elif self.use_marlin:
            # Marlin processing
            prepare_moe_fp4_layer_for_marlin(layer)
            del layer.g1_alphas
            del layer.g2_alphas
            del layer.w13_input_scale_quant
            del layer.w2_input_scale_quant
1574
1575
        else:
            # Non-TRT-LLM processing (Cutlass or non-flashinfer)
1576
1577
1578
1579
1580
            w13_blockscale_swizzled = swizzle_blockscale(layer.w13_weight_scale)
            layer.w13_weight_scale = Parameter(
                w13_blockscale_swizzled, requires_grad=False
            )

1581
            w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale)
1582
1583
1584
1585
            layer.w2_weight_scale = Parameter(
                w2_blockscale_swizzled, requires_grad=False
            )
            layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
1586

1587
    def get_fused_moe_quant_config(
1588
        self, layer: torch.nn.Module
1589
    ) -> FusedMoEQuantConfig | None:
1590
1591
1592
1593
        if (
            self.use_marlin
            or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
            return None

        return nvfp4_moe_quant_config(
            w1_scale=layer.w13_weight_scale,
            w2_scale=layer.w2_weight_scale,
            g1_alphas=layer.g1_alphas,
            g2_alphas=layer.g2_alphas,
            a1_gscale=layer.w13_input_scale_quant,
            a2_gscale=layer.w2_input_scale_quant,
        )

1605
1606
1607
1608
1609
1610
1611
1612
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
        use_grouped_topk: bool = False,
1613
1614
        topk_group: int | None = None,
        num_expert_group: int | None = None,
1615
        global_num_experts: int = -1,
1616
1617
        expert_map: torch.Tensor | None = None,
        custom_routing_function: Callable | None = None,
1618
        scoring_func: str = "softmax",
1619
        routed_scaling_factor: float = 1.0,
1620
        e_score_correction_bias: torch.Tensor | None = None,
1621
1622
        apply_router_weight_on_input: bool = False,
        activation: str = "silu",
1623
        enable_eplb: bool = False,
1624
1625
1626
1627
        expert_load_view: torch.Tensor | None = None,
        logical_to_physical_map: torch.Tensor | None = None,
        logical_replica_count: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
1628
1629
        if enable_eplb:
            raise NotImplementedError(
1630
1631
                "EPLB not supported for `ModelOptNvFp4FusedMoE` yet."
            )
1632
        assert activation == "silu", "Only SiLU activation is supported."
1633

1634
1635
1636
1637
        if (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1638
1639
1640
1641
1642
            import flashinfer

            from vllm.model_executor.models.llama4 import Llama4MoE

            a1_gscale = layer.w13_input_scale_quant
1643
1644
1645
1646
1647
1648
1649
1650
            (hidden_states_fp4, hidden_states_scale_linear_fp4) = (
                flashinfer.fp4_quantize(
                    x,
                    a1_gscale,
                    is_sf_swizzled_layout=False,
                )
            )
            use_llama4_routing = (
1651
                custom_routing_function is Llama4MoE.custom_routing_function
1652
            )
1653
1654
1655
            routing_method_type = flashinfer.RoutingMethodType.DeepSeekV3
            if use_llama4_routing:
                routing_method_type = flashinfer.RoutingMethodType.Llama4
Shu Wang's avatar
Shu Wang committed
1656
1657
1658
            routing_bias = e_score_correction_bias
            if routing_bias is not None:
                routing_bias = routing_bias.to(torch.bfloat16)
1659
1660
            out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
                routing_logits=router_logits
1661
1662
                if use_llama4_routing
                else router_logits.to(torch.float32),
Shu Wang's avatar
Shu Wang committed
1663
                routing_bias=routing_bias,
1664
1665
                hidden_states=hidden_states_fp4,
                hidden_states_scale=hidden_states_scale_linear_fp4.view(
1666
1667
                    torch.float8_e4m3fn
                ).flatten(),
1668
1669
                gemm1_weights=layer.gemm1_weights_fp4_shuffled.data,
                gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.data.view(
1670
1671
                    torch.float8_e4m3fn
                ),
1672
1673
1674
1675
1676
1677
                gemm1_bias=None,
                gemm1_alpha=None,
                gemm1_beta=None,
                gemm1_clamp_limit=None,
                gemm2_weights=layer.gemm2_weights_fp4_shuffled.data,
                gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.data.view(
1678
1679
                    torch.float8_e4m3fn
                ),
1680
1681
1682
1683
1684
1685
                gemm2_bias=None,
                output1_scale_scalar=layer.g1_scale_c.data,
                output1_scale_gate_scalar=layer.g1_alphas.data,
                output2_scale_scalar=layer.g2_alphas.data,
                num_experts=global_num_experts,
                top_k=top_k,
1686
                n_group=num_expert_group if num_expert_group is not None else 0,
1687
                topk_group=topk_group if topk_group is not None else 0,
1688
1689
1690
1691
                intermediate_size=layer.intermediate_size_per_partition,
                local_expert_offset=layer.ep_rank * layer.local_num_experts,
                local_num_experts=layer.local_num_experts,
                routed_scaling_factor=None,
1692
                tile_tokens_dim=None,
1693
1694
1695
1696
1697
                routing_method_type=routing_method_type,
                do_finalize=True,
            )[0]
            return out

XuruiYang's avatar
XuruiYang committed
1698
        topk_weights, topk_ids, _ = FusedMoE.select_experts(
1699
1700
1701
1702
1703
1704
1705
1706
1707
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
            num_expert_group=num_expert_group,
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
1708
            routed_scaling_factor=routed_scaling_factor,
1709
            e_score_correction_bias=e_score_correction_bias,
1710
1711
            indices_type=self.topk_indices_dtype,
        )
1712

1713
        if self.use_marlin:
1714
            return fused_marlin_moe(
1715
1716
1717
                x,
                layer.w13_weight,
                layer.w2_weight,
1718
1719
                None,
                None,
1720
1721
1722
1723
1724
1725
1726
1727
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                router_logits,
                topk_weights,
                topk_ids,
                global_scale1=layer.w13_weight_scale_2,
                global_scale2=layer.w2_weight_scale_2,
                quant_type_id=scalar_types.float4_e2m1f.id,
1728
                apply_router_weight_on_input=apply_router_weight_on_input,
1729
                global_num_experts=global_num_experts,
1730
                expert_map=expert_map,
1731
1732
                workspace=layer.workspace,
            )
1733

1734
1735
1736
1737
1738
1739
        elif (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
        ):
            from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (  # noqa: E501
                flashinfer_cutlass_moe_fp4,
1740
            )
1741

1742
            assert self.moe_quant_config is not None
1743

1744
            return flashinfer_cutlass_moe_fp4(
1745
1746
1747
1748
1749
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
1750
1751
                quant_config=self.moe_quant_config,
                inplace=False,
1752
1753
1754
1755
1756
1757
                activation=activation,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
        else:
1758
1759
            # If no modular kernel is provided, use cutlass_moe_fp4 for TP case
            # only (no EP).
1760
1761
            from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4

1762
1763
            assert self.moe_quant_config is not None
            return cutlass_moe_fp4(
1764
1765
1766
1767
1768
                a=x,
                w1_fp4=layer.w13_weight,
                w2_fp4=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
1769
1770
1771
1772
                quant_config=self.moe_quant_config,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
                # TODO: derive from arguments
1773
1774
1775
1776
                m=x.shape[0],
                n=layer.w2_weight.shape[2] * 2,
                k=x.shape[1],
                e=layer.w13_weight.shape[0],
1777
            )