"tests/kernels/attention/test_encoder_decoder_attn.py" did not exist on "f256ebe4df6757d76f1f1642d7e110268a2f8190"
modelopt.py 68.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Optional
6
7
8
9
10

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter

11
12
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
13
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
14
from vllm.logger import init_logger
15
from vllm.model_executor.layers.fused_moe.config import (
16
17
18
19
20
    FusedMoEConfig,
    FusedMoEQuantConfig,
    fp8_w8a8_moe_quant_config,
    nvfp4_moe_quant_config,
)
21
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
22
23
    is_valid_flashinfer_cutlass_fused_moe,
)
24
from vllm.model_executor.layers.fused_moe.layer import (
25
26
27
28
29
30
31
32
33
    FusedMoE,
    FusedMoEMethodBase,
    FusedMoeWeightScaleSupported,
)
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
)
34
from vllm.model_executor.layers.quantization import QuantizationMethods
35
from vllm.model_executor.layers.quantization.base_config import (
36
37
38
    QuantizationConfig,
    QuantizeMethodBase,
)
39
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
40
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
41
42
43
44
    build_flashinfer_fp4_cutlass_moe_prepare_finalize,
    reorder_w1w3_to_w3w1,
    select_nvfp4_gemm_impl,
)
45
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
46
47
    FlashinferMoeBackend,
    apply_flashinfer_per_tensor_scale_fp8,
48
    build_flashinfer_fp8_cutlass_moe_prepare_finalize,
49
50
51
52
53
54
55
    flashinfer_cutlass_moe_fp8,
    get_flashinfer_moe_backend,
    register_moe_scaling_factors,
    rotate_flashinfer_fp8_moe_weights,
    select_cutlass_fp8_gemm_impl,
    swap_w13_to_w31,
)
56
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
57
58
59
60
61
    apply_fp4_marlin_linear,
    is_fp4_marlin_supported,
    prepare_fp4_layer_for_marlin,
    prepare_moe_fp4_layer_for_marlin,
)
62
from vllm.model_executor.layers.quantization.utils.quant_utils import (
63
64
65
66
67
    GroupShape,
    cutlass_fp4_supported,
    is_layer_skipped,
    swizzle_blockscale,
)
68
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
69
70
71
72
    Fp8LinearOp,
    requantize_with_max_scale,
)
from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
73
from vllm.scalar_type import scalar_types
74
from vllm.utils import next_power_of_2
75
76
77
78
79
from vllm.utils.flashinfer import (
    flashinfer_scaled_fp4_mm,
    has_flashinfer,
    has_flashinfer_moe,
)
80

81
82
83
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

84
85
logger = init_logger(__name__)

86
87
QUANT_ALGOS = ["FP8", "NVFP4"]
KV_CACHE_QUANT_ALGOS = ["FP8"]
88
89
90
91
92
93
94
95


class ModelOptFp8Config(QuantizationConfig):
    """Config class for ModelOpt FP8."""

    def __init__(
        self,
        is_checkpoint_fp8_serialized: bool = False,
96
97
        kv_cache_quant_method: str | None = None,
        exclude_modules: list[str] | None = None,
98
    ) -> None:
99
        super().__init__()
100
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
101
        self.kv_cache_quant_method = kv_cache_quant_method
102
        self.exclude_modules = exclude_modules or []
103
        if is_checkpoint_fp8_serialized:
104
105
106
107
            logger.warning(
                "Detected ModelOpt fp8 checkpoint. Please note that"
                " the format is experimental and could change."
            )
108
109

    @classmethod
110
    def get_name(cls) -> QuantizationMethods:
111
112
113
        return "modelopt"

    @classmethod
114
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
115
116
117
118
119
120
121
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
        return 89

    @classmethod
122
    def get_config_filenames(cls) -> list[str]:
123
124
        return ["hf_quant_config.json"]

125
126
    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if self.exclude_modules is not None:
127
            self.exclude_modules = hf_to_vllm_mapper.apply_list(self.exclude_modules)
128

129
130
    @classmethod
    def override_quantization_method(
131
        cls, hf_quant_cfg, user_quant
132
    ) -> QuantizationMethods | None:
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
        """Detect if this ModelOpt config should be used based on
        quantization config."""

        if hf_quant_cfg is None:
            return None

        # Use the community standard 'quant_method'
        quant_method = hf_quant_cfg.get("quant_method", "").lower()

        # Only proceed if the method is explicitly "modelopt"
        if quant_method != "modelopt":
            return None

        # Look for ModelOpt-specific config structure
        if "quantization" in hf_quant_cfg:
            quant_config = hf_quant_cfg["quantization"]
            if isinstance(quant_config, dict):
                quant_algo = quant_config.get("quant_algo", "")
                if "FP8" in quant_algo:
                    return "modelopt"
        else:
            # Check for compressed-tensors style config with specific quant_algo
            quant_algo = hf_quant_cfg.get("quant_algo", "")
            if isinstance(quant_algo, str) and "FP8" in quant_algo:
                return "modelopt"

        return None

161
    @classmethod
162
    def from_config(cls, config: dict[str, Any]) -> "ModelOptFp8Config":
163
164
165
166
167
        # Handle both ModelOpt format and compressed-tensors style format
        if "quantization" in config:
            # ModelOpt format: {"quantization": {"quant_algo": "..."}}
            quant_config = cls.get_from_keys(config, ["quantization"])
            if not isinstance(quant_config, dict):
168
                raise ValueError("Expected 'quantization' to be a dictionary in config")
169
170
171
172
            quant_method = quant_config.get("quant_algo", "")
            if not quant_method:
                raise ValueError("Missing 'quant_algo' in quantization config")
            kv_cache_quant_method = quant_config.get("kv_cache_quant_algo")
173
            # "exclude_modules" is the key in the legacy hf_quant_config.json
174
175
176
177
178
179
            exclude_modules = quant_config.get("exclude_modules")
        else:
            # Compressed-tensors style format:
            # {"quant_algo": "...", "quant_method": "modelopt"}
            quant_method = config.get("quant_algo", "")
            kv_cache_quant_method = config.get("kv_cache_quant_algo")
180
181
            # "ignore" is the key in config.json
            exclude_modules = config.get("ignore")
182

183
        if quant_method not in QUANT_ALGOS:
184
185
186
187
            raise ValueError(
                f"ModelOpt currently only supports: {QUANT_ALGOS} "
                "quantizations in vLLM. Please check the "
                "`hf_quant_config.json` file for your model's "
188
189
190
                "quant configuration."
            )
        is_checkpoint_fp8_serialized = "FP8" in quant_method
191

192
        return cls(is_checkpoint_fp8_serialized, kv_cache_quant_method, exclude_modules)
193
194
195
196

    def is_layer_excluded(self, prefix: str) -> bool:
        """
        Check if a layer should be excluded from quantization.
197
        Handles both exact matching (for fused layers) and substring matching.
198
199
200
201
202
203
204
205

        This method handles both regular models and multimodal models that use
        the language_model prefix. For multimodal models, it checks if the
        module name (without the language_model prefix) is in the exclude list.
        """
        if self.exclude_modules is None:
            return False

206
        # First check exact matching with fused layer support
207
        if is_layer_skipped(prefix, self.exclude_modules, self.packed_modules_mapping):
208
209
210
            return True

        # Then check substring matching for patterns not caught by exact match
211
        for module in self.exclude_modules:
212
            # Skip exact matches already handled above
213
214
215
216
217
218
219
            if module != prefix and (
                module in prefix
                or (
                    prefix.startswith("language_model.")
                    and module in prefix.removeprefix("language_model.")
                )
            ):
