fp8.py 51.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from enum import Enum
5
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
6
7
8
9
10

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter

11
import vllm.envs as envs
12
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
13
from vllm import _custom_ops as ops
14
from vllm.distributed import get_tensor_model_parallel_world_size
15
from vllm.logger import init_logger
bnellnm's avatar
bnellnm committed
16
from vllm.model_executor.layers.fused_moe import (
17
    FusedMoE, FusedMoEActivationFormat, FusedMoEMethodBase,
18
19
    FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize,
    FusedMoeWeightScaleSupported)
20
21
from vllm.model_executor.layers.fused_moe.config import (
    FusedMoEQuantConfig, fp8_w8a8_moe_quant_config)
XuruiYang's avatar
XuruiYang committed
22
23
from vllm.model_executor.layers.fused_moe.layer import (
    UnquantizedFusedMoEMethod)
24
25
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
                                               UnquantizedLinearMethod)
26
from vllm.model_executor.layers.quantization import QuantizationMethods
27
from vllm.model_executor.layers.quantization.base_config import (
28
    QuantizationConfig, QuantizeMethodBase)
29
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
30
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
31
32
33
34
35
    FlashinferMoeBackend, apply_flashinfer_per_tensor_scale_fp8,
    build_flashinfer_fp8_cutlass_moe_prepare_finalize,
    flashinfer_cutlass_moe_fp8, get_flashinfer_moe_backend,
    register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights,
    select_cutlass_fp8_gemm_impl, swap_w13_to_w31)
36
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
37
    W8A8BlockFp8LinearOp, check_aiter_fp8_linear_support,
38
    create_fp8_input_scale, create_fp8_scale_parameter,
39
    create_fp8_weight_parameter, expert_weight_is_col_major,
40
41
42
    maybe_post_process_fp8_weight_block, process_fp8_weight_block_strategy,
    process_fp8_weight_tensor_strategy, requant_weight_ue8m0_inplace,
    validate_fp8_block_shape)
43
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
44
45
    apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin,
    prepare_moe_fp8_layer_for_marlin)
46
from vllm.model_executor.layers.quantization.utils.quant_utils import (
47
    GroupShape, is_layer_skipped)
48
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
49
50
    Fp8LinearOp, all_close_1d, cutlass_block_fp8_supported,
    cutlass_fp8_supported, maybe_create_device_identity,
51
    normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
52
53
from vllm.model_executor.parameter import (BlockQuantScaleParameter,
                                           ModelWeightParameter,
54
                                           PerTensorScaleParameter)
55
from vllm.model_executor.utils import set_weight_attrs
56
from vllm.platforms import current_platform
57
from vllm.scalar_type import scalar_types
58
from vllm.utils import has_deep_gemm
59
60
61
from vllm.utils.deep_gemm import (get_col_major_tma_aligned_tensor,
                                  is_deep_gemm_e8m0_used,
                                  is_deep_gemm_supported)
62
from vllm.utils.flashinfer import has_flashinfer_moe
63

64
65
66
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

67
68
69
70
ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)

71

72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
class Fp8MoeBackend(Enum):
    NONE = 0
    FLASHINFER_TRTLLM = 1
    FLASHINFER_CUTLASS = 2
    DEEPGEMM = 3
    CUTLASS_BLOCK_SCALED_GROUPED_GEMM = 4
    MARLIN = 5
    TRITON = 6


def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend:
    """
    Select the primary FP8 MoE backend
    Note: Shape-specific fallbacks may still occur at runtime.
    """
    # prefer FlashInfer backends when available and enabled on supported GPUs
    if (current_platform.is_cuda()
            and current_platform.is_device_capability(100)
            and envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe()):
        backend = get_flashinfer_moe_backend()
        if backend == FlashinferMoeBackend.TENSORRT_LLM:
            logger.info_once(
                "Using FlashInfer FP8 MoE TRTLLM backend for SM100")
            return Fp8MoeBackend.FLASHINFER_TRTLLM
        else:
            logger.info_once(
                "Using FlashInfer FP8 MoE CUTLASS backend for SM100")
            return Fp8MoeBackend.FLASHINFER_CUTLASS

    # weight-only path for older GPUs without native FP8
    use_marlin = (not current_platform.has_device_capability(89)
                  or envs.VLLM_TEST_FORCE_FP8_MARLIN)
    if current_platform.is_rocm():
        use_marlin = False
    if use_marlin:
        logger.info_once("Using Marlin backend for FP8 MoE")
        return Fp8MoeBackend.MARLIN

    # deepGEMM on supported platforms with block-quantized weights
    if envs.VLLM_USE_DEEP_GEMM and block_quant:
        if not has_deep_gemm():
            logger.warning_once(
                "DeepGEMM backend requested but not available.")
        elif is_deep_gemm_supported():
            logger.info_once("Using DeepGEMM backend for FP8 MoE")
            return Fp8MoeBackend.DEEPGEMM

    # CUTLASS BlockScaled GroupedGemm on SM100 with block-quantized weights
    if (current_platform.is_cuda()
            and current_platform.is_device_capability(100) and block_quant):
        logger.info_once(
            "Using Cutlass BlockScaled GroupedGemm backend for FP8 MoE")
        return Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM

    # default to Triton
    logger.info_once("Using Triton backend for FP8 MoE")
    return Fp8MoeBackend.TRITON


131
class Fp8Config(QuantizationConfig):
132
133
    """Config class for FP8."""

