modelopt.py 68.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Optional
6
7
8
9
10

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter

11
12
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
13
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
14
from vllm.logger import init_logger
15
from vllm.model_executor.layers.fused_moe.config import (
16
17
18
19
20
    FusedMoEConfig,
    FusedMoEQuantConfig,
    fp8_w8a8_moe_quant_config,
    nvfp4_moe_quant_config,
)
21
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
22
23
    is_valid_flashinfer_cutlass_fused_moe,
)
24
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
25
from vllm.model_executor.layers.fused_moe.layer import (
26
27
28
29
30
31
32
33
34
    FusedMoE,
    FusedMoEMethodBase,
    FusedMoeWeightScaleSupported,
)
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
)
35
from vllm.model_executor.layers.quantization import QuantizationMethods
36
from vllm.model_executor.layers.quantization.base_config import (
37
38
39
    QuantizationConfig,
    QuantizeMethodBase,
)
40
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
41
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
42
43
44
45
    build_flashinfer_fp4_cutlass_moe_prepare_finalize,
    reorder_w1w3_to_w3w1,
    select_nvfp4_gemm_impl,
)
46
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
47
48
    FlashinferMoeBackend,
    apply_flashinfer_per_tensor_scale_fp8,
49
    build_flashinfer_fp8_cutlass_moe_prepare_finalize,
50
51
    flashinfer_cutlass_moe_fp8,
    get_flashinfer_moe_backend,
52
    is_flashinfer_supporting_global_sf,
53
54
55
56
57
    register_moe_scaling_factors,
    rotate_flashinfer_fp8_moe_weights,
    select_cutlass_fp8_gemm_impl,
    swap_w13_to_w31,
)
58
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
59
60
61
62
63
    apply_fp4_marlin_linear,
    is_fp4_marlin_supported,
    prepare_fp4_layer_for_marlin,
    prepare_moe_fp4_layer_for_marlin,
)
64
from vllm.model_executor.layers.quantization.utils.quant_utils import (
65
66
67
68
69
    GroupShape,
    cutlass_fp4_supported,
    is_layer_skipped,
    swizzle_blockscale,
)
70
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
71
72
73
74
    Fp8LinearOp,
    requantize_with_max_scale,
)
from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
75
from vllm.scalar_type import scalar_types
76
77
78
79
80
from vllm.utils.flashinfer import (
    flashinfer_scaled_fp4_mm,
    has_flashinfer,
    has_flashinfer_moe,
)
81

82
83
84
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

85
86
logger = init_logger(__name__)

87
88
QUANT_ALGOS = ["FP8", "NVFP4"]
KV_CACHE_QUANT_ALGOS = ["FP8"]
89
90
91
92
93
94
95
96


class ModelOptFp8Config(QuantizationConfig):
    """Config class for ModelOpt FP8."""

    def __init__(
        self,
        is_checkpoint_fp8_serialized: bool = False,
97
98
        kv_cache_quant_method: str | None = None,
        exclude_modules: list[str] | None = None,
99
    ) -> None:
100
        super().__init__()
101
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
102
        self.kv_cache_quant_method = kv_cache_quant_method
103
        self.exclude_modules = exclude_modules or []
104
        if is_checkpoint_fp8_serialized:
105
106
107
108
            logger.warning(
                "Detected ModelOpt fp8 checkpoint. Please note that"
                " the format is experimental and could change."
            )
109
110

    @classmethod
111
    def get_name(cls) -> QuantizationMethods:
112
113
114
        return "modelopt"

    @classmethod
115
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
116
117
118
119
120
121
122
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
        return 89

    @classmethod
123
    def get_config_filenames(cls) -> list[str]:
124
125
        return ["hf_quant_config.json"]

126
127
    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if self.exclude_modules is not None:
128
            self.exclude_modules = hf_to_vllm_mapper.apply_list(self.exclude_modules)
129

130
131
    @classmethod
    def override_quantization_method(
132
        cls, hf_quant_cfg, user_quant
133
    ) -> QuantizationMethods | None:
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
        """Detect if this ModelOpt config should be used based on
        quantization config."""

        if hf_quant_cfg is None:
            return None

        # Use the community standard 'quant_method'
        quant_method = hf_quant_cfg.get("quant_method", "").lower()

        # Only proceed if the method is explicitly "modelopt"
        if quant_method != "modelopt":
            return None

        # Look for ModelOpt-specific config structure
        if "quantization" in hf_quant_cfg:
            quant_config = hf_quant_cfg["quantization"]
            if isinstance(quant_config, dict):
                quant_algo = quant_config.get("quant_algo", "")
                if "FP8" in quant_algo:
                    return "modelopt"
        else:
            # Check for compressed-tensors style config with specific quant_algo
            quant_algo = hf_quant_cfg.get("quant_algo", "")
            if isinstance(quant_algo, str) and "FP8" in quant_algo:
                return "modelopt"

        return None

162
    @classmethod
163
    def from_config(cls, config: dict[str, Any]) -> "ModelOptFp8Config":
164
165
166
167
168
        # Handle both ModelOpt format and compressed-tensors style format
        if "quantization" in config:
            # ModelOpt format: {"quantization": {"quant_algo": "..."}}
            quant_config = cls.get_from_keys(config, ["quantization"])
            if not isinstance(quant_config, dict):
169
                raise ValueError("Expected 'quantization' to be a dictionary in config")
170
171
172
173
            quant_method = quant_config.get("quant_algo", "")
            if not quant_method:
                raise ValueError("Missing 'quant_algo' in quantization config")
            kv_cache_quant_method = quant_config.get("kv_cache_quant_algo")
174
            # "exclude_modules" is the key in the legacy hf_quant_config.json
175
176
177
178
179
180
            exclude_modules = quant_config.get("exclude_modules")
        else:
            # Compressed-tensors style format:
            # {"quant_algo": "...", "quant_method": "modelopt"}
            quant_method = config.get("quant_algo", "")
            kv_cache_quant_method = config.get("kv_cache_quant_algo")
181
182
            # "ignore" is the key in config.json
            exclude_modules = config.get("ignore")
183

184
        if quant_method not in QUANT_ALGOS:
185
186
187
188
            raise ValueError(
                f"ModelOpt currently only supports: {QUANT_ALGOS} "
                "quantizations in vLLM. Please check the "
                "`hf_quant_config.json` file for your model's "
189
190
191
                "quant configuration."
            )
        is_checkpoint_fp8_serialized = "FP8" in quant_method
192

193
        return cls(is_checkpoint_fp8_serialized, kv_cache_quant_method, exclude_modules)
194
195
196
197

    def is_layer_excluded(self, prefix: str) -> bool:
        """
        Check if a layer should be excluded from quantization.
198
        Handles both exact matching (for fused layers) and substring matching.
199
200
201
202
203
204
205
206

        This method handles both regular models and multimodal models that use
        the language_model prefix. For multimodal models, it checks if the
        module name (without the language_model prefix) is in the exclude list.
        """
        if self.exclude_modules is None:
            return False

207
        # First check exact matching with fused layer support
208
        if is_layer_skipped(prefix, self.exclude_modules, self.packed_modules_mapping):
209
210
211
            return True

        # Then check substring matching for patterns not caught by exact match
212
        for module in self.exclude_modules:
213
            # Skip exact matches already handled above
214
215
216
217
218
219
220
            if module != prefix and (
                module in prefix
                or (
                    prefix.startswith("language_model.")
                    and module in prefix.removeprefix("language_model.")
                )
            ):
221
222
                return True
        return False
223

224
225
226
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
227
        from vllm.attention.layer import Attention  # Avoid circular import
228

