modelopt.py 68.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
5
6
7
8
9

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter

10
11
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
12
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
13
from vllm.logger import init_logger
14
from vllm.model_executor.layers.fused_moe.config import (
15
16
17
18
19
    FusedMoEConfig,
    FusedMoEQuantConfig,
    fp8_w8a8_moe_quant_config,
    nvfp4_moe_quant_config,
)
20
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
21
22
    is_valid_flashinfer_cutlass_fused_moe,
)
23
from vllm.model_executor.layers.fused_moe.layer import (
24
25
26
27
28
29
30
31
32
    FusedMoE,
    FusedMoEMethodBase,
    FusedMoeWeightScaleSupported,
)
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
)
33
from vllm.model_executor.layers.quantization import QuantizationMethods
34
from vllm.model_executor.layers.quantization.base_config import (
35
36
37
    QuantizationConfig,
    QuantizeMethodBase,
)
38
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
39
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
40
41
42
43
    build_flashinfer_fp4_cutlass_moe_prepare_finalize,
    reorder_w1w3_to_w3w1,
    select_nvfp4_gemm_impl,
)
44
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
45
46
    FlashinferMoeBackend,
    apply_flashinfer_per_tensor_scale_fp8,
47
    build_flashinfer_fp8_cutlass_moe_prepare_finalize,
48
49
50
51
52
53
54
    flashinfer_cutlass_moe_fp8,
    get_flashinfer_moe_backend,
    register_moe_scaling_factors,
    rotate_flashinfer_fp8_moe_weights,
    select_cutlass_fp8_gemm_impl,
    swap_w13_to_w31,
)
55
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
56
57
58
59
60
    apply_fp4_marlin_linear,
    is_fp4_marlin_supported,
    prepare_fp4_layer_for_marlin,
    prepare_moe_fp4_layer_for_marlin,
)
61
from vllm.model_executor.layers.quantization.utils.quant_utils import (
62
63
64
65
66
    GroupShape,
    cutlass_fp4_supported,
    is_layer_skipped,
    swizzle_blockscale,
)
67
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
68
69
70
71
    Fp8LinearOp,
    requantize_with_max_scale,
)
from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
72
from vllm.scalar_type import scalar_types
73
from vllm.utils import next_power_of_2
74
75
76
77
78
from vllm.utils.flashinfer import (
    flashinfer_scaled_fp4_mm,
    has_flashinfer,
    has_flashinfer_moe,
)
79

80
81
82
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

83
84
logger = init_logger(__name__)

85
86
QUANT_ALGOS = ["FP8", "NVFP4"]
KV_CACHE_QUANT_ALGOS = ["FP8"]
87
88
89
90
91
92
93
94


class ModelOptFp8Config(QuantizationConfig):
    """Config class for ModelOpt FP8."""

    def __init__(
        self,
        is_checkpoint_fp8_serialized: bool = False,
95
96
        kv_cache_quant_method: Optional[str] = None,
        exclude_modules: Optional[list[str]] = None,
97
    ) -> None:
98
        super().__init__()
99
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
100
        self.kv_cache_quant_method = kv_cache_quant_method
101
        self.exclude_modules = exclude_modules or []
102
        if is_checkpoint_fp8_serialized:
103
104
105
106
            logger.warning(
                "Detected ModelOpt fp8 checkpoint. Please note that"
                " the format is experimental and could change."
            )
107
108

    @classmethod
109
    def get_name(cls) -> QuantizationMethods:
110
111
112
        return "modelopt"

    @classmethod
113
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
114
115
116
117
118
119
120
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
        return 89

    @classmethod
121
    def get_config_filenames(cls) -> list[str]:
122
123
        return ["hf_quant_config.json"]

124
125
    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if self.exclude_modules is not None:
126
            self.exclude_modules = hf_to_vllm_mapper.apply_list(self.exclude_modules)
127

128
129
    @classmethod
    def override_quantization_method(
130
131
        cls, hf_quant_cfg, user_quant
    ) -> Optional[QuantizationMethods]:
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
        """Detect if this ModelOpt config should be used based on
        quantization config."""

        if hf_quant_cfg is None:
            return None

        # Use the community standard 'quant_method'
        quant_method = hf_quant_cfg.get("quant_method", "").lower()

        # Only proceed if the method is explicitly "modelopt"
        if quant_method != "modelopt":
            return None

        # Look for ModelOpt-specific config structure
        if "quantization" in hf_quant_cfg:
            quant_config = hf_quant_cfg["quantization"]
            if isinstance(quant_config, dict):
                quant_algo = quant_config.get("quant_algo", "")
                if "FP8" in quant_algo:
                    return "modelopt"
        else:
            # Check for compressed-tensors style config with specific quant_algo
            quant_algo = hf_quant_cfg.get("quant_algo", "")
            if isinstance(quant_algo, str) and "FP8" in quant_algo:
                return "modelopt"

        return None

160
    @classmethod
161
    def from_config(cls, config: dict[str, Any]) -> "ModelOptFp8Config":
162
163
164
165
166
        # Handle both ModelOpt format and compressed-tensors style format
        if "quantization" in config:
            # ModelOpt format: {"quantization": {"quant_algo": "..."}}
            quant_config = cls.get_from_keys(config, ["quantization"])
            if not isinstance(quant_config, dict):
167
                raise ValueError("Expected 'quantization' to be a dictionary in config")
168
169
170
171
            quant_method = quant_config.get("quant_algo", "")
            if not quant_method:
                raise ValueError("Missing 'quant_algo' in quantization config")
            kv_cache_quant_method = quant_config.get("kv_cache_quant_algo")
172
            # "exclude_modules" is the key in the legacy hf_quant_config.json
173
174
175
176
177
178
            exclude_modules = quant_config.get("exclude_modules")
        else:
            # Compressed-tensors style format:
            # {"quant_algo": "...", "quant_method": "modelopt"}
            quant_method = config.get("quant_algo", "")
            kv_cache_quant_method = config.get("kv_cache_quant_algo")
179
180
            # "ignore" is the key in config.json
            exclude_modules = config.get("ignore")
181

182
        if quant_method not in QUANT_ALGOS:
183
184
185
186
            raise ValueError(
                f"ModelOpt currently only supports: {QUANT_ALGOS} "
                "quantizations in vLLM. Please check the "
                "`hf_quant_config.json` file for your model's "
187
188
189
                "quant configuration."
            )
        is_checkpoint_fp8_serialized = "FP8" in quant_method
190

191
        return cls(is_checkpoint_fp8_serialized, kv_cache_quant_method, exclude_modules)
192
193
194
195

    def is_layer_excluded(self, prefix: str) -> bool:
        """
        Check if a layer should be excluded from quantization.
196
        Handles both exact matching (for fused layers) and substring matching.
197
198
199
200
201
202
203
204

        This method handles both regular models and multimodal models that use
        the language_model prefix. For multimodal models, it checks if the
        module name (without the language_model prefix) is in the exclude list.
        """
        if self.exclude_modules is None:
            return False

205
        # First check exact matching with fused layer support
206
        if is_layer_skipped(prefix, self.exclude_modules, self.packed_modules_mapping):
207
208
209
            return True

        # Then check substring matching for patterns not caught by exact match
210
        for module in self.exclude_modules:
211
            # Skip exact matches already handled above
212
213
214
215
216
217
218
            if module != prefix and (
                module in prefix
                or (
                    prefix.startswith("language_model.")
                    and module in prefix.removeprefix("language_model.")
                )
            ):
219
220
                return True
        return False
221

222
223
224
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
225
        from vllm.attention.layer import Attention  # Avoid circular import
