fp8.py 52.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Callable
5
from enum import Enum
6
from typing import TYPE_CHECKING, Any, Optional
7
8
9
10
11

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter

12
import vllm.envs as envs
13
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
14
from vllm import _custom_ops as ops
15
from vllm._aiter_ops import rocm_aiter_ops
16
from vllm.distributed import get_tensor_model_parallel_world_size
17
from vllm.logger import init_logger
18
from vllm.model_executor.layers.batch_invariant import (
19
    vllm_is_batch_invariant,
20
)
bnellnm's avatar
bnellnm committed
21
from vllm.model_executor.layers.fused_moe import (
22
23
24
25
26
27
28
    FusedMoE,
    FusedMoEActivationFormat,
    FusedMoEMethodBase,
    FusedMoEPermuteExpertsUnpermute,
    FusedMoEPrepareAndFinalize,
    FusedMoeWeightScaleSupported,
)
29
from vllm.model_executor.layers.fused_moe.config import (
30
    FusedMoEQuantConfig,
31
    RoutingMethodType,
32
33
    fp8_w8a8_moe_quant_config,
)
34
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
35
36
37
38
39
40
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
)
41
from vllm.model_executor.layers.quantization import QuantizationMethods
42
from vllm.model_executor.layers.quantization.base_config import (
43
44
45
    QuantizationConfig,
    QuantizeMethodBase,
)
46
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
47
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
48
49
    FlashinferMoeBackend,
    apply_flashinfer_per_tensor_scale_fp8,
50
    build_flashinfer_fp8_cutlass_moe_prepare_finalize,
51
52
53
54
55
56
57
    flashinfer_cutlass_moe_fp8,
    get_flashinfer_moe_backend,
    register_moe_scaling_factors,
    rotate_flashinfer_fp8_moe_weights,
    select_cutlass_fp8_gemm_impl,
    swap_w13_to_w31,
)
58
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
59
60
61
62
63
64
65
66
67
68
69
    W8A8BlockFp8LinearOp,
    create_fp8_input_scale,
    create_fp8_scale_parameter,
    create_fp8_weight_parameter,
    expert_weight_is_col_major,
    maybe_post_process_fp8_weight_block,
    process_fp8_weight_block_strategy,
    process_fp8_weight_tensor_strategy,
    requant_weight_ue8m0_inplace,
    validate_fp8_block_shape,
)
70
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
71
72
73
74
    apply_fp8_marlin_linear,
    prepare_fp8_layer_for_marlin,
    prepare_moe_fp8_layer_for_marlin,
)
75
from vllm.model_executor.layers.quantization.utils.quant_utils import (
76
77
78
    GroupShape,
    is_layer_skipped,
)
79
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
80
81
82
83
84
85
86
87
88
89
90
91
92
    Fp8LinearOp,
    all_close_1d,
    cutlass_block_fp8_supported,
    cutlass_fp8_supported,
    maybe_create_device_identity,
    normalize_e4m3fn_to_e4m3fnuz,
    per_tensor_dequantize,
)
from vllm.model_executor.parameter import (
    BlockQuantScaleParameter,
    ModelWeightParameter,
    PerTensorScaleParameter,
)
93
from vllm.model_executor.utils import set_weight_attrs
94
from vllm.platforms import current_platform
95
from vllm.scalar_type import scalar_types
96
97
98
99
100
from vllm.utils.deep_gemm import (
    get_col_major_tma_aligned_tensor,
    is_deep_gemm_e8m0_used,
    is_deep_gemm_supported,
)
101
from vllm.utils.flashinfer import has_flashinfer_moe
102
from vllm.utils.import_utils import has_deep_gemm
103

104
105
106
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

107
108
109
110
ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)

111

112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
class Fp8MoeBackend(Enum):
    NONE = 0
    FLASHINFER_TRTLLM = 1
    FLASHINFER_CUTLASS = 2
    DEEPGEMM = 3
    CUTLASS_BLOCK_SCALED_GROUPED_GEMM = 4
    MARLIN = 5
    TRITON = 6


def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend:
    """
    Select the primary FP8 MoE backend
    Note: Shape-specific fallbacks may still occur at runtime.
    """
    # prefer FlashInfer backends when available and enabled on supported GPUs
128
129
130
131
132
133
    if (
        current_platform.is_cuda()
        and current_platform.is_device_capability(100)
        and envs.VLLM_USE_FLASHINFER_MOE_FP8
        and has_flashinfer_moe()
    ):
134
135
        backend = get_flashinfer_moe_backend()
        if backend == FlashinferMoeBackend.TENSORRT_LLM:
136
            logger.info_once("Using FlashInfer FP8 MoE TRTLLM backend for SM100")
137
138
            return Fp8MoeBackend.FLASHINFER_TRTLLM
        else:
139
140
141
142
143
144
145
            if block_quant:
                raise ValueError(
                    "FlashInfer FP8 MoE throughput backend does not "
                    "support block quantization. Please use "
                    "VLLM_FLASHINFER_MOE_BACKEND=latency "
                    "instead."
                )
146
            logger.info_once("Using FlashInfer FP8 MoE CUTLASS backend for SM100")
147
148
149
            return Fp8MoeBackend.FLASHINFER_CUTLASS

    # weight-only path for older GPUs without native FP8
150
151
152
153
    use_marlin = (
        not current_platform.has_device_capability(89)
        or envs.VLLM_TEST_FORCE_FP8_MARLIN
    )
154
155
156
157
158
159
160
    if current_platform.is_rocm():
        use_marlin = False
    if use_marlin:
        logger.info_once("Using Marlin backend for FP8 MoE")
        return Fp8MoeBackend.MARLIN

    # deepGEMM on supported platforms with block-quantized weights
