fp8.py 54.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Callable
5
from enum import Enum
6
from typing import TYPE_CHECKING, Any, Optional
7
8
9
10
11

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter

12
import vllm.envs as envs
13
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
14
from vllm import _custom_ops as ops
15
from vllm.distributed import get_tensor_model_parallel_world_size
16
from vllm.logger import init_logger
17
from vllm.model_executor.layers.batch_invariant import (
18
    vllm_is_batch_invariant,
19
)
bnellnm's avatar
bnellnm committed
20
from vllm.model_executor.layers.fused_moe import (
21
22
23
24
25
26
27
    FusedMoE,
    FusedMoEActivationFormat,
    FusedMoEMethodBase,
    FusedMoEPermuteExpertsUnpermute,
    FusedMoEPrepareAndFinalize,
    FusedMoeWeightScaleSupported,
)
28
from vllm.model_executor.layers.fused_moe.config import (
29
30
31
    FusedMoEQuantConfig,
    fp8_w8a8_moe_quant_config,
)
32
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
33
34
35
36
37
38
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
)
39
from vllm.model_executor.layers.quantization import QuantizationMethods
40
from vllm.model_executor.layers.quantization.base_config import (
41
42
43
    QuantizationConfig,
    QuantizeMethodBase,
)
44
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
45
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
46
47
    FlashinferMoeBackend,
    apply_flashinfer_per_tensor_scale_fp8,
48
    build_flashinfer_fp8_cutlass_moe_prepare_finalize,
49
50
51
52
53
54
55
    flashinfer_cutlass_moe_fp8,
    get_flashinfer_moe_backend,
    register_moe_scaling_factors,
    rotate_flashinfer_fp8_moe_weights,
    select_cutlass_fp8_gemm_impl,
    swap_w13_to_w31,
)
56
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
57
58
59
60
61
62
63
64
65
66
67
68
    W8A8BlockFp8LinearOp,
    check_aiter_fp8_linear_support,
    create_fp8_input_scale,
    create_fp8_scale_parameter,
    create_fp8_weight_parameter,
    expert_weight_is_col_major,
    maybe_post_process_fp8_weight_block,
    process_fp8_weight_block_strategy,
    process_fp8_weight_tensor_strategy,
    requant_weight_ue8m0_inplace,
    validate_fp8_block_shape,
)
69
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
70
71
72
73
    apply_fp8_marlin_linear,
    prepare_fp8_layer_for_marlin,
    prepare_moe_fp8_layer_for_marlin,
)
74
from vllm.model_executor.layers.quantization.utils.quant_utils import (
75
76
77
    GroupShape,
    is_layer_skipped,
)
78
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
79
80
81
82
83
84
85
86
87
88
89
90
91
    Fp8LinearOp,
    all_close_1d,
    cutlass_block_fp8_supported,
    cutlass_fp8_supported,
    maybe_create_device_identity,
    normalize_e4m3fn_to_e4m3fnuz,
    per_tensor_dequantize,
)
from vllm.model_executor.parameter import (
    BlockQuantScaleParameter,
    ModelWeightParameter,
    PerTensorScaleParameter,
)
92
from vllm.model_executor.utils import set_weight_attrs
93
from vllm.platforms import current_platform
94
from vllm.scalar_type import scalar_types
95
from vllm.utils import has_deep_gemm
96
97
98
99
100
from vllm.utils.deep_gemm import (
    get_col_major_tma_aligned_tensor,
    is_deep_gemm_e8m0_used,
    is_deep_gemm_supported,
)
101
from vllm.utils.flashinfer import has_flashinfer_moe
102

103
104
105
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

106
107
108
109
ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)

110

111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
class Fp8MoeBackend(Enum):
    NONE = 0
    FLASHINFER_TRTLLM = 1
    FLASHINFER_CUTLASS = 2
    DEEPGEMM = 3
    CUTLASS_BLOCK_SCALED_GROUPED_GEMM = 4
    MARLIN = 5
    TRITON = 6


def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend:
    """
    Select the primary FP8 MoE backend
    Note: Shape-specific fallbacks may still occur at runtime.
    """
    # prefer FlashInfer backends when available and enabled on supported GPUs
127
128
129
130
131
132
    if (
        current_platform.is_cuda()
        and current_platform.is_device_capability(100)
        and envs.VLLM_USE_FLASHINFER_MOE_FP8
        and has_flashinfer_moe()
    ):
133
134
        backend = get_flashinfer_moe_backend()
        if backend == FlashinferMoeBackend.TENSORRT_LLM:
135
            logger.info_once("Using FlashInfer FP8 MoE TRTLLM backend for SM100")
136
137
            return Fp8MoeBackend.FLASHINFER_TRTLLM
        else:
138
            logger.info_once("Using FlashInfer FP8 MoE CUTLASS backend for SM100")
139
140
141
            return Fp8MoeBackend.FLASHINFER_CUTLASS

    # weight-only path for older GPUs without native FP8
142
143
144
145
    use_marlin = (
        not current_platform.has_device_capability(89)
        or envs.VLLM_TEST_FORCE_FP8_MARLIN
    )
146
147
148
149
150
151
152
153
154
    if current_platform.is_rocm():
        use_marlin = False
    if use_marlin:
        logger.info_once("Using Marlin backend for FP8 MoE")
        return Fp8MoeBackend.MARLIN

    # deepGEMM on supported platforms with block-quantized weights
    if envs.VLLM_USE_DEEP_GEMM and block_quant:
        if not has_deep_gemm():
155
            logger.warning_once("DeepGEMM backend requested but not available.")
156
157
158
159
160
        elif is_deep_gemm_supported():
            logger.info_once("Using DeepGEMM backend for FP8 MoE")
            return Fp8MoeBackend.DEEPGEMM

    # CUTLASS BlockScaled GroupedGemm on SM100 with block-quantized weights
