"vscode:/vscode.git/clone" did not exist on "8bf99b0b87ea99195d651710c1d71fb2e97962f7"
fp8.py 53.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Callable
5
from enum import Enum
6
from functools import partial
7
from typing import TYPE_CHECKING, Any, Optional
8
9
10
11
12

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter

13
import vllm.envs as envs
14
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
15
from vllm import _custom_ops as ops
16
from vllm._aiter_ops import rocm_aiter_ops
17
from vllm.distributed import get_tensor_model_parallel_world_size
18
from vllm.logger import init_logger
19
from vllm.model_executor.layers.batch_invariant import (
20
    vllm_is_batch_invariant,
21
)
bnellnm's avatar
bnellnm committed
22
from vllm.model_executor.layers.fused_moe import (
23
24
25
26
27
28
29
    FusedMoE,
    FusedMoEActivationFormat,
    FusedMoEMethodBase,
    FusedMoEPermuteExpertsUnpermute,
    FusedMoEPrepareAndFinalize,
    FusedMoeWeightScaleSupported,
)
30
from vllm.model_executor.layers.fused_moe.config import (
31
    FusedMoEQuantConfig,
32
    RoutingMethodType,
33
34
    fp8_w8a8_moe_quant_config,
)
35
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
36
37
38
39
40
41
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
)
42
from vllm.model_executor.layers.quantization import QuantizationMethods
43
from vllm.model_executor.layers.quantization.base_config import (
44
45
46
    QuantizationConfig,
    QuantizeMethodBase,
)
47
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
48
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
49
50
    FlashinferMoeBackend,
    apply_flashinfer_per_tensor_scale_fp8,
51
    build_flashinfer_fp8_cutlass_moe_prepare_finalize,
52
53
54
55
56
57
58
    flashinfer_cutlass_moe_fp8,
    get_flashinfer_moe_backend,
    register_moe_scaling_factors,
    rotate_flashinfer_fp8_moe_weights,
    select_cutlass_fp8_gemm_impl,
    swap_w13_to_w31,
)
59
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
60
61
62
63
    W8A8BlockFp8LinearOp,
    create_fp8_input_scale,
    create_fp8_scale_parameter,
    create_fp8_weight_parameter,
64
    deepgemm_post_process_fp8_weight_block,
65
66
67
68
69
    maybe_post_process_fp8_weight_block,
    process_fp8_weight_block_strategy,
    process_fp8_weight_tensor_strategy,
    validate_fp8_block_shape,
)
70
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
71
72
73
74
    apply_fp8_marlin_linear,
    prepare_fp8_layer_for_marlin,
    prepare_moe_fp8_layer_for_marlin,
)
75
from vllm.model_executor.layers.quantization.utils.quant_utils import (
76
77
78
    GroupShape,
    is_layer_skipped,
)
79
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
80
81
82
83
84
85
86
87
88
89
90
91
92
    Fp8LinearOp,
    all_close_1d,
    cutlass_block_fp8_supported,
    cutlass_fp8_supported,
    maybe_create_device_identity,
    normalize_e4m3fn_to_e4m3fnuz,
    per_tensor_dequantize,
)
from vllm.model_executor.parameter import (
    BlockQuantScaleParameter,
    ModelWeightParameter,
    PerTensorScaleParameter,
)
93
from vllm.model_executor.utils import set_weight_attrs
94
from vllm.platforms import current_platform
95
from vllm.scalar_type import scalar_types
96
97
98
99
from vllm.utils.deep_gemm import (
    is_deep_gemm_e8m0_used,
    is_deep_gemm_supported,
)
100
from vllm.utils.flashinfer import has_flashinfer_moe
101
from vllm.utils.import_utils import has_deep_gemm
102

103
104
105
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

106
107
108
109
ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)

110

111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
class Fp8MoeBackend(Enum):
    NONE = 0
    FLASHINFER_TRTLLM = 1
    FLASHINFER_CUTLASS = 2
    DEEPGEMM = 3
    CUTLASS_BLOCK_SCALED_GROUPED_GEMM = 4
    MARLIN = 5
    TRITON = 6


def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend:
    """
    Select the primary FP8 MoE backend
    Note: Shape-specific fallbacks may still occur at runtime.
    """
126
    # Prefer FlashInfer backends on supported GPUs; allow SM90 and SM100.
127
128
    if (
        current_platform.is_cuda()
129
130
131
132
        and (
            current_platform.is_device_capability(100)
            or current_platform.is_device_capability(90)
        )
133
134
135
        and envs.VLLM_USE_FLASHINFER_MOE_FP8
        and has_flashinfer_moe()
    ):
136
137
        backend = get_flashinfer_moe_backend()
        if backend == FlashinferMoeBackend.TENSORRT_LLM:
138
            logger.info_once("Using FlashInfer FP8 MoE TRTLLM backend for SM100")
139
140
            return Fp8MoeBackend.FLASHINFER_TRTLLM
        else:
141
            if block_quant and current_platform.is_device_capability(100):
142
143
144
145
146
147
                raise ValueError(
                    "FlashInfer FP8 MoE throughput backend does not "
                    "support block quantization. Please use "
                    "VLLM_FLASHINFER_MOE_BACKEND=latency "
                    "instead."
                )
148
            logger.info_once("Using FlashInfer FP8 MoE CUTLASS backend for SM90/SM100")
149
150
151
            return Fp8MoeBackend.FLASHINFER_CUTLASS

    # weight-only path for older GPUs without native FP8
152
153
154
155
    use_marlin = (
        not current_platform.has_device_capability(89)
        or envs.VLLM_TEST_FORCE_FP8_MARLIN
    )
156
157
158
159
160
161
162
    if current_platform.is_rocm():
        use_marlin = False
    if use_marlin:
        logger.info_once("Using Marlin backend for FP8 MoE")
        return Fp8MoeBackend.MARLIN

    # deepGEMM on supported platforms with block-quantized weights
