fp8.py 52.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

bnellnm's avatar
bnellnm committed
4
from typing import TYPE_CHECKING, Any, Callable, Optional
5
6

import torch
7
import torch.nn.functional as F
8
9
10
from torch.nn import Module
from torch.nn.parameter import Parameter

11
import vllm.envs as envs
12
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
13
from vllm import _custom_ops as ops
14
from vllm.distributed import get_tensor_model_parallel_world_size
15
from vllm.logger import init_logger
bnellnm's avatar
bnellnm committed
16
from vllm.model_executor.layers.fused_moe import (
17
18
19
    FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase,
    FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize,
    FusedMoeWeightScaleSupported)
20
21
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
                                               UnquantizedLinearMethod)
22
from vllm.model_executor.layers.quantization import QuantizationMethods
23
from vllm.model_executor.layers.quantization.base_config import (
24
    QuantizationConfig, QuantizeMethodBase)
25
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
26
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
27
28
29
30
31
    FlashinferMoeBackend, apply_flashinfer_per_tensor_scale_fp8,
    build_flashinfer_fp8_cutlass_moe_prepare_finalize,
    flashinfer_cutlass_moe_fp8, get_flashinfer_moe_backend,
    register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights,
    select_cutlass_fp8_gemm_impl, swap_w13_to_w31)
32
33
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
    get_col_major_tma_aligned_tensor, requant_weight_ue8m0_inplace)
34
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
35
36
    apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin,
    prepare_moe_fp8_layer_for_marlin)
37
from vllm.model_executor.layers.quantization.utils.quant_utils import (
38
    GroupShape, is_layer_skipped)
39
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
40
41
42
43
    Fp8LinearOp, all_close_1d, cutlass_block_fp8_supported,
    cutlass_fp8_supported, maybe_create_device_identity,
    normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
    requantize_with_max_scale)
44
45
from vllm.model_executor.parameter import (BlockQuantScaleParameter,
                                           ModelWeightParameter,
46
                                           PerTensorScaleParameter)
47
from vllm.model_executor.utils import set_weight_attrs
48
from vllm.platforms import current_platform
49
from vllm.scalar_type import scalar_types
50
from vllm.utils import has_deep_gemm
51
52
from vllm.utils.deep_gemm import (is_blackwell_deep_gemm_e8m0_used,
                                  is_deep_gemm_supported)
53
from vllm.utils.flashinfer import has_flashinfer_moe
54

55
56
57
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

58
59
60
61
ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)

62
63
64
65
66
67

def _is_col_major(x: torch.Tensor) -> bool:
    assert x.dim() == 3
    b, m, n = x.shape
    return x.stride(0) == m * n and x.stride(1) == 1 and x.stride(2) == m

68

69
class Fp8Config(QuantizationConfig):
70
71
    """Config class for FP8."""

72
73
    def __init__(
        self,
74
        is_checkpoint_fp8_serialized: bool = False,
75
        activation_scheme: str = "dynamic",
76
77
        ignored_layers: Optional[list[str]] = None,
        weight_block_size: Optional[list[int]] = None,
78
    ) -> None:
79
        super().__init__()
80

81
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
82

83
84
85
        if activation_scheme not in ACTIVATION_SCHEMES:
            raise ValueError(
                f"Unsupported activation scheme {activation_scheme}")
86
        self.activation_scheme = activation_scheme
87
        self.ignored_layers = ignored_layers or []
88
89
90
91
92
93
94
95
96
97
98
99
100
101
        if weight_block_size is not None:
            if not is_checkpoint_fp8_serialized:
                raise ValueError(
                    "The block-wise quantization only supports fp8-serialized "
                    "checkpoint for now.")
            if len(weight_block_size) != 2:
                raise ValueError(
                    "The quantization block size of weight must have 2 "
                    f"dimensions, but got {len(weight_block_size)} dimensions")
            if activation_scheme != "dynamic":
                raise ValueError("The block-wise quantization only supports "
                                 "dynamic activation scheme for now, but got "
                                 f"{activation_scheme} activation scheme.")
        self.weight_block_size = weight_block_size
102

103
    @classmethod
104
    def get_name(cls) -> QuantizationMethods:
105
106
107
        return "fp8"

    @classmethod
108
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
109
110
111
112
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
113
        return 80
114
115

    @classmethod
116
    def get_config_filenames(cls) -> list[str]:
117
118
        return []

119
120
121
122
123
    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if self.ignored_layers is not None:
            self.ignored_layers = hf_to_vllm_mapper.apply_list(
                self.ignored_layers)

124
    @classmethod
125
    def from_config(cls, config: dict[str, Any]) -> "Fp8Config":
126
127
        quant_method = cls.get_from_keys(config, ["quant_method"])
        is_checkpoint_fp8_serialized = ("fp8" in quant_method)
128
        activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
129
        ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
130
131
        weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"],
                                                 None)
132
133
134
135
        if not ignored_layers:
            ignored_layers = cls.get_from_keys_or(config,
                                                  ["modules_to_not_convert"],
                                                  None)
136
        return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
137
                   activation_scheme=activation_scheme,
138
139
                   ignored_layers=ignored_layers,
                   weight_block_size=weight_block_size)
140

141
142
    def get_quant_method(self, layer: torch.nn.Module,
                         prefix: str) -> Optional["QuantizeMethodBase"]:
143
144
        from vllm.attention.layer import Attention  # Avoid circular import

145
        if isinstance(layer, LinearBase):