226

227
        if isinstance(layer, LinearBase):
228
229
230
            if self.is_layer_excluded(prefix):
                return UnquantizedLinearMethod()
            # Check if this is a vision model layer that should not be quantized
231
            if "vision_tower" in prefix or "vision_model" in prefix:
232
                return UnquantizedLinearMethod()
233
234
235
            return ModelOptFp8LinearMethod(self)
        elif isinstance(layer, Attention):
            return ModelOptFp8KVCacheMethod(self)
236
        elif isinstance(layer, FusedMoE):
237
            return ModelOptFp8MoEMethod(self, layer)
238
239
240
241
242
243
        return None


class ModelOptFp8LinearMethod(LinearMethodBase):
    """Linear method for Model Optimizer static quantization.
    Supports loading FP8 checkpoints with static weight scale and
244
    activation scale. Future support might be added for dynamic
245
246
247
248
    scales.

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
249
    2. Only support float8_e4m3fn datatype
250
251
252
        Args: quant_config: The ModelOpt quantization config.
    """

253
    def __init__(self, quant_config: ModelOptFp8Config) -> None:
254
        self.quant_config = quant_config
255
        self.fp8_linear = Fp8LinearOp(
256
257
            act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR
        )
258
259
260
261
262

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
263
        output_partition_sizes: list[int],
264
265
266
267
268
269
270
271
272
273
274
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
275
276
277
278
279
280
281
282
283
284
285
286
287
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_fp8_serialized
            else params_dtype
        )
        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition, input_size_per_partition, dtype=weight_dtype
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
288
289
290
291
        layer.register_parameter("weight", weight)

        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
292
293
294
295
            weight_scale = PerTensorScaleParameter(
                data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
                weight_loader=weight_loader,
            )
296
297
298
            weight_scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("weight_scale", weight_scale)
            # INPUT SCALE
299
300
301
302
            scale = PerTensorScaleParameter(
                data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
                weight_loader=weight_loader,
            )
303
304
305
306
307

            scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("input_scale", scale)

    def process_weights_after_loading(self, layer: Module) -> None:
308
309
310
311
        weight = layer.weight
        max_w_scale = layer.weight_scale.max()
        if not (layer.weight_scale == layer.weight_scale[0]).all():
            max_w_scale, weight = requantize_with_max_scale(
312
313
                layer.weight, layer.weight_scale, layer.logical_widths
            )
314
315
        layer.weight = Parameter(weight.t(), requires_grad=False)
        layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
316
        layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
317
318
319
320
321
322
323

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
324
325
326
327
328
329
330
        return self.fp8_linear.apply(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            input_scale=layer.input_scale,
            bias=bias,
        )
331
332


333
334
335
336
337
338
339
340
class ModelOptFp8MoEMethod(FusedMoEMethodBase):
    """MoE method for ModelOpt FP8.
    Supports loading FP8 checkpoints with static weight scale and
    activation scale.
    Args:
        quant_config: The ModelOpt quantization config.
    """

341
342
343
    def __init__(
        self,
        quant_config: ModelOptFp8Config,
344
        layer: torch.nn.Module,
345
    ) -> None:
346
347
        super().__init__(layer.moe_config)
        self.layer = layer
348
349
        self.quant_config = quant_config
        from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
350
351
352
            cutlass_fp8_supported,
        )

353
        self.cutlass_fp8_supported = cutlass_fp8_supported()
354
        self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None
355
        if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
356
            self.flashinfer_moe_backend = get_flashinfer_moe_backend()
357
            logger.info_once(
358
359
360
361
                f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
            )

    def maybe_make_prepare_finalize(
362
363
        self,
    ) -> Optional[mk.FusedMoEPrepareAndFinalize]:
364
365
366
367
        # TRT LLM not supported with all2all yet.
        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
            return None
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
368
369
370
            prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
                self.moe
            )
371
372
373
374
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
        else:
            return super().maybe_make_prepare_finalize()
375
376
377
378

    def select_gemm_impl(
        self,
        prepare_finalize: mk.FusedMoEPrepareAndFinalize,
379
        layer: torch.nn.Module,
380
    ) -> mk.FusedMoEPermuteExpertsUnpermute:
381
        assert self.moe_quant_config is not None
382
        experts = select_cutlass_fp8_gemm_impl(
383
384
            self.moe,
            self.moe_quant_config,
385
386
387
        )
        logger.debug_once("Using %s", experts.__class__.__name__)
        return experts
388
389
390
391
392
393
394
395
396
397
398

    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        # Use FP8 dtype if checkpoint is serialized
399
400
401
402
403
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_fp8_serialized
            else params_dtype
        )
404
405
406
        weight_loader = extra_weight_attrs.get("weight_loader")

        w13_weight = ModelWeightParameter(
407
408
409
410
411
412
            data=torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                hidden_size,
                dtype=weight_dtype,
            ),
413
414
415
416
417
418
419
            input_dim=2,
            output_dim=1,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_weight", w13_weight)

        w2_weight = ModelWeightParameter(
420
421
422
423
424
425
            data=torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=weight_dtype,
            ),
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
            input_dim=2,
            output_dim=1,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w2_weight", w2_weight)

        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALES - Per-tensor scaling for ModelOpts
            # Allocate 2 scales for w1 and w3 respectively.
            # They will be combined to a single scale after weight loading.
            w13_weight_scale = PerTensorScaleParameter(
                data=torch.full(
                    (num_experts, 2),
                    1.0,
                    dtype=torch.float32,
                ),
                weight_loader=weight_loader,
            )
            w2_weight_scale = PerTensorScaleParameter(
445
                data=torch.full((num_experts,), 1.0, dtype=torch.float32),
446
447
448
449
450
451
452
                weight_loader=weight_loader,
            )
            layer.register_parameter("w13_weight_scale", w13_weight_scale)
            layer.register_parameter("w2_weight_scale", w2_weight_scale)

            # Set weight loader attributes for scales
            extra_weight_attrs.update(
453
454
                {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
            )
455
456
457

            # INPUT SCALES - Per-tensor scaling for ModelOpt
            w13_input_scale = PerTensorScaleParameter(
458
                data=torch.full((num_experts,), 1.0, dtype=torch.float32),
459
460
461
                weight_loader=weight_loader,
            )
            w2_input_scale = PerTensorScaleParameter(
462
                data=torch.full((num_experts,), 1.0, dtype=torch.float32),
463
464
465
466
467
468
469
470
471
472
                weight_loader=weight_loader,
            )
            layer.register_parameter("w13_input_scale", w13_input_scale)
            layer.register_parameter("w2_input_scale", w2_input_scale)

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        """Process FP8 MoE weights after loading from serialized checkpoint.
        Only supports pre-quantized checkpoints with FP8 weights and scales.
        """

473
        layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
474
475
476
477
        layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)

        from vllm._custom_ops import scaled_fp8_quant
        from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
478
479
            per_tensor_dequantize,
        )
480
481

        # Handle scale parameters
