modelopt.py 68.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Optional
6
7
8
9
10

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter

11
12
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
13
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
14
from vllm.logger import init_logger
15
from vllm.model_executor.layers.fused_moe.config import (
16
17
18
19
20
    FusedMoEConfig,
    FusedMoEQuantConfig,
    fp8_w8a8_moe_quant_config,
    nvfp4_moe_quant_config,
)
21
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
22
23
    is_valid_flashinfer_cutlass_fused_moe,
)
24
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
25
from vllm.model_executor.layers.fused_moe.layer import (
26
27
28
29
30
31
32
33
34
    FusedMoE,
    FusedMoEMethodBase,
    FusedMoeWeightScaleSupported,
)
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
)
35
from vllm.model_executor.layers.quantization import QuantizationMethods
36
from vllm.model_executor.layers.quantization.base_config import (
37
38
39
    QuantizationConfig,
    QuantizeMethodBase,
)
40
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
41
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
42
43
44
45
    build_flashinfer_fp4_cutlass_moe_prepare_finalize,
    reorder_w1w3_to_w3w1,
    select_nvfp4_gemm_impl,
)
46
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
47
48
    FlashinferMoeBackend,
    apply_flashinfer_per_tensor_scale_fp8,
49
    build_flashinfer_fp8_cutlass_moe_prepare_finalize,
50
51
    flashinfer_cutlass_moe_fp8,
    get_flashinfer_moe_backend,
52
    is_flashinfer_supporting_global_sf,
53
54
55
56
57
    register_moe_scaling_factors,
    rotate_flashinfer_fp8_moe_weights,
    select_cutlass_fp8_gemm_impl,
    swap_w13_to_w31,
)
58
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
59
60
61
62
63
    apply_fp4_marlin_linear,
    is_fp4_marlin_supported,
    prepare_fp4_layer_for_marlin,
    prepare_moe_fp4_layer_for_marlin,
)
64
from vllm.model_executor.layers.quantization.utils.quant_utils import (
65
66
67
68
69
    GroupShape,
    cutlass_fp4_supported,
    is_layer_skipped,
    swizzle_blockscale,
)
70
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
71
72
73
74
    Fp8LinearOp,
    requantize_with_max_scale,
)
from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
75
from vllm.scalar_type import scalar_types
76
77
78
79
80
from vllm.utils.flashinfer import (
    flashinfer_scaled_fp4_mm,
    has_flashinfer,
    has_flashinfer_moe,
)
81

82
83
84
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

85
86
logger = init_logger(__name__)

87
88
QUANT_ALGOS = ["FP8", "NVFP4"]
KV_CACHE_QUANT_ALGOS = ["FP8"]
89
90
91
92
93
94
95
96


class ModelOptFp8Config(QuantizationConfig):
    """Config class for ModelOpt FP8."""

    def __init__(
        self,
        is_checkpoint_fp8_serialized: bool = False,
97
98
        kv_cache_quant_method: str | None = None,
        exclude_modules: list[str] | None = None,
99
    ) -> None:
100
        super().__init__()
101
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
102
        self.kv_cache_quant_method = kv_cache_quant_method
103
        self.exclude_modules = exclude_modules or []
104
        if is_checkpoint_fp8_serialized:
105
106
107
108
            logger.warning(
                "Detected ModelOpt fp8 checkpoint. Please note that"
                " the format is experimental and could change."
            )
109
110

    @classmethod
111
    def get_name(cls) -> QuantizationMethods:
112
113
114
        return "modelopt"

    @classmethod
115
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
116
117
118
119
120
121
122
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
        return 89

    @classmethod
123
    def get_config_filenames(cls) -> list[str]:
124
125
        return ["hf_quant_config.json"]

126
127
    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if self.exclude_modules is not None:
128
            self.exclude_modules = hf_to_vllm_mapper.apply_list(self.exclude_modules)
129

130
131
    @classmethod
    def override_quantization_method(
132
        cls, hf_quant_cfg, user_quant
133
    ) -> QuantizationMethods | None:
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
        """Detect if this ModelOpt config should be used based on
        quantization config."""

        if hf_quant_cfg is None:
            return None

        # Use the community standard 'quant_method'
        quant_method = hf_quant_cfg.get("quant_method", "").lower()

        # Only proceed if the method is explicitly "modelopt"
        if quant_method != "modelopt":
            return None

        # Look for ModelOpt-specific config structure
        if "quantization" in hf_quant_cfg:
            quant_config = hf_quant_cfg["quantization"]
            if isinstance(quant_config, dict):
                quant_algo = quant_config.get("quant_algo", "")
                if "FP8" in quant_algo:
                    return "modelopt"
        else:
            # Check for compressed-tensors style config with specific quant_algo
            quant_algo = hf_quant_cfg.get("quant_algo", "")
            if isinstance(quant_algo, str) and "FP8" in quant_algo:
                return "modelopt"

        return None

162
    @classmethod
163
    def from_config(cls, config: dict[str, Any]) -> "ModelOptFp8Config":
164
165
166
167
168
        # Handle both ModelOpt format and compressed-tensors style format
        if "quantization" in config:
            # ModelOpt format: {"quantization": {"quant_algo": "..."}}
            quant_config = cls.get_from_keys(config, ["quantization"])
            if not isinstance(quant_config, dict):
169
                raise ValueError("Expected 'quantization' to be a dictionary in config")
170
171
172
173
            quant_method = quant_config.get("quant_algo", "")
            if not quant_method:
                raise ValueError("Missing 'quant_algo' in quantization config")
            kv_cache_quant_method = quant_config.get("kv_cache_quant_algo")
174
            # "exclude_modules" is the key in the legacy hf_quant_config.json
175
176
177
178
179
180
            exclude_modules = quant_config.get("exclude_modules")
        else:
            # Compressed-tensors style format:
            # {"quant_algo": "...", "quant_method": "modelopt"}
            quant_method = config.get("quant_algo", "")
            kv_cache_quant_method = config.get("kv_cache_quant_algo")
181
182
            # "ignore" is the key in config.json
            exclude_modules = config.get("ignore")
183

184
        if quant_method not in QUANT_ALGOS:
185
186
187
188
            raise ValueError(
                f"ModelOpt currently only supports: {QUANT_ALGOS} "
                "quantizations in vLLM. Please check the "
                "`hf_quant_config.json` file for your model's "
189
190
191
                "quant configuration."
            )
        is_checkpoint_fp8_serialized = "FP8" in quant_method
192

193
        return cls(is_checkpoint_fp8_serialized, kv_cache_quant_method, exclude_modules)
194
195
196
197

    def is_layer_excluded(self, prefix: str) -> bool:
        """
        Check if a layer should be excluded from quantization.
198
        Handles both exact matching (for fused layers) and substring matching.
199
200
201
202
203
204
205
206

        This method handles both regular models and multimodal models that use
        the language_model prefix. For multimodal models, it checks if the
        module name (without the language_model prefix) is in the exclude list.
        """
        if self.exclude_modules is None:
            return False

207
        # First check exact matching with fused layer support
208
        if is_layer_skipped(prefix, self.exclude_modules, self.packed_modules_mapping):
209
210
211
            return True

        # Then check substring matching for patterns not caught by exact match
212
        for module in self.exclude_modules:
213
            # Skip exact matches already handled above
214
215
216
217
218
219
220
            if module != prefix and (
                module in prefix
                or (
                    prefix.startswith("language_model.")
                    and module in prefix.removeprefix("language_model.")
                )
            ):
221
222
                return True
        return False
223

224
225
226
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
227
        from vllm.attention.layer import Attention  # Avoid circular import
228

229
        if isinstance(layer, LinearBase):