146
147
148
            if is_layer_skipped(prefix=prefix,
                                ignored_layers=self.ignored_layers,
                                fused_mapping=self.packed_modules_mapping):
149
                return UnquantizedLinearMethod()
150
            return Fp8LinearMethod(self)
151
        elif isinstance(layer, FusedMoE):
152
            return Fp8MoEMethod(self, layer)
153
        elif isinstance(layer, Attention):
154
            return Fp8KVCacheMethod(self)
155
        return None
156

157
158
159
160
161
162
163
164
165
166
167
168
169
    def get_cache_scale(self, name: str) -> Optional[str]:
        """
        Check whether the param name matches the format for k/v cache scales
        in compressed-tensors. If this is the case, return its equivalent
        param name expected by vLLM

        :param name: param name
        :return: matching param name for KV cache scale in vLLM
        """
        if name.endswith(".output_scale") and ".k_proj" in name:
            return name.replace(".k_proj.output_scale", ".attn.k_scale")
        if name.endswith(".output_scale") and ".v_proj" in name:
            return name.replace(".v_proj.output_scale", ".attn.v_scale")
170
171
172
173
174
        if name.endswith(".output_scale") and ".q_proj" in name:
            return name.replace(".q_proj.output_scale", ".attn.q_scale")
        if name.endswith("self_attn.prob_output_scale"):
            return name.replace(".prob_output_scale", ".attn.prob_scale")
        # If no matches, return None
175
176
        return None

177
178
179

class Fp8LinearMethod(LinearMethodBase):
    """Linear method for FP8.
180
181
182
183
184
185
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.
186
187
188
189
190

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
    2. Only support float8_e4m3fn data type due to the limitation of
       torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
191

192
193
194
195
    Args:
        quant_config: The quantization config.
    """

196
    def __init__(self, quant_config: Fp8Config):
197
        self.quant_config = quant_config
198
        self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
199
        self.out_dtype = torch.get_default_dtype()
200

201
202
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
203
204
        self.use_marlin = (not current_platform.has_device_capability(89)
                           or envs.VLLM_TEST_FORCE_FP8_MARLIN)
205
        # Disable marlin for rocm
206
        if current_platform.is_rocm():
207
            self.use_marlin = False
208

209
210
211
212
213
214
215
        # AITER is only supported on ROCm and only for FP8_FNUZ
        # and at the moment are MI300 series
        self.use_aiter_and_is_supported = (current_platform.is_rocm()
                                           and envs.VLLM_ROCM_USE_AITER
                                           and envs.VLLM_ROCM_USE_AITER_LINEAR
                                           and current_platform.is_fp8_fnuz())

216
        self.block_quant = self.quant_config.weight_block_size is not None
217
218
219
220
221
222
223
        self.act_q_static = self.quant_config.activation_scheme == "static"
        # Use per-token quantization for better perf if dynamic and cutlass
        if not self.act_q_static and cutlass_fp8_supported():
            self.act_q_group_shape = GroupShape.PER_TOKEN
        else:
            self.act_q_group_shape = GroupShape.PER_TENSOR

224
        self.fp8_linear = Fp8LinearOp(
225
            act_quant_static=self.act_q_static,
226
            act_quant_group_shape=self.act_q_group_shape)
227

228
229
230
231
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
232
        output_partition_sizes: list[int],
233
234
235
236
237
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
238
239
        maybe_create_device_identity()

240
        output_size_per_partition = sum(output_partition_sizes)
241
        weight_loader = extra_weight_attrs.get("weight_loader")
242
243
244
245
246
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None
247

248
249
250
        if self.block_quant:
            tp_size = get_tensor_model_parallel_world_size()
            assert self.quant_config.weight_block_size is not None
251
            layer.weight_block_size = self.quant_config.weight_block_size
252
253
254
255
256
257
258
259
260
261
262
263
264
            block_n, block_k = (
                self.quant_config.weight_block_size[0],
                self.quant_config.weight_block_size[1],
            )
            # Required by row parallel
            if (tp_size > 1
                    and input_size // input_size_per_partition == tp_size
                    and input_size_per_partition % block_k != 0):
                raise ValueError(
                    f"Weight input_size_per_partition = "
                    f"{input_size_per_partition} is not divisible by "
                    f"weight quantization block_k = {block_k}.")
            # Required by column parallel or enabling merged weights
