fp8.py 50 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
5
6
7
8
9

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter

10
import vllm.envs as envs
11
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
12
from vllm import _custom_ops as ops
13
from vllm.distributed import get_tensor_model_parallel_world_size
14
from vllm.logger import init_logger
bnellnm's avatar
bnellnm committed
15
from vllm.model_executor.layers.fused_moe import (
16
    FusedMoE, FusedMoEActivationFormat, FusedMoEMethodBase,
17
18
    FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize,
    FusedMoeWeightScaleSupported)
19
20
from vllm.model_executor.layers.fused_moe.config import (
    FusedMoEQuantConfig, fp8_w8a8_moe_quant_config)
21
22
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
                                               UnquantizedLinearMethod)
23
from vllm.model_executor.layers.quantization import QuantizationMethods
24
from vllm.model_executor.layers.quantization.base_config import (
25
    QuantizationConfig, QuantizeMethodBase)
26
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
27
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
28
29
30
31
32
    FlashinferMoeBackend, apply_flashinfer_per_tensor_scale_fp8,
    build_flashinfer_fp8_cutlass_moe_prepare_finalize,
    flashinfer_cutlass_moe_fp8, get_flashinfer_moe_backend,
    register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights,
    select_cutlass_fp8_gemm_impl, swap_w13_to_w31)
33
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
34
    W8A8BlockFp8LinearOp, check_aiter_fp8_linear_support,
35
    create_fp8_input_scale, create_fp8_scale_parameter,
36
    create_fp8_weight_parameter, expert_weight_is_col_major,
37
38
39
    maybe_post_process_fp8_weight_block, process_fp8_weight_block_strategy,
    process_fp8_weight_tensor_strategy, requant_weight_ue8m0_inplace,
    validate_fp8_block_shape)
40
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
41
42
    apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin,
    prepare_moe_fp8_layer_for_marlin)
43
from vllm.model_executor.layers.quantization.utils.quant_utils import (
44
    GroupShape, is_layer_skipped)
45
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
46
47
    Fp8LinearOp, all_close_1d, cutlass_block_fp8_supported,
    cutlass_fp8_supported, maybe_create_device_identity,
48
    normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
49
50
from vllm.model_executor.parameter import (BlockQuantScaleParameter,
                                           ModelWeightParameter,
51
                                           PerTensorScaleParameter)
52
from vllm.model_executor.utils import set_weight_attrs
53
from vllm.platforms import current_platform
54
from vllm.scalar_type import scalar_types
55
from vllm.utils import has_deep_gemm
56
57
58
from vllm.utils.deep_gemm import (get_col_major_tma_aligned_tensor,
                                  is_deep_gemm_e8m0_used,
                                  is_deep_gemm_supported)
59
from vllm.utils.flashinfer import has_flashinfer_moe
60

61
62
63
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

64
65
66
67
ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)

68

69
class Fp8Config(QuantizationConfig):
70
71
    """Config class for FP8."""

72
73
    def __init__(
        self,
74
        is_checkpoint_fp8_serialized: bool = False,
75
        activation_scheme: str = "dynamic",
76
77
        ignored_layers: Optional[list[str]] = None,
        weight_block_size: Optional[list[int]] = None,
78
    ) -> None:
79
        super().__init__()
80

81
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
82

83
84
85
        if activation_scheme not in ACTIVATION_SCHEMES:
            raise ValueError(
                f"Unsupported activation scheme {activation_scheme}")
86
        self.activation_scheme = activation_scheme
87
        self.ignored_layers = ignored_layers or []
88
89
90
91
92
93
94
95
96
97
98
99
100
101
        if weight_block_size is not None:
            if not is_checkpoint_fp8_serialized:
                raise ValueError(
                    "The block-wise quantization only supports fp8-serialized "
                    "checkpoint for now.")
            if len(weight_block_size) != 2:
                raise ValueError(
                    "The quantization block size of weight must have 2 "
                    f"dimensions, but got {len(weight_block_size)} dimensions")
            if activation_scheme != "dynamic":
                raise ValueError("The block-wise quantization only supports "
                                 "dynamic activation scheme for now, but got "
                                 f"{activation_scheme} activation scheme.")
        self.weight_block_size = weight_block_size
102

103
    @classmethod
104
    def get_name(cls) -> QuantizationMethods:
105
106
107
        return "fp8"

    @classmethod
108
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
109
110
111
112
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
113
        return 80
114
115

    @classmethod
116
    def get_config_filenames(cls) -> list[str]:
117
118
        return []

119
120
121
122
123
    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if self.ignored_layers is not None:
            self.ignored_layers = hf_to_vllm_mapper.apply_list(
                self.ignored_layers)

124
    @classmethod
125
    def from_config(cls, config: dict[str, Any]) -> "Fp8Config":
126
127
        quant_method = cls.get_from_keys(config, ["quant_method"])
        is_checkpoint_fp8_serialized = ("fp8" in quant_method)
128
        activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
129
        ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
130
131
        weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"],
                                                 None)
132
133
134
135
        if not ignored_layers:
            ignored_layers = cls.get_from_keys_or(config,
                                                  ["modules_to_not_convert"],
                                                  None)
136
        return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
