modelopt.py 67.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Optional
6
7
8
9
10

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter

11
12
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
13
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
14
from vllm.logger import init_logger
15
from vllm.model_executor.layers.fused_moe.config import (
16
17
18
19
20
    FusedMoEConfig,
    FusedMoEQuantConfig,
    fp8_w8a8_moe_quant_config,
    nvfp4_moe_quant_config,
)
21
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
22
from vllm.model_executor.layers.fused_moe.layer import (
23
24
25
26
27
28
29
30
31
    FusedMoE,
    FusedMoEMethodBase,
    FusedMoeWeightScaleSupported,
)
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
)
32
from vllm.model_executor.layers.quantization import QuantizationMethods
33
from vllm.model_executor.layers.quantization.base_config import (
34
35
36
    QuantizationConfig,
    QuantizeMethodBase,
)
37
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
38
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
39
40
41
42
    build_flashinfer_fp4_cutlass_moe_prepare_finalize,
    reorder_w1w3_to_w3w1,
    select_nvfp4_gemm_impl,
)
43
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
44
45
    FlashinferMoeBackend,
    apply_flashinfer_per_tensor_scale_fp8,
46
    build_flashinfer_fp8_cutlass_moe_prepare_finalize,
47
48
    flashinfer_cutlass_moe_fp8,
    get_flashinfer_moe_backend,
49
    is_flashinfer_supporting_global_sf,
50
51
52
53
54
    register_moe_scaling_factors,
    rotate_flashinfer_fp8_moe_weights,
    select_cutlass_fp8_gemm_impl,
    swap_w13_to_w31,
)
55
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
56
57
58
59
60
    apply_fp4_marlin_linear,
    is_fp4_marlin_supported,
    prepare_fp4_layer_for_marlin,
    prepare_moe_fp4_layer_for_marlin,
)
61
from vllm.model_executor.layers.quantization.utils.quant_utils import (
62
63
64
65
66
    GroupShape,
    cutlass_fp4_supported,
    is_layer_skipped,
    swizzle_blockscale,
)
67
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
68
69
70
71
    Fp8LinearOp,
    requantize_with_max_scale,
)
from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
72
from vllm.scalar_type import scalar_types
73
74
75
76
77
from vllm.utils.flashinfer import (
    flashinfer_scaled_fp4_mm,
    has_flashinfer,
    has_flashinfer_moe,
)
78

79
80
81
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

82
83
logger = init_logger(__name__)

84
85
QUANT_ALGOS = ["FP8", "NVFP4"]
KV_CACHE_QUANT_ALGOS = ["FP8"]
86
87
88
89
90
91
92
93


class ModelOptFp8Config(QuantizationConfig):
    """Config class for ModelOpt FP8."""

    def __init__(
        self,
        is_checkpoint_fp8_serialized: bool = False,
94
95
        kv_cache_quant_method: str | None = None,
        exclude_modules: list[str] | None = None,
96
    ) -> None:
97
        super().__init__()
98
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
99
        self.kv_cache_quant_method = kv_cache_quant_method
100
        self.exclude_modules = exclude_modules or []
101
        if is_checkpoint_fp8_serialized:
102
103
104
105
            logger.warning(
                "Detected ModelOpt fp8 checkpoint. Please note that"
                " the format is experimental and could change."
            )
106
107

    @classmethod
108
    def get_name(cls) -> QuantizationMethods:
109
110
111
        return "modelopt"

    @classmethod
112
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
113
114
115
116
117
118
119
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
        return 89

    @classmethod
120
    def get_config_filenames(cls) -> list[str]:
121
122
        return ["hf_quant_config.json"]

123
124
    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if self.exclude_modules is not None:
125
            self.exclude_modules = hf_to_vllm_mapper.apply_list(self.exclude_modules)
126

127
128
    @classmethod
    def override_quantization_method(
129
        cls, hf_quant_cfg, user_quant
130
    ) -> QuantizationMethods | None:
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
        """Detect if this ModelOpt config should be used based on
        quantization config."""

        if hf_quant_cfg is None:
            return None

        # Use the community standard 'quant_method'
        quant_method = hf_quant_cfg.get("quant_method", "").lower()

        # Only proceed if the method is explicitly "modelopt"
        if quant_method != "modelopt":
            return None

        # Look for ModelOpt-specific config structure
        if "quantization" in hf_quant_cfg:
            quant_config = hf_quant_cfg["quantization"]
            if isinstance(quant_config, dict):
                quant_algo = quant_config.get("quant_algo", "")
                if "FP8" in quant_algo:
                    return "modelopt"
        else:
            # Check for compressed-tensors style config with specific quant_algo
            quant_algo = hf_quant_cfg.get("quant_algo", "")
            if isinstance(quant_algo, str) and "FP8" in quant_algo:
                return "modelopt"

        return None

159
    @classmethod
160
    def from_config(cls, config: dict[str, Any]) -> "ModelOptFp8Config":
161
162
163
164
165
        # Handle both ModelOpt format and compressed-tensors style format
        if "quantization" in config:
            # ModelOpt format: {"quantization": {"quant_algo": "..."}}
            quant_config = cls.get_from_keys(config, ["quantization"])
            if not isinstance(quant_config, dict):
166
                raise ValueError("Expected 'quantization' to be a dictionary in config")
167
168
169
170
            quant_method = quant_config.get("quant_algo", "")
            if not quant_method:
                raise ValueError("Missing 'quant_algo' in quantization config")
            kv_cache_quant_method = quant_config.get("kv_cache_quant_algo")
171
            # "exclude_modules" is the key in the legacy hf_quant_config.json
172
173
174
175
176
177
            exclude_modules = quant_config.get("exclude_modules")
        else:
            # Compressed-tensors style format:
            # {"quant_algo": "...", "quant_method": "modelopt"}
            quant_method = config.get("quant_algo", "")
            kv_cache_quant_method = config.get("kv_cache_quant_algo")
178
179
            # "ignore" is the key in config.json
            exclude_modules = config.get("ignore")
180

181
        if quant_method not in QUANT_ALGOS:
182
183
184
185
            raise ValueError(
                f"ModelOpt currently only supports: {QUANT_ALGOS} "
                "quantizations in vLLM. Please check the "
                "`hf_quant_config.json` file for your model's "
186
187
188
                "quant configuration."
            )
        is_checkpoint_fp8_serialized = "FP8" in quant_method
189

190
        return cls(is_checkpoint_fp8_serialized, kv_cache_quant_method, exclude_modules)
191
192
193
194

    def is_layer_excluded(self, prefix: str) -> bool:
        """
        Check if a layer should be excluded from quantization.
195
        Handles both exact matching (for fused layers) and substring matching.
196
197
198
199
200
201
202
203

        This method handles both regular models and multimodal models that use
        the language_model prefix. For multimodal models, it checks if the
        module name (without the language_model prefix) is in the exclude list.
        """
        if self.exclude_modules is None:
            return False

204
        # First check exact matching with fused layer support
205
        if is_layer_skipped(prefix, self.exclude_modules, self.packed_modules_mapping):
206
207
208
            return True

        # Then check substring matching for patterns not caught by exact match
209
        for module in self.exclude_modules:
210
            # Skip exact matches already handled above
211
212
213
214
215
216
217
            if module != prefix and (
                module in prefix
                or (
                    prefix.startswith("language_model.")
                    and module in prefix.removeprefix("language_model.")
                )
            ):
218
219
                return True
        return False
220

221
222
223
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
224
        from vllm.attention.layer import Attention  # Avoid circular import
225

226
        if isinstance(layer, LinearBase):
227
228
229
            if self.is_layer_excluded(prefix):
                return UnquantizedLinearMethod()
            # Check if this is a vision model layer that should not be quantized
