fp8.py 58.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from enum import Enum
5
from functools import partial
6
from typing import TYPE_CHECKING, Any, Optional
7
8
9
10
11

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter

12
import vllm.envs as envs
13
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
14
from vllm import _custom_ops as ops
15
from vllm._aiter_ops import rocm_aiter_ops
16
from vllm.attention.layer import Attention
17
from vllm.distributed import get_tensor_model_parallel_world_size
18
from vllm.logger import init_logger
19
from vllm.model_executor.layers.batch_invariant import (
20
    vllm_is_batch_invariant,
21
)
bnellnm's avatar
bnellnm committed
22
from vllm.model_executor.layers.fused_moe import (
23
24
25
26
27
28
29
    FusedMoE,
    FusedMoEActivationFormat,
    FusedMoEMethodBase,
    FusedMoEPermuteExpertsUnpermute,
    FusedMoEPrepareAndFinalize,
    FusedMoeWeightScaleSupported,
)
30
from vllm.model_executor.layers.fused_moe.config import (
31
    FusedMoEParallelConfig,
32
    FusedMoEQuantConfig,
33
    RoutingMethodType,
34
35
    fp8_w8a8_moe_quant_config,
)
36
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
37
38
39
40
41
42
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
)
43
from vllm.model_executor.layers.quantization import QuantizationMethods
44
from vllm.model_executor.layers.quantization.base_config import (
45
46
47
    QuantizationConfig,
    QuantizeMethodBase,
)
48
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
49
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
50
51
    FlashinferMoeBackend,
    apply_flashinfer_per_tensor_scale_fp8,
52
    build_flashinfer_fp8_cutlass_moe_prepare_finalize,
53
54
55
56
57
58
59
    flashinfer_cutlass_moe_fp8,
    get_flashinfer_moe_backend,
    register_moe_scaling_factors,
    rotate_flashinfer_fp8_moe_weights,
    select_cutlass_fp8_gemm_impl,
    swap_w13_to_w31,
)
60
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
61
62
63
64
    W8A8BlockFp8LinearOp,
    create_fp8_input_scale,
    create_fp8_scale_parameter,
    create_fp8_weight_parameter,
65
    deepgemm_post_process_fp8_weight_block,
66
67
68
69
70
    maybe_post_process_fp8_weight_block,
    process_fp8_weight_block_strategy,
    process_fp8_weight_tensor_strategy,
    validate_fp8_block_shape,
)
71
72
73
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
    get_marlin_input_dtype,
)
74
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
75
76
77
78
    apply_fp8_marlin_linear,
    prepare_fp8_layer_for_marlin,
    prepare_moe_fp8_layer_for_marlin,
)
79
from vllm.model_executor.layers.quantization.utils.quant_utils import (
80
81
82
    GroupShape,
    is_layer_skipped,
)
83
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
84
85
86
87
88
89
90
91
92
93
94
95
96
    Fp8LinearOp,
    all_close_1d,
    cutlass_block_fp8_supported,
    cutlass_fp8_supported,
    maybe_create_device_identity,
    normalize_e4m3fn_to_e4m3fnuz,
    per_tensor_dequantize,
)
from vllm.model_executor.parameter import (
    BlockQuantScaleParameter,
    ModelWeightParameter,
    PerTensorScaleParameter,
)
97
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
98
from vllm.platforms import current_platform
99
from vllm.scalar_type import scalar_types
100
101
102
103
from vllm.utils.deep_gemm import (
    is_deep_gemm_e8m0_used,
    is_deep_gemm_supported,
)
104
from vllm.utils.flashinfer import has_flashinfer_moe
105
from vllm.utils.import_utils import has_deep_gemm
106

107
108
109
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

110
111
112
113
ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)

114

115
116
117
118
119
120
121
122
123
124
class Fp8MoeBackend(Enum):
    NONE = 0
    FLASHINFER_TRTLLM = 1
    FLASHINFER_CUTLASS = 2
    DEEPGEMM = 3
    CUTLASS_BLOCK_SCALED_GROUPED_GEMM = 4
    MARLIN = 5
    TRITON = 6


125
def get_fp8_moe_backend(
126
127
128
    block_quant: bool,
    moe_parallel_config: FusedMoEParallelConfig,
    with_lora_support: bool,
129
) -> Fp8MoeBackend:
130
131
132
133
    """
    Select the primary FP8 MoE backend
    Note: Shape-specific fallbacks may still occur at runtime.
    """
134
135
    if with_lora_support:
        return Fp8MoeBackend.TRITON
136
    # Prefer FlashInfer backends on supported GPUs; allow SM90 and SM100.
137
138
    if (
        current_platform.is_cuda()
139
        and (
140
            current_platform.is_device_capability_family(100)
141
142
            or current_platform.is_device_capability(90)
        )
143
144
145
        and envs.VLLM_USE_FLASHINFER_MOE_FP8
        and has_flashinfer_moe()
    ):
146
147
        backend = get_flashinfer_moe_backend()
        if backend == FlashinferMoeBackend.TENSORRT_LLM:
148
            logger.info_once("Using FlashInfer FP8 MoE TRTLLM backend for SM100")
149
150
            return Fp8MoeBackend.FLASHINFER_TRTLLM
        else:
151
            if block_quant and current_platform.is_device_capability_family(100):
152
153
154
155
156
157
                raise ValueError(
                    "FlashInfer FP8 MoE throughput backend does not "
                    "support block quantization. Please use "
                    "VLLM_FLASHINFER_MOE_BACKEND=latency "
                    "instead."
                )
158
            logger.info_once("Using FlashInfer FP8 MoE CUTLASS backend for SM90/SM100")
159
160
161
            return Fp8MoeBackend.FLASHINFER_CUTLASS

    # weight-only path for older GPUs without native FP8
162
163
164
165
    use_marlin = (
        not current_platform.has_device_capability(89)
        or envs.VLLM_TEST_FORCE_FP8_MARLIN
    )
166
167
168
169
170
171
    if current_platform.is_rocm():
        use_marlin = False
    if use_marlin:
        logger.info_once("Using Marlin backend for FP8 MoE")
        return Fp8MoeBackend.MARLIN