230
231
232
            if self.is_layer_excluded(prefix):
                return UnquantizedLinearMethod()
            # Check if this is a vision model layer that should not be quantized
233
            if "vision_tower" in prefix or "vision_model" in prefix:
234
                return UnquantizedLinearMethod()
235
236
237
            return ModelOptFp8LinearMethod(self)
        elif isinstance(layer, Attention):
            return ModelOptFp8KVCacheMethod(self)
238
        elif isinstance(layer, FusedMoE):
239
            return ModelOptFp8MoEMethod(self, layer)
240
241
242
243
244
245
        return None


class ModelOptFp8LinearMethod(LinearMethodBase):
    """Linear method for Model Optimizer static quantization.
    Supports loading FP8 checkpoints with static weight scale and
246
    activation scale. Future support might be added for dynamic
247
248
249
250
    scales.

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
251
    2. Only support float8_e4m3fn datatype
252
253
254
        Args: quant_config: The ModelOpt quantization config.
    """

255
    def __init__(self, quant_config: ModelOptFp8Config) -> None:
256
        self.quant_config = quant_config
257
        self.fp8_linear = Fp8LinearOp(
258
259
            act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR
        )
260
261
262
263
264

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
265
        output_partition_sizes: list[int],
266
267
268
269
270
271
272
273
274
275
276
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
277
278
279
280
281
282
283
284
285
286
287
288
289
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_fp8_serialized
            else params_dtype
        )
        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition, input_size_per_partition, dtype=weight_dtype
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
290
291
292
293
        layer.register_parameter("weight", weight)

        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
294
295
296
297
            weight_scale = PerTensorScaleParameter(
                data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
                weight_loader=weight_loader,
            )
298
299
300
            weight_scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("weight_scale", weight_scale)
            # INPUT SCALE
301
302
303
304
            scale = PerTensorScaleParameter(
                data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
                weight_loader=weight_loader,
            )
305
306
307
308
309

            scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("input_scale", scale)

    def process_weights_after_loading(self, layer: Module) -> None:
310
311
312
313
        weight = layer.weight
        max_w_scale = layer.weight_scale.max()
        if not (layer.weight_scale == layer.weight_scale[0]).all():
            max_w_scale, weight = requantize_with_max_scale(
314
315
                layer.weight, layer.weight_scale, layer.logical_widths
            )
316
317
        layer.weight = Parameter(weight.t(), requires_grad=False)
        layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
318
        layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
319
320
321
322
323

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
324
        bias: torch.Tensor | None = None,
325
    ) -> torch.Tensor:
326
327
328
329
330
331
332
        return self.fp8_linear.apply(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            input_scale=layer.input_scale,
            bias=bias,
        )
333
334


335
336
337
338
339
340
341
342
class ModelOptFp8MoEMethod(FusedMoEMethodBase):
    """MoE method for ModelOpt FP8.
    Supports loading FP8 checkpoints with static weight scale and
    activation scale.
    Args:
        quant_config: The ModelOpt quantization config.
    """

343
344
345
    def __init__(
        self,
        quant_config: ModelOptFp8Config,
346
        layer: torch.nn.Module,
347
    ) -> None:
348
349
        super().__init__(layer.moe_config)
        self.layer = layer
350
351
        self.quant_config = quant_config
        from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
352
353
354
            cutlass_fp8_supported,
        )

355
        self.cutlass_fp8_supported = cutlass_fp8_supported()
356
        self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
357
358
359
360
361
        if (
            envs.VLLM_USE_FLASHINFER_MOE_FP8
            and has_flashinfer_moe()
            and self.moe.is_act_and_mul
        ):
362
            self.flashinfer_moe_backend = get_flashinfer_moe_backend()
363
            logger.info_once(
364
365
366
367
                f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
            )

    def maybe_make_prepare_finalize(
368
        self,
369
    ) -> mk.FusedMoEPrepareAndFinalize | None:
370
371
372
373
        # TRT LLM not supported with all2all yet.
        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
            return None
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
374
375
376
            prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
                self.moe
            )
377
378
379
380
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
        else:
            return super().maybe_make_prepare_finalize()
381
382
383
384

    def select_gemm_impl(
        self,
        prepare_finalize: mk.FusedMoEPrepareAndFinalize,
385
        layer: torch.nn.Module,
386
    ) -> mk.FusedMoEPermuteExpertsUnpermute:
387
        assert self.moe_quant_config is not None
388
        experts = select_cutlass_fp8_gemm_impl(
389
390
            self.moe,
            self.moe_quant_config,
391
392
393
        )
        logger.debug_once("Using %s", experts.__class__.__name__)
        return experts
394
395
396
397
398
399
400
401
402
403
404

    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        # Use FP8 dtype if checkpoint is serialized
405
406
407
408
409
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_fp8_serialized
            else params_dtype
        )
410
411
        weight_loader = extra_weight_attrs.get("weight_loader")

412
413
414
415
416
        if self.moe.is_act_and_mul:
            w13_up_dim = 2 * intermediate_size_per_partition
        else:
            w13_up_dim = intermediate_size_per_partition

417
        w13_weight = ModelWeightParameter(
418
419
            data=torch.empty(
                num_experts,
420
                w13_up_dim,
421
422
423
                hidden_size,
                dtype=weight_dtype,
            ),
424
425
426
427
428
429
430
            input_dim=2,
            output_dim=1,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_weight", w13_weight)

        w2_weight = ModelWeightParameter(
431
432
433
434
435
436
            data=torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=weight_dtype,
            ),
437
438
439
440
441
442
443
444
            input_dim=2,
            output_dim=1,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w2_weight", w2_weight)

        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALES - Per-tensor scaling for ModelOpts
445
            # For gated MoE, allocate 2 scales for w1 and w3 respectively.
446
            # They will be combined to a single scale after weight loading.
447
448
449
450
451
            # For non-gated MoE, allocate 1 scale for w13.
            if self.moe.is_act_and_mul:
                w13_weight_scale_shape = (num_experts, 2)
            else:
                w13_weight_scale_shape = (num_experts, 1)
452
453
            w13_weight_scale = PerTensorScaleParameter(
                data=torch.full(
454
                    w13_weight_scale_shape,
455
456
457
458
459
460
                    1.0,
                    dtype=torch.float32,
                ),
                weight_loader=weight_loader,
            )
            w2_weight_scale = PerTensorScaleParameter(
461
                data=torch.full((num_experts,), 1.0, dtype=torch.float32),
462
463
464
465
466
467
468
                weight_loader=weight_loader,
            )
            layer.register_parameter("w13_weight_scale", w13_weight_scale)
            layer.register_parameter("w2_weight_scale", w2_weight_scale)

            # Set weight loader attributes for scales
            extra_weight_attrs.update(
469
470
                {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
            )
471
472
473

            # INPUT SCALES - Per-tensor scaling for ModelOpt
            w13_input_scale = PerTensorScaleParameter(
474
                data=torch.full((num_experts,), 1.0, dtype=torch.float32),
475
476
477
                weight_loader=weight_loader,
            )
            w2_input_scale = PerTensorScaleParameter(
478
                data=torch.full((num_experts,), 1.0, dtype=torch.float32),
479
480
481
482
483
484
485
486
487
488
                weight_loader=weight_loader,
            )
            layer.register_parameter("w13_input_scale", w13_input_scale)
            layer.register_parameter("w2_input_scale", w2_input_scale)

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        """Process FP8 MoE weights after loading from serialized checkpoint.
        Only supports pre-quantized checkpoints with FP8 weights and scales.
        """

489
        layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
490
491
492
493
        layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)

        from vllm._custom_ops import scaled_fp8_quant
        from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
494
495
            per_tensor_dequantize,
        )
496
497

        # Handle scale parameters
