fp8.py 53.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
5
6

import torch
7
import torch.nn.functional as F
8
9
10
from torch.nn import Module
from torch.nn.parameter import Parameter

11
import vllm.envs as envs
12
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
13
from vllm import _custom_ops as ops
14
from vllm.distributed import get_tensor_model_parallel_world_size
15
from vllm.logger import init_logger
bnellnm's avatar
bnellnm committed
16
from vllm.model_executor.layers.fused_moe import (
17
18
19
    FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase,
    FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize,
    FusedMoeWeightScaleSupported)
20
21
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
                                               UnquantizedLinearMethod)
22
from vllm.model_executor.layers.quantization import QuantizationMethods
23
from vllm.model_executor.layers.quantization.base_config import (
24
    QuantizationConfig, QuantizeMethodBase)
25
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
26
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
27
28
29
30
31
    FlashinferMoeBackend, apply_flashinfer_per_tensor_scale_fp8,
    build_flashinfer_fp8_cutlass_moe_prepare_finalize,
    flashinfer_cutlass_moe_fp8, get_flashinfer_moe_backend,
    register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights,
    select_cutlass_fp8_gemm_impl, swap_w13_to_w31)
32
33
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
    get_col_major_tma_aligned_tensor, requant_weight_ue8m0_inplace)
34
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
35
36
    apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin,
    prepare_moe_fp8_layer_for_marlin)
37
from vllm.model_executor.layers.quantization.utils.quant_utils import (
38
    GroupShape, is_layer_skipped)
39
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
40
41
42
43
    Fp8LinearOp, all_close_1d, cutlass_block_fp8_supported,
    cutlass_fp8_supported, maybe_create_device_identity,
    normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
    requantize_with_max_scale)
44
45
from vllm.model_executor.parameter import (BlockQuantScaleParameter,
                                           ModelWeightParameter,
46
                                           PerTensorScaleParameter)
47
from vllm.model_executor.utils import set_weight_attrs
48
from vllm.platforms import current_platform
49
from vllm.scalar_type import scalar_types
50
from vllm.utils import has_deep_gemm
51
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used, is_deep_gemm_supported
52
from vllm.utils.flashinfer import has_flashinfer_moe
53

54
55
56
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

57
58
59
60
ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)

61
62
63
64
65
66

def _is_col_major(x: torch.Tensor) -> bool:
    assert x.dim() == 3
    b, m, n = x.shape
    return x.stride(0) == m * n and x.stride(1) == 1 and x.stride(2) == m

67

68
class Fp8Config(QuantizationConfig):
69
70
    """Config class for FP8."""

71
72
    def __init__(
        self,
73
        is_checkpoint_fp8_serialized: bool = False,
74
        activation_scheme: str = "dynamic",
75
76
        ignored_layers: Optional[list[str]] = None,
        weight_block_size: Optional[list[int]] = None,
77
    ) -> None:
78
        super().__init__()
79

80
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
81

82
83
84
        if activation_scheme not in ACTIVATION_SCHEMES:
            raise ValueError(
                f"Unsupported activation scheme {activation_scheme}")
85
        self.activation_scheme = activation_scheme
86
        self.ignored_layers = ignored_layers or []
87
88
89
90
91
92
93
94
95
96
97
98
99
100
        if weight_block_size is not None:
            if not is_checkpoint_fp8_serialized:
                raise ValueError(
                    "The block-wise quantization only supports fp8-serialized "
                    "checkpoint for now.")
            if len(weight_block_size) != 2:
                raise ValueError(
                    "The quantization block size of weight must have 2 "
                    f"dimensions, but got {len(weight_block_size)} dimensions")
            if activation_scheme != "dynamic":
                raise ValueError("The block-wise quantization only supports "
                                 "dynamic activation scheme for now, but got "
                                 f"{activation_scheme} activation scheme.")
        self.weight_block_size = weight_block_size
101

102
    @classmethod
103
    def get_name(cls) -> QuantizationMethods:
104
105
106
        return "fp8"

    @classmethod
107
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
108
109
110
111
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
112
        return 80
113
114

    @classmethod
115
    def get_config_filenames(cls) -> list[str]:
116
117
        return []

118
119
120
121
122
    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if self.ignored_layers is not None:
            self.ignored_layers = hf_to_vllm_mapper.apply_list(
                self.ignored_layers)

123
    @classmethod
124
    def from_config(cls, config: dict[str, Any]) -> "Fp8Config":
125
126
        quant_method = cls.get_from_keys(config, ["quant_method"])
        is_checkpoint_fp8_serialized = ("fp8" in quant_method)
127
        activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
128
        ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
129
130
        weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"],
                                                 None)
131
132
133
134
        if not ignored_layers:
            ignored_layers = cls.get_from_keys_or(config,
                                                  ["modules_to_not_convert"],
                                                  None)
135
        return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
136
                   activation_scheme=activation_scheme,
137
138
                   ignored_layers=ignored_layers,
                   weight_block_size=weight_block_size)
139

140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
    def get_xpu_quant_method(self, layer: torch.nn.Module,
                             prefix: str) -> Optional["QuantizeMethodBase"]:
        from vllm.attention.layer import Attention
        from vllm.model_executor.layers.quantization.ipex_quant import (
            XPUFp8LinearMethod, XPUFp8MoEMethod)
        fp8_config = Fp8Config(
            is_checkpoint_fp8_serialized=self.is_checkpoint_fp8_serialized,
            activation_scheme=self.activation_scheme,
            ignored_layers=self.ignored_layers,
            weight_block_size=self.weight_block_size)

        if isinstance(layer, LinearBase):
            if is_layer_skipped(prefix=prefix,
                                ignored_layers=self.ignored_layers,
                                fused_mapping=self.packed_modules_mapping):
                return UnquantizedLinearMethod()
            return XPUFp8LinearMethod(fp8_config)
        elif isinstance(layer, FusedMoE):
            return XPUFp8MoEMethod(fp8_config, layer)
        elif isinstance(layer, Attention):
            return Fp8KVCacheMethod(self)
        return None

163
164
    def get_quant_method(self, layer: torch.nn.Module,
                         prefix: str) -> Optional["QuantizeMethodBase"]:
165
166
        from vllm.attention.layer import Attention  # Avoid circular import

167
168
        if current_platform.is_xpu():
            return self.get_xpu_quant_method(layer, prefix)
169
        if isinstance(layer, LinearBase):
170
171
172
            if is_layer_skipped(prefix=prefix,
                                ignored_layers=self.ignored_layers,
                                fused_mapping=self.packed_modules_mapping):
173
                return UnquantizedLinearMethod()
174
            return Fp8LinearMethod(self)
175
        elif isinstance(layer, FusedMoE):
176
            return Fp8MoEMethod(self, layer)
177
        elif isinstance(layer, Attention):
178
            return Fp8KVCacheMethod(self)
179
        return None
180

181
182
183
184
185
186
187
188
189
190
191
192
193
    def get_cache_scale(self, name: str) -> Optional[str]:
        """
        Check whether the param name matches the format for k/v cache scales
        in compressed-tensors. If this is the case, return its equivalent
        param name expected by vLLM

        :param name: param name
        :return: matching param name for KV cache scale in vLLM
        """
        if name.endswith(".output_scale") and ".k_proj" in name:
            return name.replace(".k_proj.output_scale", ".attn.k_scale")
        if name.endswith(".output_scale") and ".v_proj" in name:
            return name.replace(".v_proj.output_scale", ".attn.v_scale")
194
195
196
197
198
        if name.endswith(".output_scale") and ".q_proj" in name:
            return name.replace(".q_proj.output_scale", ".attn.q_scale")
        if name.endswith("self_attn.prob_output_scale"):
            return name.replace(".prob_output_scale", ".attn.prob_scale")
        # If no matches, return None
199
200
        return None

201
202
203

class Fp8LinearMethod(LinearMethodBase):
    """Linear method for FP8.
204
205
206
207
208
209
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.
210
211
212
213
214

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
    2. Only support float8_e4m3fn data type due to the limitation of
       torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
215

216
217
218
219
    Args:
        quant_config: The quantization config.
    """

220
    def __init__(self, quant_config: Fp8Config):
221
        self.quant_config = quant_config
222
        self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
223
        self.out_dtype = torch.get_default_dtype()
224

225
226
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
227
228
        self.use_marlin = (not current_platform.has_device_capability(89)
                           or envs.VLLM_TEST_FORCE_FP8_MARLIN)
229
        # Disable marlin for rocm
230
        if current_platform.is_rocm():
231
            self.use_marlin = False
232

233
234
235
236
237
238
239
        # AITER is only supported on ROCm and only for FP8_FNUZ
        # and at the moment are MI300 series
        self.use_aiter_and_is_supported = (current_platform.is_rocm()
                                           and envs.VLLM_ROCM_USE_AITER
                                           and envs.VLLM_ROCM_USE_AITER_LINEAR
                                           and current_platform.is_fp8_fnuz())

240
        self.block_quant = self.quant_config.weight_block_size is not None
241
242
243
244
245
246
247
        self.act_q_static = self.quant_config.activation_scheme == "static"
        # Use per-token quantization for better perf if dynamic and cutlass
        if not self.act_q_static and cutlass_fp8_supported():
            self.act_q_group_shape = GroupShape.PER_TOKEN
        else:
            self.act_q_group_shape = GroupShape.PER_TENSOR

248
        self.fp8_linear = Fp8LinearOp(
249
            act_quant_static=self.act_q_static,
250
            act_quant_group_shape=self.act_q_group_shape)
251

252
253
254
255
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
256
        output_partition_sizes: list[int],
257
258
259
260
261
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
262
263
        maybe_create_device_identity()

264
        output_size_per_partition = sum(output_partition_sizes)
265
        weight_loader = extra_weight_attrs.get("weight_loader")
266
267
268
269
270
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None
271

272
        if self.block_quant:
273
274
            tp_size = getattr(layer, "tp_size",
                              get_tensor_model_parallel_world_size())
275
            assert self.quant_config.weight_block_size is not None
276
            layer.weight_block_size = self.quant_config.weight_block_size
277
278
279
280
281
282
283
284
285
286
287
288
289
            block_n, block_k = (
                self.quant_config.weight_block_size[0],
                self.quant_config.weight_block_size[1],
            )
            # Required by row parallel
            if (tp_size > 1
                    and input_size // input_size_per_partition == tp_size
                    and input_size_per_partition % block_k != 0):
                raise ValueError(
                    f"Weight input_size_per_partition = "
                    f"{input_size_per_partition} is not divisible by "
                    f"weight quantization block_k = {block_k}.")
            # Required by column parallel or enabling merged weights
290
291
292
293
294
295
296
297
298
299
            is_tp_split = (tp_size > 1 and
                           output_size // output_size_per_partition == tp_size)
            is_merged_gemm = len(output_partition_sizes) > 1
            if is_tp_split or is_merged_gemm:
                sizes_to_check = output_partition_sizes
                if not is_tp_split and is_merged_gemm:
                    # In case of merged matrices, we allow the last
                    # matrix to not be a multiple of block size
                    sizes_to_check = output_partition_sizes[:-1]
                for output_partition_size in sizes_to_check:
300
301
302
303
304
305
                    if output_partition_size % block_n != 0:
                        raise ValueError(
                            f"Weight output_partition_size = "
                            f"{output_partition_size} is not divisible by "
                            f"weight quantization block_n = {block_n}.")

306
307
308
309
        # WEIGHT
        weight_dtype = (torch.float8_e4m3fn
                        if self.quant_config.is_checkpoint_fp8_serialized else
                        params_dtype)
310
311
312
313
314
315
316
317

        weight = ModelWeightParameter(data=torch.empty(
            output_size_per_partition,
            input_size_per_partition,
            dtype=weight_dtype),
                                      input_dim=1,
                                      output_dim=0,
                                      weight_loader=weight_loader)
318
319
        layer.register_parameter("weight", weight)

320
321
322
323
        # If checkpoint is serialized fp8, load them.
        # Otherwise, wait until process_weights_after_loading.
        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
324
325
326
327
328
329
330
            if not self.block_quant:
                scale = PerTensorScaleParameter(
                    data=torch.empty(len(output_partition_sizes),
                                     dtype=torch.float32),
                    weight_loader=weight_loader,
                )
                scale[:] = torch.finfo(torch.float32).min
331
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
332
333
334
335
336
337
338
339
340
341
342
343
344
345
                layer.register_parameter("weight_scale", scale)
            else:
                assert self.quant_config.activation_scheme == "dynamic"
                scale = BlockQuantScaleParameter(
                    data=torch.empty(
                        (output_size_per_partition + block_n - 1) // block_n,
                        (input_size_per_partition + block_k - 1) // block_k,
                        dtype=torch.float32,
                    ),
                    input_dim=1,
                    output_dim=0,
                    weight_loader=weight_loader,
                )
                scale[:] = torch.finfo(torch.float32).min
346
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
347
348
                # The weight_scale_inv name is intentional for deepseekv3
                layer.register_parameter("weight_scale_inv", scale)
349

350
            # INPUT ACTIVATION SCALE
351
            if self.quant_config.activation_scheme == "static":
352
353
354
355
356
                scale = PerTensorScaleParameter(data=torch.empty(
                    len(output_partition_sizes), dtype=torch.float32),
                                                weight_loader=weight_loader)

                scale[:] = torch.finfo(torch.float32).min
357
                set_weight_attrs(scale, {"scale_type": "input_scale"})
358
                layer.register_parameter("input_scale", scale)
359
360
            else:
                layer.register_parameter("input_scale", None)
361

362
    def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor:
363
364
365
366
367
368
369
370
371
372
        # Pad the weight tensor. This is an optimization on ROCm platform, which
        # can benefit from tensors located far enough from one another in memory
        if (envs.VLLM_ROCM_FP8_PADDING and current_platform.is_rocm()
                and weight.stride(-1) == 1
                and (weight.stride(-2) * weight.element_size()) % 512 == 0):
            num_pad = 256 // weight.element_size()
            weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
            torch.cuda.empty_cache()
        return weight

373
    def process_weights_after_loading(self, layer: Module) -> None:
374
        size_k_first = True
375
        # TODO(rob): refactor block quant into separate class.
376
        if self.block_quant:
377
            assert self.quant_config.activation_scheme == "dynamic"
378
            size_k_first = False
379
            if current_platform.is_fp8_fnuz():
380
                weight, weight_scale_inv, _ = \
381
382
                    normalize_e4m3fn_to_e4m3fnuz(
                        weight=layer.weight,
383
384
385
386
387
                        weight_scale=layer.weight_scale_inv)
            else:
                weight = layer.weight.data
                weight_scale_inv = layer.weight_scale_inv.data

388
            weight = self._maybe_pad_weight(weight)
389

390
391
392
393
394
            # Torch.compile cannot use Parameter subclasses.
            layer.weight = Parameter(weight, requires_grad=False)
            layer.weight_scale_inv = Parameter(weight_scale_inv,
                                               requires_grad=False)

395
        # If checkpoint not serialized fp8, quantize the weights.
396
        elif not self.quant_config.is_checkpoint_fp8_serialized:
397
398
            qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
                                                         scale=None)
399
400

            # Update the layer with the new values.
401
402
            layer.weight = Parameter(qweight.t(), requires_grad=False)
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
403
404
            # layer.input_scale is None indicates dynamic quant and scale is
            # computed from input.
405
            layer.input_scale = None
406

407
408
        # If checkpoint is fp8, handle that there are N scales for N
        # shards in a fused module
409
        else:
410
411
412
413
414
            layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
                                                    requires_grad=False)
            if self.quant_config.activation_scheme == "static":
                layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
                                                       requires_grad=False)
415
416
417

            weight = layer.weight
            weight_scale = layer.weight_scale
418
419
420

            # If using w8a8, torch._scaled_mm needs per tensor, so
            # requantize the logical shards as a single weight.
421
            if not self.use_marlin:
422
                # Dequant -> Quant with max scale so we can run per tensor.
423
                if current_platform.is_fp8_fnuz():
424
425
426
427
428
429
430
431
432
                    weight, weight_scale, input_scale = \
                        normalize_e4m3fn_to_e4m3fnuz(
                            weight=weight,
                            weight_scale=weight_scale,
                            input_scale=layer.input_scale)
                    if input_scale is not None:
                        layer.input_scale = Parameter(input_scale,
                                                      requires_grad=False)

433
                weight_scale, weight = requantize_with_max_scale(
434
435
                    weight=weight,
                    weight_scale=weight_scale,
436
437
                    logical_widths=layer.logical_widths,
                )
438

439
            weight = self._maybe_pad_weight(weight)
440
            # Update layer with new values.
441
            layer.weight = Parameter(weight.t(), requires_grad=False)
442
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
443
            if self.quant_config.activation_scheme == "static":
444
445
                layer.input_scale = Parameter(layer.input_scale.max(),
                                              requires_grad=False)
446

447
        if self.use_marlin:
448
            prepare_fp8_layer_for_marlin(layer, size_k_first)
449
450
            # Activations not quantized for marlin.
            del layer.input_scale
451

452
        # On B200, if E8M0 for DeepGemm is used, we need to
453
454
        # requantize the weight and input to the specific scale
        # at the same time.
455
        if is_deep_gemm_e8m0_used():
456
457
458
459
460
461
462
463
464
            assert layer.weight_block_size is not None
            block_sz = tuple(layer.weight_block_size)
            requant_weight_ue8m0_inplace(
                layer.weight.data,
                layer.weight_scale_inv.data if hasattr(
                    layer, "weight_scale_inv") else layer.weight_scale.data,
                block_sz,
            )

465
466
467
468
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
469

470
        if self.use_marlin:
471
472
473
474
            return apply_fp8_marlin_linear(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale,
475
476
477
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
478
                bias=bias)
479

480
481
        if self.block_quant:
            assert self.quant_config.weight_block_size is not None
482

483
            return torch.ops.vllm.apply_w8a8_block_fp8_linear(
484
485
486
487
488
489
                input=x,
                weight=layer.weight,
                block_size=self.quant_config.weight_block_size,
                weight_scale=layer.weight_scale_inv,
                input_scale=layer.input_scale,
                bias=bias,
490
                cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
491
                use_aiter_and_is_supported=self.use_aiter_and_is_supported,
492
493
            )

494
495
496
        return self.fp8_linear.apply(input=x,
                                     weight=layer.weight,
                                     weight_scale=layer.weight_scale,
497
                                     out_dtype=self.out_dtype,
498
499
                                     input_scale=layer.input_scale,
                                     bias=bias)
500
501


502
503
504
505
506
507
508
509
510
511
512
513
514
class Fp8MoEMethod(FusedMoEMethodBase):
    """MoE method for FP8.
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

515
516
517
    def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
        super().__init__(layer.moe_config)
        self.layer = layer
518
        self.quant_config = quant_config
519
        self.block_quant = self.quant_config.weight_block_size is not None
520

521
522
523
        self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None
        self.fused_experts: Optional[
            mk.FusedMoEModularKernel] = None  # type: ignore
524
        if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
525
            self.flashinfer_moe_backend = get_flashinfer_moe_backend()
526
            logger.info_once(
527
528
                f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
            )
529
530
531
532
533
534
535
536
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
        self.use_marlin = (not current_platform.has_device_capability(89)
                           or envs.VLLM_TEST_FORCE_FP8_MARLIN)
        # Disable marlin for rocm
        if current_platform.is_rocm():
            self.use_marlin = False

537
538
539
        # Check for DeepGemm support.
        self.allow_deep_gemm = False
        if envs.VLLM_USE_DEEP_GEMM:
540
            if not has_deep_gemm():
541
                logger.warning_once("Failed to import DeepGemm kernels.")
542
543
            elif not self.block_quant:
                logger.warning_once("Model is not block quantized. Not using "
544
                                    "DeepGemm kernels")
545
            elif (is_deep_gemm_supported()):
546
547
548
549
550
551
                logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.")
                self.allow_deep_gemm = True
            else:
                logger.warning_once(
                    "DeepGemm not supported on the current platform.")

552
553
554
        # Check for CutlassBlockScaledGroupedGemm support.
        self.allow_cutlass_block_scaled_grouped_gemm = False
        if not self.block_quant:
555
556
            logger.debug_once("Model is not block quantized. Not using "
                              "CutlassBlockScaledGroupedGemm kernels")
557
        elif (current_platform.is_cuda()
558
              and current_platform.is_device_capability(100)):
559
560
561
562
563
564
565
566
567
            logger.info_once(
                "Using CutlassBlockScaledGroupedGemm kernels for Fp8MoEMethod."
            )
            self.allow_cutlass_block_scaled_grouped_gemm = True
        else:
            logger.warning_once(
                "CutlassBlockScaledGroupedGemm not supported on the current "
                "platform.")

568
569
570
571
572
573
574
575
576
577
578
579
580
581
    def maybe_make_prepare_finalize(
        self,
        moe: FusedMoEConfig,
    ) -> Optional[mk.FusedMoEPrepareAndFinalize]:
        if self.flashinfer_moe_backend != FlashinferMoeBackend.CUTLASS:
            return super().maybe_make_prepare_finalize(moe)

        prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
            moe,
            layer=self.layer,
        )
        logger.debug_once("%s", prepare_finalize.__class__.__name__)
        return prepare_finalize

582
    def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
583
584
                       intermediate_size_per_partition: int,
                       params_dtype: torch.dtype, **extra_weight_attrs):
585

586
587
588
589
590
591
        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

592
593
        if self.quant_config.is_checkpoint_fp8_serialized:
            params_dtype = torch.float8_e4m3fn
594
595
        if self.block_quant:
            assert self.quant_config.weight_block_size is not None
596
            layer.weight_block_size = self.quant_config.weight_block_size
597
598
599
600
601
602
603
604
605
            tp_size = get_tensor_model_parallel_world_size()
            block_n, block_k = (
                self.quant_config.weight_block_size[0],
                self.quant_config.weight_block_size[1],
            )
            # NOTE: To ensure proper alignment of the block-wise quantization
            # scales, the output_size of the weights for both the gate and up
            # layers must be divisible by block_n.
            # Required by column parallel or enabling merged weights
606
            if intermediate_size_per_partition % block_n != 0:
607
608
                raise ValueError(
                    f"The output_size of gate's and up's weight = "
609
                    f"{intermediate_size_per_partition} is not divisible by "
610
                    f"weight quantization block_n = {block_n}.")
611
612
            if (tp_size > 1
                    and intermediate_size_per_partition % block_k != 0):
613
                # Required by row parallel
614
615
616
617
                raise ValueError(
                    f"The input_size of down's weight = "
                    f"{intermediate_size_per_partition} is not divisible by "
                    f"weight quantization block_k = {block_k}.")
618
619

        # WEIGHTS
620
621
622
623
624
        w13_weight = torch.nn.Parameter(torch.empty(
            num_experts,
            2 * intermediate_size_per_partition,
            hidden_size,
            dtype=params_dtype),
625
626
627
628
                                        requires_grad=False)
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

629
630
631
632
633
        w2_weight = torch.nn.Parameter(torch.empty(
            num_experts,
            hidden_size,
            intermediate_size_per_partition,
            dtype=params_dtype),
634
635
636
637
638
                                       requires_grad=False)
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
        if not self.block_quant:
            # Allocate 2 scales for w1 and w3 respectively.
            # They will be combined to a single scale after weight loading.
            w13_weight_scale = torch.nn.Parameter(torch.ones(
                num_experts, 2, dtype=torch.float32),
                                                  requires_grad=False)
            w2_weight_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                 requires_grad=False)
            layer.register_parameter("w13_weight_scale", w13_weight_scale)
            layer.register_parameter("w2_weight_scale", w2_weight_scale)
        else:
            w13_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
654
655
                    2 * ((intermediate_size_per_partition + block_n - 1) //
                         block_n),
656
657
658
659
660
661
662
663
664
                    (hidden_size + block_k - 1) // block_k,
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            w2_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
                    (hidden_size + block_n - 1) // block_n,
665
                    (intermediate_size_per_partition + block_k - 1) // block_k,
666
667
668
669
670
671
672
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
            layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
            assert self.quant_config.activation_scheme == "dynamic"
673

674
675
676
        # Add the quantization method used (per tensor/grouped/channel)
        # to ensure the weight scales are loaded in properly
        extra_weight_attrs.update(
677
678
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.
             value} if self.block_quant else
679
            {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
680
681
682
683
        # If loading fp8 checkpoint, pass the weight loaders.
        # If loading an fp16 checkpoint, do not (we will quantize in
        #   process_weights_after_loading()
        if self.quant_config.is_checkpoint_fp8_serialized:
684
685
            set_weight_attrs(w13_weight_scale, extra_weight_attrs)
            set_weight_attrs(w2_weight_scale, extra_weight_attrs)
686
687
688
689
690
691
692
693

        # INPUT_SCALES
        if self.quant_config.activation_scheme == "static":
            if not self.quant_config.is_checkpoint_fp8_serialized:
                raise ValueError(
                    "Found static activation scheme for checkpoint that "
                    "was not serialized fp8.")

694
695
696
697
            w13_input_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                 requires_grad=False)
            layer.register_parameter("w13_input_scale", w13_input_scale)
698
            set_weight_attrs(w13_input_scale, extra_weight_attrs)
699
700
701
702
703

            w2_input_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                requires_grad=False)
            layer.register_parameter("w2_input_scale", w2_input_scale)
704
705
            set_weight_attrs(w2_input_scale, extra_weight_attrs)

706
        else:
707
708
            layer.w13_input_scale = None
            layer.w2_input_scale = None
709
710

    def process_weights_after_loading(self, layer: Module) -> None:
711
712
        # Lazy import to avoid importing triton too early.
        from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
713
            is_rocm_aiter_moe_enabled, shuffle_weights)
714

715
716
        self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()

717
        # TODO (rob): refactor block quant into separate class.
718
        if self.block_quant:
719
            assert self.quant_config.activation_scheme == "dynamic"
720
            if current_platform.is_fp8_fnuz():
721
722
723
724
725
726
727
728
                w13_weight, w13_weight_scale_inv, w13_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w13_weight, layer.w13_weight_scale_inv,
                        layer.w13_input_scale)
                w2_weight, w2_weight_scale_inv, w2_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w2_weight, layer.w2_weight_scale_inv,
                        layer.w2_input_scale)
729
            elif self.flashinfer_moe_backend is not None:
730
731
                # NOTE: weights have to be swapped since the activation is
                # applied on different half for flashinfer vs vllm
732
733
                w13_weight = swap_w13_to_w31(layer.w13_weight.data)
                w13_weight_scale_inv = swap_w13_to_w31(
734
735
736
                    layer.w13_weight_scale_inv.data)
                w2_weight = layer.w2_weight.data
                w2_weight_scale_inv = layer.w2_weight_scale_inv.data
737
738
739
740
741
742
743
744
745
746
747
748
749
            else:
                w13_weight = layer.w13_weight.data
                w13_weight_scale_inv = layer.w13_weight_scale_inv.data
                w2_weight = layer.w2_weight
                w2_weight_scale_inv = layer.w2_weight_scale_inv

            # torch.compile() cannot use Parameter subclasses.
            layer.w13_weight = Parameter(w13_weight, requires_grad=False)
            layer.w13_weight_scale_inv = Parameter(w13_weight_scale_inv,
                                                   requires_grad=False)
            layer.w2_weight = Parameter(w2_weight, requires_grad=False)
            layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv,
                                                  requires_grad=False)
750
            if self.rocm_aiter_moe_enabled:
751
752
                # reshaping weights is required for aiter moe kernel.
                shuffled_w13, shuffled_w2 = shuffle_weights(
753
                    layer.w13_weight.data, layer.w2_weight.data)
754
755
756
757
758

                layer.w13_weight = torch.nn.Parameter(shuffled_w13,
                                                      requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2,
                                                     requires_grad=False)
759
760
761

            # DeepGemm scales need to be transposed and aligned.  We try to do
            # it ahead of time for performance reasons.
762
            if self.allow_deep_gemm and not is_deep_gemm_e8m0_used():
763
764
765
                # Lazy import to avoid CUDA initialization problems.
                if _is_col_major(layer.w13_weight_scale_inv):
                    layer.w13_weight_scale_inv = \
766
                        get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv).contiguous()
767
768
                if _is_col_major(layer.w2_weight_scale_inv):
                    layer.w2_weight_scale_inv = \
769
                        get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv).contiguous()
770

771
        # If checkpoint is fp16, quantize in place.
772
        elif not self.quant_config.is_checkpoint_fp8_serialized:
773
            fp8_dtype = current_platform.fp8_dtype()
774
            w13_weight = torch.empty_like(layer.w13_weight.data,
775
776
                                          dtype=fp8_dtype)
            w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
777
778
779

            # Re-initialize w13_scale because we directly quantize
            # merged w13 weights and generate a single scaling factor.
780
            layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
781
                layer.local_num_experts,
782
783
                dtype=torch.float32,
                device=w13_weight.device),
784
                                                        requires_grad=False)
785
            for expert in range(layer.local_num_experts):
786
                w13_weight[expert, :, :], layer.w13_weight_scale[
787
788
                    expert] = ops.scaled_fp8_quant(
                        layer.w13_weight.data[expert, :, :])
789
                w2_weight[expert, :, :], layer.w2_weight_scale[
790
791
792
793
794
795
                    expert] = ops.scaled_fp8_quant(
                        layer.w2_weight.data[expert, :, :])
            layer.w13_weight = torch.nn.Parameter(w13_weight,
                                                  requires_grad=False)
            layer.w2_weight = torch.nn.Parameter(w2_weight,
                                                 requires_grad=False)
796
            if self.rocm_aiter_moe_enabled:
797
                # reshaping weights is required for aiter moe kernel.
798
799
                shuffled_w13, shuffled_w2 = shuffle_weights(
                    layer.w13_weight, layer.w2_weight)
800
801
802
803
804

                layer.w13_weight = torch.nn.Parameter(shuffled_w13,
                                                      requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2,
                                                     requires_grad=False)
805
806
807
808
809
810
811
        # If checkpoint is fp8, we need to handle that the
        # MoE kernels require single activation scale and single weight
        # scale for w13 per expert.
        else:
            # Fp8 moe kernels require a single activation scale.
            # We take the max of all the scales in case they differ.
            if self.quant_config.activation_scheme == "static":
812
813
                if (layer.w13_input_scale is None
                        or layer.w2_input_scale is None):
814
815
816
                    raise ValueError(
                        "QuantConfig has static quantization, but found "
                        "activation scales are None.")
817
818
                if (not all_close_1d(layer.w13_input_scale)
                        or not all_close_1d(layer.w2_input_scale)):
819
                    logger.warning_once(
820
821
                        "Found input_scales that are not equal for "
                        "fp8 MoE layer. Using the maximum across experts "
822
                        "for each layer.")
823
824
825
826
                layer.w13_input_scale = torch.nn.Parameter(
                    layer.w13_input_scale.max(), requires_grad=False)
                layer.w2_input_scale = torch.nn.Parameter(
                    layer.w2_input_scale.max(), requires_grad=False)
827
            if current_platform.is_fp8_fnuz():
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
                # Normalize the weights and scales
                w13_weight, w13_weight_scale, w13_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w13_weight, layer.w13_weight_scale,
                        layer.w13_input_scale)
                w2_weight, w2_weight_scale, w2_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w2_weight, layer.w2_weight_scale,
                        layer.w2_input_scale)
                # Reset the parameter
                layer.w13_weight = torch.nn.Parameter(w13_weight,
                                                      requires_grad=False)
                layer.w13_weight_scale = torch.nn.Parameter(
                    w13_weight_scale, requires_grad=False)
                if w13_input_scale is not None:
                    layer.w13_input_scale = torch.nn.Parameter(
                        w13_input_scale, requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(w2_weight,
                                                     requires_grad=False)
                layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale,
                                                           requires_grad=False)
                if w2_input_scale is not None:
                    layer.w2_input_scale = torch.nn.Parameter(
                        w2_input_scale, requires_grad=False)
852
853
854

            # Fp8 moe kernel needs single weight scale for w13 per expert.
            # We take the max then dequant and requant each expert.
855
            assert layer.w13_weight_scale is not None
856
            shard_size = layer.intermediate_size_per_partition
857
            max_w13_scales = layer.w13_weight_scale.max(dim=1).values
858
            for expert_id in range(layer.local_num_experts):
859
860
861
862
863
                start = 0
                for shard_id in range(2):
                    dq_weight = per_tensor_dequantize(
                        layer.w13_weight[expert_id][start:start +
                                                    shard_size, :],
864
                        layer.w13_weight_scale[expert_id][shard_id])
865
                    layer.w13_weight[expert_id][
866
                        start:start + shard_size, :], _ = ops.scaled_fp8_quant(
867
868
869
                            dq_weight, max_w13_scales[expert_id])
                    start += shard_size

870
            if self.rocm_aiter_moe_enabled:
871
872
                shuffled_w13, shuffled_w2 = shuffle_weights(
                    layer.w13_weight, layer.w2_weight)
873
874
875
876
877
878

                layer.w13_weight = torch.nn.Parameter(shuffled_w13,
                                                      requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2,
                                                     requires_grad=False)

879
880
            layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
                                                        requires_grad=False)
881

882
883
884
885
886
887
888
889
890
891
892
            if self.flashinfer_moe_backend is not None:
                # NOTE: weights have to be swapped since the activation is
                # applied on different half for flashinfer vs vllm
                assert not self.block_quant
                register_moe_scaling_factors(layer)
                w13_weight = swap_w13_to_w31(layer.w13_weight.data)
                if self.flashinfer_moe_backend == \
                    FlashinferMoeBackend.TENSORRT_LLM:
                    rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
                layer.w13_weight.data = w13_weight.data

893
894
895
896
897
        if self.use_marlin:
            prepare_moe_fp8_layer_for_marlin(layer, False)
            # Activations not quantized for marlin.
            del layer.w13_input_scale
            del layer.w2_input_scale
898

899
        if is_deep_gemm_e8m0_used():
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
            assert layer.weight_block_size is not None
            # Re-quantise the expert weights so their scales are UE8M0.
            block_sz = tuple(layer.weight_block_size)
            requant_weight_ue8m0_inplace(
                layer.w13_weight.data,
                layer.w13_weight_scale_inv.data,
                block_sz,
            )
            requant_weight_ue8m0_inplace(
                layer.w2_weight.data,
                layer.w2_weight_scale_inv.data,
                block_sz,
            )

            # Ensure column-major TMA alignment expected by DeepGEMM.
            if _is_col_major(layer.w13_weight_scale_inv):
                layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(
                    layer.w13_weight_scale_inv).contiguous()
            if _is_col_major(layer.w2_weight_scale_inv):
                layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(
                    layer.w2_weight_scale_inv).contiguous()

bnellnm's avatar
bnellnm committed
922
923
924
925
    def select_gemm_impl(
        self,
        prepare_finalize: FusedMoEPrepareAndFinalize,
        moe: FusedMoEConfig,
926
        layer: torch.nn.Module,
bnellnm's avatar
bnellnm committed
927
    ) -> FusedMoEPermuteExpertsUnpermute:
928
929
930
        from vllm.model_executor.layers.fused_moe import (
            BatchedTritonOrDeepGemmExperts, TritonOrDeepGemmExperts)

931
932
        assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
            "Marlin and ROCm AITER are not supported with all2all yet.")
933

bnellnm's avatar
bnellnm committed
934
935
936
937
938
939
940
941
942
943
944
        if (prepare_finalize.activation_format ==
                FusedMoEActivationFormat.BatchedExperts):
            max_num_tokens_per_rank = (
                prepare_finalize.max_num_tokens_per_rank())
            assert max_num_tokens_per_rank is not None
            logger.debug(
                "BatchedTritonOrDeepGemmExperts(%s): "
                "max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
                self.__class__.__name__, max_num_tokens_per_rank,
                self.quant_config.weight_block_size, False)
            return BatchedTritonOrDeepGemmExperts(
945
                max_num_tokens=max_num_tokens_per_rank,
946
                num_dispatchers=prepare_finalize.num_dispatchers(),
947
                use_fp8_w8a8=True,
948
                block_shape=self.quant_config.weight_block_size,
bnellnm's avatar
bnellnm committed
949
                per_act_token_quant=False,
950
                allow_deep_gemm=self.allow_deep_gemm,
951
            )
952
953
954
955
956
957
958
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
            experts = select_cutlass_fp8_gemm_impl(
                moe,
                self.layer,
            )
            logger.debug_once("Using %s", experts.__class__.__name__)
            return experts
959
        else:
bnellnm's avatar
bnellnm committed
960
961
962
963
964
            logger.debug(
                "TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s",
                self.__class__.__name__, self.quant_config.weight_block_size,
                False)
            return TritonOrDeepGemmExperts(
965
966
967
968
969
                use_fp8_w8a8=True,
                block_shape=self.quant_config.weight_block_size,
                allow_deep_gemm=self.allow_deep_gemm,
            )

970
971
972
973
974
975
976
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
977
        use_grouped_topk: bool = False,
978
979
        topk_group: Optional[int] = None,
        num_expert_group: Optional[int] = None,
980
981
        global_num_experts: int = -1,
        expert_map: Optional[torch.Tensor] = None,
982
        custom_routing_function: Optional[Callable] = None,
Simon Mo's avatar
Simon Mo committed
983
        scoring_func: str = "softmax",
984
        routed_scaling_factor: float = 1.0,
Simon Mo's avatar
Simon Mo committed
985
        e_score_correction_bias: Optional[torch.Tensor] = None,
986
        apply_router_weight_on_input: bool = False,
Michael Goin's avatar
Michael Goin committed
987
        activation: str = "silu",
988
989
990
991
        enable_eplb: bool = False,
        expert_load_view: Optional[torch.Tensor] = None,
        logical_to_physical_map: Optional[torch.Tensor] = None,
        logical_replica_count: Optional[torch.Tensor] = None,
992
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
993
994
995
996
997
        if enable_eplb:
            assert expert_load_view is not None
            assert logical_to_physical_map is not None
            assert logical_replica_count is not None
            assert isinstance(layer, FusedMoE)
998

999
1000
1001
1002
1003
        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
            assert activation == 'silu', (
                f"Expected 'silu' activation but got {activation}")
            assert scoring_func == 'sigmoid', (
                f"Expected 'sigmoid' scoring func but got {scoring_func}")
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
            if self.block_quant:
                assert (renormalize and use_grouped_topk
                        and custom_routing_function is None)

                return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
                    routing_logits=router_logits.to(torch.float32),
                    routing_bias=e_score_correction_bias,
                    x=x,
                    w13_weight=layer.w13_weight,
                    w13_weight_scale_inv=layer.w13_weight_scale_inv,
                    w2_weight=layer.w2_weight,
                    w2_weight_scale_inv=layer.w2_weight_scale_inv,
                    global_num_experts=global_num_experts,
                    top_k=top_k,
                    num_expert_group=num_expert_group,
                    topk_group=topk_group,
                    intermediate_size=layer.intermediate_size_per_partition,
                    expert_offset=layer.ep_rank * layer.local_num_experts,
                    local_num_experts=layer.local_num_experts,
                    block_shape=self.quant_config.weight_block_size,
1024
                    routed_scaling=routed_scaling_factor,
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
                )
            else:
                assert (not renormalize
                        and custom_routing_function is not None)
                return apply_flashinfer_per_tensor_scale_fp8(
                    layer=layer,
                    hidden_states=x,
                    router_logits=router_logits,
                    routing_bias=e_score_correction_bias,
                    global_num_experts=global_num_experts,
                    top_k=top_k,
                    num_expert_group=num_expert_group,
                    topk_group=topk_group,
                    apply_router_weight_on_input=apply_router_weight_on_input)
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049

        topk_weights, topk_ids = FusedMoE.select_experts(
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
            num_expert_group=num_expert_group,
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
1050
            routed_scaling_factor=routed_scaling_factor,
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
            e_score_correction_bias=e_score_correction_bias,
            indices_type=self.topk_indices_dtype,
            enable_eplb=enable_eplb,
            expert_map=expert_map,
            expert_load_view=expert_load_view,
            logical_to_physical_map=logical_to_physical_map,
            logical_replica_count=logical_replica_count,
        )

        if self.rocm_aiter_moe_enabled:
            from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (  # noqa: E501
                rocm_aiter_fused_experts)
            return rocm_aiter_fused_experts(
                x,
                layer.w13_weight,
                layer.w2_weight,
1067
1068
1069
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                activation=activation,
1070
                use_fp8_w8a8=True,
1071
1072
1073
1074
1075
1076
1077
                apply_router_weight_on_input=apply_router_weight_on_input,
                w1_scale=(layer.w13_weight_scale_inv
                          if self.block_quant else layer.w13_weight_scale),
                w2_scale=(layer.w2_weight_scale_inv
                          if self.block_quant else layer.w2_weight_scale),
                a1_scale=layer.w13_input_scale,
                a2_scale=layer.w2_input_scale,
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
                block_shape=self.quant_config.weight_block_size,
                expert_map=expert_map)
        elif self.use_marlin:
            assert activation == "silu", (
                f"{activation} not supported for Marlin MoE.")
            return torch.ops.vllm.fused_marlin_moe(
                x,
                layer.w13_weight,
                layer.w2_weight,
                None,
                None,
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                router_logits,
                topk_weights,
                topk_ids,
                quant_type_id=scalar_types.float8_e4m3fn.id,
                apply_router_weight_on_input=apply_router_weight_on_input,
                global_num_experts=global_num_experts,
                expert_map=expert_map)
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
            assert self.block_quant is None
            assert (not renormalize and custom_routing_function is not None)
            assert activation == 'silu', (
                f"Expected 'silu' activation but got {activation}")
            assert scoring_func == 'sigmoid', (
                f"Expected 'sigmoid' scoring func but got {scoring_func}")
            if self.fused_experts is not None:
                return self.fused_experts(
                    x,
                    layer.w13_weight,
                    layer.w2_weight,
                    topk_weights,
                    topk_ids,
                    inplace=False,
                    activation=activation,
                    global_num_experts=global_num_experts,
                    expert_map=expert_map,
                    apply_router_weight_on_input=apply_router_weight_on_input,
                )
            else:
                return flashinfer_cutlass_moe_fp8(
                    x,
                    layer,
                    topk_weights,
                    topk_ids,
                    inplace=False,
                    activation=activation,
                    global_num_experts=global_num_experts,
                    expert_map=expert_map,
                    apply_router_weight_on_input=apply_router_weight_on_input,
                )
1130
        else:
1131
            common_kwargs = dict(
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                inplace=True,
                activation=activation,
                global_num_experts=global_num_experts,
                apply_router_weight_on_input=apply_router_weight_on_input,
                expert_map=expert_map,
                w1_scale=(layer.w13_weight_scale_inv
                          if self.block_quant else layer.w13_weight_scale),
                w2_scale=(layer.w2_weight_scale_inv
                          if self.block_quant else layer.w2_weight_scale),
                a1_scale=layer.w13_input_scale,
                a2_scale=layer.w2_input_scale,
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
            )

            if self.fused_experts is not None:
                return self.fused_experts(**common_kwargs)
            else:
                from vllm.model_executor.layers.fused_moe import fused_experts
                return fused_experts(
                    **common_kwargs,
                    use_fp8_w8a8=True,
                    block_shape=self.quant_config.weight_block_size,
                    allow_deep_gemm=self.allow_deep_gemm,
                    allow_cutlass_block_scaled_grouped_gemm=(
                        self.allow_cutlass_block_scaled_grouped_gemm),
                )
1162
1163


1164
1165
1166
class Fp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
1167
1168
1169
    """

    def __init__(self, quant_config: Fp8Config):
1170
        super().__init__(quant_config)