161
    if envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM and block_quant:
162
        if not has_deep_gemm():
163
            logger.warning_once("DeepGEMM backend requested but not available.")
164
165
166
167
168
        elif is_deep_gemm_supported():
            logger.info_once("Using DeepGEMM backend for FP8 MoE")
            return Fp8MoeBackend.DEEPGEMM

    # CUTLASS BlockScaled GroupedGemm on SM100 with block-quantized weights
169
170
171
172
173
174
    if (
        current_platform.is_cuda()
        and current_platform.is_device_capability(100)
        and block_quant
    ):
        logger.info_once("Using Cutlass BlockScaled GroupedGemm backend for FP8 MoE")
175
176
177
178
179
180
181
        return Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM

    # default to Triton
    logger.info_once("Using Triton backend for FP8 MoE")
    return Fp8MoeBackend.TRITON


182
class Fp8Config(QuantizationConfig):
183
184
    """Config class for FP8."""

185
186
    def __init__(
        self,
187
        is_checkpoint_fp8_serialized: bool = False,
188
        activation_scheme: str = "dynamic",
189
190
        ignored_layers: list[str] | None = None,
        weight_block_size: list[int] | None = None,
191
    ) -> None:
192
        super().__init__()
193

194
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
195

196
        if activation_scheme not in ACTIVATION_SCHEMES:
197
            raise ValueError(f"Unsupported activation scheme {activation_scheme}")
198
        self.activation_scheme = activation_scheme
199
        self.ignored_layers = ignored_layers or []
200
201
202
203
        if weight_block_size is not None:
            if not is_checkpoint_fp8_serialized:
                raise ValueError(
                    "The block-wise quantization only supports fp8-serialized "
204
205
                    "checkpoint for now."
                )
206
207
208
            if len(weight_block_size) != 2:
                raise ValueError(
                    "The quantization block size of weight must have 2 "
209
210
                    f"dimensions, but got {len(weight_block_size)} dimensions"
                )
211
            if activation_scheme != "dynamic":
212
213
214
215
216
                raise ValueError(
                    "The block-wise quantization only supports "
                    "dynamic activation scheme for now, but got "
                    f"{activation_scheme} activation scheme."
                )
217
        self.weight_block_size = weight_block_size
218

219
    @classmethod
220
    def get_name(cls) -> QuantizationMethods:
221
222
223
        return "fp8"

    @classmethod
224
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
225
226
227
228
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
229
        return 80
230
231

    @classmethod
232
    def get_config_filenames(cls) -> list[str]:
233
234
        return []

235
236
    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if self.ignored_layers is not None:
237
            self.ignored_layers = hf_to_vllm_mapper.apply_list(self.ignored_layers)
238

239
    @classmethod
240
    def from_config(cls, config: dict[str, Any]) -> "Fp8Config":
241
        quant_method = cls.get_from_keys(config, ["quant_method"])
242
        is_checkpoint_fp8_serialized = "fp8" in quant_method
243
        activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
244
        ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
245
        weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None)
246
        if not ignored_layers:
247
248
249
250
251
252
253
254
255
256
257
258
259
            ignored_layers = cls.get_from_keys_or(
                config, ["modules_to_not_convert"], None
            )
        return cls(
            is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
            activation_scheme=activation_scheme,
            ignored_layers=ignored_layers,
            weight_block_size=weight_block_size,
        )

    def get_xpu_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
260
261
        from vllm.attention.layer import Attention
        from vllm.model_executor.layers.quantization.ipex_quant import (
262
263
264
265
            XPUFp8LinearMethod,
            XPUFp8MoEMethod,
        )

266
267
268
269
        fp8_config = Fp8Config(
            is_checkpoint_fp8_serialized=self.is_checkpoint_fp8_serialized,
            activation_scheme=self.activation_scheme,
            ignored_layers=self.ignored_layers,
270
271
            weight_block_size=self.weight_block_size,
        )
272
273

        if isinstance(layer, LinearBase):
274
275
276
277
278
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
279
280
281
282
283
284
285
286
                return UnquantizedLinearMethod()
            return XPUFp8LinearMethod(fp8_config)
        elif isinstance(layer, FusedMoE):
            return XPUFp8MoEMethod(fp8_config, layer)
        elif isinstance(layer, Attention):
            return Fp8KVCacheMethod(self)
        return None

287
288
289
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
290
291
        from vllm.attention.layer import Attention  # Avoid circular import

292
293
        if current_platform.is_xpu():
            return self.get_xpu_quant_method(layer, prefix)
294
        if isinstance(layer, LinearBase):
295
296
297
298
299
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
300
                return UnquantizedLinearMethod()
301
            return Fp8LinearMethod(self)
302
        elif isinstance(layer, FusedMoE):
303
304
305
306
307
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
XuruiYang's avatar
XuruiYang committed
308
                return UnquantizedFusedMoEMethod(layer.moe_config)
309
            return Fp8MoEMethod(self, layer)
310
        elif isinstance(layer, Attention):
311
            return Fp8KVCacheMethod(self)
312
        return None
313

314
    def get_cache_scale(self, name: str) -> str | None:
315
316
317
318
319
320
321
322
323
324
325
326
        """
        Check whether the param name matches the format for k/v cache scales
        in compressed-tensors. If this is the case, return its equivalent
        param name expected by vLLM

        :param name: param name
        :return: matching param name for KV cache scale in vLLM
        """
        if name.endswith(".output_scale") and ".k_proj" in name:
            return name.replace(".k_proj.output_scale", ".attn.k_scale")
        if name.endswith(".output_scale") and ".v_proj" in name:
            return name.replace(".v_proj.output_scale", ".attn.v_scale")