229
        if isinstance(layer, LinearBase):
230
231
232
            if self.is_layer_excluded(prefix):
                return UnquantizedLinearMethod()
            # Check if this is a vision model layer that should not be quantized
233
            if "vision_tower" in prefix or "vision_model" in prefix:
234
                return UnquantizedLinearMethod()
235
236
237
            return ModelOptFp8LinearMethod(self)
        elif isinstance(layer, Attention):
            return ModelOptFp8KVCacheMethod(self)
238
        elif isinstance(layer, FusedMoE):
239
            return ModelOptFp8MoEMethod(self, layer)
240
241
242
243
244
245
        return None


class ModelOptFp8LinearMethod(LinearMethodBase):
    """Linear method for Model Optimizer static quantization.
    Supports loading FP8 checkpoints with static weight scale and
246
    activation scale. Future support might be added for dynamic
247
248
249
250
    scales.

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
251
    2. Only support float8_e4m3fn datatype
252
253
254
        Args: quant_config: The ModelOpt quantization config.
    """

255
    def __init__(self, quant_config: ModelOptFp8Config) -> None:
256
        self.quant_config = quant_config
257
        self.fp8_linear = Fp8LinearOp(
258
259
            act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR
        )
260
261
262
263
264

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
265
        output_partition_sizes: list[int],
266
267
268
269
270
271
272
273
274
275
276
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
277
278
279
280
281
282
283
284
285
286
287
288
289
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_fp8_serialized
            else params_dtype
        )
        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition, input_size_per_partition, dtype=weight_dtype
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
290
291
292
293
        layer.register_parameter("weight", weight)

        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
294
295
296
297
            weight_scale = PerTensorScaleParameter(
                data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
                weight_loader=weight_loader,
            )
298
299
300
            weight_scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("weight_scale", weight_scale)
            # INPUT SCALE
301
302
303
304
            scale = PerTensorScaleParameter(
                data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
                weight_loader=weight_loader,
            )
305
306
307
308
309

            scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("input_scale", scale)

    def process_weights_after_loading(self, layer: Module) -> None:
310
311
312
313
        weight = layer.weight
        max_w_scale = layer.weight_scale.max()
        if not (layer.weight_scale == layer.weight_scale[0]).all():
            max_w_scale, weight = requantize_with_max_scale(
314
315
                layer.weight, layer.weight_scale, layer.logical_widths
            )
316
317
        layer.weight = Parameter(weight.t(), requires_grad=False)
        layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
318
        layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
319
320
321
322
323

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
324
        bias: torch.Tensor | None = None,
325
    ) -> torch.Tensor:
326
327
328
329
330
331
332
        return self.fp8_linear.apply(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            input_scale=layer.input_scale,
            bias=bias,
        )
333
334


335
336
337
338
339
340
341
342
class ModelOptFp8MoEMethod(FusedMoEMethodBase):
    """MoE method for ModelOpt FP8.
    Supports loading FP8 checkpoints with static weight scale and
    activation scale.
    Args:
        quant_config: The ModelOpt quantization config.
    """

343
344
345
    def __init__(
        self,
        quant_config: ModelOptFp8Config,
346
        layer: torch.nn.Module,
347
    ) -> None:
348
349
        super().__init__(layer.moe_config)
        self.layer = layer
350
351
        self.quant_config = quant_config
        from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
352
353
354
            cutlass_fp8_supported,
        )

355
        self.cutlass_fp8_supported = cutlass_fp8_supported()
356
        self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
357
        if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
358
            self.flashinfer_moe_backend = get_flashinfer_moe_backend()
359
            logger.info_once(
360
361
362
363
                f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
            )

    def maybe_make_prepare_finalize(
364
        self,
365
    ) -> mk.FusedMoEPrepareAndFinalize | None:
366
367
368
369
        # TRT LLM not supported with all2all yet.
        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
            return None
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
370
371
372
            prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
                self.moe
            )
373
374
375
376
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
        else:
            return super().maybe_make_prepare_finalize()
377
378
379
380

    def select_gemm_impl(
        self,
        prepare_finalize: mk.FusedMoEPrepareAndFinalize,
381
        layer: torch.nn.Module,
382
    ) -> mk.FusedMoEPermuteExpertsUnpermute:
383
        assert self.moe_quant_config is not None
384
        experts = select_cutlass_fp8_gemm_impl(
385
386
            self.moe,
            self.moe_quant_config,
387
388
389
        )
        logger.debug_once("Using %s", experts.__class__.__name__)
        return experts
390
391
392
393
394
395
396
397
398
399
400

    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        # Use FP8 dtype if checkpoint is serialized
401
402
403
404
405
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_fp8_serialized
            else params_dtype
        )
406
407
408
        weight_loader = extra_weight_attrs.get("weight_loader")

        w13_weight = ModelWeightParameter(
409
410
411
412
413
414
            data=torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                hidden_size,
                dtype=weight_dtype,
            ),
415
416
417
418
419
420
421
            input_dim=2,
            output_dim=1,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_weight", w13_weight)

        w2_weight = ModelWeightParameter(
422
423
424
425
426
427
            data=torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=weight_dtype,
            ),
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
            input_dim=2,
            output_dim=1,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w2_weight", w2_weight)

        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALES - Per-tensor scaling for ModelOpts
            # Allocate 2 scales for w1 and w3 respectively.
            # They will be combined to a single scale after weight loading.
            w13_weight_scale = PerTensorScaleParameter(
                data=torch.full(
                    (num_experts, 2),
                    1.0,
                    dtype=torch.float32,
                ),
                weight_loader=weight_loader,
            )
            w2_weight_scale = PerTensorScaleParameter(
447
                data=torch.full((num_experts,), 1.0, dtype=torch.float32),
448
449
450
451
452
453
454
                weight_loader=weight_loader,
            )
            layer.register_parameter("w13_weight_scale", w13_weight_scale)
            layer.register_parameter("w2_weight_scale", w2_weight_scale)

            # Set weight loader attributes for scales
            extra_weight_attrs.update(
455
456
                {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
            )
457
458
459

            # INPUT SCALES - Per-tensor scaling for ModelOpt
            w13_input_scale = PerTensorScaleParameter(
460
                data=torch.full((num_experts,), 1.0, dtype=torch.float32),
461
462
463
                weight_loader=weight_loader,
            )
            w2_input_scale = PerTensorScaleParameter(
464
                data=torch.full((num_experts,), 1.0, dtype=torch.float32),
465
466
467
468
469
470
471
472
473
474
                weight_loader=weight_loader,
            )
            layer.register_parameter("w13_input_scale", w13_input_scale)
            layer.register_parameter("w2_input_scale", w2_input_scale)

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        """Process FP8 MoE weights after loading from serialized checkpoint.
        Only supports pre-quantized checkpoints with FP8 weights and scales.
        """

475
        layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
476
477
478
479
        layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)

        from vllm._custom_ops import scaled_fp8_quant
        from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
480
481
            per_tensor_dequantize,
        )
482
483

        # Handle scale parameters