498
        if hasattr(layer, "w13_weight_scale") and layer.w13_weight_scale is not None:
499
500
501
            # Fp8 moe kernel needs single weight scale for w13 per expert.
            # We take the max of the w1 and w3 scales
            # then dequant and requant each expert.
502
503
504
505
506
507
508
509
            if (
                layer.w13_weight_scale.dim() == 2
                and layer.w13_weight_scale.shape[1] == 2
            ):
                assert self.moe.is_act_and_mul, (
                    "w13_weight_scale should have 2 elements per expert "
                    "only for gated MoE"
                )
510
511
512
513
514
515
516
517
518
519
520
521
                # Get the maximum scale across w1 and w3 for each expert
                max_w13_scales = layer.w13_weight_scale.max(dim=1).values

                # Requantize each expert's weights using the combined scale
                # w13_weight (num_experts, 2 * intermediate_size, hidden_size)
                # where the first intermediate_size rows are w1, the next are w3
                intermediate_size = layer.w13_weight.shape[1] // 2
                for expert_id in range(layer.w13_weight.shape[0]):
                    start = 0
                    for shard_id in range(2):  # w1 and w3
                        # Dequantize using the original scale for this shard
                        dq_weight = per_tensor_dequantize(
522
523
524
                            layer.w13_weight[expert_id][
                                start : start + intermediate_size, :
                            ],
525
526
527
528
529
                            layer.w13_weight_scale[expert_id][shard_id],
                        )
                        # Requantize using the combined max scale

                        (
530
531
532
                            layer.w13_weight[expert_id][
                                start : start + intermediate_size, :
                            ],
533
                            _,
534
                        ) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
535
536
537
538

                        start += intermediate_size

                # Update the scale parameter to be per-expert
539
                layer.w13_weight_scale = Parameter(max_w13_scales, requires_grad=False)
540
            else:
541
542
543
                layer.w13_weight_scale = Parameter(
                    layer.w13_weight_scale.data, requires_grad=False
                )
544

545
546
547
548
        if hasattr(layer, "w2_weight_scale") and layer.w2_weight_scale is not None:
            layer.w2_weight_scale = Parameter(
                layer.w2_weight_scale.data, requires_grad=False
            )
549
        # Input scales must be equal for each expert in fp8 MoE layers.
550
551
552
553
554
555
556
557
        if hasattr(layer, "w13_input_scale") and layer.w13_input_scale is not None:
            layer.w13_input_scale = Parameter(
                layer.w13_input_scale.max(), requires_grad=False
            )
        if hasattr(layer, "w2_input_scale") and layer.w2_input_scale is not None:
            layer.w2_input_scale = Parameter(
                layer.w2_input_scale.max(), requires_grad=False
            )
558

559
        if self.flashinfer_moe_backend is not None:
560
            layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
561
            register_moe_scaling_factors(layer)
562
            if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
563
                rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight)
564

565
    def get_fused_moe_quant_config(
566
        self, layer: torch.nn.Module
567
    ) -> FusedMoEQuantConfig | None:
568
569
570
571
572
573
574
575
576
577
578
        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
            return None

        return fp8_w8a8_moe_quant_config(
            w1_scale=layer.w13_weight_scale,
            w2_scale=layer.w2_weight_scale,
            a1_scale=layer.w13_input_scale,
            a2_scale=layer.w2_input_scale,
            per_act_token_quant=False,
        )

579
580
581
582
583
584
585
586
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
        use_grouped_topk: bool = False,
587
588
        topk_group: int | None = None,
        num_expert_group: int | None = None,
589
        global_num_experts: int = -1,
590
591
        expert_map: torch.Tensor | None = None,
        custom_routing_function: Callable | None = None,
592
        scoring_func: str = "softmax",
593
        routed_scaling_factor: float = 1.0,
594
        e_score_correction_bias: torch.Tensor | None = None,
595
596
597
        apply_router_weight_on_input: bool = False,
        activation: str = "silu",
        enable_eplb: bool = False,
598
599
600
601
        expert_load_view: torch.Tensor | None = None,
        logical_to_physical_map: torch.Tensor | None = None,
        logical_replica_count: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
602
603
        if enable_eplb:
            raise NotImplementedError(
604
605
                "EPLB not supported for `ModelOptFp8MoEMethod` yet."
            )
606

607
        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
608
            assert self.fused_experts is None
609
610
611
            assert activation == "silu", (
                f"Expected 'silu' activation but got {activation}"
            )