220
221
                return True
        return False
222

223
224
225
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
226
        from vllm.attention.layer import Attention  # Avoid circular import
227

228
        if isinstance(layer, LinearBase):
229
230
231
            if self.is_layer_excluded(prefix):
                return UnquantizedLinearMethod()
            # Check if this is a vision model layer that should not be quantized
232
            if "vision_tower" in prefix or "vision_model" in prefix:
233
                return UnquantizedLinearMethod()
234
235
236
            return ModelOptFp8LinearMethod(self)
        elif isinstance(layer, Attention):
            return ModelOptFp8KVCacheMethod(self)
237
        elif isinstance(layer, FusedMoE):
238
            return ModelOptFp8MoEMethod(self, layer)
239
240
241
242
243
244
        return None


class ModelOptFp8LinearMethod(LinearMethodBase):
    """Linear method for Model Optimizer static quantization.
    Supports loading FP8 checkpoints with static weight scale and
245
    activation scale. Future support might be added for dynamic
246
247
248
249
    scales.

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
250
    2. Only support float8_e4m3fn datatype
251
252
253
        Args: quant_config: The ModelOpt quantization config.
    """

254
    def __init__(self, quant_config: ModelOptFp8Config) -> None:
255
        self.quant_config = quant_config
256
        self.fp8_linear = Fp8LinearOp(
257
258
            act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR
        )
259
260
261
262
263

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
264
        output_partition_sizes: list[int],
265
266
267
268
269
270
271
272
273
274
275
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
276
277
278
279
280
281
282
283
284
285
286
287
288
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_fp8_serialized
            else params_dtype
        )
        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition, input_size_per_partition, dtype=weight_dtype
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
289
290
291
292
        layer.register_parameter("weight", weight)

        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
293
294
295
296
            weight_scale = PerTensorScaleParameter(
                data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
                weight_loader=weight_loader,
            )
297
298
299
            weight_scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("weight_scale", weight_scale)
            # INPUT SCALE
300
301
302
303
            scale = PerTensorScaleParameter(
                data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
                weight_loader=weight_loader,
            )
304
305
306
307
308

            scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("input_scale", scale)

    def process_weights_after_loading(self, layer: Module) -> None:
309
310
311
312
        weight = layer.weight
        max_w_scale = layer.weight_scale.max()
        if not (layer.weight_scale == layer.weight_scale[0]).all():
            max_w_scale, weight = requantize_with_max_scale(
313
314
                layer.weight, layer.weight_scale, layer.logical_widths
            )
315
316
        layer.weight = Parameter(weight.t(), requires_grad=False)
        layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
317
        layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
318
319
320
321
322

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
323
        bias: torch.Tensor | None = None,
324
    ) -> torch.Tensor:
325
326
327
328
329
330
331
        return self.fp8_linear.apply(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            input_scale=layer.input_scale,
            bias=bias,
        )
332
333


334
335
336
337
338
339
340
341
class ModelOptFp8MoEMethod(FusedMoEMethodBase):
    """MoE method for ModelOpt FP8.
    Supports loading FP8 checkpoints with static weight scale and
    activation scale.
    Args:
        quant_config: The ModelOpt quantization config.
    """

342
343
344
    def __init__(
        self,
        quant_config: ModelOptFp8Config,
345
        layer: torch.nn.Module,
346
    ) -> None:
347
348
        super().__init__(layer.moe_config)
        self.layer = layer
349
350
        self.quant_config = quant_config
        from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
351
352
353
            cutlass_fp8_supported,
        )

354
        self.cutlass_fp8_supported = cutlass_fp8_supported()
355
        self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
356
        if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
357
            self.flashinfer_moe_backend = get_flashinfer_moe_backend()
358
            logger.info_once(
359
360
361
362
                f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
            )

    def maybe_make_prepare_finalize(
363
        self,
364
    ) -> mk.FusedMoEPrepareAndFinalize | None:
365
366
367
368
        # TRT LLM not supported with all2all yet.
        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
            return None
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
369
370
371
            prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
                self.moe
            )
372
373
374
375
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
        else:
            return super().maybe_make_prepare_finalize()
376
377
378
379

    def select_gemm_impl(
        self,
        prepare_finalize: mk.FusedMoEPrepareAndFinalize,
380
        layer: torch.nn.Module,
381
    ) -> mk.FusedMoEPermuteExpertsUnpermute:
382
        assert self.moe_quant_config is not None
383
        experts = select_cutlass_fp8_gemm_impl(
384
385
            self.moe,
            self.moe_quant_config,
386
387
388
        )
        logger.debug_once("Using %s", experts.__class__.__name__)
        return experts
389
390
391
392
393
394
395
396
397
398
399

    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        # Use FP8 dtype if checkpoint is serialized
400
401
402
403
404
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_fp8_serialized
            else params_dtype
        )
405
406
407
        weight_loader = extra_weight_attrs.get("weight_loader")

        w13_weight = ModelWeightParameter(
408
409
410
411
412
413
            data=torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                hidden_size,
                dtype=weight_dtype,
            ),
414
415
416
417
418
419
420
            input_dim=2,
            output_dim=1,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_weight", w13_weight)

        w2_weight = ModelWeightParameter(
421
422
423
424
425
426
            data=torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=weight_dtype,
            ),
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
            input_dim=2,
            output_dim=1,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w2_weight", w2_weight)

        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALES - Per-tensor scaling for ModelOpts
            # Allocate 2 scales for w1 and w3 respectively.
            # They will be combined to a single scale after weight loading.
            w13_weight_scale = PerTensorScaleParameter(
                data=torch.full(
                    (num_experts, 2),
                    1.0,
                    dtype=torch.float32,
                ),
                weight_loader=weight_loader,
            )
            w2_weight_scale = PerTensorScaleParameter(
446
                data=torch.full((num_experts,), 1.0, dtype=torch.float32),
447
448
449
450
451
452
453
                weight_loader=weight_loader,
            )
            layer.register_parameter("w13_weight_scale", w13_weight_scale)
            layer.register_parameter("w2_weight_scale", w2_weight_scale)

            # Set weight loader attributes for scales
            extra_weight_attrs.update(
454
455
                {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
            )
456
457
458

            # INPUT SCALES - Per-tensor scaling for ModelOpt
            w13_input_scale = PerTensorScaleParameter(
459
                data=torch.full((num_experts,), 1.0, dtype=torch.float32),
460
461
462
                weight_loader=weight_loader,
            )
            w2_input_scale = PerTensorScaleParameter(
463
                data=torch.full((num_experts,), 1.0, dtype=torch.float32),
464
465
466
467
468
469
470
471
472
473
                weight_loader=weight_loader,
            )
            layer.register_parameter("w13_input_scale", w13_input_scale)
            layer.register_parameter("w2_input_scale", w2_input_scale)

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        """Process FP8 MoE weights after loading from serialized checkpoint.
        Only supports pre-quantized checkpoints with FP8 weights and scales.
        """

474
        layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
475
476
477
478
        layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)

        from vllm._custom_ops import scaled_fp8_quant
        from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
479
480
            per_tensor_dequantize,
        )
481
482

        # Handle scale parameters