484
        if hasattr(layer, "w13_weight_scale") and layer.w13_weight_scale is not None:
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
            # Fp8 moe kernel needs single weight scale for w13 per expert.
            # We take the max of the w1 and w3 scales
            # then dequant and requant each expert.
            if layer.w13_weight_scale.dim() == 2:
                # Get the maximum scale across w1 and w3 for each expert
                max_w13_scales = layer.w13_weight_scale.max(dim=1).values

                # Requantize each expert's weights using the combined scale
                # w13_weight (num_experts, 2 * intermediate_size, hidden_size)
                # where the first intermediate_size rows are w1, the next are w3
                intermediate_size = layer.w13_weight.shape[1] // 2
                for expert_id in range(layer.w13_weight.shape[0]):
                    start = 0
                    for shard_id in range(2):  # w1 and w3
                        # Dequantize using the original scale for this shard
                        dq_weight = per_tensor_dequantize(
501
502
503
                            layer.w13_weight[expert_id][
                                start : start + intermediate_size, :
                            ],
504
505
506
507
508
                            layer.w13_weight_scale[expert_id][shard_id],
                        )
                        # Requantize using the combined max scale

                        (
509
510
511
                            layer.w13_weight[expert_id][
                                start : start + intermediate_size, :
                            ],
512
                            _,
513
                        ) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
514
515
516
517

                        start += intermediate_size

                # Update the scale parameter to be per-expert
518
                layer.w13_weight_scale = Parameter(max_w13_scales, requires_grad=False)
519
            else:
520
521
522
                layer.w13_weight_scale = Parameter(
                    layer.w13_weight_scale.data, requires_grad=False
                )
523

524
525
526
527
        if hasattr(layer, "w2_weight_scale") and layer.w2_weight_scale is not None:
            layer.w2_weight_scale = Parameter(
                layer.w2_weight_scale.data, requires_grad=False
            )
528
        # Input scales must be equal for each expert in fp8 MoE layers.
529
530
531
532
533
534
535
536
        if hasattr(layer, "w13_input_scale") and layer.w13_input_scale is not None:
            layer.w13_input_scale = Parameter(
                layer.w13_input_scale.max(), requires_grad=False
            )
        if hasattr(layer, "w2_input_scale") and layer.w2_input_scale is not None:
            layer.w2_input_scale = Parameter(
                layer.w2_input_scale.max(), requires_grad=False
            )
537

538
        if self.flashinfer_moe_backend is not None:
539
            layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
540
            register_moe_scaling_factors(layer)
541
            if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
542
                rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight)
543

544
    def get_fused_moe_quant_config(
545
        self, layer: torch.nn.Module
546
    ) -> FusedMoEQuantConfig | None:
547
548
549
550
551
552
553
554
555
556
557
        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
            return None

        return fp8_w8a8_moe_quant_config(
            w1_scale=layer.w13_weight_scale,
            w2_scale=layer.w2_weight_scale,
            a1_scale=layer.w13_input_scale,
            a2_scale=layer.w2_input_scale,
            per_act_token_quant=False,
        )

558
559
560
561
562
563
564
565
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
        use_grouped_topk: bool = False,
566
567
        topk_group: int | None = None,
        num_expert_group: int | None = None,
568
        global_num_experts: int = -1,
569
570
        expert_map: torch.Tensor | None = None,
        custom_routing_function: Callable | None = None,
571
        scoring_func: str = "softmax",
572
        routed_scaling_factor: float = 1.0,
573
        e_score_correction_bias: torch.Tensor | None = None,
574
575
576
        apply_router_weight_on_input: bool = False,
        activation: str = "silu",
        enable_eplb: bool = False,
577
578
579
580
        expert_load_view: torch.Tensor | None = None,
        logical_to_physical_map: torch.Tensor | None = None,
        logical_replica_count: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
581
582
        if enable_eplb:
            raise NotImplementedError(
583
584
                "EPLB not supported for `ModelOptFp8MoEMethod` yet."
            )
585

586
        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
587
            assert self.fused_experts is None
588
589
590
            assert activation == "silu", (
                f"Expected 'silu' activation but got {activation}"
            )
