modelopt.py 68.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Optional
6
7
8
9
10

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter

11
12
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
13
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
14
from vllm.logger import init_logger
15
from vllm.model_executor.layers.fused_moe.config import (
16
17
    FusedMoEConfig,
    FusedMoEQuantConfig,
18
    RoutingMethodType,
19
20
21
    fp8_w8a8_moe_quant_config,
    nvfp4_moe_quant_config,
)
22
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
23
from vllm.model_executor.layers.fused_moe.layer import (
24
25
26
27
28
29
30
31
32
    FusedMoE,
    FusedMoEMethodBase,
    FusedMoeWeightScaleSupported,
)
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
)
33
from vllm.model_executor.layers.quantization import QuantizationMethods
34
from vllm.model_executor.layers.quantization.base_config import (
35
36
37
    QuantizationConfig,
    QuantizeMethodBase,
)
38
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
39
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
40
41
42
43
    build_flashinfer_fp4_cutlass_moe_prepare_finalize,
    reorder_w1w3_to_w3w1,
    select_nvfp4_gemm_impl,
)
44
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
45
46
    FlashinferMoeBackend,
    apply_flashinfer_per_tensor_scale_fp8,
47
    build_flashinfer_fp8_cutlass_moe_prepare_finalize,
48
49
    flashinfer_cutlass_moe_fp8,
    get_flashinfer_moe_backend,
50
    is_flashinfer_supporting_global_sf,
51
52
53
54
55
    register_moe_scaling_factors,
    rotate_flashinfer_fp8_moe_weights,
    select_cutlass_fp8_gemm_impl,
    swap_w13_to_w31,
)
56
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
57
58
59
60
61
    apply_fp4_marlin_linear,
    is_fp4_marlin_supported,
    prepare_fp4_layer_for_marlin,
    prepare_moe_fp4_layer_for_marlin,
)
62
from vllm.model_executor.layers.quantization.utils.quant_utils import (
63
64
65
66
67
    GroupShape,
    cutlass_fp4_supported,
    is_layer_skipped,
    swizzle_blockscale,
)
68
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
69
70
71
72
    Fp8LinearOp,
    requantize_with_max_scale,
)
from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
73
from vllm.scalar_type import scalar_types
74
75
76
77
78
from vllm.utils.flashinfer import (
    flashinfer_scaled_fp4_mm,
    has_flashinfer,
    has_flashinfer_moe,
)
79

80
81
82
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

83
84
logger = init_logger(__name__)

85
86
QUANT_ALGOS = ["FP8", "NVFP4"]
KV_CACHE_QUANT_ALGOS = ["FP8"]
87
88
89
90
91
92
93
94


class ModelOptFp8Config(QuantizationConfig):
    """Config class for ModelOpt FP8."""

    def __init__(
        self,
        is_checkpoint_fp8_serialized: bool = False,
95
96
        kv_cache_quant_method: str | None = None,
        exclude_modules: list[str] | None = None,
97
    ) -> None:
98
        super().__init__()
99
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
100
        self.kv_cache_quant_method = kv_cache_quant_method
101
        self.exclude_modules = exclude_modules or []
102
        if is_checkpoint_fp8_serialized:
103
104
105
106
            logger.warning(
                "Detected ModelOpt fp8 checkpoint. Please note that"
                " the format is experimental and could change."
            )
107
108

    @classmethod
109
    def get_name(cls) -> QuantizationMethods:
110
111
112
        return "modelopt"

    @classmethod
113
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
114
115
116
117
118
119
120
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
        return 89

    @classmethod
121
    def get_config_filenames(cls) -> list[str]:
122
123
        return ["hf_quant_config.json"]

124
125
    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if self.exclude_modules is not None:
126
            self.exclude_modules = hf_to_vllm_mapper.apply_list(self.exclude_modules)
127

128
129
    @classmethod
    def override_quantization_method(
130
        cls, hf_quant_cfg, user_quant
131
    ) -> QuantizationMethods | None:
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
        """Detect if this ModelOpt config should be used based on
        quantization config."""

        if hf_quant_cfg is None:
            return None

        # Use the community standard 'quant_method'
        quant_method = hf_quant_cfg.get("quant_method", "").lower()

        # Only proceed if the method is explicitly "modelopt"
        if quant_method != "modelopt":
            return None

        # Look for ModelOpt-specific config structure
        if "quantization" in hf_quant_cfg:
            quant_config = hf_quant_cfg["quantization"]
            if isinstance(quant_config, dict):
                quant_algo = quant_config.get("quant_algo", "")
                if "FP8" in quant_algo:
                    return "modelopt"
        else:
            # Check for compressed-tensors style config with specific quant_algo
            quant_algo = hf_quant_cfg.get("quant_algo", "")
            if isinstance(quant_algo, str) and "FP8" in quant_algo:
                return "modelopt"

        return None

160
    @classmethod
161
    def from_config(cls, config: dict[str, Any]) -> "ModelOptFp8Config":
162
163
164
165
166
        # Handle both ModelOpt format and compressed-tensors style format
        if "quantization" in config:
            # ModelOpt format: {"quantization": {"quant_algo": "..."}}
            quant_config = cls.get_from_keys(config, ["quantization"])
            if not isinstance(quant_config, dict):
167
                raise ValueError("Expected 'quantization' to be a dictionary in config")
168
169
170
171
            quant_method = quant_config.get("quant_algo", "")
            if not quant_method:
                raise ValueError("Missing 'quant_algo' in quantization config")
            kv_cache_quant_method = quant_config.get("kv_cache_quant_algo")
172
            # "exclude_modules" is the key in the legacy hf_quant_config.json
173
174
175
176
177
178
            exclude_modules = quant_config.get("exclude_modules")
        else:
            # Compressed-tensors style format:
            # {"quant_algo": "...", "quant_method": "modelopt"}
            quant_method = config.get("quant_algo", "")
            kv_cache_quant_method = config.get("kv_cache_quant_algo")
179
180
            # "ignore" is the key in config.json
            exclude_modules = config.get("ignore")
181

182
        if quant_method not in QUANT_ALGOS:
183
184
185
186
            raise ValueError(
                f"ModelOpt currently only supports: {QUANT_ALGOS} "
                "quantizations in vLLM. Please check the "
                "`hf_quant_config.json` file for your model's "
187
188
189
                "quant configuration."
            )
        is_checkpoint_fp8_serialized = "FP8" in quant_method
190

191
        return cls(is_checkpoint_fp8_serialized, kv_cache_quant_method, exclude_modules)
192
193
194
195

    def is_layer_excluded(self, prefix: str) -> bool:
        """
        Check if a layer should be excluded from quantization.
196
        Handles both exact matching (for fused layers) and substring matching.
197
198
199
200
201
202
203
204

        This method handles both regular models and multimodal models that use
        the language_model prefix. For multimodal models, it checks if the
        module name (without the language_model prefix) is in the exclude list.
        """
        if self.exclude_modules is None:
            return False

205
        # First check exact matching with fused layer support
206
        if is_layer_skipped(prefix, self.exclude_modules, self.packed_modules_mapping):
207
208
209
            return True

        # Then check substring matching for patterns not caught by exact match
210
        for module in self.exclude_modules:
211
            # Skip exact matches already handled above
212
213
214
215
216
217
218
            if module != prefix and (
                module in prefix
                or (
                    prefix.startswith("language_model.")
                    and module in prefix.removeprefix("language_model.")
                )
            ):
219
220
                return True
        return False
221

222
223
224
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
225
226
227
228
        from vllm.attention.layer import (  # Avoid circular import
            Attention,
            MLAAttention,
        )
229

230
        if isinstance(layer, LinearBase):
231
232
233
            if self.is_layer_excluded(prefix):
                return UnquantizedLinearMethod()
            # Check if this is a vision model layer that should not be quantized