482
        if hasattr(layer, "w13_weight_scale") and layer.w13_weight_scale is not None:
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
            # Fp8 moe kernel needs single weight scale for w13 per expert.
            # We take the max of the w1 and w3 scales
            # then dequant and requant each expert.
            if layer.w13_weight_scale.dim() == 2:
                # Get the maximum scale across w1 and w3 for each expert
                max_w13_scales = layer.w13_weight_scale.max(dim=1).values

                # Requantize each expert's weights using the combined scale
                # w13_weight (num_experts, 2 * intermediate_size, hidden_size)
                # where the first intermediate_size rows are w1, the next are w3
                intermediate_size = layer.w13_weight.shape[1] // 2
                for expert_id in range(layer.w13_weight.shape[0]):
                    start = 0
                    for shard_id in range(2):  # w1 and w3
                        # Dequantize using the original scale for this shard
                        dq_weight = per_tensor_dequantize(
499
500
501
                            layer.w13_weight[expert_id][
                                start : start + intermediate_size, :
                            ],
502
503
504
505
506
                            layer.w13_weight_scale[expert_id][shard_id],
                        )
                        # Requantize using the combined max scale

                        (
507
508
509
                            layer.w13_weight[expert_id][
                                start : start + intermediate_size, :
                            ],
510
                            _,
511
                        ) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
512
513
514
515

                        start += intermediate_size

                # Update the scale parameter to be per-expert
516
                layer.w13_weight_scale = Parameter(max_w13_scales, requires_grad=False)
517
            else:
518
519
520
                layer.w13_weight_scale = Parameter(
                    layer.w13_weight_scale.data, requires_grad=False
                )
521

522
523
524
525
        if hasattr(layer, "w2_weight_scale") and layer.w2_weight_scale is not None:
            layer.w2_weight_scale = Parameter(
                layer.w2_weight_scale.data, requires_grad=False
            )
526
        # Input scales must be equal for each expert in fp8 MoE layers.
527
528
529
530
531
532
533
534
        if hasattr(layer, "w13_input_scale") and layer.w13_input_scale is not None:
            layer.w13_input_scale = Parameter(
                layer.w13_input_scale.max(), requires_grad=False
            )
        if hasattr(layer, "w2_input_scale") and layer.w2_input_scale is not None:
            layer.w2_input_scale = Parameter(
                layer.w2_input_scale.max(), requires_grad=False
            )
535

536
        if self.flashinfer_moe_backend is not None:
537
            layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
538
            register_moe_scaling_factors(layer)
539
            if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
540
                rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight)
541

542
    def get_fused_moe_quant_config(
543
544
        self, layer: torch.nn.Module
    ) -> Optional[FusedMoEQuantConfig]:
545
546
547
548
549
550
551
552
553
554
555
        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
            return None

        return fp8_w8a8_moe_quant_config(
            w1_scale=layer.w13_weight_scale,
            w2_scale=layer.w2_weight_scale,
            a1_scale=layer.w13_input_scale,
            a2_scale=layer.w2_input_scale,
            per_act_token_quant=False,
        )

556
557
558
559
560
561
562
563
564
565
566
567
568
569
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
        use_grouped_topk: bool = False,
        topk_group: Optional[int] = None,
        num_expert_group: Optional[int] = None,
        global_num_experts: int = -1,
        expert_map: Optional[torch.Tensor] = None,
        custom_routing_function: Optional[Callable] = None,
        scoring_func: str = "softmax",
570
        routed_scaling_factor: float = 1.0,
571
572
573
574
575
576
577
        e_score_correction_bias: Optional[torch.Tensor] = None,
        apply_router_weight_on_input: bool = False,
        activation: str = "silu",
        enable_eplb: bool = False,
        expert_load_view: Optional[torch.Tensor] = None,
        logical_to_physical_map: Optional[torch.Tensor] = None,
        logical_replica_count: Optional[torch.Tensor] = None,
578
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
579
580
        if enable_eplb:
            raise NotImplementedError(
581
582
                "EPLB not supported for `ModelOptFp8MoEMethod` yet."
            )
583

584
        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
585
            assert self.fused_experts is None
586
587
588
            assert activation == "silu", (
                f"Expected 'silu' activation but got {activation}"
            )