230
            if "vision_tower" in prefix or "vision_model" in prefix:
231
                return UnquantizedLinearMethod()
232
233
234
            return ModelOptFp8LinearMethod(self)
        elif isinstance(layer, Attention):
            return ModelOptFp8KVCacheMethod(self)
235
        elif isinstance(layer, FusedMoE):
236
            return ModelOptFp8MoEMethod(self, layer)
237
238
239
240
241
242
        return None


class ModelOptFp8LinearMethod(LinearMethodBase):
    """Linear method for Model Optimizer static quantization.
    Supports loading FP8 checkpoints with static weight scale and
243
    activation scale. Future support might be added for dynamic
244
245
246
247
    scales.

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
248
    2. Only support float8_e4m3fn datatype
249
250
251
        Args: quant_config: The ModelOpt quantization config.
    """

252
    def __init__(self, quant_config: ModelOptFp8Config) -> None:
253
        self.quant_config = quant_config
254
        self.fp8_linear = Fp8LinearOp(
255
256
            act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR
        )
257
258
259
260
261

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
262
        output_partition_sizes: list[int],
263
264
265
266
267
268
269
270
271
272
273
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
274
275
276
277
278
279
280
281
282
283
284
285
286
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_fp8_serialized
            else params_dtype
        )
        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition, input_size_per_partition, dtype=weight_dtype
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
287
288
289
290
        layer.register_parameter("weight", weight)

        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
291
292
293
294
            weight_scale = PerTensorScaleParameter(
                data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
                weight_loader=weight_loader,
            )
295
296
297
            weight_scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("weight_scale", weight_scale)
            # INPUT SCALE
298
299
300
301
            scale = PerTensorScaleParameter(
                data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
                weight_loader=weight_loader,
            )
302
303
304
305
306

            scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("input_scale", scale)

    def process_weights_after_loading(self, layer: Module) -> None:
307
308
309
310
        weight = layer.weight
        max_w_scale = layer.weight_scale.max()
        if not (layer.weight_scale == layer.weight_scale[0]).all():
            max_w_scale, weight = requantize_with_max_scale(
311
312
                layer.weight, layer.weight_scale, layer.logical_widths
            )
313
314
        layer.weight = Parameter(weight.t(), requires_grad=False)
        layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
315
        layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
316
317
318
319
320

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
321
        bias: torch.Tensor | None = None,
322
    ) -> torch.Tensor:
323
324
325
326
327
328
329
        return self.fp8_linear.apply(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            input_scale=layer.input_scale,
            bias=bias,
        )
330
331


332
333
334
335
336
337
338
339
class ModelOptFp8MoEMethod(FusedMoEMethodBase):
    """MoE method for ModelOpt FP8.
    Supports loading FP8 checkpoints with static weight scale and
    activation scale.
    Args:
        quant_config: The ModelOpt quantization config.
    """

340
341
342
    def __init__(
        self,
        quant_config: ModelOptFp8Config,
343
        layer: torch.nn.Module,
344
    ) -> None:
345
346
        super().__init__(layer.moe_config)
        self.layer = layer
347
348
        self.quant_config = quant_config
        from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
349
350
351
            cutlass_fp8_supported,
        )

352
        self.cutlass_fp8_supported = cutlass_fp8_supported()
353
        self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
354
355
356
357
358
        if (
            envs.VLLM_USE_FLASHINFER_MOE_FP8
            and has_flashinfer_moe()
            and self.moe.is_act_and_mul
        ):
359
            self.flashinfer_moe_backend = get_flashinfer_moe_backend()
360
            logger.info_once(
361
362
363
364
                f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
            )

    def maybe_make_prepare_finalize(
365
        self,
366
    ) -> mk.FusedMoEPrepareAndFinalize | None:
367
368
369
370
        # TRT LLM not supported with all2all yet.
        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
            return None
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
371
372
373
            prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
                self.moe
            )
374
375
376
377
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
        else:
            return super().maybe_make_prepare_finalize()
378
379
380
381

    def select_gemm_impl(
        self,
        prepare_finalize: mk.FusedMoEPrepareAndFinalize,
382
        layer: torch.nn.Module,
383
    ) -> mk.FusedMoEPermuteExpertsUnpermute:
384
        assert self.moe_quant_config is not None
385
        experts = select_cutlass_fp8_gemm_impl(
386
387
            self.moe,
            self.moe_quant_config,
388
389
390
        )
        logger.debug_once("Using %s", experts.__class__.__name__)
        return experts
391
392
393
394
395
396
397
398
399
400
401

    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        # Use FP8 dtype if checkpoint is serialized
402
403
404
405
406
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_fp8_serialized
            else params_dtype
        )
407
408
        weight_loader = extra_weight_attrs.get("weight_loader")

409
410
411
412
413
        if self.moe.is_act_and_mul:
            w13_up_dim = 2 * intermediate_size_per_partition
        else:
            w13_up_dim = intermediate_size_per_partition

414
        w13_weight = ModelWeightParameter(
415
416
            data=torch.empty(
                num_experts,
417
                w13_up_dim,
418
419
420
                hidden_size,
                dtype=weight_dtype,
            ),
421
422
423
424
425
426
427
            input_dim=2,
            output_dim=1,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_weight", w13_weight)

        w2_weight = ModelWeightParameter(
428
429
430
431
432
433
            data=torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=weight_dtype,
            ),
434
435
436
437
438
439
440
441
            input_dim=2,
            output_dim=1,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w2_weight", w2_weight)

        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALES - Per-tensor scaling for ModelOpts
442
            # For gated MoE, allocate 2 scales for w1 and w3 respectively.
443
            # They will be combined to a single scale after weight loading.
444
445
446
447
448
            # For non-gated MoE, allocate 1 scale for w13.
            if self.moe.is_act_and_mul:
                w13_weight_scale_shape = (num_experts, 2)
            else:
                w13_weight_scale_shape = (num_experts, 1)
449
450
            w13_weight_scale = PerTensorScaleParameter(
                data=torch.full(
451
                    w13_weight_scale_shape,
452
453
454
455
456
457
                    1.0,
                    dtype=torch.float32,
                ),
                weight_loader=weight_loader,
            )
            w2_weight_scale = PerTensorScaleParameter(
458
                data=torch.full((num_experts,), 1.0, dtype=torch.float32),
459
460
461
462
463
464
465
                weight_loader=weight_loader,
            )
            layer.register_parameter("w13_weight_scale", w13_weight_scale)
            layer.register_parameter("w2_weight_scale", w2_weight_scale)

            # Set weight loader attributes for scales
            extra_weight_attrs.update(
466
467
                {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
            )
468
469
470

            # INPUT SCALES - Per-tensor scaling for ModelOpt
            w13_input_scale = PerTensorScaleParameter(
471
                data=torch.full((num_experts,), 1.0, dtype=torch.float32),
472
473
474
                weight_loader=weight_loader,
            )
            w2_input_scale = PerTensorScaleParameter(
475
                data=torch.full((num_experts,), 1.0, dtype=torch.float32),
476
477
478
479
480
481
482
483
484
485
                weight_loader=weight_loader,
            )
            layer.register_parameter("w13_input_scale", w13_input_scale)
            layer.register_parameter("w2_input_scale", w2_input_scale)

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        """Process FP8 MoE weights after loading from serialized checkpoint.
        Only supports pre-quantized checkpoints with FP8 weights and scales.
        """

486
        layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
487
488
489
490
        layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)

        from vllm._custom_ops import scaled_fp8_quant
        from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
491
492
            per_tensor_dequantize,
        )
493
494

        # Handle scale parameters