161
162
163
164
165
166
    if (
        current_platform.is_cuda()
        and current_platform.is_device_capability(100)
        and block_quant
    ):
        logger.info_once("Using Cutlass BlockScaled GroupedGemm backend for FP8 MoE")
167
168
169
170
171
172
173
        return Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM

    # default to Triton
    logger.info_once("Using Triton backend for FP8 MoE")
    return Fp8MoeBackend.TRITON


174
class Fp8Config(QuantizationConfig):
175
176
    """Config class for FP8."""

177
178
    def __init__(
        self,
179
        is_checkpoint_fp8_serialized: bool = False,
180
        activation_scheme: str = "dynamic",
181
182
        ignored_layers: list[str] | None = None,
        weight_block_size: list[int] | None = None,
183
    ) -> None:
184
        super().__init__()
185

186
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
187

188
        if activation_scheme not in ACTIVATION_SCHEMES:
189
            raise ValueError(f"Unsupported activation scheme {activation_scheme}")
190
        self.activation_scheme = activation_scheme
191
        self.ignored_layers = ignored_layers or []
192
193
194
195
        if weight_block_size is not None:
            if not is_checkpoint_fp8_serialized:
                raise ValueError(
                    "The block-wise quantization only supports fp8-serialized "
196
197
                    "checkpoint for now."
                )
198
199
200
            if len(weight_block_size) != 2:
                raise ValueError(
                    "The quantization block size of weight must have 2 "
201
202
                    f"dimensions, but got {len(weight_block_size)} dimensions"
                )
203
            if activation_scheme != "dynamic":
204
205
206
207
208
                raise ValueError(
                    "The block-wise quantization only supports "
                    "dynamic activation scheme for now, but got "
                    f"{activation_scheme} activation scheme."
                )
209
        self.weight_block_size = weight_block_size
210

211
    @classmethod
212
    def get_name(cls) -> QuantizationMethods:
213
214
215
        return "fp8"

    @classmethod
216
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
217
218
219
220
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
221
        return 80
222
223

    @classmethod
224
    def get_config_filenames(cls) -> list[str]:
225
226
        return []

227
228
    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if self.ignored_layers is not None:
229
            self.ignored_layers = hf_to_vllm_mapper.apply_list(self.ignored_layers)
230

231
    @classmethod
232
    def from_config(cls, config: dict[str, Any]) -> "Fp8Config":
233
        quant_method = cls.get_from_keys(config, ["quant_method"])
234
        is_checkpoint_fp8_serialized = "fp8" in quant_method
235
        activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
236
        ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
237
        weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None)
238
        if not ignored_layers:
239
240
241
242
243
244
245
246
247
248
249
250
251
            ignored_layers = cls.get_from_keys_or(
                config, ["modules_to_not_convert"], None
            )
        return cls(
            is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
            activation_scheme=activation_scheme,
            ignored_layers=ignored_layers,
            weight_block_size=weight_block_size,
        )

    def get_xpu_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
252
253
        from vllm.attention.layer import Attention
        from vllm.model_executor.layers.quantization.ipex_quant import (
254
255
256
257
            XPUFp8LinearMethod,
            XPUFp8MoEMethod,
        )

258
259
260
261
        fp8_config = Fp8Config(
            is_checkpoint_fp8_serialized=self.is_checkpoint_fp8_serialized,
            activation_scheme=self.activation_scheme,
            ignored_layers=self.ignored_layers,
262
263
            weight_block_size=self.weight_block_size,
        )
264
265

        if isinstance(layer, LinearBase):
266
267
268
269
270
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
271
272
273
274
275
276
277
278
                return UnquantizedLinearMethod()
            return XPUFp8LinearMethod(fp8_config)
        elif isinstance(layer, FusedMoE):
            return XPUFp8MoEMethod(fp8_config, layer)
        elif isinstance(layer, Attention):
            return Fp8KVCacheMethod(self)
        return None

279
280
281
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
282
283
        from vllm.attention.layer import Attention  # Avoid circular import

284
285
        if current_platform.is_xpu():
            return self.get_xpu_quant_method(layer, prefix)
286
        if isinstance(layer, LinearBase):
287
288
289
290
291
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
292
                return UnquantizedLinearMethod()
293
            return Fp8LinearMethod(self)
294
        elif isinstance(layer, FusedMoE):
295
296
297
298
299
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
XuruiYang's avatar
XuruiYang committed
300
                return UnquantizedFusedMoEMethod(layer.moe_config)
301
            return Fp8MoEMethod(self, layer)
302
        elif isinstance(layer, Attention):
303
            return Fp8KVCacheMethod(self)
304
        return None
305

306
    def get_cache_scale(self, name: str) -> str | None:
307
308
309
310
311
312
313
314
315
316
317
318
        """
        Check whether the param name matches the format for k/v cache scales
        in compressed-tensors. If this is the case, return its equivalent
        param name expected by vLLM

        :param name: param name
        :return: matching param name for KV cache scale in vLLM
        """
        if name.endswith(".output_scale") and ".k_proj" in name:
            return name.replace(".k_proj.output_scale", ".attn.k_scale")
        if name.endswith(".output_scale") and ".v_proj" in name:
            return name.replace(".v_proj.output_scale", ".attn.v_scale")
319
320
321
322
323
        if name.endswith(".output_scale") and ".q_proj" in name:
            return name.replace(".q_proj.output_scale", ".attn.q_scale")
        if name.endswith("self_attn.prob_output_scale"):
            return name.replace(".prob_output_scale", ".attn.prob_scale")
        # If no matches, return None
324
325
        return None

326
327
328

class Fp8LinearMethod(LinearMethodBase):
    """Linear method for FP8.
329
330
331
332
333
334
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.
335
336
337
338
339

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
    2. Only support float8_e4m3fn data type due to the limitation of
       torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
340

341
342
343
344
    Args:
        quant_config: The quantization config.
    """