172
173
174
175
176
177
178
179
180
181
182
183
184
    # Determine if we should use DeepGEMM with block-quantized weights:
    # - If explicitly set by user, respect their choice
    # - If not explicitly set (default), disable when TP size is >= 8
    moe_use_deep_gemm = envs.VLLM_MOE_USE_DEEP_GEMM
    if not envs.is_set("VLLM_MOE_USE_DEEP_GEMM") and moe_parallel_config.tp_size >= 8:
        moe_use_deep_gemm = False
        logger.info_once(
            "DeepGEMM MoE is disabled by default when TP size is >= 8. "
            "Set VLLM_MOE_USE_DEEP_GEMM=1 to enable it.",
            scope="local",
        )

    if envs.VLLM_USE_DEEP_GEMM and moe_use_deep_gemm and block_quant:
185
        if not has_deep_gemm():
186
187
188
            logger.warning_once(
                "DeepGEMM backend requested but not available.", scope="local"
            )
189
        elif is_deep_gemm_supported():
190
            logger.info_once("Using DeepGEMM backend for FP8 MoE", scope="local")
191
192
193
            return Fp8MoeBackend.DEEPGEMM

    # CUTLASS BlockScaled GroupedGemm on SM100 with block-quantized weights
194
195
    if (
        current_platform.is_cuda()
196
        and current_platform.is_device_capability_family(100)
197
198
        and block_quant
    ):
199
200
201
        logger.info_once(
            "Using Cutlass BlockScaled GroupedGemm backend for FP8 MoE", scope="local"
        )
202
203
204
205
206
207
208
        return Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM

    # default to Triton
    logger.info_once("Using Triton backend for FP8 MoE")
    return Fp8MoeBackend.TRITON


209
class Fp8Config(QuantizationConfig):
210
211
    """Config class for FP8."""

212
213
    def __init__(
        self,
214
        is_checkpoint_fp8_serialized: bool = False,
215
        activation_scheme: str = "dynamic",
216
217
        ignored_layers: list[str] | None = None,
        weight_block_size: list[int] | None = None,
218
    ) -> None:
219
        super().__init__()
220

221
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
222

223
        if activation_scheme not in ACTIVATION_SCHEMES:
224
            raise ValueError(f"Unsupported activation scheme {activation_scheme}")
225
        self.activation_scheme = activation_scheme
226
        self.ignored_layers = ignored_layers or []
227
228
229
230
        if weight_block_size is not None:
            if not is_checkpoint_fp8_serialized:
                raise ValueError(
                    "The block-wise quantization only supports fp8-serialized "
231
232
                    "checkpoint for now."
                )
233
234
235
            if len(weight_block_size) != 2:
                raise ValueError(
                    "The quantization block size of weight must have 2 "
236
237
                    f"dimensions, but got {len(weight_block_size)} dimensions"
                )
238
            if activation_scheme != "dynamic":
239
240
241
242
243
                raise ValueError(
                    "The block-wise quantization only supports "
                    "dynamic activation scheme for now, but got "
                    f"{activation_scheme} activation scheme."
                )
244
        self.weight_block_size = weight_block_size
245

246
    @classmethod
247
    def get_name(cls) -> QuantizationMethods:
248
249
250
        return "fp8"

    @classmethod
251
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
252
253
254
255
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
256
        return 75
257
258

    @classmethod
259
    def get_config_filenames(cls) -> list[str]:
260
261
        return []

262
263
    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if self.ignored_layers is not None:
264
            self.ignored_layers = hf_to_vllm_mapper.apply_list(self.ignored_layers)
265

266
    @classmethod
267
    def from_config(cls, config: dict[str, Any]) -> "Fp8Config":
268
        quant_method = cls.get_from_keys(config, ["quant_method"])
269
        is_checkpoint_fp8_serialized = "fp8" in quant_method
270
        activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
271
        ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
272
        weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None)
273
        if not ignored_layers:
274
275
276
277
278
279
280
281
282
283
284
285
286
            ignored_layers = cls.get_from_keys_or(
                config, ["modules_to_not_convert"], None
            )
        return cls(
            is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
            activation_scheme=activation_scheme,
            ignored_layers=ignored_layers,
            weight_block_size=weight_block_size,
        )

    def get_xpu_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
287
        from vllm.model_executor.layers.quantization.ipex_quant import (
288
289
290
291
            XPUFp8LinearMethod,
            XPUFp8MoEMethod,
        )

292
293
294
295
        fp8_config = Fp8Config(
            is_checkpoint_fp8_serialized=self.is_checkpoint_fp8_serialized,
            activation_scheme=self.activation_scheme,
            ignored_layers=self.ignored_layers,
296
297
            weight_block_size=self.weight_block_size,
        )
298
299

        if isinstance(layer, LinearBase):
300
301
302
303
304
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
305
306
307
308
309
310
311
312
                return UnquantizedLinearMethod()
            return XPUFp8LinearMethod(fp8_config)
        elif isinstance(layer, FusedMoE):
            return XPUFp8MoEMethod(fp8_config, layer)
        elif isinstance(layer, Attention):
            return Fp8KVCacheMethod(self)
        return None

313
314
315
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
316
317
        if current_platform.is_xpu():
            return self.get_xpu_quant_method(layer, prefix)
318
        if isinstance(layer, LinearBase):
319
320
321
322
323
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
324
                return UnquantizedLinearMethod()
325
326
327
            quant_method = Fp8LinearMethod(self)
            quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
            return quant_method
328
        elif isinstance(layer, FusedMoE):
329
330
331
332
333
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
XuruiYang's avatar
XuruiYang committed
334
                return UnquantizedFusedMoEMethod(layer.moe_config)
335
336
337
338
            if self.is_checkpoint_fp8_serialized:
                moe_quant_method = Fp8MoEMethod(self, layer)
            else:
                moe_quant_method = Fp8OnlineMoEMethod(self, layer)
339
340
            moe_quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
            return moe_quant_method
341
        elif isinstance(layer, Attention):
342
            return Fp8KVCacheMethod(self)
343
        return None
344

345
    def get_cache_scale(self, name: str) -> str | None:
346
347
348
349
350
351
352
353
354
355
356
357
        """
        Check whether the param name matches the format for k/v cache scales
        in compressed-tensors. If this is the case, return its equivalent
        param name expected by vLLM

        :param name: param name
        :return: matching param name for KV cache scale in vLLM
        """
        if name.endswith(".output_scale") and ".k_proj" in name:
            return name.replace(".k_proj.output_scale", ".attn.k_scale")
        if name.endswith(".output_scale") and ".v_proj" in name:
            return name.replace(".v_proj.output_scale", ".attn.v_scale")