327
328
329
330
331
        if name.endswith(".output_scale") and ".q_proj" in name:
            return name.replace(".q_proj.output_scale", ".attn.q_scale")
        if name.endswith("self_attn.prob_output_scale"):
            return name.replace(".prob_output_scale", ".attn.prob_scale")
        # If no matches, return None
332
333
        return None

334
335
336

class Fp8LinearMethod(LinearMethodBase):
    """Linear method for FP8.
337
338
339
340
341
342
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.
343
344
345
346
347

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
    2. Only support float8_e4m3fn data type due to the limitation of
       torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
348

349
350
351
352
    Args:
        quant_config: The quantization config.
    """

353
    def __init__(self, quant_config: Fp8Config):
354
        self.quant_config = quant_config
355
        self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
356
        self.out_dtype = torch.get_default_dtype()
357

358
359
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
360
361
362
363
        self.use_marlin = (
            not current_platform.has_device_capability(89)
            or envs.VLLM_TEST_FORCE_FP8_MARLIN
        )
364
        # Disable marlin for rocm
365
        if current_platform.is_rocm():
366
            self.use_marlin = False
367
        if vllm_is_batch_invariant():
368
            self.use_marlin = False
369

370
        self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enaled()
371
        self.use_deep_gemm = is_deep_gemm_supported()
372

373
374
        self.weight_block_size = self.quant_config.weight_block_size
        self.block_quant = self.weight_block_size is not None
375
        self.act_q_static = self.quant_config.activation_scheme == "static"
376
377
        if self.weight_block_size:
            self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
378
        else:
379
380
381
382
383
            # Use per-token quantization for better perf if dynamic and cutlass
            if not self.act_q_static and cutlass_fp8_supported():
                self.act_q_group_shape = GroupShape.PER_TOKEN
            else:
                self.act_q_group_shape = GroupShape.PER_TENSOR
384

385
386
387
388
389
390
391
392
393
394
395
396
        if self.block_quant:
            assert not self.act_q_static
            assert self.weight_block_size is not None
            self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
                weight_group_shape=GroupShape(*self.weight_block_size),
                act_quant_group_shape=self.act_q_group_shape,
                cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
                use_aiter_and_is_supported=self.use_aiter_and_is_supported,
            )
        else:
            self.fp8_linear = Fp8LinearOp(
                act_quant_static=self.act_q_static,
397
398
                act_quant_group_shape=self.act_q_group_shape,
            )
399

400
401
402
403
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
404
        output_partition_sizes: list[int],
405
406
407
408
409
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
410
411
        maybe_create_device_identity()

412
        output_size_per_partition = sum(output_partition_sizes)
413
        weight_loader = extra_weight_attrs.get("weight_loader")
414
415
416
417
418
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None
419

420
        if self.block_quant:
421
422
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
423
424
425
426
427
428
429
430
            validate_fp8_block_shape(
                layer,
                input_size,
                output_size,
                input_size_per_partition,
                output_partition_sizes,
                self.weight_block_size,
            )
431

432
        # WEIGHT
433
        if self.quant_config.is_checkpoint_fp8_serialized:
434
435
436
            weight = create_fp8_weight_parameter(
                output_size_per_partition, input_size_per_partition, weight_loader
            )
437
438
        else:
            # For non-serialized checkpoints, use original dtype
439
440
441
442
443
444
445
446
447
448
            weight = ModelWeightParameter(
                data=torch.empty(
                    output_size_per_partition,
                    input_size_per_partition,
                    dtype=params_dtype,
                ),
                input_dim=1,
                output_dim=0,
                weight_loader=weight_loader,
            )
449
450
        layer.register_parameter("weight", weight)

451
452
453
454
        # If checkpoint is serialized fp8, load them.
        # Otherwise, wait until process_weights_after_loading.
        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
455
            if not self.block_quant:
456
457
458
459
460
461
462
                scale = create_fp8_scale_parameter(
                    PerTensorScaleParameter,
                    output_partition_sizes,
                    input_size_per_partition,
                    None,
                    weight_loader,
                )
463
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
464
465
                layer.register_parameter("weight_scale", scale)
            else:
466
467
                assert not self.act_q_static
                assert self.weight_block_size is not None
468
469
470
471
472
473
474
                scale = create_fp8_scale_parameter(
                    BlockQuantScaleParameter,
                    output_partition_sizes,
                    input_size_per_partition,
                    self.weight_block_size,
                    weight_loader,
                )
475
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
476
477
                # The weight_scale_inv name is intentional for deepseekv3
                layer.register_parameter("weight_scale_inv", scale)
478

479
            # INPUT ACTIVATION SCALE
480
            if self.act_q_static:
481
                scale = create_fp8_input_scale(output_partition_sizes, weight_loader)
482
                set_weight_attrs(scale, {"scale_type": "input_scale"})
483
                layer.register_parameter("input_scale", scale)
484
485
            else:
                layer.register_parameter("input_scale", None)
486

487
    def process_weights_after_loading(self, layer: Module) -> None:
488
        size_k_first = True
489
        input_scale = None
490
        # TODO(rob): refactor block quant into separate class.
491
        if self.block_quant:
492
            assert not self.act_q_static
493
            size_k_first = False
494

495
            weight, weight_scale = process_fp8_weight_block_strategy(
496
497
                layer.weight, layer.weight_scale_inv
            )
498
499
500
            # Delete the weight_scale_inv parameter to avoid confusion
            # with the weight_scale parameter
            del layer.weight_scale_inv
501

502
        # If checkpoint not serialized fp8, quantize the weights.
503
        elif not self.quant_config.is_checkpoint_fp8_serialized:
504
            qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