265
266
267
268
269
270
271
272
273
274
            is_tp_split = (tp_size > 1 and
                           output_size // output_size_per_partition == tp_size)
            is_merged_gemm = len(output_partition_sizes) > 1
            if is_tp_split or is_merged_gemm:
                sizes_to_check = output_partition_sizes
                if not is_tp_split and is_merged_gemm:
                    # In case of merged matrices, we allow the last
                    # matrix to not be a multiple of block size
                    sizes_to_check = output_partition_sizes[:-1]
                for output_partition_size in sizes_to_check:
275
276
277
278
279
280
                    if output_partition_size % block_n != 0:
                        raise ValueError(
                            f"Weight output_partition_size = "
                            f"{output_partition_size} is not divisible by "
                            f"weight quantization block_n = {block_n}.")

281
282
283
284
        # WEIGHT
        weight_dtype = (torch.float8_e4m3fn
                        if self.quant_config.is_checkpoint_fp8_serialized else
                        params_dtype)
285
286
287
288
289
290
291
292

        weight = ModelWeightParameter(data=torch.empty(
            output_size_per_partition,
            input_size_per_partition,
            dtype=weight_dtype),
                                      input_dim=1,
                                      output_dim=0,
                                      weight_loader=weight_loader)
293
294
        layer.register_parameter("weight", weight)

295
296
297
298
        # If checkpoint is serialized fp8, load them.
        # Otherwise, wait until process_weights_after_loading.
        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
299
300
301
302
303
304
305
            if not self.block_quant:
                scale = PerTensorScaleParameter(
                    data=torch.empty(len(output_partition_sizes),
                                     dtype=torch.float32),
                    weight_loader=weight_loader,
                )
                scale[:] = torch.finfo(torch.float32).min
306
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
307
308
309
310
311
312
313
314
315
316
317
318
319
320
                layer.register_parameter("weight_scale", scale)
            else:
                assert self.quant_config.activation_scheme == "dynamic"
                scale = BlockQuantScaleParameter(
                    data=torch.empty(
                        (output_size_per_partition + block_n - 1) // block_n,
                        (input_size_per_partition + block_k - 1) // block_k,
                        dtype=torch.float32,
                    ),
                    input_dim=1,
                    output_dim=0,
                    weight_loader=weight_loader,
                )
                scale[:] = torch.finfo(torch.float32).min
321
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
322
323
                # The weight_scale_inv name is intentional for deepseekv3
                layer.register_parameter("weight_scale_inv", scale)
324

325
            # INPUT ACTIVATION SCALE
326
            if self.quant_config.activation_scheme == "static":
327
328
329
330
331
                scale = PerTensorScaleParameter(data=torch.empty(
                    len(output_partition_sizes), dtype=torch.float32),
                                                weight_loader=weight_loader)

                scale[:] = torch.finfo(torch.float32).min
332
                set_weight_attrs(scale, {"scale_type": "input_scale"})
333
                layer.register_parameter("input_scale", scale)
334
335
            else:
                layer.register_parameter("input_scale", None)
336

337
    def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor:
338
339
340
341
342
343
344
345
346
347
        # Pad the weight tensor. This is an optimization on ROCm platform, which
        # can benefit from tensors located far enough from one another in memory
        if (envs.VLLM_ROCM_FP8_PADDING and current_platform.is_rocm()
                and weight.stride(-1) == 1
                and (weight.stride(-2) * weight.element_size()) % 512 == 0):
            num_pad = 256 // weight.element_size()
            weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
            torch.cuda.empty_cache()
        return weight

348
    def process_weights_after_loading(self, layer: Module) -> None:
349
        size_k_first = True
350
        # TODO(rob): refactor block quant into separate class.
351
        if self.block_quant:
352
            assert self.quant_config.activation_scheme == "dynamic"
353
            size_k_first = False
354
            if current_platform.is_fp8_fnuz():
355
                weight, weight_scale_inv, _ = \
356
357
                    normalize_e4m3fn_to_e4m3fnuz(
                        weight=layer.weight,
358
359
360
361
362
                        weight_scale=layer.weight_scale_inv)
            else:
                weight = layer.weight.data
                weight_scale_inv = layer.weight_scale_inv.data

363
            weight = self._maybe_pad_weight(weight)
364

365
366
367
368
369
            # Torch.compile cannot use Parameter subclasses.
            layer.weight = Parameter(weight, requires_grad=False)
            layer.weight_scale_inv = Parameter(weight_scale_inv,
                                               requires_grad=False)

370
        # If checkpoint not serialized fp8, quantize the weights.
371
        elif not self.quant_config.is_checkpoint_fp8_serialized:
372
373
            qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
                                                         scale=None)
374
375

            # Update the layer with the new values.
376
377
            layer.weight = Parameter(qweight.t(), requires_grad=False)
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
378
379
            # layer.input_scale is None indicates dynamic quant and scale is
            # computed from input.
380
            layer.input_scale = None
381

382
383
        # If checkpoint is fp8, handle that there are N scales for N
        # shards in a fused module
384
        else:
385
386
387
388
389
            layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
                                                    requires_grad=False)
            if self.quant_config.activation_scheme == "static":
                layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
                                                       requires_grad=False)
390
391
392

            weight = layer.weight
            weight_scale = layer.weight_scale
393
394
395

            # If using w8a8, torch._scaled_mm needs per tensor, so
            # requantize the logical shards as a single weight.
396
            if not self.use_marlin:
397
                # Dequant -> Quant with max scale so we can run per tensor.
398
                if current_platform.is_fp8_fnuz():
399
400
401
402
403
404
405
406
407
                    weight, weight_scale, input_scale = \
                        normalize_e4m3fn_to_e4m3fnuz(
                            weight=weight,
                            weight_scale=weight_scale,
                            input_scale=layer.input_scale)
                    if input_scale is not None:
                        layer.input_scale = Parameter(input_scale,
                                                      requires_grad=False)

408
                weight_scale, weight = requantize_with_max_scale(
409
410
                    weight=weight,
                    weight_scale=weight_scale,
411
412
                    logical_widths=layer.logical_widths,
                )
413

414
            weight = self._maybe_pad_weight(weight)
415
            # Update layer with new values.
416
            layer.weight = Parameter(weight.t(), requires_grad=False)
417
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
418
            if self.quant_config.activation_scheme == "static":
419
420
                layer.input_scale = Parameter(layer.input_scale.max(),
                                              requires_grad=False)
421

422
        if self.use_marlin:
423
            prepare_fp8_layer_for_marlin(layer, size_k_first)
424
425
            # Activations not quantized for marlin.
            del layer.input_scale
426