163
    if envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM and block_quant:
164
        if not has_deep_gemm():
165
            logger.warning_once("DeepGEMM backend requested but not available.")
166
167
168
169
170
        elif is_deep_gemm_supported():
            logger.info_once("Using DeepGEMM backend for FP8 MoE")
            return Fp8MoeBackend.DEEPGEMM

    # CUTLASS BlockScaled GroupedGemm on SM100 with block-quantized weights
171
172
173
174
175
176
    if (
        current_platform.is_cuda()
        and current_platform.is_device_capability(100)
        and block_quant
    ):
        logger.info_once("Using Cutlass BlockScaled GroupedGemm backend for FP8 MoE")
177
178
179
180
181
182
183
        return Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM

    # default to Triton
    logger.info_once("Using Triton backend for FP8 MoE")
    return Fp8MoeBackend.TRITON


184
class Fp8Config(QuantizationConfig):
185
186
    """Config class for FP8."""

187
188
    def __init__(
        self,
189
        is_checkpoint_fp8_serialized: bool = False,
190
        activation_scheme: str = "dynamic",
191
192
        ignored_layers: list[str] | None = None,
        weight_block_size: list[int] | None = None,
193
    ) -> None:
194
        super().__init__()
195

196
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
197

198
        if activation_scheme not in ACTIVATION_SCHEMES:
199
            raise ValueError(f"Unsupported activation scheme {activation_scheme}")
200
        self.activation_scheme = activation_scheme
201
        self.ignored_layers = ignored_layers or []
202
203
204
205
        if weight_block_size is not None:
            if not is_checkpoint_fp8_serialized:
                raise ValueError(
                    "The block-wise quantization only supports fp8-serialized "
206
207
                    "checkpoint for now."
                )
208
209
210
            if len(weight_block_size) != 2:
                raise ValueError(
                    "The quantization block size of weight must have 2 "
211
212
                    f"dimensions, but got {len(weight_block_size)} dimensions"
                )
213
            if activation_scheme != "dynamic":
214
215
216
217
218
                raise ValueError(
                    "The block-wise quantization only supports "
                    "dynamic activation scheme for now, but got "
                    f"{activation_scheme} activation scheme."
                )
219
        self.weight_block_size = weight_block_size
220

221
    @classmethod
222
    def get_name(cls) -> QuantizationMethods:
223
224
225
        return "fp8"

    @classmethod
226
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
227
228
229
230
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
231
        return 80
232
233

    @classmethod
234
    def get_config_filenames(cls) -> list[str]:
235
236
        return []

237
238
    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if self.ignored_layers is not None:
239
            self.ignored_layers = hf_to_vllm_mapper.apply_list(self.ignored_layers)
240

241
    @classmethod
242
    def from_config(cls, config: dict[str, Any]) -> "Fp8Config":
243
        quant_method = cls.get_from_keys(config, ["quant_method"])
244
        is_checkpoint_fp8_serialized = "fp8" in quant_method
245
        activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
246
        ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
247
        weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None)
248
        if not ignored_layers:
249
250
251
252
253
254
255
256
257
258
259
260
261
            ignored_layers = cls.get_from_keys_or(
                config, ["modules_to_not_convert"], None
            )
        return cls(
            is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
            activation_scheme=activation_scheme,
            ignored_layers=ignored_layers,
            weight_block_size=weight_block_size,
        )

    def get_xpu_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
262
263
        from vllm.attention.layer import Attention
        from vllm.model_executor.layers.quantization.ipex_quant import (
264
265
266
267
            XPUFp8LinearMethod,
            XPUFp8MoEMethod,
        )

268
269
270
271
        fp8_config = Fp8Config(
            is_checkpoint_fp8_serialized=self.is_checkpoint_fp8_serialized,
            activation_scheme=self.activation_scheme,
            ignored_layers=self.ignored_layers,
272
273
            weight_block_size=self.weight_block_size,
        )
274
275

        if isinstance(layer, LinearBase):
276
277
278
279
280
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
281
282
283
284
285
286
287
288
                return UnquantizedLinearMethod()
            return XPUFp8LinearMethod(fp8_config)
        elif isinstance(layer, FusedMoE):
            return XPUFp8MoEMethod(fp8_config, layer)
        elif isinstance(layer, Attention):
            return Fp8KVCacheMethod(self)
        return None

289
290
291
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
292
293
        from vllm.attention.layer import Attention  # Avoid circular import

294
295
        if current_platform.is_xpu():
            return self.get_xpu_quant_method(layer, prefix)
296
        if isinstance(layer, LinearBase):
297
298
299
300
301
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
302
                return UnquantizedLinearMethod()
303
            return Fp8LinearMethod(self)
304
        elif isinstance(layer, FusedMoE):
305
306
307
308
309
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
XuruiYang's avatar
XuruiYang committed
310
                return UnquantizedFusedMoEMethod(layer.moe_config)
311
            return Fp8MoEMethod(self, layer)
312
        elif isinstance(layer, Attention):
313
            return Fp8KVCacheMethod(self)
314
        return None
315

316
    def get_cache_scale(self, name: str) -> str | None:
317
318
319
320
321
322
323
324
325
326
327
328
        """
        Check whether the param name matches the format for k/v cache scales
        in compressed-tensors. If this is the case, return its equivalent
        param name expected by vLLM

        :param name: param name
        :return: matching param name for KV cache scale in vLLM
        """
        if name.endswith(".output_scale") and ".k_proj" in name:
            return name.replace(".k_proj.output_scale", ".attn.k_scale")
        if name.endswith(".output_scale") and ".v_proj" in name:
            return name.replace(".v_proj.output_scale", ".attn.v_scale")
