modelopt.py 61.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from fnmatch import fnmatch
5
from typing import TYPE_CHECKING, Any, Optional
6
7
8
9
10

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter

11
12
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
13
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
14
from vllm.attention.layer import Attention
15
from vllm.logger import init_logger
16
from vllm.model_executor.layers.fused_moe.config import (
17
    FusedMoEConfig,
18
19
    FusedMoEQuantConfig,
)
20
from vllm.model_executor.layers.fused_moe.layer import (
21
22
23
24
    FusedMoE,
    FusedMoEMethodBase,
    FusedMoeWeightScaleSupported,
)
25
26
27
28
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
    Fp8MoeBackend,
    convert_to_fp8_moe_kernel_format,
    make_fp8_moe_kernel,
zhuwenwen's avatar
zhuwenwen committed
29
    make_fp8_moe_kernel_for_mkm,
30
31
32
    make_fp8_moe_quant_config,
    select_fp8_moe_backend,
)
33
34
35
36
37
from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
    NvFp4MoeBackend,
    convert_to_nvfp4_moe_kernel_format,
    is_global_sf_supported_for_nvfp4_backend,
    make_nvfp4_moe_kernel,
zhuwenwen's avatar
zhuwenwen committed
38
    make_nvfp4_moe_kernel_for_mkm,
39
40
41
    make_nvfp4_moe_quant_config,
    select_nvfp4_moe_backend,
)
42
43
44
45
46
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
)
47
from vllm.model_executor.layers.quantization import QuantizationMethods
48
from vllm.model_executor.layers.quantization.base_config import (
49
50
51
    QuantizationConfig,
    QuantizeMethodBase,
)
52
53
54
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
    init_fp8_linear_kernel,
)
55
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
56
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
zhuwenwen's avatar
zhuwenwen committed
57
    build_flashinfer_fp4_cutlass_moe_prepare_finalize,
58
    flashinfer_trtllm_fp4_moe,
59
    flashinfer_trtllm_fp4_routed_moe,
60
)
61
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
62
    apply_fi_trtllm_fp8_per_tensor_moe,
zhuwenwen's avatar
zhuwenwen committed
63
    build_flashinfer_fp8_cutlass_moe_prepare_finalize,
64
)
65
66
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
    W8A8BlockFp8LinearOp,
67
68
    process_fp8_input_tensor_strategy_moe,
    process_fp8_weight_tensor_strategy_moe,
69
)
70
71
72
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
    get_marlin_input_dtype,
)
73
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
74
75
76
77
    apply_fp4_marlin_linear,
    is_fp4_marlin_supported,
    prepare_fp4_layer_for_marlin,
)
78
from vllm.model_executor.layers.quantization.utils.quant_utils import (
79
80
81
    GroupShape,
    cutlass_fp4_supported,
    is_layer_skipped,
82
83
84
    kFp8DynamicTokenSym,
    kFp8StaticTensorSym,
    kFp8StaticTokenSym,
85
86
    kNvfp4Dynamic,
    kNvfp4Static,
87
88
    swizzle_blockscale,
)
89
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
90
    cutlass_block_fp8_supported,
91
92
    requantize_with_max_scale,
)
93
94
95
96
97
98
from vllm.model_executor.parameter import (
    BlockQuantScaleParameter,
    ChannelQuantScaleParameter,
    ModelWeightParameter,
    PerTensorScaleParameter,
)
99
from vllm.model_executor.utils import replace_parameter
100
101
102
103
from vllm.utils.flashinfer import (
    flashinfer_scaled_fp4_mm,
    has_flashinfer,
)
104

105
106
107
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

108
109
logger = init_logger(__name__)

110
111
112
113
114
115
116
117
118
119
QUANT_ALGOS = [
    # FP8 (per-tensor weight + optional static activation scale).
    "FP8",
    # FP8 per-channel weight scale + per-token activation scale.
    "FP8_PER_CHANNEL_PER_TOKEN",
    # FP8 per-block weight-only (ModelOpt may emit this as lowercase).
    "FP8_PB_WO",
    # FP4
    "NVFP4",
]
120
KV_CACHE_QUANT_ALGOS = ["FP8"]
121
122