427
        # On B200, if E8M0 for DeepGemm is used, we need to
428
429
        # requantize the weight and input to the specific scale
        # at the same time.
430
        if is_blackwell_deep_gemm_e8m0_used():
431
432
433
434
435
436
437
438
439
            assert layer.weight_block_size is not None
            block_sz = tuple(layer.weight_block_size)
            requant_weight_ue8m0_inplace(
                layer.weight.data,
                layer.weight_scale_inv.data if hasattr(
                    layer, "weight_scale_inv") else layer.weight_scale.data,
                block_sz,
            )

440
441
442
443
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
444

445
        if self.use_marlin:
446
447
448
449
            return apply_fp8_marlin_linear(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale,
450
451
452
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
453
                bias=bias)
454

455
456
        if self.block_quant:
            assert self.quant_config.weight_block_size is not None
457

458
            return torch.ops.vllm.apply_w8a8_block_fp8_linear(
459
460
461
462
463
464
                input=x,
                weight=layer.weight,
                block_size=self.quant_config.weight_block_size,
                weight_scale=layer.weight_scale_inv,
                input_scale=layer.input_scale,
                bias=bias,
465
                cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
466
                use_aiter_and_is_supported=self.use_aiter_and_is_supported,
467
468
            )

469
470
471
        return self.fp8_linear.apply(input=x,
                                     weight=layer.weight,
                                     weight_scale=layer.weight_scale,
472
                                     out_dtype=self.out_dtype,
473
474
                                     input_scale=layer.input_scale,
                                     bias=bias)
475
476