329
330
331
332
333
        if name.endswith(".output_scale") and ".q_proj" in name:
            return name.replace(".q_proj.output_scale", ".attn.q_scale")
        if name.endswith("self_attn.prob_output_scale"):
            return name.replace(".prob_output_scale", ".attn.prob_scale")
        # If no matches, return None
334
335
        return None

336
337
338

class Fp8LinearMethod(LinearMethodBase):
    """Linear method for FP8.
339
340
341
342
343
344
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.
345
346
347
348
349

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
    2. Only support float8_e4m3fn data type due to the limitation of
       torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
350

351
352
353
354
    Args:
        quant_config: The quantization config.
    """

355
    def __init__(self, quant_config: Fp8Config):
356
        self.quant_config = quant_config
357
        self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
358
        self.out_dtype = torch.get_default_dtype()
359

360
361
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
362
363
364
365
        self.use_marlin = (
            not current_platform.has_device_capability(89)
            or envs.VLLM_TEST_FORCE_FP8_MARLIN
        )
366
        # Disable marlin for rocm
367
        if current_platform.is_rocm():
368
            self.use_marlin = False
369
        if vllm_is_batch_invariant():
370
            self.use_marlin = False
371

372
        self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enaled()
373
        self.use_deep_gemm = is_deep_gemm_supported()
374

375
376
        self.weight_block_size = self.quant_config.weight_block_size
        self.block_quant = self.weight_block_size is not None
377
        self.act_q_static = self.quant_config.activation_scheme == "static"
378
379
        if self.weight_block_size:
            self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
380
        else:
381
382
383
384
385
            # Use per-token quantization for better perf if dynamic and cutlass
            if not self.act_q_static and cutlass_fp8_supported():
                self.act_q_group_shape = GroupShape.PER_TOKEN
            else:
                self.act_q_group_shape = GroupShape.PER_TENSOR
386

387
388
389
390
391
392
393
394
395
396
397
398
        if self.block_quant:
            assert not self.act_q_static
            assert self.weight_block_size is not None
            self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
                weight_group_shape=GroupShape(*self.weight_block_size),
                act_quant_group_shape=self.act_q_group_shape,
                cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
                use_aiter_and_is_supported=self.use_aiter_and_is_supported,
            )
        else:
            self.fp8_linear = Fp8LinearOp(
                act_quant_static=self.act_q_static,
399
400
                act_quant_group_shape=self.act_q_group_shape,
            )
401

402
403
404
405
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
406
        output_partition_sizes: list[int],
407
408
409
410
411
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
412
413
        maybe_create_device_identity()

414
        output_size_per_partition = sum(output_partition_sizes)
415
        weight_loader = extra_weight_attrs.get("weight_loader")
416
417
418
419
420
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None
421

422
        if self.block_quant:
423
424
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
425
426
427
428
429
430
431
432
            validate_fp8_block_shape(
                layer,
                input_size,
                output_size,
                input_size_per_partition,
                output_partition_sizes,
                self.weight_block_size,
            )
433

434
        # WEIGHT
435
        if self.quant_config.is_checkpoint_fp8_serialized:
436
437
438
            weight = create_fp8_weight_parameter(
                output_size_per_partition, input_size_per_partition, weight_loader
            )
439
440
        else:
            # For non-serialized checkpoints, use original dtype
441
442
443
444
445
446
447
448
449
450
            weight = ModelWeightParameter(
                data=torch.empty(
                    output_size_per_partition,
                    input_size_per_partition,
                    dtype=params_dtype,
                ),
                input_dim=1,
                output_dim=0,
                weight_loader=weight_loader,
            )
451
452
        layer.register_parameter("weight", weight)

453
454
455
456
        # If checkpoint is serialized fp8, load them.
        # Otherwise, wait until process_weights_after_loading.
        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
457
            if not self.block_quant:
458
459
460
461
462
463
464
                scale = create_fp8_scale_parameter(
                    PerTensorScaleParameter,
                    output_partition_sizes,
                    input_size_per_partition,
                    None,
                    weight_loader,
                )
465
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
466
467
                layer.register_parameter("weight_scale", scale)
            else:
468
469
                assert not self.act_q_static
                assert self.weight_block_size is not None
470
471
472
473
474
475
476
                scale = create_fp8_scale_parameter(
                    BlockQuantScaleParameter,
                    output_partition_sizes,
                    input_size_per_partition,
                    self.weight_block_size,
                    weight_loader,
                )
477
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
478
479
                # The weight_scale_inv name is intentional for deepseekv3
                layer.register_parameter("weight_scale_inv", scale)
480

481
            # INPUT ACTIVATION SCALE
482
            if self.act_q_static:
483
                scale = create_fp8_input_scale(output_partition_sizes, weight_loader)
484
                set_weight_attrs(scale, {"scale_type": "input_scale"})
485
                layer.register_parameter("input_scale", scale)
486
487
            else:
                layer.register_parameter("input_scale", None)
488

489
    def process_weights_after_loading(self, layer: Module) -> None:
490
        size_k_first = True
491
        input_scale = None
492
        # TODO(rob): refactor block quant into separate class.
493
        if self.block_quant:
494
            assert not self.act_q_static
495
            size_k_first = False
496

497
            weight, weight_scale = process_fp8_weight_block_strategy(
498
499
                layer.weight, layer.weight_scale_inv
            )
500
501
502
            # Delete the weight_scale_inv parameter to avoid confusion
            # with the weight_scale parameter
            del layer.weight_scale_inv
503

504
        # If checkpoint not serialized fp8, quantize the weights.
505
        elif not self.quant_config.is_checkpoint_fp8_serialized:
506
            qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
507
            weight = qweight.t()
508

509
        # If checkpoint is fp8 per-tensor, handle that there are N scales for N
510
        # shards in a fused module
511
        else:
512
513
            weight = layer.weight
            weight_scale = layer.weight_scale
514
515
516

            # If using w8a8, torch._scaled_mm needs per tensor, so
            # requantize the logical shards as a single weight.