345
    def __init__(self, quant_config: Fp8Config):
346
        self.quant_config = quant_config
347
        self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
348
        self.out_dtype = torch.get_default_dtype()
349

350
351
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
352
353
354
355
        self.use_marlin = (
            not current_platform.has_device_capability(89)
            or envs.VLLM_TEST_FORCE_FP8_MARLIN
        )
356
        # Disable marlin for rocm
357
        if current_platform.is_rocm():
358
            self.use_marlin = False
359
        if vllm_is_batch_invariant():
360
            self.use_marlin = False
361

362
        self.use_aiter_and_is_supported = check_aiter_fp8_linear_support()
363

364
365
        self.weight_block_size = self.quant_config.weight_block_size
        self.block_quant = self.weight_block_size is not None
366
        self.act_q_static = self.quant_config.activation_scheme == "static"
367
368
        if self.weight_block_size:
            self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
369
        else:
370
371
372
373
374
            # Use per-token quantization for better perf if dynamic and cutlass
            if not self.act_q_static and cutlass_fp8_supported():
                self.act_q_group_shape = GroupShape.PER_TOKEN
            else:
                self.act_q_group_shape = GroupShape.PER_TENSOR
375

376
377
378
379
380
381
382
383
384
385
386
387
        if self.block_quant:
            assert not self.act_q_static
            assert self.weight_block_size is not None
            self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
                weight_group_shape=GroupShape(*self.weight_block_size),
                act_quant_group_shape=self.act_q_group_shape,
                cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
                use_aiter_and_is_supported=self.use_aiter_and_is_supported,
            )
        else:
            self.fp8_linear = Fp8LinearOp(
                act_quant_static=self.act_q_static,
388
389
                act_quant_group_shape=self.act_q_group_shape,
            )
390

391
392
393
394
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
395
        output_partition_sizes: list[int],
396
397
398
399
400
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
401
402
        maybe_create_device_identity()

403
        output_size_per_partition = sum(output_partition_sizes)
404
        weight_loader = extra_weight_attrs.get("weight_loader")
405
406
407
408
409
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None
410

411
        if self.block_quant:
412
413
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
414
415
416
417
418
419
420
421
            validate_fp8_block_shape(
                layer,
                input_size,
                output_size,
                input_size_per_partition,
                output_partition_sizes,
                self.weight_block_size,
            )
422

423
        # WEIGHT
424
        if self.quant_config.is_checkpoint_fp8_serialized:
425
426
427
            weight = create_fp8_weight_parameter(
                output_size_per_partition, input_size_per_partition, weight_loader
            )
428
429
        else:
            # For non-serialized checkpoints, use original dtype
430
431
432
433
434
435
436
437
438
439
            weight = ModelWeightParameter(
                data=torch.empty(
                    output_size_per_partition,
                    input_size_per_partition,
                    dtype=params_dtype,
                ),
                input_dim=1,
                output_dim=0,
                weight_loader=weight_loader,
            )
440
441
        layer.register_parameter("weight", weight)

442
443
444
445
        # If checkpoint is serialized fp8, load them.
        # Otherwise, wait until process_weights_after_loading.
        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
446
            if not self.block_quant:
447
448
449
450
451
452
453
                scale = create_fp8_scale_parameter(
                    PerTensorScaleParameter,
                    output_partition_sizes,
                    input_size_per_partition,
                    None,
                    weight_loader,
                )
454
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
455
456
                layer.register_parameter("weight_scale", scale)
            else:
457
458
                assert not self.act_q_static
                assert self.weight_block_size is not None
459
460
461
462
463
464
465
                scale = create_fp8_scale_parameter(
                    BlockQuantScaleParameter,
                    output_partition_sizes,
                    input_size_per_partition,
                    self.weight_block_size,
                    weight_loader,
                )
466
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
467
468
                # The weight_scale_inv name is intentional for deepseekv3
                layer.register_parameter("weight_scale_inv", scale)
469

470
            # INPUT ACTIVATION SCALE
471
            if self.act_q_static:
472
                scale = create_fp8_input_scale(output_partition_sizes, weight_loader)
473
                set_weight_attrs(scale, {"scale_type": "input_scale"})
474
                layer.register_parameter("input_scale", scale)
475
476
            else:
                layer.register_parameter("input_scale", None)
477

478
    def process_weights_after_loading(self, layer: Module) -> None:
479
        size_k_first = True
480
        input_scale = None
481
        # TODO(rob): refactor block quant into separate class.
482
        if self.block_quant:
483
            assert not self.act_q_static
484
            size_k_first = False
485

486
            weight, weight_scale = process_fp8_weight_block_strategy(
487
488
                layer.weight, layer.weight_scale_inv
            )
489
490
491
            # Delete the weight_scale_inv parameter to avoid confusion
            # with the weight_scale parameter
            del layer.weight_scale_inv
492

493
        # If checkpoint not serialized fp8, quantize the weights.
494
        elif not self.quant_config.is_checkpoint_fp8_serialized:
495
            qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
496
            weight = qweight.t()
497

498
        # If checkpoint is fp8 per-tensor, handle that there are N scales for N
499
        # shards in a fused module
500
        else:
501
502
            weight = layer.weight
            weight_scale = layer.weight_scale
503
504
505

            # If using w8a8, torch._scaled_mm needs per tensor, so
            # requantize the logical shards as a single weight.
506
            if not self.use_marlin:
507
508
509
510
511
512
                weight, weight_scale, input_scale = process_fp8_weight_tensor_strategy(
                    weight,
                    weight_scale,
                    layer.logical_widths,
                    getattr(layer, "input_scale", None),
                )
513
514
515
516
517
518
519
520
                if self.act_q_static:
                    assert input_scale is not None
                    input_scale = input_scale.max()
            weight = weight.t()

        # Update layer with new values.
        layer.weight = Parameter(weight.data, requires_grad=False)
        layer.weight_scale = Parameter(weight_scale.data, requires_grad=False)