495
        if hasattr(layer, "w13_weight_scale") and layer.w13_weight_scale is not None:
496
497
498
            # Fp8 moe kernel needs single weight scale for w13 per expert.
            # We take the max of the w1 and w3 scales
            # then dequant and requant each expert.
499
500
501
502
503
504
505
506
            if (
                layer.w13_weight_scale.dim() == 2
                and layer.w13_weight_scale.shape[1] == 2
            ):
                assert self.moe.is_act_and_mul, (
                    "w13_weight_scale should have 2 elements per expert "
                    "only for gated MoE"
                )
507
508
509
510
511
512
513
514
515
516
517
518
                # Get the maximum scale across w1 and w3 for each expert
                max_w13_scales = layer.w13_weight_scale.max(dim=1).values

                # Requantize each expert's weights using the combined scale
                # w13_weight (num_experts, 2 * intermediate_size, hidden_size)
                # where the first intermediate_size rows are w1, the next are w3
                intermediate_size = layer.w13_weight.shape[1] // 2
                for expert_id in range(layer.w13_weight.shape[0]):
                    start = 0
                    for shard_id in range(2):  # w1 and w3
                        # Dequantize using the original scale for this shard
                        dq_weight = per_tensor_dequantize(
519
520
521
                            layer.w13_weight[expert_id][
                                start : start + intermediate_size, :
                            ],
522
523
524
525
526
                            layer.w13_weight_scale[expert_id][shard_id],
                        )
                        # Requantize using the combined max scale

                        (
527
528
529
                            layer.w13_weight[expert_id][
                                start : start + intermediate_size, :
                            ],
530
                            _,
531
                        ) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
532
533
534
535

                        start += intermediate_size

                # Update the scale parameter to be per-expert
536
                layer.w13_weight_scale = Parameter(max_w13_scales, requires_grad=False)
537
            else:
538
539
540
                layer.w13_weight_scale = Parameter(
                    layer.w13_weight_scale.data, requires_grad=False
                )
541

542
543
544
545
        if hasattr(layer, "w2_weight_scale") and layer.w2_weight_scale is not None:
            layer.w2_weight_scale = Parameter(
                layer.w2_weight_scale.data, requires_grad=False
            )
546
        # Input scales must be equal for each expert in fp8 MoE layers.
547
548
549
550
551
552
553
554
        if hasattr(layer, "w13_input_scale") and layer.w13_input_scale is not None:
            layer.w13_input_scale = Parameter(
                layer.w13_input_scale.max(), requires_grad=False
            )
        if hasattr(layer, "w2_input_scale") and layer.w2_input_scale is not None:
            layer.w2_input_scale = Parameter(
                layer.w2_input_scale.max(), requires_grad=False
            )
555

556
        if self.flashinfer_moe_backend is not None:
557
            layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
558
            register_moe_scaling_factors(layer)
559
            if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
560
                rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight)
561

562
    def get_fused_moe_quant_config(
563
        self, layer: torch.nn.Module
564
    ) -> FusedMoEQuantConfig | None:
565
566
567
568
569
        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
            return None

        return fp8_w8a8_moe_quant_config(
            w1_scale=layer.w13_weight_scale,
570
            g1_alphas=(layer.w13_weight_scale * layer.w13_input_scale).squeeze(),
571
            w2_scale=layer.w2_weight_scale,
572
            g2_alphas=(layer.w2_weight_scale * layer.w2_input_scale).squeeze(),
573
            a1_scale=layer.w13_input_scale,
574
            a1_gscale=layer.w13_input_scale,
575
            a2_scale=layer.w2_input_scale,
576
            a2_gscale=1.0 / layer.w2_input_scale,
577
578
579
            per_act_token_quant=False,
        )

580
581
582
583
584
585
586
587
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
        use_grouped_topk: bool = False,
588
589
        topk_group: int | None = None,
        num_expert_group: int | None = None,
590
        global_num_experts: int = -1,
591
592
        expert_map: torch.Tensor | None = None,
        custom_routing_function: Callable | None = None,
593
        scoring_func: str = "softmax",
594
        routed_scaling_factor: float = 1.0,
595
        e_score_correction_bias: torch.Tensor | None = None,
596
597
598
        apply_router_weight_on_input: bool = False,
        activation: str = "silu",
        enable_eplb: bool = False,
599
600
601
602
        expert_load_view: torch.Tensor | None = None,
        logical_to_physical_map: torch.Tensor | None = None,
        logical_replica_count: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
603
604
        if enable_eplb:
            raise NotImplementedError(
605
606
                "EPLB not supported for `ModelOptFp8MoEMethod` yet."
            )
607

608
        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
609
610
611
            assert activation == "silu", (
                f"Expected 'silu' activation but got {activation}"
            )
612
613
614
615
616
617
618
619
620
621
            assert not renormalize
            return apply_flashinfer_per_tensor_scale_fp8(
                layer=layer,
                hidden_states=x,
                router_logits=router_logits,
                routing_bias=e_score_correction_bias,
                global_num_experts=global_num_experts,
                top_k=top_k,
                num_expert_group=num_expert_group,
                topk_group=topk_group,
622
623
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
624

625
        # Expert selection
XuruiYang's avatar
XuruiYang committed
626
        topk_weights, topk_ids, _ = FusedMoE.select_experts(
627
628
629
630
631
632
633
634
635
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
            num_expert_group=num_expert_group,
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
636
            routed_scaling_factor=routed_scaling_factor,
637
            e_score_correction_bias=e_score_correction_bias,
638
            indices_type=self.topk_indices_dtype,
639
        )
640

641
        if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
642
            assert not renormalize
643
644
645
            assert activation == "silu", (
                f"Expected 'silu' activation but got {activation}"
            )
646
647
648
649
650
651
652
653
654
655
656
657
            return flashinfer_cutlass_moe_fp8(
                x,
                layer,
                topk_weights,
                topk_ids,
                inplace=False,
                activation=activation,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
        else:
658
659
            from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts

660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
            assert self.moe_quant_config is not None

            return fused_experts(
                x,
                layer.w13_weight,
                layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                inplace=True,
                activation=activation,
                quant_config=self.moe_quant_config,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
675
676


677
678
679
680
681
682
class ModelOptNvFp4Config(QuantizationConfig):
    """Config class for ModelOpt FP4."""

    def __init__(
        self,
        is_checkpoint_nvfp4_serialized: bool,
683
        kv_cache_quant_algo: str | None,
684
        exclude_modules: list[str],
685
686
        group_size: int = 16,
    ) -> None:
687
        super().__init__()
688
689
690
691
        self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
        if is_checkpoint_nvfp4_serialized:
            logger.warning(
                "Detected ModelOpt NVFP4 checkpoint. Please note that"
692
693
                " the format is experimental and could change in future."
            )
694
695
696
697
698
699

            self.group_size = group_size
            self.kv_cache_quant_algo = kv_cache_quant_algo
            self.exclude_modules = exclude_modules

    @classmethod
700
    def get_name(cls) -> QuantizationMethods:
701
        return "modelopt_fp4"
702
703

    @classmethod
704
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
705
706
707
708
        return [torch.bfloat16, torch.half, torch.float8_e4m3fn]

    @classmethod
    def get_min_capability(cls) -> int:
709
        return 80
710
711

    @classmethod
712
    def get_config_filenames(cls) -> list[str]:
713
714
        return ["hf_quant_config.json"]

715
716
    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if self.exclude_modules is not None:
717
            self.exclude_modules = hf_to_vllm_mapper.apply_list(self.exclude_modules)
718

719
720
    @classmethod
    def override_quantization_method(
721
        cls, hf_quant_cfg, user_quant
722
    ) -> QuantizationMethods | None:
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
        """Detect if this ModelOpt FP4 config should be used based on
        quantization config."""
        if hf_quant_cfg is None:
            return None

        # Use the community standard 'quant_method'
        quant_method = hf_quant_cfg.get("quant_method", "").lower()

        # Only proceed if the method is explicitly "modelopt"
        if quant_method != "modelopt":
            return None

        # Look for ModelOpt-specific config structure
        if "quantization" in hf_quant_cfg:
            quant_config = hf_quant_cfg["quantization"]
            if isinstance(quant_config, dict):
                quant_algo = quant_config.get("quant_algo", "")
                if "NVFP4" in quant_algo:
                    return "modelopt_fp4"
        else:
            # Check for compressed-tensors style config with specific
            # quant_algo field
            quant_algo = hf_quant_cfg.get("quant_algo", "")
            if isinstance(quant_algo, str) and "FP4" in quant_algo.upper():
                return "modelopt_fp4"

        return None

751
    @classmethod
752
    def from_config(cls, config: dict[str, Any]) -> "ModelOptNvFp4Config":
753
754
755
756
757
758
759
        # Handle both traditional ModelOpt format and compressed-tensors
        # style format
        if "quantization" in config:
            # Traditional ModelOpt format:
            # {"quantization": {"quant_algo": "..."}}
            quant_config = cls.get_from_keys(config, ["quantization"])
            if not isinstance(quant_config, dict):
760
                raise ValueError("Expected 'quantization' to be a dictionary in config")
761
762
763
764
765
766
767
768
769
770
771
772
773

            quant_method = quant_config.get("quant_algo", "")
            if not quant_method:
                raise ValueError("Missing 'quant_algo' in quantization config")

            # Handle kv_cache_quant_algo with proper type validation
            kv_cache_quant_algo_raw = quant_config.get("kv_cache_quant_algo")
            if kv_cache_quant_algo_raw is None:
                # No KV cache quantization by default
                kv_cache_quant_algo = None
            elif isinstance(kv_cache_quant_algo_raw, str):
                kv_cache_quant_algo = kv_cache_quant_algo_raw
            else:
774
775
776
777
                raise ValueError(
                    f"kv_cache_quant_algo must be a string, got "
                    f"{type(kv_cache_quant_algo_raw)}"
                )
778
779
780
781
782
783
784
785
786
787
788

            # Handle group_size with proper type validation
            group_size_raw = quant_config.get("group_size")
            if group_size_raw is None:
                group_size = 16  # Default value
            elif isinstance(group_size_raw, int):
                group_size = group_size_raw
            else:
                try:
                    group_size = int(group_size_raw)
                except (ValueError, TypeError):
789
790
791
                    raise ValueError(
                        f"group_size must be an integer, got {type(group_size_raw)}"
                    ) from None
792

793
            # "exclude_modules" is the key in the legacy hf_quant_config.json
794
795
            exclude_modules = quant_config.get("exclude_modules", [])
            if not isinstance(exclude_modules, list):
796
797
798
                raise ValueError(
                    f"exclude_modules must be a list, got {type(exclude_modules)}"
                )
799
800
801
802
803
804
805
806
807
808
809
810
811
        else:
            # Compressed-tensors style format:
            # {"quant_algo": "...", "quant_method": "modelopt"}
            quant_method = config.get("quant_algo", "")

            # Handle kv_cache_quant_algo with proper type validation
            kv_cache_quant_algo_raw = config.get("kv_cache_quant_algo")
            if kv_cache_quant_algo_raw is None:
                # No KV cache quantization by default
                kv_cache_quant_algo = None
            elif isinstance(kv_cache_quant_algo_raw, str):
                kv_cache_quant_algo = kv_cache_quant_algo_raw
            else:
812
813
814
815
                raise ValueError(
                    f"kv_cache_quant_algo must be a string, got "
                    f"{type(kv_cache_quant_algo_raw)}"
                )
816
817
818
819
820
821
822
823
824
825
826

            # Handle group_size with proper type validation
            group_size_raw = config.get("group_size")
            if group_size_raw is None:
                group_size = 16  # Default value
            elif isinstance(group_size_raw, int):
                group_size = group_size_raw
            else:
                try:
                    group_size = int(group_size_raw)
                except (ValueError, TypeError):
827
828
829
                    raise ValueError(
                        f"group_size must be an integer, got {type(group_size_raw)}"
                    ) from None
830

831
832
            # "ignore" is the key in config.json
            exclude_modules = config.get("ignore", [])
833
            if not isinstance(exclude_modules, list):
834
835
836
                raise ValueError(
                    f"exclude_modules must be a list, got {type(exclude_modules)}"
                )
837

838
        if quant_method not in QUANT_ALGOS:
839
840
841
842
            raise ValueError(
                f"ModelOpt currently only supports: {QUANT_ALGOS} "
                "quantizations in vLLM. Please check the "
                "`hf_quant_config.json` file for your model's "
843
844
845
                "quant configuration."
            )
        is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
846
847
848
849
850

        # For FP4, these fields are required
        if is_checkpoint_nvfp4_serialized and "quantization" in config:
            # Check if required fields are present in the quantization config
            quant_config = config["quantization"]
851
            required_fields = ["group_size", "kv_cache_quant_algo", "exclude_modules"]
852
853
854
855
856
857
            missing_fields = [
                field for field in required_fields if field not in quant_config
            ]
            if missing_fields:
                raise ValueError(
                    f"NVFP4 quantization requires the following fields in "
858
859
860
861
862
863
864
865
866
                    f"hf_quant_config.json: {missing_fields}"
                )

        return cls(
            is_checkpoint_nvfp4_serialized,
            kv_cache_quant_algo,
            exclude_modules,
            group_size,
        )
867

868
869
870
871
872
873
    def is_layer_excluded(self, prefix: str) -> bool:
        """
        Check if a layer should be excluded from quantization.
        Handles both exact matching (for fused layers) and pattern matching.
        """
        # First check exact matching with fused layer support
874
        if is_layer_skipped(prefix, self.exclude_modules, self.packed_modules_mapping):
875
876
877
            return True

        # Check regex pattern matching for patterns not caught by exact match
878
        import regex as re
879

880
881
        for pattern in self.exclude_modules:
            # Skip patterns that would be caught by exact matching
882
883
            if "*" in pattern or "." in pattern:
                regex_str = pattern.replace(".", r"\.").replace("*", r".*")
884
885
                if re.fullmatch(regex_str, prefix):
                    return True
886
887
        return False

888
889
890
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
891
        from vllm.attention.layer import Attention  # Avoid circular import
892

893
        skip_layer = self.is_layer_excluded(prefix)
894
        if isinstance(layer, LinearBase):
895
            if skip_layer:
896
897
                return UnquantizedLinearMethod()
            # Check if this is a vision model layer that should not be quantized
898
            if "vision_tower" in prefix or "vision_model" in prefix:
899
900
901
902
                return UnquantizedLinearMethod()
            return ModelOptNvFp4LinearMethod(self)
        elif isinstance(layer, Attention):
            return ModelOptFp8KVCacheMethod(self)
903
        elif isinstance(layer, FusedMoE):
904
905
            if skip_layer:
                return None
906
            return ModelOptNvFp4FusedMoE(self, layer.moe_config, layer)
907
908
909
910
911
912
913
914
        return None


class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
    """

915
    def __init__(self, quant_config: ModelOptFp8Config | ModelOptNvFp4Config):
916
917
918
919
920
921
        super().__init__(quant_config)


class ModelOptNvFp4LinearMethod(LinearMethodBase):
    """Linear method for Model Optimizer NVFP4.
    Supports loading NVFP4 checkpoints with the following structure:
922

923
924
925
926
927
928
929
    input_scale: torch.float32, scalar ,
    weight: NVFP4(represented as byte) Shape: [1, X, y/2]
    weight_scale: FP8-E4M3, Shape: [X, Y], aka per block scale,
    weight_scale_2: torch.float32, scalar,
    Args: quant_config: The ModelOpt quantization config.
    """

930
    def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
931
        self.quant_config = quant_config
932

933
934
935
936
937
938
939
940
941
942
943
944
945
        self.backend = "none"
        if envs.VLLM_NVFP4_GEMM_BACKEND is None:
            if has_flashinfer():
                self.backend = "flashinfer-cutlass"
            elif cutlass_fp4_supported():
                self.backend = "cutlass"
            elif is_fp4_marlin_supported():
                self.backend = "marlin"
        elif envs.VLLM_NVFP4_GEMM_BACKEND.startswith("flashinfer-"):
            self.backend = envs.VLLM_NVFP4_GEMM_BACKEND
            assert has_flashinfer(), f"FlashInfer is required for {self.backend}"

        if self.backend == "none":
946
            raise ValueError(
947
948
                "No valid NVFP4 GEMM backend found. "
                "Please check your platform capability."
949
            )
950

951
952
        logger.info_once(f"Using {self.backend} for NVFP4 GEMM")

953
954
955
956
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
957
        output_partition_sizes: list[int],
958
959
960
961
962
963
964
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size
        if not self.quant_config.is_checkpoint_nvfp4_serialized:
965
966
967
968
            raise ValueError(
                "NVFP4 quantization was selected, "
                " dynamic quantization is not supported."
            )
969
970
971
972
973
974
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition

975
976
977
978
        if input_size_per_partition % 16 != 0:
            raise ValueError(
                "Unsupported model when in features size is not multiple of 16"
            )
979
        # The nvfp4 weight is still represented as
980
981
982
983
984
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_nvfp4_serialized
            else params_dtype
        )
985
986
987
988
989
990
        # Weight
        weight = ModelWeightParameter(
            data=torch.empty(
                # 2 fp4 items are packed in the input dimension
                layer.output_size_per_partition,
                layer.input_size_per_partition // 2,
991
992
                dtype=torch.uint8,
            ),
993
994
            input_dim=1,
            output_dim=0,
995
996
            weight_loader=weight_loader,
        )
997
998
999
        layer.register_parameter("weight", weight)

        # Input Weight Scale
1000
1001
1002
1003
        input_scale = PerTensorScaleParameter(
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )
1004
1005
1006
        layer.register_parameter("input_scale", input_scale)

        # Global Weight Scale
1007
1008
1009
1010
        weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )
1011
1012
1013
        layer.register_parameter("weight_scale_2", weight_scale_2)

        # Per Block Weight Scale
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
        weight_scale = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition // self.quant_config.group_size,
                dtype=weight_dtype,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034

        layer.register_parameter("weight_scale", weight_scale)

    def process_weights_after_loading(self, layer: Module) -> None:
        # global scales:
        input_scale_2 = layer.input_scale.max().to(torch.float32)
        layer.input_scale = Parameter(input_scale_2, requires_grad=False)

        weight_scale_2 = layer.weight_scale_2.max().to(torch.float32)
        layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False)

1035
1036
1037
        layer.alpha = Parameter(
            layer.input_scale * layer.weight_scale_2, requires_grad=False
        )
1038

1039
1040
        # Calculate `1 / input_scale` so that we don't need to do so at runtime
        layer.input_scale_inv = Parameter(
1041
1042
            (1 / layer.input_scale).to(torch.float32), requires_grad=False
        )
1043

1044
1045
1046
        # Swizzle the weight blockscale.
        # contracting dimension is input dimension
        # block_size = 16;
1047
1048
1049
        assert layer.weight_scale.dtype == torch.float8_e4m3fn, (
            "Weight Block scale must be represented as FP8-E4M3"
        )
1050

1051
1052
1053
1054
1055
        if self.backend == "marlin":
            prepare_fp4_layer_for_marlin(layer)
            del layer.alpha
            del layer.input_scale
        elif self.backend == "flashinfer-trtllm":
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
            # FlashInfer TRTLLM FP4 GEMM requires a different weight layout.
            # FlashInfer provides nvfp4_quantize to quantize + shuffle the
            # layout but we use our own quantization so we have to call
            # shuffles ourselves.
            from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a

            weight = layer.weight.data
            weight_scale = layer.weight_scale.data

            epilogue_tile_m = 128
1066
1067
1068
1069
1070
1071
            weight = shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m)
            weight_scale = (
                shuffle_matrix_sf_a(weight_scale.view(torch.uint8), epilogue_tile_m)
                .reshape(weight_scale.shape)
                .view(torch.float8_e4m3fn)
            )
1072

1073
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
1074
1075
1076
            layer.weight = Parameter(weight, requires_grad=False)
        else:
            swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
1077
            layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False)
1078
            layer.weight = Parameter(layer.weight.data, requires_grad=False)
1079
1080
1081
1082
1083

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
1084
        bias: torch.Tensor | None = None,
1085
    ) -> torch.Tensor:
1086
        if self.backend == "marlin":
1087
1088
1089
1090
1091
1092
1093
1094
            return apply_fp4_marlin_linear(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale,
                weight_scale_2=layer.weight_scale_2,
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
1095
1096
                bias=bias,
            )
1097

1098
        output_dtype = x.dtype
1099
        output_shape = [x.shape[0], layer.weight.shape[0]]
1100
1101

        # quantize BF16 or FP16 to (FP4 and interleaved block scale)
1102
        x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_scale_inv)
1103
1104
1105

        # validate dtypes of quantized input, input block scale,
        # weight and weight_blockscale
1106
1107
1108
1109
1110
        assert x_fp4.dtype == torch.uint8
        assert layer.weight.dtype == torch.uint8
        assert x_blockscale.dtype == torch.float8_e4m3fn
        assert layer.weight_scale.dtype == torch.float8_e4m3fn
        assert layer.alpha.dtype == torch.float32
1111

1112
1113
1114
1115
        mm_args = (
            x_fp4,
            layer.weight,
            x_blockscale,
1116
            layer.weight_scale,
1117
1118
1119
            layer.alpha,
            output_dtype,
        )
1120
1121
1122
        if self.backend.startswith("flashinfer-"):
            backend_name = self.backend[len("flashinfer-") :]
            out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name)
1123
        else:
1124
            assert self.backend == "cutlass"
1125
1126
            out = cutlass_scaled_fp4_mm(*mm_args)

1127
1128
1129
        if bias is not None:
            out = out + bias
        return out.view(*output_shape)
1130
1131
1132
1133
1134


class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
    """
    MoE Method for FP4 Quantization.
1135
    Args:
1136
1137
1138
        quant_config: NVFP4 Quant Config
    """

1139
1140
1141
1142
1143
1144
    def __init__(
        self,
        quant_config: ModelOptNvFp4Config,
        moe: FusedMoEConfig,
        layer: torch.nn.Module,
    ) -> None:
1145
1146
        from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import (
            detect_nvfp4_moe_support,  # noqa: E501
1147
1148
        )

1149
1150
1151
        super().__init__(moe)
        self.quant_config = quant_config
        self.layer = layer
1152
1153
        _nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
        self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
1154
        self.allow_flashinfer = _nvfp4.allow_flashinfer
1155
        self.use_marlin = _nvfp4.use_marlin
1156
        self.flashinfer_moe_backend = None
1157
        self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
1158
        if self.allow_flashinfer:
1159
1160
1161
            self.flashinfer_moe_backend = get_flashinfer_moe_backend()
            logger.info_once(
                f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
1162
1163
                " for ModelOptNvFp4FusedMoE."
            )
1164

1165
    def maybe_make_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize | None:
1166
1167
1168
1169
        if self.use_marlin or (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1170
            return None
1171
1172
1173
1174
        elif (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
        ):
1175
            # For now, fp4 moe only works with the flashinfer dispatcher.
1176
1177
1178
            prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(
                self.moe
            )
1179
1180
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
1181
1182
        else:
            return super().maybe_make_prepare_finalize()
1183

1184
1185
1186
    def select_gemm_impl(
        self,
        prepare_finalize: mk.FusedMoEPrepareAndFinalize,
1187
        layer: torch.nn.Module,
1188
    ) -> mk.FusedMoEPermuteExpertsUnpermute:
1189
        assert self.moe_quant_config is not None
1190
        experts = select_nvfp4_gemm_impl(
1191
1192
            self.moe,
            self.moe_quant_config,
1193
1194
1195
1196
            allow_flashinfer=self.allow_flashinfer,
        )
        logger.debug_once("Using %s", experts.__class__.__name__)
        return experts
1197

1198
1199
1200
1201
1202
1203
    def uses_weight_scale_2_pattern(self) -> bool:
        """
        FP4 variants use 'weight_scale_2' pattern for per-tensor weight scales.
        """
        return True

1204
1205
1206
1207
1208
1209
1210
1211
1212
    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
1213
        if not self.quant_config.is_checkpoint_nvfp4_serialized:
1214
1215
1216
1217
            raise ValueError(
                "NVFP4 quantization was selected, "
                " dynamic quantization is not supported."
            )
1218

1219
1220
        layer.num_experts = num_experts
        layer.params_dtype = params_dtype
1221
1222
1223
1224
        layer.quant_config = self.quant_config
        weight_dtype = torch.uint8
        weight_scale_dtype = torch.float8_e4m3fn
        weight_loader = extra_weight_attrs.get("weight_loader")
1225
        global_num_experts = extra_weight_attrs.get("global_num_experts")
1226
1227
1228
1229
1230
1231
1232
        # GEMM 1
        w13_weight = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                # 2 fp4 items are packed in the input dimension
                hidden_size // 2,
1233
1234
                dtype=weight_dtype,
            ),
1235
1236
            input_dim=1,
            output_dim=2,
1237
1238
            weight_loader=weight_loader,
        )
1239
1240
1241
1242
1243
1244
1245
1246
1247
        layer.register_parameter("w13_weight", w13_weight)

        # GEMM 2
        w2_weight = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                hidden_size,
                # 2 fp4 items are packed in the input dimension
                intermediate_size_per_partition // 2,
1248
1249
                dtype=weight_dtype,
            ),
1250
1251
            input_dim=1,
            output_dim=2,
1252
1253
            weight_loader=weight_loader,
        )
1254
1255
1256
1257
1258
1259
1260
1261
        layer.register_parameter("w2_weight", w2_weight)

        w13_weight_scale = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                # 2 fp4 items are packed in the input dimension
                hidden_size // self.quant_config.group_size,
1262
1263
                dtype=weight_scale_dtype,
            ),
1264
1265
            input_dim=1,
            output_dim=2,
1266
1267
            weight_loader=weight_loader,
        )
1268
1269
1270
1271
1272
1273
1274
        layer.register_parameter("w13_weight_scale", w13_weight_scale)

        w2_weight_scale = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                hidden_size,
                # 2 fp4 items are packed in the input dimension
1275
1276
1277
                intermediate_size_per_partition // self.quant_config.group_size,
                dtype=weight_scale_dtype,
            ),
1278
1279
            input_dim=1,
            output_dim=2,
1280
1281
            weight_loader=weight_loader,
        )
1282
1283
1284
        layer.register_parameter("w2_weight_scale", w2_weight_scale)

        extra_weight_attrs.update(
1285
1286
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
        )
1287
1288
1289

        w13_weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(num_experts, 2, dtype=torch.float32),
1290
1291
            weight_loader=weight_loader,
        )
1292
1293
1294
1295
        layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)

        w2_weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(num_experts, dtype=torch.float32),
1296
1297
            weight_loader=weight_loader,
        )
1298
1299
1300
        layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2)

        extra_weight_attrs.update(
1301
1302
            {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
        )
1303

1304
1305
1306
1307
1308
        use_global_sf = self.allow_flashinfer and is_flashinfer_supporting_global_sf(
            self.flashinfer_moe_backend
        )
        global_scale_num_experts = global_num_experts if use_global_sf else num_experts

1309
        w13_input_scale = PerTensorScaleParameter(
1310
            data=torch.empty(global_scale_num_experts, 2, dtype=torch.float32),
1311
1312
            weight_loader=weight_loader,
        )
1313
1314
        layer.register_parameter("w13_input_scale", w13_input_scale)

1315
        w2_input_scale = PerTensorScaleParameter(
1316
            data=torch.empty(global_scale_num_experts, dtype=torch.float32),
1317
1318
            weight_loader=weight_loader,
        )
1319
1320
        layer.register_parameter("w2_input_scale", w2_input_scale)

1321
    def prepare_static_weights_for_trtllm_fp4_moe(
1322
        self,
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
        # args_dequant,
        # args,
        gemm1_weights,
        gemm2_weights,
        gemm1_scales_linear_fp4_bytes,
        gemm2_scales_linear_fp4_bytes,
        hidden_size,
        intermediate_size,
        num_experts,
    ):
        from flashinfer import nvfp4_block_scale_interleave
        from flashinfer.fused_moe.core import (
1335
            _maybe_get_cached_w3_w1_permute_indices,
1336
            get_w2_permute_indices_with_cache,
1337
1338
        )

1339
1340
1341
1342
1343
        """Prepare quantized weights for kernel (done offline with weights)."""
        epilogue_tile_m = 128  # FIXME: this depends on the kernel internals

        # Convert quantized weights to proper formats
        gemm1_weights_fp4 = gemm1_weights.view(torch.float8_e4m3fn).reshape(
1344
1345
            num_experts, 2 * intermediate_size, hidden_size // 2
        )  # packed fp4
1346
        gemm1_scales_linear_fp4 = gemm1_scales_linear_fp4_bytes.view(
1347
1348
1349
1350
            torch.float8_e4m3fn
        ).reshape(
            num_experts, 2 * intermediate_size, hidden_size // 16
        )  # fp8 scaling factors
1351
1352

        gemm2_weights_fp4 = gemm2_weights.view(torch.float8_e4m3fn).reshape(
1353
1354
            num_experts, hidden_size, intermediate_size // 2
        )  # packed fp4
1355
        gemm2_scales_linear_fp4 = gemm2_scales_linear_fp4_bytes.view(
1356
1357
1358
1359
            torch.float8_e4m3fn
        ).reshape(
            num_experts, hidden_size, intermediate_size // 16
        )  # fp8 scaling factors
1360
1361
1362
1363
1364
1365

        gemm1_weights_fp4_shuffled = []
        gemm1_scales_fp4_shuffled = []
        gemm2_weights_fp4_shuffled = []
        gemm2_scales_fp4_shuffled = []
        for i in range(num_experts):
1366
1367
1368
1369
1370
1371
1372
1373
1374
            # Calculate the permute indices for the following:
            # 1. Reorder rows of W1 and scales for fused gated activation
            # 2. Shuffle weights and scaling factors for transposed mma output
            # for both w3_w1 and w2 weights and scale factors
            permute_indices = _maybe_get_cached_w3_w1_permute_indices(
                self._cache_permute_indices,
                gemm1_weights_fp4[i].view(torch.uint8),
                epilogue_tile_m,
            )
1375
1376
1377
1378
1379
            gemm1_weights_fp4_shuffled.append(
                gemm1_weights_fp4[i]
                .view(torch.uint8)[permute_indices.to(gemm1_weights_fp4.device)]
                .contiguous()
            )
1380
1381
1382
1383
1384
1385
1386

            permute_sf_indices = _maybe_get_cached_w3_w1_permute_indices(
                self._cache_permute_indices,
                gemm1_scales_linear_fp4[i].view(torch.uint8),
                epilogue_tile_m,
                num_elts_per_sf=16,
            )
1387
            gemm1_scales_fp4_shuffled.append(
1388
1389
1390
1391
1392
1393
1394
1395
                nvfp4_block_scale_interleave(
                    gemm1_scales_linear_fp4[i]
                    .view(torch.uint8)[
                        permute_sf_indices.to(gemm1_scales_linear_fp4.device)
                    ]
                    .contiguous()
                )
            )
1396

1397
            permute_indices = get_w2_permute_indices_with_cache(
1398
1399
1400
1401
                self._cache_permute_indices,
                gemm2_weights_fp4[i].view(torch.uint8),
                epilogue_tile_m,
            )
1402
1403
1404
1405
1406
            gemm2_weights_fp4_shuffled.append(
                gemm2_weights_fp4[i]
                .view(torch.uint8)[permute_indices.to(gemm2_weights_fp4.device)]
                .contiguous()
            )
1407

1408
            permute_sf_indices = get_w2_permute_indices_with_cache(
1409
1410
1411
1412
1413
                self._cache_permute_indices,
                gemm2_scales_linear_fp4[i].view(torch.uint8),
                epilogue_tile_m,
                num_elts_per_sf=16,
            )
1414
            gemm2_scales_fp4_shuffled.append(
1415
1416
1417
1418
1419
1420
1421
1422
                nvfp4_block_scale_interleave(
                    gemm2_scales_linear_fp4[i]
                    .view(torch.uint8)[
                        permute_sf_indices.to(gemm2_scales_linear_fp4.device)
                    ]
                    .contiguous()
                )
            )
1423
1424
1425
1426

        # Stack weights for all experts
        gemm1_weights_fp4_shuffled = torch.stack(gemm1_weights_fp4_shuffled)
        gemm1_scales_fp4_shuffled = (
1427
1428
1429
1430
            torch.stack(gemm1_scales_fp4_shuffled)
            .view(torch.float8_e4m3fn)
            .reshape(num_experts, 2 * intermediate_size, hidden_size // 16)
        )
1431
1432
1433

        gemm2_weights_fp4_shuffled = torch.stack(gemm2_weights_fp4_shuffled)
        gemm2_scales_fp4_shuffled = (
1434
1435
1436
1437
            torch.stack(gemm2_scales_fp4_shuffled)
            .view(torch.float8_e4m3fn)
            .reshape(num_experts, hidden_size, intermediate_size // 16)
        )
1438
1439
1440
1441
1442
1443
        return (
            gemm1_weights_fp4_shuffled,
            gemm1_scales_fp4_shuffled,
            gemm2_weights_fp4_shuffled,
            gemm2_scales_fp4_shuffled,
        )
1444

1445
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
1446
        # GEMM 1 processing
1447
1448
1449
        gemm1_weight = layer.w13_weight.data
        gemm1_weight_scale = layer.w13_weight_scale.data

1450
        if self.allow_flashinfer:
1451
            gemm1_weight, gemm1_weight_scale = reorder_w1w3_to_w3w1(
1452
1453
                gemm1_weight, gemm1_weight_scale, dim=-2
            )
1454
1455

        layer.w13_weight = Parameter(gemm1_weight, requires_grad=False)
1456
        layer.w13_weight_scale = Parameter(gemm1_weight_scale, requires_grad=False)
1457

1458
        # Common processing for w13_weight_scale_2
1459
1460
1461
        if not torch.allclose(
            layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]
        ):
1462
1463
            logger.warning_once(
                "w1_weight_scale_2 must match w3_weight_scale_2. "
1464
1465
                "Accuracy may be affected."
            )
1466
1467

        w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
1468
        layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False)
1469

1470
        # Common processing for input scales and alphas
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
        use_global_sf = self.allow_flashinfer and is_flashinfer_supporting_global_sf(
            self.flashinfer_moe_backend
        )
        if use_global_sf:
            # For backends provide by Flashinfer, the input global scales are
            # shared across all experts.
            w13_input_scale = (
                layer.w13_input_scale.max().to(torch.float32).expand(layer.num_experts)
            )
        else:
            w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
1482
1483
        layer.g1_alphas = Parameter(
            (w13_input_scale * w13_weight_scale_2).to(torch.float32),
1484
1485
            requires_grad=False,
        )
1486
1487
1488

        # This is for quantization, so we need to invert it.
        layer.w13_input_scale_quant = Parameter(
1489
1490
            (1 / w13_input_scale).to(torch.float32), requires_grad=False
        )
1491

1492
        # GEMM 2 processing
1493
1494
1495
1496
1497
1498
1499
1500
        if use_global_sf:
            # For backends provide by Flashinfer, the input global scales are
            # shared across all experts.
            w2_input_scale = (
                layer.w2_input_scale.max().to(torch.float32).expand(layer.num_experts)
            )
        else:
            w2_input_scale = layer.w2_input_scale
1501
        layer.g2_alphas = Parameter(
1502
            (w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
1503
1504
            requires_grad=False,
        )
1505
1506
1507

        # This is for quantization, so we need to invert it.
        layer.w2_input_scale_quant = Parameter(
1508
            (1 / w2_input_scale).to(torch.float32), requires_grad=False
1509
        )
1510

1511
        # TensorRT-LLM specific processing
1512
1513
1514
1515
        if (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1516
            # Prepare static weights for TRT-LLM kernel
1517
            # alternate: prepare_static_weight_layouts_for_trtllm_moe
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
            (
                gemm1_weights_fp4_shuffled,
                gemm1_scales_fp4_shuffled,
                gemm2_weights_fp4_shuffled,
                gemm2_scales_fp4_shuffled,
            ) = self.prepare_static_weights_for_trtllm_fp4_moe(
                layer.w13_weight,
                layer.w2_weight,
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                layer.w2_weight.size(-2),  # hidden_size
                layer.w13_weight.size(-2) // 2,  # intermediate_size
                layer.w13_weight.size(0),  # num_experts
            )
1532
            logger.debug_once("Finished shuffling weights for TRT-LLM MOE")
1533
1534

            layer.gemm1_weights_fp4_shuffled = Parameter(
1535
1536
                gemm1_weights_fp4_shuffled, requires_grad=False
            )
1537
            layer.gemm2_weights_fp4_shuffled = Parameter(
1538
1539
                gemm2_weights_fp4_shuffled, requires_grad=False
            )
1540
            layer.gemm1_scales_fp4_shuffled = Parameter(
1541
1542
                gemm1_scales_fp4_shuffled, requires_grad=False
            )
1543
            layer.gemm2_scales_fp4_shuffled = Parameter(
1544
1545
                gemm2_scales_fp4_shuffled, requires_grad=False
            )
1546
1547
1548

            # Additional parameter needed for TRT-LLM
            layer.g1_scale_c = Parameter(
1549
                (layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32),
1550
1551
                requires_grad=False,
            )
1552

1553
1554
1555
1556
1557
            # Clean up weights that won't be used by TRT-LLM
            del layer.w2_weight
            del layer.w2_weight_scale
            del layer.w13_weight
            del layer.w13_weight_scale
1558
1559
1560
1561
1562
1563
1564
        elif self.use_marlin:
            # Marlin processing
            prepare_moe_fp4_layer_for_marlin(layer)
            del layer.g1_alphas
            del layer.g2_alphas
            del layer.w13_input_scale_quant
            del layer.w2_input_scale_quant
1565
1566
        else:
            # Non-TRT-LLM processing (Cutlass or non-flashinfer)
1567
1568
1569
1570
1571
            w13_blockscale_swizzled = swizzle_blockscale(layer.w13_weight_scale)
            layer.w13_weight_scale = Parameter(
                w13_blockscale_swizzled, requires_grad=False
            )

1572
            w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale)
1573
1574
1575
1576
            layer.w2_weight_scale = Parameter(
                w2_blockscale_swizzled, requires_grad=False
            )
            layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
1577

1578
    def get_fused_moe_quant_config(
1579
        self, layer: torch.nn.Module
1580
    ) -> FusedMoEQuantConfig | None:
1581
1582
1583
1584
        if (
            self.use_marlin
            or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
            return None

        return nvfp4_moe_quant_config(
            w1_scale=layer.w13_weight_scale,
            w2_scale=layer.w2_weight_scale,
            g1_alphas=layer.g1_alphas,
            g2_alphas=layer.g2_alphas,
            a1_gscale=layer.w13_input_scale_quant,
            a2_gscale=layer.w2_input_scale_quant,
        )

1596
1597
1598
1599
1600
1601
1602
1603
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
        use_grouped_topk: bool = False,
1604
1605
        topk_group: int | None = None,
        num_expert_group: int | None = None,
1606
        global_num_experts: int = -1,
1607
1608
        expert_map: torch.Tensor | None = None,
        custom_routing_function: Callable | None = None,
1609
        scoring_func: str = "softmax",
1610
        routed_scaling_factor: float = 1.0,
1611
        e_score_correction_bias: torch.Tensor | None = None,
1612
1613
        apply_router_weight_on_input: bool = False,
        activation: str = "silu",
1614
        enable_eplb: bool = False,
1615
1616
1617
1618
        expert_load_view: torch.Tensor | None = None,
        logical_to_physical_map: torch.Tensor | None = None,
        logical_replica_count: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
1619
1620
        if enable_eplb:
            raise NotImplementedError(
1621
1622
                "EPLB not supported for `ModelOptNvFp4FusedMoE` yet."
            )
1623
        assert activation == "silu", "Only SiLU activation is supported."
1624

1625
1626
1627
1628
        if (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1629
1630
1631
1632
1633
            import flashinfer

            from vllm.model_executor.models.llama4 import Llama4MoE

            a1_gscale = layer.w13_input_scale_quant
1634
1635
1636
1637
1638
1639
1640
1641
            (hidden_states_fp4, hidden_states_scale_linear_fp4) = (
                flashinfer.fp4_quantize(
                    x,
                    a1_gscale,
                    is_sf_swizzled_layout=False,
                )
            )
            use_llama4_routing = (
1642
                custom_routing_function is Llama4MoE.custom_routing_function
1643
            )
1644
1645
1646
            routing_method_type = flashinfer.RoutingMethodType.DeepSeekV3
            if use_llama4_routing:
                routing_method_type = flashinfer.RoutingMethodType.Llama4
Shu Wang's avatar
Shu Wang committed
1647
1648
1649
            routing_bias = e_score_correction_bias
            if routing_bias is not None:
                routing_bias = routing_bias.to(torch.bfloat16)
1650
1651
            out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
                routing_logits=router_logits
1652
1653
                if use_llama4_routing
                else router_logits.to(torch.float32),
Shu Wang's avatar
Shu Wang committed
1654
                routing_bias=routing_bias,
1655
1656
                hidden_states=hidden_states_fp4,
                hidden_states_scale=hidden_states_scale_linear_fp4.view(
1657
1658
                    torch.float8_e4m3fn
                ).flatten(),
1659
1660
                gemm1_weights=layer.gemm1_weights_fp4_shuffled.data,
                gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.data.view(
1661
1662
                    torch.float8_e4m3fn
                ),
1663
1664
1665
1666
1667
1668
                gemm1_bias=None,
                gemm1_alpha=None,
                gemm1_beta=None,
                gemm1_clamp_limit=None,
                gemm2_weights=layer.gemm2_weights_fp4_shuffled.data,
                gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.data.view(
1669
1670
                    torch.float8_e4m3fn
                ),
1671
1672
1673
1674
1675
1676
                gemm2_bias=None,
                output1_scale_scalar=layer.g1_scale_c.data,
                output1_scale_gate_scalar=layer.g1_alphas.data,
                output2_scale_scalar=layer.g2_alphas.data,
                num_experts=global_num_experts,
                top_k=top_k,
1677
                n_group=num_expert_group if num_expert_group is not None else 0,
1678
                topk_group=topk_group if topk_group is not None else 0,
1679
1680
1681
1682
                intermediate_size=layer.intermediate_size_per_partition,
                local_expert_offset=layer.ep_rank * layer.local_num_experts,
                local_num_experts=layer.local_num_experts,
                routed_scaling_factor=None,
1683
                tile_tokens_dim=None,
1684
1685
1686
1687
1688
                routing_method_type=routing_method_type,
                do_finalize=True,
            )[0]
            return out

XuruiYang's avatar
XuruiYang committed
1689
        topk_weights, topk_ids, _ = FusedMoE.select_experts(
1690
1691
1692
1693
1694
1695
1696
1697
1698
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
            num_expert_group=num_expert_group,
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
1699
            routed_scaling_factor=routed_scaling_factor,
1700
            e_score_correction_bias=e_score_correction_bias,
1701
1702
            indices_type=self.topk_indices_dtype,
        )
1703

1704
        if self.use_marlin:
1705
            return fused_marlin_moe(
1706
1707
1708
                x,
                layer.w13_weight,
                layer.w2_weight,
1709
1710
                None,
                None,
1711
1712
1713
1714
1715
1716
1717
1718
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                router_logits,
                topk_weights,
                topk_ids,
                global_scale1=layer.w13_weight_scale_2,
                global_scale2=layer.w2_weight_scale_2,
                quant_type_id=scalar_types.float4_e2m1f.id,
1719
                apply_router_weight_on_input=apply_router_weight_on_input,
1720
                global_num_experts=global_num_experts,
1721
                expert_map=expert_map,
1722
1723
                workspace=layer.workspace,
            )
1724

1725
1726
1727
1728
1729
1730
        elif (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
        ):
            from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (  # noqa: E501
                flashinfer_cutlass_moe_fp4,
1731
            )
1732

1733
            assert self.moe_quant_config is not None
1734

1735
            return flashinfer_cutlass_moe_fp4(
1736
1737
1738
1739
1740
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
1741
1742
                quant_config=self.moe_quant_config,
                inplace=False,
1743
1744
1745
1746
1747
1748
                activation=activation,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
        else:
1749
1750
            # If no modular kernel is provided, use cutlass_moe_fp4 for TP case
            # only (no EP).
1751
1752
            from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4

1753
1754
            assert self.moe_quant_config is not None
            return cutlass_moe_fp4(
1755
1756
1757
1758
1759
                a=x,
                w1_fp4=layer.w13_weight,
                w2_fp4=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
1760
1761
1762
1763
                quant_config=self.moe_quant_config,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
                # TODO: derive from arguments
1764
1765
1766
1767
                m=x.shape[0],
                n=layer.w2_weight.shape[2] * 2,
                k=x.shape[1],
                e=layer.w13_weight.shape[0],
1768
            )