123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
    """

    def __init__(self, quant_config: "ModelOptQuantConfigBase"):
        super().__init__(quant_config)


class ModelOptQuantConfigBase(QuantizationConfig):
    LinearMethodCls: type = LinearMethodBase
    FusedMoEMethodCls: type = FusedMoEMethodBase
    KVCacheMethodCls: type = BaseKVCacheMethod

    def __init__(
        self,
        exclude_modules: list[str],
    ):
        super().__init__()
        self.exclude_modules: list[str] = exclude_modules

    def is_layer_excluded(self, prefix: str) -> bool:
        """
        Check if a layer should be excluded from quantization.

        Handles both exact matching (for fused layers) and ModelOpt wildcard matching.

        The ModelOpt exclude_modules list is a list of wildcards.
        """
        if len(self.exclude_modules) == 0:
            return False

        # First check exact matching with fused layer support
        if is_layer_skipped(prefix, self.exclude_modules, self.packed_modules_mapping):
            return True

        # TODO: This special hard coded logic is not needed for quantized checkpoints
        # generated by ModelOpt >= 0.39.0 where they are handled natually by the
        # exclude_modules config. But need to keep them for loading quantized
        # checkpoints generated by older versions. Then check substring matching
        # for patterns not caught by exact match
        for exclude_module in self.exclude_modules:
            # Skip exact matches already handled above
            if exclude_module != prefix and (
                exclude_module in prefix
                or (
                    prefix.startswith("language_model.")
                    and exclude_module in prefix.removeprefix("language_model.")
                )
            ):
                return True

        # modelopt exclude modules are not simple strings, they are wildcards
        for wildcard_pattern in self.exclude_modules:
            if fnmatch(prefix, wildcard_pattern):
                return True

        return False

    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
        # handle kv-cache first so we can focus only on weight quantization thereafter
        if isinstance(layer, Attention):
            return self.KVCacheMethodCls(self)

        # handle exclusion
        if self.is_layer_excluded(prefix):
            if isinstance(layer, LinearBase):
                return UnquantizedLinearMethod()
            return None

        # TODO: This special hard coded logic is not needed for quantized checkpoints
        # generated by ModelOpt >= 0.39.0 where they are handled natually by the
        # exclude_modules config. But need to keep them for loading quantized
        # checkpoints generated by older versions. Then check substring matching
        # for patterns not caught by exact match
        if "vision_tower" in prefix or "vision_model" in prefix:
            return UnquantizedLinearMethod()

        # now, the layer is quantized, handle it here
        if isinstance(layer, LinearBase):
205
206
207
208
            quant_method = self.LinearMethodCls(self)
            if getattr(quant_method, "backend", "") == "marlin":
                quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
            return quant_method
209
        elif isinstance(layer, FusedMoE):
210
211
212
            quant_method = self.FusedMoEMethodCls(
                quant_config=self, moe_config=layer.moe_config
            )
213
214
215
            if getattr(quant_method, "backend", "") == "marlin":
                quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
            return quant_method
216
217
218
219
220

        return None

    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if len(self.exclude_modules) > 0:
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
            # This is a workaround for the weights remapping issue:
            # https://github.com/vllm-project/vllm/issues/28072
            # Right now, the Nvidia ModelOpt library use just one wildcard pattern:
            #        module_path*
            # It gets applied if the whole tree of modules rooted at module_path
            # is not quantized. Here we replace such pattern by 2 patterns that are
            # collectively equivalent to the original pattern:
            #        module_path
            #        module_path.*
            new_exclude_modules = []
            for exclude in self.exclude_modules:
                if len(exclude) >= 2 and exclude[-1] == "*" and exclude[-2] != ".":
                    new_exclude_modules.append(exclude[:-1])
                    new_exclude_modules.append(exclude[:-1] + ".*")
                else:
                    new_exclude_modules.append(exclude)

            self.exclude_modules = hf_to_vllm_mapper.apply_list(new_exclude_modules)
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287

    @staticmethod
    def get_config_filenames() -> list[str]:
        return ["hf_quant_config.json"]

    @classmethod
    def _from_config(
        cls,
        *,
        quant_method: str,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
        original_config: dict[str, Any],
        group_size: int | None,
    ) -> "ModelOptQuantConfigBase":
        raise NotImplementedError("Please implement this function in sub classes")

    @classmethod
    def from_config(cls, config: dict[str, Any]) -> "ModelOptQuantConfigBase":
        # Handle both ModelOpt format and compressed-tensors style format
        if "quantization" in config:
            # Traditional ModelOpt format:
            # {"quantization": {"quant_algo": "..."}}
            quant_config = cls.get_from_keys(config, ["quantization"])
            if not isinstance(quant_config, dict):
                raise ValueError("Expected 'quantization' to be a dictionary in config")

            quant_method = quant_config.get("quant_algo")

            # Handle kv_cache_quant_algo with proper type validation
            kv_cache_quant_method = quant_config.get("kv_cache_quant_algo")

            # Handle group_size with proper type validation
            group_size_raw = quant_config.get("group_size")

            # "exclude_modules" is the key in the legacy hf_quant_config.json
            exclude_modules = quant_config.get("exclude_modules", [])
        else:
            # Compressed-tensors style format:
            # {"quant_algo": "...", "quant_method": "modelopt"}
            quant_method = config.get("quant_algo")
            kv_cache_quant_method = config.get("kv_cache_quant_algo")
            # "ignore" is the key in config.json
            exclude_modules = config.get("ignore", [])
            group_size_raw = config.get("group_size")

        if not quant_method:
            raise ValueError("Missing 'quant_algo' in quantization config")

288
289
290
        # Normalize quant_algo for robust matching (ModelOpt may emit lowercase).
        quant_method = str(quant_method).upper()

291
292
293
294
295
296
297
298
        if kv_cache_quant_method is None:
            # No KV cache quantization, keep this branch just to have this comment
            pass
        elif not isinstance(kv_cache_quant_method, str):
            raise ValueError(
                f"kv_cache_quant_algo must be a string, got "
                f"{type(kv_cache_quant_method)}"
            )
299
300
        else:
            kv_cache_quant_method = kv_cache_quant_method.upper()
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335

        if not isinstance(exclude_modules, list):
            raise ValueError(
                f"exclude_modules must be a list, got {type(exclude_modules)}"
            )

        if group_size_raw is None:
            group_size = None
        elif isinstance(group_size_raw, int):
            group_size = group_size_raw
        else:
            try:
                group_size = int(group_size_raw)
            except (ValueError, TypeError):
                raise ValueError(
                    f"group_size must be an integer, got {type(group_size_raw)}"
                ) from None

        if quant_method not in QUANT_ALGOS:
            raise ValueError(
                f"ModelOpt currently only supports: {QUANT_ALGOS} "
                "quantizations in vLLM. Please check the "
                "`hf_quant_config.json` file for your model's "
                "quant configuration."
            )
        return cls._from_config(
            quant_method=quant_method,
            kv_cache_quant_method=kv_cache_quant_method,
            exclude_modules=exclude_modules,
            group_size=group_size,
            original_config=config,
        )


class ModelOptFp8Config(ModelOptQuantConfigBase):
336
337
338
339
    """Config class for ModelOpt FP8."""

    def __init__(
        self,
340
        quant_method: str,
341
342
343
        is_checkpoint_fp8_serialized: bool,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
344
    ) -> None:
345
        super().__init__(exclude_modules)
346
        self.quant_method = quant_method
347
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
348
        self.kv_cache_quant_method = kv_cache_quant_method
349
        if is_checkpoint_fp8_serialized:
350
            logger.warning(
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
                "Detected ModelOpt fp8 checkpoint (quant_algo=%s). Please note "
                "that the format is experimental and could change.",
                quant_method,
            )

        # Select LinearMethod implementation based on quant_algo.
        if self.quant_method == "FP8":
            self.LinearMethodCls = ModelOptFp8LinearMethod
        elif self.quant_method == "FP8_PER_CHANNEL_PER_TOKEN":
            self.LinearMethodCls = ModelOptFp8PcPtLinearMethod
        elif self.quant_method == "FP8_PB_WO":
            self.LinearMethodCls = ModelOptFp8PbWoLinearMethod
        else:
            raise ValueError(
                "Unsupported ModelOpt FP8 quant_algo for vLLM: "
                f"{self.quant_method}. Supported: FP8 / "
                "FP8_PER_CHANNEL_PER_TOKEN / FP8_PB_WO."
368
            )
369

370
    def get_name(self) -> QuantizationMethods:
371
372
        return "modelopt"

373
    def get_supported_act_dtypes(self) -> list[torch.dtype]:
374
375
376
377
378
379
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
        return 89

380
381
    @classmethod
    def override_quantization_method(
382
        cls, hf_quant_cfg, user_quant
383
    ) -> QuantizationMethods | None:
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
        """Detect if this ModelOpt config should be used based on
        quantization config."""

        if hf_quant_cfg is None:
            return None

        # Use the community standard 'quant_method'
        quant_method = hf_quant_cfg.get("quant_method", "").lower()

        # Only proceed if the method is explicitly "modelopt"
        if quant_method != "modelopt":
            return None

        # Look for ModelOpt-specific config structure
        if "quantization" in hf_quant_cfg:
            quant_config = hf_quant_cfg["quantization"]
            if isinstance(quant_config, dict):
401
402
                quant_algo = str(quant_config.get("quant_algo", ""))
                if "FP8" in quant_algo.upper():
403
404
405
                    return "modelopt"
        else:
            # Check for compressed-tensors style config with specific quant_algo
406
407
            quant_algo = str(hf_quant_cfg.get("quant_algo", ""))
            if "FP8" in quant_algo.upper():
408
409
410
411
                return "modelopt"

        return None

412
    @classmethod
413
414
415
416
417
418
419
420
421
    def _from_config(
        cls,
        *,
        quant_method: str,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
        original_config: dict[str, Any],
        **kwargs: Any,
    ) -> "ModelOptFp8Config":
422
        is_checkpoint_fp8_serialized = "FP8" in quant_method
423

424
425
426
427
428
429
        return cls(
            quant_method,
            is_checkpoint_fp8_serialized,
            kv_cache_quant_method,
            exclude_modules,
        )
430

431
432
433
434

class ModelOptFp8LinearMethod(LinearMethodBase):
    """Linear method for Model Optimizer static quantization.
    Supports loading FP8 checkpoints with static weight scale and