483
        if hasattr(layer, "w13_weight_scale") and layer.w13_weight_scale is not None:
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
            # Fp8 moe kernel needs single weight scale for w13 per expert.
            # We take the max of the w1 and w3 scales
            # then dequant and requant each expert.
            if layer.w13_weight_scale.dim() == 2:
                # Get the maximum scale across w1 and w3 for each expert
                max_w13_scales = layer.w13_weight_scale.max(dim=1).values

                # Requantize each expert's weights using the combined scale
                # w13_weight (num_experts, 2 * intermediate_size, hidden_size)
                # where the first intermediate_size rows are w1, the next are w3
                intermediate_size = layer.w13_weight.shape[1] // 2
                for expert_id in range(layer.w13_weight.shape[0]):
                    start = 0
                    for shard_id in range(2):  # w1 and w3
                        # Dequantize using the original scale for this shard
                        dq_weight = per_tensor_dequantize(
500
501
502
                            layer.w13_weight[expert_id][
                                start : start + intermediate_size, :
                            ],
503
504
505
506
507
                            layer.w13_weight_scale[expert_id][shard_id],
                        )
                        # Requantize using the combined max scale

                        (
508
509
510
                            layer.w13_weight[expert_id][
                                start : start + intermediate_size, :
                            ],
511
                            _,
512
                        ) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
513
514
515
516

                        start += intermediate_size

                # Update the scale parameter to be per-expert
517
                layer.w13_weight_scale = Parameter(max_w13_scales, requires_grad=False)
518
            else:
519
520
521
                layer.w13_weight_scale = Parameter(
                    layer.w13_weight_scale.data, requires_grad=False
                )
522

523
524
525
526
        if hasattr(layer, "w2_weight_scale") and layer.w2_weight_scale is not None:
            layer.w2_weight_scale = Parameter(
                layer.w2_weight_scale.data, requires_grad=False
            )
527
        # Input scales must be equal for each expert in fp8 MoE layers.
528
529
530
531
532
533
534
535
        if hasattr(layer, "w13_input_scale") and layer.w13_input_scale is not None:
            layer.w13_input_scale = Parameter(
                layer.w13_input_scale.max(), requires_grad=False
            )
        if hasattr(layer, "w2_input_scale") and layer.w2_input_scale is not None:
            layer.w2_input_scale = Parameter(
                layer.w2_input_scale.max(), requires_grad=False
            )
536

537
        if self.flashinfer_moe_backend is not None:
538
            layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
539
            register_moe_scaling_factors(layer)
540
            if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
541
                rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight)
542

543
    def get_fused_moe_quant_config(
544
        self, layer: torch.nn.Module
545
    ) -> FusedMoEQuantConfig | None:
546
547
548
549
550
551
552
553
554
555
556
        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
            return None

        return fp8_w8a8_moe_quant_config(
            w1_scale=layer.w13_weight_scale,
            w2_scale=layer.w2_weight_scale,
            a1_scale=layer.w13_input_scale,
            a2_scale=layer.w2_input_scale,
            per_act_token_quant=False,
        )

557
558
559
560
561
562
563
564
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
        use_grouped_topk: bool = False,
565
566
        topk_group: int | None = None,
        num_expert_group: int | None = None,
567
        global_num_experts: int = -1,
568
569
        expert_map: torch.Tensor | None = None,
        custom_routing_function: Callable | None = None,
570
        scoring_func: str = "softmax",
571
        routed_scaling_factor: float = 1.0,
572
        e_score_correction_bias: torch.Tensor | None = None,
573
574
575
        apply_router_weight_on_input: bool = False,
        activation: str = "silu",
        enable_eplb: bool = False,
576
577
578
579
        expert_load_view: torch.Tensor | None = None,
        logical_to_physical_map: torch.Tensor | None = None,
        logical_replica_count: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
580
581
        if enable_eplb:
            raise NotImplementedError(
582
583
                "EPLB not supported for `ModelOptFp8MoEMethod` yet."
            )
584

585
        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
586
            assert self.fused_experts is None
587
588
589
            assert activation == "silu", (
                f"Expected 'silu' activation but got {activation}"
            )