358
359
360
361
362
        if name.endswith(".output_scale") and ".q_proj" in name:
            return name.replace(".q_proj.output_scale", ".attn.q_scale")
        if name.endswith("self_attn.prob_output_scale"):
            return name.replace(".prob_output_scale", ".attn.prob_scale")
        # If no matches, return None
363
364
        return None

365
366
367

class Fp8LinearMethod(LinearMethodBase):
    """Linear method for FP8.
368
369
370
371
372
373
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.
374
375
376
377
378

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
    2. Only support float8_e4m3fn data type due to the limitation of
       torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
379

380
381
382
383
    Args:
        quant_config: The quantization config.
    """

384
    def __init__(self, quant_config: Fp8Config):
385
        self.quant_config = quant_config
386
        self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
387
        self.out_dtype = torch.get_default_dtype()
388

389
390
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
391
        self.marlin_input_dtype = None
392
393
394
395
        self.use_marlin = (
            not current_platform.has_device_capability(89)
            or envs.VLLM_TEST_FORCE_FP8_MARLIN
        )
396
        # Disable marlin for rocm
397
        if current_platform.is_rocm():
398
            self.use_marlin = False
399
        if vllm_is_batch_invariant():
400
            self.use_marlin = False
401

402
        self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enaled()
403
        self.use_deep_gemm = is_deep_gemm_supported()
404

405
406
        self.weight_block_size = self.quant_config.weight_block_size
        self.block_quant = self.weight_block_size is not None
407
        self.act_q_static = self.quant_config.activation_scheme == "static"
408
409
        if self.weight_block_size:
            self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
410
        else:
411
412
413
414
415
            # Use per-token quantization for better perf if dynamic and cutlass
            if not self.act_q_static and cutlass_fp8_supported():
                self.act_q_group_shape = GroupShape.PER_TOKEN
            else:
                self.act_q_group_shape = GroupShape.PER_TENSOR
416

417
418
419
420
421
422
423
424
425
426
427
428
        if self.block_quant:
            assert not self.act_q_static
            assert self.weight_block_size is not None
            self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
                weight_group_shape=GroupShape(*self.weight_block_size),
                act_quant_group_shape=self.act_q_group_shape,
                cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
                use_aiter_and_is_supported=self.use_aiter_and_is_supported,
            )
        else:
            self.fp8_linear = Fp8LinearOp(
                act_quant_static=self.act_q_static,
429
430
                act_quant_group_shape=self.act_q_group_shape,
            )
431

432
433
434
435
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
436
        output_partition_sizes: list[int],
437
438
439
440
441
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
442
443
        maybe_create_device_identity()

444
        output_size_per_partition = sum(output_partition_sizes)
445
        weight_loader = extra_weight_attrs.get("weight_loader")
446
447
448
449
450
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None
451

452
        if self.block_quant:
453
454
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
455
456
457
458
459
460
461
462
            validate_fp8_block_shape(
                layer,
                input_size,
                output_size,
                input_size_per_partition,
                output_partition_sizes,
                self.weight_block_size,
            )
463

464
        # WEIGHT
465
        if self.quant_config.is_checkpoint_fp8_serialized:
466
467
468
            weight = create_fp8_weight_parameter(
                output_size_per_partition, input_size_per_partition, weight_loader
            )
469
        else:
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493

            def patched_weight_loader(param, loaded_weight, *args, **kwargs):
                # load the current weight chunk
                res = weight_loader(param, loaded_weight, *args, **kwargs)  # type: ignore[misc]

                # track how many elements we have updated
                if not hasattr(layer, "_loaded_numel"):
                    layer._loaded_numel = 0
                layer._loaded_numel += loaded_weight.numel()

                # if we have loaded all of the elements, call
                # process_weights_after_loading
                target_loaded_numel = layer.weight.numel()
                if layer._loaded_numel == target_loaded_numel:
                    self.process_weights_after_loading(layer)

                    # Delete the bookkeeping
                    del layer._loaded_numel
                    # Prevent the usual `process_weights_after_loading` call from doing
                    # anything
                    layer._already_called_process_weights_after_loading = True

                return res

494
            # For non-serialized checkpoints, use original dtype
495
496
497
498
499
500
501
502
            weight = ModelWeightParameter(
                data=torch.empty(
                    output_size_per_partition,
                    input_size_per_partition,
                    dtype=params_dtype,
                ),
                input_dim=1,
                output_dim=0,
503
                weight_loader=patched_weight_loader,
504
            )
505
506
        layer.register_parameter("weight", weight)

507
508
509
510
        # If checkpoint is serialized fp8, load them.
        # Otherwise, wait until process_weights_after_loading.
        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
511
            if not self.block_quant:
512
513
514
515
516
517
518
                scale = create_fp8_scale_parameter(
                    PerTensorScaleParameter,
                    output_partition_sizes,
                    input_size_per_partition,
                    None,
                    weight_loader,
                )
519
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
520
521
                layer.register_parameter("weight_scale", scale)
            else:
522
523
                assert not self.act_q_static
                assert self.weight_block_size is not None
524
525
526
527
528
529
530
                scale = create_fp8_scale_parameter(
                    BlockQuantScaleParameter,
                    output_partition_sizes,
                    input_size_per_partition,
                    self.weight_block_size,
                    weight_loader,
                )
531
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
532
533
                # The weight_scale_inv name is intentional for deepseekv3
                layer.register_parameter("weight_scale_inv", scale)
534

535
            # INPUT ACTIVATION SCALE
536
            if self.act_q_static:
537
                scale = create_fp8_input_scale(output_partition_sizes, weight_loader)
538
                set_weight_attrs(scale, {"scale_type": "input_scale"})
539
                layer.register_parameter("input_scale", scale)
540
541
            else:
                layer.register_parameter("input_scale", None)
542

543
    def process_weights_after_loading(self, layer: Module) -> None:
544
545
546
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

547
        size_k_first = True
548
        input_scale = None
549
        # TODO(rob): refactor block quant into separate class.
550
        if self.block_quant:
551
            assert not self.act_q_static