517
            if not self.use_marlin:
518
519
520
521
522
523
                weight, weight_scale, input_scale = process_fp8_weight_tensor_strategy(
                    weight,
                    weight_scale,
                    layer.logical_widths,
                    getattr(layer, "input_scale", None),
                )
524
525
526
527
528
529
530
531
                if self.act_q_static:
                    assert input_scale is not None
                    input_scale = input_scale.max()
            weight = weight.t()

        # Update layer with new values.
        layer.weight = Parameter(weight.data, requires_grad=False)
        layer.weight_scale = Parameter(weight_scale.data, requires_grad=False)
532
533
534
535
536
        layer.input_scale = (
            Parameter(input_scale, requires_grad=False)
            if input_scale is not None
            else None
        )
537

538
        if self.use_marlin:
539
            prepare_fp8_layer_for_marlin(layer, size_k_first)
540
541
            # Activations not quantized for marlin.
            del layer.input_scale
542
            return
543

544
        if self.block_quant:
545
            maybe_post_process_fp8_weight_block(layer)
546

547
548
549
550
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
551
        bias: torch.Tensor | None = None,
552
    ) -> torch.Tensor:
553
554
        # if batch invariant mode is enabled, prefer DeepGEMM FP8 path
        # we will use BF16 dequant when DeepGEMM is not supported.
555
        if vllm_is_batch_invariant():
556
557
            if self.block_quant:
                assert self.weight_block_size is not None
558
559
560
561
562
563
564
                return self.w8a8_block_fp8_linear.apply(
                    input=x,
                    weight=layer.weight,
                    weight_scale=layer.weight_scale,
                    input_scale=layer.input_scale,
                    bias=bias,
                )
565
            else:
566
567
568
                # per-tensor/channel: dequant to BF16 and run GEMM
                weight_fp8 = layer.weight.to(torch.bfloat16)
                weight_scale = layer.weight_scale.to(torch.bfloat16)
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
                if weight_scale.numel() == 1:
                    # Per-tensor: simple scalar multiplication
                    weight_bf16 = weight_fp8 * weight_scale
                else:
                    # Multiple scales (fused modules like QKV)
                    # Try to infer correct broadcasting
                    # weight is [K, N], scale could be [num_logical_weights]
                    # Need to figure out how to broadcast - for now just try
                    # direct multiplication
                    if (
                        weight_scale.dim() == 1
                        and weight_scale.shape[0] == weight_fp8.shape[0]
                    ):
                        # Per-row scaling
                        weight_bf16 = weight_fp8 * weight_scale.unsqueeze(1)
                    else:
                        # Fallback
                        weight_bf16 = weight_fp8 * weight_scale
587
                return torch.nn.functional.linear(x, weight_bf16.t(), bias)
588

589
        if self.use_marlin:
590
591
592
593
            return apply_fp8_marlin_linear(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale,
594
595
596
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
597
598
                bias=bias,
            )
599

600
        if self.block_quant:
601
602
603
            assert self.weight_block_size is not None

            return self.w8a8_block_fp8_linear.apply(
604
                input=x,
605
606
607
                weight=layer.weight,
                weight_scale=layer.weight_scale,
                input_scale=layer.input_scale,
608
                bias=bias,
609
            )
610

611
612
613
614
615
616
617
618
        return self.fp8_linear.apply(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            out_dtype=self.out_dtype,
            input_scale=layer.input_scale,
            bias=bias,
        )
619
620