234
            if "vision_tower" in prefix or "vision_model" in prefix:
235
                return UnquantizedLinearMethod()
236
            return ModelOptFp8LinearMethod(self)
237
        elif isinstance(layer, (Attention, MLAAttention)):
238
            return ModelOptFp8KVCacheMethod(self)
239
        elif isinstance(layer, FusedMoE):
240
            return ModelOptFp8MoEMethod(self, layer)
241
242
243
244
245
246
        return None


class ModelOptFp8LinearMethod(LinearMethodBase):
    """Linear method for Model Optimizer static quantization.
    Supports loading FP8 checkpoints with static weight scale and
247
    activation scale. Future support might be added for dynamic
248
249
250
251
    scales.

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
252
    2. Only support float8_e4m3fn datatype
253
254
255
        Args: quant_config: The ModelOpt quantization config.
    """

256
    def __init__(self, quant_config: ModelOptFp8Config) -> None:
257
        self.quant_config = quant_config
258
        self.fp8_linear = Fp8LinearOp(
259
260
            act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR
        )
261
262
263
264
265

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
266
        output_partition_sizes: list[int],
267
268
269
270
271
272
273
274
275
276
277
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
278
279
280
281
282
283
284
285
286
287
288
289
290
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_fp8_serialized
            else params_dtype
        )
        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition, input_size_per_partition, dtype=weight_dtype
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
291
292
293
294
        layer.register_parameter("weight", weight)

        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
295
296
297
298
            weight_scale = PerTensorScaleParameter(
                data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
                weight_loader=weight_loader,
            )
299
300
301
            weight_scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("weight_scale", weight_scale)
            # INPUT SCALE
302
303
304
305
            scale = PerTensorScaleParameter(
                data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
                weight_loader=weight_loader,
            )
306
307
308
309
310

            scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("input_scale", scale)

    def process_weights_after_loading(self, layer: Module) -> None:
311
312
313
314
        weight = layer.weight
        max_w_scale = layer.weight_scale.max()
        if not (layer.weight_scale == layer.weight_scale[0]).all():
            max_w_scale, weight = requantize_with_max_scale(
315
316
                layer.weight, layer.weight_scale, layer.logical_widths
            )
317
318
        layer.weight = Parameter(weight.t(), requires_grad=False)
        layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
319
        layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
320
321
322
323
324

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
325
        bias: torch.Tensor | None = None,
326
    ) -> torch.Tensor:
327
328
329
330
331
332
333
        return self.fp8_linear.apply(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            input_scale=layer.input_scale,
            bias=bias,
        )
334
335


336
337
338
339
340
341
342
343
class ModelOptFp8MoEMethod(FusedMoEMethodBase):
    """MoE method for ModelOpt FP8.
    Supports loading FP8 checkpoints with static weight scale and
    activation scale.
    Args:
        quant_config: The ModelOpt quantization config.
    """

344
345
346
    def __init__(
        self,
        quant_config: ModelOptFp8Config,
347
        layer: torch.nn.Module,
348
    ) -> None:
349
350
        super().__init__(layer.moe_config)
        self.layer = layer
351
352
        self.quant_config = quant_config
        from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
353
354
355
            cutlass_fp8_supported,
        )

356
        self.cutlass_fp8_supported = cutlass_fp8_supported()
357
        self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
358
        if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
359
            self.flashinfer_moe_backend = get_flashinfer_moe_backend()
360
361
362
363
364
365
366
367
368
369
            if (
                self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
                and not self.moe.is_act_and_mul
            ):
                logger.info_once(
                    "Non-gated MoE is not supported for min-latency mode,"
                    "falling back to high-throughput mode"
                )
                self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS

370
            logger.info_once(
371
372
373
374
                f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
            )

    def maybe_make_prepare_finalize(
375
        self,
376
        routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
377
    ) -> mk.FusedMoEPrepareAndFinalize | None:
378
379
380
381
        # TRT LLM not supported with all2all yet.
        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
            return None
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
382
383
384
            prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
                self.moe
            )
385
386
387
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
        else:
388
            return super().maybe_make_prepare_finalize(routing_tables)
389
390
391
392

    def select_gemm_impl(
        self,
        prepare_finalize: mk.FusedMoEPrepareAndFinalize,
393
        layer: torch.nn.Module,
394
    ) -> mk.FusedMoEPermuteExpertsUnpermute:
395
        assert self.moe_quant_config is not None
396
        experts = select_cutlass_fp8_gemm_impl(
397
398
            self.moe,
            self.moe_quant_config,
399
400
401
        )
        logger.debug_once("Using %s", experts.__class__.__name__)
        return experts
402
403
404
405
406
407
408
409
410
411
412

    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        # Use FP8 dtype if checkpoint is serialized
413
414
415
416
417
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_fp8_serialized
            else params_dtype
        )
418
419
        weight_loader = extra_weight_attrs.get("weight_loader")

420
421
422
423
424
        if self.moe.is_act_and_mul:
            w13_up_dim = 2 * intermediate_size_per_partition
        else:
            w13_up_dim = intermediate_size_per_partition

425
        w13_weight = ModelWeightParameter(
426
427
            data=torch.empty(
                num_experts,
428
                w13_up_dim,
429
430
431
                hidden_size,
                dtype=weight_dtype,
            ),
432
433
434
435
436
437
438
            input_dim=2,
            output_dim=1,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_weight", w13_weight)

        w2_weight = ModelWeightParameter(
439
440
441
442
443
444
            data=torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=weight_dtype,
            ),
445
446
447
448
449
450
451
452
            input_dim=2,
            output_dim=1,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w2_weight", w2_weight)

        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALES - Per-tensor scaling for ModelOpts
453
            # For gated MoE, allocate 2 scales for w1 and w3 respectively.
454
            # They will be combined to a single scale after weight loading.
455
456
457
458
459
            # For non-gated MoE, allocate 1 scale for w13.
            if self.moe.is_act_and_mul:
                w13_weight_scale_shape = (num_experts, 2)
            else:
                w13_weight_scale_shape = (num_experts, 1)
460
461
            w13_weight_scale = PerTensorScaleParameter(
                data=torch.full(
462
                    w13_weight_scale_shape,
463
464
465
466
467
468
                    1.0,
                    dtype=torch.float32,
                ),
                weight_loader=weight_loader,
            )
            w2_weight_scale = PerTensorScaleParameter(
469
                data=torch.full((num_experts,), 1.0, dtype=torch.float32),
470
471
472
473
474
475
476
                weight_loader=weight_loader,
            )
            layer.register_parameter("w13_weight_scale", w13_weight_scale)
            layer.register_parameter("w2_weight_scale", w2_weight_scale)

            # Set weight loader attributes for scales
            extra_weight_attrs.update(
477
478
                {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
            )
479
480
481

            # INPUT SCALES - Per-tensor scaling for ModelOpt
            w13_input_scale = PerTensorScaleParameter(
482
                data=torch.full((num_experts,), 1.0, dtype=torch.float32),
483
484
485
                weight_loader=weight_loader,
            )
            w2_input_scale = PerTensorScaleParameter(
486
                data=torch.full((num_experts,), 1.0, dtype=torch.float32),
487
488
489
490
491
492
493
494
495
496
                weight_loader=weight_loader,
            )
            layer.register_parameter("w13_input_scale", w13_input_scale)
            layer.register_parameter("w2_input_scale", w2_input_scale)

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        """Process FP8 MoE weights after loading from serialized checkpoint.
        Only supports pre-quantized checkpoints with FP8 weights and scales.
        """

497
        layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
498
499
500
501
        layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)

        from vllm._custom_ops import scaled_fp8_quant
        from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
502
503
            per_tensor_dequantize,
        )
504
505

        # Handle scale parameters
506
        if hasattr(layer, "w13_weight_scale") and layer.w13_weight_scale is not None:
507
508
509
            # Fp8 moe kernel needs single weight scale for w13 per expert.
            # We take the max of the w1 and w3 scales
            # then dequant and requant each expert.
510
511
512
513
514
515
516
517
            if (
                layer.w13_weight_scale.dim() == 2
                and layer.w13_weight_scale.shape[1] == 2
            ):
                assert self.moe.is_act_and_mul, (
                    "w13_weight_scale should have 2 elements per expert "
                    "only for gated MoE"
                )
518
519
520
521
522
523
524
525
526
527
528
529
                # Get the maximum scale across w1 and w3 for each expert
                max_w13_scales = layer.w13_weight_scale.max(dim=1).values

                # Requantize each expert's weights using the combined scale
                # w13_weight (num_experts, 2 * intermediate_size, hidden_size)
                # where the first intermediate_size rows are w1, the next are w3
                intermediate_size = layer.w13_weight.shape[1] // 2
                for expert_id in range(layer.w13_weight.shape[0]):
                    start = 0
                    for shard_id in range(2):  # w1 and w3
                        # Dequantize using the original scale for this shard
                        dq_weight = per_tensor_dequantize(
530
531
532
                            layer.w13_weight[expert_id][
                                start : start + intermediate_size, :
                            ],
533
534
535
536
537
                            layer.w13_weight_scale[expert_id][shard_id],
                        )
                        # Requantize using the combined max scale

                        (
538
539
540
                            layer.w13_weight[expert_id][
                                start : start + intermediate_size, :
                            ],
541
                            _,
542
                        ) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
543
544
545
546

                        start += intermediate_size

                # Update the scale parameter to be per-expert
547
                layer.w13_weight_scale = Parameter(max_w13_scales, requires_grad=False)
548
            else:
549
550
551
                layer.w13_weight_scale = Parameter(
                    layer.w13_weight_scale.data, requires_grad=False
                )
552

553
554
555
556
        if hasattr(layer, "w2_weight_scale") and layer.w2_weight_scale is not None:
            layer.w2_weight_scale = Parameter(
                layer.w2_weight_scale.data, requires_grad=False
            )
557
        # Input scales must be equal for each expert in fp8 MoE layers.
558
559
560
561
562
563
564
565
        if hasattr(layer, "w13_input_scale") and layer.w13_input_scale is not None:
            layer.w13_input_scale = Parameter(
                layer.w13_input_scale.max(), requires_grad=False
            )
        if hasattr(layer, "w2_input_scale") and layer.w2_input_scale is not None:
            layer.w2_input_scale = Parameter(
                layer.w2_input_scale.max(), requires_grad=False
            )
566

567
        if self.flashinfer_moe_backend is not None:
568
569
            if self.moe.is_act_and_mul:
                layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
570
            if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
571
                rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight)
572
        register_moe_scaling_factors(layer)
573

574
    def get_fused_moe_quant_config(
575
        self, layer: torch.nn.Module
576
    ) -> FusedMoEQuantConfig | None:
577
578
579
580
581
        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
            return None

        return fp8_w8a8_moe_quant_config(
            w1_scale=layer.w13_weight_scale,
582
            g1_alphas=layer.output1_scales_gate_scalar.squeeze(),
583
            w2_scale=layer.w2_weight_scale,
584
            g2_alphas=layer.output2_scales_scalar.squeeze(),
585
            a1_scale=layer.w13_input_scale,
586
            a1_gscale=layer.w13_input_scale,
587
            a2_scale=layer.w2_input_scale,
588
            a2_gscale=layer.w2_input_scale_inv,
589
590
591
            per_act_token_quant=False,
        )

592
593
594
595
596
597
598
599
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
        use_grouped_topk: bool = False,
600
601
        topk_group: int | None = None,
        num_expert_group: int | None = None,
602
        global_num_experts: int = -1,
603
604
        expert_map: torch.Tensor | None = None,
        custom_routing_function: Callable | None = None,
605
        scoring_func: str = "softmax",
606
        routed_scaling_factor: float = 1.0,
607
        e_score_correction_bias: torch.Tensor | None = None,
608
609
610
        apply_router_weight_on_input: bool = False,
        activation: str = "silu",
        enable_eplb: bool = False,
611
612
613
614
        expert_load_view: torch.Tensor | None = None,
        logical_to_physical_map: torch.Tensor | None = None,
        logical_replica_count: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
615
616
        if enable_eplb:
            raise NotImplementedError(
617
618
                "EPLB not supported for `ModelOptFp8MoEMethod` yet."
            )
619

620
        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
621
622
623
            assert activation == "silu", (
                f"Expected 'silu' activation but got {activation}"
            )
624
625
626
627
628
629
630
631
632
633
            assert not renormalize
            return apply_flashinfer_per_tensor_scale_fp8(
                layer=layer,
                hidden_states=x,
                router_logits=router_logits,
                routing_bias=e_score_correction_bias,
                global_num_experts=global_num_experts,
                top_k=top_k,
                num_expert_group=num_expert_group,
                topk_group=topk_group,
634
635
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
636

637
        # Expert selection
XuruiYang's avatar
XuruiYang committed
638
        topk_weights, topk_ids, _ = FusedMoE.select_experts(
639
640
641
642
643
644
645
646
647
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
            num_expert_group=num_expert_group,
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
648
            routed_scaling_factor=routed_scaling_factor,
649
            e_score_correction_bias=e_score_correction_bias,
650
            indices_type=self.topk_indices_dtype,
651
        )
652

653
        if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
654
655
656
            assert activation in ("silu", "relu2_no_mul"), (
                "Expected activation to be in ('silu', 'relu2_no_mul'),"
                f"but got {activation}"
657
            )
658
659
660
661
662
663
664
665
666
667
668
669
            return flashinfer_cutlass_moe_fp8(
                x,
                layer,
                topk_weights,
                topk_ids,
                inplace=False,
                activation=activation,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
        else:
670
671
            from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts

672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
            assert self.moe_quant_config is not None

            return fused_experts(
                x,
                layer.w13_weight,
                layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                inplace=True,
                activation=activation,
                quant_config=self.moe_quant_config,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
687
688


689
690
691
692
693
694
class ModelOptNvFp4Config(QuantizationConfig):
    """Config class for ModelOpt FP4."""

    def __init__(
        self,
        is_checkpoint_nvfp4_serialized: bool,
695
        kv_cache_quant_algo: str | None,
696
        exclude_modules: list[str],
697
698
        group_size: int = 16,
    ) -> None:
699
        super().__init__()
700
701
702
703
        self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
        if is_checkpoint_nvfp4_serialized:
            logger.warning(
                "Detected ModelOpt NVFP4 checkpoint. Please note that"
704
705
                " the format is experimental and could change in future."
            )
706
707
708
709
710
711

            self.group_size = group_size
            self.kv_cache_quant_algo = kv_cache_quant_algo
            self.exclude_modules = exclude_modules

    @classmethod
712
    def get_name(cls) -> QuantizationMethods:
713
        return "modelopt_fp4"
714
715

    @classmethod
716
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
717
718
719
720
        return [torch.bfloat16, torch.half, torch.float8_e4m3fn]

    @classmethod
    def get_min_capability(cls) -> int:
721
        return 80
722
723

    @classmethod
724
    def get_config_filenames(cls) -> list[str]:
725
726
        return ["hf_quant_config.json"]

727
728
    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if self.exclude_modules is not None:
729
            self.exclude_modules = hf_to_vllm_mapper.apply_list(self.exclude_modules)
730

731
732
    @classmethod
    def override_quantization_method(
733
        cls, hf_quant_cfg, user_quant
734
    ) -> QuantizationMethods | None:
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
        """Detect if this ModelOpt FP4 config should be used based on
        quantization config."""
        if hf_quant_cfg is None:
            return None

        # Use the community standard 'quant_method'
        quant_method = hf_quant_cfg.get("quant_method", "").lower()

        # Only proceed if the method is explicitly "modelopt"
        if quant_method != "modelopt":
            return None

        # Look for ModelOpt-specific config structure
        if "quantization" in hf_quant_cfg:
            quant_config = hf_quant_cfg["quantization"]
            if isinstance(quant_config, dict):
                quant_algo = quant_config.get("quant_algo", "")
                if "NVFP4" in quant_algo:
                    return "modelopt_fp4"
        else:
            # Check for compressed-tensors style config with specific
            # quant_algo field
            quant_algo = hf_quant_cfg.get("quant_algo", "")
            if isinstance(quant_algo, str) and "FP4" in quant_algo.upper():
                return "modelopt_fp4"

        return None

763
    @classmethod
764
    def from_config(cls, config: dict[str, Any]) -> "ModelOptNvFp4Config":
765
766
767
768
769
770
771
        # Handle both traditional ModelOpt format and compressed-tensors
        # style format
        if "quantization" in config:
            # Traditional ModelOpt format:
            # {"quantization": {"quant_algo": "..."}}
            quant_config = cls.get_from_keys(config, ["quantization"])
            if not isinstance(quant_config, dict):
772
                raise ValueError("Expected 'quantization' to be a dictionary in config")
773
774
775
776
777
778
779
780
781
782
783
784
785

            quant_method = quant_config.get("quant_algo", "")
            if not quant_method:
                raise ValueError("Missing 'quant_algo' in quantization config")

            # Handle kv_cache_quant_algo with proper type validation
            kv_cache_quant_algo_raw = quant_config.get("kv_cache_quant_algo")
            if kv_cache_quant_algo_raw is None:
                # No KV cache quantization by default
                kv_cache_quant_algo = None
            elif isinstance(kv_cache_quant_algo_raw, str):
                kv_cache_quant_algo = kv_cache_quant_algo_raw
            else:
786
787
788
789
                raise ValueError(
                    f"kv_cache_quant_algo must be a string, got "
                    f"{type(kv_cache_quant_algo_raw)}"
                )
790
791
792
793
794
795
796
797
798
799
800

            # Handle group_size with proper type validation
            group_size_raw = quant_config.get("group_size")
            if group_size_raw is None:
                group_size = 16  # Default value
            elif isinstance(group_size_raw, int):
                group_size = group_size_raw
            else:
                try:
                    group_size = int(group_size_raw)
                except (ValueError, TypeError):
801
802
803
                    raise ValueError(
                        f"group_size must be an integer, got {type(group_size_raw)}"
                    ) from None
804

805
            # "exclude_modules" is the key in the legacy hf_quant_config.json
806
807
            exclude_modules = quant_config.get("exclude_modules", [])
            if not isinstance(exclude_modules, list):
808
809
810
                raise ValueError(
                    f"exclude_modules must be a list, got {type(exclude_modules)}"
                )
811
812
813
814
815
816
817
818
819
820
821
822
823
        else:
            # Compressed-tensors style format:
            # {"quant_algo": "...", "quant_method": "modelopt"}
            quant_method = config.get("quant_algo", "")

            # Handle kv_cache_quant_algo with proper type validation
            kv_cache_quant_algo_raw = config.get("kv_cache_quant_algo")
            if kv_cache_quant_algo_raw is None:
                # No KV cache quantization by default
                kv_cache_quant_algo = None
            elif isinstance(kv_cache_quant_algo_raw, str):
                kv_cache_quant_algo = kv_cache_quant_algo_raw
            else:
824
825
826
827
                raise ValueError(
                    f"kv_cache_quant_algo must be a string, got "
                    f"{type(kv_cache_quant_algo_raw)}"
                )
828
829
830
831
832
833
834
835
836
837
838

            # Handle group_size with proper type validation
            group_size_raw = config.get("group_size")
            if group_size_raw is None:
                group_size = 16  # Default value
            elif isinstance(group_size_raw, int):
                group_size = group_size_raw
            else:
                try:
                    group_size = int(group_size_raw)
                except (ValueError, TypeError):
839
840
841
                    raise ValueError(
                        f"group_size must be an integer, got {type(group_size_raw)}"
                    ) from None
842

843
844
            # "ignore" is the key in config.json
            exclude_modules = config.get("ignore", [])
845
            if not isinstance(exclude_modules, list):
846
847
848
                raise ValueError(
                    f"exclude_modules must be a list, got {type(exclude_modules)}"
                )
849

850
        if quant_method not in QUANT_ALGOS:
851
852
853
854
            raise ValueError(
                f"ModelOpt currently only supports: {QUANT_ALGOS} "
                "quantizations in vLLM. Please check the "
                "`hf_quant_config.json` file for your model's "
855
856
857
                "quant configuration."
            )
        is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
858
859
860
861
862

        # For FP4, these fields are required
        if is_checkpoint_nvfp4_serialized and "quantization" in config:
            # Check if required fields are present in the quantization config
            quant_config = config["quantization"]
863
            required_fields = ["group_size", "kv_cache_quant_algo", "exclude_modules"]
864
865
866
867
868
869
            missing_fields = [
                field for field in required_fields if field not in quant_config
            ]
            if missing_fields:
                raise ValueError(
                    f"NVFP4 quantization requires the following fields in "
870
871
872
873
874
875
876
877
878
                    f"hf_quant_config.json: {missing_fields}"
                )

        return cls(
            is_checkpoint_nvfp4_serialized,
            kv_cache_quant_algo,
            exclude_modules,
            group_size,
        )
879

880
881
882
883
884
885
    def is_layer_excluded(self, prefix: str) -> bool:
        """
        Check if a layer should be excluded from quantization.
        Handles both exact matching (for fused layers) and pattern matching.
        """
        # First check exact matching with fused layer support
886
        if is_layer_skipped(prefix, self.exclude_modules, self.packed_modules_mapping):
887
888
889
            return True

        # Check regex pattern matching for patterns not caught by exact match
890
        import regex as re
891

892
893
        for pattern in self.exclude_modules:
            # Skip patterns that would be caught by exact matching
894
895
            if "*" in pattern or "." in pattern:
                regex_str = pattern.replace(".", r"\.").replace("*", r".*")
896
897
                if re.fullmatch(regex_str, prefix):
                    return True
898
899
        return False

900
901
902
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
903
904
905
906
        from vllm.attention.layer import (  # Avoid circular import
            Attention,
            MLAAttention,
        )
907

908
        skip_layer = self.is_layer_excluded(prefix)
909
        if isinstance(layer, LinearBase):
910
            if skip_layer:
911
912
                return UnquantizedLinearMethod()
            # Check if this is a vision model layer that should not be quantized
913
            if "vision_tower" in prefix or "vision_model" in prefix:
914
915
                return UnquantizedLinearMethod()
            return ModelOptNvFp4LinearMethod(self)
916
        elif isinstance(layer, (Attention, MLAAttention)):
917
            return ModelOptFp8KVCacheMethod(self)
918
        elif isinstance(layer, FusedMoE):
919
920
            if skip_layer:
                return None
921
            return ModelOptNvFp4FusedMoE(self, layer.moe_config, layer)
922
923
924
925
926
927
928
929
        return None


class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
    """

930
    def __init__(self, quant_config: ModelOptFp8Config | ModelOptNvFp4Config):
931
932
933
934
935
936
        super().__init__(quant_config)


class ModelOptNvFp4LinearMethod(LinearMethodBase):
    """Linear method for Model Optimizer NVFP4.
    Supports loading NVFP4 checkpoints with the following structure:
937

938
939
940
941
942
943
944
    input_scale: torch.float32, scalar ,
    weight: NVFP4(represented as byte) Shape: [1, X, y/2]
    weight_scale: FP8-E4M3, Shape: [X, Y], aka per block scale,
    weight_scale_2: torch.float32, scalar,
    Args: quant_config: The ModelOpt quantization config.
    """

945
    def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
946
        self.quant_config = quant_config
947

948
949
950
951
952
953
954
955
956
957
958
        self.backend = "none"
        if envs.VLLM_NVFP4_GEMM_BACKEND is None:
            if has_flashinfer():
                self.backend = "flashinfer-cutlass"
            elif cutlass_fp4_supported():
                self.backend = "cutlass"
            elif is_fp4_marlin_supported():
                self.backend = "marlin"
        elif envs.VLLM_NVFP4_GEMM_BACKEND.startswith("flashinfer-"):
            self.backend = envs.VLLM_NVFP4_GEMM_BACKEND
            assert has_flashinfer(), f"FlashInfer is required for {self.backend}"
959
960
961
        elif envs.VLLM_NVFP4_GEMM_BACKEND == "cutlass":
            self.backend = "cutlass"
            assert cutlass_fp4_supported(), f"Cutlass is required for {self.backend}"
962
963

        if self.backend == "none":
964
            raise ValueError(
965
966
                "No valid NVFP4 GEMM backend found. "
                "Please check your platform capability."
967
            )
968

969
970
        logger.info_once(f"Using {self.backend} for NVFP4 GEMM")

971
972
973
974
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
975
        output_partition_sizes: list[int],
976
977
978
979
980
981
982
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size
        if not self.quant_config.is_checkpoint_nvfp4_serialized:
983
984
985
986
            raise ValueError(
                "NVFP4 quantization was selected, "
                " dynamic quantization is not supported."
            )
987
988
989
990
991
992
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition

993
994
995
996
        if input_size_per_partition % 16 != 0:
            raise ValueError(
                "Unsupported model when in features size is not multiple of 16"
            )
997
        # The nvfp4 weight is still represented as
998
999
1000
1001
1002
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_nvfp4_serialized
            else params_dtype
        )
1003
1004
1005
1006
1007
1008
        # Weight
        weight = ModelWeightParameter(
            data=torch.empty(
                # 2 fp4 items are packed in the input dimension
                layer.output_size_per_partition,
                layer.input_size_per_partition // 2,
1009
1010
                dtype=torch.uint8,
            ),
1011
1012
            input_dim=1,
            output_dim=0,
1013
1014
            weight_loader=weight_loader,
        )
1015
1016
1017
        layer.register_parameter("weight", weight)

        # Input Weight Scale
1018
1019
1020
1021
        input_scale = PerTensorScaleParameter(
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )
1022
1023
1024
        layer.register_parameter("input_scale", input_scale)

        # Global Weight Scale
1025
1026
1027
1028
        weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )
1029
1030
1031
        layer.register_parameter("weight_scale_2", weight_scale_2)

        # Per Block Weight Scale
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
        weight_scale = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition // self.quant_config.group_size,
                dtype=weight_dtype,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052

        layer.register_parameter("weight_scale", weight_scale)

    def process_weights_after_loading(self, layer: Module) -> None:
        # global scales:
        input_scale_2 = layer.input_scale.max().to(torch.float32)
        layer.input_scale = Parameter(input_scale_2, requires_grad=False)

        weight_scale_2 = layer.weight_scale_2.max().to(torch.float32)
        layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False)

1053
1054
1055
        layer.alpha = Parameter(
            layer.input_scale * layer.weight_scale_2, requires_grad=False
        )
1056

1057
1058
        # Calculate `1 / input_scale` so that we don't need to do so at runtime
        layer.input_scale_inv = Parameter(
1059
1060
            (1 / layer.input_scale).to(torch.float32), requires_grad=False
        )
1061

1062
1063
1064
        # Swizzle the weight blockscale.
        # contracting dimension is input dimension
        # block_size = 16;
1065
1066
1067
        assert layer.weight_scale.dtype == torch.float8_e4m3fn, (
            "Weight Block scale must be represented as FP8-E4M3"
        )
1068

1069
1070
1071
1072
1073
        if self.backend == "marlin":
            prepare_fp4_layer_for_marlin(layer)
            del layer.alpha
            del layer.input_scale
        elif self.backend == "flashinfer-trtllm":
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
            # FlashInfer TRTLLM FP4 GEMM requires a different weight layout.
            # FlashInfer provides nvfp4_quantize to quantize + shuffle the
            # layout but we use our own quantization so we have to call
            # shuffles ourselves.
            from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a

            weight = layer.weight.data
            weight_scale = layer.weight_scale.data

            epilogue_tile_m = 128
1084
1085
1086
1087
1088
1089
            weight = shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m)
            weight_scale = (
                shuffle_matrix_sf_a(weight_scale.view(torch.uint8), epilogue_tile_m)
                .reshape(weight_scale.shape)
                .view(torch.float8_e4m3fn)
            )
1090

1091
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
1092
1093
1094
            layer.weight = Parameter(weight, requires_grad=False)
        else:
            swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
1095
            layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False)
1096
            layer.weight = Parameter(layer.weight.data, requires_grad=False)
1097
1098
1099
1100
1101

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
1102
        bias: torch.Tensor | None = None,
1103
    ) -> torch.Tensor:
1104
        if self.backend == "marlin":
1105
1106
1107
1108
1109
1110
1111
1112
            return apply_fp4_marlin_linear(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale,
                weight_scale_2=layer.weight_scale_2,
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
1113
1114
                bias=bias,
            )
1115

1116
        output_dtype = x.dtype
1117
        output_shape = [x.shape[0], layer.weight.shape[0]]
1118
1119

        # quantize BF16 or FP16 to (FP4 and interleaved block scale)
1120
        x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_scale_inv)
1121
1122
1123

        # validate dtypes of quantized input, input block scale,
        # weight and weight_blockscale
1124
1125
1126
1127
1128
        assert x_fp4.dtype == torch.uint8
        assert layer.weight.dtype == torch.uint8
        assert x_blockscale.dtype == torch.float8_e4m3fn
        assert layer.weight_scale.dtype == torch.float8_e4m3fn
        assert layer.alpha.dtype == torch.float32
1129

1130
1131
1132
1133
        mm_args = (
            x_fp4,
            layer.weight,
            x_blockscale,
1134
            layer.weight_scale,
1135
1136
1137
            layer.alpha,
            output_dtype,
        )
1138
1139
1140
        if self.backend.startswith("flashinfer-"):
            backend_name = self.backend[len("flashinfer-") :]
            out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name)
1141
        else:
1142
            assert self.backend == "cutlass"
1143
1144
            out = cutlass_scaled_fp4_mm(*mm_args)

1145
1146
1147
        if bias is not None:
            out = out + bias
        return out.view(*output_shape)
1148
1149
1150
1151
1152


class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
    """
    MoE Method for FP4 Quantization.
1153
    Args:
1154
1155
1156
        quant_config: NVFP4 Quant Config
    """

1157
1158
1159
1160
1161
1162
    def __init__(
        self,
        quant_config: ModelOptNvFp4Config,
        moe: FusedMoEConfig,
        layer: torch.nn.Module,
    ) -> None:
1163
1164
        from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import (
            detect_nvfp4_moe_support,  # noqa: E501
1165
1166
        )

1167
1168
1169
        super().__init__(moe)
        self.quant_config = quant_config
        self.layer = layer
1170
1171
        _nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
        self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
1172
        self.allow_flashinfer = _nvfp4.allow_flashinfer
1173
        self.use_marlin = _nvfp4.use_marlin
1174
        self.flashinfer_moe_backend = None
1175
        self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
1176
        if self.allow_flashinfer:
1177
1178
1179
            self.flashinfer_moe_backend = get_flashinfer_moe_backend()
            logger.info_once(
                f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
1180
1181
                " for ModelOptNvFp4FusedMoE."
            )
1182

1183
1184
1185
1186
    def maybe_make_prepare_finalize(
        self,
        routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
    ) -> mk.FusedMoEPrepareAndFinalize | None:
1187
1188
1189
1190
        if self.use_marlin or (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1191
            return None
1192
1193
1194
1195
        elif (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
        ):
1196
            # For now, fp4 moe only works with the flashinfer dispatcher.
1197
1198
1199
            prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(
                self.moe
            )
1200
1201
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
1202
        else:
1203
            return super().maybe_make_prepare_finalize(routing_tables)
1204

1205
1206
1207
    def select_gemm_impl(
        self,
        prepare_finalize: mk.FusedMoEPrepareAndFinalize,
1208
        layer: torch.nn.Module,
1209
    ) -> mk.FusedMoEPermuteExpertsUnpermute:
1210
        assert self.moe_quant_config is not None
1211
        experts = select_nvfp4_gemm_impl(
1212
1213
            self.moe,
            self.moe_quant_config,
1214
1215
1216
1217
            allow_flashinfer=self.allow_flashinfer,
        )
        logger.debug_once("Using %s", experts.__class__.__name__)
        return experts
1218

1219
1220
1221
1222
1223
1224
    def uses_weight_scale_2_pattern(self) -> bool:
        """
        FP4 variants use 'weight_scale_2' pattern for per-tensor weight scales.
        """
        return True

1225
1226
1227
1228
1229
1230
1231
1232
1233
    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
1234
        if not self.quant_config.is_checkpoint_nvfp4_serialized:
1235
1236
1237
1238
            raise ValueError(
                "NVFP4 quantization was selected, "
                " dynamic quantization is not supported."
            )
1239

1240
1241
        layer.num_experts = num_experts
        layer.params_dtype = params_dtype
1242
1243
1244
1245
        layer.quant_config = self.quant_config
        weight_dtype = torch.uint8
        weight_scale_dtype = torch.float8_e4m3fn
        weight_loader = extra_weight_attrs.get("weight_loader")
1246
        global_num_experts = extra_weight_attrs.get("global_num_experts")
1247
1248
1249
1250
1251
1252
1253
        # GEMM 1
        w13_weight = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                # 2 fp4 items are packed in the input dimension
                hidden_size // 2,
1254
1255
                dtype=weight_dtype,
            ),
1256
1257
            input_dim=1,
            output_dim=2,
1258
1259
            weight_loader=weight_loader,
        )
1260
1261
1262
1263
1264
1265
1266
1267
1268
        layer.register_parameter("w13_weight", w13_weight)

        # GEMM 2
        w2_weight = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                hidden_size,
                # 2 fp4 items are packed in the input dimension
                intermediate_size_per_partition // 2,
1269
1270
                dtype=weight_dtype,
            ),
1271
1272
            input_dim=1,
            output_dim=2,
1273
1274
            weight_loader=weight_loader,
        )
1275
1276
1277
1278
1279
1280
1281
1282
        layer.register_parameter("w2_weight", w2_weight)

        w13_weight_scale = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                # 2 fp4 items are packed in the input dimension
                hidden_size // self.quant_config.group_size,
1283
1284
                dtype=weight_scale_dtype,
            ),
1285
1286
            input_dim=1,
            output_dim=2,
1287
1288
            weight_loader=weight_loader,
        )
1289
1290
1291
1292
1293
1294
1295
        layer.register_parameter("w13_weight_scale", w13_weight_scale)

        w2_weight_scale = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                hidden_size,
                # 2 fp4 items are packed in the input dimension
1296
1297
1298
                intermediate_size_per_partition // self.quant_config.group_size,
                dtype=weight_scale_dtype,
            ),
1299
1300
            input_dim=1,
            output_dim=2,
1301
1302
            weight_loader=weight_loader,
        )
1303
1304
1305
        layer.register_parameter("w2_weight_scale", w2_weight_scale)

        extra_weight_attrs.update(
1306
1307
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
        )
1308
1309
1310

        w13_weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(num_experts, 2, dtype=torch.float32),
1311
1312
            weight_loader=weight_loader,
        )
1313
1314
1315
1316
        layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)

        w2_weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(num_experts, dtype=torch.float32),
1317
1318
            weight_loader=weight_loader,
        )
1319
1320
1321
        layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2)

        extra_weight_attrs.update(
1322
1323
            {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
        )
1324

1325
1326
1327
1328
1329
        use_global_sf = self.allow_flashinfer and is_flashinfer_supporting_global_sf(
            self.flashinfer_moe_backend
        )
        global_scale_num_experts = global_num_experts if use_global_sf else num_experts

1330
        w13_input_scale = PerTensorScaleParameter(
1331
            data=torch.empty(global_scale_num_experts, 2, dtype=torch.float32),
1332
1333
            weight_loader=weight_loader,
        )
1334
1335
        layer.register_parameter("w13_input_scale", w13_input_scale)

1336
        w2_input_scale = PerTensorScaleParameter(
1337
            data=torch.empty(global_scale_num_experts, dtype=torch.float32),
1338
1339
            weight_loader=weight_loader,
        )
1340
1341
        layer.register_parameter("w2_input_scale", w2_input_scale)

1342
    def prepare_static_weights_for_trtllm_fp4_moe(
1343
        self,
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
        # args_dequant,
        # args,
        gemm1_weights,
        gemm2_weights,
        gemm1_scales_linear_fp4_bytes,
        gemm2_scales_linear_fp4_bytes,
        hidden_size,
        intermediate_size,
        num_experts,
    ):
        from flashinfer import nvfp4_block_scale_interleave
        from flashinfer.fused_moe.core import (
1356
            _maybe_get_cached_w3_w1_permute_indices,
1357
            get_w2_permute_indices_with_cache,
1358
1359
        )

1360
1361
1362
1363
1364
        """Prepare quantized weights for kernel (done offline with weights)."""
        epilogue_tile_m = 128  # FIXME: this depends on the kernel internals

        # Convert quantized weights to proper formats
        gemm1_weights_fp4 = gemm1_weights.view(torch.float8_e4m3fn).reshape(
1365
1366
            num_experts, 2 * intermediate_size, hidden_size // 2
        )  # packed fp4
1367
        gemm1_scales_linear_fp4 = gemm1_scales_linear_fp4_bytes.view(
1368
1369
1370
1371
            torch.float8_e4m3fn
        ).reshape(
            num_experts, 2 * intermediate_size, hidden_size // 16
        )  # fp8 scaling factors
1372
1373

        gemm2_weights_fp4 = gemm2_weights.view(torch.float8_e4m3fn).reshape(
1374
1375
            num_experts, hidden_size, intermediate_size // 2
        )  # packed fp4
1376
        gemm2_scales_linear_fp4 = gemm2_scales_linear_fp4_bytes.view(
1377
1378
1379
1380
            torch.float8_e4m3fn
        ).reshape(
            num_experts, hidden_size, intermediate_size // 16
        )  # fp8 scaling factors
1381
1382
1383
1384
1385
1386

        gemm1_weights_fp4_shuffled = []
        gemm1_scales_fp4_shuffled = []
        gemm2_weights_fp4_shuffled = []
        gemm2_scales_fp4_shuffled = []
        for i in range(num_experts):
1387
1388
1389
1390
1391
1392
1393
1394
1395
            # Calculate the permute indices for the following:
            # 1. Reorder rows of W1 and scales for fused gated activation
            # 2. Shuffle weights and scaling factors for transposed mma output
            # for both w3_w1 and w2 weights and scale factors
            permute_indices = _maybe_get_cached_w3_w1_permute_indices(
                self._cache_permute_indices,
                gemm1_weights_fp4[i].view(torch.uint8),
                epilogue_tile_m,
            )
1396
1397
1398
1399
1400
            gemm1_weights_fp4_shuffled.append(
                gemm1_weights_fp4[i]
                .view(torch.uint8)[permute_indices.to(gemm1_weights_fp4.device)]
                .contiguous()
            )
1401
1402
1403
1404
1405
1406
1407

            permute_sf_indices = _maybe_get_cached_w3_w1_permute_indices(
                self._cache_permute_indices,
                gemm1_scales_linear_fp4[i].view(torch.uint8),
                epilogue_tile_m,
                num_elts_per_sf=16,
            )
1408
            gemm1_scales_fp4_shuffled.append(
1409
1410
1411
1412
1413
1414
1415
1416
                nvfp4_block_scale_interleave(
                    gemm1_scales_linear_fp4[i]
                    .view(torch.uint8)[
                        permute_sf_indices.to(gemm1_scales_linear_fp4.device)
                    ]
                    .contiguous()
                )
            )
1417

1418
            permute_indices = get_w2_permute_indices_with_cache(
1419
1420
1421
1422
                self._cache_permute_indices,
                gemm2_weights_fp4[i].view(torch.uint8),
                epilogue_tile_m,
            )
1423
1424
1425
1426
1427
            gemm2_weights_fp4_shuffled.append(
                gemm2_weights_fp4[i]
                .view(torch.uint8)[permute_indices.to(gemm2_weights_fp4.device)]
                .contiguous()
            )
1428

1429
            permute_sf_indices = get_w2_permute_indices_with_cache(
1430
1431
1432
1433
1434
                self._cache_permute_indices,
                gemm2_scales_linear_fp4[i].view(torch.uint8),
                epilogue_tile_m,
                num_elts_per_sf=16,
            )
1435
            gemm2_scales_fp4_shuffled.append(
1436
1437
1438
1439
1440
1441
1442
1443
                nvfp4_block_scale_interleave(
                    gemm2_scales_linear_fp4[i]
                    .view(torch.uint8)[
                        permute_sf_indices.to(gemm2_scales_linear_fp4.device)
                    ]
                    .contiguous()
                )
            )
1444
1445
1446
1447

        # Stack weights for all experts
        gemm1_weights_fp4_shuffled = torch.stack(gemm1_weights_fp4_shuffled)
        gemm1_scales_fp4_shuffled = (
1448
1449
1450
1451
            torch.stack(gemm1_scales_fp4_shuffled)
            .view(torch.float8_e4m3fn)
            .reshape(num_experts, 2 * intermediate_size, hidden_size // 16)
        )
1452
1453
1454

        gemm2_weights_fp4_shuffled = torch.stack(gemm2_weights_fp4_shuffled)
        gemm2_scales_fp4_shuffled = (
1455
1456
1457
1458
            torch.stack(gemm2_scales_fp4_shuffled)
            .view(torch.float8_e4m3fn)
            .reshape(num_experts, hidden_size, intermediate_size // 16)
        )
1459
1460
1461
1462
1463
1464
        return (
            gemm1_weights_fp4_shuffled,
            gemm1_scales_fp4_shuffled,
            gemm2_weights_fp4_shuffled,
            gemm2_scales_fp4_shuffled,
        )
1465

1466
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
1467
        # GEMM 1 processing
1468
1469
1470
        gemm1_weight = layer.w13_weight.data
        gemm1_weight_scale = layer.w13_weight_scale.data

1471
1472
1473
1474
        if (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
        ):
1475
            gemm1_weight, gemm1_weight_scale = reorder_w1w3_to_w3w1(
1476
1477
                gemm1_weight, gemm1_weight_scale, dim=-2
            )
1478
1479

        layer.w13_weight = Parameter(gemm1_weight, requires_grad=False)
1480
        layer.w13_weight_scale = Parameter(gemm1_weight_scale, requires_grad=False)
1481

1482
        # Common processing for w13_weight_scale_2
1483
1484
1485
        if not torch.allclose(
            layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]
        ):
1486
1487
            logger.warning_once(
                "w1_weight_scale_2 must match w3_weight_scale_2. "
1488
1489
                "Accuracy may be affected."
            )
1490
1491

        w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
1492
        layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False)
1493

1494
        # Common processing for input scales and alphas
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
        use_global_sf = self.allow_flashinfer and is_flashinfer_supporting_global_sf(
            self.flashinfer_moe_backend
        )
        if use_global_sf:
            # For backends provide by Flashinfer, the input global scales are
            # shared across all experts.
            w13_input_scale = (
                layer.w13_input_scale.max().to(torch.float32).expand(layer.num_experts)
            )
        else:
            w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
1506
1507
        layer.g1_alphas = Parameter(
            (w13_input_scale * w13_weight_scale_2).to(torch.float32),
1508
1509
            requires_grad=False,
        )
1510
1511
1512

        # This is for quantization, so we need to invert it.
        layer.w13_input_scale_quant = Parameter(
1513
1514
            (1 / w13_input_scale).to(torch.float32), requires_grad=False
        )
1515

1516
        # GEMM 2 processing
1517
1518
1519
1520
1521
1522
1523
1524
        if use_global_sf:
            # For backends provide by Flashinfer, the input global scales are
            # shared across all experts.
            w2_input_scale = (
                layer.w2_input_scale.max().to(torch.float32).expand(layer.num_experts)
            )
        else:
            w2_input_scale = layer.w2_input_scale
1525
        layer.g2_alphas = Parameter(
1526
            (w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
1527
1528
            requires_grad=False,
        )
1529
1530
1531

        # This is for quantization, so we need to invert it.
        layer.w2_input_scale_quant = Parameter(
1532
            (1 / w2_input_scale).to(torch.float32), requires_grad=False
1533
        )
1534

1535
        # TensorRT-LLM specific processing
1536
1537
1538
1539
        if (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1540
            # Prepare static weights for TRT-LLM kernel
1541
            # alternate: prepare_static_weight_layouts_for_trtllm_moe
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
            (
                gemm1_weights_fp4_shuffled,
                gemm1_scales_fp4_shuffled,
                gemm2_weights_fp4_shuffled,
                gemm2_scales_fp4_shuffled,
            ) = self.prepare_static_weights_for_trtllm_fp4_moe(
                layer.w13_weight,
                layer.w2_weight,
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                layer.w2_weight.size(-2),  # hidden_size
                layer.w13_weight.size(-2) // 2,  # intermediate_size
                layer.w13_weight.size(0),  # num_experts
            )
1556
            logger.debug_once("Finished shuffling weights for TRT-LLM MOE")
1557
1558

            layer.gemm1_weights_fp4_shuffled = Parameter(
1559
1560
                gemm1_weights_fp4_shuffled, requires_grad=False
            )
1561
            layer.gemm2_weights_fp4_shuffled = Parameter(
1562
1563
                gemm2_weights_fp4_shuffled, requires_grad=False
            )
1564
            layer.gemm1_scales_fp4_shuffled = Parameter(
1565
1566
                gemm1_scales_fp4_shuffled, requires_grad=False
            )
1567
            layer.gemm2_scales_fp4_shuffled = Parameter(
1568
1569
                gemm2_scales_fp4_shuffled, requires_grad=False
            )
1570
1571
1572

            # Additional parameter needed for TRT-LLM
            layer.g1_scale_c = Parameter(
1573
                (layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32),
1574
1575
                requires_grad=False,
            )
1576

1577
1578
1579
1580
1581
            # Clean up weights that won't be used by TRT-LLM
            del layer.w2_weight
            del layer.w2_weight_scale
            del layer.w13_weight
            del layer.w13_weight_scale
1582
1583
1584
1585
1586
1587
1588
        elif self.use_marlin:
            # Marlin processing
            prepare_moe_fp4_layer_for_marlin(layer)
            del layer.g1_alphas
            del layer.g2_alphas
            del layer.w13_input_scale_quant
            del layer.w2_input_scale_quant
1589
1590
        else:
            # Non-TRT-LLM processing (Cutlass or non-flashinfer)
1591
1592
1593
1594
1595
            w13_blockscale_swizzled = swizzle_blockscale(layer.w13_weight_scale)
            layer.w13_weight_scale = Parameter(
                w13_blockscale_swizzled, requires_grad=False
            )

1596
            w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale)
1597
1598
1599
1600
            layer.w2_weight_scale = Parameter(
                w2_blockscale_swizzled, requires_grad=False
            )
            layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
1601

1602
    def get_fused_moe_quant_config(
1603
        self, layer: torch.nn.Module
1604
    ) -> FusedMoEQuantConfig | None:
1605
1606
1607
1608
        if (
            self.use_marlin
            or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
            return None

        return nvfp4_moe_quant_config(
            w1_scale=layer.w13_weight_scale,
            w2_scale=layer.w2_weight_scale,
            g1_alphas=layer.g1_alphas,
            g2_alphas=layer.g2_alphas,
            a1_gscale=layer.w13_input_scale_quant,
            a2_gscale=layer.w2_input_scale_quant,
        )

1620
1621
1622
1623
1624
1625
1626
1627
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
        use_grouped_topk: bool = False,
1628
1629
        topk_group: int | None = None,
        num_expert_group: int | None = None,
1630
        global_num_experts: int = -1,
1631
1632
        expert_map: torch.Tensor | None = None,
        custom_routing_function: Callable | None = None,
1633
        scoring_func: str = "softmax",
1634
        routed_scaling_factor: float = 1.0,
1635
        e_score_correction_bias: torch.Tensor | None = None,
1636
1637
        apply_router_weight_on_input: bool = False,
        activation: str = "silu",
1638
        enable_eplb: bool = False,
1639
1640
1641
1642
        expert_load_view: torch.Tensor | None = None,
        logical_to_physical_map: torch.Tensor | None = None,
        logical_replica_count: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
1643
1644
        if enable_eplb:
            raise NotImplementedError(
1645
1646
                "EPLB not supported for `ModelOptNvFp4FusedMoE` yet."
            )
1647
        assert activation == "silu", "Only SiLU activation is supported."
1648

1649
1650
1651
1652
        if (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1653
1654
1655
1656
1657
            import flashinfer

            from vllm.model_executor.models.llama4 import Llama4MoE

            a1_gscale = layer.w13_input_scale_quant
1658
1659
1660
1661
1662
1663
1664
1665
            (hidden_states_fp4, hidden_states_scale_linear_fp4) = (
                flashinfer.fp4_quantize(
                    x,
                    a1_gscale,
                    is_sf_swizzled_layout=False,
                )
            )
            use_llama4_routing = (
1666
                custom_routing_function is Llama4MoE.custom_routing_function
1667
            )
1668
            routing_method_type = layer.routing_method_type
1669
            if use_llama4_routing:
1670
1671
1672
1673
1674
1675
                routing_method_type = RoutingMethodType.Llama4
            router_logits = (
                router_logits.to(torch.float32)
                if routing_method_type == RoutingMethodType.DeepSeekV3
                else router_logits
            )
Shu Wang's avatar
Shu Wang committed
1676
1677
1678
            routing_bias = e_score_correction_bias
            if routing_bias is not None:
                routing_bias = routing_bias.to(torch.bfloat16)
1679
            out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
1680
                routing_logits=router_logits,
Shu Wang's avatar
Shu Wang committed
1681
                routing_bias=routing_bias,
1682
1683
                hidden_states=hidden_states_fp4,
                hidden_states_scale=hidden_states_scale_linear_fp4.view(
1684
1685
                    torch.float8_e4m3fn
                ).flatten(),
1686
1687
                gemm1_weights=layer.gemm1_weights_fp4_shuffled.data,
                gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.data.view(
1688
1689
                    torch.float8_e4m3fn
                ),
1690
1691
1692
1693
1694
1695
                gemm1_bias=None,
                gemm1_alpha=None,
                gemm1_beta=None,
                gemm1_clamp_limit=None,
                gemm2_weights=layer.gemm2_weights_fp4_shuffled.data,
                gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.data.view(
1696
1697
                    torch.float8_e4m3fn
                ),
1698
1699
1700
1701
1702
1703
                gemm2_bias=None,
                output1_scale_scalar=layer.g1_scale_c.data,
                output1_scale_gate_scalar=layer.g1_alphas.data,
                output2_scale_scalar=layer.g2_alphas.data,
                num_experts=global_num_experts,
                top_k=top_k,
1704
1705
                n_group=num_expert_group,
                topk_group=topk_group,
1706
1707
1708
                intermediate_size=layer.intermediate_size_per_partition,
                local_expert_offset=layer.ep_rank * layer.local_num_experts,
                local_num_experts=layer.local_num_experts,
1709
                routed_scaling_factor=1.0,
1710
                tile_tokens_dim=None,
1711
1712
1713
1714
1715
                routing_method_type=routing_method_type,
                do_finalize=True,
            )[0]
            return out

XuruiYang's avatar
XuruiYang committed
1716
        topk_weights, topk_ids, _ = FusedMoE.select_experts(
1717
1718
1719
1720
1721
1722
1723
1724
1725
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
            num_expert_group=num_expert_group,
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
1726
            routed_scaling_factor=routed_scaling_factor,
1727
            e_score_correction_bias=e_score_correction_bias,
1728
1729
            indices_type=self.topk_indices_dtype,
        )
1730

1731
        if self.use_marlin:
1732
            return fused_marlin_moe(
1733
1734
1735
                x,
                layer.w13_weight,
                layer.w2_weight,
1736
1737
                None,
                None,
1738
1739
1740
1741
1742
1743
1744
1745
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                router_logits,
                topk_weights,
                topk_ids,
                global_scale1=layer.w13_weight_scale_2,
                global_scale2=layer.w2_weight_scale_2,
                quant_type_id=scalar_types.float4_e2m1f.id,
1746
                apply_router_weight_on_input=apply_router_weight_on_input,
1747
                global_num_experts=global_num_experts,
1748
                expert_map=expert_map,
1749
1750
                workspace=layer.workspace,
            )
1751

1752
1753
1754
1755
        elif self.allow_flashinfer:
            assert self.flashinfer_moe_backend in (
                FlashinferMoeBackend.CUTLASS,
                FlashinferMoeBackend.CUTEDSL,
1756
            )
1757
1758
1759
1760
            if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
                from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (  # noqa: E501
                    flashinfer_cutlass_moe_fp4,
                )
1761

1762
1763
1764
1765
1766
1767
1768
                flashinfer_fn_moe_fp4 = flashinfer_cutlass_moe_fp4
            else:
                from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import (  # noqa: E501
                    flashinfer_cutedsl_moe_fp4,
                )

                flashinfer_fn_moe_fp4 = flashinfer_cutedsl_moe_fp4
1769

1770
1771
            assert self.moe_quant_config is not None
            return flashinfer_fn_moe_fp4(
1772
1773
1774
1775
1776
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
1777
1778
                quant_config=self.moe_quant_config,
                inplace=False,
1779
1780
1781
1782
1783
1784
                activation=activation,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
        else:
1785
1786
            # If no modular kernel is provided, use cutlass_moe_fp4 for TP case
            # only (no EP).
1787
1788
            from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4

1789
1790
            assert self.moe_quant_config is not None
            return cutlass_moe_fp4(
1791
1792
1793
1794
1795
                a=x,
                w1_fp4=layer.w13_weight,
                w2_fp4=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
1796
1797
1798
1799
                quant_config=self.moe_quant_config,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
                # TODO: derive from arguments
1800
1801
1802
1803
                m=x.shape[0],
                n=layer.w2_weight.shape[2] * 2,
                k=x.shape[1],
                e=layer.w13_weight.shape[0],
1804
            )