137
                   activation_scheme=activation_scheme,
138
139
                   ignored_layers=ignored_layers,
                   weight_block_size=weight_block_size)
140

141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
    def get_xpu_quant_method(self, layer: torch.nn.Module,
                             prefix: str) -> Optional["QuantizeMethodBase"]:
        from vllm.attention.layer import Attention
        from vllm.model_executor.layers.quantization.ipex_quant import (
            XPUFp8LinearMethod, XPUFp8MoEMethod)
        fp8_config = Fp8Config(
            is_checkpoint_fp8_serialized=self.is_checkpoint_fp8_serialized,
            activation_scheme=self.activation_scheme,
            ignored_layers=self.ignored_layers,
            weight_block_size=self.weight_block_size)

        if isinstance(layer, LinearBase):
            if is_layer_skipped(prefix=prefix,
                                ignored_layers=self.ignored_layers,
                                fused_mapping=self.packed_modules_mapping):
                return UnquantizedLinearMethod()
            return XPUFp8LinearMethod(fp8_config)
        elif isinstance(layer, FusedMoE):
            return XPUFp8MoEMethod(fp8_config, layer)
        elif isinstance(layer, Attention):
            return Fp8KVCacheMethod(self)
        return None

164
165
    def get_quant_method(self, layer: torch.nn.Module,
                         prefix: str) -> Optional["QuantizeMethodBase"]:
166
167
        from vllm.attention.layer import Attention  # Avoid circular import

168
169
        if current_platform.is_xpu():
            return self.get_xpu_quant_method(layer, prefix)
170
        if isinstance(layer, LinearBase):
171
172
173
            if is_layer_skipped(prefix=prefix,
                                ignored_layers=self.ignored_layers,
                                fused_mapping=self.packed_modules_mapping):
174
                return UnquantizedLinearMethod()
175
            return Fp8LinearMethod(self)
176
        elif isinstance(layer, FusedMoE):
177
            return Fp8MoEMethod(self, layer)
178
        elif isinstance(layer, Attention):
179
            return Fp8KVCacheMethod(self)
180
        return None
181

182
183
184
185
186
187
188
189
190
191
192
193
194
    def get_cache_scale(self, name: str) -> Optional[str]:
        """
        Check whether the param name matches the format for k/v cache scales
        in compressed-tensors. If this is the case, return its equivalent
        param name expected by vLLM

        :param name: param name
        :return: matching param name for KV cache scale in vLLM
        """
        if name.endswith(".output_scale") and ".k_proj" in name:
            return name.replace(".k_proj.output_scale", ".attn.k_scale")
        if name.endswith(".output_scale") and ".v_proj" in name:
            return name.replace(".v_proj.output_scale", ".attn.v_scale")
195
196
197
198
199
        if name.endswith(".output_scale") and ".q_proj" in name:
            return name.replace(".q_proj.output_scale", ".attn.q_scale")
        if name.endswith("self_attn.prob_output_scale"):
            return name.replace(".prob_output_scale", ".attn.prob_scale")
        # If no matches, return None
200
201
        return None

202
203
204

class Fp8LinearMethod(LinearMethodBase):
    """Linear method for FP8.
205
206
207
208
209
210
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.
211
212
213
214
215

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
    2. Only support float8_e4m3fn data type due to the limitation of
       torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
216

217
218
219
220
    Args:
        quant_config: The quantization config.
    """

221
    def __init__(self, quant_config: Fp8Config):
222
        self.quant_config = quant_config
223
        self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
224
        self.out_dtype = torch.get_default_dtype()
225

226
227
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
228
229
        self.use_marlin = (not current_platform.has_device_capability(89)
                           or envs.VLLM_TEST_FORCE_FP8_MARLIN)
230
        # Disable marlin for rocm
231
        if current_platform.is_rocm():
232
            self.use_marlin = False
233

234
        self.use_aiter_and_is_supported = check_aiter_fp8_linear_support()
235

236
237
        self.weight_block_size = self.quant_config.weight_block_size
        self.block_quant = self.weight_block_size is not None
238
        self.act_q_static = self.quant_config.activation_scheme == "static"
239
240
        if self.weight_block_size:
            self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
241
        else:
242
243
244
245
246
            # Use per-token quantization for better perf if dynamic and cutlass
            if not self.act_q_static and cutlass_fp8_supported():
                self.act_q_group_shape = GroupShape.PER_TOKEN
            else:
                self.act_q_group_shape = GroupShape.PER_TENSOR
247

248
249
250
251
252
253
254
255
256
257
258
259
260
        if self.block_quant:
            assert not self.act_q_static
            assert self.weight_block_size is not None
            self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
                weight_group_shape=GroupShape(*self.weight_block_size),
                act_quant_group_shape=self.act_q_group_shape,
                cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
                use_aiter_and_is_supported=self.use_aiter_and_is_supported,
            )
        else:
            self.fp8_linear = Fp8LinearOp(
                act_quant_static=self.act_q_static,
                act_quant_group_shape=self.act_q_group_shape)
261

262
263
264
265
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
266
        output_partition_sizes: list[int],
267
268
269
270
271
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
272
273
        maybe_create_device_identity()

274
        output_size_per_partition = sum(output_partition_sizes)
275
        weight_loader = extra_weight_attrs.get("weight_loader")
276
277
278
279
280
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None
281

282
        if self.block_quant:
283
284
285
286
287
288
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
            validate_fp8_block_shape(layer, input_size, output_size,
                                     input_size_per_partition,
                                     output_partition_sizes,
                                     self.weight_block_size)
289

290
        # WEIGHT
291
292
293
294
295
296
297
298
299
300
301
302
303
        if self.quant_config.is_checkpoint_fp8_serialized:
            weight = create_fp8_weight_parameter(output_size_per_partition,
                                                 input_size_per_partition,
                                                 weight_loader)
        else:
            # For non-serialized checkpoints, use original dtype
            weight = ModelWeightParameter(data=torch.empty(
                output_size_per_partition,
                input_size_per_partition,
                dtype=params_dtype),
                                          input_dim=1,
                                          output_dim=0,
                                          weight_loader=weight_loader)
304
305
        layer.register_parameter("weight", weight)

306
307
308
309
        # If checkpoint is serialized fp8, load them.
        # Otherwise, wait until process_weights_after_loading.
        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
310
            if not self.block_quant:
311
312
313
314
                scale = create_fp8_scale_parameter(PerTensorScaleParameter,
                                                   output_partition_sizes,
                                                   input_size_per_partition,
                                                   None, weight_loader)
315
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
316
317
                layer.register_parameter("weight_scale", scale)
            else:
318
319
320
321
322
323
324
                assert not self.act_q_static
                assert self.weight_block_size is not None
                scale = create_fp8_scale_parameter(BlockQuantScaleParameter,
                                                   output_partition_sizes,
                                                   input_size_per_partition,
                                                   self.weight_block_size,
                                                   weight_loader)
325
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
326
327
                # The weight_scale_inv name is intentional for deepseekv3
                layer.register_parameter("weight_scale_inv", scale)
328

329
            # INPUT ACTIVATION SCALE
330
331
332
            if self.act_q_static:
                scale = create_fp8_input_scale(output_partition_sizes,
                                               weight_loader)
333
                set_weight_attrs(scale, {"scale_type": "input_scale"})
334
                layer.register_parameter("input_scale", scale)
335
336
            else:
                layer.register_parameter("input_scale", None)
337

338
    def process_weights_after_loading(self, layer: Module) -> None:
339
        size_k_first = True
340
        input_scale = None
341
        # TODO(rob): refactor block quant into separate class.
342
        if self.block_quant:
343
            assert not self.act_q_static
344
            size_k_first = False
345

346
347
348
349
350
            weight, weight_scale = process_fp8_weight_block_strategy(
                layer.weight, layer.weight_scale_inv)
            # Delete the weight_scale_inv parameter to avoid confusion
            # with the weight_scale parameter
            del layer.weight_scale_inv
351

352
        # If checkpoint not serialized fp8, quantize the weights.
353
        elif not self.quant_config.is_checkpoint_fp8_serialized:
354
355
            qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
                                                         scale=None)
356
            weight = qweight.t()
357

358
        # If checkpoint is fp8 per-tensor, handle that there are N scales for N
359
        # shards in a fused module
360
        else:
361
362
            weight = layer.weight
            weight_scale = layer.weight_scale
363
364
365

            # If using w8a8, torch._scaled_mm needs per tensor, so
            # requantize the logical shards as a single weight.
366
            if not self.use_marlin:
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
                weight, weight_scale, input_scale = (
                    process_fp8_weight_tensor_strategy(
                        weight, weight_scale, layer.logical_widths,
                        getattr(layer, 'input_scale', None)))
                if self.act_q_static:
                    assert input_scale is not None
                    input_scale = input_scale.max()
            weight = weight.t()

        # Update layer with new values.
        layer.weight = Parameter(weight.data, requires_grad=False)
        layer.weight_scale = Parameter(weight_scale.data, requires_grad=False)
        layer.input_scale = Parameter(
            input_scale,
            requires_grad=False) if input_scale is not None else None
382

383
        if self.use_marlin:
384
            prepare_fp8_layer_for_marlin(layer, size_k_first)
385
386
            # Activations not quantized for marlin.
            del layer.input_scale
387
            return
388

389
390
391
        if self.block_quant:
            maybe_post_process_fp8_weight_block(
                layer, self.cutlass_block_fp8_supported)
392

393
394
395
396
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
397

398
        if self.use_marlin:
399
400
401
402
            return apply_fp8_marlin_linear(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale,
403
404
405
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
406
                bias=bias)
407

408
        if self.block_quant:
409
410
411
            assert self.weight_block_size is not None

            return self.w8a8_block_fp8_linear.apply(
412
                input=x,
413
414
415
                weight=layer.weight,
                weight_scale=layer.weight_scale,
                input_scale=layer.input_scale,
416
                bias=bias,
417
            )
418

419
420
421
        return self.fp8_linear.apply(input=x,
                                     weight=layer.weight,
                                     weight_scale=layer.weight_scale,
422
                                     out_dtype=self.out_dtype,
423
424
                                     input_scale=layer.input_scale,
                                     bias=bias)
425
426


427
428
429
430
431
432
433
434
435
436
437
438
439
class Fp8MoEMethod(FusedMoEMethodBase):
    """MoE method for FP8.
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

440
441
442
    def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
        super().__init__(layer.moe_config)
        self.layer = layer
443
        self.quant_config = quant_config
444
445
        self.weight_block_size = self.quant_config.weight_block_size
        self.block_quant = self.weight_block_size is not None
446

447
448
449
        self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None
        self.fused_experts: Optional[
            mk.FusedMoEModularKernel] = None  # type: ignore
450
        if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
451
            self.flashinfer_moe_backend = get_flashinfer_moe_backend()
452
            logger.info_once(
453
454
                f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
            )
455
456
457
458
459
460
461
462
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
        self.use_marlin = (not current_platform.has_device_capability(89)
                           or envs.VLLM_TEST_FORCE_FP8_MARLIN)
        # Disable marlin for rocm
        if current_platform.is_rocm():
            self.use_marlin = False

463
464
465
        # Check for DeepGemm support.
        self.allow_deep_gemm = False
        if envs.VLLM_USE_DEEP_GEMM:
466
            if not has_deep_gemm():
467
                logger.warning_once("Failed to import DeepGemm kernels.")
468
469
            elif not self.block_quant:
                logger.warning_once("Model is not block quantized. Not using "
470
                                    "DeepGemm kernels")
471
            elif (is_deep_gemm_supported()):
472
473
474
475
476
477
                logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.")
                self.allow_deep_gemm = True
            else:
                logger.warning_once(
                    "DeepGemm not supported on the current platform.")

478
479
480
        # Check for CutlassBlockScaledGroupedGemm support.
        self.allow_cutlass_block_scaled_grouped_gemm = False
        if not self.block_quant:
481
482
            logger.debug_once("Model is not block quantized. Not using "
                              "CutlassBlockScaledGroupedGemm kernels")
483
        elif (current_platform.is_cuda()
484
              and current_platform.is_device_capability(100)):
485
486
487
488
489
490
491
492
493
            logger.info_once(
                "Using CutlassBlockScaledGroupedGemm kernels for Fp8MoEMethod."
            )
            self.allow_cutlass_block_scaled_grouped_gemm = True
        else:
            logger.warning_once(
                "CutlassBlockScaledGroupedGemm not supported on the current "
                "platform.")

494
    def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
495
496
                       intermediate_size_per_partition: int,
                       params_dtype: torch.dtype, **extra_weight_attrs):
497

498
499
500
501
502
503
        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

504
505
        if self.quant_config.is_checkpoint_fp8_serialized:
            params_dtype = torch.float8_e4m3fn
506
        if self.block_quant:
507
508
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
509
510
            tp_size = get_tensor_model_parallel_world_size()
            block_n, block_k = (
511
512
                self.weight_block_size[0],
                self.weight_block_size[1],
513
514
515
516
517
            )
            # NOTE: To ensure proper alignment of the block-wise quantization
            # scales, the output_size of the weights for both the gate and up
            # layers must be divisible by block_n.
            # Required by column parallel or enabling merged weights
518
            if intermediate_size_per_partition % block_n != 0:
519
520
                raise ValueError(
                    f"The output_size of gate's and up's weight = "
521
                    f"{intermediate_size_per_partition} is not divisible by "
522
                    f"weight quantization block_n = {block_n}.")
523
524
            if (tp_size > 1
                    and intermediate_size_per_partition % block_k != 0):
525
                # Required by row parallel
526
527
528
529
                raise ValueError(
                    f"The input_size of down's weight = "
                    f"{intermediate_size_per_partition} is not divisible by "
                    f"weight quantization block_k = {block_k}.")
530
531

        # WEIGHTS
532
533
534
535
536
        w13_weight = torch.nn.Parameter(torch.empty(
            num_experts,
            2 * intermediate_size_per_partition,
            hidden_size,
            dtype=params_dtype),
537
538
539
540
                                        requires_grad=False)
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

541
542
543
544
545
        w2_weight = torch.nn.Parameter(torch.empty(
            num_experts,
            hidden_size,
            intermediate_size_per_partition,
            dtype=params_dtype),
546
547
548
549
550
                                       requires_grad=False)
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
        if not self.block_quant:
            # Allocate 2 scales for w1 and w3 respectively.
            # They will be combined to a single scale after weight loading.
            w13_weight_scale = torch.nn.Parameter(torch.ones(
                num_experts, 2, dtype=torch.float32),
                                                  requires_grad=False)
            w2_weight_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                 requires_grad=False)
            layer.register_parameter("w13_weight_scale", w13_weight_scale)
            layer.register_parameter("w2_weight_scale", w2_weight_scale)
        else:
            w13_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
566
567
                    2 * ((intermediate_size_per_partition + block_n - 1) //
                         block_n),
568
569
570
571
572
573
574
575
576
                    (hidden_size + block_k - 1) // block_k,
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            w2_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
                    (hidden_size + block_n - 1) // block_n,
577
                    (intermediate_size_per_partition + block_k - 1) // block_k,
578
579
580
581
582
583
584
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
            layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
            assert self.quant_config.activation_scheme == "dynamic"
585

586
587
588
        # Add the quantization method used (per tensor/grouped/channel)
        # to ensure the weight scales are loaded in properly
        extra_weight_attrs.update(
589
590
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.
             value} if self.block_quant else
591
            {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
592
593
594
595
        # If loading fp8 checkpoint, pass the weight loaders.
        # If loading an fp16 checkpoint, do not (we will quantize in
        #   process_weights_after_loading()
        if self.quant_config.is_checkpoint_fp8_serialized:
596
597
            set_weight_attrs(w13_weight_scale, extra_weight_attrs)
            set_weight_attrs(w2_weight_scale, extra_weight_attrs)
598
599
600
601
602
603
604
605

        # INPUT_SCALES
        if self.quant_config.activation_scheme == "static":
            if not self.quant_config.is_checkpoint_fp8_serialized:
                raise ValueError(
                    "Found static activation scheme for checkpoint that "
                    "was not serialized fp8.")

606
607
608
609
            w13_input_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                 requires_grad=False)
            layer.register_parameter("w13_input_scale", w13_input_scale)
610
            set_weight_attrs(w13_input_scale, extra_weight_attrs)
611
612
613
614
615

            w2_input_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                requires_grad=False)
            layer.register_parameter("w2_input_scale", w2_input_scale)
616
617
            set_weight_attrs(w2_input_scale, extra_weight_attrs)

618
        else:
619
620
            layer.w13_input_scale = None
            layer.w2_input_scale = None
621
622

    def process_weights_after_loading(self, layer: Module) -> None:
623
624
        # Lazy import to avoid importing triton too early.
        from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
625
            is_rocm_aiter_moe_enabled, shuffle_weights)
626

627
628
        self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()

629
        # TODO (rob): refactor block quant into separate class.
630
        if self.block_quant:
631
            assert self.quant_config.activation_scheme == "dynamic"
632
            if current_platform.is_fp8_fnuz():
633
634
635
636
637
638
639
640
                w13_weight, w13_weight_scale_inv, w13_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w13_weight, layer.w13_weight_scale_inv,
                        layer.w13_input_scale)
                w2_weight, w2_weight_scale_inv, w2_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w2_weight, layer.w2_weight_scale_inv,
                        layer.w2_input_scale)
641
            elif self.flashinfer_moe_backend is not None:
642
643
                # NOTE: weights have to be swapped since the activation is
                # applied on different half for flashinfer vs vllm
644
645
                w13_weight = swap_w13_to_w31(layer.w13_weight.data)
                w13_weight_scale_inv = swap_w13_to_w31(
646
647
648
                    layer.w13_weight_scale_inv.data)
                w2_weight = layer.w2_weight.data
                w2_weight_scale_inv = layer.w2_weight_scale_inv.data
649
650
651
652
653
654
655
656
657
658
659
660
661
            else:
                w13_weight = layer.w13_weight.data
                w13_weight_scale_inv = layer.w13_weight_scale_inv.data
                w2_weight = layer.w2_weight
                w2_weight_scale_inv = layer.w2_weight_scale_inv

            # torch.compile() cannot use Parameter subclasses.
            layer.w13_weight = Parameter(w13_weight, requires_grad=False)
            layer.w13_weight_scale_inv = Parameter(w13_weight_scale_inv,
                                                   requires_grad=False)
            layer.w2_weight = Parameter(w2_weight, requires_grad=False)
            layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv,
                                                  requires_grad=False)
662
            if self.rocm_aiter_moe_enabled:
663
664
                # reshaping weights is required for aiter moe kernel.
                shuffled_w13, shuffled_w2 = shuffle_weights(
665
                    layer.w13_weight.data, layer.w2_weight.data)
666
667
668
669
670

                layer.w13_weight = torch.nn.Parameter(shuffled_w13,
                                                      requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2,
                                                     requires_grad=False)
671

672
            # DeepGemm scales need to be transposed and aligned. We try to do
673
            # it ahead of time for performance reasons.
674
            if self.allow_deep_gemm and not is_deep_gemm_e8m0_used():
675
                if expert_weight_is_col_major(layer.w13_weight_scale_inv):
676
                    layer.w13_weight_scale_inv = \
677
                        get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv)
678
                if expert_weight_is_col_major(layer.w2_weight_scale_inv):
679
                    layer.w2_weight_scale_inv = \
680
                        get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv)
681

682
        # If checkpoint is fp16, quantize in place.
683
        elif not self.quant_config.is_checkpoint_fp8_serialized:
684
            fp8_dtype = current_platform.fp8_dtype()
685
            w13_weight = torch.empty_like(layer.w13_weight.data,
686
687
                                          dtype=fp8_dtype)
            w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
688
689
690

            # Re-initialize w13_scale because we directly quantize
            # merged w13 weights and generate a single scaling factor.
691
            layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
692
                layer.local_num_experts,
693
694
                dtype=torch.float32,
                device=w13_weight.device),
695
                                                        requires_grad=False)
696
            for expert in range(layer.local_num_experts):
697
                w13_weight[expert, :, :], layer.w13_weight_scale[
698
699
                    expert] = ops.scaled_fp8_quant(
                        layer.w13_weight.data[expert, :, :])
700
                w2_weight[expert, :, :], layer.w2_weight_scale[
701
702
703
704
705
706
                    expert] = ops.scaled_fp8_quant(
                        layer.w2_weight.data[expert, :, :])
            layer.w13_weight = torch.nn.Parameter(w13_weight,
                                                  requires_grad=False)
            layer.w2_weight = torch.nn.Parameter(w2_weight,
                                                 requires_grad=False)
707
            if self.rocm_aiter_moe_enabled:
708
                # reshaping weights is required for aiter moe kernel.
709
710
                shuffled_w13, shuffled_w2 = shuffle_weights(
                    layer.w13_weight, layer.w2_weight)
711
712
713
714
715

                layer.w13_weight = torch.nn.Parameter(shuffled_w13,
                                                      requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2,
                                                     requires_grad=False)
716
717
718
719
720
721
722
        # If checkpoint is fp8, we need to handle that the
        # MoE kernels require single activation scale and single weight
        # scale for w13 per expert.
        else:
            # Fp8 moe kernels require a single activation scale.
            # We take the max of all the scales in case they differ.
            if self.quant_config.activation_scheme == "static":
723
724
                if (layer.w13_input_scale is None
                        or layer.w2_input_scale is None):
725
726
727
                    raise ValueError(
                        "QuantConfig has static quantization, but found "
                        "activation scales are None.")
728
729
                if (not all_close_1d(layer.w13_input_scale)
                        or not all_close_1d(layer.w2_input_scale)):
730
                    logger.warning_once(
731
732
                        "Found input_scales that are not equal for "
                        "fp8 MoE layer. Using the maximum across experts "
733
                        "for each layer.")
734
735
736
737
                layer.w13_input_scale = torch.nn.Parameter(
                    layer.w13_input_scale.max(), requires_grad=False)
                layer.w2_input_scale = torch.nn.Parameter(
                    layer.w2_input_scale.max(), requires_grad=False)
738
            if current_platform.is_fp8_fnuz():
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
                # Normalize the weights and scales
                w13_weight, w13_weight_scale, w13_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w13_weight, layer.w13_weight_scale,
                        layer.w13_input_scale)
                w2_weight, w2_weight_scale, w2_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w2_weight, layer.w2_weight_scale,
                        layer.w2_input_scale)
                # Reset the parameter
                layer.w13_weight = torch.nn.Parameter(w13_weight,
                                                      requires_grad=False)
                layer.w13_weight_scale = torch.nn.Parameter(
                    w13_weight_scale, requires_grad=False)
                if w13_input_scale is not None:
                    layer.w13_input_scale = torch.nn.Parameter(
                        w13_input_scale, requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(w2_weight,
                                                     requires_grad=False)
                layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale,
                                                           requires_grad=False)
                if w2_input_scale is not None:
                    layer.w2_input_scale = torch.nn.Parameter(
                        w2_input_scale, requires_grad=False)
763
764
765

            # Fp8 moe kernel needs single weight scale for w13 per expert.
            # We take the max then dequant and requant each expert.
766
            assert layer.w13_weight_scale is not None
767
            shard_size = layer.intermediate_size_per_partition
768
            max_w13_scales = layer.w13_weight_scale.max(dim=1).values
769
            for expert_id in range(layer.local_num_experts):
770
771
772
773
774
                start = 0
                for shard_id in range(2):
                    dq_weight = per_tensor_dequantize(
                        layer.w13_weight[expert_id][start:start +
                                                    shard_size, :],
775
                        layer.w13_weight_scale[expert_id][shard_id])
776
                    layer.w13_weight[expert_id][
777
                        start:start + shard_size, :], _ = ops.scaled_fp8_quant(
778
779
780
                            dq_weight, max_w13_scales[expert_id])
                    start += shard_size

781
            if self.rocm_aiter_moe_enabled:
782
783
                shuffled_w13, shuffled_w2 = shuffle_weights(
                    layer.w13_weight, layer.w2_weight)
784
785
786
787
788
789

                layer.w13_weight = torch.nn.Parameter(shuffled_w13,
                                                      requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2,
                                                     requires_grad=False)

790
791
            layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
                                                        requires_grad=False)
792

793
794
795
796
797
798
799
800
801
802
803
            if self.flashinfer_moe_backend is not None:
                # NOTE: weights have to be swapped since the activation is
                # applied on different half for flashinfer vs vllm
                assert not self.block_quant
                register_moe_scaling_factors(layer)
                w13_weight = swap_w13_to_w31(layer.w13_weight.data)
                if self.flashinfer_moe_backend == \
                    FlashinferMoeBackend.TENSORRT_LLM:
                    rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
                layer.w13_weight.data = w13_weight.data

804
805
806
807
808
        if self.use_marlin:
            prepare_moe_fp8_layer_for_marlin(layer, False)
            # Activations not quantized for marlin.
            del layer.w13_input_scale
            del layer.w2_input_scale
809

810
        if is_deep_gemm_e8m0_used() and self.block_quant:
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
            assert layer.weight_block_size is not None
            # Re-quantise the expert weights so their scales are UE8M0.
            block_sz = tuple(layer.weight_block_size)
            requant_weight_ue8m0_inplace(
                layer.w13_weight.data,
                layer.w13_weight_scale_inv.data,
                block_sz,
            )
            requant_weight_ue8m0_inplace(
                layer.w2_weight.data,
                layer.w2_weight_scale_inv.data,
                block_sz,
            )

            # Ensure column-major TMA alignment expected by DeepGEMM.
826
            if expert_weight_is_col_major(layer.w13_weight_scale_inv):
827
                layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(
828
                    layer.w13_weight_scale_inv)
829
            if expert_weight_is_col_major(layer.w2_weight_scale_inv):
830
                layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(
831
                    layer.w2_weight_scale_inv)
832

833
834
835
836
837
838
839
840
841
842
843
844
845
846
    def maybe_make_prepare_finalize(
            self) -> Optional[mk.FusedMoEPrepareAndFinalize]:
        if (self.rocm_aiter_moe_enabled or self.use_marlin
                or self.flashinfer_moe_backend
                == FlashinferMoeBackend.TENSORRT_LLM):
            return None
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
            prepare_finalize = (
                build_flashinfer_fp8_cutlass_moe_prepare_finalize(self.moe))
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
        else:
            return super().maybe_make_prepare_finalize()

bnellnm's avatar
bnellnm committed
847
848
849
    def select_gemm_impl(
        self,
        prepare_finalize: FusedMoEPrepareAndFinalize,
850
        layer: torch.nn.Module,
bnellnm's avatar
bnellnm committed
851
    ) -> FusedMoEPermuteExpertsUnpermute:
852
853
854
        from vllm.model_executor.layers.fused_moe import (
            BatchedTritonOrDeepGemmExperts, TritonOrDeepGemmExperts)

855
856
        assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
            "Marlin and ROCm AITER are not supported with all2all yet.")
857

858
859
        assert self.moe_quant_config is not None

bnellnm's avatar
bnellnm committed
860
861
862
863
864
865
866
867
868
        if (prepare_finalize.activation_format ==
                FusedMoEActivationFormat.BatchedExperts):
            max_num_tokens_per_rank = (
                prepare_finalize.max_num_tokens_per_rank())
            assert max_num_tokens_per_rank is not None
            logger.debug(
                "BatchedTritonOrDeepGemmExperts(%s): "
                "max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
                self.__class__.__name__, max_num_tokens_per_rank,
869
                self.weight_block_size, False)
bnellnm's avatar
bnellnm committed
870
            return BatchedTritonOrDeepGemmExperts(
871
                max_num_tokens=max_num_tokens_per_rank,
872
                num_dispatchers=prepare_finalize.num_dispatchers(),
873
                quant_config=self.moe_quant_config,
874
                allow_deep_gemm=self.allow_deep_gemm,
875
            )
876
877
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
            experts = select_cutlass_fp8_gemm_impl(
878
879
                self.moe,
                self.moe_quant_config,
880
881
882
            )
            logger.debug_once("Using %s", experts.__class__.__name__)
            return experts
883
        else:
bnellnm's avatar
bnellnm committed
884
885
            logger.debug(
                "TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s",
886
                self.__class__.__name__, self.weight_block_size, False)
bnellnm's avatar
bnellnm committed
887
            return TritonOrDeepGemmExperts(
888
                quant_config=self.moe_quant_config,
889
890
891
                allow_deep_gemm=self.allow_deep_gemm,
            )

892
893
894
895
896
897
898
899
900
901
902
903
    def get_fused_moe_quant_config(
            self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
        if self.use_marlin:
            return None

        return fp8_w8a8_moe_quant_config(
            w1_scale=(layer.w13_weight_scale_inv
                      if self.block_quant else layer.w13_weight_scale),
            w2_scale=(layer.w2_weight_scale_inv
                      if self.block_quant else layer.w2_weight_scale),
            a1_scale=layer.w13_input_scale,
            a2_scale=layer.w2_input_scale,
904
            block_shape=self.weight_block_size,
905
906
        )

907
908
909
910
911
912
913
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
914
        use_grouped_topk: bool = False,
915
916
        topk_group: Optional[int] = None,
        num_expert_group: Optional[int] = None,
917
918
        global_num_experts: int = -1,
        expert_map: Optional[torch.Tensor] = None,
919
        custom_routing_function: Optional[Callable] = None,
Simon Mo's avatar
Simon Mo committed
920
        scoring_func: str = "softmax",
921
        routed_scaling_factor: float = 1.0,
Simon Mo's avatar
Simon Mo committed
922
        e_score_correction_bias: Optional[torch.Tensor] = None,
923
        apply_router_weight_on_input: bool = False,
Michael Goin's avatar
Michael Goin committed
924
        activation: str = "silu",
925
926
927
928
        enable_eplb: bool = False,
        expert_load_view: Optional[torch.Tensor] = None,
        logical_to_physical_map: Optional[torch.Tensor] = None,
        logical_replica_count: Optional[torch.Tensor] = None,
929
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
930
931
932
933
934
        if enable_eplb:
            assert expert_load_view is not None
            assert logical_to_physical_map is not None
            assert logical_replica_count is not None
            assert isinstance(layer, FusedMoE)
935

936
937
        if (self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
                and self.fused_experts is None):
938
939
940
941
            assert activation == 'silu', (
                f"Expected 'silu' activation but got {activation}")
            assert scoring_func == 'sigmoid', (
                f"Expected 'sigmoid' scoring func but got {scoring_func}")
942
            if self.block_quant:
943
                import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe  # noqa: E501, F401
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
                assert (renormalize and use_grouped_topk
                        and custom_routing_function is None)

                return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
                    routing_logits=router_logits.to(torch.float32),
                    routing_bias=e_score_correction_bias,
                    x=x,
                    w13_weight=layer.w13_weight,
                    w13_weight_scale_inv=layer.w13_weight_scale_inv,
                    w2_weight=layer.w2_weight,
                    w2_weight_scale_inv=layer.w2_weight_scale_inv,
                    global_num_experts=global_num_experts,
                    top_k=top_k,
                    num_expert_group=num_expert_group,
                    topk_group=topk_group,
                    intermediate_size=layer.intermediate_size_per_partition,
                    expert_offset=layer.ep_rank * layer.local_num_experts,
                    local_num_experts=layer.local_num_experts,
962
                    block_shape=self.weight_block_size,
963
                    routed_scaling=routed_scaling_factor,
964
965
966
967
968
969
970
971
972
973
974
975
976
977
                )
            else:
                assert (not renormalize
                        and custom_routing_function is not None)
                return apply_flashinfer_per_tensor_scale_fp8(
                    layer=layer,
                    hidden_states=x,
                    router_logits=router_logits,
                    routing_bias=e_score_correction_bias,
                    global_num_experts=global_num_experts,
                    top_k=top_k,
                    num_expert_group=num_expert_group,
                    topk_group=topk_group,
                    apply_router_weight_on_input=apply_router_weight_on_input)
978
979
980
981
982
983
984
985
986
987
988

        topk_weights, topk_ids = FusedMoE.select_experts(
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
            num_expert_group=num_expert_group,
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
989
            routed_scaling_factor=routed_scaling_factor,
990
991
992
993
994
995
996
997
998
            e_score_correction_bias=e_score_correction_bias,
            indices_type=self.topk_indices_dtype,
            enable_eplb=enable_eplb,
            expert_map=expert_map,
            expert_load_view=expert_load_view,
            logical_to_physical_map=logical_to_physical_map,
            logical_replica_count=logical_replica_count,
        )

999
1000
1001
1002
        #
        # Note: the order of checks is important since self.fused_experts
        # can override fused_experts or cutlass but not rocm or marlin.
        #
1003
1004
1005
        if self.rocm_aiter_moe_enabled:
            from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (  # noqa: E501
                rocm_aiter_fused_experts)
1006
            assert self.fused_experts is None
1007
1008
1009
1010
            return rocm_aiter_fused_experts(
                x,
                layer.w13_weight,
                layer.w2_weight,
1011
1012
1013
1014
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                activation=activation,
                apply_router_weight_on_input=apply_router_weight_on_input,
1015
1016
                expert_map=expert_map,
                quant_config=self.moe_quant_config)
1017
1018
1019
        elif self.use_marlin:
            assert activation == "silu", (
                f"{activation} not supported for Marlin MoE.")
1020
            assert self.fused_experts is None
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
            return torch.ops.vllm.fused_marlin_moe(
                x,
                layer.w13_weight,
                layer.w2_weight,
                None,
                None,
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                router_logits,
                topk_weights,
                topk_ids,
                quant_type_id=scalar_types.float8_e4m3fn.id,
                apply_router_weight_on_input=apply_router_weight_on_input,
                global_num_experts=global_num_experts,
1035
1036
                expert_map=expert_map,
                workspace=layer.workspace)
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
        elif self.fused_experts:
            return self.fused_experts(
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                inplace=True,
                activation=activation,
                global_num_experts=global_num_experts,
                apply_router_weight_on_input=apply_router_weight_on_input,
                expert_map=expert_map,
            )
1050
1051
1052
1053
1054
1055
1056
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
            assert self.block_quant is None
            assert (not renormalize and custom_routing_function is not None)
            assert activation == 'silu', (
                f"Expected 'silu' activation but got {activation}")
            assert scoring_func == 'sigmoid', (
                f"Expected 'sigmoid' scoring func but got {scoring_func}")
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068

            return flashinfer_cutlass_moe_fp8(
                x,
                layer,
                topk_weights,
                topk_ids,
                inplace=False,
                activation=activation,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
1069
        else:
1070
1071
            from vllm.model_executor.layers.fused_moe import fused_experts
            return fused_experts(
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                inplace=True,
                activation=activation,
                global_num_experts=global_num_experts,
                apply_router_weight_on_input=apply_router_weight_on_input,
                expert_map=expert_map,
1082
1083
1084
1085
                quant_config=self.moe_quant_config,
                allow_deep_gemm=self.allow_deep_gemm,
                allow_cutlass_block_scaled_grouped_gemm=(
                    self.allow_cutlass_block_scaled_grouped_gemm))
1086
1087


1088
1089
1090
class Fp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
1091
1092
1093
    """

    def __init__(self, quant_config: Fp8Config):
1094
        super().__init__(quant_config)