435
    activation scale. Future support might be added for dynamic
436
437
438
439
    scales.

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
440
    2. Only support float8_e4m3fn datatype
441
442
443
        Args: quant_config: The ModelOpt quantization config.
    """

444
    def __init__(self, quant_config: ModelOptFp8Config) -> None:
445
        self.quant_config = quant_config
446
447
448
449
450
        self.fp8_linear = init_fp8_linear_kernel(
            activation_quant_key=kFp8StaticTensorSym,
            weight_quant_key=kFp8StaticTensorSym,
            out_dtype=torch.get_default_dtype(),
            module_name=self.__class__.__name__,
451
        )
452
453
454
455
456

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
457
        output_partition_sizes: list[int],
458
459
460
461
462
463
464
465
466
467
468
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
469
470
471
472
473
474
475
476
477
478
479
480
481
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_fp8_serialized
            else params_dtype
        )
        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition, input_size_per_partition, dtype=weight_dtype
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
482
483
484
485
        layer.register_parameter("weight", weight)

        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
486
487
488
489
            weight_scale = PerTensorScaleParameter(
                data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
                weight_loader=weight_loader,
            )
490
491
492
            weight_scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("weight_scale", weight_scale)
            # INPUT SCALE
493
494
495
496
            scale = PerTensorScaleParameter(
                data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
                weight_loader=weight_loader,
            )
497
498
499
500
501

            scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("input_scale", scale)

    def process_weights_after_loading(self, layer: Module) -> None:
502
503
504
505
        weight = layer.weight
        max_w_scale = layer.weight_scale.max()
        if not (layer.weight_scale == layer.weight_scale[0]).all():
            max_w_scale, weight = requantize_with_max_scale(
506
507
                layer.weight, layer.weight_scale, layer.logical_widths
            )
508
509
        layer.weight = Parameter(weight.t(), requires_grad=False)
        layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
510
        layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
511
512
513
514
515

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
516
        bias: torch.Tensor | None = None,
517
    ) -> torch.Tensor:
518
        return self.fp8_linear.apply_weights(layer, x, bias)
519
520


521
522
523
524
525
526
527
528
529
530
531
class ModelOptFp8PcPtLinearMethod(LinearMethodBase):
    """Linear method for ModelOpt FP8_PER_CHANNEL_PER_TOKEN checkpoints.

    Expected checkpoint structure (per Linear):
    - weight: fp8-e4m3fn, shape [out, in]
    - weight_scale: fp32, shape [out] (per-output-channel)
    - no input_scale (activations are dynamically quantized per-token)
    """

    def __init__(self, quant_config: ModelOptFp8Config) -> None:
        self.quant_config = quant_config
532
533
534
535
536
        self.fp8_linear = init_fp8_linear_kernel(
            activation_quant_key=kFp8DynamicTokenSym,
            weight_quant_key=kFp8StaticTokenSym,
            out_dtype=torch.get_default_dtype(),
            module_name=self.__class__.__name__,
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
        )

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size

        if not self.quant_config.is_checkpoint_fp8_serialized:
            raise ValueError(
                "FP8_PER_CHANNEL_PER_TOKEN currently only supports "
                "FP8-serialized checkpoints."
            )

        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition

        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition,
                dtype=torch.float8_e4m3fn,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
        layer.register_parameter("weight", weight)

        weight_scale = ChannelQuantScaleParameter(
            data=torch.empty(output_size_per_partition, dtype=torch.float32),
            output_dim=0,
            weight_loader=weight_loader,
        )
        weight_scale[:] = torch.finfo(torch.float32).min
        layer.register_parameter("weight_scale", weight_scale)

    def process_weights_after_loading(self, layer: Module) -> None:
        layer.weight = Parameter(layer.weight.t(), requires_grad=False)
        layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
593
        return self.fp8_linear.apply_weights(layer, x, bias)
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714


class ModelOptFp8PbWoLinearMethod(LinearMethodBase):
    """Linear method for ModelOpt FP8_PB_WO checkpoints.

    ModelOpt exports `weight_scale` as a 4D tensor:
      [out_blk, 1, in_blk, 1]
    where block size is typically 128 for both dims.

    vLLM executes it as FP8 GEMM with *dynamic per-token* activation quant.
    """

    _WEIGHT_BLOCK_SIZE: tuple[int, int] = (128, 128)

    def __init__(self, quant_config: ModelOptFp8Config) -> None:
        self.quant_config = quant_config
        block_n, block_k = self._WEIGHT_BLOCK_SIZE
        self.weight_block_size = list(self._WEIGHT_BLOCK_SIZE)
        self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
            weight_group_shape=GroupShape(block_n, block_k),
            act_quant_group_shape=GroupShape(1, block_k),
            cutlass_block_fp8_supported=cutlass_block_fp8_supported(),
            use_aiter_and_is_supported=False,
        )

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size

        if not self.quant_config.is_checkpoint_fp8_serialized:
            raise ValueError(
                "FP8_PB_WO currently only supports FP8-serialized checkpoints."
            )

        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition

        # Expose block size so the v2 weight loaders can translate offsets from
        # element-space -> block-space for BlockQuantScaleParameter.
        layer.weight_block_size = self.weight_block_size

        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition,
                dtype=torch.float8_e4m3fn,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
        layer.register_parameter("weight", weight)

        block_n, block_k = self._WEIGHT_BLOCK_SIZE
        if output_size_per_partition % block_n != 0:
            raise ValueError(
                "ModelOpt FP8_PB_WO requires out_features divisible by "
                f"{block_n}, got {output_size_per_partition}."
            )
        if input_size_per_partition % block_k != 0:
            raise ValueError(
                "ModelOpt FP8_PB_WO requires in_features divisible by "
                f"{block_k}, got {input_size_per_partition}."
            )

        out_blks = output_size_per_partition // block_n
        in_blks = input_size_per_partition // block_k

        # Match ModelOpt's exported shape so weight loading works without a
        # custom loader: [out_blk, 1, in_blk, 1]
        weight_scale = BlockQuantScaleParameter(
            data=torch.empty((out_blks, 1, in_blks, 1), dtype=torch.float32),
            input_dim=2,
            output_dim=0,
            weight_loader=weight_loader,
        )
        weight_scale[:] = torch.finfo(torch.float32).min
        layer.register_parameter("weight_scale", weight_scale)

    def process_weights_after_loading(self, layer: Module) -> None:
        # Keep weight in [out, in] layout for W8A8BlockFp8LinearOp.
        layer.weight = Parameter(layer.weight.data, requires_grad=False)

        scale = layer.weight_scale
        if scale.dim() == 4:
            # [out_blk, 1, in_blk, 1] -> [out_blk, in_blk]
            scale = scale.squeeze(1).squeeze(-1)
        elif scale.dim() != 2:
            raise ValueError(
                "Unexpected ModelOpt FP8_PB_WO weight_scale shape: "
                f"{tuple(scale.shape)}."
            )

        layer.weight_scale = Parameter(scale.contiguous(), requires_grad=False)

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        return self.w8a8_block_fp8_linear.apply(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            input_scale=None,
            bias=bias,
        )


715
716
717
718
719
720
721
722
class ModelOptFp8MoEMethod(FusedMoEMethodBase):
    """MoE method for ModelOpt FP8.
    Supports loading FP8 checkpoints with static weight scale and
    activation scale.
    Args:
        quant_config: The ModelOpt quantization config.
    """

723
724
725
    def __init__(
        self,
        quant_config: ModelOptFp8Config,
726
        moe_config: FusedMoEConfig,
727
    ) -> None:
728
        super().__init__(moe_config)
729
        self.quant_config = quant_config
730
        assert self.quant_config.is_checkpoint_fp8_serialized
731
732
733
734
735
736

        # Select Fp8 MoE backend
        self.fp8_backend, self.experts_cls = select_fp8_moe_backend(
            config=self.moe,
            weight_key=kFp8StaticTensorSym,
            activation_key=kFp8StaticTensorSym,
737
        )
738

zhuwenwen's avatar
zhuwenwen committed
739
740
741
742
743
744
745
746
747
        # Delay creation of the kernel until after process-weights.
        self.kernel: mk.FusedMoEModularKernel | None = None

    @property
    def topk_indices_dtype(self) -> torch.dtype | None:
        if self.kernel is not None:
            return self.kernel.prepare_finalize.topk_indices_dtype()
        return None

748
    def maybe_make_prepare_finalize(
749
        self,
750
        routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
751
    ) -> mk.FusedMoEPrepareAndFinalize | None:
zhuwenwen's avatar
zhuwenwen committed
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
        # TRT LLM not supported with all2all yet.
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
            return None
        elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
            # For no-EP case, don't use the MKM framework.
            if not self.moe.moe_parallel_config.use_all2all_kernels:
                return None

            prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
                self.moe,
                use_deepseek_fp8_block_scale=False,
            )
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
        return super().maybe_make_prepare_finalize(routing_tables)
767
768
769
770

    def select_gemm_impl(
        self,
        prepare_finalize: mk.FusedMoEPrepareAndFinalize,
771
        layer: torch.nn.Module,
772
    ) -> mk.FusedMoEPermuteExpertsUnpermute:
zhuwenwen's avatar
zhuwenwen committed
773
774
775
776
777
778
779
        assert self.moe_quant_config is not None
        assert self.experts_cls is not None
        return make_fp8_moe_kernel_for_mkm(
            moe_config=self.moe,
            quant_config=self.moe_quant_config,
            experts_cls=self.experts_cls,
            prepare_finalize=prepare_finalize,
780
        )
781
782
783
784
785
786
787
788
789
790

    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
791
792
793
        layer.orig_dtype = params_dtype
        layer.num_experts = num_experts

794
        # Use FP8 dtype if checkpoint is serialized
795
796
797
798
799
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_fp8_serialized
            else params_dtype
        )
800
801
        weight_loader = extra_weight_attrs.get("weight_loader")

802
        w13_num_shards = 2 if self.moe.is_act_and_mul else 1
803

804
        w13_weight = ModelWeightParameter(
805
806
            data=torch.empty(
                num_experts,
807
                w13_num_shards * intermediate_size_per_partition,
808
809
810
                hidden_size,
                dtype=weight_dtype,
            ),
811
812
813
814
815
816
817
            input_dim=2,
            output_dim=1,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_weight", w13_weight)

        w2_weight = ModelWeightParameter(
818
819
820
821
822
823
            data=torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=weight_dtype,
            ),
824
825
826
827
828
829
            input_dim=2,
            output_dim=1,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w2_weight", w2_weight)

830
831
832
833
834
835
        # WEIGHT SCALES - Per-tensor scaling for ModelOpts
        # For gated MoE, allocate 2 scales for w1 and w3 respectively.
        # They will be combined to a single scale after weight loading.
        # For non-gated MoE, allocate 1 scale for w13.
        w13_weight_scale = PerTensorScaleParameter(
            data=torch.full(
836
                (num_experts, w13_num_shards),
837
838
839
840
841
842
843
844
845
846
847
                1.0,
                dtype=torch.float32,
            ),
            weight_loader=weight_loader,
        )
        w2_weight_scale = PerTensorScaleParameter(
            data=torch.full((num_experts,), 1.0, dtype=torch.float32),
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_weight_scale", w13_weight_scale)
        layer.register_parameter("w2_weight_scale", w2_weight_scale)
848

849
850
851
852
        # INPUT SCALES - Per-tensor scaling for ModelOpt
        w13_input_scale = PerTensorScaleParameter(
            data=torch.full((num_experts,), 1.0, dtype=torch.float32),
            weight_loader=weight_loader,
853
        )
854
855
856
857
858
859
        w2_input_scale = PerTensorScaleParameter(
            data=torch.full((num_experts,), 1.0, dtype=torch.float32),
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_input_scale", w13_input_scale)
        layer.register_parameter("w2_input_scale", w2_input_scale)
860

861
862
    def _setup_kernel(
        self,
zhuwenwen's avatar
zhuwenwen committed
863
        layer: torch.nn.Module,
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
        w13: torch.Tensor,
        w2: torch.Tensor,
        w13_scale: torch.Tensor,
        w2_scale: torch.Tensor,
        w13_input_scale: torch.Tensor,
        w2_input_scale: torch.Tensor,
    ):
        w13, w2, w13_scale, w2_scale = convert_to_fp8_moe_kernel_format(
            fp8_backend=self.fp8_backend,
            layer=layer,
            w13=w13,
            w2=w2,
            w13_scale=w13_scale,
            w2_scale=w2_scale,
            w13_input_scale=w13_input_scale,
            w2_input_scale=w2_input_scale,
        )
881

882
883
884
885
886
887
888
        # Replace parameters with updated versions. Note that this helper
        # function ensures the replacement is compatible with RL weight reloads.
        replace_parameter(layer, "w13_weight", w13)
        replace_parameter(layer, "w2_weight", w2)
        replace_parameter(layer, "w13_weight_scale", w13_scale)
        replace_parameter(layer, "w2_weight_scale", w2_scale)

889
        # Setup modular kernel.
890
891
        self.moe_quant_config = self.get_fused_moe_quant_config(layer)
        if self.moe_quant_config:
892
            assert self.experts_cls is not None
zhuwenwen's avatar
zhuwenwen committed
893
            self.kernel, self.use_inplace = make_fp8_moe_kernel(
894
895
896
                moe_quant_config=self.moe_quant_config,
                moe_config=self.moe,
                fp8_backend=self.fp8_backend,
897
                experts_cls=self.experts_cls,
898
            )
899

900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        w13 = layer.w13_weight
        w2 = layer.w2_weight
        w13_scale = layer.w13_weight_scale
        w2_scale = layer.w2_weight_scale
        w13_input_scale = layer.w13_input_scale
        w2_input_scale = layer.w2_input_scale

        # Per tensor kernels require single activation scale. Use the max.
        w13_input_scale, w2_input_scale = process_fp8_input_tensor_strategy_moe(
            w13_input_scale, w2_input_scale
        )
        replace_parameter(layer, "w13_input_scale", w13_input_scale)
        replace_parameter(layer, "w2_input_scale", w2_input_scale)

        # Per tensor kernels require single weight scale for w13 per expert, but
        # on disk there is a scale for w1 and w3. Use the max to requantize.
        shard_size = layer.intermediate_size_per_partition
        w13, w13_scale = process_fp8_weight_tensor_strategy_moe(
            w13,
            w13_scale,
            shard_size,
            num_experts=layer.w13_weight.shape[0],
            is_act_and_mul=self.moe.is_act_and_mul,
924
925
        )

926
927
928
929
        # Shuffle weights to runtime format and setup kernel.
        self._setup_kernel(
            layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale
        )
930

931
    def get_fused_moe_quant_config(
932
        self, layer: torch.nn.Module
933
    ) -> FusedMoEQuantConfig | None:
934
935
936
937
938
939
940
941
942
943
944
945
        w1_scale = layer.w13_weight_scale
        w2_scale = layer.w2_weight_scale
        a1_scale = layer.w13_input_scale
        a2_scale = layer.w2_input_scale

        return make_fp8_moe_quant_config(
            fp8_backend=self.fp8_backend,
            w1_scale=w1_scale,
            w2_scale=w2_scale,
            a1_scale=a1_scale,
            a2_scale=a2_scale,
        )
946

947
948
949
950
951
    @property
    def is_monolithic(self) -> bool:
        return self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM

    def apply_monolithic(
952
        self,
953
        layer: FusedMoE,
954
955
        x: torch.Tensor,
        router_logits: torch.Tensor,
956
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
957
958
959
960
961
        assert self.is_monolithic
        assert self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
        if layer.enable_eplb:
            raise NotImplementedError(
                "EPLB not supported for FlashInfer TRTLLM FP8 MoE Backend."
962
            )
963
964
965
966
967
968
969
970
        # TODO(rob): this validation should happen at kernel selection
        # time in the oracle rather than here.
        assert layer.activation == "silu", (
            f"Expected 'silu' activation but got {layer.activation}"
        )
        assert not layer.renormalize
        return apply_fi_trtllm_fp8_per_tensor_moe(
            layer=layer,
971
972
            hidden_states=x,
            router_logits=router_logits,
973
974
975
976
977
978
            routing_bias=layer.e_score_correction_bias,
            global_num_experts=layer.global_num_experts,
            top_k=layer.top_k,
            num_expert_group=layer.num_expert_group,
            topk_group=layer.topk_group,
            apply_router_weight_on_input=layer.apply_router_weight_on_input,
979
        )
980

981
982
983
984
985
986
987
988
989
    def apply(
        self,
        layer: FusedMoE,
        x: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        assert not self.is_monolithic

990
991
992
        # TODO(rob): this validation should happen at kernel selection
        # time in the oracle rather than here.
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
993
            assert layer.activation in ("silu", "relu2_no_mul"), (
994
                "Expected activation to be in ('silu', 'relu2_no_mul'),"
995
                f"but got {layer.activation}"
996
            )
997

zhuwenwen's avatar
zhuwenwen committed
998
999
        assert self.kernel is not None
        return self.kernel(
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
            x,
            layer.w13_weight,
            layer.w2_weight,
            topk_weights,
            topk_ids,
            inplace=self.use_inplace,
            activation=layer.activation,
            global_num_experts=layer.global_num_experts,
            expert_map=layer.expert_map,
            apply_router_weight_on_input=layer.apply_router_weight_on_input,
        )

1012

1013
1014
1015
1016
1017
1018
ModelOptFp8Config.LinearMethodCls = ModelOptFp8LinearMethod
ModelOptFp8Config.FusedMoEMethodCls = ModelOptFp8MoEMethod
ModelOptFp8Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod


class ModelOptNvFp4Config(ModelOptQuantConfigBase):
1019
1020
1021
1022
1023
    """Config class for ModelOpt FP4."""

    def __init__(
        self,
        is_checkpoint_nvfp4_serialized: bool,
1024
        kv_cache_quant_algo: str | None,
1025
        exclude_modules: list[str],
1026
1027
        group_size: int = 16,
    ) -> None:
1028
        super().__init__(exclude_modules)
1029
1030
1031
1032
        self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
        if is_checkpoint_nvfp4_serialized:
            logger.warning(
                "Detected ModelOpt NVFP4 checkpoint. Please note that"
1033
1034
                " the format is experimental and could change in future."
            )
1035
1036
1037
1038

            self.group_size = group_size
            self.kv_cache_quant_algo = kv_cache_quant_algo

1039
    def get_name(self) -> QuantizationMethods:
1040
        return "modelopt_fp4"
1041

1042
    def get_supported_act_dtypes(self) -> list[torch.dtype]:
1043
1044
1045
1046
        return [torch.bfloat16, torch.half, torch.float8_e4m3fn]

    @classmethod
    def get_min_capability(cls) -> int:
1047
        return 75
1048

1049
1050
    @classmethod
    def override_quantization_method(
1051
        cls, hf_quant_cfg, user_quant
1052
    ) -> QuantizationMethods | None:
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
        """Detect if this ModelOpt FP4 config should be used based on
        quantization config."""
        if hf_quant_cfg is None:
            return None

        # Use the community standard 'quant_method'
        quant_method = hf_quant_cfg.get("quant_method", "").lower()

        # Only proceed if the method is explicitly "modelopt"
        if quant_method != "modelopt":
            return None

        # Look for ModelOpt-specific config structure
        if "quantization" in hf_quant_cfg:
            quant_config = hf_quant_cfg["quantization"]
            if isinstance(quant_config, dict):
                quant_algo = quant_config.get("quant_algo", "")
                if "NVFP4" in quant_algo:
                    return "modelopt_fp4"
        else:
            # Check for compressed-tensors style config with specific
            # quant_algo field
            quant_algo = hf_quant_cfg.get("quant_algo", "")
            if isinstance(quant_algo, str) and "FP4" in quant_algo.upper():
                return "modelopt_fp4"

        return None

1081
    @classmethod
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
    def _from_config(
        cls,
        *,
        quant_method: str,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
        original_config: dict[str, Any],
        group_size: int | None,
        **kwargs: Any,
    ) -> "ModelOptNvFp4Config":
1092
        is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
1093

1094
1095
1096
        if group_size is None:
            group_size = 16  # Default value

1097
        # For FP4, these fields are required
1098
        if is_checkpoint_nvfp4_serialized and "quantization" in original_config:
1099
            # Check if required fields are present in the quantization config
1100
            quant_config = original_config["quantization"]
1101
            required_fields = ["group_size", "kv_cache_quant_algo", "exclude_modules"]
1102
1103
1104
1105
1106
1107
            missing_fields = [
                field for field in required_fields if field not in quant_config
            ]
            if missing_fields:
                raise ValueError(
                    f"NVFP4 quantization requires the following fields in "
1108
1109
1110
1111
1112
                    f"hf_quant_config.json: {missing_fields}"
                )

        return cls(
            is_checkpoint_nvfp4_serialized,
1113
            kv_cache_quant_method,
1114
1115
1116
            exclude_modules,
            group_size,
        )
1117
1118
1119
1120
1121


class ModelOptNvFp4LinearMethod(LinearMethodBase):
    """Linear method for Model Optimizer NVFP4.
    Supports loading NVFP4 checkpoints with the following structure:
1122

1123
1124
1125
1126
1127
1128
1129
    input_scale: torch.float32, scalar ,
    weight: NVFP4(represented as byte) Shape: [1, X, y/2]
    weight_scale: FP8-E4M3, Shape: [X, Y], aka per block scale,
    weight_scale_2: torch.float32, scalar,
    Args: quant_config: The ModelOpt quantization config.
    """

1130
    def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
1131
        self.quant_config = quant_config
1132
        self.marlin_input_dtype = None
1133

1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
        self.backend = "none"
        if envs.VLLM_NVFP4_GEMM_BACKEND is None:
            if has_flashinfer():
                self.backend = "flashinfer-cutlass"
            elif cutlass_fp4_supported():
                self.backend = "cutlass"
            elif is_fp4_marlin_supported():
                self.backend = "marlin"
        elif envs.VLLM_NVFP4_GEMM_BACKEND.startswith("flashinfer-"):
            self.backend = envs.VLLM_NVFP4_GEMM_BACKEND
            assert has_flashinfer(), f"FlashInfer is required for {self.backend}"
1145
1146
1147
        elif envs.VLLM_NVFP4_GEMM_BACKEND == "cutlass":
            self.backend = "cutlass"
            assert cutlass_fp4_supported(), f"Cutlass is required for {self.backend}"
1148
1149
1150
        elif envs.VLLM_NVFP4_GEMM_BACKEND == "marlin":
            self.backend = "marlin"
            assert is_fp4_marlin_supported(), f"Marlin is required for {self.backend}"
1151
1152

        if self.backend == "none":
1153
            raise ValueError(
1154
1155
                "No valid NVFP4 GEMM backend found. "
                "Please check your platform capability."
1156
            )
1157

1158
1159
        logger.info_once(f"Using {self.backend} for NVFP4 GEMM")

1160
1161
1162
1163
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
1164
        output_partition_sizes: list[int],
1165
1166
1167
1168
1169
1170
1171
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size
        if not self.quant_config.is_checkpoint_nvfp4_serialized:
1172
1173
1174
1175
            raise ValueError(
                "NVFP4 quantization was selected, "
                " dynamic quantization is not supported."
            )
1176
1177
1178
1179
1180
1181
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition

1182
1183
1184
1185
        if input_size_per_partition % 16 != 0:
            raise ValueError(
                "Unsupported model when in features size is not multiple of 16"
            )
1186
        # The nvfp4 weight is still represented as
1187
1188
1189
1190
1191
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_nvfp4_serialized
            else params_dtype
        )
1192
1193
1194
1195
1196
1197
        # Weight
        weight = ModelWeightParameter(
            data=torch.empty(
                # 2 fp4 items are packed in the input dimension
                layer.output_size_per_partition,
                layer.input_size_per_partition // 2,
1198
1199
                dtype=torch.uint8,
            ),
1200
1201
            input_dim=1,
            output_dim=0,
1202
1203
            weight_loader=weight_loader,
        )
1204
1205
1206
        layer.register_parameter("weight", weight)

        # Input Weight Scale
1207
1208
1209
1210
        input_scale = PerTensorScaleParameter(
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )
1211
1212
1213
        layer.register_parameter("input_scale", input_scale)

        # Global Weight Scale
1214
1215
1216
1217
        weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )
1218
1219
1220
        layer.register_parameter("weight_scale_2", weight_scale_2)

        # Per Block Weight Scale
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
        weight_scale = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition // self.quant_config.group_size,
                dtype=weight_dtype,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241

        layer.register_parameter("weight_scale", weight_scale)

    def process_weights_after_loading(self, layer: Module) -> None:
        # global scales:
        input_scale_2 = layer.input_scale.max().to(torch.float32)
        layer.input_scale = Parameter(input_scale_2, requires_grad=False)

        weight_scale_2 = layer.weight_scale_2.max().to(torch.float32)
        layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False)

1242
1243
1244
        layer.alpha = Parameter(
            layer.input_scale * layer.weight_scale_2, requires_grad=False
        )
1245

1246
1247
        # Calculate `1 / input_scale` so that we don't need to do so at runtime
        layer.input_scale_inv = Parameter(
1248
1249
            (1 / layer.input_scale).to(torch.float32), requires_grad=False
        )
1250

1251
1252
1253
        # Swizzle the weight blockscale.
        # contracting dimension is input dimension
        # block_size = 16;
1254
1255
1256
        assert layer.weight_scale.dtype == torch.float8_e4m3fn, (
            "Weight Block scale must be represented as FP8-E4M3"
        )
1257

1258
1259
1260
1261
1262
        if self.backend == "marlin":
            prepare_fp4_layer_for_marlin(layer)
            del layer.alpha
            del layer.input_scale
        elif self.backend == "flashinfer-trtllm":
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
            # FlashInfer TRTLLM FP4 GEMM requires a different weight layout.
            # FlashInfer provides nvfp4_quantize to quantize + shuffle the
            # layout but we use our own quantization so we have to call
            # shuffles ourselves.
            from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a

            weight = layer.weight.data
            weight_scale = layer.weight_scale.data

            epilogue_tile_m = 128
1273
1274
1275
1276
1277
1278
            weight = shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m)
            weight_scale = (
                shuffle_matrix_sf_a(weight_scale.view(torch.uint8), epilogue_tile_m)
                .reshape(weight_scale.shape)
                .view(torch.float8_e4m3fn)
            )
1279

1280
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
1281
1282
1283
            layer.weight = Parameter(weight, requires_grad=False)
        else:
            swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
1284
            layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False)
zhuwenwen's avatar
zhuwenwen committed
1285
            layer.weight = Parameter(layer.weight.data, requires_grad=False)
1286
1287
1288
1289
1290

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
1291
        bias: torch.Tensor | None = None,
1292
    ) -> torch.Tensor:
1293
        if self.backend == "marlin":
1294
1295
1296
1297
1298
1299
1300
1301
            return apply_fp4_marlin_linear(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale,
                weight_scale_2=layer.weight_scale_2,
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
1302
                bias=bias,
1303
                input_dtype=self.marlin_input_dtype,
1304
            )
1305

1306
        output_dtype = x.dtype
zhuwenwen's avatar
zhuwenwen committed
1307
        output_shape = [x.shape[0], layer.weight.shape[0]]
1308
1309

        # quantize BF16 or FP16 to (FP4 and interleaved block scale)
1310
1311
1312
        x_fp4, x_blockscale = scaled_fp4_quant(
            x, layer.input_scale_inv, is_sf_swizzled_layout=True, backend=self.backend
        )
1313
1314
1315

        # validate dtypes of quantized input, input block scale,
        # weight and weight_blockscale
1316
1317
1318
1319
1320
        assert x_fp4.dtype == torch.uint8
        assert layer.weight.dtype == torch.uint8
        assert x_blockscale.dtype == torch.float8_e4m3fn
        assert layer.weight_scale.dtype == torch.float8_e4m3fn
        assert layer.alpha.dtype == torch.float32
1321

1322
1323
1324
1325
        mm_args = (
            x_fp4,
            layer.weight,
            x_blockscale,
1326
            layer.weight_scale,
1327
1328
1329
            layer.alpha,
            output_dtype,
        )
1330
1331
1332
        if self.backend.startswith("flashinfer-"):
            backend_name = self.backend[len("flashinfer-") :]
            out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name)
1333
        else:
1334
            assert self.backend == "cutlass"
1335
1336
            out = cutlass_scaled_fp4_mm(*mm_args)

1337
1338
1339
        if bias is not None:
            out = out + bias
        return out.view(*output_shape)
1340
1341
1342
1343
1344


class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
    """
    MoE Method for FP4 Quantization.