552
            size_k_first = False
553

554
            weight, weight_scale_inv = process_fp8_weight_block_strategy(
555
556
                layer.weight, layer.weight_scale_inv
            )
557
558
559
560

            # Update layer with new values
            replace_parameter(layer, "weight", weight.data)
            replace_parameter(layer, "weight_scale_inv", weight_scale_inv.data)
561

562
        # If checkpoint not serialized fp8, quantize the weights.
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
        else:
            if not self.quant_config.is_checkpoint_fp8_serialized:
                qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
                weight = qweight.t()

            # If checkpoint is fp8 per-tensor, handle that there are N scales for N
            # shards in a fused module
            else:
                weight = layer.weight
                weight_scale = layer.weight_scale

                # If using w8a8, torch._scaled_mm needs per tensor, so
                # requantize the logical shards as a single weight.
                if not self.use_marlin:
                    weight, weight_scale, input_scale = (
                        process_fp8_weight_tensor_strategy(
                            weight,
                            weight_scale,
                            layer.logical_widths,
                            getattr(layer, "input_scale", None),
                        )
                    )
                    if self.act_q_static:
                        assert input_scale is not None
                        input_scale = input_scale.max()
                weight = weight.t()
589

590
591
592
593
594
595
            # Update layer with new values.
            replace_parameter(layer, "weight", weight.data)
            replace_parameter(layer, "weight_scale", weight_scale.data)

        if input_scale is not None:
            replace_parameter(layer, "input_scale", input_scale)
596
        else:
597
            layer.input_scale = None
598

599
        if self.use_marlin:
600
601
602
            prepare_fp8_layer_for_marlin(
                layer, size_k_first, input_dtype=self.marlin_input_dtype
            )
603
604
            # Activations not quantized for marlin.
            del layer.input_scale
605
            return
606

607
        if self.block_quant:
608
            maybe_post_process_fp8_weight_block(layer)
609

610
611
612
613
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
614
        bias: torch.Tensor | None = None,
615
    ) -> torch.Tensor:
616
617
        # if batch invariant mode is enabled, prefer DeepGEMM FP8 path
        # we will use BF16 dequant when DeepGEMM is not supported.
618
        if vllm_is_batch_invariant():
619
620
            if self.block_quant:
                assert self.weight_block_size is not None
621
622
623
                return self.w8a8_block_fp8_linear.apply(
                    input=x,
                    weight=layer.weight,
624
                    weight_scale=layer.weight_scale_inv,
625
626
627
                    input_scale=layer.input_scale,
                    bias=bias,
                )
628
            else:
629
630
631
                # per-tensor/channel: dequant to BF16 and run GEMM
                weight_fp8 = layer.weight.to(torch.bfloat16)
                weight_scale = layer.weight_scale.to(torch.bfloat16)
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
                if weight_scale.numel() == 1:
                    # Per-tensor: simple scalar multiplication
                    weight_bf16 = weight_fp8 * weight_scale
                else:
                    # Multiple scales (fused modules like QKV)
                    # Try to infer correct broadcasting
                    # weight is [K, N], scale could be [num_logical_weights]
                    # Need to figure out how to broadcast - for now just try
                    # direct multiplication
                    if (
                        weight_scale.dim() == 1
                        and weight_scale.shape[0] == weight_fp8.shape[0]
                    ):
                        # Per-row scaling
                        weight_bf16 = weight_fp8 * weight_scale.unsqueeze(1)
                    else:
                        # Fallback
                        weight_bf16 = weight_fp8 * weight_scale
650
                return torch.nn.functional.linear(x, weight_bf16.t(), bias)
651

652
        if self.use_marlin:
653
654
655
656
657
            if self.block_quant:
                weight_scale = layer.weight_scale_inv
            else:
                weight_scale = layer.weight_scale

658
659
660
            return apply_fp8_marlin_linear(
                input=x,
                weight=layer.weight,
661
                weight_scale=weight_scale,
662
663
664
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
665
                input_dtype=self.marlin_input_dtype,
666
667
                bias=bias,
            )
668

669
        if self.block_quant:
670
671
672
            assert self.weight_block_size is not None

            return self.w8a8_block_fp8_linear.apply(
673
                input=x,
674
                weight=layer.weight,
675
                weight_scale=layer.weight_scale_inv,
676
                input_scale=layer.input_scale,
677
                bias=bias,
678
            )
679

680
681
682
683
684
685
686
687
        return self.fp8_linear.apply(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            out_dtype=self.out_dtype,
            input_scale=layer.input_scale,
            bias=bias,
        )
688
689