505
            weight = qweight.t()
506

507
        # If checkpoint is fp8 per-tensor, handle that there are N scales for N
508
        # shards in a fused module
509
        else:
510
511
            weight = layer.weight
            weight_scale = layer.weight_scale
512
513
514

            # If using w8a8, torch._scaled_mm needs per tensor, so
            # requantize the logical shards as a single weight.
515
            if not self.use_marlin:
516
517
518
519
520
521
                weight, weight_scale, input_scale = process_fp8_weight_tensor_strategy(
                    weight,
                    weight_scale,
                    layer.logical_widths,
                    getattr(layer, "input_scale", None),
                )
522
523
524
525
526
527
528
529
                if self.act_q_static:
                    assert input_scale is not None
                    input_scale = input_scale.max()
            weight = weight.t()

        # Update layer with new values.
        layer.weight = Parameter(weight.data, requires_grad=False)
        layer.weight_scale = Parameter(weight_scale.data, requires_grad=False)
530
531
532
533
534
        layer.input_scale = (
            Parameter(input_scale, requires_grad=False)
            if input_scale is not None
            else None
        )
535

536
        if self.use_marlin:
537
            prepare_fp8_layer_for_marlin(layer, size_k_first)
538
539
            # Activations not quantized for marlin.
            del layer.input_scale
540
            return
541

542
        if self.block_quant:
543
            maybe_post_process_fp8_weight_block(layer, self.cutlass_block_fp8_supported)
544

545
546
547
548
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
549
        bias: torch.Tensor | None = None,
550
    ) -> torch.Tensor:
551
552
        # if batch invariant mode is enabled, prefer DeepGEMM FP8 path
        # we will use BF16 dequant when DeepGEMM is not supported.
553
        if vllm_is_batch_invariant():
554
555
            if self.block_quant:
                assert self.weight_block_size is not None
556
557
558
559
560
561
562
                return self.w8a8_block_fp8_linear.apply(
                    input=x,
                    weight=layer.weight,
                    weight_scale=layer.weight_scale,
                    input_scale=layer.input_scale,
                    bias=bias,
                )
563
            else:
564
565
566
                # per-tensor/channel: dequant to BF16 and run GEMM
                weight_fp8 = layer.weight.to(torch.bfloat16)
                weight_scale = layer.weight_scale.to(torch.bfloat16)
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
                if weight_scale.numel() == 1:
                    # Per-tensor: simple scalar multiplication
                    weight_bf16 = weight_fp8 * weight_scale
                else:
                    # Multiple scales (fused modules like QKV)
                    # Try to infer correct broadcasting
                    # weight is [K, N], scale could be [num_logical_weights]
                    # Need to figure out how to broadcast - for now just try
                    # direct multiplication
                    if (
                        weight_scale.dim() == 1
                        and weight_scale.shape[0] == weight_fp8.shape[0]
                    ):
                        # Per-row scaling
                        weight_bf16 = weight_fp8 * weight_scale.unsqueeze(1)
                    else:
                        # Fallback
                        weight_bf16 = weight_fp8 * weight_scale
585
                return torch.nn.functional.linear(x, weight_bf16.t(), bias)
586

587
        if self.use_marlin:
588
589
590
591
            return apply_fp8_marlin_linear(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale,
592
593
594
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
595
596
                bias=bias,
            )
597

598
        if self.block_quant:
599
600
601
            assert self.weight_block_size is not None

            return self.w8a8_block_fp8_linear.apply(
602
                input=x,
603
604
605
                weight=layer.weight,
                weight_scale=layer.weight_scale,
                input_scale=layer.input_scale,
606
                bias=bias,
607
            )
608

609
610
611
612
613
614
615
616
        return self.fp8_linear.apply(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            out_dtype=self.out_dtype,
            input_scale=layer.input_scale,
            bias=bias,
        )
617
618