591
592
593
594
595
596
597
598
599
600
            assert not renormalize
            return apply_flashinfer_per_tensor_scale_fp8(
                layer=layer,
                hidden_states=x,
                router_logits=router_logits,
                routing_bias=e_score_correction_bias,
                global_num_experts=global_num_experts,
                top_k=top_k,
                num_expert_group=num_expert_group,
                topk_group=topk_group,
601
602
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
603

604
        # Expert selection
XuruiYang's avatar
XuruiYang committed
605
        topk_weights, topk_ids, _ = FusedMoE.select_experts(
606
607
608
609
610
611
612
613
614
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
            num_expert_group=num_expert_group,
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
615
            routed_scaling_factor=routed_scaling_factor,
616
            e_score_correction_bias=e_score_correction_bias,
617
            indices_type=self.topk_indices_dtype,
618
        )
619

620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
        #
        # Note: the order here is important. self.fused_experts can override
        # cutlass or fused_experts.
        #
        if self.fused_experts is not None:
            return self.fused_experts(
                x,
                layer.w13_weight,
                layer.w2_weight,
                topk_weights,
                topk_ids,
                inplace=False,
                activation=activation,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
638
            assert not renormalize
639
640
641
            assert activation == "silu", (
                f"Expected 'silu' activation but got {activation}"
            )
642
643
644
645
646
647
648
649
650
651
652
653
            return flashinfer_cutlass_moe_fp8(
                x,
                layer,
                topk_weights,
                topk_ids,
                inplace=False,
                activation=activation,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
        else:
654
655
            from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts

656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
            assert self.moe_quant_config is not None

            return fused_experts(
                x,
                layer.w13_weight,
                layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                inplace=True,
                activation=activation,
                quant_config=self.moe_quant_config,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
671
672


673
674
675
676
677
678
class ModelOptNvFp4Config(QuantizationConfig):
    """Config class for ModelOpt FP4."""

    def __init__(
        self,
        is_checkpoint_nvfp4_serialized: bool,
679
        kv_cache_quant_algo: str | None,
680
        exclude_modules: list[str],
681
682
        group_size: int = 16,
    ) -> None:
683
        super().__init__()
684
685
686
687
        self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
        if is_checkpoint_nvfp4_serialized:
            logger.warning(
                "Detected ModelOpt NVFP4 checkpoint. Please note that"
688
689
                " the format is experimental and could change in future."
            )
690
691
692
693
694
695

            self.group_size = group_size
            self.kv_cache_quant_algo = kv_cache_quant_algo
            self.exclude_modules = exclude_modules

    @classmethod
696
    def get_name(cls) -> QuantizationMethods:
697
        return "modelopt_fp4"
698
699

    @classmethod
700
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
701
702
703
704
        return [torch.bfloat16, torch.half, torch.float8_e4m3fn]

    @classmethod
    def get_min_capability(cls) -> int:
705
        return 80
706
707

    @classmethod
708
    def get_config_filenames(cls) -> list[str]:
709
710
        return ["hf_quant_config.json"]

711
712
    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if self.exclude_modules is not None:
713
            self.exclude_modules = hf_to_vllm_mapper.apply_list(self.exclude_modules)
714

715
716
    @classmethod
    def override_quantization_method(
717
        cls, hf_quant_cfg, user_quant
718
    ) -> QuantizationMethods | None:
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
        """Detect if this ModelOpt FP4 config should be used based on
        quantization config."""
        if hf_quant_cfg is None:
            return None

        # Use the community standard 'quant_method'
        quant_method = hf_quant_cfg.get("quant_method", "").lower()

        # Only proceed if the method is explicitly "modelopt"
        if quant_method != "modelopt":
            return None

        # Look for ModelOpt-specific config structure
        if "quantization" in hf_quant_cfg:
            quant_config = hf_quant_cfg["quantization"]
            if isinstance(quant_config, dict):
                quant_algo = quant_config.get("quant_algo", "")
                if "NVFP4" in quant_algo:
                    return "modelopt_fp4"
        else:
            # Check for compressed-tensors style config with specific
            # quant_algo field
            quant_algo = hf_quant_cfg.get("quant_algo", "")
            if isinstance(quant_algo, str) and "FP4" in quant_algo.upper():
                return "modelopt_fp4"

        return None

747
    @classmethod
748
    def from_config(cls, config: dict[str, Any]) -> "ModelOptNvFp4Config":
749
750
751
752
753
754
755
        # Handle both traditional ModelOpt format and compressed-tensors
        # style format
        if "quantization" in config:
            # Traditional ModelOpt format:
            # {"quantization": {"quant_algo": "..."}}
            quant_config = cls.get_from_keys(config, ["quantization"])
            if not isinstance(quant_config, dict):
756
                raise ValueError("Expected 'quantization' to be a dictionary in config")
757
758
759
760
761
762
763
764
765
766
767
768
769

            quant_method = quant_config.get("quant_algo", "")
            if not quant_method:
                raise ValueError("Missing 'quant_algo' in quantization config")

            # Handle kv_cache_quant_algo with proper type validation
            kv_cache_quant_algo_raw = quant_config.get("kv_cache_quant_algo")
            if kv_cache_quant_algo_raw is None:
                # No KV cache quantization by default
                kv_cache_quant_algo = None
            elif isinstance(kv_cache_quant_algo_raw, str):
                kv_cache_quant_algo = kv_cache_quant_algo_raw
            else:
770
771
772
773
                raise ValueError(
                    f"kv_cache_quant_algo must be a string, got "
                    f"{type(kv_cache_quant_algo_raw)}"
                )
774
775
776
777
778
779
780
781
782
783
784

            # Handle group_size with proper type validation
            group_size_raw = quant_config.get("group_size")
            if group_size_raw is None:
                group_size = 16  # Default value
            elif isinstance(group_size_raw, int):
                group_size = group_size_raw
            else:
                try:
                    group_size = int(group_size_raw)
                except (ValueError, TypeError):
785
786
787
                    raise ValueError(
                        f"group_size must be an integer, got {type(group_size_raw)}"
                    ) from None
788

789
            # "exclude_modules" is the key in the legacy hf_quant_config.json
790
791
            exclude_modules = quant_config.get("exclude_modules", [])
            if not isinstance(exclude_modules, list):
792
793
794
                raise ValueError(
                    f"exclude_modules must be a list, got {type(exclude_modules)}"
                )
795
796
797
798
799
800
801
802
803
804
805
806
807
        else:
            # Compressed-tensors style format:
            # {"quant_algo": "...", "quant_method": "modelopt"}
            quant_method = config.get("quant_algo", "")

            # Handle kv_cache_quant_algo with proper type validation
            kv_cache_quant_algo_raw = config.get("kv_cache_quant_algo")
            if kv_cache_quant_algo_raw is None:
                # No KV cache quantization by default
                kv_cache_quant_algo = None
            elif isinstance(kv_cache_quant_algo_raw, str):
                kv_cache_quant_algo = kv_cache_quant_algo_raw
            else:
808
809
810
811
                raise ValueError(
                    f"kv_cache_quant_algo must be a string, got "
                    f"{type(kv_cache_quant_algo_raw)}"
                )
812
813
814
815
816
817
818
819
820
821
822

            # Handle group_size with proper type validation
            group_size_raw = config.get("group_size")
            if group_size_raw is None:
                group_size = 16  # Default value
            elif isinstance(group_size_raw, int):
                group_size = group_size_raw
            else:
                try:
                    group_size = int(group_size_raw)
                except (ValueError, TypeError):
823
824
825
                    raise ValueError(
                        f"group_size must be an integer, got {type(group_size_raw)}"
                    ) from None
826

827
828
            # "ignore" is the key in config.json
            exclude_modules = config.get("ignore", [])
829
            if not isinstance(exclude_modules, list):
830
831
832
                raise ValueError(
                    f"exclude_modules must be a list, got {type(exclude_modules)}"
                )
833

834
        if quant_method not in QUANT_ALGOS:
835
836
837
838
            raise ValueError(
                f"ModelOpt currently only supports: {QUANT_ALGOS} "
                "quantizations in vLLM. Please check the "
                "`hf_quant_config.json` file for your model's "
839
840
841
                "quant configuration."
            )
        is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
842
843
844
845
846

        # For FP4, these fields are required
        if is_checkpoint_nvfp4_serialized and "quantization" in config:
            # Check if required fields are present in the quantization config
            quant_config = config["quantization"]
847
            required_fields = ["group_size", "kv_cache_quant_algo", "exclude_modules"]
848
849
850
851
852
853
            missing_fields = [
                field for field in required_fields if field not in quant_config
            ]
            if missing_fields:
                raise ValueError(
                    f"NVFP4 quantization requires the following fields in "
854
855
856
857
858
859
860
861
862
                    f"hf_quant_config.json: {missing_fields}"
                )

        return cls(
            is_checkpoint_nvfp4_serialized,
            kv_cache_quant_algo,
            exclude_modules,
            group_size,
        )
863

864
865
866
867
868
869
    def is_layer_excluded(self, prefix: str) -> bool:
        """
        Check if a layer should be excluded from quantization.
        Handles both exact matching (for fused layers) and pattern matching.
        """
        # First check exact matching with fused layer support
870
        if is_layer_skipped(prefix, self.exclude_modules, self.packed_modules_mapping):
871
872
873
            return True

        # Check regex pattern matching for patterns not caught by exact match
874
        import regex as re
875

876
877
        for pattern in self.exclude_modules:
            # Skip patterns that would be caught by exact matching
878
879
            if "*" in pattern or "." in pattern:
                regex_str = pattern.replace(".", r"\.").replace("*", r".*")
880
881
                if re.fullmatch(regex_str, prefix):
                    return True
882
883
        return False

884
885
886
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
887
        from vllm.attention.layer import Attention  # Avoid circular import
888

889
        skip_layer = self.is_layer_excluded(prefix)
890
        if isinstance(layer, LinearBase):
891
            if skip_layer:
892
893
                return UnquantizedLinearMethod()
            # Check if this is a vision model layer that should not be quantized
894
            if "vision_tower" in prefix or "vision_model" in prefix:
895
896
897
898
                return UnquantizedLinearMethod()
            return ModelOptNvFp4LinearMethod(self)
        elif isinstance(layer, Attention):
            return ModelOptFp8KVCacheMethod(self)
899
        elif isinstance(layer, FusedMoE):
900
901
            if skip_layer:
                return None
902
            return ModelOptNvFp4FusedMoE(self, layer.moe_config, layer)
903
904
905
906
907
908
909
910
        return None


class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
    """

911
    def __init__(self, quant_config: ModelOptFp8Config | ModelOptNvFp4Config):
912
913
914
915
916
917
        super().__init__(quant_config)


class ModelOptNvFp4LinearMethod(LinearMethodBase):
    """Linear method for Model Optimizer NVFP4.
    Supports loading NVFP4 checkpoints with the following structure:
918

919
920
921
922
923
924
925
    input_scale: torch.float32, scalar ,
    weight: NVFP4(represented as byte) Shape: [1, X, y/2]
    weight_scale: FP8-E4M3, Shape: [X, Y], aka per block scale,
    weight_scale_2: torch.float32, scalar,
    Args: quant_config: The ModelOpt quantization config.
    """

926
    def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
927
        self.quant_config = quant_config
928

929
930
931
932
933
934
935
936
937
938
939
940
941
        self.backend = "none"
        if envs.VLLM_NVFP4_GEMM_BACKEND is None:
            if has_flashinfer():
                self.backend = "flashinfer-cutlass"
            elif cutlass_fp4_supported():
                self.backend = "cutlass"
            elif is_fp4_marlin_supported():
                self.backend = "marlin"
        elif envs.VLLM_NVFP4_GEMM_BACKEND.startswith("flashinfer-"):
            self.backend = envs.VLLM_NVFP4_GEMM_BACKEND
            assert has_flashinfer(), f"FlashInfer is required for {self.backend}"

        if self.backend == "none":
942
            raise ValueError(
943
944
                "No valid NVFP4 GEMM backend found. "
                "Please check your platform capability."
945
            )
946

947
948
        logger.info_once(f"Using {self.backend} for NVFP4 GEMM")

949
950
951
952
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
953
        output_partition_sizes: list[int],
954
955
956
957
958
959
960
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size
        if not self.quant_config.is_checkpoint_nvfp4_serialized:
961
962
963
964
            raise ValueError(
                "NVFP4 quantization was selected, "
                " dynamic quantization is not supported."
            )
965
966
967
968
969
970
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition

971
972
973
974
        if input_size_per_partition % 16 != 0:
            raise ValueError(
                "Unsupported model when in features size is not multiple of 16"
            )
975
        # The nvfp4 weight is still represented as
976
977
978
979
980
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_nvfp4_serialized
            else params_dtype
        )
981
982
983
984
985
986
        # Weight
        weight = ModelWeightParameter(
            data=torch.empty(
                # 2 fp4 items are packed in the input dimension
                layer.output_size_per_partition,
                layer.input_size_per_partition // 2,
987
988
                dtype=torch.uint8,
            ),
989
990
            input_dim=1,
            output_dim=0,
991
992
            weight_loader=weight_loader,
        )
993
994
995
        layer.register_parameter("weight", weight)

        # Input Weight Scale
996
997
998
999
        input_scale = PerTensorScaleParameter(
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )
1000
1001
1002
        layer.register_parameter("input_scale", input_scale)

        # Global Weight Scale
1003
1004
1005
1006
        weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )
1007
1008
1009
        layer.register_parameter("weight_scale_2", weight_scale_2)

        # Per Block Weight Scale
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
        weight_scale = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition // self.quant_config.group_size,
                dtype=weight_dtype,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030

        layer.register_parameter("weight_scale", weight_scale)

    def process_weights_after_loading(self, layer: Module) -> None:
        # global scales:
        input_scale_2 = layer.input_scale.max().to(torch.float32)
        layer.input_scale = Parameter(input_scale_2, requires_grad=False)

        weight_scale_2 = layer.weight_scale_2.max().to(torch.float32)
        layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False)

1031
1032
1033
        layer.alpha = Parameter(
            layer.input_scale * layer.weight_scale_2, requires_grad=False
        )
1034

1035
1036
        # Calculate `1 / input_scale` so that we don't need to do so at runtime
        layer.input_scale_inv = Parameter(
1037
1038
            (1 / layer.input_scale).to(torch.float32), requires_grad=False
        )
1039

1040
1041
1042
        # Swizzle the weight blockscale.
        # contracting dimension is input dimension
        # block_size = 16;
1043
1044
1045
        assert layer.weight_scale.dtype == torch.float8_e4m3fn, (
            "Weight Block scale must be represented as FP8-E4M3"
        )
1046

1047
1048
1049
1050
1051
        if self.backend == "marlin":
            prepare_fp4_layer_for_marlin(layer)
            del layer.alpha
            del layer.input_scale
        elif self.backend == "flashinfer-trtllm":
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
            # FlashInfer TRTLLM FP4 GEMM requires a different weight layout.
            # FlashInfer provides nvfp4_quantize to quantize + shuffle the
            # layout but we use our own quantization so we have to call
            # shuffles ourselves.
            from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a

            weight = layer.weight.data
            weight_scale = layer.weight_scale.data

            epilogue_tile_m = 128
1062
1063
1064
1065
1066
1067
            weight = shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m)
            weight_scale = (
                shuffle_matrix_sf_a(weight_scale.view(torch.uint8), epilogue_tile_m)
                .reshape(weight_scale.shape)
                .view(torch.float8_e4m3fn)
            )
1068

1069
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
1070
1071
1072
            layer.weight = Parameter(weight, requires_grad=False)
        else:
            swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
1073
            layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False)
1074
            layer.weight = Parameter(layer.weight.data, requires_grad=False)
1075
1076
1077
1078
1079

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
1080
        bias: torch.Tensor | None = None,
1081
    ) -> torch.Tensor:
1082
        if self.backend == "marlin":
1083
1084
1085
1086
1087
1088
1089
1090
            return apply_fp4_marlin_linear(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale,
                weight_scale_2=layer.weight_scale_2,
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
1091
1092
                bias=bias,
            )
1093

1094
        output_dtype = x.dtype
1095
        output_shape = [x.shape[0], layer.weight.shape[0]]
1096
1097

        # quantize BF16 or FP16 to (FP4 and interleaved block scale)
1098
        x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_scale_inv)
1099
1100
1101

        # validate dtypes of quantized input, input block scale,
        # weight and weight_blockscale
1102
1103
1104
1105
1106
        assert x_fp4.dtype == torch.uint8
        assert layer.weight.dtype == torch.uint8
        assert x_blockscale.dtype == torch.float8_e4m3fn
        assert layer.weight_scale.dtype == torch.float8_e4m3fn
        assert layer.alpha.dtype == torch.float32
1107

1108
1109
1110
1111
        mm_args = (
            x_fp4,
            layer.weight,
            x_blockscale,
1112
            layer.weight_scale,
1113
1114
1115
            layer.alpha,
            output_dtype,
        )
1116
1117
1118
        if self.backend.startswith("flashinfer-"):
            backend_name = self.backend[len("flashinfer-") :]
            out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name)
1119
        else:
1120
            assert self.backend == "cutlass"
1121
1122
            out = cutlass_scaled_fp4_mm(*mm_args)

1123
1124
1125
        if bias is not None:
            out = out + bias
        return out.view(*output_shape)
1126
1127
1128
1129
1130


class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
    """
    MoE Method for FP4 Quantization.
1131
    Args:
1132
1133
1134
        quant_config: NVFP4 Quant Config
    """

1135
1136
1137
1138
1139
1140
    def __init__(
        self,
        quant_config: ModelOptNvFp4Config,
        moe: FusedMoEConfig,
        layer: torch.nn.Module,
    ) -> None:
1141
        from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import (  # noqa: E501
1142
1143
1144
            detect_nvfp4_moe_support,
        )

1145
1146
1147
        super().__init__(moe)
        self.quant_config = quant_config
        self.layer = layer
1148
1149
        _nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
        self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
1150
        self.allow_flashinfer = _nvfp4.allow_flashinfer
1151
        self.use_marlin = _nvfp4.use_marlin
1152
        self.flashinfer_moe_backend = None
1153
        self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
1154
        if self.allow_flashinfer:
1155
1156
1157
            self.flashinfer_moe_backend = get_flashinfer_moe_backend()
            logger.info_once(
                f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
1158
1159
                " for ModelOptNvFp4FusedMoE."
            )
1160

1161
    def maybe_make_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize | None:
1162
1163
1164
1165
        if self.use_marlin or (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1166
            return None
1167
1168
1169
1170
        elif (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
        ):
1171
            # For now, fp4 moe only works with the flashinfer dispatcher.
1172
1173
1174
            prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(
                self.moe
            )
1175
1176
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
1177
1178
        else:
            return super().maybe_make_prepare_finalize()
1179

1180
1181
1182
    def select_gemm_impl(
        self,
        prepare_finalize: mk.FusedMoEPrepareAndFinalize,
1183
        layer: torch.nn.Module,
1184
    ) -> mk.FusedMoEPermuteExpertsUnpermute:
1185
        assert self.moe_quant_config is not None
1186
        experts = select_nvfp4_gemm_impl(
1187
1188
            self.moe,
            self.moe_quant_config,
1189
1190
1191
1192
            allow_flashinfer=self.allow_flashinfer,
        )
        logger.debug_once("Using %s", experts.__class__.__name__)
        return experts
1193

1194
1195
1196
1197
1198
1199
    def uses_weight_scale_2_pattern(self) -> bool:
        """
        FP4 variants use 'weight_scale_2' pattern for per-tensor weight scales.
        """
        return True

1200
1201
1202
1203
1204
1205
1206
1207
1208
    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
1209
        if not self.quant_config.is_checkpoint_nvfp4_serialized:
1210
1211
1212
1213
            raise ValueError(
                "NVFP4 quantization was selected, "
                " dynamic quantization is not supported."
            )
1214

1215
1216
        layer.num_experts = num_experts
        layer.params_dtype = params_dtype
1217
1218
1219
1220
        layer.quant_config = self.quant_config
        weight_dtype = torch.uint8
        weight_scale_dtype = torch.float8_e4m3fn
        weight_loader = extra_weight_attrs.get("weight_loader")
1221
        global_num_experts = extra_weight_attrs.get("global_num_experts")
1222
1223
1224
1225
1226
1227
1228
        # GEMM 1
        w13_weight = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                # 2 fp4 items are packed in the input dimension
                hidden_size // 2,
1229
1230
                dtype=weight_dtype,
            ),
1231
1232
            input_dim=1,
            output_dim=2,
1233
1234
            weight_loader=weight_loader,
        )
1235
1236
1237
1238
1239
1240
1241
1242
1243
        layer.register_parameter("w13_weight", w13_weight)

        # GEMM 2
        w2_weight = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                hidden_size,
                # 2 fp4 items are packed in the input dimension
                intermediate_size_per_partition // 2,
1244
1245
                dtype=weight_dtype,
            ),
1246
1247
            input_dim=1,
            output_dim=2,
1248
1249
            weight_loader=weight_loader,
        )
1250
1251
1252
1253
1254
1255
1256
1257
        layer.register_parameter("w2_weight", w2_weight)

        w13_weight_scale = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                # 2 fp4 items are packed in the input dimension
                hidden_size // self.quant_config.group_size,
1258
1259
                dtype=weight_scale_dtype,
            ),
1260
1261
            input_dim=1,
            output_dim=2,
1262
1263
            weight_loader=weight_loader,
        )
1264
1265
1266
1267
1268
1269
1270
        layer.register_parameter("w13_weight_scale", w13_weight_scale)

        w2_weight_scale = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                hidden_size,
                # 2 fp4 items are packed in the input dimension
1271
1272
1273
                intermediate_size_per_partition // self.quant_config.group_size,
                dtype=weight_scale_dtype,
            ),
1274
1275
            input_dim=1,
            output_dim=2,
1276
1277
            weight_loader=weight_loader,
        )
1278
1279
1280
        layer.register_parameter("w2_weight_scale", w2_weight_scale)

        extra_weight_attrs.update(
1281
1282
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
        )
1283
1284
1285

        w13_weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(num_experts, 2, dtype=torch.float32),
1286
1287
            weight_loader=weight_loader,
        )
1288
1289
1290
1291
        layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)

        w2_weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(num_experts, dtype=torch.float32),
1292
1293
            weight_loader=weight_loader,
        )
1294
1295
1296
        layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2)

        extra_weight_attrs.update(
1297
1298
            {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
        )
1299

1300
1301
1302
1303
1304
        use_global_sf = self.allow_flashinfer and is_flashinfer_supporting_global_sf(
            self.flashinfer_moe_backend
        )
        global_scale_num_experts = global_num_experts if use_global_sf else num_experts

1305
        w13_input_scale = PerTensorScaleParameter(
1306
            data=torch.empty(global_scale_num_experts, 2, dtype=torch.float32),
1307
1308
            weight_loader=weight_loader,
        )
1309
1310
        layer.register_parameter("w13_input_scale", w13_input_scale)

1311
        w2_input_scale = PerTensorScaleParameter(
1312
            data=torch.empty(global_scale_num_experts, dtype=torch.float32),
1313
1314
            weight_loader=weight_loader,
        )
1315
1316
        layer.register_parameter("w2_input_scale", w2_input_scale)

1317
    def prepare_static_weights_for_trtllm_fp4_moe(
1318
        self,
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
        # args_dequant,
        # args,
        gemm1_weights,
        gemm2_weights,
        gemm1_scales_linear_fp4_bytes,
        gemm2_scales_linear_fp4_bytes,
        hidden_size,
        intermediate_size,
        num_experts,
    ):
        from flashinfer import nvfp4_block_scale_interleave
        from flashinfer.fused_moe.core import (
1331
            _maybe_get_cached_w3_w1_permute_indices,
1332
            get_w2_permute_indices_with_cache,
1333
1334
        )

1335
1336
1337
1338
1339
        """Prepare quantized weights for kernel (done offline with weights)."""
        epilogue_tile_m = 128  # FIXME: this depends on the kernel internals

        # Convert quantized weights to proper formats
        gemm1_weights_fp4 = gemm1_weights.view(torch.float8_e4m3fn).reshape(
1340
1341
            num_experts, 2 * intermediate_size, hidden_size // 2
        )  # packed fp4
1342
        gemm1_scales_linear_fp4 = gemm1_scales_linear_fp4_bytes.view(
1343
1344
1345
1346
            torch.float8_e4m3fn
        ).reshape(
            num_experts, 2 * intermediate_size, hidden_size // 16
        )  # fp8 scaling factors
1347
1348

        gemm2_weights_fp4 = gemm2_weights.view(torch.float8_e4m3fn).reshape(
1349
1350
            num_experts, hidden_size, intermediate_size // 2
        )  # packed fp4
1351
        gemm2_scales_linear_fp4 = gemm2_scales_linear_fp4_bytes.view(
1352
1353
1354
1355
            torch.float8_e4m3fn
        ).reshape(
            num_experts, hidden_size, intermediate_size // 16
        )  # fp8 scaling factors
1356
1357
1358
1359
1360
1361

        gemm1_weights_fp4_shuffled = []
        gemm1_scales_fp4_shuffled = []
        gemm2_weights_fp4_shuffled = []
        gemm2_scales_fp4_shuffled = []
        for i in range(num_experts):
1362
1363
1364
1365
1366
1367
1368
1369
1370
            # Calculate the permute indices for the following:
            # 1. Reorder rows of W1 and scales for fused gated activation
            # 2. Shuffle weights and scaling factors for transposed mma output
            # for both w3_w1 and w2 weights and scale factors
            permute_indices = _maybe_get_cached_w3_w1_permute_indices(
                self._cache_permute_indices,
                gemm1_weights_fp4[i].view(torch.uint8),
                epilogue_tile_m,
            )
1371
1372
1373
1374
1375
            gemm1_weights_fp4_shuffled.append(
                gemm1_weights_fp4[i]
                .view(torch.uint8)[permute_indices.to(gemm1_weights_fp4.device)]
                .contiguous()
            )
1376
1377
1378
1379
1380
1381
1382

            permute_sf_indices = _maybe_get_cached_w3_w1_permute_indices(
                self._cache_permute_indices,
                gemm1_scales_linear_fp4[i].view(torch.uint8),
                epilogue_tile_m,
                num_elts_per_sf=16,
            )
1383
            gemm1_scales_fp4_shuffled.append(
1384
1385
1386
1387
1388
1389
1390
1391
                nvfp4_block_scale_interleave(
                    gemm1_scales_linear_fp4[i]
                    .view(torch.uint8)[
                        permute_sf_indices.to(gemm1_scales_linear_fp4.device)
                    ]
                    .contiguous()
                )
            )
1392

1393
            permute_indices = get_w2_permute_indices_with_cache(
1394
1395
1396
1397
                self._cache_permute_indices,
                gemm2_weights_fp4[i].view(torch.uint8),
                epilogue_tile_m,
            )
1398
1399
1400
1401
1402
            gemm2_weights_fp4_shuffled.append(
                gemm2_weights_fp4[i]
                .view(torch.uint8)[permute_indices.to(gemm2_weights_fp4.device)]
                .contiguous()
            )
1403

1404
            permute_sf_indices = get_w2_permute_indices_with_cache(
1405
1406
1407
1408
1409
                self._cache_permute_indices,
                gemm2_scales_linear_fp4[i].view(torch.uint8),
                epilogue_tile_m,
                num_elts_per_sf=16,
            )
1410
            gemm2_scales_fp4_shuffled.append(
1411
1412
1413
1414
1415
1416
1417
1418
                nvfp4_block_scale_interleave(
                    gemm2_scales_linear_fp4[i]
                    .view(torch.uint8)[
                        permute_sf_indices.to(gemm2_scales_linear_fp4.device)
                    ]
                    .contiguous()
                )
            )
1419
1420
1421
1422

        # Stack weights for all experts
        gemm1_weights_fp4_shuffled = torch.stack(gemm1_weights_fp4_shuffled)
        gemm1_scales_fp4_shuffled = (
1423
1424
1425
1426
            torch.stack(gemm1_scales_fp4_shuffled)
            .view(torch.float8_e4m3fn)
            .reshape(num_experts, 2 * intermediate_size, hidden_size // 16)
        )
1427
1428
1429

        gemm2_weights_fp4_shuffled = torch.stack(gemm2_weights_fp4_shuffled)
        gemm2_scales_fp4_shuffled = (
1430
1431
1432
1433
            torch.stack(gemm2_scales_fp4_shuffled)
            .view(torch.float8_e4m3fn)
            .reshape(num_experts, hidden_size, intermediate_size // 16)
        )
1434
1435
1436
1437
1438
1439
        return (
            gemm1_weights_fp4_shuffled,
            gemm1_scales_fp4_shuffled,
            gemm2_weights_fp4_shuffled,
            gemm2_scales_fp4_shuffled,
        )
1440

1441
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
1442
        # GEMM 1 processing
1443
1444
1445
        gemm1_weight = layer.w13_weight.data
        gemm1_weight_scale = layer.w13_weight_scale.data

1446
        if self.allow_flashinfer:
1447
            gemm1_weight, gemm1_weight_scale = reorder_w1w3_to_w3w1(
1448
1449
                gemm1_weight, gemm1_weight_scale, dim=-2
            )
1450
1451

        layer.w13_weight = Parameter(gemm1_weight, requires_grad=False)
1452
        layer.w13_weight_scale = Parameter(gemm1_weight_scale, requires_grad=False)
1453

1454
        # Common processing for w13_weight_scale_2
1455
1456
1457
        if not torch.allclose(
            layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]
        ):
1458
1459
            logger.warning_once(
                "w1_weight_scale_2 must match w3_weight_scale_2. "
1460
1461
                "Accuracy may be affected."
            )
1462
1463

        w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
1464
        layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False)
1465

1466
        # Common processing for input scales and alphas
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
        use_global_sf = self.allow_flashinfer and is_flashinfer_supporting_global_sf(
            self.flashinfer_moe_backend
        )
        if use_global_sf:
            # For backends provide by Flashinfer, the input global scales are
            # shared across all experts.
            w13_input_scale = (
                layer.w13_input_scale.max().to(torch.float32).expand(layer.num_experts)
            )
        else:
            w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
1478
1479
        layer.g1_alphas = Parameter(
            (w13_input_scale * w13_weight_scale_2).to(torch.float32),
1480
1481
            requires_grad=False,
        )
1482
1483
1484

        # This is for quantization, so we need to invert it.
        layer.w13_input_scale_quant = Parameter(
1485
1486
            (1 / w13_input_scale).to(torch.float32), requires_grad=False
        )
1487

1488
        # GEMM 2 processing
1489
1490
1491
1492
1493
1494
1495
1496
        if use_global_sf:
            # For backends provide by Flashinfer, the input global scales are
            # shared across all experts.
            w2_input_scale = (
                layer.w2_input_scale.max().to(torch.float32).expand(layer.num_experts)
            )
        else:
            w2_input_scale = layer.w2_input_scale
1497
        layer.g2_alphas = Parameter(
1498
            (w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
1499
1500
            requires_grad=False,
        )
1501
1502
1503

        # This is for quantization, so we need to invert it.
        layer.w2_input_scale_quant = Parameter(
1504
            (1 / w2_input_scale).to(torch.float32), requires_grad=False
1505
        )
1506

1507
        # TensorRT-LLM specific processing
1508
1509
1510
1511
        if (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1512
            # Prepare static weights for TRT-LLM kernel
1513
            # alternate: prepare_static_weight_layouts_for_trtllm_moe
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
            (
                gemm1_weights_fp4_shuffled,
                gemm1_scales_fp4_shuffled,
                gemm2_weights_fp4_shuffled,
                gemm2_scales_fp4_shuffled,
            ) = self.prepare_static_weights_for_trtllm_fp4_moe(
                layer.w13_weight,
                layer.w2_weight,
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                layer.w2_weight.size(-2),  # hidden_size
                layer.w13_weight.size(-2) // 2,  # intermediate_size
                layer.w13_weight.size(0),  # num_experts
            )
1528
            logger.debug_once("Finished shuffling weights for TRT-LLM MOE")
1529
1530

            layer.gemm1_weights_fp4_shuffled = Parameter(
1531
1532
                gemm1_weights_fp4_shuffled, requires_grad=False
            )
1533
            layer.gemm2_weights_fp4_shuffled = Parameter(
1534
1535
                gemm2_weights_fp4_shuffled, requires_grad=False
            )
1536
            layer.gemm1_scales_fp4_shuffled = Parameter(
1537
1538
                gemm1_scales_fp4_shuffled, requires_grad=False
            )
1539
            layer.gemm2_scales_fp4_shuffled = Parameter(
1540
1541
                gemm2_scales_fp4_shuffled, requires_grad=False
            )
1542
1543
1544

            # Additional parameter needed for TRT-LLM
            layer.g1_scale_c = Parameter(
1545
                (layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32),
1546
1547
                requires_grad=False,
            )
1548

1549
1550
1551
1552
1553
            # Clean up weights that won't be used by TRT-LLM
            del layer.w2_weight
            del layer.w2_weight_scale
            del layer.w13_weight
            del layer.w13_weight_scale
1554
1555
1556
1557
1558
1559
1560
        elif self.use_marlin:
            # Marlin processing
            prepare_moe_fp4_layer_for_marlin(layer)
            del layer.g1_alphas
            del layer.g2_alphas
            del layer.w13_input_scale_quant
            del layer.w2_input_scale_quant
1561
1562
        else:
            # Non-TRT-LLM processing (Cutlass or non-flashinfer)
1563
1564
1565
1566
1567
            w13_blockscale_swizzled = swizzle_blockscale(layer.w13_weight_scale)
            layer.w13_weight_scale = Parameter(
                w13_blockscale_swizzled, requires_grad=False
            )

1568
            w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale)
1569
1570
1571
1572
            layer.w2_weight_scale = Parameter(
                w2_blockscale_swizzled, requires_grad=False
            )
            layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
1573

1574
    def get_fused_moe_quant_config(
1575
        self, layer: torch.nn.Module
1576
    ) -> FusedMoEQuantConfig | None:
1577
1578
1579
1580
        if (
            self.use_marlin
            or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
            return None

        return nvfp4_moe_quant_config(
            w1_scale=layer.w13_weight_scale,
            w2_scale=layer.w2_weight_scale,
            g1_alphas=layer.g1_alphas,
            g2_alphas=layer.g2_alphas,
            a1_gscale=layer.w13_input_scale_quant,
            a2_gscale=layer.w2_input_scale_quant,
        )

1592
1593
1594
1595
1596
1597
1598
1599
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
        use_grouped_topk: bool = False,
1600
1601
        topk_group: int | None = None,
        num_expert_group: int | None = None,
1602
        global_num_experts: int = -1,
1603
1604
        expert_map: torch.Tensor | None = None,
        custom_routing_function: Callable | None = None,
1605
        scoring_func: str = "softmax",
1606
        routed_scaling_factor: float = 1.0,
1607
        e_score_correction_bias: torch.Tensor | None = None,
1608
1609
        apply_router_weight_on_input: bool = False,
        activation: str = "silu",
1610
        enable_eplb: bool = False,
1611
1612
1613
1614
        expert_load_view: torch.Tensor | None = None,
        logical_to_physical_map: torch.Tensor | None = None,
        logical_replica_count: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
1615
1616
        if enable_eplb:
            raise NotImplementedError(
1617
1618
                "EPLB not supported for `ModelOptNvFp4FusedMoE` yet."
            )
1619
        assert activation == "silu", "Only SiLU activation is supported."
1620

1621
1622
1623
1624
        if (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1625
1626
1627
1628
            import flashinfer

            from vllm.model_executor.models.llama4 import Llama4MoE

1629
1630
            assert self.fused_experts is None

1631
            a1_gscale = layer.w13_input_scale_quant
1632
1633
1634
1635
1636
1637
1638
1639
            (hidden_states_fp4, hidden_states_scale_linear_fp4) = (
                flashinfer.fp4_quantize(
                    x,
                    a1_gscale,
                    is_sf_swizzled_layout=False,
                )
            )
            use_llama4_routing = (
1640
                custom_routing_function is Llama4MoE.custom_routing_function
1641
            )
1642
1643
1644
            routing_method_type = flashinfer.RoutingMethodType.DeepSeekV3
            if use_llama4_routing:
                routing_method_type = flashinfer.RoutingMethodType.Llama4
Shu Wang's avatar
Shu Wang committed
1645
1646
1647
            routing_bias = e_score_correction_bias
            if routing_bias is not None:
                routing_bias = routing_bias.to(torch.bfloat16)
1648
1649
            out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
                routing_logits=router_logits
1650
1651
                if use_llama4_routing
                else router_logits.to(torch.float32),
Shu Wang's avatar
Shu Wang committed
1652
                routing_bias=routing_bias,
1653
1654
                hidden_states=hidden_states_fp4,
                hidden_states_scale=hidden_states_scale_linear_fp4.view(
1655
1656
                    torch.float8_e4m3fn
                ).flatten(),
1657
1658
                gemm1_weights=layer.gemm1_weights_fp4_shuffled.data,
                gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.data.view(
1659
1660
                    torch.float8_e4m3fn
                ),
1661
1662
1663
1664
1665
1666
                gemm1_bias=None,
                gemm1_alpha=None,
                gemm1_beta=None,
                gemm1_clamp_limit=None,
                gemm2_weights=layer.gemm2_weights_fp4_shuffled.data,
                gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.data.view(
1667
1668
                    torch.float8_e4m3fn
                ),
1669
1670
1671
1672
1673
1674
                gemm2_bias=None,
                output1_scale_scalar=layer.g1_scale_c.data,
                output1_scale_gate_scalar=layer.g1_alphas.data,
                output2_scale_scalar=layer.g2_alphas.data,
                num_experts=global_num_experts,
                top_k=top_k,
1675
                n_group=num_expert_group if num_expert_group is not None else 0,
1676
                topk_group=topk_group if topk_group is not None else 0,
1677
1678
1679
1680
                intermediate_size=layer.intermediate_size_per_partition,
                local_expert_offset=layer.ep_rank * layer.local_num_experts,
                local_num_experts=layer.local_num_experts,
                routed_scaling_factor=None,
1681
                tile_tokens_dim=None,
1682
1683
1684
1685
1686
                routing_method_type=routing_method_type,
                do_finalize=True,
            )[0]
            return out

XuruiYang's avatar
XuruiYang committed
1687
        topk_weights, topk_ids, _ = FusedMoE.select_experts(
1688
1689
1690
1691
1692
1693
1694
1695
1696
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
            num_expert_group=num_expert_group,
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
1697
            routed_scaling_factor=routed_scaling_factor,
1698
            e_score_correction_bias=e_score_correction_bias,
1699
1700
            indices_type=self.topk_indices_dtype,
        )
1701

1702
1703
1704
1705
1706
        #
        # Note: the order here is important. self.fused_experts can override
        # flashinfer cutlass, cutlass fp4 or fused_experts but not marlin or
        # trtllm.
        #
1707
        if self.use_marlin:
1708
            assert self.fused_experts is None
1709
            return fused_marlin_moe(
1710
1711
1712
                x,
                layer.w13_weight,
                layer.w2_weight,
1713
1714
                None,
                None,
1715
1716
1717
1718
1719
1720
1721
1722
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                router_logits,
                topk_weights,
                topk_ids,
                global_scale1=layer.w13_weight_scale_2,
                global_scale2=layer.w2_weight_scale_2,
                quant_type_id=scalar_types.float4_e2m1f.id,
1723
                apply_router_weight_on_input=apply_router_weight_on_input,
1724
                global_num_experts=global_num_experts,
1725
                expert_map=expert_map,
1726
1727
                workspace=layer.workspace,
            )
1728

1729
        elif self.fused_experts is not None:
1730
1731
1732
1733
            assert (
                self.allow_flashinfer
                and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
            )
1734
1735

            assert is_valid_flashinfer_cutlass_fused_moe(
1736
1737
                x, layer.w13_weight, layer.w2_weight
            ), "Flashinfer CUTLASS Fused MoE not applicable!"
1738

1739
            return self.fused_experts(
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                inplace=False,  # TODO(shuw): fix later, now output is high prec
                activation=activation,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
1751
1752
1753
1754
        elif (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
        ):
1755
            from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (  # noqa: E501
1756
1757
1758
                flashinfer_cutlass_moe_fp4,
            )

1759
            assert self.moe_quant_config is not None
1760

1761
            return flashinfer_cutlass_moe_fp4(
1762
1763
1764
1765
1766
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
1767
1768
                quant_config=self.moe_quant_config,
                inplace=False,
1769
1770
1771
1772
1773
1774
                activation=activation,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
        else:
1775
1776
            # If no modular kernel is provided, use cutlass_moe_fp4 for TP case
            # only (no EP).
1777
1778
            from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4

1779
1780
            assert self.moe_quant_config is not None
            return cutlass_moe_fp4(
1781
1782
1783
1784
1785
                a=x,
                w1_fp4=layer.w13_weight,
                w2_fp4=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
1786
1787
1788
1789
                quant_config=self.moe_quant_config,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
                # TODO: derive from arguments
1790
1791
1792
1793
                m=x.shape[0],
                n=layer.w2_weight.shape[2] * 2,
                k=x.shape[1],
                e=layer.w13_weight.shape[0],
1794
            )