690
691
692
693
694
695
696
697
698
699
700
701
702
class Fp8MoEMethod(FusedMoEMethodBase):
    """MoE method for FP8.
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

703
704
705
    def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
        super().__init__(layer.moe_config)
        self.layer = layer
706
        self.quant_config = quant_config
707
        self.weight_block_size = self.quant_config.weight_block_size
708
        self.block_quant: bool = self.weight_block_size is not None
709
        self.fp8_backend = get_fp8_moe_backend(
710
            self.block_quant, layer.moe_parallel_config, self.moe.is_lora_enabled
711
        )
712

713
        self.marlin_input_dtype = None
714
        self.use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN
715
        self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
716
717
718
719
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
            self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
        elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
            self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS
720
721
722
723
724
725
726
727
728
729
            if self.block_quant:
                assert self.weight_block_size == [128, 128], (
                    f"Only support weight_block_size == [128, 128], "
                    f"got {self.weight_block_size}"
                )
            self.flashinfer_moe_fn = partial(
                flashinfer_cutlass_moe_fp8,
                moe=self.moe,
                use_deepseek_fp8_block_scale=self.block_quant,
            )
730

731
        self.allow_deep_gemm = self.fp8_backend == Fp8MoeBackend.DEEPGEMM
732
733
734
        self.allow_cutlass_block_scaled_grouped_gemm = (
            self.fp8_backend == Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM
        )
735

736
737
738
739
740
741
742
743
744
    def create_weights(
        self,
        layer: Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
745
746
747
748
749
750
        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

751
752
753
        assert self.quant_config.is_checkpoint_fp8_serialized
        params_dtype = torch.float8_e4m3fn

754
        if self.block_quant:
755
756
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
757
758
            tp_size = get_tensor_model_parallel_world_size()
            block_n, block_k = (
759
760
                self.weight_block_size[0],
                self.weight_block_size[1],
761
762
763
764
765
            )
            # NOTE: To ensure proper alignment of the block-wise quantization
            # scales, the output_size of the weights for both the gate and up
            # layers must be divisible by block_n.
            # Required by column parallel or enabling merged weights
766
            if intermediate_size_per_partition % block_n != 0:
767
768
                raise ValueError(
                    f"The output_size of gate's and up's weight = "
769
                    f"{intermediate_size_per_partition} is not divisible by "
770
771
772
                    f"weight quantization block_n = {block_n}."
                )
            if tp_size > 1 and intermediate_size_per_partition % block_k != 0:
773
                # Required by row parallel
774
775
776
                raise ValueError(
                    f"The input_size of down's weight = "
                    f"{intermediate_size_per_partition} is not divisible by "
777
778
                    f"weight quantization block_k = {block_k}."
                )
779
780

        # WEIGHTS
781
782
783
784
785
786
787
788
789
        w13_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                hidden_size,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
790
791
792
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

793
794
795
796
797
798
799
800
801
        w2_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
802
803
804
805
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
806
807
808
        if not self.block_quant:
            # Allocate 2 scales for w1 and w3 respectively.
            # They will be combined to a single scale after weight loading.
809
810
811
812
813
814
            w13_weight_scale = torch.nn.Parameter(
                torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
            )
            w2_weight_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
815
816
817
818
819
820
            layer.register_parameter("w13_weight_scale", w13_weight_scale)
            layer.register_parameter("w2_weight_scale", w2_weight_scale)
        else:
            w13_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
821
                    2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
822
823
824
825
826
827
828
829
830
                    (hidden_size + block_k - 1) // block_k,
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            w2_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
                    (hidden_size + block_n - 1) // block_n,
831
                    (intermediate_size_per_partition + block_k - 1) // block_k,
832
833
834
835
836
837
838
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
            layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
            assert self.quant_config.activation_scheme == "dynamic"
839

840
841
842
        # Add the quantization method used (per tensor/grouped/channel)
        # to ensure the weight scales are loaded in properly
        extra_weight_attrs.update(
843
844
845
846
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
            if self.block_quant
            else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
        )
847
848
        set_weight_attrs(w13_weight_scale, extra_weight_attrs)
        set_weight_attrs(w2_weight_scale, extra_weight_attrs)
849
850
851

        # INPUT_SCALES
        if self.quant_config.activation_scheme == "static":
852
853
854
            w13_input_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
855
            layer.register_parameter("w13_input_scale", w13_input_scale)
856
            set_weight_attrs(w13_input_scale, extra_weight_attrs)
857

858
859
860
            w2_input_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
861
            layer.register_parameter("w2_input_scale", w2_input_scale)
862
863
            set_weight_attrs(w2_input_scale, extra_weight_attrs)

864
        else:
865
866
            layer.w13_input_scale = None
            layer.w2_input_scale = None
867

868
869
        self.rocm_aiter_moe_enabled = False

870
    def process_weights_after_loading(self, layer: Module) -> None:
871
872
873
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

874
875
        # Lazy import to avoid importing triton too early.

876
        self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
877

878
        # TODO (rob): refactor block quant into separate class.
879
        if self.block_quant:
880
            assert self.quant_config.activation_scheme == "dynamic"
881
            if current_platform.is_fp8_fnuz():
882
                w13_weight, w13_weight_scale_inv, w13_input_scale = (
883
                    normalize_e4m3fn_to_e4m3fnuz(
884
885
886
887
888
889
                        layer.w13_weight,
                        layer.w13_weight_scale_inv,
                        layer.w13_input_scale,
                    )
                )
                w2_weight, w2_weight_scale_inv, w2_input_scale = (
890
                    normalize_e4m3fn_to_e4m3fnuz(
891
892
893
                        layer.w2_weight, layer.w2_weight_scale_inv, layer.w2_input_scale
                    )
                )
894
            elif self.flashinfer_moe_backend is not None:
895
896
                # NOTE: weights have to be swapped since the activation is
                # applied on different half for flashinfer vs vllm
897
                w13_weight = swap_w13_to_w31(layer.w13_weight.data)
898
                w13_weight_scale_inv = swap_w13_to_w31(layer.w13_weight_scale_inv.data)
899
900
                w2_weight = layer.w2_weight.data
                w2_weight_scale_inv = layer.w2_weight_scale_inv.data
901
902
903
904
905
906
907
            else:
                w13_weight = layer.w13_weight.data
                w13_weight_scale_inv = layer.w13_weight_scale_inv.data
                w2_weight = layer.w2_weight
                w2_weight_scale_inv = layer.w2_weight_scale_inv

            # torch.compile() cannot use Parameter subclasses.
908
909
910
911
            replace_parameter(layer, "w13_weight", w13_weight)
            replace_parameter(layer, "w13_weight_scale_inv", w13_weight_scale_inv)
            replace_parameter(layer, "w2_weight", w2_weight)
            replace_parameter(layer, "w2_weight_scale_inv", w2_weight_scale_inv)
912
            if self.rocm_aiter_moe_enabled:
913
                # reshaping weights is required for aiter moe kernel.
914
                shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
915
916
                    layer.w13_weight.data, layer.w2_weight.data
                )
917

918
919
                replace_parameter(layer, "w13_weight", shuffled_w13)
                replace_parameter(layer, "w2_weight", shuffled_w2)
920

921
            # DeepGemm scales need to be transposed and aligned. We try to do
922
            # it ahead of time for performance reasons.
923
924
925
926
927
928
929
            if self.allow_deep_gemm:
                dg_w13_weight, dg_w13_weight_scale_inv = (
                    deepgemm_post_process_fp8_weight_block(
                        wq=layer.w13_weight.data,
                        ws=layer.w13_weight_scale_inv.data,
                        quant_block_shape=tuple(layer.weight_block_size),
                        use_e8m0=is_deep_gemm_e8m0_used(),
930
                    )
931
932
933
934
935
936
937
                )
                dg_w2_weight, dg_w2_weight_scale_inv = (
                    deepgemm_post_process_fp8_weight_block(
                        wq=layer.w2_weight.data,
                        ws=layer.w2_weight_scale_inv.data,
                        quant_block_shape=tuple(layer.weight_block_size),
                        use_e8m0=is_deep_gemm_e8m0_used(),
938
                    )
939
940
941
942
943
944
945
946
947
                )
                layer.w13_weight = Parameter(dg_w13_weight, requires_grad=False)
                layer.w13_weight_scale_inv = Parameter(
                    dg_w13_weight_scale_inv, requires_grad=False
                )
                layer.w2_weight = Parameter(dg_w2_weight, requires_grad=False)
                layer.w2_weight_scale_inv = Parameter(
                    dg_w2_weight_scale_inv, requires_grad=False
                )
948
949
950
951
        else:
            # Fp8 moe kernels require a single activation scale.
            # We take the max of all the scales in case they differ.
            if self.quant_config.activation_scheme == "static":
952
                if layer.w13_input_scale is None or layer.w2_input_scale is None:
953
954
                    raise ValueError(
                        "QuantConfig has static quantization, but found "
955
956
957
958
959
                        "activation scales are None."
                    )
                if not all_close_1d(layer.w13_input_scale) or not all_close_1d(
                    layer.w2_input_scale
                ):
960
                    logger.warning_once(
961
962
                        "Found input_scales that are not equal for "
                        "fp8 MoE layer. Using the maximum across experts "
963
964
                        "for each layer."
                    )
965
966
                replace_parameter(layer, "w13_input_scale", layer.w13_input_scale.max())
                replace_parameter(layer, "w2_input_scale", layer.w2_input_scale.max())
967
            if current_platform.is_fp8_fnuz():
968
                # Normalize the weights and scales
969
                w13_weight, w13_weight_scale, w13_input_scale = (
970
                    normalize_e4m3fn_to_e4m3fnuz(
971
972
973
974
                        layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale
                    )
                )
                w2_weight, w2_weight_scale, w2_input_scale = (
975
                    normalize_e4m3fn_to_e4m3fnuz(
976
977
978
                        layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale
                    )
                )
979
                # Reset the parameter
980
981
                replace_parameter(layer, "w13_weight", w13_weight)
                replace_parameter(layer, "w13_weight_scale", w13_weight_scale)
982
                if w13_input_scale is not None:
983
984
985
                    replace_parameter(layer, "w13_input_scale", w13_input_scale)
                replace_parameter(layer, "w2_weight", w2_weight)
                replace_parameter(layer, "w2_weight_scale", w2_weight_scale)
986
                if w2_input_scale is not None:
987
                    replace_parameter(layer, "w2_input_scale", w2_input_scale)
988
989
990

            # Fp8 moe kernel needs single weight scale for w13 per expert.
            # We take the max then dequant and requant each expert.
991
            assert layer.w13_weight_scale is not None
992
            shard_size = layer.intermediate_size_per_partition
993
            max_w13_scales = layer.w13_weight_scale.max(dim=1).values
994
            for expert_id in range(layer.local_num_experts):
995
996
997
                start = 0
                for shard_id in range(2):
                    dq_weight = per_tensor_dequantize(
998
999
1000
1001
1002
1003
                        layer.w13_weight[expert_id][start : start + shard_size, :],
                        layer.w13_weight_scale[expert_id][shard_id],
                    )
                    layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
                        ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
                    )
1004
1005
                    start += shard_size

1006
            if self.rocm_aiter_moe_enabled:
1007
                shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
1008
1009
                    layer.w13_weight, layer.w2_weight
                )
1010

1011
1012
                replace_parameter(layer, "w13_weight", shuffled_w13)
                replace_parameter(layer, "w2_weight", shuffled_w2)
1013

1014
            replace_parameter(layer, "w13_weight_scale", max_w13_scales)
1015

1016
1017
1018
1019
1020
1021
            if self.flashinfer_moe_backend is not None:
                # NOTE: weights have to be swapped since the activation is
                # applied on different half for flashinfer vs vllm
                assert not self.block_quant
                register_moe_scaling_factors(layer)
                w13_weight = swap_w13_to_w31(layer.w13_weight.data)
1022
                if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
1023
1024
1025
                    rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
                layer.w13_weight.data = w13_weight.data

1026
        if self.use_marlin:
1027
1028
1029
            prepare_moe_fp8_layer_for_marlin(
                layer, False, input_dtype=self.marlin_input_dtype
            )
1030
1031
1032
            # Activations not quantized for marlin.
            del layer.w13_input_scale
            del layer.w2_input_scale
1033

1034
1035
1036
1037
    def maybe_make_prepare_finalize(
        self,
        routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
    ) -> mk.FusedMoEPrepareAndFinalize | None:
1038
1039
1040
1041
1042
        if (
            self.rocm_aiter_moe_enabled
            or self.use_marlin
            or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1043
1044
            return None
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
1045
1046
1047
1048
1049
1050
            if self.block_quant:
                assert self.weight_block_size == [128, 128], (
                    f"Only support weight_block_size == [128, 128], "
                    f"got {self.weight_block_size}"
                )
            # Wire block-scale flag through prepare/finalize when using CUTLASS
1051
            prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
1052
1053
                self.moe,
                use_deepseek_fp8_block_scale=self.block_quant,
1054
            )
1055
1056
1057
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
        else:
1058
            return super().maybe_make_prepare_finalize(routing_tables)
1059

bnellnm's avatar
bnellnm committed
1060
1061
1062
    def select_gemm_impl(
        self,
        prepare_finalize: FusedMoEPrepareAndFinalize,
1063
        layer: torch.nn.Module,
bnellnm's avatar
bnellnm committed
1064
    ) -> FusedMoEPermuteExpertsUnpermute:
1065
        from vllm.model_executor.layers.fused_moe import (
1066
1067
            BatchedDeepGemmExperts,
            BatchedTritonExperts,
1068
            TritonExperts,
1069
1070
            TritonOrDeepGemmExperts,
        )
1071

1072
        assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
1073
1074
            "Marlin and ROCm AITER are not supported with all2all yet."
        )
1075

1076
1077
        assert self.moe_quant_config is not None

1078
1079
1080
1081
1082
        if (
            prepare_finalize.activation_format
            == FusedMoEActivationFormat.BatchedExperts
        ):
            max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
bnellnm's avatar
bnellnm committed
1083
            assert max_num_tokens_per_rank is not None
1084
1085
1086
1087

            experts_impl = (
                BatchedDeepGemmExperts if self.allow_deep_gemm else BatchedTritonExperts
            )
bnellnm's avatar
bnellnm committed
1088
            logger.debug(
1089
1090
                "%s(%s): max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
                experts_impl.__name__,
1091
1092
1093
1094
1095
                self.__class__.__name__,
                max_num_tokens_per_rank,
                self.weight_block_size,
                False,
            )
1096
            return experts_impl(
1097
                max_num_tokens=max_num_tokens_per_rank,
1098
                num_dispatchers=prepare_finalize.num_dispatchers(),
1099
                quant_config=self.moe_quant_config,
1100
            )
1101
1102
        elif self.moe.is_lora_enabled:
            return TritonExperts(quant_config=self.moe_quant_config)
1103
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
1104
            # Select GEMM experts with block-scale when weights are block-quantized
1105
            experts = select_cutlass_fp8_gemm_impl(
1106
1107
                self.moe,
                self.moe_quant_config,
1108
                use_deepseek_fp8_block_scale=self.block_quant,
1109
1110
1111
            )
            logger.debug_once("Using %s", experts.__class__.__name__)
            return experts
1112
        else:
bnellnm's avatar
bnellnm committed
1113
1114
            logger.debug(
                "TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s",
1115
1116
1117
1118
                self.__class__.__name__,
                self.weight_block_size,
                False,
            )
bnellnm's avatar
bnellnm committed
1119
            return TritonOrDeepGemmExperts(
1120
                quant_config=self.moe_quant_config,
1121
1122
1123
                allow_deep_gemm=self.allow_deep_gemm,
            )

1124
    def get_fused_moe_quant_config(
1125
        self, layer: torch.nn.Module
1126
    ) -> FusedMoEQuantConfig | None:
1127
1128
1129
1130
        if self.use_marlin:
            return None

        return fp8_w8a8_moe_quant_config(
1131
1132
1133
1134
1135
1136
1137
1138
            w1_scale=(
                layer.w13_weight_scale_inv
                if self.block_quant
                else layer.w13_weight_scale
            ),
            w2_scale=(
                layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
            ),
1139
1140
            a1_scale=layer.w13_input_scale,
            a2_scale=layer.w2_input_scale,
1141
            block_shape=self.weight_block_size,
1142
1143
        )

1144
1145
1146
1147
1148
1149
1150
1151
    @property
    def supports_eplb(self) -> bool:
        return True

    @property
    def allow_inplace(self) -> bool:
        return True

1152
1153
    def apply(
        self,
1154
        layer: FusedMoE,
1155
1156
        x: torch.Tensor,
        router_logits: torch.Tensor,
1157
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
1158
        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
1159
1160
1161
1162
            if layer.enable_eplb:
                raise NotImplementedError("EPLB not supported for `Fp8MoEMethod` yet.")
            assert layer.activation == "silu", (
                f"Expected 'silu' activation but got {layer.activation}"
1163
            )
1164

1165
            if self.block_quant:
1166
                import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe  # noqa: E501, F401
1167
1168

                e_score_correction_bias = (
1169
1170
                    layer.e_score_correction_bias.to(x.dtype)
                    if layer.e_score_correction_bias is not None
1171
1172
                    else None
                )
1173
                routing_method_type = layer.routing_method_type
1174
                return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
1175
1176
1177
                    routing_logits=router_logits.to(torch.float32)
                    if routing_method_type == RoutingMethodType.DeepSeekV3
                    else router_logits,
1178
1179
1180
1181
1182
1183
                    routing_bias=e_score_correction_bias,
                    x=x,
                    w13_weight=layer.w13_weight,
                    w13_weight_scale_inv=layer.w13_weight_scale_inv,
                    w2_weight=layer.w2_weight,
                    w2_weight_scale_inv=layer.w2_weight_scale_inv,
1184
1185
1186
1187
                    global_num_experts=layer.global_num_experts,
                    top_k=layer.top_k,
                    num_expert_group=layer.num_expert_group,
                    topk_group=layer.topk_group,
1188
1189
1190
                    intermediate_size=layer.intermediate_size_per_partition,
                    expert_offset=layer.ep_rank * layer.local_num_experts,
                    local_num_experts=layer.local_num_experts,
1191
                    block_shape=self.weight_block_size,
1192
                    routing_method_type=routing_method_type,
1193
                    routed_scaling=layer.routed_scaling_factor,
1194
1195
                )
            else:
1196
1197
1198
                assert (
                    not layer.renormalize and layer.custom_routing_function is not None
                )
XuruiYang's avatar
XuruiYang committed
1199
                result = apply_flashinfer_per_tensor_scale_fp8(
1200
1201
1202
                    layer=layer,
                    hidden_states=x,
                    router_logits=router_logits,
1203
1204
1205
1206
1207
1208
                    routing_bias=layer.e_score_correction_bias,
                    global_num_experts=layer.global_num_experts,
                    top_k=layer.top_k,
                    num_expert_group=layer.num_expert_group,
                    topk_group=layer.topk_group,
                    apply_router_weight_on_input=layer.apply_router_weight_on_input,
1209
                )
1210

1211
        select_result = layer.select_experts(
1212
1213
1214
1215
            hidden_states=x,
            router_logits=router_logits,
        )

XuruiYang's avatar
XuruiYang committed
1216
1217
        topk_weights, topk_ids, zero_expert_result = select_result

1218
1219
        if self.rocm_aiter_moe_enabled:
            from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (  # noqa: E501
1220
1221
1222
                rocm_aiter_fused_experts,
            )

XuruiYang's avatar
XuruiYang committed
1223
            result = rocm_aiter_fused_experts(
1224
1225
1226
                x,
                layer.w13_weight,
                layer.w2_weight,
1227
1228
                topk_weights=topk_weights,
                topk_ids=topk_ids,
1229
1230
1231
                activation=layer.activation,
                apply_router_weight_on_input=layer.apply_router_weight_on_input,
                expert_map=layer.expert_map,
1232
1233
                quant_config=self.moe_quant_config,
            )
1234
        elif self.use_marlin:
1235
1236
1237
            assert layer.activation == "silu", (
                f"{layer.activation} not supported for Marlin MoE."
            )
1238
            result = fused_marlin_moe(
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
                x,
                layer.w13_weight,
                layer.w2_weight,
                None,
                None,
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                router_logits,
                topk_weights,
                topk_ids,
                quant_type_id=scalar_types.float8_e4m3fn.id,
1250
1251
1252
                apply_router_weight_on_input=layer.apply_router_weight_on_input,
                global_num_experts=layer.global_num_experts,
                expert_map=layer.expert_map,
1253
                input_dtype=self.marlin_input_dtype,
1254
1255
                workspace=layer.workspace,
            )
1256
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
1257
1258
            assert layer.activation == "silu", (
                f"Expected 'silu' activation but got {layer.activation}"
1259
            )
1260
            if not self.block_quant:
1261
1262
1263
1264
1265
                assert (
                    not layer.renormalize and layer.custom_routing_function is not None
                )
                assert layer.scoring_func == "sigmoid", (
                    f"Expected 'sigmoid' scoring func but got {layer.scoring_func}"
1266
1267
1268
1269
                )
            # Delegate to CUTLASS FlashInfer path; function already bound with
            # use_deepseek_fp8_block_scale for block-quant when applicable
            result = self.flashinfer_moe_fn(
1270
1271
1272
1273
1274
                x,
                layer,
                topk_weights,
                topk_ids,
                inplace=False,
1275
1276
1277
1278
                activation=layer.activation,
                global_num_experts=layer.global_num_experts,
                expert_map=layer.expert_map,
                apply_router_weight_on_input=layer.apply_router_weight_on_input,
1279
            )
1280
        else:
1281
            from vllm.model_executor.layers.fused_moe import fused_experts
1282

XuruiYang's avatar
XuruiYang committed
1283
            result = fused_experts(
1284
1285
1286
1287
1288
1289
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                inplace=True,
1290
1291
1292
1293
                activation=layer.activation,
                global_num_experts=layer.global_num_experts,
                apply_router_weight_on_input=layer.apply_router_weight_on_input,
                expert_map=layer.expert_map,
1294
1295
1296
                quant_config=self.moe_quant_config,
                allow_deep_gemm=self.allow_deep_gemm,
                allow_cutlass_block_scaled_grouped_gemm=(
1297
1298
1299
                    self.allow_cutlass_block_scaled_grouped_gemm
                ),
            )
1300
1301

        if layer.zero_expert_num != 0 and layer.zero_expert_type is not None:
1302
            assert not isinstance(result, tuple), (
XuruiYang's avatar
XuruiYang committed
1303
                "Shared + zero experts are mutually exclusive not yet supported"
1304
            )
XuruiYang's avatar
XuruiYang committed
1305
1306
1307
            return result, zero_expert_result
        else:
            return result
1308
1309


1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
class Fp8OnlineMoEMethod(Fp8MoEMethod):
    """MoE method for online FP8 quantization.
    Supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

    def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
        super().__init__(quant_config, layer)
        assert not quant_config.is_checkpoint_fp8_serialized
        assert quant_config.activation_scheme == "dynamic"
        assert quant_config.weight_block_size is None
        assert self.flashinfer_moe_backend is None

    def create_weights(
        self,
        layer: Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

        # We are doing online quantization, patch the weight loaded
        # to call `process_weights_after_loading` in a streaming fashion
        # as soon as the last weight chunk is loaded.
        weight_loader = extra_weight_attrs["weight_loader"]
        # create a new holder to prevent modifying behavior of any other
        # objects which might depend on the old one
        new_extra_weight_attrs = extra_weight_attrs

        def patched_weight_loader(param, loaded_weight, *args, **kwargs):
            # load the current weight chunk
            res = weight_loader(param, loaded_weight, *args, **kwargs)  # type: ignore[misc]

            # add a counter to track how many elements we have updated
            if not hasattr(layer, "_loaded_numel"):
                layer._loaded_numel = 0
            layer._loaded_numel += loaded_weight.numel()

            # if we have loaded all of the elements, call
            # process_weights_after_loading
            target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel()
            if layer._loaded_numel == target_loaded_numel:
                self.process_weights_after_loading(layer)

                # Delete the bookkeeping
                del layer._loaded_numel
                # Prevent the usual `process_weights_after_loading` call
                # from doing anything
                layer._already_called_process_weights_after_loading = True

            return res

        new_extra_weight_attrs["weight_loader"] = patched_weight_loader
        extra_weight_attrs = new_extra_weight_attrs

        # WEIGHTS
        w13_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                hidden_size,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

        w2_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
        # Allocate 2 scales for w1 and w3 respectively.
        # They will be combined to a single scale after weight loading.
        w13_weight_scale = torch.nn.Parameter(
            torch.ones(num_experts, dtype=torch.float32), requires_grad=False
        )
        w2_weight_scale = torch.nn.Parameter(
            torch.ones(num_experts, dtype=torch.float32), requires_grad=False
        )
        layer.register_parameter("w13_weight_scale", w13_weight_scale)
        layer.register_parameter("w2_weight_scale", w2_weight_scale)

        layer.w13_input_scale = None
        layer.w2_input_scale = None

        self.rocm_aiter_moe_enabled = False

    def process_weights_after_loading(self, layer: Module) -> None:
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

        # Lazy import to avoid importing triton too early.
        self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()

        # If checkpoint is fp16, quantize in place.
        fp8_dtype = current_platform.fp8_dtype()
        w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
        w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)

        for expert in range(layer.local_num_experts):
            w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
                ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
            )
            w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
                ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
            )
        replace_parameter(layer, "w13_weight", w13_weight)
        replace_parameter(layer, "w2_weight", w2_weight)

        # Reshuffle weights for AITER if needed.
        if self.rocm_aiter_moe_enabled:
            shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
                layer.w13_weight, layer.w2_weight
            )
            replace_parameter(layer, "w13_weight", shuffled_w13)
            replace_parameter(layer, "w2_weight", shuffled_w2)

        # Rushuffle weights for MARLIN if needed.
        if self.use_marlin:
            prepare_moe_fp8_layer_for_marlin(
                layer, False, input_dtype=self.marlin_input_dtype
            )


1455
1456
1457
class Fp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
1458
1459
1460
    """

    def __init__(self, quant_config: Fp8Config):
1461
        super().__init__(quant_config)