619
620
621
622
623
624
625
626
627
628
629
630
631
class Fp8MoEMethod(FusedMoEMethodBase):
    """MoE method for FP8.
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

632
633
634
    def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
        super().__init__(layer.moe_config)
        self.layer = layer
635
        self.quant_config = quant_config
636
        self.weight_block_size = self.quant_config.weight_block_size
637
        self.block_quant: bool = self.weight_block_size is not None
638
        self.fp8_backend = get_fp8_moe_backend(self.block_quant)
639

640
        self.use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN
641
        self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
642
643
644
645
646
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
            self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
        elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
            self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS

647
        self.allow_deep_gemm = self.fp8_backend == Fp8MoeBackend.DEEPGEMM
648
649
650
        self.allow_cutlass_block_scaled_grouped_gemm = (
            self.fp8_backend == Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM
        )
651

652
653
654
655
656
657
658
659
660
    def create_weights(
        self,
        layer: Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
661
662
663
664
665
666
        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

667
668
        if self.quant_config.is_checkpoint_fp8_serialized:
            params_dtype = torch.float8_e4m3fn
669
        if self.block_quant:
670
671
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
672
673
            tp_size = get_tensor_model_parallel_world_size()
            block_n, block_k = (
674
675
                self.weight_block_size[0],
                self.weight_block_size[1],
676
677
678
679
680
            )
            # NOTE: To ensure proper alignment of the block-wise quantization
            # scales, the output_size of the weights for both the gate and up
            # layers must be divisible by block_n.
            # Required by column parallel or enabling merged weights
681
            if intermediate_size_per_partition % block_n != 0:
682
683
                raise ValueError(
                    f"The output_size of gate's and up's weight = "
684
                    f"{intermediate_size_per_partition} is not divisible by "
685
686
687
                    f"weight quantization block_n = {block_n}."
                )
            if tp_size > 1 and intermediate_size_per_partition % block_k != 0:
688
                # Required by row parallel
689
690
691
                raise ValueError(
                    f"The input_size of down's weight = "
                    f"{intermediate_size_per_partition} is not divisible by "
692
693
                    f"weight quantization block_k = {block_k}."
                )
694
695

        # WEIGHTS
696
697
698
699
700
701
702
703
704
        w13_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                hidden_size,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
705
706
707
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

708
709
710
711
712
713
714
715
716
        w2_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
717
718
719
720
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
721
722
723
        if not self.block_quant:
            # Allocate 2 scales for w1 and w3 respectively.
            # They will be combined to a single scale after weight loading.
724
725
726
727
728
729
            w13_weight_scale = torch.nn.Parameter(
                torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
            )
            w2_weight_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
730
731
732
733
734
735
            layer.register_parameter("w13_weight_scale", w13_weight_scale)
            layer.register_parameter("w2_weight_scale", w2_weight_scale)
        else:
            w13_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
736
                    2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
737
738
739
740
741
742
743
744
745
                    (hidden_size + block_k - 1) // block_k,
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            w2_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
                    (hidden_size + block_n - 1) // block_n,
746
                    (intermediate_size_per_partition + block_k - 1) // block_k,
747
748
749
750
751
752
753
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
            layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
            assert self.quant_config.activation_scheme == "dynamic"
754

755
756
757
        # Add the quantization method used (per tensor/grouped/channel)
        # to ensure the weight scales are loaded in properly
        extra_weight_attrs.update(
758
759
760
761
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
            if self.block_quant
            else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
        )
762
763
764
765
        # If loading fp8 checkpoint, pass the weight loaders.
        # If loading an fp16 checkpoint, do not (we will quantize in
        #   process_weights_after_loading()
        if self.quant_config.is_checkpoint_fp8_serialized:
766
767
            set_weight_attrs(w13_weight_scale, extra_weight_attrs)
            set_weight_attrs(w2_weight_scale, extra_weight_attrs)
768
769
770
771
772
773

        # INPUT_SCALES
        if self.quant_config.activation_scheme == "static":
            if not self.quant_config.is_checkpoint_fp8_serialized:
                raise ValueError(
                    "Found static activation scheme for checkpoint that "
774
775
                    "was not serialized fp8."
                )
776

777
778
779
            w13_input_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
780
            layer.register_parameter("w13_input_scale", w13_input_scale)
781
            set_weight_attrs(w13_input_scale, extra_weight_attrs)
782

783
784
785
            w2_input_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
786
            layer.register_parameter("w2_input_scale", w2_input_scale)
787
788
            set_weight_attrs(w2_input_scale, extra_weight_attrs)

789
        else:
790
791
            layer.w13_input_scale = None
            layer.w2_input_scale = None
792

793
794
        self.rocm_aiter_moe_enabled = False

795
    def process_weights_after_loading(self, layer: Module) -> None:
796
797
        # Lazy import to avoid importing triton too early.

798
        self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
799

800
        # TODO (rob): refactor block quant into separate class.
801
        if self.block_quant:
802
            assert self.quant_config.activation_scheme == "dynamic"
803
            if current_platform.is_fp8_fnuz():
804
                w13_weight, w13_weight_scale_inv, w13_input_scale = (
805
                    normalize_e4m3fn_to_e4m3fnuz(
806
807
808
809
810
811
                        layer.w13_weight,
                        layer.w13_weight_scale_inv,
                        layer.w13_input_scale,
                    )
                )
                w2_weight, w2_weight_scale_inv, w2_input_scale = (
812
                    normalize_e4m3fn_to_e4m3fnuz(
813
814
815
                        layer.w2_weight, layer.w2_weight_scale_inv, layer.w2_input_scale
                    )
                )
816
            elif self.flashinfer_moe_backend is not None:
817
818
                # NOTE: weights have to be swapped since the activation is
                # applied on different half for flashinfer vs vllm
819
                w13_weight = swap_w13_to_w31(layer.w13_weight.data)
820
                w13_weight_scale_inv = swap_w13_to_w31(layer.w13_weight_scale_inv.data)
821
822
                w2_weight = layer.w2_weight.data
                w2_weight_scale_inv = layer.w2_weight_scale_inv.data
823
824
825
826
827
828
829
830
            else:
                w13_weight = layer.w13_weight.data
                w13_weight_scale_inv = layer.w13_weight_scale_inv.data
                w2_weight = layer.w2_weight
                w2_weight_scale_inv = layer.w2_weight_scale_inv

            # torch.compile() cannot use Parameter subclasses.
            layer.w13_weight = Parameter(w13_weight, requires_grad=False)
831
832
833
            layer.w13_weight_scale_inv = Parameter(
                w13_weight_scale_inv, requires_grad=False
            )
834
            layer.w2_weight = Parameter(w2_weight, requires_grad=False)
835
836
837
            layer.w2_weight_scale_inv = Parameter(
                w2_weight_scale_inv, requires_grad=False
            )
838
            if self.rocm_aiter_moe_enabled:
839
                # reshaping weights is required for aiter moe kernel.
840
                shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
841
842
                    layer.w13_weight.data, layer.w2_weight.data
                )
843

844
845
                layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
846

847
            # DeepGemm scales need to be transposed and aligned. We try to do
848
            # it ahead of time for performance reasons.
849
            if self.allow_deep_gemm and not is_deep_gemm_e8m0_used():
850
                if expert_weight_is_col_major(layer.w13_weight_scale_inv):
851
852
853
                    layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(
                        layer.w13_weight_scale_inv
                    )
854
                if expert_weight_is_col_major(layer.w2_weight_scale_inv):
855
856
857
                    layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(
                        layer.w2_weight_scale_inv
                    )
858

859
        # If checkpoint is fp16, quantize in place.
860
        elif not self.quant_config.is_checkpoint_fp8_serialized:
861
            fp8_dtype = current_platform.fp8_dtype()
862
            w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
863
            w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
864
865
866

            # Re-initialize w13_scale because we directly quantize
            # merged w13 weights and generate a single scaling factor.
867
868
869
870
871
872
873
874
            layer.w13_weight_scale = torch.nn.Parameter(
                torch.ones(
                    layer.local_num_experts,
                    dtype=torch.float32,
                    device=w13_weight.device,
                ),
                requires_grad=False,
            )
875
            for expert in range(layer.local_num_experts):
876
877
878
879
880
881
882
883
                w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
                    ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
                )
                w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
                    ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
                )
            layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
            layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
884
            if self.rocm_aiter_moe_enabled:
885
                # reshaping weights is required for aiter moe kernel.
886
                shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
887
888
                    layer.w13_weight, layer.w2_weight
                )
889

890
891
                layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
892
893
894
895
896
897
898
        # If checkpoint is fp8, we need to handle that the
        # MoE kernels require single activation scale and single weight
        # scale for w13 per expert.
        else:
            # Fp8 moe kernels require a single activation scale.
            # We take the max of all the scales in case they differ.
            if self.quant_config.activation_scheme == "static":
899
                if layer.w13_input_scale is None or layer.w2_input_scale is None:
900
901
                    raise ValueError(
                        "QuantConfig has static quantization, but found "
902
903
904
905
906
                        "activation scales are None."
                    )
                if not all_close_1d(layer.w13_input_scale) or not all_close_1d(
                    layer.w2_input_scale
                ):
907
                    logger.warning_once(
908
909
                        "Found input_scales that are not equal for "
                        "fp8 MoE layer. Using the maximum across experts "
910
911
                        "for each layer."
                    )
912
                layer.w13_input_scale = torch.nn.Parameter(
913
914
                    layer.w13_input_scale.max(), requires_grad=False
                )
915
                layer.w2_input_scale = torch.nn.Parameter(
916
917
                    layer.w2_input_scale.max(), requires_grad=False
                )
918
            if current_platform.is_fp8_fnuz():
919
                # Normalize the weights and scales
920
                w13_weight, w13_weight_scale, w13_input_scale = (
921
                    normalize_e4m3fn_to_e4m3fnuz(
922
923
924
925
                        layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale
                    )
                )
                w2_weight, w2_weight_scale, w2_input_scale = (
926
                    normalize_e4m3fn_to_e4m3fnuz(
927
928
929
                        layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale
                    )
                )
930
                # Reset the parameter
931
                layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
932
                layer.w13_weight_scale = torch.nn.Parameter(
933
934
                    w13_weight_scale, requires_grad=False
                )
935
936
                if w13_input_scale is not None:
                    layer.w13_input_scale = torch.nn.Parameter(
937
938
939
940
941
942
                        w13_input_scale, requires_grad=False
                    )
                layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
                layer.w2_weight_scale = torch.nn.Parameter(
                    w2_weight_scale, requires_grad=False
                )
943
944
                if w2_input_scale is not None:
                    layer.w2_input_scale = torch.nn.Parameter(
945
946
                        w2_input_scale, requires_grad=False
                    )
947
948
949

            # Fp8 moe kernel needs single weight scale for w13 per expert.
            # We take the max then dequant and requant each expert.
950
            assert layer.w13_weight_scale is not None
951
            shard_size = layer.intermediate_size_per_partition
952
            max_w13_scales = layer.w13_weight_scale.max(dim=1).values
953
            for expert_id in range(layer.local_num_experts):
954
955
956
                start = 0
                for shard_id in range(2):
                    dq_weight = per_tensor_dequantize(
957
958
959
960
961
962
                        layer.w13_weight[expert_id][start : start + shard_size, :],
                        layer.w13_weight_scale[expert_id][shard_id],
                    )
                    layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
                        ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
                    )
963
964
                    start += shard_size

965
            if self.rocm_aiter_moe_enabled:
966
                shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
967
968
                    layer.w13_weight, layer.w2_weight
                )
969

970
971
                layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
972

973
974
975
            layer.w13_weight_scale = torch.nn.Parameter(
                max_w13_scales, requires_grad=False
            )
976

977
978
979
980
981
982
            if self.flashinfer_moe_backend is not None:
                # NOTE: weights have to be swapped since the activation is
                # applied on different half for flashinfer vs vllm
                assert not self.block_quant
                register_moe_scaling_factors(layer)
                w13_weight = swap_w13_to_w31(layer.w13_weight.data)
983
                if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
984
985
986
                    rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
                layer.w13_weight.data = w13_weight.data

987
988
989
990
991
        if self.use_marlin:
            prepare_moe_fp8_layer_for_marlin(layer, False)
            # Activations not quantized for marlin.
            del layer.w13_input_scale
            del layer.w2_input_scale
992

993
        if is_deep_gemm_e8m0_used() and self.block_quant:
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
            assert layer.weight_block_size is not None
            # Re-quantise the expert weights so their scales are UE8M0.
            block_sz = tuple(layer.weight_block_size)
            requant_weight_ue8m0_inplace(
                layer.w13_weight.data,
                layer.w13_weight_scale_inv.data,
                block_sz,
            )
            requant_weight_ue8m0_inplace(
                layer.w2_weight.data,
                layer.w2_weight_scale_inv.data,
                block_sz,
            )

            # Ensure column-major TMA alignment expected by DeepGEMM.
1009
            if expert_weight_is_col_major(layer.w13_weight_scale_inv):
1010
                layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(
1011
1012
                    layer.w13_weight_scale_inv
                )
1013
            if expert_weight_is_col_major(layer.w2_weight_scale_inv):
1014
                layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(
1015
1016
                    layer.w2_weight_scale_inv
                )
1017

1018
    def maybe_make_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize | None:
1019
1020
1021
1022
1023
        if (
            self.rocm_aiter_moe_enabled
            or self.use_marlin
            or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1024
1025
            return None
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
1026
1027
1028
            prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
                self.moe
            )
1029
1030
1031
1032
1033
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
        else:
            return super().maybe_make_prepare_finalize()

bnellnm's avatar
bnellnm committed
1034
1035
1036
    def select_gemm_impl(
        self,
        prepare_finalize: FusedMoEPrepareAndFinalize,
1037
        layer: torch.nn.Module,
bnellnm's avatar
bnellnm committed
1038
    ) -> FusedMoEPermuteExpertsUnpermute:
1039
        from vllm.model_executor.layers.fused_moe import (
1040
1041
1042
            BatchedTritonOrDeepGemmExperts,
            TritonOrDeepGemmExperts,
        )
1043

1044
        assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
1045
1046
            "Marlin and ROCm AITER are not supported with all2all yet."
        )
1047

1048
1049
        assert self.moe_quant_config is not None

1050
1051
1052
1053
1054
        if (
            prepare_finalize.activation_format
            == FusedMoEActivationFormat.BatchedExperts
        ):
            max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
bnellnm's avatar
bnellnm committed
1055
1056
1057
1058
            assert max_num_tokens_per_rank is not None
            logger.debug(
                "BatchedTritonOrDeepGemmExperts(%s): "
                "max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
1059
1060
1061
1062
1063
                self.__class__.__name__,
                max_num_tokens_per_rank,
                self.weight_block_size,
                False,
            )
bnellnm's avatar
bnellnm committed
1064
            return BatchedTritonOrDeepGemmExperts(
1065
                max_num_tokens=max_num_tokens_per_rank,
1066
                num_dispatchers=prepare_finalize.num_dispatchers(),
1067
                quant_config=self.moe_quant_config,
1068
                allow_deep_gemm=self.allow_deep_gemm,
1069
            )
1070
1071
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
            experts = select_cutlass_fp8_gemm_impl(
1072
1073
                self.moe,
                self.moe_quant_config,
1074
1075
1076
            )
            logger.debug_once("Using %s", experts.__class__.__name__)
            return experts
1077
        else:
bnellnm's avatar
bnellnm committed
1078
1079
            logger.debug(
                "TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s",
1080
1081
1082
1083
                self.__class__.__name__,
                self.weight_block_size,
                False,
            )
bnellnm's avatar
bnellnm committed
1084
            return TritonOrDeepGemmExperts(
1085
                quant_config=self.moe_quant_config,
1086
1087
1088
                allow_deep_gemm=self.allow_deep_gemm,
            )

1089
    def get_fused_moe_quant_config(
1090
        self, layer: torch.nn.Module
1091
    ) -> FusedMoEQuantConfig | None:
1092
1093
1094
1095
        if self.use_marlin:
            return None

        return fp8_w8a8_moe_quant_config(
1096
1097
1098
1099
1100
1101
1102
1103
            w1_scale=(
                layer.w13_weight_scale_inv
                if self.block_quant
                else layer.w13_weight_scale
            ),
            w2_scale=(
                layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
            ),
1104
1105
            a1_scale=layer.w13_input_scale,
            a2_scale=layer.w2_input_scale,
1106
            block_shape=self.weight_block_size,
1107
1108
        )

1109
1110
1111
1112
1113
1114
1115
1116
    @property
    def supports_eplb(self) -> bool:
        return True

    @property
    def allow_inplace(self) -> bool:
        return True

1117
1118
1119
1120
1121
1122
1123
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
1124
        use_grouped_topk: bool = False,
1125
1126
        topk_group: int | None = None,
        num_expert_group: int | None = None,
1127
        global_num_experts: int = -1,
1128
1129
        expert_map: torch.Tensor | None = None,
        custom_routing_function: Callable | None = None,
Simon Mo's avatar
Simon Mo committed
1130
        scoring_func: str = "softmax",
1131
        routed_scaling_factor: float = 1.0,
1132
        e_score_correction_bias: torch.Tensor | None = None,
1133
        apply_router_weight_on_input: bool = False,
Michael Goin's avatar
Michael Goin committed
1134
        activation: str = "silu",
1135
        enable_eplb: bool = False,
1136
1137
1138
1139
        expert_load_view: torch.Tensor | None = None,
        logical_to_physical_map: torch.Tensor | None = None,
        logical_replica_count: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
1140
1141
1142
1143
1144
        if enable_eplb:
            assert expert_load_view is not None
            assert logical_to_physical_map is not None
            assert logical_replica_count is not None
            assert isinstance(layer, FusedMoE)
1145

1146
        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
1147
1148
1149
            assert activation == "silu", (
                f"Expected 'silu' activation but got {activation}"
            )
1150

1151
            if self.block_quant:
1152
                import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe  # noqa: E501, F401
1153
1154
1155
1156
1157
1158

                e_score_correction_bias = (
                    e_score_correction_bias.to(x.dtype)
                    if e_score_correction_bias is not None
                    else None
                )
1159
                routing_method_type = layer.routing_method_type
1160
                return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
1161
1162
1163
                    routing_logits=router_logits.to(torch.float32)
                    if routing_method_type == RoutingMethodType.DeepSeekV3
                    else router_logits,
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
                    routing_bias=e_score_correction_bias,
                    x=x,
                    w13_weight=layer.w13_weight,
                    w13_weight_scale_inv=layer.w13_weight_scale_inv,
                    w2_weight=layer.w2_weight,
                    w2_weight_scale_inv=layer.w2_weight_scale_inv,
                    global_num_experts=global_num_experts,
                    top_k=top_k,
                    num_expert_group=num_expert_group,
                    topk_group=topk_group,
                    intermediate_size=layer.intermediate_size_per_partition,
                    expert_offset=layer.ep_rank * layer.local_num_experts,
                    local_num_experts=layer.local_num_experts,
1177
                    block_shape=self.weight_block_size,
1178
                    routing_method_type=routing_method_type,
1179
                    routed_scaling=routed_scaling_factor,
1180
1181
                )
            else:
1182
                assert not renormalize and custom_routing_function is not None
XuruiYang's avatar
XuruiYang committed
1183
                result = apply_flashinfer_per_tensor_scale_fp8(
1184
1185
1186
1187
1188
1189
1190
1191
                    layer=layer,
                    hidden_states=x,
                    router_logits=router_logits,
                    routing_bias=e_score_correction_bias,
                    global_num_experts=global_num_experts,
                    top_k=top_k,
                    num_expert_group=num_expert_group,
                    topk_group=topk_group,
1192
1193
                    apply_router_weight_on_input=apply_router_weight_on_input,
                )
1194

1195
1196
        zero_expert_num = getattr(layer, "zero_expert_num", 0)
        zero_expert_type = getattr(layer, "zero_expert_type", None)
XuruiYang's avatar
XuruiYang committed
1197
1198

        select_result = FusedMoE.select_experts(
1199
1200
1201
1202
1203
1204
1205
1206
1207
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
            num_expert_group=num_expert_group,
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
1208
            routed_scaling_factor=routed_scaling_factor,
1209
1210
1211
1212
1213
1214
1215
            e_score_correction_bias=e_score_correction_bias,
            indices_type=self.topk_indices_dtype,
            enable_eplb=enable_eplb,
            expert_map=expert_map,
            expert_load_view=expert_load_view,
            logical_to_physical_map=logical_to_physical_map,
            logical_replica_count=logical_replica_count,
XuruiYang's avatar
XuruiYang committed
1216
1217
1218
            global_num_experts=global_num_experts,
            zero_expert_num=zero_expert_num,
            zero_expert_type=zero_expert_type,
1219
            num_fused_shared_experts=layer.num_fused_shared_experts,
1220
1221
        )

XuruiYang's avatar
XuruiYang committed
1222
1223
        topk_weights, topk_ids, zero_expert_result = select_result

1224
1225
        if self.rocm_aiter_moe_enabled:
            from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (  # noqa: E501
1226
1227
1228
                rocm_aiter_fused_experts,
            )

XuruiYang's avatar
XuruiYang committed
1229
            result = rocm_aiter_fused_experts(
1230
1231
1232
                x,
                layer.w13_weight,
                layer.w2_weight,
1233
1234
1235
1236
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                activation=activation,
                apply_router_weight_on_input=apply_router_weight_on_input,
1237
                expert_map=expert_map,
1238
1239
                quant_config=self.moe_quant_config,
            )
1240
        elif self.use_marlin:
1241
            assert activation == "silu", f"{activation} not supported for Marlin MoE."
1242
            result = fused_marlin_moe(
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
                x,
                layer.w13_weight,
                layer.w2_weight,
                None,
                None,
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                router_logits,
                topk_weights,
                topk_ids,
                quant_type_id=scalar_types.float8_e4m3fn.id,
                apply_router_weight_on_input=apply_router_weight_on_input,
                global_num_experts=global_num_experts,
1256
                expert_map=expert_map,
1257
1258
                workspace=layer.workspace,
            )
1259
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
1260
            assert not self.block_quant
1261
1262
1263
1264
1265
1266
1267
            assert not renormalize and custom_routing_function is not None
            assert activation == "silu", (
                f"Expected 'silu' activation but got {activation}"
            )
            assert scoring_func == "sigmoid", (
                f"Expected 'sigmoid' scoring func but got {scoring_func}"
            )
1268

XuruiYang's avatar
XuruiYang committed
1269
            result = flashinfer_cutlass_moe_fp8(
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
                x,
                layer,
                topk_weights,
                topk_ids,
                inplace=False,
                activation=activation,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
1280
        else:
1281
            from vllm.model_executor.layers.fused_moe import fused_experts
1282

XuruiYang's avatar
XuruiYang committed
1283
            result = fused_experts(
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                inplace=True,
                activation=activation,
                global_num_experts=global_num_experts,
                apply_router_weight_on_input=apply_router_weight_on_input,
                expert_map=expert_map,
1294
1295
1296
                quant_config=self.moe_quant_config,
                allow_deep_gemm=self.allow_deep_gemm,
                allow_cutlass_block_scaled_grouped_gemm=(
1297
1298
1299
                    self.allow_cutlass_block_scaled_grouped_gemm
                ),
            )
XuruiYang's avatar
XuruiYang committed
1300
        if zero_expert_num != 0 and zero_expert_type is not None:
1301
            assert not isinstance(result, tuple), (
XuruiYang's avatar
XuruiYang committed
1302
                "Shared + zero experts are mutually exclusive not yet supported"
1303
            )
XuruiYang's avatar
XuruiYang committed
1304
1305
1306
            return result, zero_expert_result
        else:
            return result
1307
1308


1309
1310
1311
class Fp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
1312
1313
1314
    """

    def __init__(self, quant_config: Fp8Config):
1315
        super().__init__(quant_config)