477
478
479
480
481
482
483
484
485
486
487
488
489
class Fp8MoEMethod(FusedMoEMethodBase):
    """MoE method for FP8.
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

490
491
492
    def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
        super().__init__(layer.moe_config)
        self.layer = layer
493
        self.quant_config = quant_config
494
        self.block_quant = self.quant_config.weight_block_size is not None
495

496
497
498
        self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None
        self.fused_experts: Optional[
            mk.FusedMoEModularKernel] = None  # type: ignore
499
        if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
500
            self.flashinfer_moe_backend = get_flashinfer_moe_backend()
501
            logger.info_once(
502
503
                f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
            )
504
505
506
507
508
509
510
511
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
        self.use_marlin = (not current_platform.has_device_capability(89)
                           or envs.VLLM_TEST_FORCE_FP8_MARLIN)
        # Disable marlin for rocm
        if current_platform.is_rocm():
            self.use_marlin = False

512
513
514
        # Check for DeepGemm support.
        self.allow_deep_gemm = False
        if envs.VLLM_USE_DEEP_GEMM:
515
            if not has_deep_gemm():
516
                logger.warning_once("Failed to import DeepGemm kernels.")
517
518
            elif not self.block_quant:
                logger.warning_once("Model is not block quantized. Not using "
519
                                    "DeepGemm kernels")
520
            elif (is_deep_gemm_supported()):
521
522
523
524
525
526
                logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.")
                self.allow_deep_gemm = True
            else:
                logger.warning_once(
                    "DeepGemm not supported on the current platform.")

527
528
529
        # Check for CutlassBlockScaledGroupedGemm support.
        self.allow_cutlass_block_scaled_grouped_gemm = False
        if not self.block_quant:
530
531
            logger.debug_once("Model is not block quantized. Not using "
                              "CutlassBlockScaledGroupedGemm kernels")
532
        elif (current_platform.is_cuda()
533
              and current_platform.is_device_capability(100)):
534
535
536
537
538
539
540
541
542
            logger.info_once(
                "Using CutlassBlockScaledGroupedGemm kernels for Fp8MoEMethod."
            )
            self.allow_cutlass_block_scaled_grouped_gemm = True
        else:
            logger.warning_once(
                "CutlassBlockScaledGroupedGemm not supported on the current "
                "platform.")

543
544
545
546
547
548
549
550
551
552
553
554
555
556
    def maybe_make_prepare_finalize(
        self,
        moe: FusedMoEConfig,
    ) -> Optional[mk.FusedMoEPrepareAndFinalize]:
        if self.flashinfer_moe_backend != FlashinferMoeBackend.CUTLASS:
            return super().maybe_make_prepare_finalize(moe)

        prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
            moe,
            layer=self.layer,
        )
        logger.debug_once("%s", prepare_finalize.__class__.__name__)
        return prepare_finalize

557
    def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
558
559
                       intermediate_size_per_partition: int,
                       params_dtype: torch.dtype, **extra_weight_attrs):
560

561
562
563
564
565
566
        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

567
568
        if self.quant_config.is_checkpoint_fp8_serialized:
            params_dtype = torch.float8_e4m3fn
569
570
        if self.block_quant:
            assert self.quant_config.weight_block_size is not None
571
            layer.weight_block_size = self.quant_config.weight_block_size
572
573
574
575
576
577
578
579
580
            tp_size = get_tensor_model_parallel_world_size()
            block_n, block_k = (
                self.quant_config.weight_block_size[0],
                self.quant_config.weight_block_size[1],
            )
            # NOTE: To ensure proper alignment of the block-wise quantization
            # scales, the output_size of the weights for both the gate and up
            # layers must be divisible by block_n.
            # Required by column parallel or enabling merged weights
581
            if intermediate_size_per_partition % block_n != 0:
582
583
                raise ValueError(
                    f"The output_size of gate's and up's weight = "
584
                    f"{intermediate_size_per_partition} is not divisible by "
585
                    f"weight quantization block_n = {block_n}.")
586
587
            if (tp_size > 1
                    and intermediate_size_per_partition % block_k != 0):
588
                # Required by row parallel
589
590
591
592
                raise ValueError(
                    f"The input_size of down's weight = "
                    f"{intermediate_size_per_partition} is not divisible by "
                    f"weight quantization block_k = {block_k}.")
593
594

        # WEIGHTS
595
596
597
598
599
        w13_weight = torch.nn.Parameter(torch.empty(
            num_experts,
            2 * intermediate_size_per_partition,
            hidden_size,
            dtype=params_dtype),
600
601
602
603
                                        requires_grad=False)
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

604
605
606
607
608
        w2_weight = torch.nn.Parameter(torch.empty(
            num_experts,
            hidden_size,
            intermediate_size_per_partition,
            dtype=params_dtype),
609
610
611
612
613
                                       requires_grad=False)
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
        if not self.block_quant:
            # Allocate 2 scales for w1 and w3 respectively.
            # They will be combined to a single scale after weight loading.
            w13_weight_scale = torch.nn.Parameter(torch.ones(
                num_experts, 2, dtype=torch.float32),
                                                  requires_grad=False)
            w2_weight_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                 requires_grad=False)
            layer.register_parameter("w13_weight_scale", w13_weight_scale)
            layer.register_parameter("w2_weight_scale", w2_weight_scale)
        else:
            w13_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
629
630
                    2 * ((intermediate_size_per_partition + block_n - 1) //
                         block_n),
631
632
633
634
635
636
637
638
639
                    (hidden_size + block_k - 1) // block_k,
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            w2_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
                    (hidden_size + block_n - 1) // block_n,
640
                    (intermediate_size_per_partition + block_k - 1) // block_k,
641
642
643
644
645
646
647
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
            layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
            assert self.quant_config.activation_scheme == "dynamic"
648

649
650
651
        # Add the quantization method used (per tensor/grouped/channel)
        # to ensure the weight scales are loaded in properly
        extra_weight_attrs.update(
652
653
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.
             value} if self.block_quant else
654
            {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
655
656
657
658
        # If loading fp8 checkpoint, pass the weight loaders.
        # If loading an fp16 checkpoint, do not (we will quantize in
        #   process_weights_after_loading()
        if self.quant_config.is_checkpoint_fp8_serialized:
659
660
            set_weight_attrs(w13_weight_scale, extra_weight_attrs)
            set_weight_attrs(w2_weight_scale, extra_weight_attrs)
661
662
663
664
665
666
667
668

        # INPUT_SCALES
        if self.quant_config.activation_scheme == "static":
            if not self.quant_config.is_checkpoint_fp8_serialized:
                raise ValueError(
                    "Found static activation scheme for checkpoint that "
                    "was not serialized fp8.")

669
670
671
672
            w13_input_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                 requires_grad=False)
            layer.register_parameter("w13_input_scale", w13_input_scale)
673
            set_weight_attrs(w13_input_scale, extra_weight_attrs)
674
675
676
677
678

            w2_input_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                requires_grad=False)
            layer.register_parameter("w2_input_scale", w2_input_scale)
679
680
            set_weight_attrs(w2_input_scale, extra_weight_attrs)

681
        else:
682
683
            layer.w13_input_scale = None
            layer.w2_input_scale = None
684
685

    def process_weights_after_loading(self, layer: Module) -> None:
686
687
        # Lazy import to avoid importing triton too early.
        from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
688
            is_rocm_aiter_moe_enabled, shuffle_weights)
689

690
691
        self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()

692
        # TODO (rob): refactor block quant into separate class.
693
        if self.block_quant:
694
            assert self.quant_config.activation_scheme == "dynamic"
695
            if current_platform.is_fp8_fnuz():
696
697
698
699
700
701
702
703
                w13_weight, w13_weight_scale_inv, w13_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w13_weight, layer.w13_weight_scale_inv,
                        layer.w13_input_scale)
                w2_weight, w2_weight_scale_inv, w2_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w2_weight, layer.w2_weight_scale_inv,
                        layer.w2_input_scale)
704
            elif self.flashinfer_moe_backend is not None:
705
706
                # NOTE: weights have to be swapped since the activation is
                # applied on different half for flashinfer vs vllm
707
708
                w13_weight = swap_w13_to_w31(layer.w13_weight.data)
                w13_weight_scale_inv = swap_w13_to_w31(
709
710
711
                    layer.w13_weight_scale_inv.data)
                w2_weight = layer.w2_weight.data
                w2_weight_scale_inv = layer.w2_weight_scale_inv.data
712
713
714
715
716
717
718
719
720
721
722
723
724
            else:
                w13_weight = layer.w13_weight.data
                w13_weight_scale_inv = layer.w13_weight_scale_inv.data
                w2_weight = layer.w2_weight
                w2_weight_scale_inv = layer.w2_weight_scale_inv

            # torch.compile() cannot use Parameter subclasses.
            layer.w13_weight = Parameter(w13_weight, requires_grad=False)
            layer.w13_weight_scale_inv = Parameter(w13_weight_scale_inv,
                                                   requires_grad=False)
            layer.w2_weight = Parameter(w2_weight, requires_grad=False)
            layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv,
                                                  requires_grad=False)
725
            if self.rocm_aiter_moe_enabled:
726
727
                # reshaping weights is required for aiter moe kernel.
                shuffled_w13, shuffled_w2 = shuffle_weights(
728
                    layer.w13_weight.data, layer.w2_weight.data)
729
730
731
732
733

                layer.w13_weight = torch.nn.Parameter(shuffled_w13,
                                                      requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2,
                                                     requires_grad=False)
734
735
736

            # DeepGemm scales need to be transposed and aligned.  We try to do
            # it ahead of time for performance reasons.
737
            if self.allow_deep_gemm and not is_blackwell_deep_gemm_e8m0_used():
738
739
740
                # Lazy import to avoid CUDA initialization problems.
                if _is_col_major(layer.w13_weight_scale_inv):
                    layer.w13_weight_scale_inv = \
741
                        get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv).contiguous()
742
743
                if _is_col_major(layer.w2_weight_scale_inv):
                    layer.w2_weight_scale_inv = \
744
                        get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv).contiguous()
745

746
        # If checkpoint is fp16, quantize in place.
747
        elif not self.quant_config.is_checkpoint_fp8_serialized:
748
            fp8_dtype = current_platform.fp8_dtype()
749
            w13_weight = torch.empty_like(layer.w13_weight.data,
750
751
                                          dtype=fp8_dtype)
            w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
752
753
754

            # Re-initialize w13_scale because we directly quantize
            # merged w13 weights and generate a single scaling factor.
755
            layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
756
                layer.local_num_experts,
757
758
                dtype=torch.float32,
                device=w13_weight.device),
759
                                                        requires_grad=False)
760
            for expert in range(layer.local_num_experts):
761
                w13_weight[expert, :, :], layer.w13_weight_scale[
762
763
                    expert] = ops.scaled_fp8_quant(
                        layer.w13_weight.data[expert, :, :])
764
                w2_weight[expert, :, :], layer.w2_weight_scale[
765
766
767
768
769
770
                    expert] = ops.scaled_fp8_quant(
                        layer.w2_weight.data[expert, :, :])
            layer.w13_weight = torch.nn.Parameter(w13_weight,
                                                  requires_grad=False)
            layer.w2_weight = torch.nn.Parameter(w2_weight,
                                                 requires_grad=False)
771
            if self.rocm_aiter_moe_enabled:
772
                # reshaping weights is required for aiter moe kernel.
773
774
                shuffled_w13, shuffled_w2 = shuffle_weights(
                    layer.w13_weight, layer.w2_weight)
775
776
777
778
779

                layer.w13_weight = torch.nn.Parameter(shuffled_w13,
                                                      requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2,
                                                     requires_grad=False)
780
781
782
783
784
785
786
        # If checkpoint is fp8, we need to handle that the
        # MoE kernels require single activation scale and single weight
        # scale for w13 per expert.
        else:
            # Fp8 moe kernels require a single activation scale.
            # We take the max of all the scales in case they differ.
            if self.quant_config.activation_scheme == "static":
787
788
                if (layer.w13_input_scale is None
                        or layer.w2_input_scale is None):
789
790
791
                    raise ValueError(
                        "QuantConfig has static quantization, but found "
                        "activation scales are None.")
792
793
                if (not all_close_1d(layer.w13_input_scale)
                        or not all_close_1d(layer.w2_input_scale)):
794
                    logger.warning_once(
795
796
                        "Found input_scales that are not equal for "
                        "fp8 MoE layer. Using the maximum across experts "
797
                        "for each layer.")
798
799
800
801
                layer.w13_input_scale = torch.nn.Parameter(
                    layer.w13_input_scale.max(), requires_grad=False)
                layer.w2_input_scale = torch.nn.Parameter(
                    layer.w2_input_scale.max(), requires_grad=False)
802
            if current_platform.is_fp8_fnuz():
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
                # Normalize the weights and scales
                w13_weight, w13_weight_scale, w13_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w13_weight, layer.w13_weight_scale,
                        layer.w13_input_scale)
                w2_weight, w2_weight_scale, w2_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w2_weight, layer.w2_weight_scale,
                        layer.w2_input_scale)
                # Reset the parameter
                layer.w13_weight = torch.nn.Parameter(w13_weight,
                                                      requires_grad=False)
                layer.w13_weight_scale = torch.nn.Parameter(
                    w13_weight_scale, requires_grad=False)
                if w13_input_scale is not None:
                    layer.w13_input_scale = torch.nn.Parameter(
                        w13_input_scale, requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(w2_weight,
                                                     requires_grad=False)
                layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale,
                                                           requires_grad=False)
                if w2_input_scale is not None:
                    layer.w2_input_scale = torch.nn.Parameter(
                        w2_input_scale, requires_grad=False)
827
828
829

            # Fp8 moe kernel needs single weight scale for w13 per expert.
            # We take the max then dequant and requant each expert.
830
            assert layer.w13_weight_scale is not None
831
            shard_size = layer.intermediate_size_per_partition
832
            max_w13_scales = layer.w13_weight_scale.max(dim=1).values
833
            for expert_id in range(layer.local_num_experts):
834
835
836
837
838
                start = 0
                for shard_id in range(2):
                    dq_weight = per_tensor_dequantize(
                        layer.w13_weight[expert_id][start:start +
                                                    shard_size, :],
839
                        layer.w13_weight_scale[expert_id][shard_id])
840
                    layer.w13_weight[expert_id][
841
                        start:start + shard_size, :], _ = ops.scaled_fp8_quant(
842
843
844
                            dq_weight, max_w13_scales[expert_id])
                    start += shard_size

845
            if self.rocm_aiter_moe_enabled:
846
847
                shuffled_w13, shuffled_w2 = shuffle_weights(
                    layer.w13_weight, layer.w2_weight)
848
849
850
851
852
853

                layer.w13_weight = torch.nn.Parameter(shuffled_w13,
                                                      requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2,
                                                     requires_grad=False)

854
855
            layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
                                                        requires_grad=False)
856

857
858
859
860
861
862
863
864
865
866
867
            if self.flashinfer_moe_backend is not None:
                # NOTE: weights have to be swapped since the activation is
                # applied on different half for flashinfer vs vllm
                assert not self.block_quant
                register_moe_scaling_factors(layer)
                w13_weight = swap_w13_to_w31(layer.w13_weight.data)
                if self.flashinfer_moe_backend == \
                    FlashinferMoeBackend.TENSORRT_LLM:
                    rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
                layer.w13_weight.data = w13_weight.data

868
869
870
871
872
        if self.use_marlin:
            prepare_moe_fp8_layer_for_marlin(layer, False)
            # Activations not quantized for marlin.
            del layer.w13_input_scale
            del layer.w2_input_scale
873

874
        if is_blackwell_deep_gemm_e8m0_used():
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
            assert layer.weight_block_size is not None
            # Re-quantise the expert weights so their scales are UE8M0.
            block_sz = tuple(layer.weight_block_size)
            requant_weight_ue8m0_inplace(
                layer.w13_weight.data,
                layer.w13_weight_scale_inv.data,
                block_sz,
            )
            requant_weight_ue8m0_inplace(
                layer.w2_weight.data,
                layer.w2_weight_scale_inv.data,
                block_sz,
            )

            # Ensure column-major TMA alignment expected by DeepGEMM.
            if _is_col_major(layer.w13_weight_scale_inv):
                layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(
                    layer.w13_weight_scale_inv).contiguous()
            if _is_col_major(layer.w2_weight_scale_inv):
                layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(
                    layer.w2_weight_scale_inv).contiguous()

bnellnm's avatar
bnellnm committed
897
898
899
900
901
    def select_gemm_impl(
        self,
        prepare_finalize: FusedMoEPrepareAndFinalize,
        moe: FusedMoEConfig,
    ) -> FusedMoEPermuteExpertsUnpermute:
902
903
904
        from vllm.model_executor.layers.fused_moe import (
            BatchedTritonOrDeepGemmExperts, TritonOrDeepGemmExperts)

905
906
        assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
            "Marlin and ROCm AITER are not supported with all2all yet.")
907

bnellnm's avatar
bnellnm committed
908
909
910
911
912
913
914
915
916
917
918
        if (prepare_finalize.activation_format ==
                FusedMoEActivationFormat.BatchedExperts):
            max_num_tokens_per_rank = (
                prepare_finalize.max_num_tokens_per_rank())
            assert max_num_tokens_per_rank is not None
            logger.debug(
                "BatchedTritonOrDeepGemmExperts(%s): "
                "max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
                self.__class__.__name__, max_num_tokens_per_rank,
                self.quant_config.weight_block_size, False)
            return BatchedTritonOrDeepGemmExperts(
919
                max_num_tokens=max_num_tokens_per_rank,
920
                num_dispatchers=prepare_finalize.num_dispatchers(),
921
                use_fp8_w8a8=True,
922
                block_shape=self.quant_config.weight_block_size,
bnellnm's avatar
bnellnm committed
923
                per_act_token_quant=False,
924
                allow_deep_gemm=self.allow_deep_gemm,
925
            )
926
927
928
929
930
931
932
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
            experts = select_cutlass_fp8_gemm_impl(
                moe,
                self.layer,
            )
            logger.debug_once("Using %s", experts.__class__.__name__)
            return experts
933
        else:
bnellnm's avatar
bnellnm committed
934
935
936
937
938
            logger.debug(
                "TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s",
                self.__class__.__name__, self.quant_config.weight_block_size,
                False)
            return TritonOrDeepGemmExperts(
939
940
941
942
943
                use_fp8_w8a8=True,
                block_shape=self.quant_config.weight_block_size,
                allow_deep_gemm=self.allow_deep_gemm,
            )

944
945
946
947
948
949
950
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
951
        use_grouped_topk: bool = False,
952
953
        topk_group: Optional[int] = None,
        num_expert_group: Optional[int] = None,
954
955
        global_num_experts: int = -1,
        expert_map: Optional[torch.Tensor] = None,
956
        custom_routing_function: Optional[Callable] = None,
Simon Mo's avatar
Simon Mo committed
957
958
        scoring_func: str = "softmax",
        e_score_correction_bias: Optional[torch.Tensor] = None,
959
        apply_router_weight_on_input: bool = False,
Michael Goin's avatar
Michael Goin committed
960
        activation: str = "silu",
961
962
963
964
        enable_eplb: bool = False,
        expert_load_view: Optional[torch.Tensor] = None,
        logical_to_physical_map: Optional[torch.Tensor] = None,
        logical_replica_count: Optional[torch.Tensor] = None,
965
    ) -> torch.Tensor:
966
967
968
969
970
        if enable_eplb:
            assert expert_load_view is not None
            assert logical_to_physical_map is not None
            assert logical_replica_count is not None
            assert isinstance(layer, FusedMoE)
971

972
973
974
975
976
        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
            assert activation == 'silu', (
                f"Expected 'silu' activation but got {activation}")
            assert scoring_func == 'sigmoid', (
                f"Expected 'sigmoid' scoring func but got {scoring_func}")
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
            if self.block_quant:
                assert (renormalize and use_grouped_topk
                        and custom_routing_function is None)

                return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
                    routing_logits=router_logits.to(torch.float32),
                    routing_bias=e_score_correction_bias,
                    x=x,
                    w13_weight=layer.w13_weight,
                    w13_weight_scale_inv=layer.w13_weight_scale_inv,
                    w2_weight=layer.w2_weight,
                    w2_weight_scale_inv=layer.w2_weight_scale_inv,
                    global_num_experts=global_num_experts,
                    top_k=top_k,
                    num_expert_group=num_expert_group,
                    topk_group=topk_group,
                    intermediate_size=layer.intermediate_size_per_partition,
                    expert_offset=layer.ep_rank * layer.local_num_experts,
                    local_num_experts=layer.local_num_experts,
                    block_shape=self.quant_config.weight_block_size,
                    routed_scaling=1.0,
                )
            else:
                assert (not renormalize
                        and custom_routing_function is not None)
                return apply_flashinfer_per_tensor_scale_fp8(
                    layer=layer,
                    hidden_states=x,
                    router_logits=router_logits,
                    routing_bias=e_score_correction_bias,
                    global_num_experts=global_num_experts,
                    top_k=top_k,
                    num_expert_group=num_expert_group,
                    topk_group=topk_group,
                    apply_router_weight_on_input=apply_router_weight_on_input)
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038

        topk_weights, topk_ids = FusedMoE.select_experts(
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
            num_expert_group=num_expert_group,
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
            e_score_correction_bias=e_score_correction_bias,
            indices_type=self.topk_indices_dtype,
            enable_eplb=enable_eplb,
            expert_map=expert_map,
            expert_load_view=expert_load_view,
            logical_to_physical_map=logical_to_physical_map,
            logical_replica_count=logical_replica_count,
        )

        if self.rocm_aiter_moe_enabled:
            from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (  # noqa: E501
                rocm_aiter_fused_experts)
            return rocm_aiter_fused_experts(
                x,
                layer.w13_weight,
                layer.w2_weight,
1039
1040
1041
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                activation=activation,
1042
                use_fp8_w8a8=True,
1043
1044
1045
1046
1047
1048
1049
                apply_router_weight_on_input=apply_router_weight_on_input,
                w1_scale=(layer.w13_weight_scale_inv
                          if self.block_quant else layer.w13_weight_scale),
                w2_scale=(layer.w2_weight_scale_inv
                          if self.block_quant else layer.w2_weight_scale),
                a1_scale=layer.w13_input_scale,
                a2_scale=layer.w2_input_scale,
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
                block_shape=self.quant_config.weight_block_size,
                expert_map=expert_map)
        elif self.use_marlin:
            assert activation == "silu", (
                f"{activation} not supported for Marlin MoE.")
            return torch.ops.vllm.fused_marlin_moe(
                x,
                layer.w13_weight,
                layer.w2_weight,
                None,
                None,
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                router_logits,
                topk_weights,
                topk_ids,
                quant_type_id=scalar_types.float8_e4m3fn.id,
                apply_router_weight_on_input=apply_router_weight_on_input,
                global_num_experts=global_num_experts,
                expert_map=expert_map)
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
            assert self.block_quant is None
            assert (not renormalize and custom_routing_function is not None)
            assert activation == 'silu', (
                f"Expected 'silu' activation but got {activation}")
            assert scoring_func == 'sigmoid', (
                f"Expected 'sigmoid' scoring func but got {scoring_func}")
            if self.fused_experts is not None:
                return self.fused_experts(
                    x,
                    layer.w13_weight,
                    layer.w2_weight,
                    topk_weights,
                    topk_ids,
                    inplace=False,
                    activation=activation,
                    global_num_experts=global_num_experts,
                    expert_map=expert_map,
                    apply_router_weight_on_input=apply_router_weight_on_input,
                )
            else:
                return flashinfer_cutlass_moe_fp8(
                    x,
                    layer,
                    topk_weights,
                    topk_ids,
                    inplace=False,
                    activation=activation,
                    global_num_experts=global_num_experts,
                    expert_map=expert_map,
                    apply_router_weight_on_input=apply_router_weight_on_input,
                )
1102
        else:
1103
            common_kwargs = dict(
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                inplace=True,
                activation=activation,
                global_num_experts=global_num_experts,
                apply_router_weight_on_input=apply_router_weight_on_input,
                expert_map=expert_map,
                w1_scale=(layer.w13_weight_scale_inv
                          if self.block_quant else layer.w13_weight_scale),
                w2_scale=(layer.w2_weight_scale_inv
                          if self.block_quant else layer.w2_weight_scale),
                a1_scale=layer.w13_input_scale,
                a2_scale=layer.w2_input_scale,
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
            )

            if self.fused_experts is not None:
                return self.fused_experts(**common_kwargs)
            else:
                from vllm.model_executor.layers.fused_moe import fused_experts
                return fused_experts(
                    **common_kwargs,
                    use_fp8_w8a8=True,
                    block_shape=self.quant_config.weight_block_size,
                    allow_deep_gemm=self.allow_deep_gemm,
                    allow_cutlass_block_scaled_grouped_gemm=(
                        self.allow_cutlass_block_scaled_grouped_gemm),
                )
1134
1135


1136
1137
1138
class Fp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
1139
1140
1141
    """

    def __init__(self, quant_config: Fp8Config):
1142
        super().__init__(quant_config)