589
590
591
592
593
594
595
596
597
598
            assert not renormalize
            return apply_flashinfer_per_tensor_scale_fp8(
                layer=layer,
                hidden_states=x,
                router_logits=router_logits,
                routing_bias=e_score_correction_bias,
                global_num_experts=global_num_experts,
                top_k=top_k,
                num_expert_group=num_expert_group,
                topk_group=topk_group,
599
600
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
601

602
        # Expert selection
XuruiYang's avatar
XuruiYang committed
603
        topk_weights, topk_ids, _ = FusedMoE.select_experts(
604
605
606
607
608
609
610
611
612
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
            num_expert_group=num_expert_group,
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
613
            routed_scaling_factor=routed_scaling_factor,
614
            e_score_correction_bias=e_score_correction_bias,
615
            indices_type=self.topk_indices_dtype,
616
        )
617

618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
        #
        # Note: the order here is important. self.fused_experts can override
        # cutlass or fused_experts.
        #
        if self.fused_experts is not None:
            return self.fused_experts(
                x,
                layer.w13_weight,
                layer.w2_weight,
                topk_weights,
                topk_ids,
                inplace=False,
                activation=activation,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
636
            assert not renormalize
637
638
639
            assert activation == "silu", (
                f"Expected 'silu' activation but got {activation}"
            )
640
641
642
643
644
645
646
647
648
649
650
651
            return flashinfer_cutlass_moe_fp8(
                x,
                layer,
                topk_weights,
                topk_ids,
                inplace=False,
                activation=activation,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
        else:
652
653
            from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts

654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
            assert self.moe_quant_config is not None

            return fused_experts(
                x,
                layer.w13_weight,
                layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                inplace=True,
                activation=activation,
                quant_config=self.moe_quant_config,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
669
670


671
672
673
674
675
676
class ModelOptNvFp4Config(QuantizationConfig):
    """Config class for ModelOpt FP4."""

    def __init__(
        self,
        is_checkpoint_nvfp4_serialized: bool,
677
        kv_cache_quant_algo: Optional[str],
678
        exclude_modules: list[str],
679
680
        group_size: int = 16,
    ) -> None:
681
        super().__init__()
682
683
684
685
        self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
        if is_checkpoint_nvfp4_serialized:
            logger.warning(
                "Detected ModelOpt NVFP4 checkpoint. Please note that"
686
687
                " the format is experimental and could change in future."
            )
688
689
690
691
692
693

            self.group_size = group_size
            self.kv_cache_quant_algo = kv_cache_quant_algo
            self.exclude_modules = exclude_modules

    @classmethod
694
    def get_name(cls) -> QuantizationMethods:
695
        return "modelopt_fp4"
696
697

    @classmethod
698
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
699
700
701
702
        return [torch.bfloat16, torch.half, torch.float8_e4m3fn]

    @classmethod
    def get_min_capability(cls) -> int:
703
        return 80
704
705

    @classmethod
706
    def get_config_filenames(cls) -> list[str]:
707
708
        return ["hf_quant_config.json"]

709
710
    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if self.exclude_modules is not None:
711
            self.exclude_modules = hf_to_vllm_mapper.apply_list(self.exclude_modules)
712

713
714
    @classmethod
    def override_quantization_method(
715
716
        cls, hf_quant_cfg, user_quant
    ) -> Optional[QuantizationMethods]:
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
        """Detect if this ModelOpt FP4 config should be used based on
        quantization config."""
        if hf_quant_cfg is None:
            return None

        # Use the community standard 'quant_method'
        quant_method = hf_quant_cfg.get("quant_method", "").lower()

        # Only proceed if the method is explicitly "modelopt"
        if quant_method != "modelopt":
            return None

        # Look for ModelOpt-specific config structure
        if "quantization" in hf_quant_cfg:
            quant_config = hf_quant_cfg["quantization"]
            if isinstance(quant_config, dict):
                quant_algo = quant_config.get("quant_algo", "")
                if "NVFP4" in quant_algo:
                    return "modelopt_fp4"
        else:
            # Check for compressed-tensors style config with specific
            # quant_algo field
            quant_algo = hf_quant_cfg.get("quant_algo", "")
            if isinstance(quant_algo, str) and "FP4" in quant_algo.upper():
                return "modelopt_fp4"

        return None

745
    @classmethod
746
    def from_config(cls, config: dict[str, Any]) -> "ModelOptNvFp4Config":
747
748
749
750
751
752
753
        # Handle both traditional ModelOpt format and compressed-tensors
        # style format
        if "quantization" in config:
            # Traditional ModelOpt format:
            # {"quantization": {"quant_algo": "..."}}
            quant_config = cls.get_from_keys(config, ["quantization"])
            if not isinstance(quant_config, dict):
754
                raise ValueError("Expected 'quantization' to be a dictionary in config")
755
756
757
758
759
760
761
762
763
764
765
766
767

            quant_method = quant_config.get("quant_algo", "")
            if not quant_method:
                raise ValueError("Missing 'quant_algo' in quantization config")

            # Handle kv_cache_quant_algo with proper type validation
            kv_cache_quant_algo_raw = quant_config.get("kv_cache_quant_algo")
            if kv_cache_quant_algo_raw is None:
                # No KV cache quantization by default
                kv_cache_quant_algo = None
            elif isinstance(kv_cache_quant_algo_raw, str):
                kv_cache_quant_algo = kv_cache_quant_algo_raw
            else:
768
769
770
771
                raise ValueError(
                    f"kv_cache_quant_algo must be a string, got "
                    f"{type(kv_cache_quant_algo_raw)}"
                )
772
773
774
775
776
777
778
779
780
781
782

            # Handle group_size with proper type validation
            group_size_raw = quant_config.get("group_size")
            if group_size_raw is None:
                group_size = 16  # Default value
            elif isinstance(group_size_raw, int):
                group_size = group_size_raw
            else:
                try:
                    group_size = int(group_size_raw)
                except (ValueError, TypeError):
783
784
785
                    raise ValueError(
                        f"group_size must be an integer, got {type(group_size_raw)}"
                    ) from None
786

787
            # "exclude_modules" is the key in the legacy hf_quant_config.json
788
789
            exclude_modules = quant_config.get("exclude_modules", [])
            if not isinstance(exclude_modules, list):
790
791
792
                raise ValueError(
                    f"exclude_modules must be a list, got {type(exclude_modules)}"
                )
793
794
795
796
797
798
799
800
801
802
803
804
805
        else:
            # Compressed-tensors style format:
            # {"quant_algo": "...", "quant_method": "modelopt"}
            quant_method = config.get("quant_algo", "")

            # Handle kv_cache_quant_algo with proper type validation
            kv_cache_quant_algo_raw = config.get("kv_cache_quant_algo")
            if kv_cache_quant_algo_raw is None:
                # No KV cache quantization by default
                kv_cache_quant_algo = None
            elif isinstance(kv_cache_quant_algo_raw, str):
                kv_cache_quant_algo = kv_cache_quant_algo_raw
            else:
806
807
808
809
                raise ValueError(
                    f"kv_cache_quant_algo must be a string, got "
                    f"{type(kv_cache_quant_algo_raw)}"
                )
810
811
812
813
814
815
816
817
818
819
820

            # Handle group_size with proper type validation
            group_size_raw = config.get("group_size")
            if group_size_raw is None:
                group_size = 16  # Default value
            elif isinstance(group_size_raw, int):
                group_size = group_size_raw
            else:
                try:
                    group_size = int(group_size_raw)
                except (ValueError, TypeError):
821
822
823
                    raise ValueError(
                        f"group_size must be an integer, got {type(group_size_raw)}"
                    ) from None
824

825
826
            # "ignore" is the key in config.json
            exclude_modules = config.get("ignore", [])
827
            if not isinstance(exclude_modules, list):
828
829
830
                raise ValueError(
                    f"exclude_modules must be a list, got {type(exclude_modules)}"
                )
831

832
        if quant_method not in QUANT_ALGOS:
833
834
835
836
            raise ValueError(
                f"ModelOpt currently only supports: {QUANT_ALGOS} "
                "quantizations in vLLM. Please check the "
                "`hf_quant_config.json` file for your model's "
837
838
839
                "quant configuration."
            )
        is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
840
841
842
843
844

        # For FP4, these fields are required
        if is_checkpoint_nvfp4_serialized and "quantization" in config:
            # Check if required fields are present in the quantization config
            quant_config = config["quantization"]
845
            required_fields = ["group_size", "kv_cache_quant_algo", "exclude_modules"]
846
847
848
849
850
851
            missing_fields = [
                field for field in required_fields if field not in quant_config
            ]
            if missing_fields:
                raise ValueError(
                    f"NVFP4 quantization requires the following fields in "
852
853
854
855
856
857
858
859
860
                    f"hf_quant_config.json: {missing_fields}"
                )

        return cls(
            is_checkpoint_nvfp4_serialized,
            kv_cache_quant_algo,
            exclude_modules,
            group_size,
        )
861

862
863
864
865
866
867
    def is_layer_excluded(self, prefix: str) -> bool:
        """
        Check if a layer should be excluded from quantization.
        Handles both exact matching (for fused layers) and pattern matching.
        """
        # First check exact matching with fused layer support
868
        if is_layer_skipped(prefix, self.exclude_modules, self.packed_modules_mapping):
869
870
871
            return True

        # Check regex pattern matching for patterns not caught by exact match
872
        import regex as re
873

874
875
        for pattern in self.exclude_modules:
            # Skip patterns that would be caught by exact matching
876
877
            if "*" in pattern or "." in pattern:
                regex_str = pattern.replace(".", r"\.").replace("*", r".*")
878
879
                if re.fullmatch(regex_str, prefix):
                    return True
880
881
        return False

882
883
884
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
885
        from vllm.attention.layer import Attention  # Avoid circular import
886

887
        skip_layer = self.is_layer_excluded(prefix)
888
        if isinstance(layer, LinearBase):
889
            if skip_layer:
890
891
                return UnquantizedLinearMethod()
            # Check if this is a vision model layer that should not be quantized
892
            if "vision_tower" in prefix or "vision_model" in prefix:
893
894
895
896
                return UnquantizedLinearMethod()
            return ModelOptNvFp4LinearMethod(self)
        elif isinstance(layer, Attention):
            return ModelOptFp8KVCacheMethod(self)
897
        elif isinstance(layer, FusedMoE):
898
899
            if skip_layer:
                return None
900
            return ModelOptNvFp4FusedMoE(self, layer.moe_config, layer)
901
902
903
904
905
906
907
908
        return None


class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
    """

909
    def __init__(self, quant_config: Union[ModelOptFp8Config, ModelOptNvFp4Config]):
910
911
912
913
914
915
        super().__init__(quant_config)


class ModelOptNvFp4LinearMethod(LinearMethodBase):
    """Linear method for Model Optimizer NVFP4.
    Supports loading NVFP4 checkpoints with the following structure:
916

917
918
919
920
921
922
923
    input_scale: torch.float32, scalar ,
    weight: NVFP4(represented as byte) Shape: [1, X, y/2]
    weight_scale: FP8-E4M3, Shape: [X, Y], aka per block scale,
    weight_scale_2: torch.float32, scalar,
    Args: quant_config: The ModelOpt quantization config.
    """

924
    def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
925
        self.quant_config = quant_config
926

927
928
929
930
931
932
933
934
935
936
        if envs.VLLM_USE_TRTLLM_FP4_GEMM:
            assert has_flashinfer(), "TRTLLM FP4 GEMM requires FlashInfer"
            self.backend = "flashinfer-trtllm"
        elif has_flashinfer():
            self.backend = "flashinfer-cutlass"
        elif cutlass_fp4_supported():
            self.backend = "cutlass"
        elif is_fp4_marlin_supported():
            self.backend = "marlin"
        else:
937
938
939
940
941
            raise ValueError(
                "Current platform does not support NVFP4"
                " quantization. Please use Blackwell and"
                " above."
            )
942
943
944
945
946

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
947
        output_partition_sizes: list[int],
948
949
950
951
952
953
954
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size
        if not self.quant_config.is_checkpoint_nvfp4_serialized:
955
956
957
958
            raise ValueError(
                "NVFP4 quantization was selected, "
                " dynamic quantization is not supported."
            )
959
960
961
962
963
964
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition

965
966
967
968
        if input_size_per_partition % 16 != 0:
            raise ValueError(
                "Unsupported model when in features size is not multiple of 16"
            )
969
        # The nvfp4 weight is still represented as
970
971
972
973
974
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_nvfp4_serialized
            else params_dtype
        )
975
976
977
978
979
980
        # Weight
        weight = ModelWeightParameter(
            data=torch.empty(
                # 2 fp4 items are packed in the input dimension
                layer.output_size_per_partition,
                layer.input_size_per_partition // 2,
981
982
                dtype=torch.uint8,
            ),
983
984
            input_dim=1,
            output_dim=0,
985
986
            weight_loader=weight_loader,
        )
987
988
989
        layer.register_parameter("weight", weight)

        # Input Weight Scale
990
991
992
993
        input_scale = PerTensorScaleParameter(
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )
994
995
996
        layer.register_parameter("input_scale", input_scale)

        # Global Weight Scale
997
998
999
1000
        weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )
1001
1002
1003
        layer.register_parameter("weight_scale_2", weight_scale_2)

        # Per Block Weight Scale
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
        weight_scale = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition // self.quant_config.group_size,
                dtype=weight_dtype,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024

        layer.register_parameter("weight_scale", weight_scale)

    def process_weights_after_loading(self, layer: Module) -> None:
        # global scales:
        input_scale_2 = layer.input_scale.max().to(torch.float32)
        layer.input_scale = Parameter(input_scale_2, requires_grad=False)

        weight_scale_2 = layer.weight_scale_2.max().to(torch.float32)
        layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False)

1025
1026
1027
        layer.alpha = Parameter(
            layer.input_scale * layer.weight_scale_2, requires_grad=False
        )
1028

1029
1030
        # Calculate `1 / input_scale` so that we don't need to do so at runtime
        layer.input_scale_inv = Parameter(
1031
1032
            (1 / layer.input_scale).to(torch.float32), requires_grad=False
        )
1033

1034
1035
1036
        # Swizzle the weight blockscale.
        # contracting dimension is input dimension
        # block_size = 16;
1037
1038
1039
        assert layer.weight_scale.dtype == torch.float8_e4m3fn, (
            "Weight Block scale must be represented as FP8-E4M3"
        )
1040

1041
1042
1043
1044
1045
        if self.backend == "marlin":
            prepare_fp4_layer_for_marlin(layer)
            del layer.alpha
            del layer.input_scale
        elif self.backend == "flashinfer-trtllm":
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
            # FlashInfer TRTLLM FP4 GEMM requires a different weight layout.
            # FlashInfer provides nvfp4_quantize to quantize + shuffle the
            # layout but we use our own quantization so we have to call
            # shuffles ourselves.
            from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a

            weight = layer.weight.data
            weight_scale = layer.weight_scale.data

            epilogue_tile_m = 128
1056
1057
1058
1059
1060
1061
            weight = shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m)
            weight_scale = (
                shuffle_matrix_sf_a(weight_scale.view(torch.uint8), epilogue_tile_m)
                .reshape(weight_scale.shape)
                .view(torch.float8_e4m3fn)
            )
1062

1063
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
1064
1065
1066
            layer.weight = Parameter(weight, requires_grad=False)
        else:
            swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
1067
            layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False)
1068
            layer.weight = Parameter(layer.weight.data, requires_grad=False)
1069
1070
1071
1072
1073
1074
1075

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
1076
        if self.backend == "marlin":
1077
1078
1079
1080
1081
1082
1083
1084
            return apply_fp4_marlin_linear(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale,
                weight_scale_2=layer.weight_scale_2,
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
1085
1086
                bias=bias,
            )
1087

1088
        output_dtype = x.dtype
1089
        output_shape = [x.shape[0], layer.weight.shape[0]]
1090
1091

        # quantize BF16 or FP16 to (FP4 and interleaved block scale)
1092
        x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_scale_inv)
1093
1094
1095

        # validate dtypes of quantized input, input block scale,
        # weight and weight_blockscale
1096
1097
1098
1099
1100
        assert x_fp4.dtype == torch.uint8
        assert layer.weight.dtype == torch.uint8
        assert x_blockscale.dtype == torch.float8_e4m3fn
        assert layer.weight_scale.dtype == torch.float8_e4m3fn
        assert layer.alpha.dtype == torch.float32
1101

1102
1103
1104
1105
        mm_args = (
            x_fp4,
            layer.weight,
            x_blockscale,
1106
            layer.weight_scale,
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
            layer.alpha,
            output_dtype,
        )
        if self.backend == "flashinfer-trtllm":
            out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm")
        elif self.backend == "flashinfer-cutlass":
            out = flashinfer_scaled_fp4_mm(*mm_args, backend="cutlass")
        else:
            out = cutlass_scaled_fp4_mm(*mm_args)

1117
1118
1119
        if bias is not None:
            out = out + bias
        return out.view(*output_shape)
1120
1121


1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
def _get_tile_tokens_dim(num_tokens: int, top_k: int, num_experts: int) -> int:
    # Guess tokens per expert assuming perfect expert distribution first.
    num_tokens_per_expert = (num_tokens * top_k) // num_experts
    # And pad the number to the next power of 2.
    tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
    # Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
    tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
    return tile_tokens_dim


1132
1133
1134
class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
    """
    MoE Method for FP4 Quantization.
1135
    Args:
1136
1137
1138
        quant_config: NVFP4 Quant Config
    """

1139
1140
1141
1142
1143
1144
    def __init__(
        self,
        quant_config: ModelOptNvFp4Config,
        moe: FusedMoEConfig,
        layer: torch.nn.Module,
    ) -> None:
1145
        from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import (  # noqa: E501
1146
1147
1148
            detect_nvfp4_moe_support,
        )

1149
1150
1151
        super().__init__(moe)
        self.quant_config = quant_config
        self.layer = layer
1152
1153
        _nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
        self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
1154
        self.allow_flashinfer = _nvfp4.allow_flashinfer
1155
        self.use_marlin = _nvfp4.use_marlin
1156
        self.flashinfer_moe_backend = None
1157
        self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
1158
        if self.allow_flashinfer:
1159
1160
1161
            self.flashinfer_moe_backend = get_flashinfer_moe_backend()
            logger.info_once(
                f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
1162
1163
                " for ModelOptNvFp4FusedMoE."
            )
1164

1165
1166
1167
1168
1169
    def maybe_make_prepare_finalize(self) -> Optional[mk.FusedMoEPrepareAndFinalize]:
        if self.use_marlin or (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1170
            return None
1171
1172
1173
1174
        elif (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
        ):
1175
            # For now, fp4 moe only works with the flashinfer dispatcher.
1176
1177
1178
            prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(
                self.moe
            )
1179
1180
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
1181
1182
        else:
            return super().maybe_make_prepare_finalize()
1183

1184
1185
1186
    def select_gemm_impl(
        self,
        prepare_finalize: mk.FusedMoEPrepareAndFinalize,
1187
        layer: torch.nn.Module,
1188
    ) -> mk.FusedMoEPermuteExpertsUnpermute:
1189
        assert self.moe_quant_config is not None
1190
        experts = select_nvfp4_gemm_impl(
1191
1192
            self.moe,
            self.moe_quant_config,
1193
1194
1195
1196
            allow_flashinfer=self.allow_flashinfer,
        )
        logger.debug_once("Using %s", experts.__class__.__name__)
        return experts
1197

1198
1199
1200
1201
1202
1203
    def uses_weight_scale_2_pattern(self) -> bool:
        """
        FP4 variants use 'weight_scale_2' pattern for per-tensor weight scales.
        """
        return True

1204
1205
1206
1207
1208
1209
1210
1211
1212
    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
1213
        if not self.quant_config.is_checkpoint_nvfp4_serialized:
1214
1215
1216
1217
            raise ValueError(
                "NVFP4 quantization was selected, "
                " dynamic quantization is not supported."
            )
1218

1219
1220
        layer.num_experts = num_experts
        layer.params_dtype = params_dtype
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
        layer.quant_config = self.quant_config
        weight_dtype = torch.uint8
        weight_scale_dtype = torch.float8_e4m3fn
        weight_loader = extra_weight_attrs.get("weight_loader")
        # GEMM 1
        w13_weight = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                # 2 fp4 items are packed in the input dimension
                hidden_size // 2,
1232
1233
                dtype=weight_dtype,
            ),
1234
1235
            input_dim=1,
            output_dim=2,
1236
1237
            weight_loader=weight_loader,
        )
1238
1239
1240
1241
1242
1243
1244
1245
1246
        layer.register_parameter("w13_weight", w13_weight)

        # GEMM 2
        w2_weight = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                hidden_size,
                # 2 fp4 items are packed in the input dimension
                intermediate_size_per_partition // 2,
1247
1248
                dtype=weight_dtype,
            ),
1249
1250
            input_dim=1,
            output_dim=2,
1251
1252
            weight_loader=weight_loader,
        )
1253
1254
1255
1256
1257
1258
1259
1260
        layer.register_parameter("w2_weight", w2_weight)

        w13_weight_scale = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                # 2 fp4 items are packed in the input dimension
                hidden_size // self.quant_config.group_size,
1261
1262
                dtype=weight_scale_dtype,
            ),
1263
1264
            input_dim=1,
            output_dim=2,
1265
1266
            weight_loader=weight_loader,
        )
1267
1268
1269
1270
1271
1272
1273
        layer.register_parameter("w13_weight_scale", w13_weight_scale)

        w2_weight_scale = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                hidden_size,
                # 2 fp4 items are packed in the input dimension
1274
1275
1276
                intermediate_size_per_partition // self.quant_config.group_size,
                dtype=weight_scale_dtype,
            ),
1277
1278
            input_dim=1,
            output_dim=2,
1279
1280
            weight_loader=weight_loader,
        )
1281
1282
1283
        layer.register_parameter("w2_weight_scale", w2_weight_scale)

        extra_weight_attrs.update(
1284
1285
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
        )
1286
1287
1288

        w13_weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(num_experts, 2, dtype=torch.float32),
1289
1290
            weight_loader=weight_loader,
        )
1291
1292
1293
1294
        layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)

        w2_weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(num_experts, dtype=torch.float32),
1295
1296
            weight_loader=weight_loader,
        )
1297
1298
1299
        layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2)

        extra_weight_attrs.update(
1300
1301
            {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
        )
1302

1303
1304
1305
1306
        w13_input_scale = PerTensorScaleParameter(
            data=torch.empty(num_experts, 2, dtype=torch.float32),
            weight_loader=weight_loader,
        )
1307
1308
        layer.register_parameter("w13_input_scale", w13_input_scale)

1309
1310
1311
1312
        w2_input_scale = PerTensorScaleParameter(
            data=torch.empty(num_experts, dtype=torch.float32),
            weight_loader=weight_loader,
        )
1313
1314
        layer.register_parameter("w2_input_scale", w2_input_scale)

1315
    def prepare_static_weights_for_trtllm_fp4_moe(
1316
        self,
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
        # args_dequant,
        # args,
        gemm1_weights,
        gemm2_weights,
        gemm1_scales_linear_fp4_bytes,
        gemm2_scales_linear_fp4_bytes,
        hidden_size,
        intermediate_size,
        num_experts,
    ):
        from flashinfer import nvfp4_block_scale_interleave
        from flashinfer.fused_moe.core import (
            _maybe_get_cached_w2_permute_indices,
1330
1331
1332
            _maybe_get_cached_w3_w1_permute_indices,
        )

1333
1334
1335
1336
1337
        """Prepare quantized weights for kernel (done offline with weights)."""
        epilogue_tile_m = 128  # FIXME: this depends on the kernel internals

        # Convert quantized weights to proper formats
        gemm1_weights_fp4 = gemm1_weights.view(torch.float8_e4m3fn).reshape(
1338
1339
            num_experts, 2 * intermediate_size, hidden_size // 2
        )  # packed fp4
1340
        gemm1_scales_linear_fp4 = gemm1_scales_linear_fp4_bytes.view(
1341
1342
1343
1344
            torch.float8_e4m3fn
        ).reshape(
            num_experts, 2 * intermediate_size, hidden_size // 16
        )  # fp8 scaling factors
1345
1346

        gemm2_weights_fp4 = gemm2_weights.view(torch.float8_e4m3fn).reshape(
1347
1348
            num_experts, hidden_size, intermediate_size // 2
        )  # packed fp4
1349
        gemm2_scales_linear_fp4 = gemm2_scales_linear_fp4_bytes.view(
1350
1351
1352
1353
            torch.float8_e4m3fn
        ).reshape(
            num_experts, hidden_size, intermediate_size // 16
        )  # fp8 scaling factors
1354
1355
1356
1357
1358
1359

        gemm1_weights_fp4_shuffled = []
        gemm1_scales_fp4_shuffled = []
        gemm2_weights_fp4_shuffled = []
        gemm2_scales_fp4_shuffled = []
        for i in range(num_experts):
1360
1361
1362
1363
1364
1365
1366
1367
1368
            # Calculate the permute indices for the following:
            # 1. Reorder rows of W1 and scales for fused gated activation
            # 2. Shuffle weights and scaling factors for transposed mma output
            # for both w3_w1 and w2 weights and scale factors
            permute_indices = _maybe_get_cached_w3_w1_permute_indices(
                self._cache_permute_indices,
                gemm1_weights_fp4[i].view(torch.uint8),
                epilogue_tile_m,
            )
1369
1370
1371
1372
1373
            gemm1_weights_fp4_shuffled.append(
                gemm1_weights_fp4[i]
                .view(torch.uint8)[permute_indices.to(gemm1_weights_fp4.device)]
                .contiguous()
            )
1374
1375
1376
1377
1378
1379
1380

            permute_sf_indices = _maybe_get_cached_w3_w1_permute_indices(
                self._cache_permute_indices,
                gemm1_scales_linear_fp4[i].view(torch.uint8),
                epilogue_tile_m,
                num_elts_per_sf=16,
            )
1381
            gemm1_scales_fp4_shuffled.append(
1382
1383
1384
1385
1386
1387
1388
1389
                nvfp4_block_scale_interleave(
                    gemm1_scales_linear_fp4[i]
                    .view(torch.uint8)[
                        permute_sf_indices.to(gemm1_scales_linear_fp4.device)
                    ]
                    .contiguous()
                )
            )
1390
1391
1392
1393
1394
1395

            permute_indices = _maybe_get_cached_w2_permute_indices(
                self._cache_permute_indices,
                gemm2_weights_fp4[i].view(torch.uint8),
                epilogue_tile_m,
            )
1396
1397
1398
1399
1400
            gemm2_weights_fp4_shuffled.append(
                gemm2_weights_fp4[i]
                .view(torch.uint8)[permute_indices.to(gemm2_weights_fp4.device)]
                .contiguous()
            )
1401
1402
1403
1404
1405
1406
1407

            permute_sf_indices = _maybe_get_cached_w2_permute_indices(
                self._cache_permute_indices,
                gemm2_scales_linear_fp4[i].view(torch.uint8),
                epilogue_tile_m,
                num_elts_per_sf=16,
            )
1408
            gemm2_scales_fp4_shuffled.append(
1409
1410
1411
1412
1413
1414
1415
1416
                nvfp4_block_scale_interleave(
                    gemm2_scales_linear_fp4[i]
                    .view(torch.uint8)[
                        permute_sf_indices.to(gemm2_scales_linear_fp4.device)
                    ]
                    .contiguous()
                )
            )
1417
1418
1419
1420

        # Stack weights for all experts
        gemm1_weights_fp4_shuffled = torch.stack(gemm1_weights_fp4_shuffled)
        gemm1_scales_fp4_shuffled = (
1421
1422
1423
1424
            torch.stack(gemm1_scales_fp4_shuffled)
            .view(torch.float8_e4m3fn)
            .reshape(num_experts, 2 * intermediate_size, hidden_size // 16)
        )
1425
1426
1427

        gemm2_weights_fp4_shuffled = torch.stack(gemm2_weights_fp4_shuffled)
        gemm2_scales_fp4_shuffled = (
1428
1429
1430
1431
            torch.stack(gemm2_scales_fp4_shuffled)
            .view(torch.float8_e4m3fn)
            .reshape(num_experts, hidden_size, intermediate_size // 16)
        )
1432
1433
1434
1435
1436
1437
        return (
            gemm1_weights_fp4_shuffled,
            gemm1_scales_fp4_shuffled,
            gemm2_weights_fp4_shuffled,
            gemm2_scales_fp4_shuffled,
        )
1438

1439
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
1440
        # GEMM 1 processing
1441
1442
1443
        gemm1_weight = layer.w13_weight.data
        gemm1_weight_scale = layer.w13_weight_scale.data

1444
        if self.allow_flashinfer:
1445
            gemm1_weight, gemm1_weight_scale = reorder_w1w3_to_w3w1(
1446
1447
                gemm1_weight, gemm1_weight_scale, dim=-2
            )
1448
1449

        layer.w13_weight = Parameter(gemm1_weight, requires_grad=False)
1450
        layer.w13_weight_scale = Parameter(gemm1_weight_scale, requires_grad=False)
1451

1452
        # Common processing for w13_weight_scale_2
1453
1454
1455
        if not torch.allclose(
            layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]
        ):
1456
1457
            logger.warning_once(
                "w1_weight_scale_2 must match w3_weight_scale_2. "
1458
1459
                "Accuracy may be affected."
            )
1460
1461

        w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
1462
        layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False)
1463

1464
        # Common processing for input scales and alphas
1465
        w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
1466
1467
        layer.g1_alphas = Parameter(
            (w13_input_scale * w13_weight_scale_2).to(torch.float32),
1468
1469
            requires_grad=False,
        )
1470
1471
1472

        # This is for quantization, so we need to invert it.
        layer.w13_input_scale_quant = Parameter(
1473
1474
            (1 / w13_input_scale).to(torch.float32), requires_grad=False
        )
1475

1476
        # GEMM 2 processing
1477
1478
        layer.g2_alphas = Parameter(
            (layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
1479
1480
            requires_grad=False,
        )
1481
1482
1483

        # This is for quantization, so we need to invert it.
        layer.w2_input_scale_quant = Parameter(
1484
1485
            (1 / layer.w2_input_scale).to(torch.float32), requires_grad=False
        )
1486

1487
        # TensorRT-LLM specific processing
1488
1489
1490
1491
        if (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1492
            # Prepare static weights for TRT-LLM kernel
1493
            # alternate: prepare_static_weight_layouts_for_trtllm_moe
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
            (
                gemm1_weights_fp4_shuffled,
                gemm1_scales_fp4_shuffled,
                gemm2_weights_fp4_shuffled,
                gemm2_scales_fp4_shuffled,
            ) = self.prepare_static_weights_for_trtllm_fp4_moe(
                layer.w13_weight,
                layer.w2_weight,
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                layer.w2_weight.size(-2),  # hidden_size
                layer.w13_weight.size(-2) // 2,  # intermediate_size
                layer.w13_weight.size(0),  # num_experts
            )
1508
            logger.debug_once("Finished shuffling weights for TRT-LLM MOE")
1509
1510

            layer.gemm1_weights_fp4_shuffled = Parameter(
1511
1512
                gemm1_weights_fp4_shuffled, requires_grad=False
            )
1513
            layer.gemm2_weights_fp4_shuffled = Parameter(
1514
1515
                gemm2_weights_fp4_shuffled, requires_grad=False
            )
1516
            layer.gemm1_scales_fp4_shuffled = Parameter(
1517
1518
                gemm1_scales_fp4_shuffled, requires_grad=False
            )
1519
            layer.gemm2_scales_fp4_shuffled = Parameter(
1520
1521
                gemm2_scales_fp4_shuffled, requires_grad=False
            )
1522
1523
1524

            # Additional parameter needed for TRT-LLM
            layer.g1_scale_c = Parameter(
1525
                (layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32),
1526
1527
                requires_grad=False,
            )
1528

1529
1530
1531
1532
1533
            # Clean up weights that won't be used by TRT-LLM
            del layer.w2_weight
            del layer.w2_weight_scale
            del layer.w13_weight
            del layer.w13_weight_scale
1534
1535
1536
1537
1538
1539
1540
        elif self.use_marlin:
            # Marlin processing
            prepare_moe_fp4_layer_for_marlin(layer)
            del layer.g1_alphas
            del layer.g2_alphas
            del layer.w13_input_scale_quant
            del layer.w2_input_scale_quant
1541
1542
        else:
            # Non-TRT-LLM processing (Cutlass or non-flashinfer)
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
            assert layer.w13_weight_scale.shape[2] % 16 == 0, (
                "Expected weight_scale.dim(1) to be divisible by 16"
            )
            assert layer.w13_weight_scale.dtype == torch.float8_e4m3fn, (
                "Weight Blockscale must be represented as FP8-E4M3"
            )
            w13_blockscale_swizzled = swizzle_blockscale(layer.w13_weight_scale)
            layer.w13_weight_scale = Parameter(
                w13_blockscale_swizzled, requires_grad=False
            )

            assert layer.w2_weight_scale.shape[2] % 16 == 0, (
                "Expected weight_scale.dim(1) to be divisible by 16"
            )
            assert layer.w2_weight_scale.dtype == torch.float8_e4m3fn, (
                "Weight Blockscale must be represented as FP8-E4M3"
            )
1560
            w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale)
1561
1562
1563
1564
            layer.w2_weight_scale = Parameter(
                w2_blockscale_swizzled, requires_grad=False
            )
            layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
1565

1566
    def get_fused_moe_quant_config(
1567
1568
1569
1570
1571
1572
        self, layer: torch.nn.Module
    ) -> Optional[FusedMoEQuantConfig]:
        if (
            self.use_marlin
            or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
            return None

        return nvfp4_moe_quant_config(
            w1_scale=layer.w13_weight_scale,
            w2_scale=layer.w2_weight_scale,
            g1_alphas=layer.g1_alphas,
            g2_alphas=layer.g2_alphas,
            a1_gscale=layer.w13_input_scale_quant,
            a2_gscale=layer.w2_input_scale_quant,
        )

1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
        use_grouped_topk: bool = False,
        topk_group: Optional[int] = None,
        num_expert_group: Optional[int] = None,
        global_num_experts: int = -1,
        expert_map: Optional[torch.Tensor] = None,
        custom_routing_function: Optional[Callable] = None,
        scoring_func: str = "softmax",
1598
        routed_scaling_factor: float = 1.0,
1599
1600
1601
        e_score_correction_bias: Optional[torch.Tensor] = None,
        apply_router_weight_on_input: bool = False,
        activation: str = "silu",
1602
1603
1604
1605
        enable_eplb: bool = False,
        expert_load_view: Optional[torch.Tensor] = None,
        logical_to_physical_map: Optional[torch.Tensor] = None,
        logical_replica_count: Optional[torch.Tensor] = None,
1606
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
1607
1608
        if enable_eplb:
            raise NotImplementedError(
1609
1610
                "EPLB not supported for `ModelOptNvFp4FusedMoE` yet."
            )
1611
        assert activation == "silu", "Only SiLU activation is supported."
1612

1613
1614
1615
1616
        if (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1617
1618
1619
1620
            import flashinfer

            from vllm.model_executor.models.llama4 import Llama4MoE

1621
1622
            assert self.fused_experts is None

1623
            a1_gscale = layer.w13_input_scale_quant
1624
1625
1626
1627
1628
1629
1630
1631
            (hidden_states_fp4, hidden_states_scale_linear_fp4) = (
                flashinfer.fp4_quantize(
                    x,
                    a1_gscale,
                    is_sf_swizzled_layout=False,
                )
            )
            use_llama4_routing = (
1632
                custom_routing_function is Llama4MoE.custom_routing_function
1633
            )
1634
1635
1636
            routing_method_type = flashinfer.RoutingMethodType.DeepSeekV3
            if use_llama4_routing:
                routing_method_type = flashinfer.RoutingMethodType.Llama4
Shu Wang's avatar
Shu Wang committed
1637
1638
1639
            routing_bias = e_score_correction_bias
            if routing_bias is not None:
                routing_bias = routing_bias.to(torch.bfloat16)
1640
1641
            out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
                routing_logits=router_logits
1642
1643
                if use_llama4_routing
                else router_logits.to(torch.float32),
Shu Wang's avatar
Shu Wang committed
1644
                routing_bias=routing_bias,
1645
1646
                hidden_states=hidden_states_fp4,
                hidden_states_scale=hidden_states_scale_linear_fp4.view(
1647
1648
                    torch.float8_e4m3fn
                ).flatten(),
1649
1650
                gemm1_weights=layer.gemm1_weights_fp4_shuffled.data,
                gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.data.view(
1651
1652
                    torch.float8_e4m3fn
                ),
1653
1654
1655
1656
1657
1658
                gemm1_bias=None,
                gemm1_alpha=None,
                gemm1_beta=None,
                gemm1_clamp_limit=None,
                gemm2_weights=layer.gemm2_weights_fp4_shuffled.data,
                gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.data.view(
1659
1660
                    torch.float8_e4m3fn
                ),
1661
1662
1663
1664
1665
1666
                gemm2_bias=None,
                output1_scale_scalar=layer.g1_scale_c.data,
                output1_scale_gate_scalar=layer.g1_alphas.data,
                output2_scale_scalar=layer.g2_alphas.data,
                num_experts=global_num_experts,
                top_k=top_k,
1667
                n_group=num_expert_group if num_expert_group is not None else 0,
1668
                topk_group=topk_group if topk_group is not None else 0,
1669
1670
1671
1672
                intermediate_size=layer.intermediate_size_per_partition,
                local_expert_offset=layer.ep_rank * layer.local_num_experts,
                local_num_experts=layer.local_num_experts,
                routed_scaling_factor=None,
1673
1674
1675
                tile_tokens_dim=_get_tile_tokens_dim(
                    x.shape[0], top_k, layer.local_num_experts
                ),
1676
1677
1678
1679
1680
                routing_method_type=routing_method_type,
                do_finalize=True,
            )[0]
            return out

XuruiYang's avatar
XuruiYang committed
1681
        topk_weights, topk_ids, _ = FusedMoE.select_experts(
1682
1683
1684
1685
1686
1687
1688
1689
1690
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
            num_expert_group=num_expert_group,
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
1691
            routed_scaling_factor=routed_scaling_factor,
1692
            e_score_correction_bias=e_score_correction_bias,
1693
1694
            indices_type=self.topk_indices_dtype,
        )
1695

1696
1697
1698
1699
1700
        #
        # Note: the order here is important. self.fused_experts can override
        # flashinfer cutlass, cutlass fp4 or fused_experts but not marlin or
        # trtllm.
        #
1701
        if self.use_marlin:
1702
            assert self.fused_experts is None
1703
1704
1705
1706
            return torch.ops.vllm.fused_marlin_moe(
                x,
                layer.w13_weight,
                layer.w2_weight,
1707
1708
                None,
                None,
1709
1710
1711
1712
1713
1714
1715
1716
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                router_logits,
                topk_weights,
                topk_ids,
                global_scale1=layer.w13_weight_scale_2,
                global_scale2=layer.w2_weight_scale_2,
                quant_type_id=scalar_types.float4_e2m1f.id,
1717
                apply_router_weight_on_input=apply_router_weight_on_input,
1718
                global_num_experts=global_num_experts,
1719
                expert_map=expert_map,
1720
1721
                workspace=layer.workspace,
            )
1722

1723
        elif self.fused_experts is not None:
1724
1725
1726
1727
            assert (
                self.allow_flashinfer
                and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
            )
1728
1729

            assert is_valid_flashinfer_cutlass_fused_moe(
1730
1731
                x, layer.w13_weight, layer.w2_weight
            ), "Flashinfer CUTLASS Fused MoE not applicable!"
1732

1733
            return self.fused_experts(
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                inplace=False,  # TODO(shuw): fix later, now output is high prec
                activation=activation,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
1745
1746
1747
1748
        elif (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
        ):
1749
            from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (  # noqa: E501
1750
1751
1752
                flashinfer_cutlass_moe_fp4,
            )

1753
            assert self.moe_quant_config is not None
1754

1755
            return flashinfer_cutlass_moe_fp4(
1756
1757
1758
1759
1760
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
1761
1762
                quant_config=self.moe_quant_config,
                inplace=False,
1763
1764
1765
1766
1767
1768
                activation=activation,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
        else:
1769
1770
            # If no modular kernel is provided, use cutlass_moe_fp4 for TP case
            # only (no EP).
1771
1772
            from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4

1773
1774
            assert self.moe_quant_config is not None
            return cutlass_moe_fp4(
1775
1776
1777
1778
1779
                a=x,
                w1_fp4=layer.w13_weight,
                w2_fp4=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
1780
1781
1782
1783
                quant_config=self.moe_quant_config,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
                # TODO: derive from arguments
1784
1785
1786
1787
                m=x.shape[0],
                n=layer.w2_weight.shape[2] * 2,
                k=x.shape[1],
                e=layer.w13_weight.shape[0],
1788
            )