621
622
623
624
625
626
627
628
629
630
631
632
633
class Fp8MoEMethod(FusedMoEMethodBase):
    """MoE method for FP8.
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

634
635
636
    def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
        super().__init__(layer.moe_config)
        self.layer = layer
637
        self.quant_config = quant_config
638
        self.weight_block_size = self.quant_config.weight_block_size
639
        self.block_quant: bool = self.weight_block_size is not None
640
        self.fp8_backend = get_fp8_moe_backend(self.block_quant)
641

642
        self.use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN
643
        self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
644
645
646
647
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
            self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
        elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
            self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS
648
649
650
651
652
653
654
655
656
657
            if self.block_quant:
                assert self.weight_block_size == [128, 128], (
                    f"Only support weight_block_size == [128, 128], "
                    f"got {self.weight_block_size}"
                )
            self.flashinfer_moe_fn = partial(
                flashinfer_cutlass_moe_fp8,
                moe=self.moe,
                use_deepseek_fp8_block_scale=self.block_quant,
            )
658

659
        self.allow_deep_gemm = self.fp8_backend == Fp8MoeBackend.DEEPGEMM
660
661
662
        self.allow_cutlass_block_scaled_grouped_gemm = (
            self.fp8_backend == Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM
        )
663

664
665
666
667
668
669
670
671
672
    def create_weights(
        self,
        layer: Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
673
674
675
676
677
678
        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

679
680
        if self.quant_config.is_checkpoint_fp8_serialized:
            params_dtype = torch.float8_e4m3fn
681
        if self.block_quant:
682
683
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
684
685
            tp_size = get_tensor_model_parallel_world_size()
            block_n, block_k = (
686
687
                self.weight_block_size[0],
                self.weight_block_size[1],
688
689
690
691
692
            )
            # NOTE: To ensure proper alignment of the block-wise quantization
            # scales, the output_size of the weights for both the gate and up
            # layers must be divisible by block_n.
            # Required by column parallel or enabling merged weights
693
            if intermediate_size_per_partition % block_n != 0:
694
695
                raise ValueError(
                    f"The output_size of gate's and up's weight = "
696
                    f"{intermediate_size_per_partition} is not divisible by "
697
698
699
                    f"weight quantization block_n = {block_n}."
                )
            if tp_size > 1 and intermediate_size_per_partition % block_k != 0:
700
                # Required by row parallel
701
702
703
                raise ValueError(
                    f"The input_size of down's weight = "
                    f"{intermediate_size_per_partition} is not divisible by "
704
705
                    f"weight quantization block_k = {block_k}."
                )
706
707

        # WEIGHTS
708
709
710
711
712
713
714
715
716
        w13_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                hidden_size,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
717
718
719
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

720
721
722
723
724
725
726
727
728
        w2_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
729
730
731
732
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
733
734
735
        if not self.block_quant:
            # Allocate 2 scales for w1 and w3 respectively.
            # They will be combined to a single scale after weight loading.
736
737
738
739
740
741
            w13_weight_scale = torch.nn.Parameter(
                torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
            )
            w2_weight_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
742
743
744
745
746
747
            layer.register_parameter("w13_weight_scale", w13_weight_scale)
            layer.register_parameter("w2_weight_scale", w2_weight_scale)
        else:
            w13_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
748
                    2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
749
750
751
752
753
754
755
756
757
                    (hidden_size + block_k - 1) // block_k,
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            w2_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
                    (hidden_size + block_n - 1) // block_n,
758
                    (intermediate_size_per_partition + block_k - 1) // block_k,
759
760
761
762
763
764
765
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
            layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
            assert self.quant_config.activation_scheme == "dynamic"
766

767
768
769
        # Add the quantization method used (per tensor/grouped/channel)
        # to ensure the weight scales are loaded in properly
        extra_weight_attrs.update(
770
771
772
773
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
            if self.block_quant
            else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
        )
774
775
776
777
        # If loading fp8 checkpoint, pass the weight loaders.
        # If loading an fp16 checkpoint, do not (we will quantize in
        #   process_weights_after_loading()
        if self.quant_config.is_checkpoint_fp8_serialized:
778
779
            set_weight_attrs(w13_weight_scale, extra_weight_attrs)
            set_weight_attrs(w2_weight_scale, extra_weight_attrs)
780
781
782
783
784
785

        # INPUT_SCALES
        if self.quant_config.activation_scheme == "static":
            if not self.quant_config.is_checkpoint_fp8_serialized:
                raise ValueError(
                    "Found static activation scheme for checkpoint that "
786
787
                    "was not serialized fp8."
                )
788

789
790
791
            w13_input_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
792
            layer.register_parameter("w13_input_scale", w13_input_scale)
793
            set_weight_attrs(w13_input_scale, extra_weight_attrs)
794

795
796
797
            w2_input_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
798
            layer.register_parameter("w2_input_scale", w2_input_scale)
799
800
            set_weight_attrs(w2_input_scale, extra_weight_attrs)

801
        else:
802
803
            layer.w13_input_scale = None
            layer.w2_input_scale = None
804

805
806
        self.rocm_aiter_moe_enabled = False

807
    def process_weights_after_loading(self, layer: Module) -> None:
808
809
        # Lazy import to avoid importing triton too early.

810
        self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
811

812
        # TODO (rob): refactor block quant into separate class.
813
        if self.block_quant:
814
            assert self.quant_config.activation_scheme == "dynamic"
815
            if current_platform.is_fp8_fnuz():
816
                w13_weight, w13_weight_scale_inv, w13_input_scale = (
817
                    normalize_e4m3fn_to_e4m3fnuz(
818
819
820
821
822
823
                        layer.w13_weight,
                        layer.w13_weight_scale_inv,
                        layer.w13_input_scale,
                    )
                )
                w2_weight, w2_weight_scale_inv, w2_input_scale = (
824
                    normalize_e4m3fn_to_e4m3fnuz(
825
826
827
                        layer.w2_weight, layer.w2_weight_scale_inv, layer.w2_input_scale
                    )
                )
828
            elif self.flashinfer_moe_backend is not None:
829
830
                # NOTE: weights have to be swapped since the activation is
                # applied on different half for flashinfer vs vllm
831
                w13_weight = swap_w13_to_w31(layer.w13_weight.data)
832
                w13_weight_scale_inv = swap_w13_to_w31(layer.w13_weight_scale_inv.data)
833
834
                w2_weight = layer.w2_weight.data
                w2_weight_scale_inv = layer.w2_weight_scale_inv.data
835
836
837
838
839
840
841
842
            else:
                w13_weight = layer.w13_weight.data
                w13_weight_scale_inv = layer.w13_weight_scale_inv.data
                w2_weight = layer.w2_weight
                w2_weight_scale_inv = layer.w2_weight_scale_inv

            # torch.compile() cannot use Parameter subclasses.
            layer.w13_weight = Parameter(w13_weight, requires_grad=False)
843
844
845
            layer.w13_weight_scale_inv = Parameter(
                w13_weight_scale_inv, requires_grad=False
            )
846
            layer.w2_weight = Parameter(w2_weight, requires_grad=False)
847
848
849
            layer.w2_weight_scale_inv = Parameter(
                w2_weight_scale_inv, requires_grad=False
            )
850
            if self.rocm_aiter_moe_enabled:
851
                # reshaping weights is required for aiter moe kernel.
852
                shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
853
854
                    layer.w13_weight.data, layer.w2_weight.data
                )
855

856
857
                layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
858

859
            # DeepGemm scales need to be transposed and aligned. We try to do
860
            # it ahead of time for performance reasons.
861
862
863
864
865
866
867
            if self.allow_deep_gemm:
                dg_w13_weight, dg_w13_weight_scale_inv = (
                    deepgemm_post_process_fp8_weight_block(
                        wq=layer.w13_weight.data,
                        ws=layer.w13_weight_scale_inv.data,
                        quant_block_shape=tuple(layer.weight_block_size),
                        use_e8m0=is_deep_gemm_e8m0_used(),
868
                    )
869
870
871
872
873
874
875
                )
                dg_w2_weight, dg_w2_weight_scale_inv = (
                    deepgemm_post_process_fp8_weight_block(
                        wq=layer.w2_weight.data,
                        ws=layer.w2_weight_scale_inv.data,
                        quant_block_shape=tuple(layer.weight_block_size),
                        use_e8m0=is_deep_gemm_e8m0_used(),
876
                    )
877
878
879
880
881
882
883
884
885
                )
                layer.w13_weight = Parameter(dg_w13_weight, requires_grad=False)
                layer.w13_weight_scale_inv = Parameter(
                    dg_w13_weight_scale_inv, requires_grad=False
                )
                layer.w2_weight = Parameter(dg_w2_weight, requires_grad=False)
                layer.w2_weight_scale_inv = Parameter(
                    dg_w2_weight_scale_inv, requires_grad=False
                )
886

887
        # If checkpoint is fp16, quantize in place.
888
        elif not self.quant_config.is_checkpoint_fp8_serialized:
889
            fp8_dtype = current_platform.fp8_dtype()
890
            w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
891
            w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
892
893
894

            # Re-initialize w13_scale because we directly quantize
            # merged w13 weights and generate a single scaling factor.
895
896
897
898
899
900
901
902
            layer.w13_weight_scale = torch.nn.Parameter(
                torch.ones(
                    layer.local_num_experts,
                    dtype=torch.float32,
                    device=w13_weight.device,
                ),
                requires_grad=False,
            )
903
            for expert in range(layer.local_num_experts):
904
905
906
907
908
909
910
911
                w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
                    ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
                )
                w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
                    ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
                )
            layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
            layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
912
            if self.rocm_aiter_moe_enabled:
913
                # reshaping weights is required for aiter moe kernel.
914
                shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
915
916
                    layer.w13_weight, layer.w2_weight
                )
917

918
919
                layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
920
921
922
923
924
925
926
        # If checkpoint is fp8, we need to handle that the
        # MoE kernels require single activation scale and single weight
        # scale for w13 per expert.
        else:
            # Fp8 moe kernels require a single activation scale.
            # We take the max of all the scales in case they differ.
            if self.quant_config.activation_scheme == "static":
927
                if layer.w13_input_scale is None or layer.w2_input_scale is None:
928
929
                    raise ValueError(
                        "QuantConfig has static quantization, but found "
930
931
932
933
934
                        "activation scales are None."
                    )
                if not all_close_1d(layer.w13_input_scale) or not all_close_1d(
                    layer.w2_input_scale
                ):
935
                    logger.warning_once(
936
937
                        "Found input_scales that are not equal for "
                        "fp8 MoE layer. Using the maximum across experts "
938
939
                        "for each layer."
                    )
940
                layer.w13_input_scale = torch.nn.Parameter(
941
942
                    layer.w13_input_scale.max(), requires_grad=False
                )
943
                layer.w2_input_scale = torch.nn.Parameter(
944
945
                    layer.w2_input_scale.max(), requires_grad=False
                )
946
            if current_platform.is_fp8_fnuz():
947
                # Normalize the weights and scales
948
                w13_weight, w13_weight_scale, w13_input_scale = (
949
                    normalize_e4m3fn_to_e4m3fnuz(
950
951
952
953
                        layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale
                    )
                )
                w2_weight, w2_weight_scale, w2_input_scale = (
954
                    normalize_e4m3fn_to_e4m3fnuz(
955
956
957
                        layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale
                    )
                )
958
                # Reset the parameter
959
                layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
960
                layer.w13_weight_scale = torch.nn.Parameter(
961
962
                    w13_weight_scale, requires_grad=False
                )
963
964
                if w13_input_scale is not None:
                    layer.w13_input_scale = torch.nn.Parameter(
965
966
967
968
969
970
                        w13_input_scale, requires_grad=False
                    )
                layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
                layer.w2_weight_scale = torch.nn.Parameter(
                    w2_weight_scale, requires_grad=False
                )
971
972
                if w2_input_scale is not None:
                    layer.w2_input_scale = torch.nn.Parameter(
973
974
                        w2_input_scale, requires_grad=False
                    )
975
976
977

            # Fp8 moe kernel needs single weight scale for w13 per expert.
            # We take the max then dequant and requant each expert.
978
            assert layer.w13_weight_scale is not None
979
            shard_size = layer.intermediate_size_per_partition
980
            max_w13_scales = layer.w13_weight_scale.max(dim=1).values
981
            for expert_id in range(layer.local_num_experts):
982
983
984
                start = 0
                for shard_id in range(2):
                    dq_weight = per_tensor_dequantize(
985
986
987
988
989
990
                        layer.w13_weight[expert_id][start : start + shard_size, :],
                        layer.w13_weight_scale[expert_id][shard_id],
                    )
                    layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
                        ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
                    )
991
992
                    start += shard_size

993
            if self.rocm_aiter_moe_enabled:
994
                shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
995
996
                    layer.w13_weight, layer.w2_weight
                )
997

998
999
                layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
1000

1001
1002
1003
            layer.w13_weight_scale = torch.nn.Parameter(
                max_w13_scales, requires_grad=False
            )
1004

1005
1006
1007
1008
1009
1010
            if self.flashinfer_moe_backend is not None:
                # NOTE: weights have to be swapped since the activation is
                # applied on different half for flashinfer vs vllm
                assert not self.block_quant
                register_moe_scaling_factors(layer)
                w13_weight = swap_w13_to_w31(layer.w13_weight.data)
1011
                if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
1012
1013
1014
                    rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
                layer.w13_weight.data = w13_weight.data

1015
1016
1017
1018
1019
        if self.use_marlin:
            prepare_moe_fp8_layer_for_marlin(layer, False)
            # Activations not quantized for marlin.
            del layer.w13_input_scale
            del layer.w2_input_scale
1020

1021
1022
1023
1024
    def maybe_make_prepare_finalize(
        self,
        routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
    ) -> mk.FusedMoEPrepareAndFinalize | None:
1025
1026
1027
1028
1029
        if (
            self.rocm_aiter_moe_enabled
            or self.use_marlin
            or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1030
1031
            return None
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
1032
1033
1034
1035
1036
1037
            if self.block_quant:
                assert self.weight_block_size == [128, 128], (
                    f"Only support weight_block_size == [128, 128], "
                    f"got {self.weight_block_size}"
                )
            # Wire block-scale flag through prepare/finalize when using CUTLASS
1038
            prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
1039
1040
                self.moe,
                use_deepseek_fp8_block_scale=self.block_quant,
1041
            )
1042
1043
1044
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
        else:
1045
            return super().maybe_make_prepare_finalize(routing_tables)
1046

bnellnm's avatar
bnellnm committed
1047
1048
1049
    def select_gemm_impl(
        self,
        prepare_finalize: FusedMoEPrepareAndFinalize,
1050
        layer: torch.nn.Module,
bnellnm's avatar
bnellnm committed
1051
    ) -> FusedMoEPermuteExpertsUnpermute:
1052
        from vllm.model_executor.layers.fused_moe import (
1053
1054
            BatchedDeepGemmExperts,
            BatchedTritonExperts,
1055
1056
            TritonOrDeepGemmExperts,
        )
1057

1058
        assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
1059
1060
            "Marlin and ROCm AITER are not supported with all2all yet."
        )
1061

1062
1063
        assert self.moe_quant_config is not None

1064
1065
1066
1067
1068
        if (
            prepare_finalize.activation_format
            == FusedMoEActivationFormat.BatchedExperts
        ):
            max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
bnellnm's avatar
bnellnm committed
1069
            assert max_num_tokens_per_rank is not None
1070
1071
1072
1073

            experts_impl = (
                BatchedDeepGemmExperts if self.allow_deep_gemm else BatchedTritonExperts
            )
bnellnm's avatar
bnellnm committed
1074
            logger.debug(
1075
1076
                "%s(%s): max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
                experts_impl.__name__,
1077
1078
1079
1080
1081
                self.__class__.__name__,
                max_num_tokens_per_rank,
                self.weight_block_size,
                False,
            )
1082
            return experts_impl(
1083
                max_num_tokens=max_num_tokens_per_rank,
1084
                num_dispatchers=prepare_finalize.num_dispatchers(),
1085
                quant_config=self.moe_quant_config,
1086
            )
1087

1088
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
1089
            # Select GEMM experts with block-scale when weights are block-quantized
1090
            experts = select_cutlass_fp8_gemm_impl(
1091
1092
                self.moe,
                self.moe_quant_config,
1093
                use_deepseek_fp8_block_scale=self.block_quant,
1094
1095
1096
            )
            logger.debug_once("Using %s", experts.__class__.__name__)
            return experts
1097
        else:
bnellnm's avatar
bnellnm committed
1098
1099
            logger.debug(
                "TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s",
1100
1101
1102
1103
                self.__class__.__name__,
                self.weight_block_size,
                False,
            )
bnellnm's avatar
bnellnm committed
1104
            return TritonOrDeepGemmExperts(
1105
                quant_config=self.moe_quant_config,
1106
1107
1108
                allow_deep_gemm=self.allow_deep_gemm,
            )

1109
    def get_fused_moe_quant_config(
1110
        self, layer: torch.nn.Module
1111
    ) -> FusedMoEQuantConfig | None:
1112
1113
1114
1115
        if self.use_marlin:
            return None

        return fp8_w8a8_moe_quant_config(
1116
1117
1118
1119
1120
1121
1122
1123
            w1_scale=(
                layer.w13_weight_scale_inv
                if self.block_quant
                else layer.w13_weight_scale
            ),
            w2_scale=(
                layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
            ),
1124
1125
            a1_scale=layer.w13_input_scale,
            a2_scale=layer.w2_input_scale,
1126
            block_shape=self.weight_block_size,
1127
1128
        )

1129
1130
1131
1132
1133
1134
1135
1136
    @property
    def supports_eplb(self) -> bool:
        return True

    @property
    def allow_inplace(self) -> bool:
        return True

1137
1138
1139
1140
1141
1142
1143
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
1144
        use_grouped_topk: bool = False,
1145
1146
        topk_group: int | None = None,
        num_expert_group: int | None = None,
1147
        global_num_experts: int = -1,
1148
1149
        expert_map: torch.Tensor | None = None,
        custom_routing_function: Callable | None = None,
Simon Mo's avatar
Simon Mo committed
1150
        scoring_func: str = "softmax",
1151
        routed_scaling_factor: float = 1.0,
1152
        e_score_correction_bias: torch.Tensor | None = None,
1153
        apply_router_weight_on_input: bool = False,
Michael Goin's avatar
Michael Goin committed
1154
        activation: str = "silu",
1155
        enable_eplb: bool = False,
1156
1157
1158
1159
        expert_load_view: torch.Tensor | None = None,
        logical_to_physical_map: torch.Tensor | None = None,
        logical_replica_count: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
1160
1161
1162
1163
1164
        if enable_eplb:
            assert expert_load_view is not None
            assert logical_to_physical_map is not None
            assert logical_replica_count is not None
            assert isinstance(layer, FusedMoE)
1165

1166
        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
1167
1168
1169
            assert activation == "silu", (
                f"Expected 'silu' activation but got {activation}"
            )
1170

1171
            if self.block_quant:
1172
                import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe  # noqa: E501, F401
1173
1174
1175
1176
1177
1178

                e_score_correction_bias = (
                    e_score_correction_bias.to(x.dtype)
                    if e_score_correction_bias is not None
                    else None
                )
1179
                routing_method_type = layer.routing_method_type
1180
                return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
1181
1182
1183
                    routing_logits=router_logits.to(torch.float32)
                    if routing_method_type == RoutingMethodType.DeepSeekV3
                    else router_logits,
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
                    routing_bias=e_score_correction_bias,
                    x=x,
                    w13_weight=layer.w13_weight,
                    w13_weight_scale_inv=layer.w13_weight_scale_inv,
                    w2_weight=layer.w2_weight,
                    w2_weight_scale_inv=layer.w2_weight_scale_inv,
                    global_num_experts=global_num_experts,
                    top_k=top_k,
                    num_expert_group=num_expert_group,
                    topk_group=topk_group,
                    intermediate_size=layer.intermediate_size_per_partition,
                    expert_offset=layer.ep_rank * layer.local_num_experts,
                    local_num_experts=layer.local_num_experts,
1197
                    block_shape=self.weight_block_size,
1198
                    routing_method_type=routing_method_type,
1199
                    routed_scaling=routed_scaling_factor,
1200
1201
                )
            else:
1202
                assert not renormalize and custom_routing_function is not None
XuruiYang's avatar
XuruiYang committed
1203
                result = apply_flashinfer_per_tensor_scale_fp8(
1204
1205
1206
1207
1208
1209
1210
1211
                    layer=layer,
                    hidden_states=x,
                    router_logits=router_logits,
                    routing_bias=e_score_correction_bias,
                    global_num_experts=global_num_experts,
                    top_k=top_k,
                    num_expert_group=num_expert_group,
                    topk_group=topk_group,
1212
1213
                    apply_router_weight_on_input=apply_router_weight_on_input,
                )
1214

1215
1216
        zero_expert_num = getattr(layer, "zero_expert_num", 0)
        zero_expert_type = getattr(layer, "zero_expert_type", None)
XuruiYang's avatar
XuruiYang committed
1217
1218

        select_result = FusedMoE.select_experts(
1219
1220
1221
1222
1223
1224
1225
1226
1227
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
            num_expert_group=num_expert_group,
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
1228
            routed_scaling_factor=routed_scaling_factor,
1229
1230
1231
1232
1233
1234
1235
            e_score_correction_bias=e_score_correction_bias,
            indices_type=self.topk_indices_dtype,
            enable_eplb=enable_eplb,
            expert_map=expert_map,
            expert_load_view=expert_load_view,
            logical_to_physical_map=logical_to_physical_map,
            logical_replica_count=logical_replica_count,
XuruiYang's avatar
XuruiYang committed
1236
1237
1238
            global_num_experts=global_num_experts,
            zero_expert_num=zero_expert_num,
            zero_expert_type=zero_expert_type,
1239
            num_fused_shared_experts=layer.num_fused_shared_experts,
1240
1241
        )

XuruiYang's avatar
XuruiYang committed
1242
1243
        topk_weights, topk_ids, zero_expert_result = select_result

1244
1245
        if self.rocm_aiter_moe_enabled:
            from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (  # noqa: E501
1246
1247
1248
                rocm_aiter_fused_experts,
            )

XuruiYang's avatar
XuruiYang committed
1249
            result = rocm_aiter_fused_experts(
1250
1251
1252
                x,
                layer.w13_weight,
                layer.w2_weight,
1253
1254
1255
1256
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                activation=activation,
                apply_router_weight_on_input=apply_router_weight_on_input,
1257
                expert_map=expert_map,
1258
1259
                quant_config=self.moe_quant_config,
            )
1260
        elif self.use_marlin:
1261
            assert activation == "silu", f"{activation} not supported for Marlin MoE."
1262
            result = fused_marlin_moe(
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
                x,
                layer.w13_weight,
                layer.w2_weight,
                None,
                None,
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                router_logits,
                topk_weights,
                topk_ids,
                quant_type_id=scalar_types.float8_e4m3fn.id,
                apply_router_weight_on_input=apply_router_weight_on_input,
                global_num_experts=global_num_experts,
1276
                expert_map=expert_map,
1277
1278
                workspace=layer.workspace,
            )
1279
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
1280
1281
1282
            assert activation == "silu", (
                f"Expected 'silu' activation but got {activation}"
            )
1283
1284
1285
1286
1287
1288
1289
1290
            if not self.block_quant:
                assert not renormalize and custom_routing_function is not None
                assert scoring_func == "sigmoid", (
                    f"Expected 'sigmoid' scoring func but got {scoring_func}"
                )
            # Delegate to CUTLASS FlashInfer path; function already bound with
            # use_deepseek_fp8_block_scale for block-quant when applicable
            result = self.flashinfer_moe_fn(
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
                x,
                layer,
                topk_weights,
                topk_ids,
                inplace=False,
                activation=activation,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
1301
        else:
1302
            from vllm.model_executor.layers.fused_moe import fused_experts
1303

XuruiYang's avatar
XuruiYang committed
1304
            result = fused_experts(
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                inplace=True,
                activation=activation,
                global_num_experts=global_num_experts,
                apply_router_weight_on_input=apply_router_weight_on_input,
                expert_map=expert_map,
1315
1316
1317
                quant_config=self.moe_quant_config,
                allow_deep_gemm=self.allow_deep_gemm,
                allow_cutlass_block_scaled_grouped_gemm=(
1318
1319
1320
                    self.allow_cutlass_block_scaled_grouped_gemm
                ),
            )
XuruiYang's avatar
XuruiYang committed
1321
        if zero_expert_num != 0 and zero_expert_type is not None:
1322
            assert not isinstance(result, tuple), (
XuruiYang's avatar
XuruiYang committed
1323
                "Shared + zero experts are mutually exclusive not yet supported"
1324
            )
XuruiYang's avatar
XuruiYang committed
1325
1326
1327
            return result, zero_expert_result
        else:
            return result
1328
1329


1330
1331
1332
class Fp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
1333
1334
1335
    """

    def __init__(self, quant_config: Fp8Config):
1336
        super().__init__(quant_config)