134
135
    def __init__(
        self,
136
        is_checkpoint_fp8_serialized: bool = False,
137
        activation_scheme: str = "dynamic",
138
139
        ignored_layers: Optional[list[str]] = None,
        weight_block_size: Optional[list[int]] = None,
140
    ) -> None:
141
        super().__init__()
142

143
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
144

145
146
147
        if activation_scheme not in ACTIVATION_SCHEMES:
            raise ValueError(
                f"Unsupported activation scheme {activation_scheme}")
148
        self.activation_scheme = activation_scheme
149
        self.ignored_layers = ignored_layers or []
150
151
152
153
154
155
156
157
158
159
160
161
162
163
        if weight_block_size is not None:
            if not is_checkpoint_fp8_serialized:
                raise ValueError(
                    "The block-wise quantization only supports fp8-serialized "
                    "checkpoint for now.")
            if len(weight_block_size) != 2:
                raise ValueError(
                    "The quantization block size of weight must have 2 "
                    f"dimensions, but got {len(weight_block_size)} dimensions")
            if activation_scheme != "dynamic":
                raise ValueError("The block-wise quantization only supports "
                                 "dynamic activation scheme for now, but got "
                                 f"{activation_scheme} activation scheme.")
        self.weight_block_size = weight_block_size
164

165
    @classmethod
166
    def get_name(cls) -> QuantizationMethods:
167
168
169
        return "fp8"

    @classmethod
170
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
171
172
173
174
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
175
        return 80
176
177

    @classmethod
178
    def get_config_filenames(cls) -> list[str]:
179
180
        return []

181
182
183
184
185
    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if self.ignored_layers is not None:
            self.ignored_layers = hf_to_vllm_mapper.apply_list(
                self.ignored_layers)

186
    @classmethod
187
    def from_config(cls, config: dict[str, Any]) -> "Fp8Config":
188
189
        quant_method = cls.get_from_keys(config, ["quant_method"])
        is_checkpoint_fp8_serialized = ("fp8" in quant_method)
190
        activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
191
        ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
192
193
        weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"],
                                                 None)
194
195
196
197
        if not ignored_layers:
            ignored_layers = cls.get_from_keys_or(config,
                                                  ["modules_to_not_convert"],
                                                  None)
198
        return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
199
                   activation_scheme=activation_scheme,
200
201
                   ignored_layers=ignored_layers,
                   weight_block_size=weight_block_size)
202

203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
    def get_xpu_quant_method(self, layer: torch.nn.Module,
                             prefix: str) -> Optional["QuantizeMethodBase"]:
        from vllm.attention.layer import Attention
        from vllm.model_executor.layers.quantization.ipex_quant import (
            XPUFp8LinearMethod, XPUFp8MoEMethod)
        fp8_config = Fp8Config(
            is_checkpoint_fp8_serialized=self.is_checkpoint_fp8_serialized,
            activation_scheme=self.activation_scheme,
            ignored_layers=self.ignored_layers,
            weight_block_size=self.weight_block_size)

        if isinstance(layer, LinearBase):
            if is_layer_skipped(prefix=prefix,
                                ignored_layers=self.ignored_layers,
                                fused_mapping=self.packed_modules_mapping):
                return UnquantizedLinearMethod()
            return XPUFp8LinearMethod(fp8_config)
        elif isinstance(layer, FusedMoE):
            return XPUFp8MoEMethod(fp8_config, layer)
        elif isinstance(layer, Attention):
            return Fp8KVCacheMethod(self)
        return None

226
227
    def get_quant_method(self, layer: torch.nn.Module,
                         prefix: str) -> Optional["QuantizeMethodBase"]:
228
229
        from vllm.attention.layer import Attention  # Avoid circular import

230
231
        if current_platform.is_xpu():
            return self.get_xpu_quant_method(layer, prefix)
232
        if isinstance(layer, LinearBase):
233
234
235
            if is_layer_skipped(prefix=prefix,
                                ignored_layers=self.ignored_layers,
                                fused_mapping=self.packed_modules_mapping):
236
                return UnquantizedLinearMethod()
237
            return Fp8LinearMethod(self)
238
        elif isinstance(layer, FusedMoE):
XuruiYang's avatar
XuruiYang committed
239
240
241
242
            if is_layer_skipped(prefix=prefix,
                                ignored_layers=self.ignored_layers,
                                fused_mapping=self.packed_modules_mapping):
                return UnquantizedFusedMoEMethod(layer.moe_config)
243
            return Fp8MoEMethod(self, layer)
244
        elif isinstance(layer, Attention):
245
            return Fp8KVCacheMethod(self)
246
        return None
247

248
249
250
251
252
253
254
255
256
257
258
259
260
    def get_cache_scale(self, name: str) -> Optional[str]:
        """
        Check whether the param name matches the format for k/v cache scales
        in compressed-tensors. If this is the case, return its equivalent
        param name expected by vLLM

        :param name: param name
        :return: matching param name for KV cache scale in vLLM
        """
        if name.endswith(".output_scale") and ".k_proj" in name:
            return name.replace(".k_proj.output_scale", ".attn.k_scale")
        if name.endswith(".output_scale") and ".v_proj" in name:
            return name.replace(".v_proj.output_scale", ".attn.v_scale")
261
262
263
264
265
        if name.endswith(".output_scale") and ".q_proj" in name:
            return name.replace(".q_proj.output_scale", ".attn.q_scale")
        if name.endswith("self_attn.prob_output_scale"):
            return name.replace(".prob_output_scale", ".attn.prob_scale")
        # If no matches, return None
266
267
        return None

268
269
270

class Fp8LinearMethod(LinearMethodBase):
    """Linear method for FP8.
271
272
273
274
275
276
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.
277
278
279
280
281

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
    2. Only support float8_e4m3fn data type due to the limitation of
       torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
282

283
284
285
286
    Args:
        quant_config: The quantization config.
    """

287
    def __init__(self, quant_config: Fp8Config):
288
        self.quant_config = quant_config
289
        self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
290
        self.out_dtype = torch.get_default_dtype()
291

292
293
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
294
295
        self.use_marlin = (not current_platform.has_device_capability(89)
                           or envs.VLLM_TEST_FORCE_FP8_MARLIN)
296
        # Disable marlin for rocm
297
        if current_platform.is_rocm():
298
            self.use_marlin = False
299

300
        self.use_aiter_and_is_supported = check_aiter_fp8_linear_support()
301

302
303
        self.weight_block_size = self.quant_config.weight_block_size
        self.block_quant = self.weight_block_size is not None
304
        self.act_q_static = self.quant_config.activation_scheme == "static"
305
306
        if self.weight_block_size:
            self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
307
        else:
308
309
310
311
312
            # Use per-token quantization for better perf if dynamic and cutlass
            if not self.act_q_static and cutlass_fp8_supported():
                self.act_q_group_shape = GroupShape.PER_TOKEN
            else:
                self.act_q_group_shape = GroupShape.PER_TENSOR
313

314
315
316
317
318
319
320
321
322
323
324
325
326
        if self.block_quant:
            assert not self.act_q_static
            assert self.weight_block_size is not None
            self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
                weight_group_shape=GroupShape(*self.weight_block_size),
                act_quant_group_shape=self.act_q_group_shape,
                cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
                use_aiter_and_is_supported=self.use_aiter_and_is_supported,
            )
        else:
            self.fp8_linear = Fp8LinearOp(
                act_quant_static=self.act_q_static,
                act_quant_group_shape=self.act_q_group_shape)
327

328
329
330
331
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
332
        output_partition_sizes: list[int],
333
334
335
336
337
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
338
339
        maybe_create_device_identity()

340
        output_size_per_partition = sum(output_partition_sizes)
341
        weight_loader = extra_weight_attrs.get("weight_loader")
342
343
344
345
346
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None
347

348
        if self.block_quant:
349
350
351
352
353
354
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
            validate_fp8_block_shape(layer, input_size, output_size,
                                     input_size_per_partition,
                                     output_partition_sizes,
                                     self.weight_block_size)
355

356
        # WEIGHT
357
358
359
360
361
362
363
364
365
366
367
368
369
        if self.quant_config.is_checkpoint_fp8_serialized:
            weight = create_fp8_weight_parameter(output_size_per_partition,
                                                 input_size_per_partition,
                                                 weight_loader)
        else:
            # For non-serialized checkpoints, use original dtype
            weight = ModelWeightParameter(data=torch.empty(
                output_size_per_partition,
                input_size_per_partition,
                dtype=params_dtype),
                                          input_dim=1,
                                          output_dim=0,
                                          weight_loader=weight_loader)
370
371
        layer.register_parameter("weight", weight)

372
373
374
375
        # If checkpoint is serialized fp8, load them.
        # Otherwise, wait until process_weights_after_loading.
        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
376
            if not self.block_quant:
377
378
379
380
                scale = create_fp8_scale_parameter(PerTensorScaleParameter,
                                                   output_partition_sizes,
                                                   input_size_per_partition,
                                                   None, weight_loader)
381
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
382
383
                layer.register_parameter("weight_scale", scale)
            else:
384
385
386
387
388
389
390
                assert not self.act_q_static
                assert self.weight_block_size is not None
                scale = create_fp8_scale_parameter(BlockQuantScaleParameter,
                                                   output_partition_sizes,
                                                   input_size_per_partition,
                                                   self.weight_block_size,
                                                   weight_loader)
391
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
392
393
                # The weight_scale_inv name is intentional for deepseekv3
                layer.register_parameter("weight_scale_inv", scale)
394

395
            # INPUT ACTIVATION SCALE
396
397
398
            if self.act_q_static:
                scale = create_fp8_input_scale(output_partition_sizes,
                                               weight_loader)
399
                set_weight_attrs(scale, {"scale_type": "input_scale"})
400
                layer.register_parameter("input_scale", scale)
401
402
            else:
                layer.register_parameter("input_scale", None)
403

404
    def process_weights_after_loading(self, layer: Module) -> None:
405
        size_k_first = True
406
        input_scale = None
407
        # TODO(rob): refactor block quant into separate class.
408
        if self.block_quant:
409
            assert not self.act_q_static
410
            size_k_first = False
411

412
413
414
415
416
            weight, weight_scale = process_fp8_weight_block_strategy(
                layer.weight, layer.weight_scale_inv)
            # Delete the weight_scale_inv parameter to avoid confusion
            # with the weight_scale parameter
            del layer.weight_scale_inv
417

418
        # If checkpoint not serialized fp8, quantize the weights.
419
        elif not self.quant_config.is_checkpoint_fp8_serialized:
420
421
            qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
                                                         scale=None)
422
            weight = qweight.t()
423

424
        # If checkpoint is fp8 per-tensor, handle that there are N scales for N
425
        # shards in a fused module
426
        else:
427
428
            weight = layer.weight
            weight_scale = layer.weight_scale
429
430
431

            # If using w8a8, torch._scaled_mm needs per tensor, so
            # requantize the logical shards as a single weight.
432
            if not self.use_marlin:
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
                weight, weight_scale, input_scale = (
                    process_fp8_weight_tensor_strategy(
                        weight, weight_scale, layer.logical_widths,
                        getattr(layer, 'input_scale', None)))
                if self.act_q_static:
                    assert input_scale is not None
                    input_scale = input_scale.max()
            weight = weight.t()

        # Update layer with new values.
        layer.weight = Parameter(weight.data, requires_grad=False)
        layer.weight_scale = Parameter(weight_scale.data, requires_grad=False)
        layer.input_scale = Parameter(
            input_scale,
            requires_grad=False) if input_scale is not None else None
448

449
        if self.use_marlin:
450
            prepare_fp8_layer_for_marlin(layer, size_k_first)
451
452
            # Activations not quantized for marlin.
            del layer.input_scale
453
            return
454

455
456
457
        if self.block_quant:
            maybe_post_process_fp8_weight_block(
                layer, self.cutlass_block_fp8_supported)
458

459
460
461
462
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
463

464
        if self.use_marlin:
465
466
467
468
            return apply_fp8_marlin_linear(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale,
469
470
471
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
472
                bias=bias)
473

474
        if self.block_quant:
475
476
477
            assert self.weight_block_size is not None

            return self.w8a8_block_fp8_linear.apply(
478
                input=x,
479
480
481
                weight=layer.weight,
                weight_scale=layer.weight_scale,
                input_scale=layer.input_scale,
482
                bias=bias,
483
            )
484

485
486
487
        return self.fp8_linear.apply(input=x,
                                     weight=layer.weight,
                                     weight_scale=layer.weight_scale,
488
                                     out_dtype=self.out_dtype,
489
490
                                     input_scale=layer.input_scale,
                                     bias=bias)
491
492


493
494
495
496
497
498
499
500
501
502
503
504
505
class Fp8MoEMethod(FusedMoEMethodBase):
    """MoE method for FP8.
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

506
507
508
    def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
        super().__init__(layer.moe_config)
        self.layer = layer
509
        self.quant_config = quant_config
510
        self.weight_block_size = self.quant_config.weight_block_size
511
        self.block_quant: bool = self.weight_block_size is not None
512

513
514
        self.fused_experts: Optional[
            mk.FusedMoEModularKernel] = None  # type: ignore
515

516
        self.fp8_backend = get_fp8_moe_backend(self.block_quant)
517

518
        self.use_marlin = (self.fp8_backend == Fp8MoeBackend.MARLIN)
519
        self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None
520
521
522
523
524
525
526
527
528
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
            self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
        elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
            self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS

        self.allow_deep_gemm = (self.fp8_backend == Fp8MoeBackend.DEEPGEMM)
        self.allow_cutlass_block_scaled_grouped_gemm = (
            self.fp8_backend == Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM
        )
529

530
    def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
531
532
                       intermediate_size_per_partition: int,
                       params_dtype: torch.dtype, **extra_weight_attrs):
533

534
535
536
537
538
539
        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

540
541
        if self.quant_config.is_checkpoint_fp8_serialized:
            params_dtype = torch.float8_e4m3fn
542
        if self.block_quant:
543
544
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
545
546
            tp_size = get_tensor_model_parallel_world_size()
            block_n, block_k = (
547
548
                self.weight_block_size[0],
                self.weight_block_size[1],
549
550
551
552
553
            )
            # NOTE: To ensure proper alignment of the block-wise quantization
            # scales, the output_size of the weights for both the gate and up
            # layers must be divisible by block_n.
            # Required by column parallel or enabling merged weights
554
            if intermediate_size_per_partition % block_n != 0:
555
556
                raise ValueError(
                    f"The output_size of gate's and up's weight = "
557
                    f"{intermediate_size_per_partition} is not divisible by "
558
                    f"weight quantization block_n = {block_n}.")
559
560
            if (tp_size > 1
                    and intermediate_size_per_partition % block_k != 0):
561
                # Required by row parallel
562
563
564
565
                raise ValueError(
                    f"The input_size of down's weight = "
                    f"{intermediate_size_per_partition} is not divisible by "
                    f"weight quantization block_k = {block_k}.")
566
567

        # WEIGHTS
568
569
570
571
572
        w13_weight = torch.nn.Parameter(torch.empty(
            num_experts,
            2 * intermediate_size_per_partition,
            hidden_size,
            dtype=params_dtype),
573
574
575
576
                                        requires_grad=False)
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

577
578
579
580
581
        w2_weight = torch.nn.Parameter(torch.empty(
            num_experts,
            hidden_size,
            intermediate_size_per_partition,
            dtype=params_dtype),
582
583
584
585
586
                                       requires_grad=False)
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
        if not self.block_quant:
            # Allocate 2 scales for w1 and w3 respectively.
            # They will be combined to a single scale after weight loading.
            w13_weight_scale = torch.nn.Parameter(torch.ones(
                num_experts, 2, dtype=torch.float32),
                                                  requires_grad=False)
            w2_weight_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                 requires_grad=False)
            layer.register_parameter("w13_weight_scale", w13_weight_scale)
            layer.register_parameter("w2_weight_scale", w2_weight_scale)
        else:
            w13_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
602
603
                    2 * ((intermediate_size_per_partition + block_n - 1) //
                         block_n),
604
605
606
607
608
609
610
611
612
                    (hidden_size + block_k - 1) // block_k,
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            w2_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
                    (hidden_size + block_n - 1) // block_n,
613
                    (intermediate_size_per_partition + block_k - 1) // block_k,
614
615
616
617
618
619
620
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
            layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
            assert self.quant_config.activation_scheme == "dynamic"
621

622
623
624
        # Add the quantization method used (per tensor/grouped/channel)
        # to ensure the weight scales are loaded in properly
        extra_weight_attrs.update(
625
626
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.
             value} if self.block_quant else
627
            {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
628
629
630
631
        # If loading fp8 checkpoint, pass the weight loaders.
        # If loading an fp16 checkpoint, do not (we will quantize in
        #   process_weights_after_loading()
        if self.quant_config.is_checkpoint_fp8_serialized:
632
633
            set_weight_attrs(w13_weight_scale, extra_weight_attrs)
            set_weight_attrs(w2_weight_scale, extra_weight_attrs)
634
635
636
637
638
639
640
641

        # INPUT_SCALES
        if self.quant_config.activation_scheme == "static":
            if not self.quant_config.is_checkpoint_fp8_serialized:
                raise ValueError(
                    "Found static activation scheme for checkpoint that "
                    "was not serialized fp8.")

642
643
644
645
            w13_input_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                 requires_grad=False)
            layer.register_parameter("w13_input_scale", w13_input_scale)
646
            set_weight_attrs(w13_input_scale, extra_weight_attrs)
647
648
649
650
651

            w2_input_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                requires_grad=False)
            layer.register_parameter("w2_input_scale", w2_input_scale)
652
653
            set_weight_attrs(w2_input_scale, extra_weight_attrs)

654
        else:
655
656
            layer.w13_input_scale = None
            layer.w2_input_scale = None
657
658

    def process_weights_after_loading(self, layer: Module) -> None:
659
660
        # Lazy import to avoid importing triton too early.
        from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
661
            is_rocm_aiter_moe_enabled, shuffle_weights)
662

663
664
        self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()

665
        # TODO (rob): refactor block quant into separate class.
666
        if self.block_quant:
667
            assert self.quant_config.activation_scheme == "dynamic"
668
            if current_platform.is_fp8_fnuz():
669
670
671
672
673
674
675
676
                w13_weight, w13_weight_scale_inv, w13_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w13_weight, layer.w13_weight_scale_inv,
                        layer.w13_input_scale)
                w2_weight, w2_weight_scale_inv, w2_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w2_weight, layer.w2_weight_scale_inv,
                        layer.w2_input_scale)
677
            elif self.flashinfer_moe_backend is not None:
678
679
                # NOTE: weights have to be swapped since the activation is
                # applied on different half for flashinfer vs vllm
680
681
                w13_weight = swap_w13_to_w31(layer.w13_weight.data)
                w13_weight_scale_inv = swap_w13_to_w31(
682
683
684
                    layer.w13_weight_scale_inv.data)
                w2_weight = layer.w2_weight.data
                w2_weight_scale_inv = layer.w2_weight_scale_inv.data
685
686
687
688
689
690
691
692
693
694
695
696
697
            else:
                w13_weight = layer.w13_weight.data
                w13_weight_scale_inv = layer.w13_weight_scale_inv.data
                w2_weight = layer.w2_weight
                w2_weight_scale_inv = layer.w2_weight_scale_inv

            # torch.compile() cannot use Parameter subclasses.
            layer.w13_weight = Parameter(w13_weight, requires_grad=False)
            layer.w13_weight_scale_inv = Parameter(w13_weight_scale_inv,
                                                   requires_grad=False)
            layer.w2_weight = Parameter(w2_weight, requires_grad=False)
            layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv,
                                                  requires_grad=False)
698
            if self.rocm_aiter_moe_enabled:
699
700
                # reshaping weights is required for aiter moe kernel.
                shuffled_w13, shuffled_w2 = shuffle_weights(
701
                    layer.w13_weight.data, layer.w2_weight.data)
702
703
704
705
706

                layer.w13_weight = torch.nn.Parameter(shuffled_w13,
                                                      requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2,
                                                     requires_grad=False)
707

708
            # DeepGemm scales need to be transposed and aligned. We try to do
709
            # it ahead of time for performance reasons.
710
            if self.allow_deep_gemm and not is_deep_gemm_e8m0_used():
711
                if expert_weight_is_col_major(layer.w13_weight_scale_inv):
712
                    layer.w13_weight_scale_inv = \
713
                        get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv)
714
                if expert_weight_is_col_major(layer.w2_weight_scale_inv):
715
                    layer.w2_weight_scale_inv = \
716
                        get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv)
717

718
        # If checkpoint is fp16, quantize in place.
719
        elif not self.quant_config.is_checkpoint_fp8_serialized:
720
            fp8_dtype = current_platform.fp8_dtype()
721
            w13_weight = torch.empty_like(layer.w13_weight.data,
722
723
                                          dtype=fp8_dtype)
            w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
724
725
726

            # Re-initialize w13_scale because we directly quantize
            # merged w13 weights and generate a single scaling factor.
727
            layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
728
                layer.local_num_experts,
729
730
                dtype=torch.float32,
                device=w13_weight.device),
731
                                                        requires_grad=False)
732
            for expert in range(layer.local_num_experts):
733
                w13_weight[expert, :, :], layer.w13_weight_scale[
734
735
                    expert] = ops.scaled_fp8_quant(
                        layer.w13_weight.data[expert, :, :])
736
                w2_weight[expert, :, :], layer.w2_weight_scale[
737
738
739
740
741
742
                    expert] = ops.scaled_fp8_quant(
                        layer.w2_weight.data[expert, :, :])
            layer.w13_weight = torch.nn.Parameter(w13_weight,
                                                  requires_grad=False)
            layer.w2_weight = torch.nn.Parameter(w2_weight,
                                                 requires_grad=False)
743
            if self.rocm_aiter_moe_enabled:
744
                # reshaping weights is required for aiter moe kernel.
745
746
                shuffled_w13, shuffled_w2 = shuffle_weights(
                    layer.w13_weight, layer.w2_weight)
747
748
749
750
751

                layer.w13_weight = torch.nn.Parameter(shuffled_w13,
                                                      requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2,
                                                     requires_grad=False)
752
753
754
755
756
757
758
        # If checkpoint is fp8, we need to handle that the
        # MoE kernels require single activation scale and single weight
        # scale for w13 per expert.
        else:
            # Fp8 moe kernels require a single activation scale.
            # We take the max of all the scales in case they differ.
            if self.quant_config.activation_scheme == "static":
759
760
                if (layer.w13_input_scale is None
                        or layer.w2_input_scale is None):
761
762
763
                    raise ValueError(
                        "QuantConfig has static quantization, but found "
                        "activation scales are None.")
764
765
                if (not all_close_1d(layer.w13_input_scale)
                        or not all_close_1d(layer.w2_input_scale)):
766
                    logger.warning_once(
767
768
                        "Found input_scales that are not equal for "
                        "fp8 MoE layer. Using the maximum across experts "
769
                        "for each layer.")
770
771
772
773
                layer.w13_input_scale = torch.nn.Parameter(
                    layer.w13_input_scale.max(), requires_grad=False)
                layer.w2_input_scale = torch.nn.Parameter(
                    layer.w2_input_scale.max(), requires_grad=False)
774
            if current_platform.is_fp8_fnuz():
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
                # Normalize the weights and scales
                w13_weight, w13_weight_scale, w13_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w13_weight, layer.w13_weight_scale,
                        layer.w13_input_scale)
                w2_weight, w2_weight_scale, w2_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w2_weight, layer.w2_weight_scale,
                        layer.w2_input_scale)
                # Reset the parameter
                layer.w13_weight = torch.nn.Parameter(w13_weight,
                                                      requires_grad=False)
                layer.w13_weight_scale = torch.nn.Parameter(
                    w13_weight_scale, requires_grad=False)
                if w13_input_scale is not None:
                    layer.w13_input_scale = torch.nn.Parameter(
                        w13_input_scale, requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(w2_weight,
                                                     requires_grad=False)
                layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale,
                                                           requires_grad=False)
                if w2_input_scale is not None:
                    layer.w2_input_scale = torch.nn.Parameter(
                        w2_input_scale, requires_grad=False)
799
800
801

            # Fp8 moe kernel needs single weight scale for w13 per expert.
            # We take the max then dequant and requant each expert.
802
            assert layer.w13_weight_scale is not None
803
            shard_size = layer.intermediate_size_per_partition
804
            max_w13_scales = layer.w13_weight_scale.max(dim=1).values
805
            for expert_id in range(layer.local_num_experts):
806
807
808
809
810
                start = 0
                for shard_id in range(2):
                    dq_weight = per_tensor_dequantize(
                        layer.w13_weight[expert_id][start:start +
                                                    shard_size, :],
811
                        layer.w13_weight_scale[expert_id][shard_id])
812
                    layer.w13_weight[expert_id][
813
                        start:start + shard_size, :], _ = ops.scaled_fp8_quant(
814
815
816
                            dq_weight, max_w13_scales[expert_id])
                    start += shard_size

817
            if self.rocm_aiter_moe_enabled:
818
819
                shuffled_w13, shuffled_w2 = shuffle_weights(
                    layer.w13_weight, layer.w2_weight)
820
821
822
823
824
825

                layer.w13_weight = torch.nn.Parameter(shuffled_w13,
                                                      requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2,
                                                     requires_grad=False)

826
827
            layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
                                                        requires_grad=False)
828

829
830
831
832
833
834
835
836
837
838
839
            if self.flashinfer_moe_backend is not None:
                # NOTE: weights have to be swapped since the activation is
                # applied on different half for flashinfer vs vllm
                assert not self.block_quant
                register_moe_scaling_factors(layer)
                w13_weight = swap_w13_to_w31(layer.w13_weight.data)
                if self.flashinfer_moe_backend == \
                    FlashinferMoeBackend.TENSORRT_LLM:
                    rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
                layer.w13_weight.data = w13_weight.data

840
841
842
843
844
        if self.use_marlin:
            prepare_moe_fp8_layer_for_marlin(layer, False)
            # Activations not quantized for marlin.
            del layer.w13_input_scale
            del layer.w2_input_scale
845

846
        if is_deep_gemm_e8m0_used() and self.block_quant:
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
            assert layer.weight_block_size is not None
            # Re-quantise the expert weights so their scales are UE8M0.
            block_sz = tuple(layer.weight_block_size)
            requant_weight_ue8m0_inplace(
                layer.w13_weight.data,
                layer.w13_weight_scale_inv.data,
                block_sz,
            )
            requant_weight_ue8m0_inplace(
                layer.w2_weight.data,
                layer.w2_weight_scale_inv.data,
                block_sz,
            )

            # Ensure column-major TMA alignment expected by DeepGEMM.
862
            if expert_weight_is_col_major(layer.w13_weight_scale_inv):
863
                layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(
864
                    layer.w13_weight_scale_inv)
865
            if expert_weight_is_col_major(layer.w2_weight_scale_inv):
866
                layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(
867
                    layer.w2_weight_scale_inv)
868

869
870
871
872
873
874
875
876
877
878
879
880
881
882
    def maybe_make_prepare_finalize(
            self) -> Optional[mk.FusedMoEPrepareAndFinalize]:
        if (self.rocm_aiter_moe_enabled or self.use_marlin
                or self.flashinfer_moe_backend
                == FlashinferMoeBackend.TENSORRT_LLM):
            return None
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
            prepare_finalize = (
                build_flashinfer_fp8_cutlass_moe_prepare_finalize(self.moe))
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
        else:
            return super().maybe_make_prepare_finalize()

bnellnm's avatar
bnellnm committed
883
884
885
    def select_gemm_impl(
        self,
        prepare_finalize: FusedMoEPrepareAndFinalize,
886
        layer: torch.nn.Module,
bnellnm's avatar
bnellnm committed
887
    ) -> FusedMoEPermuteExpertsUnpermute:
888
889
890
        from vllm.model_executor.layers.fused_moe import (
            BatchedTritonOrDeepGemmExperts, TritonOrDeepGemmExperts)

891
892
        assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
            "Marlin and ROCm AITER are not supported with all2all yet.")
893

894
895
        assert self.moe_quant_config is not None

bnellnm's avatar
bnellnm committed
896
897
898
899
900
901
902
903
904
        if (prepare_finalize.activation_format ==
                FusedMoEActivationFormat.BatchedExperts):
            max_num_tokens_per_rank = (
                prepare_finalize.max_num_tokens_per_rank())
            assert max_num_tokens_per_rank is not None
            logger.debug(
                "BatchedTritonOrDeepGemmExperts(%s): "
                "max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
                self.__class__.__name__, max_num_tokens_per_rank,
905
                self.weight_block_size, False)
bnellnm's avatar
bnellnm committed
906
            return BatchedTritonOrDeepGemmExperts(
907
                max_num_tokens=max_num_tokens_per_rank,
908
                num_dispatchers=prepare_finalize.num_dispatchers(),
909
                quant_config=self.moe_quant_config,
910
                allow_deep_gemm=self.allow_deep_gemm,
911
            )
912
913
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
            experts = select_cutlass_fp8_gemm_impl(
914
915
                self.moe,
                self.moe_quant_config,
916
917
918
            )
            logger.debug_once("Using %s", experts.__class__.__name__)
            return experts
919
        else:
bnellnm's avatar
bnellnm committed
920
921
            logger.debug(
                "TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s",
922
                self.__class__.__name__, self.weight_block_size, False)
bnellnm's avatar
bnellnm committed
923
            return TritonOrDeepGemmExperts(
924
                quant_config=self.moe_quant_config,
925
926
927
                allow_deep_gemm=self.allow_deep_gemm,
            )

928
929
930
931
932
933
934
935
936
937
938
939
    def get_fused_moe_quant_config(
            self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
        if self.use_marlin:
            return None

        return fp8_w8a8_moe_quant_config(
            w1_scale=(layer.w13_weight_scale_inv
                      if self.block_quant else layer.w13_weight_scale),
            w2_scale=(layer.w2_weight_scale_inv
                      if self.block_quant else layer.w2_weight_scale),
            a1_scale=layer.w13_input_scale,
            a2_scale=layer.w2_input_scale,
940
            block_shape=self.weight_block_size,
941
942
        )

943
944
945
946
947
948
949
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
950
        use_grouped_topk: bool = False,
951
952
        topk_group: Optional[int] = None,
        num_expert_group: Optional[int] = None,
953
954
        global_num_experts: int = -1,
        expert_map: Optional[torch.Tensor] = None,
955
        custom_routing_function: Optional[Callable] = None,
Simon Mo's avatar
Simon Mo committed
956
        scoring_func: str = "softmax",
957
        routed_scaling_factor: float = 1.0,
Simon Mo's avatar
Simon Mo committed
958
        e_score_correction_bias: Optional[torch.Tensor] = None,
959
        apply_router_weight_on_input: bool = False,
Michael Goin's avatar
Michael Goin committed
960
        activation: str = "silu",
961
962
963
964
        enable_eplb: bool = False,
        expert_load_view: Optional[torch.Tensor] = None,
        logical_to_physical_map: Optional[torch.Tensor] = None,
        logical_replica_count: Optional[torch.Tensor] = None,
965
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
XuruiYang's avatar
XuruiYang committed
966

967
968
969
970
971
        if enable_eplb:
            assert expert_load_view is not None
            assert logical_to_physical_map is not None
            assert logical_replica_count is not None
            assert isinstance(layer, FusedMoE)
972

973
974
        if (self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
                and self.fused_experts is None):
975
976
977
978
            assert activation == 'silu', (
                f"Expected 'silu' activation but got {activation}")
            assert scoring_func == 'sigmoid', (
                f"Expected 'sigmoid' scoring func but got {scoring_func}")
979
            if self.block_quant:
980
                import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe  # noqa: E501, F401
981
982
                assert (renormalize and use_grouped_topk
                        and custom_routing_function is None)
983
984
985
                e_score_correction_bias = (e_score_correction_bias.to(
                    x.dtype) if e_score_correction_bias is not None else None)
                return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
986
987
988
989
990
991
992
993
994
995
996
997
998
999
                    routing_logits=router_logits.to(torch.float32),
                    routing_bias=e_score_correction_bias,
                    x=x,
                    w13_weight=layer.w13_weight,
                    w13_weight_scale_inv=layer.w13_weight_scale_inv,
                    w2_weight=layer.w2_weight,
                    w2_weight_scale_inv=layer.w2_weight_scale_inv,
                    global_num_experts=global_num_experts,
                    top_k=top_k,
                    num_expert_group=num_expert_group,
                    topk_group=topk_group,
                    intermediate_size=layer.intermediate_size_per_partition,
                    expert_offset=layer.ep_rank * layer.local_num_experts,
                    local_num_experts=layer.local_num_experts,
1000
                    block_shape=self.weight_block_size,
1001
                    routed_scaling=routed_scaling_factor,
1002
1003
1004
1005
                )
            else:
                assert (not renormalize
                        and custom_routing_function is not None)
XuruiYang's avatar
XuruiYang committed
1006
                result = apply_flashinfer_per_tensor_scale_fp8(
1007
1008
1009
1010
1011
1012
1013
1014
1015
                    layer=layer,
                    hidden_states=x,
                    router_logits=router_logits,
                    routing_bias=e_score_correction_bias,
                    global_num_experts=global_num_experts,
                    top_k=top_k,
                    num_expert_group=num_expert_group,
                    topk_group=topk_group,
                    apply_router_weight_on_input=apply_router_weight_on_input)
1016

XuruiYang's avatar
XuruiYang committed
1017
1018
1019
1020
        zero_expert_num = getattr(layer, 'zero_expert_num', 0)
        zero_expert_type = getattr(layer, 'zero_expert_type', None)

        select_result = FusedMoE.select_experts(
1021
1022
1023
1024
1025
1026
1027
1028
1029
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
            num_expert_group=num_expert_group,
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
1030
            routed_scaling_factor=routed_scaling_factor,
1031
1032
1033
1034
1035
1036
1037
            e_score_correction_bias=e_score_correction_bias,
            indices_type=self.topk_indices_dtype,
            enable_eplb=enable_eplb,
            expert_map=expert_map,
            expert_load_view=expert_load_view,
            logical_to_physical_map=logical_to_physical_map,
            logical_replica_count=logical_replica_count,
XuruiYang's avatar
XuruiYang committed
1038
1039
1040
            global_num_experts=global_num_experts,
            zero_expert_num=zero_expert_num,
            zero_expert_type=zero_expert_type,
1041
1042
        )

1043
1044
1045
1046
        #
        # Note: the order of checks is important since self.fused_experts
        # can override fused_experts or cutlass but not rocm or marlin.
        #
XuruiYang's avatar
XuruiYang committed
1047
1048
        topk_weights, topk_ids, zero_expert_result = select_result

1049
1050
1051
        if self.rocm_aiter_moe_enabled:
            from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (  # noqa: E501
                rocm_aiter_fused_experts)
1052
            assert self.fused_experts is None
XuruiYang's avatar
XuruiYang committed
1053
            result = rocm_aiter_fused_experts(
1054
1055
1056
                x,
                layer.w13_weight,
                layer.w2_weight,
1057
1058
1059
1060
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                activation=activation,
                apply_router_weight_on_input=apply_router_weight_on_input,
1061
1062
                expert_map=expert_map,
                quant_config=self.moe_quant_config)
1063
1064
1065
        elif self.use_marlin:
            assert activation == "silu", (
                f"{activation} not supported for Marlin MoE.")
1066
            assert self.fused_experts is None
XuruiYang's avatar
XuruiYang committed
1067
            result = torch.ops.vllm.fused_marlin_moe(
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
                x,
                layer.w13_weight,
                layer.w2_weight,
                None,
                None,
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                router_logits,
                topk_weights,
                topk_ids,
                quant_type_id=scalar_types.float8_e4m3fn.id,
                apply_router_weight_on_input=apply_router_weight_on_input,
                global_num_experts=global_num_experts,
1081
1082
                expert_map=expert_map,
                workspace=layer.workspace)
1083
        elif self.fused_experts:
XuruiYang's avatar
XuruiYang committed
1084
            result = self.fused_experts(
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                inplace=True,
                activation=activation,
                global_num_experts=global_num_experts,
                apply_router_weight_on_input=apply_router_weight_on_input,
                expert_map=expert_map,
            )
1096
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
1097
            assert not self.block_quant
1098
1099
1100
1101
1102
            assert (not renormalize and custom_routing_function is not None)
            assert activation == 'silu', (
                f"Expected 'silu' activation but got {activation}")
            assert scoring_func == 'sigmoid', (
                f"Expected 'sigmoid' scoring func but got {scoring_func}")
1103

XuruiYang's avatar
XuruiYang committed
1104
            result = flashinfer_cutlass_moe_fp8(
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
                x,
                layer,
                topk_weights,
                topk_ids,
                inplace=False,
                activation=activation,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
1115
        else:
1116
            from vllm.model_executor.layers.fused_moe import fused_experts
XuruiYang's avatar
XuruiYang committed
1117
            result = fused_experts(
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                inplace=True,
                activation=activation,
                global_num_experts=global_num_experts,
                apply_router_weight_on_input=apply_router_weight_on_input,
                expert_map=expert_map,
1128
1129
1130
1131
                quant_config=self.moe_quant_config,
                allow_deep_gemm=self.allow_deep_gemm,
                allow_cutlass_block_scaled_grouped_gemm=(
                    self.allow_cutlass_block_scaled_grouped_gemm))
XuruiYang's avatar
XuruiYang committed
1132
1133
1134
1135
1136
1137
        if zero_expert_num != 0 and zero_expert_type is not None:
            assert not isinstance(result, tuple), \
                "Shared + zero experts are mutually exclusive not yet supported"
            return result, zero_expert_result
        else:
            return result
1138
1139


1140
1141
1142
class Fp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
1143
1144
1145
    """

    def __init__(self, quant_config: Fp8Config):
1146
        super().__init__(quant_config)