612
613
614
615
616
617
618
619
620
621
            assert not renormalize
            return apply_flashinfer_per_tensor_scale_fp8(
                layer=layer,
                hidden_states=x,
                router_logits=router_logits,
                routing_bias=e_score_correction_bias,
                global_num_experts=global_num_experts,
                top_k=top_k,
                num_expert_group=num_expert_group,
                topk_group=topk_group,
622
623
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
624

625
        # Expert selection
XuruiYang's avatar
XuruiYang committed
626
        topk_weights, topk_ids, _ = FusedMoE.select_experts(
627
628
629
630
631
632
633
634
635
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
            num_expert_group=num_expert_group,
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
636
            routed_scaling_factor=routed_scaling_factor,
637
            e_score_correction_bias=e_score_correction_bias,
638
            indices_type=self.topk_indices_dtype,
639
        )
640

641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
        #
        # Note: the order here is important. self.fused_experts can override
        # cutlass or fused_experts.
        #
        if self.fused_experts is not None:
            return self.fused_experts(
                x,
                layer.w13_weight,
                layer.w2_weight,
                topk_weights,
                topk_ids,
                inplace=False,
                activation=activation,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
659
            assert not renormalize
660
661
662
            assert activation == "silu", (
                f"Expected 'silu' activation but got {activation}"
            )
663
664
665
666
667
668
669
670
671
672
673
674
            return flashinfer_cutlass_moe_fp8(
                x,
                layer,
                topk_weights,
                topk_ids,
                inplace=False,
                activation=activation,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
        else:
675
676
            from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts

677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
            assert self.moe_quant_config is not None

            return fused_experts(
                x,
                layer.w13_weight,
                layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                inplace=True,
                activation=activation,
                quant_config=self.moe_quant_config,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
692
693


694
695
696
697
698
699
class ModelOptNvFp4Config(QuantizationConfig):
    """Config class for ModelOpt FP4."""

    def __init__(
        self,
        is_checkpoint_nvfp4_serialized: bool,
700
        kv_cache_quant_algo: str | None,
701
        exclude_modules: list[str],
702
703
        group_size: int = 16,
    ) -> None:
704
        super().__init__()
705
706
707
708
        self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
        if is_checkpoint_nvfp4_serialized:
            logger.warning(
                "Detected ModelOpt NVFP4 checkpoint. Please note that"
709
710
                " the format is experimental and could change in future."
            )
711
712
713
714
715
716

            self.group_size = group_size
            self.kv_cache_quant_algo = kv_cache_quant_algo
            self.exclude_modules = exclude_modules

    @classmethod
717
    def get_name(cls) -> QuantizationMethods:
718
        return "modelopt_fp4"
719
720

    @classmethod
721
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
722
723
724
725
        return [torch.bfloat16, torch.half, torch.float8_e4m3fn]

    @classmethod
    def get_min_capability(cls) -> int:
726
        return 80
727
728

    @classmethod
729
    def get_config_filenames(cls) -> list[str]:
730
731
        return ["hf_quant_config.json"]

732
733
    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if self.exclude_modules is not None:
734
            self.exclude_modules = hf_to_vllm_mapper.apply_list(self.exclude_modules)
735

736
737
    @classmethod
    def override_quantization_method(
738
        cls, hf_quant_cfg, user_quant
739
    ) -> QuantizationMethods | None:
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
        """Detect if this ModelOpt FP4 config should be used based on
        quantization config."""
        if hf_quant_cfg is None:
            return None

        # Use the community standard 'quant_method'
        quant_method = hf_quant_cfg.get("quant_method", "").lower()

        # Only proceed if the method is explicitly "modelopt"
        if quant_method != "modelopt":
            return None

        # Look for ModelOpt-specific config structure
        if "quantization" in hf_quant_cfg:
            quant_config = hf_quant_cfg["quantization"]
            if isinstance(quant_config, dict):
                quant_algo = quant_config.get("quant_algo", "")
                if "NVFP4" in quant_algo:
                    return "modelopt_fp4"
        else:
            # Check for compressed-tensors style config with specific
            # quant_algo field
            quant_algo = hf_quant_cfg.get("quant_algo", "")
            if isinstance(quant_algo, str) and "FP4" in quant_algo.upper():
                return "modelopt_fp4"

        return None

768
    @classmethod
769
    def from_config(cls, config: dict[str, Any]) -> "ModelOptNvFp4Config":
770
771
772
773
774
775
776
        # Handle both traditional ModelOpt format and compressed-tensors
        # style format
        if "quantization" in config:
            # Traditional ModelOpt format:
            # {"quantization": {"quant_algo": "..."}}
            quant_config = cls.get_from_keys(config, ["quantization"])
            if not isinstance(quant_config, dict):
777
                raise ValueError("Expected 'quantization' to be a dictionary in config")
778
779
780
781
782
783
784
785
786
787
788
789
790

            quant_method = quant_config.get("quant_algo", "")
            if not quant_method:
                raise ValueError("Missing 'quant_algo' in quantization config")

            # Handle kv_cache_quant_algo with proper type validation
            kv_cache_quant_algo_raw = quant_config.get("kv_cache_quant_algo")
            if kv_cache_quant_algo_raw is None:
                # No KV cache quantization by default
                kv_cache_quant_algo = None
            elif isinstance(kv_cache_quant_algo_raw, str):
                kv_cache_quant_algo = kv_cache_quant_algo_raw
            else:
791
792
793
794
                raise ValueError(
                    f"kv_cache_quant_algo must be a string, got "
                    f"{type(kv_cache_quant_algo_raw)}"
                )
795
796
797
798
799
800
801
802
803
804
805

            # Handle group_size with proper type validation
            group_size_raw = quant_config.get("group_size")
            if group_size_raw is None:
                group_size = 16  # Default value
            elif isinstance(group_size_raw, int):
                group_size = group_size_raw
            else:
                try:
                    group_size = int(group_size_raw)
                except (ValueError, TypeError):
806
807
808
                    raise ValueError(
                        f"group_size must be an integer, got {type(group_size_raw)}"
                    ) from None
809

810
            # "exclude_modules" is the key in the legacy hf_quant_config.json
811
812
            exclude_modules = quant_config.get("exclude_modules", [])
            if not isinstance(exclude_modules, list):
813
814
815
                raise ValueError(
                    f"exclude_modules must be a list, got {type(exclude_modules)}"
                )
816
817
818
819
820
821
822
823
824
825
826
827
828
        else:
            # Compressed-tensors style format:
            # {"quant_algo": "...", "quant_method": "modelopt"}
            quant_method = config.get("quant_algo", "")

            # Handle kv_cache_quant_algo with proper type validation
            kv_cache_quant_algo_raw = config.get("kv_cache_quant_algo")
            if kv_cache_quant_algo_raw is None:
                # No KV cache quantization by default
                kv_cache_quant_algo = None
            elif isinstance(kv_cache_quant_algo_raw, str):
                kv_cache_quant_algo = kv_cache_quant_algo_raw
            else:
829
830
831
832
                raise ValueError(
                    f"kv_cache_quant_algo must be a string, got "
                    f"{type(kv_cache_quant_algo_raw)}"
                )
833
834
835
836
837
838
839
840
841
842
843

            # Handle group_size with proper type validation
            group_size_raw = config.get("group_size")
            if group_size_raw is None:
                group_size = 16  # Default value
            elif isinstance(group_size_raw, int):
                group_size = group_size_raw
            else:
                try:
                    group_size = int(group_size_raw)
                except (ValueError, TypeError):
844
845
846
                    raise ValueError(
                        f"group_size must be an integer, got {type(group_size_raw)}"
                    ) from None
847

848
849
            # "ignore" is the key in config.json
            exclude_modules = config.get("ignore", [])
850
            if not isinstance(exclude_modules, list):
851
852
853
                raise ValueError(
                    f"exclude_modules must be a list, got {type(exclude_modules)}"
                )
854

855
        if quant_method not in QUANT_ALGOS:
856
857
858
859
            raise ValueError(
                f"ModelOpt currently only supports: {QUANT_ALGOS} "
                "quantizations in vLLM. Please check the "
                "`hf_quant_config.json` file for your model's "
860
861
862
                "quant configuration."
            )
        is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
863
864
865
866
867

        # For FP4, these fields are required
        if is_checkpoint_nvfp4_serialized and "quantization" in config:
            # Check if required fields are present in the quantization config
            quant_config = config["quantization"]
868
            required_fields = ["group_size", "kv_cache_quant_algo", "exclude_modules"]
869
870
871
872
873
874
            missing_fields = [
                field for field in required_fields if field not in quant_config
            ]
            if missing_fields:
                raise ValueError(
                    f"NVFP4 quantization requires the following fields in "
875
876
877
878
879
880
881
882
883
                    f"hf_quant_config.json: {missing_fields}"
                )

        return cls(
            is_checkpoint_nvfp4_serialized,
            kv_cache_quant_algo,
            exclude_modules,
            group_size,
        )
884

885
886
887
888
889
890
    def is_layer_excluded(self, prefix: str) -> bool:
        """
        Check if a layer should be excluded from quantization.
        Handles both exact matching (for fused layers) and pattern matching.
        """
        # First check exact matching with fused layer support
891
        if is_layer_skipped(prefix, self.exclude_modules, self.packed_modules_mapping):
892
893
894
            return True

        # Check regex pattern matching for patterns not caught by exact match
895
        import regex as re
896

897
898
        for pattern in self.exclude_modules:
            # Skip patterns that would be caught by exact matching
899
900
            if "*" in pattern or "." in pattern:
                regex_str = pattern.replace(".", r"\.").replace("*", r".*")
901
902
                if re.fullmatch(regex_str, prefix):
                    return True
903
904
        return False

905
906
907
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
908
        from vllm.attention.layer import Attention  # Avoid circular import
909

910
        skip_layer = self.is_layer_excluded(prefix)
911
        if isinstance(layer, LinearBase):
912
            if skip_layer:
913
914
                return UnquantizedLinearMethod()
            # Check if this is a vision model layer that should not be quantized
915
            if "vision_tower" in prefix or "vision_model" in prefix:
916
917
918
919
                return UnquantizedLinearMethod()
            return ModelOptNvFp4LinearMethod(self)
        elif isinstance(layer, Attention):
            return ModelOptFp8KVCacheMethod(self)
920
        elif isinstance(layer, FusedMoE):
921
922
            if skip_layer:
                return None
923
            return ModelOptNvFp4FusedMoE(self, layer.moe_config, layer)
924
925
926
927
928
929
930
931
        return None


class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
    """

932
    def __init__(self, quant_config: ModelOptFp8Config | ModelOptNvFp4Config):
933
934
935
936
937
938
        super().__init__(quant_config)


class ModelOptNvFp4LinearMethod(LinearMethodBase):
    """Linear method for Model Optimizer NVFP4.
    Supports loading NVFP4 checkpoints with the following structure:
939

940
941
942
943
944
945
946
    input_scale: torch.float32, scalar ,
    weight: NVFP4(represented as byte) Shape: [1, X, y/2]
    weight_scale: FP8-E4M3, Shape: [X, Y], aka per block scale,
    weight_scale_2: torch.float32, scalar,
    Args: quant_config: The ModelOpt quantization config.
    """

947
    def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
948
        self.quant_config = quant_config
949

950
951
952
953
954
955
956
957
958
959
960
961
962
        self.backend = "none"
        if envs.VLLM_NVFP4_GEMM_BACKEND is None:
            if has_flashinfer():
                self.backend = "flashinfer-cutlass"
            elif cutlass_fp4_supported():
                self.backend = "cutlass"
            elif is_fp4_marlin_supported():
                self.backend = "marlin"
        elif envs.VLLM_NVFP4_GEMM_BACKEND.startswith("flashinfer-"):
            self.backend = envs.VLLM_NVFP4_GEMM_BACKEND
            assert has_flashinfer(), f"FlashInfer is required for {self.backend}"

        if self.backend == "none":
963
            raise ValueError(
964
965
                "No valid NVFP4 GEMM backend found. "
                "Please check your platform capability."
966
            )
967

968
969
        logger.info_once(f"Using {self.backend} for NVFP4 GEMM")

970
971
972
973
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
974
        output_partition_sizes: list[int],
975
976
977
978
979
980
981
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size
        if not self.quant_config.is_checkpoint_nvfp4_serialized:
982
983
984
985
            raise ValueError(
                "NVFP4 quantization was selected, "
                " dynamic quantization is not supported."
            )
986
987
988
989
990
991
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition

992
993
994
995
        if input_size_per_partition % 16 != 0:
            raise ValueError(
                "Unsupported model when in features size is not multiple of 16"
            )
996
        # The nvfp4 weight is still represented as
997
998
999
1000
1001
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_nvfp4_serialized
            else params_dtype
        )
1002
1003
1004
1005
1006
1007
        # Weight
        weight = ModelWeightParameter(
            data=torch.empty(
                # 2 fp4 items are packed in the input dimension
                layer.output_size_per_partition,
                layer.input_size_per_partition // 2,
1008
1009
                dtype=torch.uint8,
            ),
1010
1011
            input_dim=1,
            output_dim=0,
1012
1013
            weight_loader=weight_loader,
        )
1014
1015
1016
        layer.register_parameter("weight", weight)

        # Input Weight Scale
1017
1018
1019
1020
        input_scale = PerTensorScaleParameter(
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )
1021
1022
1023
        layer.register_parameter("input_scale", input_scale)

        # Global Weight Scale
1024
1025
1026
1027
        weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )
1028
1029
1030
        layer.register_parameter("weight_scale_2", weight_scale_2)

        # Per Block Weight Scale
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
        weight_scale = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition // self.quant_config.group_size,
                dtype=weight_dtype,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051

        layer.register_parameter("weight_scale", weight_scale)

    def process_weights_after_loading(self, layer: Module) -> None:
        # global scales:
        input_scale_2 = layer.input_scale.max().to(torch.float32)
        layer.input_scale = Parameter(input_scale_2, requires_grad=False)

        weight_scale_2 = layer.weight_scale_2.max().to(torch.float32)
        layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False)

1052
1053
1054
        layer.alpha = Parameter(
            layer.input_scale * layer.weight_scale_2, requires_grad=False
        )
1055

1056
1057
        # Calculate `1 / input_scale` so that we don't need to do so at runtime
        layer.input_scale_inv = Parameter(
1058
1059
            (1 / layer.input_scale).to(torch.float32), requires_grad=False
        )
1060

1061
1062
1063
        # Swizzle the weight blockscale.
        # contracting dimension is input dimension
        # block_size = 16;
1064
1065
1066
        assert layer.weight_scale.dtype == torch.float8_e4m3fn, (
            "Weight Block scale must be represented as FP8-E4M3"
        )
1067

1068
1069
1070
1071
1072
        if self.backend == "marlin":
            prepare_fp4_layer_for_marlin(layer)
            del layer.alpha
            del layer.input_scale
        elif self.backend == "flashinfer-trtllm":
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
            # FlashInfer TRTLLM FP4 GEMM requires a different weight layout.
            # FlashInfer provides nvfp4_quantize to quantize + shuffle the
            # layout but we use our own quantization so we have to call
            # shuffles ourselves.
            from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a

            weight = layer.weight.data
            weight_scale = layer.weight_scale.data

            epilogue_tile_m = 128
1083
1084
1085
1086
1087
1088
            weight = shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m)
            weight_scale = (
                shuffle_matrix_sf_a(weight_scale.view(torch.uint8), epilogue_tile_m)
                .reshape(weight_scale.shape)
                .view(torch.float8_e4m3fn)
            )
1089

1090
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
1091
1092
1093
            layer.weight = Parameter(weight, requires_grad=False)
        else:
            swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
1094
            layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False)
1095
            layer.weight = Parameter(layer.weight.data, requires_grad=False)
1096
1097
1098
1099
1100

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
1101
        bias: torch.Tensor | None = None,
1102
    ) -> torch.Tensor:
1103
        if self.backend == "marlin":
1104
1105
1106
1107
1108
1109
1110
1111
            return apply_fp4_marlin_linear(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale,
                weight_scale_2=layer.weight_scale_2,
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
1112
1113
                bias=bias,
            )
1114

1115
        output_dtype = x.dtype
1116
        output_shape = [x.shape[0], layer.weight.shape[0]]
1117
1118

        # quantize BF16 or FP16 to (FP4 and interleaved block scale)
1119
        x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_scale_inv)
1120
1121
1122

        # validate dtypes of quantized input, input block scale,
        # weight and weight_blockscale
1123
1124
1125
1126
1127
        assert x_fp4.dtype == torch.uint8
        assert layer.weight.dtype == torch.uint8
        assert x_blockscale.dtype == torch.float8_e4m3fn
        assert layer.weight_scale.dtype == torch.float8_e4m3fn
        assert layer.alpha.dtype == torch.float32
1128

1129
1130
1131
1132
        mm_args = (
            x_fp4,
            layer.weight,
            x_blockscale,
1133
            layer.weight_scale,
1134
1135
1136
            layer.alpha,
            output_dtype,
        )
1137
1138
1139
        if self.backend.startswith("flashinfer-"):
            backend_name = self.backend[len("flashinfer-") :]
            out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name)
1140
        else:
1141
            assert self.backend == "cutlass"
1142
1143
            out = cutlass_scaled_fp4_mm(*mm_args)

1144
1145
1146
        if bias is not None:
            out = out + bias
        return out.view(*output_shape)
1147
1148
1149
1150
1151


class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
    """
    MoE Method for FP4 Quantization.
1152
    Args:
1153
1154
1155
        quant_config: NVFP4 Quant Config
    """

1156
1157
1158
1159
1160
1161
    def __init__(
        self,
        quant_config: ModelOptNvFp4Config,
        moe: FusedMoEConfig,
        layer: torch.nn.Module,
    ) -> None:
1162
        from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import (  # noqa: E501
1163
1164
1165
            detect_nvfp4_moe_support,
        )

1166
1167
1168
        super().__init__(moe)
        self.quant_config = quant_config
        self.layer = layer
1169
1170
        _nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
        self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
1171
        self.allow_flashinfer = _nvfp4.allow_flashinfer
1172
        self.use_marlin = _nvfp4.use_marlin
1173
        self.flashinfer_moe_backend = None
1174
        self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
1175
        if self.allow_flashinfer:
1176
1177
1178
            self.flashinfer_moe_backend = get_flashinfer_moe_backend()
            logger.info_once(
                f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
1179
1180
                " for ModelOptNvFp4FusedMoE."
            )
1181

1182
    def maybe_make_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize | None:
1183
1184
1185
1186
        if self.use_marlin or (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1187
            return None
1188
1189
1190
1191
        elif (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
        ):
1192
            # For now, fp4 moe only works with the flashinfer dispatcher.
1193
1194
1195
            prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(
                self.moe
            )
1196
1197
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
1198
1199
        else:
            return super().maybe_make_prepare_finalize()
1200

1201
1202
1203
    def select_gemm_impl(
        self,
        prepare_finalize: mk.FusedMoEPrepareAndFinalize,
1204
        layer: torch.nn.Module,
1205
    ) -> mk.FusedMoEPermuteExpertsUnpermute:
1206
        assert self.moe_quant_config is not None
1207
        experts = select_nvfp4_gemm_impl(
1208
1209
            self.moe,
            self.moe_quant_config,
1210
1211
1212
1213
            allow_flashinfer=self.allow_flashinfer,
        )
        logger.debug_once("Using %s", experts.__class__.__name__)
        return experts
1214

1215
1216
1217
1218
1219
1220
    def uses_weight_scale_2_pattern(self) -> bool:
        """
        FP4 variants use 'weight_scale_2' pattern for per-tensor weight scales.
        """
        return True

1221
1222
1223
1224
1225
1226
1227
1228
1229
    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
1230
        if not self.quant_config.is_checkpoint_nvfp4_serialized:
1231
1232
1233
1234
            raise ValueError(
                "NVFP4 quantization was selected, "
                " dynamic quantization is not supported."
            )
1235

1236
1237
        layer.num_experts = num_experts
        layer.params_dtype = params_dtype
1238
1239
1240
1241
        layer.quant_config = self.quant_config
        weight_dtype = torch.uint8
        weight_scale_dtype = torch.float8_e4m3fn
        weight_loader = extra_weight_attrs.get("weight_loader")
1242
        global_num_experts = extra_weight_attrs.get("global_num_experts")
1243
1244
1245
1246
1247
1248
1249
        # GEMM 1
        w13_weight = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                # 2 fp4 items are packed in the input dimension
                hidden_size // 2,
1250
1251
                dtype=weight_dtype,
            ),
1252
1253
            input_dim=1,
            output_dim=2,
1254
1255
            weight_loader=weight_loader,
        )
1256
1257
1258
1259
1260
1261
1262
1263
1264
        layer.register_parameter("w13_weight", w13_weight)

        # GEMM 2
        w2_weight = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                hidden_size,
                # 2 fp4 items are packed in the input dimension
                intermediate_size_per_partition // 2,
1265
1266
                dtype=weight_dtype,
            ),
1267
1268
            input_dim=1,
            output_dim=2,
1269
1270
            weight_loader=weight_loader,
        )
1271
1272
1273
1274
1275
1276
1277
1278
        layer.register_parameter("w2_weight", w2_weight)

        w13_weight_scale = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                # 2 fp4 items are packed in the input dimension
                hidden_size // self.quant_config.group_size,
1279
1280
                dtype=weight_scale_dtype,
            ),
1281
1282
            input_dim=1,
            output_dim=2,
1283
1284
            weight_loader=weight_loader,
        )
1285
1286
1287
1288
1289
1290
1291
        layer.register_parameter("w13_weight_scale", w13_weight_scale)

        w2_weight_scale = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                hidden_size,
                # 2 fp4 items are packed in the input dimension
1292
1293
1294
                intermediate_size_per_partition // self.quant_config.group_size,
                dtype=weight_scale_dtype,
            ),
1295
1296
            input_dim=1,
            output_dim=2,
1297
1298
            weight_loader=weight_loader,
        )
1299
1300
1301
        layer.register_parameter("w2_weight_scale", w2_weight_scale)

        extra_weight_attrs.update(
1302
1303
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
        )
1304
1305
1306

        w13_weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(num_experts, 2, dtype=torch.float32),
1307
1308
            weight_loader=weight_loader,
        )
1309
1310
1311
1312
        layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)

        w2_weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(num_experts, dtype=torch.float32),
1313
1314
            weight_loader=weight_loader,
        )
1315
1316
1317
        layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2)

        extra_weight_attrs.update(
1318
1319
            {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
        )
1320

1321
1322
1323
1324
1325
        use_global_sf = self.allow_flashinfer and is_flashinfer_supporting_global_sf(
            self.flashinfer_moe_backend
        )
        global_scale_num_experts = global_num_experts if use_global_sf else num_experts

1326
        w13_input_scale = PerTensorScaleParameter(
1327
            data=torch.empty(global_scale_num_experts, 2, dtype=torch.float32),
1328
1329
            weight_loader=weight_loader,
        )
1330
1331
        layer.register_parameter("w13_input_scale", w13_input_scale)

1332
        w2_input_scale = PerTensorScaleParameter(
1333
            data=torch.empty(global_scale_num_experts, dtype=torch.float32),
1334
1335
            weight_loader=weight_loader,
        )
1336
1337
        layer.register_parameter("w2_input_scale", w2_input_scale)

1338
    def prepare_static_weights_for_trtllm_fp4_moe(
1339
        self,
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
        # args_dequant,
        # args,
        gemm1_weights,
        gemm2_weights,
        gemm1_scales_linear_fp4_bytes,
        gemm2_scales_linear_fp4_bytes,
        hidden_size,
        intermediate_size,
        num_experts,
    ):
        from flashinfer import nvfp4_block_scale_interleave
        from flashinfer.fused_moe.core import (
1352
            _maybe_get_cached_w3_w1_permute_indices,
1353
            get_w2_permute_indices_with_cache,
1354
1355
        )

1356
1357
1358
1359
1360
        """Prepare quantized weights for kernel (done offline with weights)."""
        epilogue_tile_m = 128  # FIXME: this depends on the kernel internals

        # Convert quantized weights to proper formats
        gemm1_weights_fp4 = gemm1_weights.view(torch.float8_e4m3fn).reshape(
1361
1362
            num_experts, 2 * intermediate_size, hidden_size // 2
        )  # packed fp4
1363
        gemm1_scales_linear_fp4 = gemm1_scales_linear_fp4_bytes.view(
1364
1365
1366
1367
            torch.float8_e4m3fn
        ).reshape(
            num_experts, 2 * intermediate_size, hidden_size // 16
        )  # fp8 scaling factors
1368
1369

        gemm2_weights_fp4 = gemm2_weights.view(torch.float8_e4m3fn).reshape(
1370
1371
            num_experts, hidden_size, intermediate_size // 2
        )  # packed fp4
1372
        gemm2_scales_linear_fp4 = gemm2_scales_linear_fp4_bytes.view(
1373
1374
1375
1376
            torch.float8_e4m3fn
        ).reshape(
            num_experts, hidden_size, intermediate_size // 16
        )  # fp8 scaling factors
1377
1378
1379
1380
1381
1382

        gemm1_weights_fp4_shuffled = []
        gemm1_scales_fp4_shuffled = []
        gemm2_weights_fp4_shuffled = []
        gemm2_scales_fp4_shuffled = []
        for i in range(num_experts):
1383
1384
1385
1386
1387
1388
1389
1390
1391
            # Calculate the permute indices for the following:
            # 1. Reorder rows of W1 and scales for fused gated activation
            # 2. Shuffle weights and scaling factors for transposed mma output
            # for both w3_w1 and w2 weights and scale factors
            permute_indices = _maybe_get_cached_w3_w1_permute_indices(
                self._cache_permute_indices,
                gemm1_weights_fp4[i].view(torch.uint8),
                epilogue_tile_m,
            )
1392
1393
1394
1395
1396
            gemm1_weights_fp4_shuffled.append(
                gemm1_weights_fp4[i]
                .view(torch.uint8)[permute_indices.to(gemm1_weights_fp4.device)]
                .contiguous()
            )
1397
1398
1399
1400
1401
1402
1403

            permute_sf_indices = _maybe_get_cached_w3_w1_permute_indices(
                self._cache_permute_indices,
                gemm1_scales_linear_fp4[i].view(torch.uint8),
                epilogue_tile_m,
                num_elts_per_sf=16,
            )
1404
            gemm1_scales_fp4_shuffled.append(
1405
1406
1407
1408
1409
1410
1411
1412
                nvfp4_block_scale_interleave(
                    gemm1_scales_linear_fp4[i]
                    .view(torch.uint8)[
                        permute_sf_indices.to(gemm1_scales_linear_fp4.device)
                    ]
                    .contiguous()
                )
            )
1413

1414
            permute_indices = get_w2_permute_indices_with_cache(
1415
1416
1417
1418
                self._cache_permute_indices,
                gemm2_weights_fp4[i].view(torch.uint8),
                epilogue_tile_m,
            )
1419
1420
1421
1422
1423
            gemm2_weights_fp4_shuffled.append(
                gemm2_weights_fp4[i]
                .view(torch.uint8)[permute_indices.to(gemm2_weights_fp4.device)]
                .contiguous()
            )
1424

1425
            permute_sf_indices = get_w2_permute_indices_with_cache(
1426
1427
1428
1429
1430
                self._cache_permute_indices,
                gemm2_scales_linear_fp4[i].view(torch.uint8),
                epilogue_tile_m,
                num_elts_per_sf=16,
            )
1431
            gemm2_scales_fp4_shuffled.append(
1432
1433
1434
1435
1436
1437
1438
1439
                nvfp4_block_scale_interleave(
                    gemm2_scales_linear_fp4[i]
                    .view(torch.uint8)[
                        permute_sf_indices.to(gemm2_scales_linear_fp4.device)
                    ]
                    .contiguous()
                )
            )
1440
1441
1442
1443

        # Stack weights for all experts
        gemm1_weights_fp4_shuffled = torch.stack(gemm1_weights_fp4_shuffled)
        gemm1_scales_fp4_shuffled = (
1444
1445
1446
1447
            torch.stack(gemm1_scales_fp4_shuffled)
            .view(torch.float8_e4m3fn)
            .reshape(num_experts, 2 * intermediate_size, hidden_size // 16)
        )
1448
1449
1450

        gemm2_weights_fp4_shuffled = torch.stack(gemm2_weights_fp4_shuffled)
        gemm2_scales_fp4_shuffled = (
1451
1452
1453
1454
            torch.stack(gemm2_scales_fp4_shuffled)
            .view(torch.float8_e4m3fn)
            .reshape(num_experts, hidden_size, intermediate_size // 16)
        )
1455
1456
1457
1458
1459
1460
        return (
            gemm1_weights_fp4_shuffled,
            gemm1_scales_fp4_shuffled,
            gemm2_weights_fp4_shuffled,
            gemm2_scales_fp4_shuffled,
        )
1461

1462
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
1463
        # GEMM 1 processing
1464
1465
1466
        gemm1_weight = layer.w13_weight.data
        gemm1_weight_scale = layer.w13_weight_scale.data

1467
        if self.allow_flashinfer:
1468
            gemm1_weight, gemm1_weight_scale = reorder_w1w3_to_w3w1(
1469
1470
                gemm1_weight, gemm1_weight_scale, dim=-2
            )
1471
1472

        layer.w13_weight = Parameter(gemm1_weight, requires_grad=False)
1473
        layer.w13_weight_scale = Parameter(gemm1_weight_scale, requires_grad=False)
1474

1475
        # Common processing for w13_weight_scale_2
1476
1477
1478
        if not torch.allclose(
            layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]
        ):
1479
1480
            logger.warning_once(
                "w1_weight_scale_2 must match w3_weight_scale_2. "
1481
1482
                "Accuracy may be affected."
            )
1483
1484

        w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
1485
        layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False)
1486

1487
        # Common processing for input scales and alphas
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
        use_global_sf = self.allow_flashinfer and is_flashinfer_supporting_global_sf(
            self.flashinfer_moe_backend
        )
        if use_global_sf:
            # For backends provide by Flashinfer, the input global scales are
            # shared across all experts.
            w13_input_scale = (
                layer.w13_input_scale.max().to(torch.float32).expand(layer.num_experts)
            )
        else:
            w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
1499
1500
        layer.g1_alphas = Parameter(
            (w13_input_scale * w13_weight_scale_2).to(torch.float32),
1501
1502
            requires_grad=False,
        )
1503
1504
1505

        # This is for quantization, so we need to invert it.
        layer.w13_input_scale_quant = Parameter(
1506
1507
            (1 / w13_input_scale).to(torch.float32), requires_grad=False
        )
1508

1509
        # GEMM 2 processing
1510
1511
1512
1513
1514
1515
1516
1517
        if use_global_sf:
            # For backends provide by Flashinfer, the input global scales are
            # shared across all experts.
            w2_input_scale = (
                layer.w2_input_scale.max().to(torch.float32).expand(layer.num_experts)
            )
        else:
            w2_input_scale = layer.w2_input_scale
1518
        layer.g2_alphas = Parameter(
1519
            (w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
1520
1521
            requires_grad=False,
        )
1522
1523
1524

        # This is for quantization, so we need to invert it.
        layer.w2_input_scale_quant = Parameter(
1525
            (1 / w2_input_scale).to(torch.float32), requires_grad=False
1526
        )
1527

1528
        # TensorRT-LLM specific processing
1529
1530
1531
1532
        if (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1533
            # Prepare static weights for TRT-LLM kernel
1534
            # alternate: prepare_static_weight_layouts_for_trtllm_moe
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
            (
                gemm1_weights_fp4_shuffled,
                gemm1_scales_fp4_shuffled,
                gemm2_weights_fp4_shuffled,
                gemm2_scales_fp4_shuffled,
            ) = self.prepare_static_weights_for_trtllm_fp4_moe(
                layer.w13_weight,
                layer.w2_weight,
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                layer.w2_weight.size(-2),  # hidden_size
                layer.w13_weight.size(-2) // 2,  # intermediate_size
                layer.w13_weight.size(0),  # num_experts
            )
1549
            logger.debug_once("Finished shuffling weights for TRT-LLM MOE")
1550
1551

            layer.gemm1_weights_fp4_shuffled = Parameter(
1552
1553
                gemm1_weights_fp4_shuffled, requires_grad=False
            )
1554
            layer.gemm2_weights_fp4_shuffled = Parameter(
1555
1556
                gemm2_weights_fp4_shuffled, requires_grad=False
            )
1557
            layer.gemm1_scales_fp4_shuffled = Parameter(
1558
1559
                gemm1_scales_fp4_shuffled, requires_grad=False
            )
1560
            layer.gemm2_scales_fp4_shuffled = Parameter(
1561
1562
                gemm2_scales_fp4_shuffled, requires_grad=False
            )
1563
1564
1565

            # Additional parameter needed for TRT-LLM
            layer.g1_scale_c = Parameter(
1566
                (layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32),
1567
1568
                requires_grad=False,
            )
1569

1570
1571
1572
1573
1574
            # Clean up weights that won't be used by TRT-LLM
            del layer.w2_weight
            del layer.w2_weight_scale
            del layer.w13_weight
            del layer.w13_weight_scale
1575
1576
1577
1578
1579
1580
1581
        elif self.use_marlin:
            # Marlin processing
            prepare_moe_fp4_layer_for_marlin(layer)
            del layer.g1_alphas
            del layer.g2_alphas
            del layer.w13_input_scale_quant
            del layer.w2_input_scale_quant
1582
1583
        else:
            # Non-TRT-LLM processing (Cutlass or non-flashinfer)
1584
1585
1586
1587
1588
            w13_blockscale_swizzled = swizzle_blockscale(layer.w13_weight_scale)
            layer.w13_weight_scale = Parameter(
                w13_blockscale_swizzled, requires_grad=False
            )

1589
            w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale)
1590
1591
1592
1593
            layer.w2_weight_scale = Parameter(
                w2_blockscale_swizzled, requires_grad=False
            )
            layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
1594

1595
    def get_fused_moe_quant_config(
1596
        self, layer: torch.nn.Module
1597
    ) -> FusedMoEQuantConfig | None:
1598
1599
1600
1601
        if (
            self.use_marlin
            or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
            return None

        return nvfp4_moe_quant_config(
            w1_scale=layer.w13_weight_scale,
            w2_scale=layer.w2_weight_scale,
            g1_alphas=layer.g1_alphas,
            g2_alphas=layer.g2_alphas,
            a1_gscale=layer.w13_input_scale_quant,
            a2_gscale=layer.w2_input_scale_quant,
        )

1613
1614
1615
1616
1617
1618
1619
1620
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
        use_grouped_topk: bool = False,
1621
1622
        topk_group: int | None = None,
        num_expert_group: int | None = None,
1623
        global_num_experts: int = -1,
1624
1625
        expert_map: torch.Tensor | None = None,
        custom_routing_function: Callable | None = None,
1626
        scoring_func: str = "softmax",
1627
        routed_scaling_factor: float = 1.0,
1628
        e_score_correction_bias: torch.Tensor | None = None,
1629
1630
        apply_router_weight_on_input: bool = False,
        activation: str = "silu",
1631
        enable_eplb: bool = False,
1632
1633
1634
1635
        expert_load_view: torch.Tensor | None = None,
        logical_to_physical_map: torch.Tensor | None = None,
        logical_replica_count: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
1636
1637
        if enable_eplb:
            raise NotImplementedError(
1638
1639
                "EPLB not supported for `ModelOptNvFp4FusedMoE` yet."
            )
1640
        assert activation == "silu", "Only SiLU activation is supported."
1641

1642
1643
1644
1645
        if (
            self.allow_flashinfer
            and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1646
1647
1648
1649
            import flashinfer

            from vllm.model_executor.models.llama4 import Llama4MoE

1650
1651
            assert self.fused_experts is None

1652
            a1_gscale = layer.w13_input_scale_quant
1653
1654
1655
1656
1657
1658
1659
1660
            (hidden_states_fp4, hidden_states_scale_linear_fp4) = (
                flashinfer.fp4_quantize(
                    x,
                    a1_gscale,
                    is_sf_swizzled_layout=False,
                )
            )
            use_llama4_routing = (
1661
                custom_routing_function is Llama4MoE.custom_routing_function
1662
            )
1663
1664
1665
            routing_method_type = flashinfer.RoutingMethodType.DeepSeekV3
            if use_llama4_routing:
                routing_method_type = flashinfer.RoutingMethodType.Llama4
Shu Wang's avatar
Shu Wang committed
1666
1667
1668
            routing_bias = e_score_correction_bias
            if routing_bias is not None:
                routing_bias = routing_bias.to(torch.bfloat16)
1669
1670
            out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
                routing_logits=router_logits
1671
1672
                if use_llama4_routing
                else router_logits.to(torch.float32),
Shu Wang's avatar
Shu Wang committed
1673
                routing_bias=routing_bias,
1674
1675
                hidden_states=hidden_states_fp4,
                hidden_states_scale=hidden_states_scale_linear_fp4.view(
1676
1677
                    torch.float8_e4m3fn
                ).flatten(),
1678
1679
                gemm1_weights=layer.gemm1_weights_fp4_shuffled.data,
                gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.data.view(
1680
1681
                    torch.float8_e4m3fn
                ),
1682
1683
1684
1685
1686
1687
                gemm1_bias=None,
                gemm1_alpha=None,
                gemm1_beta=None,
                gemm1_clamp_limit=None,
                gemm2_weights=layer.gemm2_weights_fp4_shuffled.data,
                gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.data.view(
1688
1689
                    torch.float8_e4m3fn
                ),
1690
1691
1692
1693
1694
1695
                gemm2_bias=None,
                output1_scale_scalar=layer.g1_scale_c.data,
                output1_scale_gate_scalar=layer.g1_alphas.data,
                output2_scale_scalar=layer.g2_alphas.data,
                num_experts=global_num_experts,
                top_k=top_k,
1696
                n_group=num_expert_group if num_expert_group is not None else 0,
1697
                topk_group=topk_group if topk_group is not None else 0,
1698
1699
1700
1701
                intermediate_size=layer.intermediate_size_per_partition,
                local_expert_offset=layer.ep_rank * layer.local_num_experts,
                local_num_experts=layer.local_num_experts,
                routed_scaling_factor=None,
1702
                tile_tokens_dim=None,
1703
1704
1705
1706
1707
                routing_method_type=routing_method_type,
                do_finalize=True,
            )[0]
            return out

XuruiYang's avatar
XuruiYang committed
1708
        topk_weights, topk_ids, _ = FusedMoE.select_experts(
1709
1710
1711
1712
1713
1714
1715
1716
1717
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
            num_expert_group=num_expert_group,
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
1718
            routed_scaling_factor=routed_scaling_factor,
1719
            e_score_correction_bias=e_score_correction_bias,
1720
1721
            indices_type=self.topk_indices_dtype,
        )
1722

1723
1724
1725
1726
1727
        #
        # Note: the order here is important. self.fused_experts can override
        # flashinfer cutlass, cutlass fp4 or fused_experts but not marlin or
        # trtllm.
        #
1728
        if self.use_marlin:
1729
            assert self.fused_experts is None
1730
            return fused_marlin_moe(
1731
1732
1733
                x,
                layer.w13_weight,
                layer.w2_weight,
1734
1735
                None,
                None,
1736
1737
1738
1739
1740
1741
1742
1743
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                router_logits,
                topk_weights,
                topk_ids,
                global_scale1=layer.w13_weight_scale_2,
                global_scale2=layer.w2_weight_scale_2,
                quant_type_id=scalar_types.float4_e2m1f.id,
1744
                apply_router_weight_on_input=apply_router_weight_on_input,
1745
                global_num_experts=global_num_experts,
1746
                expert_map=expert_map,
1747
1748
                workspace=layer.workspace,
            )
1749

1750
        elif self.fused_experts is not None:
1751
1752
1753
1754
            assert (
                self.allow_flashinfer
                and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
            )
1755
1756

            assert is_valid_flashinfer_cutlass_fused_moe(
1757
1758
                x, layer.w13_weight, layer.w2_weight
            ), "Flashinfer CUTLASS Fused MoE not applicable!"
1759

1760
            return self.fused_experts(
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                inplace=False,  # TODO(shuw): fix later, now output is high prec
                activation=activation,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
        else:
1773
1774
            # If no modular kernel is provided, use cutlass_moe_fp4 for TP case
            # only (no EP).
1775
1776
            from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4

1777
1778
            assert self.moe_quant_config is not None
            return cutlass_moe_fp4(
1779
1780
1781
1782
1783
                a=x,
                w1_fp4=layer.w13_weight,
                w2_fp4=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
1784
1785
1786
1787
                quant_config=self.moe_quant_config,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
                # TODO: derive from arguments
1788
1789
1790
1791
                m=x.shape[0],
                n=layer.w2_weight.shape[2] * 2,
                k=x.shape[1],
                e=layer.w13_weight.shape[0],
1792
            )