590
591
592
593
594
595
596
597
598
599
            assert not renormalize
            return apply_flashinfer_per_tensor_scale_fp8(
                layer=layer,
                hidden_states=x,
                router_logits=router_logits,
                routing_bias=e_score_correction_bias,
                global_num_experts=global_num_experts,
                top_k=top_k,
                num_expert_group=num_expert_group,
                topk_group=topk_group,
600
601
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
602

603
        # Expert selection
XuruiYang's avatar
XuruiYang committed
604
        topk_weights, topk_ids, _ = FusedMoE.select_experts(
605
606
607
608
609
610
611
612
613
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
            num_expert_group=num_expert_group,
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
614
            routed_scaling_factor=routed_scaling_factor,
615
            e_score_correction_bias=e_score_correction_bias,
616
            indices_type=self.topk_indices_dtype,
617
        )
618

619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
        #
        # Note: the order here is important. self.fused_experts can override
        # cutlass or fused_experts.
        #
        if self.fused_experts is not None:
            return self.fused_experts(
                x,
                layer.w13_weight,
                layer.w2_weight,
                topk_weights,
                topk_ids,
                inplace=False,
                activation=activation,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
637
            assert not renormalize
638
639
640
            assert activation == "silu", (
                f"Expected 'silu' activation but got {activation}"
            )
641
642
643
644
645
646
647
648
649
650
651
652
            return flashinfer_cutlass_moe_fp8(
                x,
                layer,
                topk_weights,
                topk_ids,
                inplace=False,
                activation=activation,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
        else:
653
654
            from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts

655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
            assert self.moe_quant_config is not None

            return fused_experts(
                x,
                layer.w13_weight,
                layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                inplace=True,
                activation=activation,
                quant_config=self.moe_quant_config,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
670
671


672
673
674
675
676
677
class ModelOptNvFp4Config(QuantizationConfig):
    """Config class for ModelOpt FP4."""

    def __init__(
        self,
        is_checkpoint_nvfp4_serialized: bool,
678
        kv_cache_quant_algo: str | None,
679
        exclude_modules: list[str],
680
681
        group_size: int = 16,
    ) -> None:
682
        super().__init__()
683
684
685
686
        self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
        if is_checkpoint_nvfp4_serialized:
            logger.warning(
                "Detected ModelOpt NVFP4 checkpoint. Please note that"
687
688
                " the format is experimental and could change in future."
            )
689
690
691
692
693
694

            self.group_size = group_size
            self.kv_cache_quant_algo = kv_cache_quant_algo
            self.exclude_modules = exclude_modules

    @classmethod
695
    def get_name(cls) -> QuantizationMethods:
696
        return "modelopt_fp4"
697
698

    @classmethod
699
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
700
701
702
703
        return [torch.bfloat16, torch.half, torch.float8_e4m3fn]

    @classmethod
    def get_min_capability(cls) -> int:
704
        return 80
705
706

    @classmethod
707
    def get_config_filenames(cls) -> list[str]:
708
709
        return ["hf_quant_config.json"]

710
711
    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if self.exclude_modules is not None:
712
            self.exclude_modules = hf_to_vllm_mapper.apply_list(self.exclude_modules)
713

714
715
    @classmethod
    def override_quantization_method(
716
        cls, hf_quant_cfg, user_quant
717
    ) -> QuantizationMethods | None:
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
        """Detect if this ModelOpt FP4 config should be used based on
        quantization config."""
        if hf_quant_cfg is None:
            return None

        # Use the community standard 'quant_method'
        quant_method = hf_quant_cfg.get("quant_method", "").lower()

        # Only proceed if the method is explicitly "modelopt"
        if quant_method != "modelopt":
            return None

        # Look for ModelOpt-specific config structure
        if "quantization" in hf_quant_cfg:
            quant_config = hf_quant_cfg["quantization"]
            if isinstance(quant_config, dict):
                quant_algo = quant_config.get("quant_algo", "")
                if "NVFP4" in quant_algo:
                    return "modelopt_fp4"
        else:
            # Check for compressed-tensors style config with specific
            # quant_algo field
            quant_algo = hf_quant_cfg.get("quant_algo", "")
            if isinstance(quant_algo, str) and "FP4" in quant_algo.upper():
                return "modelopt_fp4"

        return None

746
    @classmethod
747
    def from_config(cls, config: dict[str, Any]) -> "ModelOptNvFp4Config":
748
749
750
751
752
753
754
        # Handle both traditional ModelOpt format and compressed-tensors
        # style format
        if "quantization" in config:
            # Traditional ModelOpt format:
            # {"quantization": {"quant_algo": "..."}}
            quant_config = cls.get_from_keys(config, ["quantization"])
            if not isinstance(quant_config, dict):
755
                raise ValueError("Expected 'quantization' to be a dictionary in config")
756
757
758
759
760
761
762
763
764
765
766
767
768

            quant_method = quant_config.get("quant_algo", "")
            if not quant_method:
                raise ValueError("Missing 'quant_algo' in quantization config")

            # Handle kv_cache_quant_algo with proper type validation
            kv_cache_quant_algo_raw = quant_config.get("kv_cache_quant_algo")
            if kv_cache_quant_algo_raw is None:
                # No KV cache quantization by default
                kv_cache_quant_algo = None
            elif isinstance(kv_cache_quant_algo_raw, str):
                kv_cache_quant_algo = kv_cache_quant_algo_raw
            else:
769
770
771
772
                raise ValueError(
                    f"kv_cache_quant_algo must be a string, got "
                    f"{type(kv_cache_quant_algo_raw)}"
                )
773
774
775
776
777
778
779
780
781
782
783

            # Handle group_size with proper type validation
            group_size_raw = quant_config.get("group_size")
            if group_size_raw is None:
                group_size = 16  # Default value
            elif isinstance(group_size_raw, int):
                group_size = group_size_raw
            else:
                try:
                    group_size = int(group_size_raw)
                except (ValueError, TypeError):
784
785
786
                    raise ValueError(
                        f"group_size must be an integer, got {type(group_size_raw)}"
                    ) from None
787

788
            # "exclude_modules" is the key in the legacy hf_quant_config.json
789
790
            exclude_modules = quant_config.get("exclude_modules", [])
            if not isinstance(exclude_modules, list):
791
792
793
                raise ValueError(
                    f"exclude_modules must be a list, got {type(exclude_modules)}"
                )
794
795
796
797
798
799
800
801
802
803
804
805
806
        else:
            # Compressed-tensors style format:
            # {"quant_algo": "...", "quant_method": "modelopt"}
            quant_method = config.get("quant_algo", "")

            # Handle kv_cache_quant_algo with proper type validation
            kv_cache_quant_algo_raw = config.get("kv_cache_quant_algo")
            if kv_cache_quant_algo_raw is None:
                # No KV cache quantization by default
                kv_cache_quant_algo = None
            elif isinstance(kv_cache_quant_algo_raw, str):
                kv_cache_quant_algo = kv_cache_quant_algo_raw
            else:
807
808
809
810
                raise ValueError(
                    f"kv_cache_quant_algo must be a string, got "
                    f"{type(kv_cache_quant_algo_raw)}"
                )
811
812
813
814
815
816
817
818
819
820
821

            # Handle group_size with proper type validation
            group_size_raw = config.get("group_size")
            if group_size_raw is None:
                group_size = 16  # Default value
            elif isinstance(group_size_raw, int):
                group_size = group_size_raw
            else:
                try:
                    group_size = int(group_size_raw)
                except (ValueError, TypeError):
822
823
824
                    raise ValueError(
                        f"group_size must be an integer, got {type(group_size_raw)}"
                    ) from None
825

826
827
            # "ignore" is the key in config.json
            exclude_modules = config.get("ignore", [])
828
            if not isinstance(exclude_modules, list):
829
830
831
                raise ValueError(
                    f"exclude_modules must be a list, got {type(exclude_modules)}"
                )
832

833
        if quant_method not in QUANT_ALGOS:
834
835
836
837
            raise ValueError(
                f"ModelOpt currently only supports: {QUANT_ALGOS} "
                "quantizations in vLLM. Please check the "
                "`hf_quant_config.json` file for your model's "
838
839
840
                "quant configuration."
            )
        is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
841
842
843
844
845

        # For FP4, these fields are required
        if is_checkpoint_nvfp4_serialized and "quantization" in config:
            # Check if required fields are present in the quantization config
            quant_config = config["quantization"]
846
            required_fields = ["group_size", "kv_cache_quant_algo", "exclude_modules"]
847
848
849
850
851
852
            missing_fields = [
                field for field in required_fields if field not in quant_config
            ]
            if missing_fields:
                raise ValueError(
                    f"NVFP4 quantization requires the following fields in "
853
854
855
856
857
858
859
860
861
                    f"hf_quant_config.json: {missing_fields}"
                )

        return cls(
            is_checkpoint_nvfp4_serialized,
            kv_cache_quant_algo,
            exclude_modules,
            group_size,
        )
862

863
864
865
866
867
868
    def is_layer_excluded(self, prefix: str) -> bool:
        """
        Check if a layer should be excluded from quantization.
        Handles both exact matching (for fused layers) and pattern matching.
        """
        # First check exact matching with fused layer support
869
        if is_layer_skipped(prefix, self.exclude_modules, self.packed_modules_mapping):
870
871
872
            return True

        # Check regex pattern matching for patterns not caught by exact match
873
        import regex as re
874

875
876
        for pattern in self.exclude_modules:
            # Skip patterns that would be caught by exact matching
877
878
            if "*" in pattern or "." in pattern:
                regex_str = pattern.replace(".", r"\.").replace("*", r".*")
879
880
                if re.fullmatch(regex_str, prefix):
                    return True
881
882
        return False

883
884
885
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
886
        from vllm.attention.layer import Attention  # Avoid circular import
887

888
        skip_layer = self.is_layer_excluded(prefix)
889
        if isinstance(layer, LinearBase):
890
            if skip_layer:
891
892
                return UnquantizedLinearMethod()
            # Check if this is a vision model layer that should not be quantized
893
            if "vision_tower" in prefix or "vision_model" in prefix:
894
895
896
897
                return UnquantizedLinearMethod()
            return ModelOptNvFp4LinearMethod(self)
        elif isinstance(layer, Attention):
            return ModelOptFp8KVCacheMethod(self)
898
        elif isinstance(layer, FusedMoE):
899
900
            if skip_layer:
                return None
901
            return ModelOptNvFp4FusedMoE(self, layer.moe_config, layer)
902
903
904
905
906
907
908
909
        return None


class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
    """

910
    def __init__(self, quant_config: ModelOptFp8Config | ModelOptNvFp4Config):
911
912
913
914
915
916
        super().__init__(quant_config)


class ModelOptNvFp4LinearMethod(LinearMethodBase):
    """Linear method for Model Optimizer NVFP4.
    Supports loading NVFP4 checkpoints with the following structure:
917

918
919
920
921
922
923
924
    input_scale: torch.float32, scalar ,
    weight: NVFP4(represented as byte) Shape: [1, X, y/2]
    weight_scale: FP8-E4M3, Shape: [X, Y], aka per block scale,
    weight_scale_2: torch.float32, scalar,
    Args: quant_config: The ModelOpt quantization config.
    """

925
    def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
926
        self.quant_config = quant_config
927

928
929
930
931
932
933
934
935
936
937
        if envs.VLLM_USE_TRTLLM_FP4_GEMM:
            assert has_flashinfer(), "TRTLLM FP4 GEMM requires FlashInfer"
            self.backend = "flashinfer-trtllm"
        elif has_flashinfer():
            self.backend = "flashinfer-cutlass"
        elif cutlass_fp4_supported():
            self.backend = "cutlass"
        elif is_fp4_marlin_supported():
            self.backend = "marlin"
        else:
938
939
940
941
942
            raise ValueError(
                "Current platform does not support NVFP4"
                " quantization. Please use Blackwell and"
                " above."
            )
943
944
945
946
947

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
948
        output_partition_sizes: list[int],
949
950
951
952
953
954
955
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size
        if not self.quant_config.is_checkpoint_nvfp4_serialized:
956
957
958
959
            raise ValueError(
                "NVFP4 quantization was selected, "
                " dynamic quantization is not supported."
            )
960
961
962
963
964
965
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition

966
967
968
969
        if input_size_per_partition % 16 != 0:
            raise ValueError(
                "Unsupported model when in features size is not multiple of 16"
            )
970
        # The nvfp4 weight is still represented as
971
972
973
974
975
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_nvfp4_serialized
            else params_dtype
        )
976
977
978
979
980
981
        # Weight
        weight = ModelWeightParameter(
            data=torch.empty(
                # 2 fp4 items are packed in the input dimension
                layer.output_size_per_partition,
                layer.input_size_per_partition // 2,
982
983
                dtype=torch.uint8,
            ),
984
985
            input_dim=1,
            output_dim=0,
986
987
            weight_loader=weight_loader,
        )
988
989
990
        layer.register_parameter("weight", weight)

        # Input Weight Scale
991
992
993
994
        input_scale = PerTensorScaleParameter(
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )
995
996
997
        layer.register_parameter("input_scale", input_scale)

        # Global Weight Scale
998
999
1000
1001
        weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )
1002
1003
1004
        layer.register_parameter("weight_scale_2", weight_scale_2)

        # Per Block Weight Scale
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
        weight_scale = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition // self.quant_config.group_size,
                dtype=weight_dtype,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025

        layer.register_parameter("weight_scale", weight_scale)

    def process_weights_after_loading(self, layer: Module) -> None:
        # global scales:
        input_scale_2 = layer.input_scale.max().to(torch.float32)
        layer.input_scale = Parameter(input_scale_2, requires_grad=False)

        weight_scale_2 = layer.weight_scale_2.max().to(torch.float32)
        layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False)

1026
1027
1028
        layer.alpha = Parameter(
            layer.input_scale * layer.weight_scale_2, requires_grad=False
        )
1029

1030
1031
        # Calculate `1 / input_scale` so that we don't need to do so at runtime
        layer.input_scale_inv = Parameter(
1032
1033
            (1 / layer.input_scale).to(torch.float32), requires_grad=False
        )
1034

1035
1036
1037
        # Swizzle the weight blockscale.
        # contracting dimension is input dimension
        # block_size = 16;
1038
1039
1040
        assert layer.weight_scale.dtype == torch.float8_e4m3fn, (
            "Weight Block scale must be represented as FP8-E4M3"
        )
1041

1042
1043
1044
1045
1046
        if self.backend == "marlin":
            prepare_fp4_layer_for_marlin(layer)
            del layer.alpha
            del layer.input_scale
        elif self.backend == "flashinfer-trtllm":
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
            # FlashInfer TRTLLM FP4 GEMM requires a different weight layout.
            # FlashInfer provides nvfp4_quantize to quantize + shuffle the
            # layout but we use our own quantization so we have to call
            # shuffles ourselves.
            from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a

            weight = layer.weight.data
            weight_scale = layer.weight_scale.data

            epilogue_tile_m = 128
1057
1058
1059
1060
1061
1062
            weight = shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m)
            weight_scale = (
                shuffle_matrix_sf_a(weight_scale.view(torch.uint8), epilogue_tile_m)
                .reshape(weight_scale.shape)
                .view(torch.float8_e4m3fn)
            )
1063

1064
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
1065
1066
1067
            layer.weight = Parameter(weight, requires_grad=False)
        else:
            swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
1068
            layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False)
1069
            layer.weight = Parameter(layer.weight.data, requires_grad=False)
1070
1071
1072
1073
1074

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
1075
        bias: torch.Tensor | None = None,
1076
    ) -> torch.Tensor:
1077
        if self.backend == "marlin":
1078
1079
1080
1081
1082
1083
1084
1085
            return apply_fp4_marlin_linear(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale,
                weight_scale_2=layer.weight_scale_2,
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
1086
1087
                bias=bias,
            )
1088

1089
        output_dtype = x.dtype
1090
        output_shape = [x.shape[0], layer.weight.shape[0]]
1091
1092

        # quantize BF16 or FP16 to (FP4 and interleaved block scale)
1093
        x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_scale_inv)
1094
1095
1096

        # validate dtypes of quantized input, input block scale,
        # weight and weight_blockscale
1097
1098
1099
1100
1101
        assert x_fp4.dtype == torch.uint8
        assert layer.weight.dtype == torch.uint8
        assert x_blockscale.dtype == torch.float8_e4m3fn
        assert layer.weight_scale.dtype == torch.float8_e4m3fn
        assert layer.alpha.dtype == torch.float32
1102

1103
1104
1105
1106
        mm_args = (
            x_fp4,
            layer.weight,
            x_blockscale,
1107
            layer.weight_scale,
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
            layer.alpha,
            output_dtype,
        )
        if self.backend == "flashinfer-trtllm":
            out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm")
        elif self.backend == "flashinfer-cutlass":
            out = flashinfer_scaled_fp4_mm(*mm_args, backend="cutlass")
        else:
            out = cutlass_scaled_fp4_mm(*mm_args)

1118
1119
1120
        if bias is not None:
            out = out + bias
        return out.view(*output_shape)
1121
1122


1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
def _get_tile_tokens_dim(num_tokens: int, top_k: int, num_experts: int) -> int:
    # Guess tokens per expert assuming perfect expert distribution first.
    num_tokens_per_expert = (num_tokens * top_k) // num_experts
    # And pad the number to the next power of 2.
    tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
    # Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
    tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
    return tile_tokens_dim


1133
1134
1135
class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
    """
    MoE Method for FP4 Quantization.
1136
    Args:
1137
1138
1139
        quant_config: NVFP4 Quant Config
    """

1140
1141
1142
1143
1144
1145
    def __init__(
        self,
        quant_config: ModelOptNvFp4Config,
        moe: FusedMoEConfig,
        layer: torch.nn.Module,
    ) -> None:
1146
        from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import (  # noqa: E501
1147
1148
1149
            detect_nvfp4_moe_support,
        )

1150
1151
1152
        super().__init__(moe)
        self.quant_config = quant_config
        self.layer = layer
1153
1154
        _nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
        self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
1155
        self.allow_flashinfer = _nvfp4.allow_flashinfer
1156
        self.use_marlin = _nvfp4.use_marlin
1157
        self.flashinfer_moe_backend = None
1158
        self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
1159
        if self.allow_flashinfer:
1160
1161
1162
            self.flashinfer_moe_backend = get_flashinfer_moe_backend()
            logger.info_once(
                f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
1163
1164
                " for ModelOptNvFp4FusedMoE."
            )
1165

1166
    def maybe_make_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize | None:
1167
1168
1169
1170
        if self.use_marlin or (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1171
            return None
1172
1173
1174
1175
        elif (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
        ):
1176
            # For now, fp4 moe only works with the flashinfer dispatcher.
1177
1178
1179
            prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(
                self.moe
            )
1180
1181
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
1182
1183
        else:
            return super().maybe_make_prepare_finalize()
1184

1185
1186
1187
    def select_gemm_impl(
        self,
        prepare_finalize: mk.FusedMoEPrepareAndFinalize,
1188
        layer: torch.nn.Module,
1189
    ) -> mk.FusedMoEPermuteExpertsUnpermute:
1190
        assert self.moe_quant_config is not None
1191
        experts = select_nvfp4_gemm_impl(
1192
1193
            self.moe,
            self.moe_quant_config,
1194
1195
1196
1197
            allow_flashinfer=self.allow_flashinfer,
        )
        logger.debug_once("Using %s", experts.__class__.__name__)
        return experts
1198

1199
1200
1201
1202
1203
1204
    def uses_weight_scale_2_pattern(self) -> bool:
        """
        FP4 variants use 'weight_scale_2' pattern for per-tensor weight scales.
        """
        return True

1205
1206
1207
1208
1209
1210
1211
1212
1213
    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
1214
        if not self.quant_config.is_checkpoint_nvfp4_serialized:
1215
1216
1217
1218
            raise ValueError(
                "NVFP4 quantization was selected, "
                " dynamic quantization is not supported."
            )
1219

1220
1221
        layer.num_experts = num_experts
        layer.params_dtype = params_dtype
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
        layer.quant_config = self.quant_config
        weight_dtype = torch.uint8
        weight_scale_dtype = torch.float8_e4m3fn
        weight_loader = extra_weight_attrs.get("weight_loader")
        # GEMM 1
        w13_weight = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                # 2 fp4 items are packed in the input dimension
                hidden_size // 2,
1233
1234
                dtype=weight_dtype,
            ),
1235
1236
            input_dim=1,
            output_dim=2,
1237
1238
            weight_loader=weight_loader,
        )
1239
1240
1241
1242
1243
1244
1245
1246
1247
        layer.register_parameter("w13_weight", w13_weight)

        # GEMM 2
        w2_weight = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                hidden_size,
                # 2 fp4 items are packed in the input dimension
                intermediate_size_per_partition // 2,
1248
1249
                dtype=weight_dtype,
            ),
1250
1251
            input_dim=1,
            output_dim=2,
1252
1253
            weight_loader=weight_loader,
        )
1254
1255
1256
1257
1258
1259
1260
1261
        layer.register_parameter("w2_weight", w2_weight)

        w13_weight_scale = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                # 2 fp4 items are packed in the input dimension
                hidden_size // self.quant_config.group_size,
1262
1263
                dtype=weight_scale_dtype,
            ),
1264
1265
            input_dim=1,
            output_dim=2,
1266
1267
            weight_loader=weight_loader,
        )
1268
1269
1270
1271
1272
1273
1274
        layer.register_parameter("w13_weight_scale", w13_weight_scale)

        w2_weight_scale = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                hidden_size,
                # 2 fp4 items are packed in the input dimension
1275
1276
1277
                intermediate_size_per_partition // self.quant_config.group_size,
                dtype=weight_scale_dtype,
            ),
1278
1279
            input_dim=1,
            output_dim=2,
1280
1281
            weight_loader=weight_loader,
        )
1282
1283
1284
        layer.register_parameter("w2_weight_scale", w2_weight_scale)

        extra_weight_attrs.update(
1285
1286
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
        )
1287
1288
1289

        w13_weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(num_experts, 2, dtype=torch.float32),
1290
1291
            weight_loader=weight_loader,
        )
1292
1293
1294
1295
        layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)

        w2_weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(num_experts, dtype=torch.float32),
1296
1297
            weight_loader=weight_loader,
        )
1298
1299
1300
        layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2)

        extra_weight_attrs.update(
1301
1302
            {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
        )
1303

1304
1305
1306
1307
        w13_input_scale = PerTensorScaleParameter(
            data=torch.empty(num_experts, 2, dtype=torch.float32),
            weight_loader=weight_loader,
        )
1308
1309
        layer.register_parameter("w13_input_scale", w13_input_scale)

1310
1311
1312
1313
        w2_input_scale = PerTensorScaleParameter(
            data=torch.empty(num_experts, dtype=torch.float32),
            weight_loader=weight_loader,
        )
1314
1315
        layer.register_parameter("w2_input_scale", w2_input_scale)

1316
    def prepare_static_weights_for_trtllm_fp4_moe(
1317
        self,
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
        # args_dequant,
        # args,
        gemm1_weights,
        gemm2_weights,
        gemm1_scales_linear_fp4_bytes,
        gemm2_scales_linear_fp4_bytes,
        hidden_size,
        intermediate_size,
        num_experts,
    ):
        from flashinfer import nvfp4_block_scale_interleave
        from flashinfer.fused_moe.core import (
            _maybe_get_cached_w2_permute_indices,
1331
1332
1333
            _maybe_get_cached_w3_w1_permute_indices,
        )

1334
1335
1336
1337
1338
        """Prepare quantized weights for kernel (done offline with weights)."""
        epilogue_tile_m = 128  # FIXME: this depends on the kernel internals

        # Convert quantized weights to proper formats
        gemm1_weights_fp4 = gemm1_weights.view(torch.float8_e4m3fn).reshape(
1339
1340
            num_experts, 2 * intermediate_size, hidden_size // 2
        )  # packed fp4
1341
        gemm1_scales_linear_fp4 = gemm1_scales_linear_fp4_bytes.view(
1342
1343
1344
1345
            torch.float8_e4m3fn
        ).reshape(
            num_experts, 2 * intermediate_size, hidden_size // 16
        )  # fp8 scaling factors
1346
1347

        gemm2_weights_fp4 = gemm2_weights.view(torch.float8_e4m3fn).reshape(
1348
1349
            num_experts, hidden_size, intermediate_size // 2
        )  # packed fp4
1350
        gemm2_scales_linear_fp4 = gemm2_scales_linear_fp4_bytes.view(
1351
1352
1353
1354
            torch.float8_e4m3fn
        ).reshape(
            num_experts, hidden_size, intermediate_size // 16
        )  # fp8 scaling factors
1355
1356
1357
1358
1359
1360

        gemm1_weights_fp4_shuffled = []
        gemm1_scales_fp4_shuffled = []
        gemm2_weights_fp4_shuffled = []
        gemm2_scales_fp4_shuffled = []
        for i in range(num_experts):
1361
1362
1363
1364
1365
1366
1367
1368
1369
            # Calculate the permute indices for the following:
            # 1. Reorder rows of W1 and scales for fused gated activation
            # 2. Shuffle weights and scaling factors for transposed mma output
            # for both w3_w1 and w2 weights and scale factors
            permute_indices = _maybe_get_cached_w3_w1_permute_indices(
                self._cache_permute_indices,
                gemm1_weights_fp4[i].view(torch.uint8),
                epilogue_tile_m,
            )
1370
1371
1372
1373
1374
            gemm1_weights_fp4_shuffled.append(
                gemm1_weights_fp4[i]
                .view(torch.uint8)[permute_indices.to(gemm1_weights_fp4.device)]
                .contiguous()
            )
1375
1376
1377
1378
1379
1380
1381

            permute_sf_indices = _maybe_get_cached_w3_w1_permute_indices(
                self._cache_permute_indices,
                gemm1_scales_linear_fp4[i].view(torch.uint8),
                epilogue_tile_m,
                num_elts_per_sf=16,
            )
1382
            gemm1_scales_fp4_shuffled.append(
1383
1384
1385
1386
1387
1388
1389
1390
                nvfp4_block_scale_interleave(
                    gemm1_scales_linear_fp4[i]
                    .view(torch.uint8)[
                        permute_sf_indices.to(gemm1_scales_linear_fp4.device)
                    ]
                    .contiguous()
                )
            )
1391
1392
1393
1394
1395
1396

            permute_indices = _maybe_get_cached_w2_permute_indices(
                self._cache_permute_indices,
                gemm2_weights_fp4[i].view(torch.uint8),
                epilogue_tile_m,
            )
1397
1398
1399
1400
1401
            gemm2_weights_fp4_shuffled.append(
                gemm2_weights_fp4[i]
                .view(torch.uint8)[permute_indices.to(gemm2_weights_fp4.device)]
                .contiguous()
            )
1402
1403
1404
1405
1406
1407
1408

            permute_sf_indices = _maybe_get_cached_w2_permute_indices(
                self._cache_permute_indices,
                gemm2_scales_linear_fp4[i].view(torch.uint8),
                epilogue_tile_m,
                num_elts_per_sf=16,
            )
1409
            gemm2_scales_fp4_shuffled.append(
1410
1411
1412
1413
1414
1415
1416
1417
                nvfp4_block_scale_interleave(
                    gemm2_scales_linear_fp4[i]
                    .view(torch.uint8)[
                        permute_sf_indices.to(gemm2_scales_linear_fp4.device)
                    ]
                    .contiguous()
                )
            )
1418
1419
1420
1421

        # Stack weights for all experts
        gemm1_weights_fp4_shuffled = torch.stack(gemm1_weights_fp4_shuffled)
        gemm1_scales_fp4_shuffled = (
1422
1423
1424
1425
            torch.stack(gemm1_scales_fp4_shuffled)
            .view(torch.float8_e4m3fn)
            .reshape(num_experts, 2 * intermediate_size, hidden_size // 16)
        )
1426
1427
1428

        gemm2_weights_fp4_shuffled = torch.stack(gemm2_weights_fp4_shuffled)
        gemm2_scales_fp4_shuffled = (
1429
1430
1431
1432
            torch.stack(gemm2_scales_fp4_shuffled)
            .view(torch.float8_e4m3fn)
            .reshape(num_experts, hidden_size, intermediate_size // 16)
        )
1433
1434
1435
1436
1437
1438
        return (
            gemm1_weights_fp4_shuffled,
            gemm1_scales_fp4_shuffled,
            gemm2_weights_fp4_shuffled,
            gemm2_scales_fp4_shuffled,
        )
1439

1440
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
1441
        # GEMM 1 processing
1442
1443
1444
        gemm1_weight = layer.w13_weight.data
        gemm1_weight_scale = layer.w13_weight_scale.data

1445
        if self.allow_flashinfer:
1446
            gemm1_weight, gemm1_weight_scale = reorder_w1w3_to_w3w1(
1447
1448
                gemm1_weight, gemm1_weight_scale, dim=-2
            )
1449
1450

        layer.w13_weight = Parameter(gemm1_weight, requires_grad=False)
1451
        layer.w13_weight_scale = Parameter(gemm1_weight_scale, requires_grad=False)
1452

1453
        # Common processing for w13_weight_scale_2
1454
1455
1456
        if not torch.allclose(
            layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]
        ):
1457
1458
            logger.warning_once(
                "w1_weight_scale_2 must match w3_weight_scale_2. "
1459
1460
                "Accuracy may be affected."
            )
1461
1462

        w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
1463
        layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False)
1464

1465
        # Common processing for input scales and alphas
1466
        w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
1467
1468
        layer.g1_alphas = Parameter(
            (w13_input_scale * w13_weight_scale_2).to(torch.float32),
1469
1470
            requires_grad=False,
        )
1471
1472
1473

        # This is for quantization, so we need to invert it.
        layer.w13_input_scale_quant = Parameter(
1474
1475
            (1 / w13_input_scale).to(torch.float32), requires_grad=False
        )
1476

1477
        # GEMM 2 processing
1478
1479
        layer.g2_alphas = Parameter(
            (layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
1480
1481
            requires_grad=False,
        )
1482
1483
1484

        # This is for quantization, so we need to invert it.
        layer.w2_input_scale_quant = Parameter(
1485
1486
            (1 / layer.w2_input_scale).to(torch.float32), requires_grad=False
        )
1487

1488
        # TensorRT-LLM specific processing
1489
1490
1491
1492
        if (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1493
            # Prepare static weights for TRT-LLM kernel
1494
            # alternate: prepare_static_weight_layouts_for_trtllm_moe
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
            (
                gemm1_weights_fp4_shuffled,
                gemm1_scales_fp4_shuffled,
                gemm2_weights_fp4_shuffled,
                gemm2_scales_fp4_shuffled,
            ) = self.prepare_static_weights_for_trtllm_fp4_moe(
                layer.w13_weight,
                layer.w2_weight,
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                layer.w2_weight.size(-2),  # hidden_size
                layer.w13_weight.size(-2) // 2,  # intermediate_size
                layer.w13_weight.size(0),  # num_experts
            )
1509
            logger.debug_once("Finished shuffling weights for TRT-LLM MOE")
1510
1511

            layer.gemm1_weights_fp4_shuffled = Parameter(
1512
1513
                gemm1_weights_fp4_shuffled, requires_grad=False
            )
1514
            layer.gemm2_weights_fp4_shuffled = Parameter(
1515
1516
                gemm2_weights_fp4_shuffled, requires_grad=False
            )
1517
            layer.gemm1_scales_fp4_shuffled = Parameter(
1518
1519
                gemm1_scales_fp4_shuffled, requires_grad=False
            )
1520
            layer.gemm2_scales_fp4_shuffled = Parameter(
1521
1522
                gemm2_scales_fp4_shuffled, requires_grad=False
            )
1523
1524
1525

            # Additional parameter needed for TRT-LLM
            layer.g1_scale_c = Parameter(
1526
                (layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32),
1527
1528
                requires_grad=False,
            )
1529

1530
1531
1532
1533
1534
            # Clean up weights that won't be used by TRT-LLM
            del layer.w2_weight
            del layer.w2_weight_scale
            del layer.w13_weight
            del layer.w13_weight_scale
1535
1536
1537
1538
1539
1540
1541
        elif self.use_marlin:
            # Marlin processing
            prepare_moe_fp4_layer_for_marlin(layer)
            del layer.g1_alphas
            del layer.g2_alphas
            del layer.w13_input_scale_quant
            del layer.w2_input_scale_quant
1542
1543
        else:
            # Non-TRT-LLM processing (Cutlass or non-flashinfer)
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
            assert layer.w13_weight_scale.shape[2] % 16 == 0, (
                "Expected weight_scale.dim(1) to be divisible by 16"
            )
            assert layer.w13_weight_scale.dtype == torch.float8_e4m3fn, (
                "Weight Blockscale must be represented as FP8-E4M3"
            )
            w13_blockscale_swizzled = swizzle_blockscale(layer.w13_weight_scale)
            layer.w13_weight_scale = Parameter(
                w13_blockscale_swizzled, requires_grad=False
            )

            assert layer.w2_weight_scale.shape[2] % 16 == 0, (
                "Expected weight_scale.dim(1) to be divisible by 16"
            )
            assert layer.w2_weight_scale.dtype == torch.float8_e4m3fn, (
                "Weight Blockscale must be represented as FP8-E4M3"
            )
1561
            w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale)
1562
1563
1564
1565
            layer.w2_weight_scale = Parameter(
                w2_blockscale_swizzled, requires_grad=False
            )
            layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
1566

1567
    def get_fused_moe_quant_config(
1568
        self, layer: torch.nn.Module
1569
    ) -> FusedMoEQuantConfig | None:
1570
1571
1572
1573
        if (
            self.use_marlin
            or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
            return None

        return nvfp4_moe_quant_config(
            w1_scale=layer.w13_weight_scale,
            w2_scale=layer.w2_weight_scale,
            g1_alphas=layer.g1_alphas,
            g2_alphas=layer.g2_alphas,
            a1_gscale=layer.w13_input_scale_quant,
            a2_gscale=layer.w2_input_scale_quant,
        )

1585
1586
1587
1588
1589
1590
1591
1592
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
        use_grouped_topk: bool = False,
1593
1594
        topk_group: int | None = None,
        num_expert_group: int | None = None,
1595
        global_num_experts: int = -1,
1596
1597
        expert_map: torch.Tensor | None = None,
        custom_routing_function: Callable | None = None,
1598
        scoring_func: str = "softmax",
1599
        routed_scaling_factor: float = 1.0,
1600
        e_score_correction_bias: torch.Tensor | None = None,
1601
1602
        apply_router_weight_on_input: bool = False,
        activation: str = "silu",
1603
        enable_eplb: bool = False,
1604
1605
1606
1607
        expert_load_view: torch.Tensor | None = None,
        logical_to_physical_map: torch.Tensor | None = None,
        logical_replica_count: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
1608
1609
        if enable_eplb:
            raise NotImplementedError(
1610
1611
                "EPLB not supported for `ModelOptNvFp4FusedMoE` yet."
            )
1612
        assert activation == "silu", "Only SiLU activation is supported."
1613

1614
1615
1616
1617
        if (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1618
1619
1620
1621
            import flashinfer

            from vllm.model_executor.models.llama4 import Llama4MoE

1622
1623
            assert self.fused_experts is None

1624
            a1_gscale = layer.w13_input_scale_quant
1625
1626
1627
1628
1629
1630
1631
1632
            (hidden_states_fp4, hidden_states_scale_linear_fp4) = (
                flashinfer.fp4_quantize(
                    x,
                    a1_gscale,
                    is_sf_swizzled_layout=False,
                )
            )
            use_llama4_routing = (
1633
                custom_routing_function is Llama4MoE.custom_routing_function
1634
            )
1635
1636
1637
            routing_method_type = flashinfer.RoutingMethodType.DeepSeekV3
            if use_llama4_routing:
                routing_method_type = flashinfer.RoutingMethodType.Llama4
Shu Wang's avatar
Shu Wang committed
1638
1639
1640
            routing_bias = e_score_correction_bias
            if routing_bias is not None:
                routing_bias = routing_bias.to(torch.bfloat16)
1641
1642
            out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
                routing_logits=router_logits
1643
1644
                if use_llama4_routing
                else router_logits.to(torch.float32),
Shu Wang's avatar
Shu Wang committed
1645
                routing_bias=routing_bias,
1646
1647
                hidden_states=hidden_states_fp4,
                hidden_states_scale=hidden_states_scale_linear_fp4.view(
1648
1649
                    torch.float8_e4m3fn
                ).flatten(),
1650
1651
                gemm1_weights=layer.gemm1_weights_fp4_shuffled.data,
                gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.data.view(
1652
1653
                    torch.float8_e4m3fn
                ),
1654
1655
1656
1657
1658
1659
                gemm1_bias=None,
                gemm1_alpha=None,
                gemm1_beta=None,
                gemm1_clamp_limit=None,
                gemm2_weights=layer.gemm2_weights_fp4_shuffled.data,
                gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.data.view(
1660
1661
                    torch.float8_e4m3fn
                ),
1662
1663
1664
1665
1666
1667
                gemm2_bias=None,
                output1_scale_scalar=layer.g1_scale_c.data,
                output1_scale_gate_scalar=layer.g1_alphas.data,
                output2_scale_scalar=layer.g2_alphas.data,
                num_experts=global_num_experts,
                top_k=top_k,
1668
                n_group=num_expert_group if num_expert_group is not None else 0,
1669
                topk_group=topk_group if topk_group is not None else 0,
1670
1671
1672
1673
                intermediate_size=layer.intermediate_size_per_partition,
                local_expert_offset=layer.ep_rank * layer.local_num_experts,
                local_num_experts=layer.local_num_experts,
                routed_scaling_factor=None,
1674
1675
1676
                tile_tokens_dim=_get_tile_tokens_dim(
                    x.shape[0], top_k, layer.local_num_experts
                ),
1677
1678
1679
1680
1681
                routing_method_type=routing_method_type,
                do_finalize=True,
            )[0]
            return out

XuruiYang's avatar
XuruiYang committed
1682
        topk_weights, topk_ids, _ = FusedMoE.select_experts(
1683
1684
1685
1686
1687
1688
1689
1690
1691
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
            num_expert_group=num_expert_group,
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
1692
            routed_scaling_factor=routed_scaling_factor,
1693
            e_score_correction_bias=e_score_correction_bias,
1694
1695
            indices_type=self.topk_indices_dtype,
        )
1696

1697
1698
1699
1700
1701
        #
        # Note: the order here is important. self.fused_experts can override
        # flashinfer cutlass, cutlass fp4 or fused_experts but not marlin or
        # trtllm.
        #
1702
        if self.use_marlin:
1703
            assert self.fused_experts is None
1704
1705
1706
1707
            return torch.ops.vllm.fused_marlin_moe(
                x,
                layer.w13_weight,
                layer.w2_weight,
1708
1709
                None,
                None,
1710
1711
1712
1713
1714
1715
1716
1717
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                router_logits,
                topk_weights,
                topk_ids,
                global_scale1=layer.w13_weight_scale_2,
                global_scale2=layer.w2_weight_scale_2,
                quant_type_id=scalar_types.float4_e2m1f.id,
1718
                apply_router_weight_on_input=apply_router_weight_on_input,
1719
                global_num_experts=global_num_experts,
1720
                expert_map=expert_map,
1721
1722
                workspace=layer.workspace,
            )
1723

1724
        elif self.fused_experts is not None:
1725
1726
1727
1728
            assert (
                self.allow_flashinfer
                and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
            )
1729
1730

            assert is_valid_flashinfer_cutlass_fused_moe(
1731
1732
                x, layer.w13_weight, layer.w2_weight
            ), "Flashinfer CUTLASS Fused MoE not applicable!"
1733

1734
            return self.fused_experts(
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                inplace=False,  # TODO(shuw): fix later, now output is high prec
                activation=activation,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
1746
1747
1748
1749
        elif (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
        ):
1750
            from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (  # noqa: E501
1751
1752
1753
                flashinfer_cutlass_moe_fp4,
            )

1754
            assert self.moe_quant_config is not None
1755

1756
            return flashinfer_cutlass_moe_fp4(
1757
1758
1759
1760
1761
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
1762
1763
                quant_config=self.moe_quant_config,
                inplace=False,
1764
1765
1766
1767
1768
1769
                activation=activation,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
        else:
1770
1771
            # If no modular kernel is provided, use cutlass_moe_fp4 for TP case
            # only (no EP).
1772
1773
            from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4

1774
1775
            assert self.moe_quant_config is not None
            return cutlass_moe_fp4(
1776
1777
1778
1779
1780
                a=x,
                w1_fp4=layer.w13_weight,
                w2_fp4=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
1781
1782
1783
1784
                quant_config=self.moe_quant_config,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
                # TODO: derive from arguments
1785
1786
1787
1788
                m=x.shape[0],
                n=layer.w2_weight.shape[2] * 2,
                k=x.shape[1],
                e=layer.w13_weight.shape[0],
1789
            )