521
522
523
524
525
        layer.input_scale = (
            Parameter(input_scale, requires_grad=False)
            if input_scale is not None
            else None
        )
526

527
        if self.use_marlin:
528
            prepare_fp8_layer_for_marlin(layer, size_k_first)
529
530
            # Activations not quantized for marlin.
            del layer.input_scale
531
            return
532

533
        if self.block_quant:
534
            maybe_post_process_fp8_weight_block(layer, self.cutlass_block_fp8_supported)
535

536
537
538
539
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
540
        bias: torch.Tensor | None = None,
541
    ) -> torch.Tensor:
542
        # If batch invariant mode is enabled, dequantize and use BF16 compute
543
        if vllm_is_batch_invariant():
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
            # Dequantize FP8 weights to BF16
            weight_fp8 = layer.weight.to(torch.bfloat16)
            weight_scale = layer.weight_scale.to(torch.bfloat16)

            # Handle different quantization granularities
            if self.block_quant:
                # Block-wise quantization:
                # - Weight is NOT transposed, shape is [N, K] (output_size, input_size)
                # - Scale has shape [num_blocks_k, num_blocks_n] (TRANSPOSED!)
                assert self.weight_block_size is not None
                block_n, block_k = self.weight_block_size  # Note: order is [N, K]

                N, K = weight_fp8.shape

                # Scale is stored transposed: [num_blocks_k, num_blocks_n]
                # We need to transpose it to [num_blocks_n, num_blocks_k] first
                weight_scale = weight_scale.t()

                # Expand scale to match weight dimensions
                # scale_expanded should have shape [N, K]
                scale_expanded = weight_scale.repeat_interleave(
                    block_n, dim=0
                ).repeat_interleave(block_k, dim=1)
                # Trim to exact weight size (in case of padding)
                scale_expanded = scale_expanded[:N, :K]
                weight_bf16 = weight_fp8 * scale_expanded
            else:
                # Per-tensor quantization: weight IS transposed to [K, N]
                # scale should be scalar or [1] or per-output-channel [N]
                if weight_scale.numel() == 1:
                    # Per-tensor: simple scalar multiplication
                    weight_bf16 = weight_fp8 * weight_scale
                else:
                    # Multiple scales (fused modules like QKV)
                    # Try to infer correct broadcasting
                    # weight is [K, N], scale could be [num_logical_weights]
                    # Need to figure out how to broadcast - for now just try
                    # direct multiplication
                    if (
                        weight_scale.dim() == 1
                        and weight_scale.shape[0] == weight_fp8.shape[0]
                    ):
                        # Per-row scaling
                        weight_bf16 = weight_fp8 * weight_scale.unsqueeze(1)
                    else:
                        # Fallback
                        weight_bf16 = weight_fp8 * weight_scale

            # For block quant, weight is [N, K], for per-tensor it's [K, N]
            # F.linear expects weight to be [N, K], so:
            if self.block_quant:
                # Already in correct shape [N, K]
                output = torch.nn.functional.linear(x, weight_bf16, bias)
            else:
                # Need to transpose back: [K, N] -> [N, K]
                output = torch.nn.functional.linear(x, weight_bf16.t(), bias)
            return output

602
        if self.use_marlin:
603
604
605
606
            return apply_fp8_marlin_linear(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale,
607
608
609
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
610
611
                bias=bias,
            )
612

613
        if self.block_quant:
614
615
616
            assert self.weight_block_size is not None

            return self.w8a8_block_fp8_linear.apply(
617
                input=x,
618
619
620
                weight=layer.weight,
                weight_scale=layer.weight_scale,
                input_scale=layer.input_scale,
621
                bias=bias,
622
            )
623

624
625
626
627
628
629
630
631
        return self.fp8_linear.apply(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            out_dtype=self.out_dtype,
            input_scale=layer.input_scale,
            bias=bias,
        )
632
633