1345
    Args:
1346
1347
1348
        quant_config: NVFP4 Quant Config
    """

1349
1350
1351
    def __init__(
        self,
        quant_config: ModelOptNvFp4Config,
1352
        moe_config: FusedMoEConfig,
1353
    ) -> None:
1354
        super().__init__(moe_config)
1355
        self.quant_config = quant_config
1356
1357
1358
1359
1360
1361
1362
        # Select experts implementation.
        self.nvfp4_backend, self.experts_cls = select_nvfp4_moe_backend(
            config=self.moe,
            weight_key=kNvfp4Static,
            activation_key=kNvfp4Dynamic,
        )

zhuwenwen's avatar
zhuwenwen committed
1363
1364
1365
        # Delay creation of the kernel until after process-weights.
        self.kernel: mk.FusedMoEModularKernel | None = None

1366
1367
1368
        self.use_global_sf = is_global_sf_supported_for_nvfp4_backend(
            self.nvfp4_backend
        )
1369

zhuwenwen's avatar
zhuwenwen committed
1370
1371
1372
1373
1374
1375
    @property
    def topk_indices_dtype(self) -> torch.dtype | None:
        if self.kernel is not None:
            return self.kernel.prepare_finalize.topk_indices_dtype()
        return None

1376
1377
1378
1379
    def maybe_make_prepare_finalize(
        self,
        routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
    ) -> mk.FusedMoEPrepareAndFinalize | None:
zhuwenwen's avatar
zhuwenwen committed
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
        if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
            return None
        elif self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTLASS:
            # For no-EP case, don't use the MKM framework.
            if not self.moe.moe_parallel_config.use_all2all_kernels:
                return None
            # For now, fp4 moe only works with the flashinfer dispatcher.
            prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(
                self.moe
            )
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
        else:
            return super().maybe_make_prepare_finalize(routing_tables)
1394

1395
1396
1397
    def select_gemm_impl(
        self,
        prepare_finalize: mk.FusedMoEPrepareAndFinalize,
1398
        layer: torch.nn.Module,
1399
    ) -> mk.FusedMoEPermuteExpertsUnpermute:
zhuwenwen's avatar
zhuwenwen committed
1400
1401
1402
1403
1404
1405
1406
        assert self.moe_quant_config is not None
        assert self.experts_cls is not None
        return make_nvfp4_moe_kernel_for_mkm(
            moe_config=self.moe,
            quant_config=self.moe_quant_config,
            experts_cls=self.experts_cls,
            prepare_finalize=prepare_finalize,
1407
        )
1408

1409
1410
1411
1412
1413
1414
    def uses_weight_scale_2_pattern(self) -> bool:
        """
        FP4 variants use 'weight_scale_2' pattern for per-tensor weight scales.
        """
        return True

1415
1416
1417
1418
1419
1420
1421
1422
1423
    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
1424
        assert self.quant_config.is_checkpoint_nvfp4_serialized
1425

1426
1427
        layer.num_experts = num_experts
        layer.params_dtype = params_dtype
1428
1429
1430
1431
        layer.quant_config = self.quant_config
        weight_dtype = torch.uint8
        weight_scale_dtype = torch.float8_e4m3fn
        weight_loader = extra_weight_attrs.get("weight_loader")
1432
        global_num_experts = extra_weight_attrs.get("global_num_experts")
1433
        w13_num_shards = 2 if self.moe.is_act_and_mul else 1
1434
1435
1436
1437
        # GEMM 1
        w13_weight = ModelWeightParameter(
            data=torch.empty(
                num_experts,
1438
                w13_num_shards * intermediate_size_per_partition,
1439
1440
                # 2 fp4 items are packed in the input dimension
                hidden_size // 2,
1441
1442
                dtype=weight_dtype,
            ),
1443
1444
            input_dim=1,
            output_dim=2,
1445
1446
            weight_loader=weight_loader,
        )
1447
1448
1449
1450
1451
1452
1453
1454
1455
        layer.register_parameter("w13_weight", w13_weight)

        # GEMM 2
        w2_weight = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                hidden_size,
                # 2 fp4 items are packed in the input dimension
                intermediate_size_per_partition // 2,
1456
1457
                dtype=weight_dtype,
            ),
1458
1459
            input_dim=1,
            output_dim=2,
1460
1461
            weight_loader=weight_loader,
        )
1462
1463
1464
1465
1466
        layer.register_parameter("w2_weight", w2_weight)

        w13_weight_scale = ModelWeightParameter(
            data=torch.empty(
                num_experts,
1467
                w13_num_shards * intermediate_size_per_partition,
1468
1469
                # 2 fp4 items are packed in the input dimension
                hidden_size // self.quant_config.group_size,
1470
1471
                dtype=weight_scale_dtype,
            ),
1472
1473
            input_dim=1,
            output_dim=2,
1474
1475
            weight_loader=weight_loader,
        )
1476
1477
1478
1479
1480
1481
1482
        layer.register_parameter("w13_weight_scale", w13_weight_scale)

        w2_weight_scale = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                hidden_size,
                # 2 fp4 items are packed in the input dimension
1483
1484
1485
                intermediate_size_per_partition // self.quant_config.group_size,
                dtype=weight_scale_dtype,
            ),
1486
1487
            input_dim=1,
            output_dim=2,
1488
1489
            weight_loader=weight_loader,
        )
1490
1491
1492
        layer.register_parameter("w2_weight_scale", w2_weight_scale)

        extra_weight_attrs.update(
1493
1494
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
        )
1495
1496

        w13_weight_scale_2 = PerTensorScaleParameter(
1497
            data=torch.empty(num_experts, w13_num_shards, dtype=torch.float32),
1498
1499
            weight_loader=weight_loader,
        )
1500
1501
1502
1503
        layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)

        w2_weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(num_experts, dtype=torch.float32),
1504
1505
            weight_loader=weight_loader,
        )
1506
1507
1508
        layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2)

        extra_weight_attrs.update(
1509
1510
            {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
        )
1511

1512
1513
        global_sf_num_experts = (
            global_num_experts if self.use_global_sf else num_experts
1514
        )
1515
        w13_input_scale = PerTensorScaleParameter(
1516
            data=torch.empty(
1517
                global_sf_num_experts,
1518
                w13_num_shards,
1519
1520
                dtype=torch.float32,
            ),
1521
1522
            weight_loader=weight_loader,
        )
1523
1524
        layer.register_parameter("w13_input_scale", w13_input_scale)

1525
        w2_input_scale = PerTensorScaleParameter(
1526
            data=torch.empty(global_sf_num_experts, dtype=torch.float32),
1527
1528
            weight_loader=weight_loader,
        )
1529
1530
        layer.register_parameter("w2_input_scale", w2_input_scale)

zhuwenwen's avatar
zhuwenwen committed
1531
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
1532
1533
1534
        """
        Convert NVFP4 MoE weights into kernel format and setup the kernel.
        """
1535

1536
        # Use a single gscale for w13.
1537
        if self.moe.is_act_and_mul and not torch.allclose(
1538
1539
            layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]
        ):
1540
1541
            logger.warning_once(
                "w1_weight_scale_2 must match w3_weight_scale_2. "
1542
1543
                "Accuracy may be affected."
            )
1544
        w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0].contiguous()
1545

1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
        (
            w13,
            w13_scale,
            w13_scale_2,
            a13_scale,
            w2,
            w2_scale,
            w2_scale_2,
            a2_scale,
        ) = convert_to_nvfp4_moe_kernel_format(
            nvfp4_backend=self.nvfp4_backend,
            layer=layer,
            w13=layer.w13_weight,
            w13_scale=layer.w13_weight_scale,
            w13_scale_2=w13_weight_scale_2,
            a13_scale=layer.w13_input_scale,
            w2=layer.w2_weight,
            w2_scale=layer.w2_weight_scale,
            w2_scale_2=layer.w2_weight_scale_2,
            a2_scale=layer.w2_input_scale,
            is_act_and_mul=self.moe.is_act_and_mul,
1567
        )
1568

1569
1570
1571
1572
1573
1574
1575
1576
        replace_parameter(layer, "w13_weight", w13)
        replace_parameter(layer, "w13_weight_scale", w13_scale)
        replace_parameter(layer, "w13_weight_scale_2", w13_scale_2)
        replace_parameter(layer, "w13_input_scale", a13_scale)
        replace_parameter(layer, "w2_weight", w2)
        replace_parameter(layer, "w2_weight_scale", w2_scale)
        replace_parameter(layer, "w2_weight_scale_2", w2_scale_2)
        replace_parameter(layer, "w2_input_scale", a2_scale)
1577

1578
1579
1580
1581
        # Setup modular kernel for TP case and naive DP/EP case.
        # In non-naive DP/EP case, we will create a ModularKernelMethod.
        # TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
        # in both cases.
1582
        self.moe_quant_config = self.get_fused_moe_quant_config(layer)
zhuwenwen's avatar
zhuwenwen committed
1583
1584
1585
1586
        if self.moe_quant_config and (
            (not self.moe.moe_parallel_config.use_all2all_kernels)
            or self.moe.moe_parallel_config.use_naive_all2all_kernels
        ):
1587
            assert self.experts_cls is not None
zhuwenwen's avatar
zhuwenwen committed
1588
            self.kernel = make_nvfp4_moe_kernel(
1589
                moe_quant_config=self.moe_quant_config,
1590
                moe_config=self.moe,
1591
                experts_cls=self.experts_cls,
1592
            )
1593

1594
1595
1596
1597
    @property
    def do_post_quant_allgather(self):
        return self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM

1598
1599
1600
1601
1602
1603
1604
    def prepare_dp_allgather_tensor(
        self,
        layer: FusedMoE,
        hidden_states: torch.Tensor,
        router_logits: torch.Tensor,
    ) -> tuple[torch.Tensor, list[torch.Tensor]]:
        """Optionally prepare extra tensors to carry through DP allgather/EP."""
1605
1606
1607
1608
1609
1610
        if self.nvfp4_backend != NvFp4MoeBackend.FLASHINFER_TRTLLM:
            raise RuntimeError(
                "prepare_dp_allgather_tensor is only supported for "
                "FlashInfer TRTLLM NVFP4 MoE backend."
            )

1611
1612
1613
1614
        import flashinfer

        hidden_states_fp4, hidden_states_sf = flashinfer.fp4_quantize(
            hidden_states,
1615
            layer.a1_gscale,
1616
1617
1618
1619
1620
            is_sf_swizzled_layout=False,
        )
        extra_tensors: list[torch.Tensor] = [hidden_states_sf]
        return hidden_states_fp4, extra_tensors

1621
    def get_fused_moe_quant_config(
1622
        self, layer: torch.nn.Module
1623
    ) -> FusedMoEQuantConfig | None:
1624
1625
1626
        return make_nvfp4_moe_quant_config(
            backend=self.nvfp4_backend,
            w13_scale=layer.w13_weight_scale,
1627
            w2_scale=layer.w2_weight_scale,
1628
1629
1630
1631
            w13_scale_2=layer.w13_weight_scale_2,
            w2_scale_2=layer.w2_weight_scale_2,
            a13_scale=layer.w13_input_scale,
            a2_scale=layer.w2_input_scale,
1632
1633
        )

1634
1635
1636
1637
    @property
    def supports_eplb(self) -> bool:
        return True

1638
1639
1640
1641
1642
1643
1644
1645
    @property
    def is_monolithic(self) -> bool:
        return (
            self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
            and not self.moe.moe_parallel_config.enable_eplb
        )

    def apply_monolithic(
1646
        self,
1647
        layer: FusedMoE,
1648
1649
        x: torch.Tensor,
        router_logits: torch.Tensor,
1650
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
1651
1652
        assert self.is_monolithic
        assert (
1653
            self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
1654
            and not layer.enable_eplb
1655
        )
1656

1657
1658
1659
        return flashinfer_trtllm_fp4_moe(
            layer=layer,
            x=x,
1660
            router_logits=router_logits,
1661
1662
1663
1664
1665
1666
1667
            top_k=layer.top_k,
            activation=layer.activation,
            global_num_experts=layer.global_num_experts,
            num_expert_group=layer.num_expert_group,
            topk_group=layer.topk_group,
            custom_routing_function=layer.custom_routing_function,
            e_score_correction_bias=layer.e_score_correction_bias,
1668
        )
1669

1670
1671
1672
1673
1674
1675
1676
1677
1678
    def apply(
        self,
        layer: FusedMoE,
        x: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        assert not self.is_monolithic

1679
        # EPLB path
1680
1681
        if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
            assert layer.enable_eplb
1682
1683
1684
1685
1686
1687
            return flashinfer_trtllm_fp4_routed_moe(
                layer=layer,
                x=x,
                topk_ids=topk_ids,
                topk_weights=topk_weights,
                top_k=layer.top_k,
1688
                activation=layer.activation,
1689
1690
                global_num_experts=layer.global_num_experts,
            )
1691
        else:
zhuwenwen's avatar
zhuwenwen committed
1692
1693
            assert self.kernel is not None
            return self.kernel(
1694
1695
1696
1697
1698
                x,
                layer.w13_weight,
                layer.w2_weight,
                topk_weights,
                topk_ids,
1699
                inplace=False,
1700
1701
1702
1703
                activation=layer.activation,
                global_num_experts=layer.global_num_experts,
                expert_map=layer.expert_map,
                apply_router_weight_on_input=layer.apply_router_weight_on_input,
1704
            )
1705
1706
1707
1708


ModelOptNvFp4Config.LinearMethodCls = ModelOptNvFp4LinearMethod
ModelOptNvFp4Config.FusedMoEMethodCls = ModelOptNvFp4FusedMoE
zhuwenwen's avatar
zhuwenwen committed
1709
ModelOptNvFp4Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod