fp8.py 51.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from enum import Enum
5
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
6
7
8
9
10

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter

11
import vllm.envs as envs
12
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
13
from vllm import _custom_ops as ops
14
from vllm.distributed import get_tensor_model_parallel_world_size
15
from vllm.logger import init_logger
bnellnm's avatar
bnellnm committed
16
from vllm.model_executor.layers.fused_moe import (
17
18
19
20
21
22
23
    FusedMoE,
    FusedMoEActivationFormat,
    FusedMoEMethodBase,
    FusedMoEPermuteExpertsUnpermute,
    FusedMoEPrepareAndFinalize,
    FusedMoeWeightScaleSupported,
)
24
from vllm.model_executor.layers.fused_moe.config import (
25
26
27
28
29
30
31
32
33
    FusedMoEQuantConfig,
    fp8_w8a8_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
)
34
from vllm.model_executor.layers.quantization import QuantizationMethods
35
from vllm.model_executor.layers.quantization.base_config import (
36
37
38
    QuantizationConfig,
    QuantizeMethodBase,
)
39
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
40
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
41
42
    FlashinferMoeBackend,
    apply_flashinfer_per_tensor_scale_fp8,
43
    build_flashinfer_fp8_cutlass_moe_prepare_finalize,
44
45
46
47
48
49
50
    flashinfer_cutlass_moe_fp8,
    get_flashinfer_moe_backend,
    register_moe_scaling_factors,
    rotate_flashinfer_fp8_moe_weights,
    select_cutlass_fp8_gemm_impl,
    swap_w13_to_w31,
)
51
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
52
53
54
55
56
57
58
59
60
61
62
63
    W8A8BlockFp8LinearOp,
    check_aiter_fp8_linear_support,
    create_fp8_input_scale,
    create_fp8_scale_parameter,
    create_fp8_weight_parameter,
    expert_weight_is_col_major,
    maybe_post_process_fp8_weight_block,
    process_fp8_weight_block_strategy,
    process_fp8_weight_tensor_strategy,
    requant_weight_ue8m0_inplace,
    validate_fp8_block_shape,
)
64
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
65
66
67
68
    apply_fp8_marlin_linear,
    prepare_fp8_layer_for_marlin,
    prepare_moe_fp8_layer_for_marlin,
)
69
from vllm.model_executor.layers.quantization.utils.quant_utils import (
70
71
72
    GroupShape,
    is_layer_skipped,
)
73
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
74
75
76
77
78
79
80
81
82
83
84
85
86
    Fp8LinearOp,
    all_close_1d,
    cutlass_block_fp8_supported,
    cutlass_fp8_supported,
    maybe_create_device_identity,
    normalize_e4m3fn_to_e4m3fnuz,
    per_tensor_dequantize,
)
from vllm.model_executor.parameter import (
    BlockQuantScaleParameter,
    ModelWeightParameter,
    PerTensorScaleParameter,
)
87
from vllm.model_executor.utils import set_weight_attrs
88
from vllm.platforms import current_platform
89
from vllm.scalar_type import scalar_types
90
from vllm.utils import has_deep_gemm
91
92
93
94
95
from vllm.utils.deep_gemm import (
    get_col_major_tma_aligned_tensor,
    is_deep_gemm_e8m0_used,
    is_deep_gemm_supported,
)
96
from vllm.utils.flashinfer import has_flashinfer_moe
97

98
99
100
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

101
102
103
104
ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)

105

106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
class Fp8MoeBackend(Enum):
    NONE = 0
    FLASHINFER_TRTLLM = 1
    FLASHINFER_CUTLASS = 2
    DEEPGEMM = 3
    CUTLASS_BLOCK_SCALED_GROUPED_GEMM = 4
    MARLIN = 5
    TRITON = 6


def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend:
    """
    Select the primary FP8 MoE backend
    Note: Shape-specific fallbacks may still occur at runtime.
    """
    # prefer FlashInfer backends when available and enabled on supported GPUs
122
123
124
125
126
127
    if (
        current_platform.is_cuda()
        and current_platform.is_device_capability(100)
        and envs.VLLM_USE_FLASHINFER_MOE_FP8
        and has_flashinfer_moe()
    ):
128
129
        backend = get_flashinfer_moe_backend()
        if backend == FlashinferMoeBackend.TENSORRT_LLM:
130
            logger.info_once("Using FlashInfer FP8 MoE TRTLLM backend for SM100")
131
132
            return Fp8MoeBackend.FLASHINFER_TRTLLM
        else:
133
            logger.info_once("Using FlashInfer FP8 MoE CUTLASS backend for SM100")
134
135
136
            return Fp8MoeBackend.FLASHINFER_CUTLASS

    # weight-only path for older GPUs without native FP8
137
138
139
140
    use_marlin = (
        not current_platform.has_device_capability(89)
        or envs.VLLM_TEST_FORCE_FP8_MARLIN
    )
141
142
143
144
145
146
147
148
149
    if current_platform.is_rocm():
        use_marlin = False
    if use_marlin:
        logger.info_once("Using Marlin backend for FP8 MoE")
        return Fp8MoeBackend.MARLIN

    # deepGEMM on supported platforms with block-quantized weights
    if envs.VLLM_USE_DEEP_GEMM and block_quant:
        if not has_deep_gemm():
150
            logger.warning_once("DeepGEMM backend requested but not available.")
151
152
153
154
155
        elif is_deep_gemm_supported():
            logger.info_once("Using DeepGEMM backend for FP8 MoE")
            return Fp8MoeBackend.DEEPGEMM

    # CUTLASS BlockScaled GroupedGemm on SM100 with block-quantized weights
156
157
158
159
160
161
    if (
        current_platform.is_cuda()
        and current_platform.is_device_capability(100)
        and block_quant
    ):
        logger.info_once("Using Cutlass BlockScaled GroupedGemm backend for FP8 MoE")
162
163
164
165
166
167
168
        return Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM

    # default to Triton
    logger.info_once("Using Triton backend for FP8 MoE")
    return Fp8MoeBackend.TRITON


169
class Fp8Config(QuantizationConfig):
170
171
    """Config class for FP8."""

172
173
    def __init__(
        self,
174
        is_checkpoint_fp8_serialized: bool = False,
175
        activation_scheme: str = "dynamic",
176
177
        ignored_layers: Optional[list[str]] = None,
        weight_block_size: Optional[list[int]] = None,
178
    ) -> None:
179
        super().__init__()
180

181
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
182

183
        if activation_scheme not in ACTIVATION_SCHEMES:
184
            raise ValueError(f"Unsupported activation scheme {activation_scheme}")
185
        self.activation_scheme = activation_scheme
186
        self.ignored_layers = ignored_layers or []
187
188
189
190
        if weight_block_size is not None:
            if not is_checkpoint_fp8_serialized:
                raise ValueError(
                    "The block-wise quantization only supports fp8-serialized "
191
192
                    "checkpoint for now."
                )
193
194
195
            if len(weight_block_size) != 2:
                raise ValueError(
                    "The quantization block size of weight must have 2 "
196
197
                    f"dimensions, but got {len(weight_block_size)} dimensions"
                )
198
            if activation_scheme != "dynamic":
199
200
201
202
203
                raise ValueError(
                    "The block-wise quantization only supports "
                    "dynamic activation scheme for now, but got "
                    f"{activation_scheme} activation scheme."
                )
204
        self.weight_block_size = weight_block_size
205

206
    @classmethod
207
    def get_name(cls) -> QuantizationMethods:
208
209
210
        return "fp8"

    @classmethod
211
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
212
213
214
215
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
216
        return 80
217
218

    @classmethod
219
    def get_config_filenames(cls) -> list[str]:
220
221
        return []

222
223
    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if self.ignored_layers is not None:
224
            self.ignored_layers = hf_to_vllm_mapper.apply_list(self.ignored_layers)
225

226
    @classmethod
227
    def from_config(cls, config: dict[str, Any]) -> "Fp8Config":
228
        quant_method = cls.get_from_keys(config, ["quant_method"])
229
        is_checkpoint_fp8_serialized = "fp8" in quant_method
230
        activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
231
        ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
232
        weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None)
233
        if not ignored_layers:
234
235
236
237
238
239
240
241
242
243
244
245
246
            ignored_layers = cls.get_from_keys_or(
                config, ["modules_to_not_convert"], None
            )
        return cls(
            is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
            activation_scheme=activation_scheme,
            ignored_layers=ignored_layers,
            weight_block_size=weight_block_size,
        )

    def get_xpu_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
247
248
        from vllm.attention.layer import Attention
        from vllm.model_executor.layers.quantization.ipex_quant import (
249
250
251
252
            XPUFp8LinearMethod,
            XPUFp8MoEMethod,
        )

253
254
255
256
        fp8_config = Fp8Config(
            is_checkpoint_fp8_serialized=self.is_checkpoint_fp8_serialized,
            activation_scheme=self.activation_scheme,
            ignored_layers=self.ignored_layers,
257
258
            weight_block_size=self.weight_block_size,
        )
259
260

        if isinstance(layer, LinearBase):
261
262
263
264
265
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
266
267
268
269
270
271
272
273
                return UnquantizedLinearMethod()
            return XPUFp8LinearMethod(fp8_config)
        elif isinstance(layer, FusedMoE):
            return XPUFp8MoEMethod(fp8_config, layer)
        elif isinstance(layer, Attention):
            return Fp8KVCacheMethod(self)
        return None

274
275
276
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
277
278
        from vllm.attention.layer import Attention  # Avoid circular import

279
280
        if current_platform.is_xpu():
            return self.get_xpu_quant_method(layer, prefix)
281
        if isinstance(layer, LinearBase):
282
283
284
285
286
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
287
                return UnquantizedLinearMethod()
288
            return Fp8LinearMethod(self)
289
        elif isinstance(layer, FusedMoE):
290
291
292
293
294
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
XuruiYang's avatar
XuruiYang committed
295
                return UnquantizedFusedMoEMethod(layer.moe_config)
296
            return Fp8MoEMethod(self, layer)
297
        elif isinstance(layer, Attention):
298
            return Fp8KVCacheMethod(self)
299
        return None
300

301
302
303
304
305
306
307
308
309
310
311
312
313
    def get_cache_scale(self, name: str) -> Optional[str]:
        """
        Check whether the param name matches the format for k/v cache scales
        in compressed-tensors. If this is the case, return its equivalent
        param name expected by vLLM

        :param name: param name
        :return: matching param name for KV cache scale in vLLM
        """
        if name.endswith(".output_scale") and ".k_proj" in name:
            return name.replace(".k_proj.output_scale", ".attn.k_scale")
        if name.endswith(".output_scale") and ".v_proj" in name:
            return name.replace(".v_proj.output_scale", ".attn.v_scale")
314
315
316
317
318
        if name.endswith(".output_scale") and ".q_proj" in name:
            return name.replace(".q_proj.output_scale", ".attn.q_scale")
        if name.endswith("self_attn.prob_output_scale"):
            return name.replace(".prob_output_scale", ".attn.prob_scale")
        # If no matches, return None
319
320
        return None

321
322
323

class Fp8LinearMethod(LinearMethodBase):
    """Linear method for FP8.
324
325
326
327
328
329
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.
330
331
332
333
334

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
    2. Only support float8_e4m3fn data type due to the limitation of
       torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
335

336
337
338
339
    Args:
        quant_config: The quantization config.
    """

340
    def __init__(self, quant_config: Fp8Config):
341
        self.quant_config = quant_config
342
        self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
343
        self.out_dtype = torch.get_default_dtype()
344

345
346
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
347
348
349
350
        self.use_marlin = (
            not current_platform.has_device_capability(89)
            or envs.VLLM_TEST_FORCE_FP8_MARLIN
        )
351
        # Disable marlin for rocm
352
        if current_platform.is_rocm():
353
            self.use_marlin = False
354

355
        self.use_aiter_and_is_supported = check_aiter_fp8_linear_support()
356

357
358
        self.weight_block_size = self.quant_config.weight_block_size
        self.block_quant = self.weight_block_size is not None
359
        self.act_q_static = self.quant_config.activation_scheme == "static"
360
361
        if self.weight_block_size:
            self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
362
        else:
363
364
365
366
367
            # Use per-token quantization for better perf if dynamic and cutlass
            if not self.act_q_static and cutlass_fp8_supported():
                self.act_q_group_shape = GroupShape.PER_TOKEN
            else:
                self.act_q_group_shape = GroupShape.PER_TENSOR
368

369
370
371
372
373
374
375
376
377
378
379
380
        if self.block_quant:
            assert not self.act_q_static
            assert self.weight_block_size is not None
            self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
                weight_group_shape=GroupShape(*self.weight_block_size),
                act_quant_group_shape=self.act_q_group_shape,
                cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
                use_aiter_and_is_supported=self.use_aiter_and_is_supported,
            )
        else:
            self.fp8_linear = Fp8LinearOp(
                act_quant_static=self.act_q_static,
381
382
                act_quant_group_shape=self.act_q_group_shape,
            )
383

384
385
386
387
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
388
        output_partition_sizes: list[int],
389
390
391
392
393
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
394
395
        maybe_create_device_identity()

396
        output_size_per_partition = sum(output_partition_sizes)
397
        weight_loader = extra_weight_attrs.get("weight_loader")
398
399
400
401
402
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None
403

404
        if self.block_quant:
405
406
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
407
408
409
410
411
412
413
414
            validate_fp8_block_shape(
                layer,
                input_size,
                output_size,
                input_size_per_partition,
                output_partition_sizes,
                self.weight_block_size,
            )
415

416
        # WEIGHT
417
        if self.quant_config.is_checkpoint_fp8_serialized:
418
419
420
            weight = create_fp8_weight_parameter(
                output_size_per_partition, input_size_per_partition, weight_loader
            )
421
422
        else:
            # For non-serialized checkpoints, use original dtype
423
424
425
426
427
428
429
430
431
432
            weight = ModelWeightParameter(
                data=torch.empty(
                    output_size_per_partition,
                    input_size_per_partition,
                    dtype=params_dtype,
                ),
                input_dim=1,
                output_dim=0,
                weight_loader=weight_loader,
            )
433
434
        layer.register_parameter("weight", weight)

435
436
437
438
        # If checkpoint is serialized fp8, load them.
        # Otherwise, wait until process_weights_after_loading.
        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
439
            if not self.block_quant:
440
441
442
443
444
445
446
                scale = create_fp8_scale_parameter(
                    PerTensorScaleParameter,
                    output_partition_sizes,
                    input_size_per_partition,
                    None,
                    weight_loader,
                )
447
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
448
449
                layer.register_parameter("weight_scale", scale)
            else:
450
451
                assert not self.act_q_static
                assert self.weight_block_size is not None
452
453
454
455
456
457
458
                scale = create_fp8_scale_parameter(
                    BlockQuantScaleParameter,
                    output_partition_sizes,
                    input_size_per_partition,
                    self.weight_block_size,
                    weight_loader,
                )
459
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
460
461
                # The weight_scale_inv name is intentional for deepseekv3
                layer.register_parameter("weight_scale_inv", scale)
462

463
            # INPUT ACTIVATION SCALE
464
            if self.act_q_static:
465
                scale = create_fp8_input_scale(output_partition_sizes, weight_loader)
466
                set_weight_attrs(scale, {"scale_type": "input_scale"})
467
                layer.register_parameter("input_scale", scale)
468
469
            else:
                layer.register_parameter("input_scale", None)
470

471
    def process_weights_after_loading(self, layer: Module) -> None:
472
        size_k_first = True
473
        input_scale = None
474
        # TODO(rob): refactor block quant into separate class.
475
        if self.block_quant:
476
            assert not self.act_q_static
477
            size_k_first = False
478

479
            weight, weight_scale = process_fp8_weight_block_strategy(
480
481
                layer.weight, layer.weight_scale_inv
            )
482
483
484
            # Delete the weight_scale_inv parameter to avoid confusion
            # with the weight_scale parameter
            del layer.weight_scale_inv
485

486
        # If checkpoint not serialized fp8, quantize the weights.
487
        elif not self.quant_config.is_checkpoint_fp8_serialized:
488
            qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
489
            weight = qweight.t()
490

491
        # If checkpoint is fp8 per-tensor, handle that there are N scales for N
492
        # shards in a fused module
493
        else:
494
495
            weight = layer.weight
            weight_scale = layer.weight_scale
496
497
498

            # If using w8a8, torch._scaled_mm needs per tensor, so
            # requantize the logical shards as a single weight.
499
            if not self.use_marlin:
500
501
502
503
504
505
                weight, weight_scale, input_scale = process_fp8_weight_tensor_strategy(
                    weight,
                    weight_scale,
                    layer.logical_widths,
                    getattr(layer, "input_scale", None),
                )
506
507
508
509
510
511
512
513
                if self.act_q_static:
                    assert input_scale is not None
                    input_scale = input_scale.max()
            weight = weight.t()

        # Update layer with new values.
        layer.weight = Parameter(weight.data, requires_grad=False)
        layer.weight_scale = Parameter(weight_scale.data, requires_grad=False)
514
515
516
517
518
        layer.input_scale = (
            Parameter(input_scale, requires_grad=False)
            if input_scale is not None
            else None
        )
519

520
        if self.use_marlin:
521
            prepare_fp8_layer_for_marlin(layer, size_k_first)
522
523
            # Activations not quantized for marlin.
            del layer.input_scale
524
            return
525

526
        if self.block_quant:
527
            maybe_post_process_fp8_weight_block(layer, self.cutlass_block_fp8_supported)
528

529
530
531
532
533
534
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
535
        if self.use_marlin:
536
537
538
539
            return apply_fp8_marlin_linear(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale,
540
541
542
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
543
544
                bias=bias,
            )
545

546
        if self.block_quant:
547
548
549
            assert self.weight_block_size is not None

            return self.w8a8_block_fp8_linear.apply(
550
                input=x,
551
552
553
                weight=layer.weight,
                weight_scale=layer.weight_scale,
                input_scale=layer.input_scale,
554
                bias=bias,
555
            )
556

557
558
559
560
561
562
563
564
        return self.fp8_linear.apply(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            out_dtype=self.out_dtype,
            input_scale=layer.input_scale,
            bias=bias,
        )
565
566


567
568
569
570
571
572
573
574
575
576
577
578
579
class Fp8MoEMethod(FusedMoEMethodBase):
    """MoE method for FP8.
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

580
581
582
    def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
        super().__init__(layer.moe_config)
        self.layer = layer
583
        self.quant_config = quant_config
584
        self.weight_block_size = self.quant_config.weight_block_size
585
        self.block_quant: bool = self.weight_block_size is not None
586

587
        self.fused_experts: Optional[mk.FusedMoEModularKernel] = None  # type: ignore
588

589
        self.fp8_backend = get_fp8_moe_backend(self.block_quant)
590

591
        self.use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN
592
        self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None
593
594
595
596
597
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
            self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
        elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
            self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS

598
        self.allow_deep_gemm = self.fp8_backend == Fp8MoeBackend.DEEPGEMM
599
600
601
        self.allow_cutlass_block_scaled_grouped_gemm = (
            self.fp8_backend == Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM
        )
602

603
604
605
606
607
608
609
610
611
    def create_weights(
        self,
        layer: Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
612
613
614
615
616
617
        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

618
619
        if self.quant_config.is_checkpoint_fp8_serialized:
            params_dtype = torch.float8_e4m3fn
620
        if self.block_quant:
621
622
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
623
624
            tp_size = get_tensor_model_parallel_world_size()
            block_n, block_k = (
625
626
                self.weight_block_size[0],
                self.weight_block_size[1],
627
628
629
630
631
            )
            # NOTE: To ensure proper alignment of the block-wise quantization
            # scales, the output_size of the weights for both the gate and up
            # layers must be divisible by block_n.
            # Required by column parallel or enabling merged weights
632
            if intermediate_size_per_partition % block_n != 0:
633
634
                raise ValueError(
                    f"The output_size of gate's and up's weight = "
635
                    f"{intermediate_size_per_partition} is not divisible by "
636
637
638
                    f"weight quantization block_n = {block_n}."
                )
            if tp_size > 1 and intermediate_size_per_partition % block_k != 0:
639
                # Required by row parallel
640
641
642
                raise ValueError(
                    f"The input_size of down's weight = "
                    f"{intermediate_size_per_partition} is not divisible by "
643
644
                    f"weight quantization block_k = {block_k}."
                )
645
646

        # WEIGHTS
647
648
649
650
651
652
653
654
655
        w13_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                hidden_size,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
656
657
658
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

659
660
661
662
663
664
665
666
667
        w2_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
668
669
670
671
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
672
673
674
        if not self.block_quant:
            # Allocate 2 scales for w1 and w3 respectively.
            # They will be combined to a single scale after weight loading.
675
676
677
678
679
680
            w13_weight_scale = torch.nn.Parameter(
                torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
            )
            w2_weight_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
681
682
683
684
685
686
            layer.register_parameter("w13_weight_scale", w13_weight_scale)
            layer.register_parameter("w2_weight_scale", w2_weight_scale)
        else:
            w13_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
687
                    2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
688
689
690
691
692
693
694
695
696
                    (hidden_size + block_k - 1) // block_k,
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            w2_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
                    (hidden_size + block_n - 1) // block_n,
697
                    (intermediate_size_per_partition + block_k - 1) // block_k,
698
699
700
701
702
703
704
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
            layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
            assert self.quant_config.activation_scheme == "dynamic"
705

706
707
708
        # Add the quantization method used (per tensor/grouped/channel)
        # to ensure the weight scales are loaded in properly
        extra_weight_attrs.update(
709
710
711
712
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
            if self.block_quant
            else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
        )
713
714
715
716
        # If loading fp8 checkpoint, pass the weight loaders.
        # If loading an fp16 checkpoint, do not (we will quantize in
        #   process_weights_after_loading()
        if self.quant_config.is_checkpoint_fp8_serialized:
717
718
            set_weight_attrs(w13_weight_scale, extra_weight_attrs)
            set_weight_attrs(w2_weight_scale, extra_weight_attrs)
719
720
721
722
723
724

        # INPUT_SCALES
        if self.quant_config.activation_scheme == "static":
            if not self.quant_config.is_checkpoint_fp8_serialized:
                raise ValueError(
                    "Found static activation scheme for checkpoint that "
725
726
                    "was not serialized fp8."
                )
727

728
729
730
            w13_input_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
731
            layer.register_parameter("w13_input_scale", w13_input_scale)
732
            set_weight_attrs(w13_input_scale, extra_weight_attrs)
733

734
735
736
            w2_input_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
737
            layer.register_parameter("w2_input_scale", w2_input_scale)
738
739
            set_weight_attrs(w2_input_scale, extra_weight_attrs)

740
        else:
741
742
            layer.w13_input_scale = None
            layer.w2_input_scale = None
743
744

    def process_weights_after_loading(self, layer: Module) -> None:
745
746
        # Lazy import to avoid importing triton too early.
        from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
747
748
749
            is_rocm_aiter_moe_enabled,
            shuffle_weights,
        )
750

751
752
        self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()

753
        # TODO (rob): refactor block quant into separate class.
754
        if self.block_quant:
755
            assert self.quant_config.activation_scheme == "dynamic"
756
            if current_platform.is_fp8_fnuz():
757
                w13_weight, w13_weight_scale_inv, w13_input_scale = (
758
                    normalize_e4m3fn_to_e4m3fnuz(
759
760
761
762
763
764
                        layer.w13_weight,
                        layer.w13_weight_scale_inv,
                        layer.w13_input_scale,
                    )
                )
                w2_weight, w2_weight_scale_inv, w2_input_scale = (
765
                    normalize_e4m3fn_to_e4m3fnuz(
766
767
768
                        layer.w2_weight, layer.w2_weight_scale_inv, layer.w2_input_scale
                    )
                )
769
            elif self.flashinfer_moe_backend is not None:
770
771
                # NOTE: weights have to be swapped since the activation is
                # applied on different half for flashinfer vs vllm
772
                w13_weight = swap_w13_to_w31(layer.w13_weight.data)
773
                w13_weight_scale_inv = swap_w13_to_w31(layer.w13_weight_scale_inv.data)
774
775
                w2_weight = layer.w2_weight.data
                w2_weight_scale_inv = layer.w2_weight_scale_inv.data
776
777
778
779
780
781
782
783
            else:
                w13_weight = layer.w13_weight.data
                w13_weight_scale_inv = layer.w13_weight_scale_inv.data
                w2_weight = layer.w2_weight
                w2_weight_scale_inv = layer.w2_weight_scale_inv

            # torch.compile() cannot use Parameter subclasses.
            layer.w13_weight = Parameter(w13_weight, requires_grad=False)
784
785
786
            layer.w13_weight_scale_inv = Parameter(
                w13_weight_scale_inv, requires_grad=False
            )
787
            layer.w2_weight = Parameter(w2_weight, requires_grad=False)
788
789
790
            layer.w2_weight_scale_inv = Parameter(
                w2_weight_scale_inv, requires_grad=False
            )
791
            if self.rocm_aiter_moe_enabled:
792
793
                # reshaping weights is required for aiter moe kernel.
                shuffled_w13, shuffled_w2 = shuffle_weights(
794
795
                    layer.w13_weight.data, layer.w2_weight.data
                )
796

797
798
                layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
799

800
            # DeepGemm scales need to be transposed and aligned. We try to do
801
            # it ahead of time for performance reasons.
802
            if self.allow_deep_gemm and not is_deep_gemm_e8m0_used():
803
                if expert_weight_is_col_major(layer.w13_weight_scale_inv):
804
805
806
                    layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(
                        layer.w13_weight_scale_inv
                    )
807
                if expert_weight_is_col_major(layer.w2_weight_scale_inv):
808
809
810
                    layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(
                        layer.w2_weight_scale_inv
                    )
811

812
        # If checkpoint is fp16, quantize in place.
813
        elif not self.quant_config.is_checkpoint_fp8_serialized:
814
            fp8_dtype = current_platform.fp8_dtype()
815
            w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
816
            w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
817
818
819

            # Re-initialize w13_scale because we directly quantize
            # merged w13 weights and generate a single scaling factor.
820
821
822
823
824
825
826
827
            layer.w13_weight_scale = torch.nn.Parameter(
                torch.ones(
                    layer.local_num_experts,
                    dtype=torch.float32,
                    device=w13_weight.device,
                ),
                requires_grad=False,
            )
828
            for expert in range(layer.local_num_experts):
829
830
831
832
833
834
835
836
                w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
                    ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
                )
                w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
                    ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
                )
            layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
            layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
837
            if self.rocm_aiter_moe_enabled:
838
                # reshaping weights is required for aiter moe kernel.
839
                shuffled_w13, shuffled_w2 = shuffle_weights(
840
841
                    layer.w13_weight, layer.w2_weight
                )
842

843
844
                layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
845
846
847
848
849
850
851
        # If checkpoint is fp8, we need to handle that the
        # MoE kernels require single activation scale and single weight
        # scale for w13 per expert.
        else:
            # Fp8 moe kernels require a single activation scale.
            # We take the max of all the scales in case they differ.
            if self.quant_config.activation_scheme == "static":
852
                if layer.w13_input_scale is None or layer.w2_input_scale is None:
853
854
                    raise ValueError(
                        "QuantConfig has static quantization, but found "
855
856
857
858
859
                        "activation scales are None."
                    )
                if not all_close_1d(layer.w13_input_scale) or not all_close_1d(
                    layer.w2_input_scale
                ):
860
                    logger.warning_once(
861
862
                        "Found input_scales that are not equal for "
                        "fp8 MoE layer. Using the maximum across experts "
863
864
                        "for each layer."
                    )
865
                layer.w13_input_scale = torch.nn.Parameter(
866
867
                    layer.w13_input_scale.max(), requires_grad=False
                )
868
                layer.w2_input_scale = torch.nn.Parameter(
869
870
                    layer.w2_input_scale.max(), requires_grad=False
                )
871
            if current_platform.is_fp8_fnuz():
872
                # Normalize the weights and scales
873
                w13_weight, w13_weight_scale, w13_input_scale = (
874
                    normalize_e4m3fn_to_e4m3fnuz(
875
876
877
878
                        layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale
                    )
                )
                w2_weight, w2_weight_scale, w2_input_scale = (
879
                    normalize_e4m3fn_to_e4m3fnuz(
880
881
882
                        layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale
                    )
                )
883
                # Reset the parameter
884
                layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
885
                layer.w13_weight_scale = torch.nn.Parameter(
886
887
                    w13_weight_scale, requires_grad=False
                )
888
889
                if w13_input_scale is not None:
                    layer.w13_input_scale = torch.nn.Parameter(
890
891
892
893
894
895
                        w13_input_scale, requires_grad=False
                    )
                layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
                layer.w2_weight_scale = torch.nn.Parameter(
                    w2_weight_scale, requires_grad=False
                )
896
897
                if w2_input_scale is not None:
                    layer.w2_input_scale = torch.nn.Parameter(
898
899
                        w2_input_scale, requires_grad=False
                    )
900
901
902

            # Fp8 moe kernel needs single weight scale for w13 per expert.
            # We take the max then dequant and requant each expert.
903
            assert layer.w13_weight_scale is not None
904
            shard_size = layer.intermediate_size_per_partition
905
            max_w13_scales = layer.w13_weight_scale.max(dim=1).values
906
            for expert_id in range(layer.local_num_experts):
907
908
909
                start = 0
                for shard_id in range(2):
                    dq_weight = per_tensor_dequantize(
910
911
912
913
914
915
                        layer.w13_weight[expert_id][start : start + shard_size, :],
                        layer.w13_weight_scale[expert_id][shard_id],
                    )
                    layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
                        ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
                    )
916
917
                    start += shard_size

918
            if self.rocm_aiter_moe_enabled:
919
                shuffled_w13, shuffled_w2 = shuffle_weights(
920
921
                    layer.w13_weight, layer.w2_weight
                )
922

923
924
                layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
925

926
927
928
            layer.w13_weight_scale = torch.nn.Parameter(
                max_w13_scales, requires_grad=False
            )
929

930
931
932
933
934
935
            if self.flashinfer_moe_backend is not None:
                # NOTE: weights have to be swapped since the activation is
                # applied on different half for flashinfer vs vllm
                assert not self.block_quant
                register_moe_scaling_factors(layer)
                w13_weight = swap_w13_to_w31(layer.w13_weight.data)
936
                if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
937
938
939
                    rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
                layer.w13_weight.data = w13_weight.data

940
941
942
943
944
        if self.use_marlin:
            prepare_moe_fp8_layer_for_marlin(layer, False)
            # Activations not quantized for marlin.
            del layer.w13_input_scale
            del layer.w2_input_scale
945

946
        if is_deep_gemm_e8m0_used() and self.block_quant:
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
            assert layer.weight_block_size is not None
            # Re-quantise the expert weights so their scales are UE8M0.
            block_sz = tuple(layer.weight_block_size)
            requant_weight_ue8m0_inplace(
                layer.w13_weight.data,
                layer.w13_weight_scale_inv.data,
                block_sz,
            )
            requant_weight_ue8m0_inplace(
                layer.w2_weight.data,
                layer.w2_weight_scale_inv.data,
                block_sz,
            )

            # Ensure column-major TMA alignment expected by DeepGEMM.
962
            if expert_weight_is_col_major(layer.w13_weight_scale_inv):
963
                layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(
964
965
                    layer.w13_weight_scale_inv
                )
966
            if expert_weight_is_col_major(layer.w2_weight_scale_inv):
967
                layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(
968
969
                    layer.w2_weight_scale_inv
                )
970

971
972
973
974
975
976
    def maybe_make_prepare_finalize(self) -> Optional[mk.FusedMoEPrepareAndFinalize]:
        if (
            self.rocm_aiter_moe_enabled
            or self.use_marlin
            or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
977
978
            return None
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
979
980
981
            prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
                self.moe
            )
982
983
984
985
986
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
        else:
            return super().maybe_make_prepare_finalize()

bnellnm's avatar
bnellnm committed
987
988
989
    def select_gemm_impl(
        self,
        prepare_finalize: FusedMoEPrepareAndFinalize,
990
        layer: torch.nn.Module,
bnellnm's avatar
bnellnm committed
991
    ) -> FusedMoEPermuteExpertsUnpermute:
992
        from vllm.model_executor.layers.fused_moe import (
993
994
995
            BatchedTritonOrDeepGemmExperts,
            TritonOrDeepGemmExperts,
        )
996

997
        assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
998
999
            "Marlin and ROCm AITER are not supported with all2all yet."
        )
1000

1001
1002
        assert self.moe_quant_config is not None

1003
1004
1005
1006
1007
        if (
            prepare_finalize.activation_format
            == FusedMoEActivationFormat.BatchedExperts
        ):
            max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
bnellnm's avatar
bnellnm committed
1008
1009
1010
1011
            assert max_num_tokens_per_rank is not None
            logger.debug(
                "BatchedTritonOrDeepGemmExperts(%s): "
                "max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
1012
1013
1014
1015
1016
                self.__class__.__name__,
                max_num_tokens_per_rank,
                self.weight_block_size,
                False,
            )
bnellnm's avatar
bnellnm committed
1017
            return BatchedTritonOrDeepGemmExperts(
1018
                max_num_tokens=max_num_tokens_per_rank,
1019
                num_dispatchers=prepare_finalize.num_dispatchers(),
1020
                quant_config=self.moe_quant_config,
1021
                allow_deep_gemm=self.allow_deep_gemm,
1022
            )
1023
1024
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
            experts = select_cutlass_fp8_gemm_impl(
1025
1026
                self.moe,
                self.moe_quant_config,
1027
1028
1029
            )
            logger.debug_once("Using %s", experts.__class__.__name__)
            return experts
1030
        else:
bnellnm's avatar
bnellnm committed
1031
1032
            logger.debug(
                "TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s",
1033
1034
1035
1036
                self.__class__.__name__,
                self.weight_block_size,
                False,
            )
bnellnm's avatar
bnellnm committed
1037
            return TritonOrDeepGemmExperts(
1038
                quant_config=self.moe_quant_config,
1039
1040
1041
                allow_deep_gemm=self.allow_deep_gemm,
            )

1042
    def get_fused_moe_quant_config(
1043
1044
        self, layer: torch.nn.Module
    ) -> Optional[FusedMoEQuantConfig]:
1045
1046
1047
1048
        if self.use_marlin:
            return None

        return fp8_w8a8_moe_quant_config(
1049
1050
1051
1052
1053
1054
1055
1056
            w1_scale=(
                layer.w13_weight_scale_inv
                if self.block_quant
                else layer.w13_weight_scale
            ),
            w2_scale=(
                layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
            ),
1057
1058
            a1_scale=layer.w13_input_scale,
            a2_scale=layer.w2_input_scale,
1059
            block_shape=self.weight_block_size,
1060
1061
        )

1062
1063
1064
1065
1066
1067
1068
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
1069
        use_grouped_topk: bool = False,
1070
1071
        topk_group: Optional[int] = None,
        num_expert_group: Optional[int] = None,
1072
1073
        global_num_experts: int = -1,
        expert_map: Optional[torch.Tensor] = None,
1074
        custom_routing_function: Optional[Callable] = None,
Simon Mo's avatar
Simon Mo committed
1075
        scoring_func: str = "softmax",
1076
        routed_scaling_factor: float = 1.0,
Simon Mo's avatar
Simon Mo committed
1077
        e_score_correction_bias: Optional[torch.Tensor] = None,
1078
        apply_router_weight_on_input: bool = False,
Michael Goin's avatar
Michael Goin committed
1079
        activation: str = "silu",
1080
1081
1082
1083
        enable_eplb: bool = False,
        expert_load_view: Optional[torch.Tensor] = None,
        logical_to_physical_map: Optional[torch.Tensor] = None,
        logical_replica_count: Optional[torch.Tensor] = None,
1084
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
1085
1086
1087
1088
1089
        if enable_eplb:
            assert expert_load_view is not None
            assert logical_to_physical_map is not None
            assert logical_replica_count is not None
            assert isinstance(layer, FusedMoE)
1090

1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
        if (
            self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
            and self.fused_experts is None
        ):
            assert activation == "silu", (
                f"Expected 'silu' activation but got {activation}"
            )
            assert scoring_func == "sigmoid", (
                f"Expected 'sigmoid' scoring func but got {scoring_func}"
            )
1101
            if self.block_quant:
1102
                import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe  # noqa: E501, F401
1103
1104
1105
1106
1107
1108
1109
1110
1111

                assert (
                    renormalize and use_grouped_topk and custom_routing_function is None
                )
                e_score_correction_bias = (
                    e_score_correction_bias.to(x.dtype)
                    if e_score_correction_bias is not None
                    else None
                )
1112
                return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
                    routing_logits=router_logits.to(torch.float32),
                    routing_bias=e_score_correction_bias,
                    x=x,
                    w13_weight=layer.w13_weight,
                    w13_weight_scale_inv=layer.w13_weight_scale_inv,
                    w2_weight=layer.w2_weight,
                    w2_weight_scale_inv=layer.w2_weight_scale_inv,
                    global_num_experts=global_num_experts,
                    top_k=top_k,
                    num_expert_group=num_expert_group,
                    topk_group=topk_group,
                    intermediate_size=layer.intermediate_size_per_partition,
                    expert_offset=layer.ep_rank * layer.local_num_experts,
                    local_num_experts=layer.local_num_experts,
1127
                    block_shape=self.weight_block_size,
1128
                    routed_scaling=routed_scaling_factor,
1129
1130
                )
            else:
1131
                assert not renormalize and custom_routing_function is not None
XuruiYang's avatar
XuruiYang committed
1132
                result = apply_flashinfer_per_tensor_scale_fp8(
1133
1134
1135
1136
1137
1138
1139
1140
                    layer=layer,
                    hidden_states=x,
                    router_logits=router_logits,
                    routing_bias=e_score_correction_bias,
                    global_num_experts=global_num_experts,
                    top_k=top_k,
                    num_expert_group=num_expert_group,
                    topk_group=topk_group,
1141
1142
                    apply_router_weight_on_input=apply_router_weight_on_input,
                )
1143

1144
1145
        zero_expert_num = getattr(layer, "zero_expert_num", 0)
        zero_expert_type = getattr(layer, "zero_expert_type", None)
XuruiYang's avatar
XuruiYang committed
1146
1147

        select_result = FusedMoE.select_experts(
1148
1149
1150
1151
1152
1153
1154
1155
1156
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
            num_expert_group=num_expert_group,
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
1157
            routed_scaling_factor=routed_scaling_factor,
1158
1159
1160
1161
1162
1163
1164
            e_score_correction_bias=e_score_correction_bias,
            indices_type=self.topk_indices_dtype,
            enable_eplb=enable_eplb,
            expert_map=expert_map,
            expert_load_view=expert_load_view,
            logical_to_physical_map=logical_to_physical_map,
            logical_replica_count=logical_replica_count,
XuruiYang's avatar
XuruiYang committed
1165
1166
1167
            global_num_experts=global_num_experts,
            zero_expert_num=zero_expert_num,
            zero_expert_type=zero_expert_type,
1168
1169
        )

1170
1171
1172
1173
        #
        # Note: the order of checks is important since self.fused_experts
        # can override fused_experts or cutlass but not rocm or marlin.
        #
XuruiYang's avatar
XuruiYang committed
1174
1175
        topk_weights, topk_ids, zero_expert_result = select_result

1176
1177
        if self.rocm_aiter_moe_enabled:
            from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (  # noqa: E501
1178
1179
1180
                rocm_aiter_fused_experts,
            )

1181
            assert self.fused_experts is None
XuruiYang's avatar
XuruiYang committed
1182
            result = rocm_aiter_fused_experts(
1183
1184
1185
                x,
                layer.w13_weight,
                layer.w2_weight,
1186
1187
1188
1189
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                activation=activation,
                apply_router_weight_on_input=apply_router_weight_on_input,
1190
                expert_map=expert_map,
1191
1192
                quant_config=self.moe_quant_config,
            )
1193
        elif self.use_marlin:
1194
            assert activation == "silu", f"{activation} not supported for Marlin MoE."
1195
            assert self.fused_experts is None
XuruiYang's avatar
XuruiYang committed
1196
            result = torch.ops.vllm.fused_marlin_moe(
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
                x,
                layer.w13_weight,
                layer.w2_weight,
                None,
                None,
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                router_logits,
                topk_weights,
                topk_ids,
                quant_type_id=scalar_types.float8_e4m3fn.id,
                apply_router_weight_on_input=apply_router_weight_on_input,
                global_num_experts=global_num_experts,
1210
                expert_map=expert_map,
1211
1212
                workspace=layer.workspace,
            )
1213
        elif self.fused_experts:
XuruiYang's avatar
XuruiYang committed
1214
            result = self.fused_experts(
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                inplace=True,
                activation=activation,
                global_num_experts=global_num_experts,
                apply_router_weight_on_input=apply_router_weight_on_input,
                expert_map=expert_map,
            )
1226
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
1227
            assert not self.block_quant
1228
1229
1230
1231
1232
1233
1234
            assert not renormalize and custom_routing_function is not None
            assert activation == "silu", (
                f"Expected 'silu' activation but got {activation}"
            )
            assert scoring_func == "sigmoid", (
                f"Expected 'sigmoid' scoring func but got {scoring_func}"
            )
1235

XuruiYang's avatar
XuruiYang committed
1236
            result = flashinfer_cutlass_moe_fp8(
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
                x,
                layer,
                topk_weights,
                topk_ids,
                inplace=False,
                activation=activation,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
1247
        else:
1248
            from vllm.model_executor.layers.fused_moe import fused_experts
1249

XuruiYang's avatar
XuruiYang committed
1250
            result = fused_experts(
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                inplace=True,
                activation=activation,
                global_num_experts=global_num_experts,
                apply_router_weight_on_input=apply_router_weight_on_input,
                expert_map=expert_map,
1261
1262
1263
                quant_config=self.moe_quant_config,
                allow_deep_gemm=self.allow_deep_gemm,
                allow_cutlass_block_scaled_grouped_gemm=(
1264
1265
1266
                    self.allow_cutlass_block_scaled_grouped_gemm
                ),
            )
XuruiYang's avatar
XuruiYang committed
1267
        if zero_expert_num != 0 and zero_expert_type is not None:
1268
            assert not isinstance(result, tuple), (
XuruiYang's avatar
XuruiYang committed
1269
                "Shared + zero experts are mutually exclusive not yet supported"
1270
            )
XuruiYang's avatar
XuruiYang committed
1271
1272
1273
            return result, zero_expert_result
        else:
            return result
1274
1275


1276
1277
1278
class Fp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
1279
1280
1281
    """

    def __init__(self, quant_config: Fp8Config):
1282
        super().__init__(quant_config)