634
635
636
637
638
639
640
641
642
643
644
645
646
class Fp8MoEMethod(FusedMoEMethodBase):
    """MoE method for FP8.
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

647
648
649
    def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
        super().__init__(layer.moe_config)
        self.layer = layer
650
        self.quant_config = quant_config
651
        self.weight_block_size = self.quant_config.weight_block_size
652
        self.block_quant: bool = self.weight_block_size is not None
653

654
        self.fused_experts: mk.FusedMoEModularKernel | None = None  # type: ignore
655

656
        self.fp8_backend = get_fp8_moe_backend(self.block_quant)
657

658
        self.use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN
659
        self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
660
661
662
663
664
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
            self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
        elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
            self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS

665
        self.allow_deep_gemm = self.fp8_backend == Fp8MoeBackend.DEEPGEMM
666
667
668
        self.allow_cutlass_block_scaled_grouped_gemm = (
            self.fp8_backend == Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM
        )
669

670
671
672
673
674
675
676
677
678
    def create_weights(
        self,
        layer: Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
679
680
681
682
683
684
        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

685
686
        if self.quant_config.is_checkpoint_fp8_serialized:
            params_dtype = torch.float8_e4m3fn
687
        if self.block_quant:
688
689
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
690
691
            tp_size = get_tensor_model_parallel_world_size()
            block_n, block_k = (
692
693
                self.weight_block_size[0],
                self.weight_block_size[1],
694
695
696
697
698
            )
            # NOTE: To ensure proper alignment of the block-wise quantization
            # scales, the output_size of the weights for both the gate and up
            # layers must be divisible by block_n.
            # Required by column parallel or enabling merged weights
699
            if intermediate_size_per_partition % block_n != 0:
700
701
                raise ValueError(
                    f"The output_size of gate's and up's weight = "
702
                    f"{intermediate_size_per_partition} is not divisible by "
703
704
705
                    f"weight quantization block_n = {block_n}."
                )
            if tp_size > 1 and intermediate_size_per_partition % block_k != 0:
706
                # Required by row parallel
707
708
709
                raise ValueError(
                    f"The input_size of down's weight = "
                    f"{intermediate_size_per_partition} is not divisible by "
710
711
                    f"weight quantization block_k = {block_k}."
                )
712
713

        # WEIGHTS
714
715
716
717
718
719
720
721
722
        w13_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                hidden_size,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
723
724
725
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

726
727
728
729
730
731
732
733
734
        w2_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
735
736
737
738
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
739
740
741
        if not self.block_quant:
            # Allocate 2 scales for w1 and w3 respectively.
            # They will be combined to a single scale after weight loading.
742
743
744
745
746
747
            w13_weight_scale = torch.nn.Parameter(
                torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
            )
            w2_weight_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
748
749
750
751
752
753
            layer.register_parameter("w13_weight_scale", w13_weight_scale)
            layer.register_parameter("w2_weight_scale", w2_weight_scale)
        else:
            w13_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
754
                    2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
755
756
757
758
759
760
761
762
763
                    (hidden_size + block_k - 1) // block_k,
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            w2_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
                    (hidden_size + block_n - 1) // block_n,
764
                    (intermediate_size_per_partition + block_k - 1) // block_k,
765
766
767
768
769
770
771
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
            layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
            assert self.quant_config.activation_scheme == "dynamic"
772

773
774
775
        # Add the quantization method used (per tensor/grouped/channel)
        # to ensure the weight scales are loaded in properly
        extra_weight_attrs.update(
776
777
778
779
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
            if self.block_quant
            else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
        )
780
781
782
783
        # If loading fp8 checkpoint, pass the weight loaders.
        # If loading an fp16 checkpoint, do not (we will quantize in
        #   process_weights_after_loading()
        if self.quant_config.is_checkpoint_fp8_serialized:
784
785
            set_weight_attrs(w13_weight_scale, extra_weight_attrs)
            set_weight_attrs(w2_weight_scale, extra_weight_attrs)
786
787
788
789
790
791

        # INPUT_SCALES
        if self.quant_config.activation_scheme == "static":
            if not self.quant_config.is_checkpoint_fp8_serialized:
                raise ValueError(
                    "Found static activation scheme for checkpoint that "
792
793
                    "was not serialized fp8."
                )
794

795
796
797
            w13_input_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
798
            layer.register_parameter("w13_input_scale", w13_input_scale)
799
            set_weight_attrs(w13_input_scale, extra_weight_attrs)
800

801
802
803
            w2_input_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
804
            layer.register_parameter("w2_input_scale", w2_input_scale)
805
806
            set_weight_attrs(w2_input_scale, extra_weight_attrs)

807
        else:
808
809
            layer.w13_input_scale = None
            layer.w2_input_scale = None
810

811
812
        self.rocm_aiter_moe_enabled = False

813
    def process_weights_after_loading(self, layer: Module) -> None:
814
815
        # Lazy import to avoid importing triton too early.
        from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
816
817
818
            is_rocm_aiter_moe_enabled,
            shuffle_weights,
        )
819

820
821
        self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()

822
        # TODO (rob): refactor block quant into separate class.
823
        if self.block_quant:
824
            assert self.quant_config.activation_scheme == "dynamic"
825
            if current_platform.is_fp8_fnuz():
826
                w13_weight, w13_weight_scale_inv, w13_input_scale = (
827
                    normalize_e4m3fn_to_e4m3fnuz(
828
829
830
831
832
833
                        layer.w13_weight,
                        layer.w13_weight_scale_inv,
                        layer.w13_input_scale,
                    )
                )
                w2_weight, w2_weight_scale_inv, w2_input_scale = (
834
                    normalize_e4m3fn_to_e4m3fnuz(
835
836
837
                        layer.w2_weight, layer.w2_weight_scale_inv, layer.w2_input_scale
                    )
                )
838
            elif self.flashinfer_moe_backend is not None:
839
840
                # NOTE: weights have to be swapped since the activation is
                # applied on different half for flashinfer vs vllm
841
                w13_weight = swap_w13_to_w31(layer.w13_weight.data)
842
                w13_weight_scale_inv = swap_w13_to_w31(layer.w13_weight_scale_inv.data)
843
844
                w2_weight = layer.w2_weight.data
                w2_weight_scale_inv = layer.w2_weight_scale_inv.data
845
846
847
848
849
850
851
852
            else:
                w13_weight = layer.w13_weight.data
                w13_weight_scale_inv = layer.w13_weight_scale_inv.data
                w2_weight = layer.w2_weight
                w2_weight_scale_inv = layer.w2_weight_scale_inv

            # torch.compile() cannot use Parameter subclasses.
            layer.w13_weight = Parameter(w13_weight, requires_grad=False)
853
854
855
            layer.w13_weight_scale_inv = Parameter(
                w13_weight_scale_inv, requires_grad=False
            )
856
            layer.w2_weight = Parameter(w2_weight, requires_grad=False)
857
858
859
            layer.w2_weight_scale_inv = Parameter(
                w2_weight_scale_inv, requires_grad=False
            )
860
            if self.rocm_aiter_moe_enabled:
861
862
                # reshaping weights is required for aiter moe kernel.
                shuffled_w13, shuffled_w2 = shuffle_weights(
863
864
                    layer.w13_weight.data, layer.w2_weight.data
                )
865

866
867
                layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
868

869
            # DeepGemm scales need to be transposed and aligned. We try to do
870
            # it ahead of time for performance reasons.
871
            if self.allow_deep_gemm and not is_deep_gemm_e8m0_used():
872
                if expert_weight_is_col_major(layer.w13_weight_scale_inv):
873
874
875
                    layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(
                        layer.w13_weight_scale_inv
                    )
876
                if expert_weight_is_col_major(layer.w2_weight_scale_inv):
877
878
879
                    layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(
                        layer.w2_weight_scale_inv
                    )
880

881
        # If checkpoint is fp16, quantize in place.
882
        elif not self.quant_config.is_checkpoint_fp8_serialized:
883
            fp8_dtype = current_platform.fp8_dtype()
884
            w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
885
            w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
886
887
888

            # Re-initialize w13_scale because we directly quantize
            # merged w13 weights and generate a single scaling factor.
889
890
891
892
893
894
895
896
            layer.w13_weight_scale = torch.nn.Parameter(
                torch.ones(
                    layer.local_num_experts,
                    dtype=torch.float32,
                    device=w13_weight.device,
                ),
                requires_grad=False,
            )
897
            for expert in range(layer.local_num_experts):
898
899
900
901
902
903
904
905
                w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
                    ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
                )
                w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
                    ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
                )
            layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
            layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
906
            if self.rocm_aiter_moe_enabled:
907
                # reshaping weights is required for aiter moe kernel.
908
                shuffled_w13, shuffled_w2 = shuffle_weights(
909
910
                    layer.w13_weight, layer.w2_weight
                )
911

912
913
                layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
914
915
916
917
918
919
920
        # If checkpoint is fp8, we need to handle that the
        # MoE kernels require single activation scale and single weight
        # scale for w13 per expert.
        else:
            # Fp8 moe kernels require a single activation scale.
            # We take the max of all the scales in case they differ.
            if self.quant_config.activation_scheme == "static":
921
                if layer.w13_input_scale is None or layer.w2_input_scale is None:
922
923
                    raise ValueError(
                        "QuantConfig has static quantization, but found "
924
925
926
927
928
                        "activation scales are None."
                    )
                if not all_close_1d(layer.w13_input_scale) or not all_close_1d(
                    layer.w2_input_scale
                ):
929
                    logger.warning_once(
930
931
                        "Found input_scales that are not equal for "
                        "fp8 MoE layer. Using the maximum across experts "
932
933
                        "for each layer."
                    )
934
                layer.w13_input_scale = torch.nn.Parameter(
935
936
                    layer.w13_input_scale.max(), requires_grad=False
                )
937
                layer.w2_input_scale = torch.nn.Parameter(
938
939
                    layer.w2_input_scale.max(), requires_grad=False
                )
940
            if current_platform.is_fp8_fnuz():
941
                # Normalize the weights and scales
942
                w13_weight, w13_weight_scale, w13_input_scale = (
943
                    normalize_e4m3fn_to_e4m3fnuz(
944
945
946
947
                        layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale
                    )
                )
                w2_weight, w2_weight_scale, w2_input_scale = (
948
                    normalize_e4m3fn_to_e4m3fnuz(
949
950
951
                        layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale
                    )
                )
952
                # Reset the parameter
953
                layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
954
                layer.w13_weight_scale = torch.nn.Parameter(
955
956
                    w13_weight_scale, requires_grad=False
                )
957
958
                if w13_input_scale is not None:
                    layer.w13_input_scale = torch.nn.Parameter(
959
960
961
962
963
964
                        w13_input_scale, requires_grad=False
                    )
                layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
                layer.w2_weight_scale = torch.nn.Parameter(
                    w2_weight_scale, requires_grad=False
                )
965
966
                if w2_input_scale is not None:
                    layer.w2_input_scale = torch.nn.Parameter(
967
968
                        w2_input_scale, requires_grad=False
                    )
969
970
971

            # Fp8 moe kernel needs single weight scale for w13 per expert.
            # We take the max then dequant and requant each expert.
972
            assert layer.w13_weight_scale is not None
973
            shard_size = layer.intermediate_size_per_partition
974
            max_w13_scales = layer.w13_weight_scale.max(dim=1).values
975
            for expert_id in range(layer.local_num_experts):
976
977
978
                start = 0
                for shard_id in range(2):
                    dq_weight = per_tensor_dequantize(
979
980
981
982
983
984
                        layer.w13_weight[expert_id][start : start + shard_size, :],
                        layer.w13_weight_scale[expert_id][shard_id],
                    )
                    layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
                        ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
                    )
985
986
                    start += shard_size

987
            if self.rocm_aiter_moe_enabled:
988
                shuffled_w13, shuffled_w2 = shuffle_weights(
989
990
                    layer.w13_weight, layer.w2_weight
                )
991

992
993
                layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
994

995
996
997
            layer.w13_weight_scale = torch.nn.Parameter(
                max_w13_scales, requires_grad=False
            )
998

999
1000
1001
1002
1003
1004
            if self.flashinfer_moe_backend is not None:
                # NOTE: weights have to be swapped since the activation is
                # applied on different half for flashinfer vs vllm
                assert not self.block_quant
                register_moe_scaling_factors(layer)
                w13_weight = swap_w13_to_w31(layer.w13_weight.data)
1005
                if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
1006
1007
1008
                    rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
                layer.w13_weight.data = w13_weight.data

1009
1010
1011
1012
1013
        if self.use_marlin:
            prepare_moe_fp8_layer_for_marlin(layer, False)
            # Activations not quantized for marlin.
            del layer.w13_input_scale
            del layer.w2_input_scale
1014

1015
        if is_deep_gemm_e8m0_used() and self.block_quant:
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
            assert layer.weight_block_size is not None
            # Re-quantise the expert weights so their scales are UE8M0.
            block_sz = tuple(layer.weight_block_size)
            requant_weight_ue8m0_inplace(
                layer.w13_weight.data,
                layer.w13_weight_scale_inv.data,
                block_sz,
            )
            requant_weight_ue8m0_inplace(
                layer.w2_weight.data,
                layer.w2_weight_scale_inv.data,
                block_sz,
            )

            # Ensure column-major TMA alignment expected by DeepGEMM.
1031
            if expert_weight_is_col_major(layer.w13_weight_scale_inv):
1032
                layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(
1033
1034
                    layer.w13_weight_scale_inv
                )
1035
            if expert_weight_is_col_major(layer.w2_weight_scale_inv):
1036
                layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(
1037
1038
                    layer.w2_weight_scale_inv
                )
1039

1040
    def maybe_make_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize | None:
1041
1042
1043
1044
1045
        if (
            self.rocm_aiter_moe_enabled
            or self.use_marlin
            or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1046
1047
            return None
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
1048
1049
1050
            prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
                self.moe
            )
1051
1052
1053
1054
1055
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
        else:
            return super().maybe_make_prepare_finalize()

bnellnm's avatar
bnellnm committed
1056
1057
1058
    def select_gemm_impl(
        self,
        prepare_finalize: FusedMoEPrepareAndFinalize,
1059
        layer: torch.nn.Module,
bnellnm's avatar
bnellnm committed
1060
    ) -> FusedMoEPermuteExpertsUnpermute:
1061
        from vllm.model_executor.layers.fused_moe import (
1062
1063
1064
            BatchedTritonOrDeepGemmExperts,
            TritonOrDeepGemmExperts,
        )
1065

1066
        assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
1067
1068
            "Marlin and ROCm AITER are not supported with all2all yet."
        )
1069

1070
1071
        assert self.moe_quant_config is not None

1072
1073
1074
1075
1076
        if (
            prepare_finalize.activation_format
            == FusedMoEActivationFormat.BatchedExperts
        ):
            max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
bnellnm's avatar
bnellnm committed
1077
1078
1079
1080
            assert max_num_tokens_per_rank is not None
            logger.debug(
                "BatchedTritonOrDeepGemmExperts(%s): "
                "max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
1081
1082
1083
1084
1085
                self.__class__.__name__,
                max_num_tokens_per_rank,
                self.weight_block_size,
                False,
            )
bnellnm's avatar
bnellnm committed
1086
            return BatchedTritonOrDeepGemmExperts(
1087
                max_num_tokens=max_num_tokens_per_rank,
1088
                num_dispatchers=prepare_finalize.num_dispatchers(),
1089
                quant_config=self.moe_quant_config,
1090
                allow_deep_gemm=self.allow_deep_gemm,
1091
            )
1092
1093
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
            experts = select_cutlass_fp8_gemm_impl(
1094
1095
                self.moe,
                self.moe_quant_config,
1096
1097
1098
            )
            logger.debug_once("Using %s", experts.__class__.__name__)
            return experts
1099
        else:
bnellnm's avatar
bnellnm committed
1100
1101
            logger.debug(
                "TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s",
1102
1103
1104
1105
                self.__class__.__name__,
                self.weight_block_size,
                False,
            )
bnellnm's avatar
bnellnm committed
1106
            return TritonOrDeepGemmExperts(
1107
                quant_config=self.moe_quant_config,
1108
1109
1110
                allow_deep_gemm=self.allow_deep_gemm,
            )

1111
    def get_fused_moe_quant_config(
1112
        self, layer: torch.nn.Module
1113
    ) -> FusedMoEQuantConfig | None:
1114
1115
1116
1117
        if self.use_marlin:
            return None

        return fp8_w8a8_moe_quant_config(
1118
1119
1120
1121
1122
1123
1124
1125
            w1_scale=(
                layer.w13_weight_scale_inv
                if self.block_quant
                else layer.w13_weight_scale
            ),
            w2_scale=(
                layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
            ),
1126
1127
            a1_scale=layer.w13_input_scale,
            a2_scale=layer.w2_input_scale,
1128
            block_shape=self.weight_block_size,
1129
1130
        )

1131
1132
1133
1134
1135
1136
1137
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
1138
        use_grouped_topk: bool = False,
1139
1140
        topk_group: int | None = None,
        num_expert_group: int | None = None,
1141
        global_num_experts: int = -1,
1142
1143
        expert_map: torch.Tensor | None = None,
        custom_routing_function: Callable | None = None,
Simon Mo's avatar
Simon Mo committed
1144
        scoring_func: str = "softmax",
1145
        routed_scaling_factor: float = 1.0,
1146
        e_score_correction_bias: torch.Tensor | None = None,
1147
        apply_router_weight_on_input: bool = False,
Michael Goin's avatar
Michael Goin committed
1148
        activation: str = "silu",
1149
        enable_eplb: bool = False,
1150
1151
1152
1153
        expert_load_view: torch.Tensor | None = None,
        logical_to_physical_map: torch.Tensor | None = None,
        logical_replica_count: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
1154
1155
1156
1157
1158
        if enable_eplb:
            assert expert_load_view is not None
            assert logical_to_physical_map is not None
            assert logical_replica_count is not None
            assert isinstance(layer, FusedMoE)
1159

1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
        if (
            self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
            and self.fused_experts is None
        ):
            assert activation == "silu", (
                f"Expected 'silu' activation but got {activation}"
            )
            assert scoring_func == "sigmoid", (
                f"Expected 'sigmoid' scoring func but got {scoring_func}"
            )
1170
            if self.block_quant:
1171
                import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe  # noqa: E501, F401
1172
1173
1174
1175
1176
1177
1178
1179
1180

                assert (
                    renormalize and use_grouped_topk and custom_routing_function is None
                )
                e_score_correction_bias = (
                    e_score_correction_bias.to(x.dtype)
                    if e_score_correction_bias is not None
                    else None
                )
1181
                return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
                    routing_logits=router_logits.to(torch.float32),
                    routing_bias=e_score_correction_bias,
                    x=x,
                    w13_weight=layer.w13_weight,
                    w13_weight_scale_inv=layer.w13_weight_scale_inv,
                    w2_weight=layer.w2_weight,
                    w2_weight_scale_inv=layer.w2_weight_scale_inv,
                    global_num_experts=global_num_experts,
                    top_k=top_k,
                    num_expert_group=num_expert_group,
                    topk_group=topk_group,
                    intermediate_size=layer.intermediate_size_per_partition,
                    expert_offset=layer.ep_rank * layer.local_num_experts,
                    local_num_experts=layer.local_num_experts,
1196
                    block_shape=self.weight_block_size,
1197
                    routed_scaling=routed_scaling_factor,
1198
1199
                )
            else:
1200
                assert not renormalize and custom_routing_function is not None
XuruiYang's avatar
XuruiYang committed
1201
                result = apply_flashinfer_per_tensor_scale_fp8(
1202
1203
1204
1205
1206
1207
1208
1209
                    layer=layer,
                    hidden_states=x,
                    router_logits=router_logits,
                    routing_bias=e_score_correction_bias,
                    global_num_experts=global_num_experts,
                    top_k=top_k,
                    num_expert_group=num_expert_group,
                    topk_group=topk_group,
1210
1211
                    apply_router_weight_on_input=apply_router_weight_on_input,
                )
1212

1213
1214
        zero_expert_num = getattr(layer, "zero_expert_num", 0)
        zero_expert_type = getattr(layer, "zero_expert_type", None)
XuruiYang's avatar
XuruiYang committed
1215
1216

        select_result = FusedMoE.select_experts(
1217
1218
1219
1220
1221
1222
1223
1224
1225
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
            num_expert_group=num_expert_group,
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
1226
            routed_scaling_factor=routed_scaling_factor,
1227
1228
1229
1230
1231
1232
1233
            e_score_correction_bias=e_score_correction_bias,
            indices_type=self.topk_indices_dtype,
            enable_eplb=enable_eplb,
            expert_map=expert_map,
            expert_load_view=expert_load_view,
            logical_to_physical_map=logical_to_physical_map,
            logical_replica_count=logical_replica_count,
XuruiYang's avatar
XuruiYang committed
1234
1235
1236
            global_num_experts=global_num_experts,
            zero_expert_num=zero_expert_num,
            zero_expert_type=zero_expert_type,
1237
            num_fused_shared_experts=layer.num_fused_shared_experts,
1238
1239
        )

1240
1241
1242
1243
        #
        # Note: the order of checks is important since self.fused_experts
        # can override fused_experts or cutlass but not rocm or marlin.
        #
XuruiYang's avatar
XuruiYang committed
1244
1245
        topk_weights, topk_ids, zero_expert_result = select_result

1246
1247
        if self.rocm_aiter_moe_enabled:
            from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (  # noqa: E501
1248
1249
1250
                rocm_aiter_fused_experts,
            )

1251
            assert self.fused_experts is None
XuruiYang's avatar
XuruiYang committed
1252
            result = rocm_aiter_fused_experts(
1253
1254
1255
                x,
                layer.w13_weight,
                layer.w2_weight,
1256
1257
1258
1259
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                activation=activation,
                apply_router_weight_on_input=apply_router_weight_on_input,
1260
                expert_map=expert_map,
1261
1262
                quant_config=self.moe_quant_config,
            )
1263
        elif self.use_marlin:
1264
            assert activation == "silu", f"{activation} not supported for Marlin MoE."
1265
            assert self.fused_experts is None
1266
            result = fused_marlin_moe(
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
                x,
                layer.w13_weight,
                layer.w2_weight,
                None,
                None,
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                router_logits,
                topk_weights,
                topk_ids,
                quant_type_id=scalar_types.float8_e4m3fn.id,
                apply_router_weight_on_input=apply_router_weight_on_input,
                global_num_experts=global_num_experts,
1280
                expert_map=expert_map,
1281
1282
                workspace=layer.workspace,
            )
1283
        elif self.fused_experts:
XuruiYang's avatar
XuruiYang committed
1284
            result = self.fused_experts(
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                inplace=True,
                activation=activation,
                global_num_experts=global_num_experts,
                apply_router_weight_on_input=apply_router_weight_on_input,
                expert_map=expert_map,
            )
1296
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
1297
            assert not self.block_quant
1298
1299
1300
1301
1302
1303
1304
            assert not renormalize and custom_routing_function is not None
            assert activation == "silu", (
                f"Expected 'silu' activation but got {activation}"
            )
            assert scoring_func == "sigmoid", (
                f"Expected 'sigmoid' scoring func but got {scoring_func}"
            )
1305

XuruiYang's avatar
XuruiYang committed
1306
            result = flashinfer_cutlass_moe_fp8(
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
                x,
                layer,
                topk_weights,
                topk_ids,
                inplace=False,
                activation=activation,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
1317
        else:
1318
            from vllm.model_executor.layers.fused_moe import fused_experts
1319

XuruiYang's avatar
XuruiYang committed
1320
            result = fused_experts(
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                inplace=True,
                activation=activation,
                global_num_experts=global_num_experts,
                apply_router_weight_on_input=apply_router_weight_on_input,
                expert_map=expert_map,
1331
1332
1333
                quant_config=self.moe_quant_config,
                allow_deep_gemm=self.allow_deep_gemm,
                allow_cutlass_block_scaled_grouped_gemm=(
1334
1335
1336
                    self.allow_cutlass_block_scaled_grouped_gemm
                ),
            )
XuruiYang's avatar
XuruiYang committed
1337
        if zero_expert_num != 0 and zero_expert_type is not None:
1338
            assert not isinstance(result, tuple), (
XuruiYang's avatar
XuruiYang committed
1339
                "Shared + zero experts are mutually exclusive not yet supported"
1340
            )
XuruiYang's avatar
XuruiYang committed
1341
1342
1343
            return result, zero_expert_result
        else:
            return result
1344
1345


1346
1347
1348
class Fp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
1349
1350
1351
    """

    def __init__(self, quant_config: Fp8Config):
1352
        super().__init__(quant_config)