fp8.py 56.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from enum import Enum
5
from functools import partial
6
from typing import TYPE_CHECKING, Any, Optional
7
8
9
10
11

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter

12
import vllm.envs as envs
13
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
14
from vllm import _custom_ops as ops
15
from vllm._aiter_ops import rocm_aiter_ops
16
from vllm.attention.layer import Attention
17
from vllm.distributed import get_tensor_model_parallel_world_size
18
from vllm.logger import init_logger
19
from vllm.model_executor.layers.batch_invariant import (
20
    vllm_is_batch_invariant,
21
)
bnellnm's avatar
bnellnm committed
22
from vllm.model_executor.layers.fused_moe import (
23
24
25
26
27
28
29
    FusedMoE,
    FusedMoEActivationFormat,
    FusedMoEMethodBase,
    FusedMoEPermuteExpertsUnpermute,
    FusedMoEPrepareAndFinalize,
    FusedMoeWeightScaleSupported,
)
30
from vllm.model_executor.layers.fused_moe.config import (
31
    FusedMoEParallelConfig,
32
    FusedMoEQuantConfig,
33
    RoutingMethodType,
34
35
    fp8_w8a8_moe_quant_config,
)
36
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
37
38
39
40
41
42
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
)
43
from vllm.model_executor.layers.quantization import QuantizationMethods
44
from vllm.model_executor.layers.quantization.base_config import (
45
46
47
    QuantizationConfig,
    QuantizeMethodBase,
)
48
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
49
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
50
51
    FlashinferMoeBackend,
    apply_flashinfer_per_tensor_scale_fp8,
52
    build_flashinfer_fp8_cutlass_moe_prepare_finalize,
53
54
55
56
57
58
59
    flashinfer_cutlass_moe_fp8,
    get_flashinfer_moe_backend,
    register_moe_scaling_factors,
    rotate_flashinfer_fp8_moe_weights,
    select_cutlass_fp8_gemm_impl,
    swap_w13_to_w31,
)
60
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
61
62
63
64
    W8A8BlockFp8LinearOp,
    create_fp8_input_scale,
    create_fp8_scale_parameter,
    create_fp8_weight_parameter,
65
    deepgemm_post_process_fp8_weight_block,
66
67
68
69
70
    maybe_post_process_fp8_weight_block,
    process_fp8_weight_block_strategy,
    process_fp8_weight_tensor_strategy,
    validate_fp8_block_shape,
)
71
72
73
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
    get_marlin_input_dtype,
)
74
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
75
76
77
78
    apply_fp8_marlin_linear,
    prepare_fp8_layer_for_marlin,
    prepare_moe_fp8_layer_for_marlin,
)
79
from vllm.model_executor.layers.quantization.utils.quant_utils import (
80
81
82
    GroupShape,
    is_layer_skipped,
)
83
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
84
85
86
87
88
89
90
91
92
93
94
95
96
    Fp8LinearOp,
    all_close_1d,
    cutlass_block_fp8_supported,
    cutlass_fp8_supported,
    maybe_create_device_identity,
    normalize_e4m3fn_to_e4m3fnuz,
    per_tensor_dequantize,
)
from vllm.model_executor.parameter import (
    BlockQuantScaleParameter,
    ModelWeightParameter,
    PerTensorScaleParameter,
)
97
from vllm.model_executor.utils import set_weight_attrs
98
from vllm.platforms import current_platform
99
from vllm.scalar_type import scalar_types
100
101
102
103
from vllm.utils.deep_gemm import (
    is_deep_gemm_e8m0_used,
    is_deep_gemm_supported,
)
104
from vllm.utils.flashinfer import has_flashinfer_moe
105
from vllm.utils.import_utils import has_deep_gemm
106

107
108
109
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

110
111
112
113
ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)

114

115
116
117
118
119
120
121
122
123
124
class Fp8MoeBackend(Enum):
    NONE = 0
    FLASHINFER_TRTLLM = 1
    FLASHINFER_CUTLASS = 2
    DEEPGEMM = 3
    CUTLASS_BLOCK_SCALED_GROUPED_GEMM = 4
    MARLIN = 5
    TRITON = 6


125
def get_fp8_moe_backend(
126
127
128
    block_quant: bool,
    moe_parallel_config: FusedMoEParallelConfig,
    with_lora_support: bool,
129
) -> Fp8MoeBackend:
130
131
132
133
    """
    Select the primary FP8 MoE backend
    Note: Shape-specific fallbacks may still occur at runtime.
    """
134
135
    if with_lora_support:
        return Fp8MoeBackend.TRITON
136
    # Prefer FlashInfer backends on supported GPUs; allow SM90 and SM100.
137
138
    if (
        current_platform.is_cuda()
139
140
141
142
        and (
            current_platform.is_device_capability(100)
            or current_platform.is_device_capability(90)
        )
143
144
145
        and envs.VLLM_USE_FLASHINFER_MOE_FP8
        and has_flashinfer_moe()
    ):
146
147
        backend = get_flashinfer_moe_backend()
        if backend == FlashinferMoeBackend.TENSORRT_LLM:
148
            logger.info_once("Using FlashInfer FP8 MoE TRTLLM backend for SM100")
149
150
            return Fp8MoeBackend.FLASHINFER_TRTLLM
        else:
151
            if block_quant and current_platform.is_device_capability(100):
152
153
154
155
156
157
                raise ValueError(
                    "FlashInfer FP8 MoE throughput backend does not "
                    "support block quantization. Please use "
                    "VLLM_FLASHINFER_MOE_BACKEND=latency "
                    "instead."
                )
158
            logger.info_once("Using FlashInfer FP8 MoE CUTLASS backend for SM90/SM100")
159
160
161
            return Fp8MoeBackend.FLASHINFER_CUTLASS

    # weight-only path for older GPUs without native FP8
162
163
164
165
    use_marlin = (
        not current_platform.has_device_capability(89)
        or envs.VLLM_TEST_FORCE_FP8_MARLIN
    )
166
167
168
169
170
171
    if current_platform.is_rocm():
        use_marlin = False
    if use_marlin:
        logger.info_once("Using Marlin backend for FP8 MoE")
        return Fp8MoeBackend.MARLIN

172
173
174
175
176
177
178
179
180
181
182
183
184
    # Determine if we should use DeepGEMM with block-quantized weights:
    # - If explicitly set by user, respect their choice
    # - If not explicitly set (default), disable when TP size is >= 8
    moe_use_deep_gemm = envs.VLLM_MOE_USE_DEEP_GEMM
    if not envs.is_set("VLLM_MOE_USE_DEEP_GEMM") and moe_parallel_config.tp_size >= 8:
        moe_use_deep_gemm = False
        logger.info_once(
            "DeepGEMM MoE is disabled by default when TP size is >= 8. "
            "Set VLLM_MOE_USE_DEEP_GEMM=1 to enable it.",
            scope="local",
        )

    if envs.VLLM_USE_DEEP_GEMM and moe_use_deep_gemm and block_quant:
185
        if not has_deep_gemm():
186
187
188
            logger.warning_once(
                "DeepGEMM backend requested but not available.", scope="local"
            )
189
        elif is_deep_gemm_supported():
190
            logger.info_once("Using DeepGEMM backend for FP8 MoE", scope="local")
191
192
193
            return Fp8MoeBackend.DEEPGEMM

    # CUTLASS BlockScaled GroupedGemm on SM100 with block-quantized weights
194
195
196
197
198
    if (
        current_platform.is_cuda()
        and current_platform.is_device_capability(100)
        and block_quant
    ):
199
200
201
        logger.info_once(
            "Using Cutlass BlockScaled GroupedGemm backend for FP8 MoE", scope="local"
        )
202
203
204
205
206
207
208
        return Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM

    # default to Triton
    logger.info_once("Using Triton backend for FP8 MoE")
    return Fp8MoeBackend.TRITON


209
class Fp8Config(QuantizationConfig):
210
211
    """Config class for FP8."""

212
213
    def __init__(
        self,
214
        is_checkpoint_fp8_serialized: bool = False,
215
        activation_scheme: str = "dynamic",
216
217
        ignored_layers: list[str] | None = None,
        weight_block_size: list[int] | None = None,
218
    ) -> None:
219
        super().__init__()
220

221
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
222

223
        if activation_scheme not in ACTIVATION_SCHEMES:
224
            raise ValueError(f"Unsupported activation scheme {activation_scheme}")
225
        self.activation_scheme = activation_scheme
226
        self.ignored_layers = ignored_layers or []
227
228
229
230
        if weight_block_size is not None:
            if not is_checkpoint_fp8_serialized:
                raise ValueError(
                    "The block-wise quantization only supports fp8-serialized "
231
232
                    "checkpoint for now."
                )
233
234
235
            if len(weight_block_size) != 2:
                raise ValueError(
                    "The quantization block size of weight must have 2 "
236
237
                    f"dimensions, but got {len(weight_block_size)} dimensions"
                )
238
            if activation_scheme != "dynamic":
239
240
241
242
243
                raise ValueError(
                    "The block-wise quantization only supports "
                    "dynamic activation scheme for now, but got "
                    f"{activation_scheme} activation scheme."
                )
244
        self.weight_block_size = weight_block_size
245

246
    @classmethod
247
    def get_name(cls) -> QuantizationMethods:
248
249
250
        return "fp8"

    @classmethod
251
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
252
253
254
255
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
256
        return 80
257
258

    @classmethod
259
    def get_config_filenames(cls) -> list[str]:
260
261
        return []

262
263
    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if self.ignored_layers is not None:
264
            self.ignored_layers = hf_to_vllm_mapper.apply_list(self.ignored_layers)
265

266
    @classmethod
267
    def from_config(cls, config: dict[str, Any]) -> "Fp8Config":
268
        quant_method = cls.get_from_keys(config, ["quant_method"])
269
        is_checkpoint_fp8_serialized = "fp8" in quant_method
270
        activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
271
        ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
272
        weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None)
273
        if not ignored_layers:
274
275
276
277
278
279
280
281
282
283
284
285
286
            ignored_layers = cls.get_from_keys_or(
                config, ["modules_to_not_convert"], None
            )
        return cls(
            is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
            activation_scheme=activation_scheme,
            ignored_layers=ignored_layers,
            weight_block_size=weight_block_size,
        )

    def get_xpu_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
287
        from vllm.model_executor.layers.quantization.ipex_quant import (
288
289
290
291
            XPUFp8LinearMethod,
            XPUFp8MoEMethod,
        )

292
293
294
295
        fp8_config = Fp8Config(
            is_checkpoint_fp8_serialized=self.is_checkpoint_fp8_serialized,
            activation_scheme=self.activation_scheme,
            ignored_layers=self.ignored_layers,
296
297
            weight_block_size=self.weight_block_size,
        )
298
299

        if isinstance(layer, LinearBase):
300
301
302
303
304
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
305
306
307
308
309
310
311
312
                return UnquantizedLinearMethod()
            return XPUFp8LinearMethod(fp8_config)
        elif isinstance(layer, FusedMoE):
            return XPUFp8MoEMethod(fp8_config, layer)
        elif isinstance(layer, Attention):
            return Fp8KVCacheMethod(self)
        return None

313
314
315
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
316
317
        if current_platform.is_xpu():
            return self.get_xpu_quant_method(layer, prefix)
318
        if isinstance(layer, LinearBase):
319
320
321
322
323
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
324
                return UnquantizedLinearMethod()
325
326
327
            quant_method = Fp8LinearMethod(self)
            quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
            return quant_method
328
        elif isinstance(layer, FusedMoE):
329
330
331
332
333
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
XuruiYang's avatar
XuruiYang committed
334
                return UnquantizedFusedMoEMethod(layer.moe_config)
335
336
337
            moe_quant_method = Fp8MoEMethod(self, layer)
            moe_quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
            return moe_quant_method
338
        elif isinstance(layer, Attention):
339
            return Fp8KVCacheMethod(self)
340
        return None
341

342
    def get_cache_scale(self, name: str) -> str | None:
343
344
345
346
347
348
349
350
351
352
353
354
        """
        Check whether the param name matches the format for k/v cache scales
        in compressed-tensors. If this is the case, return its equivalent
        param name expected by vLLM

        :param name: param name
        :return: matching param name for KV cache scale in vLLM
        """
        if name.endswith(".output_scale") and ".k_proj" in name:
            return name.replace(".k_proj.output_scale", ".attn.k_scale")
        if name.endswith(".output_scale") and ".v_proj" in name:
            return name.replace(".v_proj.output_scale", ".attn.v_scale")
355
356
357
358
359
        if name.endswith(".output_scale") and ".q_proj" in name:
            return name.replace(".q_proj.output_scale", ".attn.q_scale")
        if name.endswith("self_attn.prob_output_scale"):
            return name.replace(".prob_output_scale", ".attn.prob_scale")
        # If no matches, return None
360
361
        return None

362
363
364

class Fp8LinearMethod(LinearMethodBase):
    """Linear method for FP8.
365
366
367
368
369
370
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.
371
372
373
374
375

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
    2. Only support float8_e4m3fn data type due to the limitation of
       torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
376

377
378
379
380
    Args:
        quant_config: The quantization config.
    """

381
    def __init__(self, quant_config: Fp8Config):
382
        self.quant_config = quant_config
383
        self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
384
        self.out_dtype = torch.get_default_dtype()
385

386
387
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
388
        self.marlin_input_dtype = None
389
390
391
392
        self.use_marlin = (
            not current_platform.has_device_capability(89)
            or envs.VLLM_TEST_FORCE_FP8_MARLIN
        )
393
        # Disable marlin for rocm
394
        if current_platform.is_rocm():
395
            self.use_marlin = False
396
        if vllm_is_batch_invariant():
397
            self.use_marlin = False
398

399
        self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enaled()
400
        self.use_deep_gemm = is_deep_gemm_supported()
401

402
403
        self.weight_block_size = self.quant_config.weight_block_size
        self.block_quant = self.weight_block_size is not None
404
        self.act_q_static = self.quant_config.activation_scheme == "static"
405
406
        if self.weight_block_size:
            self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
407
        else:
408
409
410
411
412
            # Use per-token quantization for better perf if dynamic and cutlass
            if not self.act_q_static and cutlass_fp8_supported():
                self.act_q_group_shape = GroupShape.PER_TOKEN
            else:
                self.act_q_group_shape = GroupShape.PER_TENSOR
413

414
415
416
417
418
419
420
421
422
423
424
425
        if self.block_quant:
            assert not self.act_q_static
            assert self.weight_block_size is not None
            self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
                weight_group_shape=GroupShape(*self.weight_block_size),
                act_quant_group_shape=self.act_q_group_shape,
                cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
                use_aiter_and_is_supported=self.use_aiter_and_is_supported,
            )
        else:
            self.fp8_linear = Fp8LinearOp(
                act_quant_static=self.act_q_static,
426
427
                act_quant_group_shape=self.act_q_group_shape,
            )
428

429
430
431
432
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
433
        output_partition_sizes: list[int],
434
435
436
437
438
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
439
440
        maybe_create_device_identity()

441
        output_size_per_partition = sum(output_partition_sizes)
442
        weight_loader = extra_weight_attrs.get("weight_loader")
443
444
445
446
447
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None
448

449
        if self.block_quant:
450
451
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
452
453
454
455
456
457
458
459
            validate_fp8_block_shape(
                layer,
                input_size,
                output_size,
                input_size_per_partition,
                output_partition_sizes,
                self.weight_block_size,
            )
460

461
        # WEIGHT
462
        if self.quant_config.is_checkpoint_fp8_serialized:
463
464
465
            weight = create_fp8_weight_parameter(
                output_size_per_partition, input_size_per_partition, weight_loader
            )
466
        else:
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490

            def patched_weight_loader(param, loaded_weight, *args, **kwargs):
                # load the current weight chunk
                res = weight_loader(param, loaded_weight, *args, **kwargs)  # type: ignore[misc]

                # track how many elements we have updated
                if not hasattr(layer, "_loaded_numel"):
                    layer._loaded_numel = 0
                layer._loaded_numel += loaded_weight.numel()

                # if we have loaded all of the elements, call
                # process_weights_after_loading
                target_loaded_numel = layer.weight.numel()
                if layer._loaded_numel == target_loaded_numel:
                    self.process_weights_after_loading(layer)

                    # Delete the bookkeeping
                    del layer._loaded_numel
                    # Prevent the usual `process_weights_after_loading` call from doing
                    # anything
                    layer._already_called_process_weights_after_loading = True

                return res

491
            # For non-serialized checkpoints, use original dtype
492
493
494
495
496
497
498
499
            weight = ModelWeightParameter(
                data=torch.empty(
                    output_size_per_partition,
                    input_size_per_partition,
                    dtype=params_dtype,
                ),
                input_dim=1,
                output_dim=0,
500
                weight_loader=patched_weight_loader,
501
            )
502
503
        layer.register_parameter("weight", weight)

504
505
506
507
        # If checkpoint is serialized fp8, load them.
        # Otherwise, wait until process_weights_after_loading.
        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
508
            if not self.block_quant:
509
510
511
512
513
514
515
                scale = create_fp8_scale_parameter(
                    PerTensorScaleParameter,
                    output_partition_sizes,
                    input_size_per_partition,
                    None,
                    weight_loader,
                )
516
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
517
518
                layer.register_parameter("weight_scale", scale)
            else:
519
520
                assert not self.act_q_static
                assert self.weight_block_size is not None
521
522
523
524
525
526
527
                scale = create_fp8_scale_parameter(
                    BlockQuantScaleParameter,
                    output_partition_sizes,
                    input_size_per_partition,
                    self.weight_block_size,
                    weight_loader,
                )
528
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
529
530
                # The weight_scale_inv name is intentional for deepseekv3
                layer.register_parameter("weight_scale_inv", scale)
531

532
            # INPUT ACTIVATION SCALE
533
            if self.act_q_static:
534
                scale = create_fp8_input_scale(output_partition_sizes, weight_loader)
535
                set_weight_attrs(scale, {"scale_type": "input_scale"})
536
                layer.register_parameter("input_scale", scale)
537
538
            else:
                layer.register_parameter("input_scale", None)
539

540
    def process_weights_after_loading(self, layer: Module) -> None:
541
542
543
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

544
        size_k_first = True
545
        input_scale = None
546
        # TODO(rob): refactor block quant into separate class.
547
        if self.block_quant:
548
            assert not self.act_q_static
549
            size_k_first = False
550

551
            weight, weight_scale = process_fp8_weight_block_strategy(
552
553
                layer.weight, layer.weight_scale_inv
            )
554
555
556
            # Delete the weight_scale_inv parameter to avoid confusion
            # with the weight_scale parameter
            del layer.weight_scale_inv
557

558
        # If checkpoint not serialized fp8, quantize the weights.
559
        elif not self.quant_config.is_checkpoint_fp8_serialized:
560
            qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
561
            weight = qweight.t()
562

563
        # If checkpoint is fp8 per-tensor, handle that there are N scales for N
564
        # shards in a fused module
565
        else:
566
567
            weight = layer.weight
            weight_scale = layer.weight_scale
568
569
570

            # If using w8a8, torch._scaled_mm needs per tensor, so
            # requantize the logical shards as a single weight.
571
            if not self.use_marlin:
572
573
574
575
576
577
                weight, weight_scale, input_scale = process_fp8_weight_tensor_strategy(
                    weight,
                    weight_scale,
                    layer.logical_widths,
                    getattr(layer, "input_scale", None),
                )
578
579
580
581
582
583
584
585
                if self.act_q_static:
                    assert input_scale is not None
                    input_scale = input_scale.max()
            weight = weight.t()

        # Update layer with new values.
        layer.weight = Parameter(weight.data, requires_grad=False)
        layer.weight_scale = Parameter(weight_scale.data, requires_grad=False)
586
587
588
589
590
        layer.input_scale = (
            Parameter(input_scale, requires_grad=False)
            if input_scale is not None
            else None
        )
591

592
        if self.use_marlin:
593
594
595
            prepare_fp8_layer_for_marlin(
                layer, size_k_first, input_dtype=self.marlin_input_dtype
            )
596
597
            # Activations not quantized for marlin.
            del layer.input_scale
598
            return
599

600
        if self.block_quant:
601
            maybe_post_process_fp8_weight_block(layer)
602

603
604
605
606
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
607
        bias: torch.Tensor | None = None,
608
    ) -> torch.Tensor:
609
610
        # if batch invariant mode is enabled, prefer DeepGEMM FP8 path
        # we will use BF16 dequant when DeepGEMM is not supported.
611
        if vllm_is_batch_invariant():
612
613
            if self.block_quant:
                assert self.weight_block_size is not None
614
615
616
617
618
619
620
                return self.w8a8_block_fp8_linear.apply(
                    input=x,
                    weight=layer.weight,
                    weight_scale=layer.weight_scale,
                    input_scale=layer.input_scale,
                    bias=bias,
                )
621
            else:
622
623
624
                # per-tensor/channel: dequant to BF16 and run GEMM
                weight_fp8 = layer.weight.to(torch.bfloat16)
                weight_scale = layer.weight_scale.to(torch.bfloat16)
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
                if weight_scale.numel() == 1:
                    # Per-tensor: simple scalar multiplication
                    weight_bf16 = weight_fp8 * weight_scale
                else:
                    # Multiple scales (fused modules like QKV)
                    # Try to infer correct broadcasting
                    # weight is [K, N], scale could be [num_logical_weights]
                    # Need to figure out how to broadcast - for now just try
                    # direct multiplication
                    if (
                        weight_scale.dim() == 1
                        and weight_scale.shape[0] == weight_fp8.shape[0]
                    ):
                        # Per-row scaling
                        weight_bf16 = weight_fp8 * weight_scale.unsqueeze(1)
                    else:
                        # Fallback
                        weight_bf16 = weight_fp8 * weight_scale
643
                return torch.nn.functional.linear(x, weight_bf16.t(), bias)
644

645
        if self.use_marlin:
646
647
648
649
            return apply_fp8_marlin_linear(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale,
650
651
652
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
653
                input_dtype=self.marlin_input_dtype,
654
655
                bias=bias,
            )
656

657
        if self.block_quant:
658
659
660
            assert self.weight_block_size is not None

            return self.w8a8_block_fp8_linear.apply(
661
                input=x,
662
663
664
                weight=layer.weight,
                weight_scale=layer.weight_scale,
                input_scale=layer.input_scale,
665
                bias=bias,
666
            )
667

668
669
670
671
672
673
674
675
        return self.fp8_linear.apply(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            out_dtype=self.out_dtype,
            input_scale=layer.input_scale,
            bias=bias,
        )
676
677


678
679
680
681
682
683
684
685
686
687
688
689
690
class Fp8MoEMethod(FusedMoEMethodBase):
    """MoE method for FP8.
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

691
692
693
    def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
        super().__init__(layer.moe_config)
        self.layer = layer
694
        self.quant_config = quant_config
695
        self.weight_block_size = self.quant_config.weight_block_size
696
        self.block_quant: bool = self.weight_block_size is not None
697
        self.fp8_backend = get_fp8_moe_backend(
698
            self.block_quant, layer.moe_parallel_config, self.moe.is_lora_enabled
699
        )
700

701
        self.marlin_input_dtype = None
702
        self.use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN
703
        self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
704
705
706
707
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
            self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
        elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
            self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS
708
709
710
711
712
713
714
715
716
717
            if self.block_quant:
                assert self.weight_block_size == [128, 128], (
                    f"Only support weight_block_size == [128, 128], "
                    f"got {self.weight_block_size}"
                )
            self.flashinfer_moe_fn = partial(
                flashinfer_cutlass_moe_fp8,
                moe=self.moe,
                use_deepseek_fp8_block_scale=self.block_quant,
            )
718

719
        self.allow_deep_gemm = self.fp8_backend == Fp8MoeBackend.DEEPGEMM
720
721
722
        self.allow_cutlass_block_scaled_grouped_gemm = (
            self.fp8_backend == Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM
        )
723

724
725
726
727
728
729
730
731
732
    def create_weights(
        self,
        layer: Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
733
734
735
736
737
738
        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

739
740
        if self.quant_config.is_checkpoint_fp8_serialized:
            params_dtype = torch.float8_e4m3fn
741
        if self.block_quant:
742
743
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
744
745
            tp_size = get_tensor_model_parallel_world_size()
            block_n, block_k = (
746
747
                self.weight_block_size[0],
                self.weight_block_size[1],
748
749
750
751
752
            )
            # NOTE: To ensure proper alignment of the block-wise quantization
            # scales, the output_size of the weights for both the gate and up
            # layers must be divisible by block_n.
            # Required by column parallel or enabling merged weights
753
            if intermediate_size_per_partition % block_n != 0:
754
755
                raise ValueError(
                    f"The output_size of gate's and up's weight = "
756
                    f"{intermediate_size_per_partition} is not divisible by "
757
758
759
                    f"weight quantization block_n = {block_n}."
                )
            if tp_size > 1 and intermediate_size_per_partition % block_k != 0:
760
                # Required by row parallel
761
762
763
                raise ValueError(
                    f"The input_size of down's weight = "
                    f"{intermediate_size_per_partition} is not divisible by "
764
765
                    f"weight quantization block_k = {block_k}."
                )
766

767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
        # if we are doing online quantization, patch the weight
        # loaded to call `process_weights_after_loading` in a streaming fashion
        # as soon as the last weight chunk is loaded
        if not self.quant_config.is_checkpoint_fp8_serialized:
            weight_loader = extra_weight_attrs["weight_loader"]
            # create a new holder to prevent modifying behavior of any other
            # objects which might depend on the old one
            new_extra_weight_attrs = extra_weight_attrs

            def patched_weight_loader(param, loaded_weight, *args, **kwargs):
                # load the current weight chunk
                res = weight_loader(param, loaded_weight, *args, **kwargs)  # type: ignore[misc]

                # add a counter to track how many elements we have updated
                if not hasattr(layer, "_loaded_numel"):
                    layer._loaded_numel = 0
                layer._loaded_numel += loaded_weight.numel()

                # if we have loaded all of the elements, call
                # process_weights_after_loading
                target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel()
                if layer._loaded_numel == target_loaded_numel:
                    self.process_weights_after_loading(layer)

                    # Delete the bookkeeping
                    del layer._loaded_numel
                    # Prevent the usual `process_weights_after_loading` call
                    # from doing anything
                    layer._already_called_process_weights_after_loading = True

                return res

            new_extra_weight_attrs["weight_loader"] = patched_weight_loader
            extra_weight_attrs = new_extra_weight_attrs

802
        # WEIGHTS
803
804
805
806
807
808
809
810
811
        w13_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                hidden_size,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
812
813
814
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

815
816
817
818
819
820
821
822
823
        w2_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
824
825
826
827
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
828
829
830
        if not self.block_quant:
            # Allocate 2 scales for w1 and w3 respectively.
            # They will be combined to a single scale after weight loading.
831
832
833
834
835
836
            w13_weight_scale = torch.nn.Parameter(
                torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
            )
            w2_weight_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
837
838
839
840
841
842
            layer.register_parameter("w13_weight_scale", w13_weight_scale)
            layer.register_parameter("w2_weight_scale", w2_weight_scale)
        else:
            w13_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
843
                    2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
844
845
846
847
848
849
850
851
852
                    (hidden_size + block_k - 1) // block_k,
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            w2_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
                    (hidden_size + block_n - 1) // block_n,
853
                    (intermediate_size_per_partition + block_k - 1) // block_k,
854
855
856
857
858
859
860
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
            layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
            assert self.quant_config.activation_scheme == "dynamic"
861

862
863
864
        # Add the quantization method used (per tensor/grouped/channel)
        # to ensure the weight scales are loaded in properly
        extra_weight_attrs.update(
865
866
867
868
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
            if self.block_quant
            else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
        )
869
870
871
872
        # If loading fp8 checkpoint, pass the weight loaders.
        # If loading an fp16 checkpoint, do not (we will quantize in
        #   process_weights_after_loading()
        if self.quant_config.is_checkpoint_fp8_serialized:
873
874
            set_weight_attrs(w13_weight_scale, extra_weight_attrs)
            set_weight_attrs(w2_weight_scale, extra_weight_attrs)
875
876
877
878
879
880

        # INPUT_SCALES
        if self.quant_config.activation_scheme == "static":
            if not self.quant_config.is_checkpoint_fp8_serialized:
                raise ValueError(
                    "Found static activation scheme for checkpoint that "
881
882
                    "was not serialized fp8."
                )
883

884
885
886
            w13_input_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
887
            layer.register_parameter("w13_input_scale", w13_input_scale)
888
            set_weight_attrs(w13_input_scale, extra_weight_attrs)
889

890
891
892
            w2_input_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
893
            layer.register_parameter("w2_input_scale", w2_input_scale)
894
895
            set_weight_attrs(w2_input_scale, extra_weight_attrs)

896
        else:
897
898
            layer.w13_input_scale = None
            layer.w2_input_scale = None
899

900
901
        self.rocm_aiter_moe_enabled = False

902
    def process_weights_after_loading(self, layer: Module) -> None:
903
904
905
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

906
907
        # Lazy import to avoid importing triton too early.

908
        self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
909

910
        # TODO (rob): refactor block quant into separate class.
911
        if self.block_quant:
912
            assert self.quant_config.activation_scheme == "dynamic"
913
            if current_platform.is_fp8_fnuz():
914
                w13_weight, w13_weight_scale_inv, w13_input_scale = (
915
                    normalize_e4m3fn_to_e4m3fnuz(
916
917
918
919
920
921
                        layer.w13_weight,
                        layer.w13_weight_scale_inv,
                        layer.w13_input_scale,
                    )
                )
                w2_weight, w2_weight_scale_inv, w2_input_scale = (
922
                    normalize_e4m3fn_to_e4m3fnuz(
923
924
925
                        layer.w2_weight, layer.w2_weight_scale_inv, layer.w2_input_scale
                    )
                )
926
            elif self.flashinfer_moe_backend is not None:
927
928
                # NOTE: weights have to be swapped since the activation is
                # applied on different half for flashinfer vs vllm
929
                w13_weight = swap_w13_to_w31(layer.w13_weight.data)
930
                w13_weight_scale_inv = swap_w13_to_w31(layer.w13_weight_scale_inv.data)
931
932
                w2_weight = layer.w2_weight.data
                w2_weight_scale_inv = layer.w2_weight_scale_inv.data
933
934
935
936
937
938
939
940
            else:
                w13_weight = layer.w13_weight.data
                w13_weight_scale_inv = layer.w13_weight_scale_inv.data
                w2_weight = layer.w2_weight
                w2_weight_scale_inv = layer.w2_weight_scale_inv

            # torch.compile() cannot use Parameter subclasses.
            layer.w13_weight = Parameter(w13_weight, requires_grad=False)
941
942
943
            layer.w13_weight_scale_inv = Parameter(
                w13_weight_scale_inv, requires_grad=False
            )
944
            layer.w2_weight = Parameter(w2_weight, requires_grad=False)
945
946
947
            layer.w2_weight_scale_inv = Parameter(
                w2_weight_scale_inv, requires_grad=False
            )
948
            if self.rocm_aiter_moe_enabled:
949
                # reshaping weights is required for aiter moe kernel.
950
                shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
951
952
                    layer.w13_weight.data, layer.w2_weight.data
                )
953

954
955
                layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
956

957
            # DeepGemm scales need to be transposed and aligned. We try to do
958
            # it ahead of time for performance reasons.
959
960
961
962
963
964
965
            if self.allow_deep_gemm:
                dg_w13_weight, dg_w13_weight_scale_inv = (
                    deepgemm_post_process_fp8_weight_block(
                        wq=layer.w13_weight.data,
                        ws=layer.w13_weight_scale_inv.data,
                        quant_block_shape=tuple(layer.weight_block_size),
                        use_e8m0=is_deep_gemm_e8m0_used(),
966
                    )
967
968
969
970
971
972
973
                )
                dg_w2_weight, dg_w2_weight_scale_inv = (
                    deepgemm_post_process_fp8_weight_block(
                        wq=layer.w2_weight.data,
                        ws=layer.w2_weight_scale_inv.data,
                        quant_block_shape=tuple(layer.weight_block_size),
                        use_e8m0=is_deep_gemm_e8m0_used(),
974
                    )
975
976
977
978
979
980
981
982
983
                )
                layer.w13_weight = Parameter(dg_w13_weight, requires_grad=False)
                layer.w13_weight_scale_inv = Parameter(
                    dg_w13_weight_scale_inv, requires_grad=False
                )
                layer.w2_weight = Parameter(dg_w2_weight, requires_grad=False)
                layer.w2_weight_scale_inv = Parameter(
                    dg_w2_weight_scale_inv, requires_grad=False
                )
984

985
        # If checkpoint is fp16, quantize in place.
986
        elif not self.quant_config.is_checkpoint_fp8_serialized:
987
            fp8_dtype = current_platform.fp8_dtype()
988
            w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
989
            w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
990
991
992

            # Re-initialize w13_scale because we directly quantize
            # merged w13 weights and generate a single scaling factor.
993
994
995
996
997
998
999
1000
            layer.w13_weight_scale = torch.nn.Parameter(
                torch.ones(
                    layer.local_num_experts,
                    dtype=torch.float32,
                    device=w13_weight.device,
                ),
                requires_grad=False,
            )
1001
            for expert in range(layer.local_num_experts):
1002
1003
1004
1005
1006
1007
1008
1009
                w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
                    ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
                )
                w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
                    ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
                )
            layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
            layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
1010
            if self.rocm_aiter_moe_enabled:
1011
                # reshaping weights is required for aiter moe kernel.
1012
                shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
1013
1014
                    layer.w13_weight, layer.w2_weight
                )
1015

1016
1017
                layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
1018
1019
1020
1021
1022
1023
1024
        # If checkpoint is fp8, we need to handle that the
        # MoE kernels require single activation scale and single weight
        # scale for w13 per expert.
        else:
            # Fp8 moe kernels require a single activation scale.
            # We take the max of all the scales in case they differ.
            if self.quant_config.activation_scheme == "static":
1025
                if layer.w13_input_scale is None or layer.w2_input_scale is None:
1026
1027
                    raise ValueError(
                        "QuantConfig has static quantization, but found "
1028
1029
1030
1031
1032
                        "activation scales are None."
                    )
                if not all_close_1d(layer.w13_input_scale) or not all_close_1d(
                    layer.w2_input_scale
                ):
1033
                    logger.warning_once(
1034
1035
                        "Found input_scales that are not equal for "
                        "fp8 MoE layer. Using the maximum across experts "
1036
1037
                        "for each layer."
                    )
1038
                layer.w13_input_scale = torch.nn.Parameter(
1039
1040
                    layer.w13_input_scale.max(), requires_grad=False
                )
1041
                layer.w2_input_scale = torch.nn.Parameter(
1042
1043
                    layer.w2_input_scale.max(), requires_grad=False
                )
1044
            if current_platform.is_fp8_fnuz():
1045
                # Normalize the weights and scales
1046
                w13_weight, w13_weight_scale, w13_input_scale = (
1047
                    normalize_e4m3fn_to_e4m3fnuz(
1048
1049
1050
1051
                        layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale
                    )
                )
                w2_weight, w2_weight_scale, w2_input_scale = (
1052
                    normalize_e4m3fn_to_e4m3fnuz(
1053
1054
1055
                        layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale
                    )
                )
1056
                # Reset the parameter
1057
                layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
1058
                layer.w13_weight_scale = torch.nn.Parameter(
1059
1060
                    w13_weight_scale, requires_grad=False
                )
1061
1062
                if w13_input_scale is not None:
                    layer.w13_input_scale = torch.nn.Parameter(
1063
1064
1065
1066
1067
1068
                        w13_input_scale, requires_grad=False
                    )
                layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
                layer.w2_weight_scale = torch.nn.Parameter(
                    w2_weight_scale, requires_grad=False
                )
1069
1070
                if w2_input_scale is not None:
                    layer.w2_input_scale = torch.nn.Parameter(
1071
1072
                        w2_input_scale, requires_grad=False
                    )
1073
1074
1075

            # Fp8 moe kernel needs single weight scale for w13 per expert.
            # We take the max then dequant and requant each expert.
1076
            assert layer.w13_weight_scale is not None
1077
            shard_size = layer.intermediate_size_per_partition
1078
            max_w13_scales = layer.w13_weight_scale.max(dim=1).values
1079
            for expert_id in range(layer.local_num_experts):
1080
1081
1082
                start = 0
                for shard_id in range(2):
                    dq_weight = per_tensor_dequantize(
1083
1084
1085
1086
1087
1088
                        layer.w13_weight[expert_id][start : start + shard_size, :],
                        layer.w13_weight_scale[expert_id][shard_id],
                    )
                    layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
                        ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
                    )
1089
1090
                    start += shard_size

1091
            if self.rocm_aiter_moe_enabled:
1092
                shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
1093
1094
                    layer.w13_weight, layer.w2_weight
                )
1095

1096
1097
                layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
1098

1099
1100
1101
            layer.w13_weight_scale = torch.nn.Parameter(
                max_w13_scales, requires_grad=False
            )
1102

1103
1104
1105
1106
1107
1108
            if self.flashinfer_moe_backend is not None:
                # NOTE: weights have to be swapped since the activation is
                # applied on different half for flashinfer vs vllm
                assert not self.block_quant
                register_moe_scaling_factors(layer)
                w13_weight = swap_w13_to_w31(layer.w13_weight.data)
1109
                if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
1110
1111
1112
                    rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
                layer.w13_weight.data = w13_weight.data

1113
        if self.use_marlin:
1114
1115
1116
            prepare_moe_fp8_layer_for_marlin(
                layer, False, input_dtype=self.marlin_input_dtype
            )
1117
1118
1119
            # Activations not quantized for marlin.
            del layer.w13_input_scale
            del layer.w2_input_scale
1120

1121
1122
1123
1124
    def maybe_make_prepare_finalize(
        self,
        routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
    ) -> mk.FusedMoEPrepareAndFinalize | None:
1125
1126
1127
1128
1129
        if (
            self.rocm_aiter_moe_enabled
            or self.use_marlin
            or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1130
1131
            return None
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
1132
1133
1134
1135
1136
1137
            if self.block_quant:
                assert self.weight_block_size == [128, 128], (
                    f"Only support weight_block_size == [128, 128], "
                    f"got {self.weight_block_size}"
                )
            # Wire block-scale flag through prepare/finalize when using CUTLASS
1138
            prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
1139
1140
                self.moe,
                use_deepseek_fp8_block_scale=self.block_quant,
1141
            )
1142
1143
1144
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
        else:
1145
            return super().maybe_make_prepare_finalize(routing_tables)
1146

bnellnm's avatar
bnellnm committed
1147
1148
1149
    def select_gemm_impl(
        self,
        prepare_finalize: FusedMoEPrepareAndFinalize,
1150
        layer: torch.nn.Module,
bnellnm's avatar
bnellnm committed
1151
    ) -> FusedMoEPermuteExpertsUnpermute:
1152
        from vllm.model_executor.layers.fused_moe import (
1153
1154
            BatchedDeepGemmExperts,
            BatchedTritonExperts,
1155
            TritonExperts,
1156
1157
            TritonOrDeepGemmExperts,
        )
1158

1159
        assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
1160
1161
            "Marlin and ROCm AITER are not supported with all2all yet."
        )
1162

1163
1164
        assert self.moe_quant_config is not None

1165
1166
1167
1168
1169
        if (
            prepare_finalize.activation_format
            == FusedMoEActivationFormat.BatchedExperts
        ):
            max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
bnellnm's avatar
bnellnm committed
1170
            assert max_num_tokens_per_rank is not None
1171
1172
1173
1174

            experts_impl = (
                BatchedDeepGemmExperts if self.allow_deep_gemm else BatchedTritonExperts
            )
bnellnm's avatar
bnellnm committed
1175
            logger.debug(
1176
1177
                "%s(%s): max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
                experts_impl.__name__,
1178
1179
1180
1181
1182
                self.__class__.__name__,
                max_num_tokens_per_rank,
                self.weight_block_size,
                False,
            )
1183
            return experts_impl(
1184
                max_num_tokens=max_num_tokens_per_rank,
1185
                num_dispatchers=prepare_finalize.num_dispatchers(),
1186
                quant_config=self.moe_quant_config,
1187
            )
1188
1189
        elif self.moe.is_lora_enabled:
            return TritonExperts(quant_config=self.moe_quant_config)
1190
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
1191
            # Select GEMM experts with block-scale when weights are block-quantized
1192
            experts = select_cutlass_fp8_gemm_impl(
1193
1194
                self.moe,
                self.moe_quant_config,
1195
                use_deepseek_fp8_block_scale=self.block_quant,
1196
1197
1198
            )
            logger.debug_once("Using %s", experts.__class__.__name__)
            return experts
1199
        else:
bnellnm's avatar
bnellnm committed
1200
1201
            logger.debug(
                "TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s",
1202
1203
1204
1205
                self.__class__.__name__,
                self.weight_block_size,
                False,
            )
bnellnm's avatar
bnellnm committed
1206
            return TritonOrDeepGemmExperts(
1207
                quant_config=self.moe_quant_config,
1208
1209
1210
                allow_deep_gemm=self.allow_deep_gemm,
            )

1211
    def get_fused_moe_quant_config(
1212
        self, layer: torch.nn.Module
1213
    ) -> FusedMoEQuantConfig | None:
1214
1215
1216
1217
        if self.use_marlin:
            return None

        return fp8_w8a8_moe_quant_config(
1218
1219
1220
1221
1222
1223
1224
1225
            w1_scale=(
                layer.w13_weight_scale_inv
                if self.block_quant
                else layer.w13_weight_scale
            ),
            w2_scale=(
                layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
            ),
1226
1227
            a1_scale=layer.w13_input_scale,
            a2_scale=layer.w2_input_scale,
1228
            block_shape=self.weight_block_size,
1229
1230
        )

1231
1232
1233
1234
1235
1236
1237
1238
    @property
    def supports_eplb(self) -> bool:
        return True

    @property
    def allow_inplace(self) -> bool:
        return True

1239
1240
    def apply(
        self,
1241
        layer: FusedMoE,
1242
1243
        x: torch.Tensor,
        router_logits: torch.Tensor,
1244
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
1245
        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
1246
1247
1248
1249
            if layer.enable_eplb:
                raise NotImplementedError("EPLB not supported for `Fp8MoEMethod` yet.")
            assert layer.activation == "silu", (
                f"Expected 'silu' activation but got {layer.activation}"
1250
            )
1251

1252
            if self.block_quant:
1253
                import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe  # noqa: E501, F401
1254
1255

                e_score_correction_bias = (
1256
1257
                    layer.e_score_correction_bias.to(x.dtype)
                    if layer.e_score_correction_bias is not None
1258
1259
                    else None
                )
1260
                routing_method_type = layer.routing_method_type
1261
                return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
1262
1263
1264
                    routing_logits=router_logits.to(torch.float32)
                    if routing_method_type == RoutingMethodType.DeepSeekV3
                    else router_logits,
1265
1266
1267
1268
1269
1270
                    routing_bias=e_score_correction_bias,
                    x=x,
                    w13_weight=layer.w13_weight,
                    w13_weight_scale_inv=layer.w13_weight_scale_inv,
                    w2_weight=layer.w2_weight,
                    w2_weight_scale_inv=layer.w2_weight_scale_inv,
1271
1272
1273
1274
                    global_num_experts=layer.global_num_experts,
                    top_k=layer.top_k,
                    num_expert_group=layer.num_expert_group,
                    topk_group=layer.topk_group,
1275
1276
1277
                    intermediate_size=layer.intermediate_size_per_partition,
                    expert_offset=layer.ep_rank * layer.local_num_experts,
                    local_num_experts=layer.local_num_experts,
1278
                    block_shape=self.weight_block_size,
1279
                    routing_method_type=routing_method_type,
1280
                    routed_scaling=layer.routed_scaling_factor,
1281
1282
                )
            else:
1283
1284
1285
                assert (
                    not layer.renormalize and layer.custom_routing_function is not None
                )
XuruiYang's avatar
XuruiYang committed
1286
                result = apply_flashinfer_per_tensor_scale_fp8(
1287
1288
1289
                    layer=layer,
                    hidden_states=x,
                    router_logits=router_logits,
1290
1291
1292
1293
1294
1295
                    routing_bias=layer.e_score_correction_bias,
                    global_num_experts=layer.global_num_experts,
                    top_k=layer.top_k,
                    num_expert_group=layer.num_expert_group,
                    topk_group=layer.topk_group,
                    apply_router_weight_on_input=layer.apply_router_weight_on_input,
1296
                )
1297

1298
        select_result = layer.select_experts(
1299
1300
1301
1302
            hidden_states=x,
            router_logits=router_logits,
        )

XuruiYang's avatar
XuruiYang committed
1303
1304
        topk_weights, topk_ids, zero_expert_result = select_result

1305
1306
        if self.rocm_aiter_moe_enabled:
            from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (  # noqa: E501
1307
1308
1309
                rocm_aiter_fused_experts,
            )

XuruiYang's avatar
XuruiYang committed
1310
            result = rocm_aiter_fused_experts(
1311
1312
1313
                x,
                layer.w13_weight,
                layer.w2_weight,
1314
1315
                topk_weights=topk_weights,
                topk_ids=topk_ids,
1316
1317
1318
                activation=layer.activation,
                apply_router_weight_on_input=layer.apply_router_weight_on_input,
                expert_map=layer.expert_map,
1319
1320
                quant_config=self.moe_quant_config,
            )
1321
        elif self.use_marlin:
1322
1323
1324
            assert layer.activation == "silu", (
                f"{layer.activation} not supported for Marlin MoE."
            )
1325
            result = fused_marlin_moe(
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
                x,
                layer.w13_weight,
                layer.w2_weight,
                None,
                None,
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                router_logits,
                topk_weights,
                topk_ids,
                quant_type_id=scalar_types.float8_e4m3fn.id,
1337
1338
1339
                apply_router_weight_on_input=layer.apply_router_weight_on_input,
                global_num_experts=layer.global_num_experts,
                expert_map=layer.expert_map,
1340
                input_dtype=self.marlin_input_dtype,
1341
1342
                workspace=layer.workspace,
            )
1343
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
1344
1345
            assert layer.activation == "silu", (
                f"Expected 'silu' activation but got {layer.activation}"
1346
            )
1347
            if not self.block_quant:
1348
1349
1350
1351
1352
                assert (
                    not layer.renormalize and layer.custom_routing_function is not None
                )
                assert layer.scoring_func == "sigmoid", (
                    f"Expected 'sigmoid' scoring func but got {layer.scoring_func}"
1353
1354
1355
1356
                )
            # Delegate to CUTLASS FlashInfer path; function already bound with
            # use_deepseek_fp8_block_scale for block-quant when applicable
            result = self.flashinfer_moe_fn(
1357
1358
1359
1360
1361
                x,
                layer,
                topk_weights,
                topk_ids,
                inplace=False,
1362
1363
1364
1365
                activation=layer.activation,
                global_num_experts=layer.global_num_experts,
                expert_map=layer.expert_map,
                apply_router_weight_on_input=layer.apply_router_weight_on_input,
1366
            )
1367
        else:
1368
            from vllm.model_executor.layers.fused_moe import fused_experts
1369

XuruiYang's avatar
XuruiYang committed
1370
            result = fused_experts(
1371
1372
1373
1374
1375
1376
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                inplace=True,
1377
1378
1379
1380
                activation=layer.activation,
                global_num_experts=layer.global_num_experts,
                apply_router_weight_on_input=layer.apply_router_weight_on_input,
                expert_map=layer.expert_map,
1381
1382
1383
                quant_config=self.moe_quant_config,
                allow_deep_gemm=self.allow_deep_gemm,
                allow_cutlass_block_scaled_grouped_gemm=(
1384
1385
1386
                    self.allow_cutlass_block_scaled_grouped_gemm
                ),
            )
1387
1388

        if layer.zero_expert_num != 0 and layer.zero_expert_type is not None:
1389
            assert not isinstance(result, tuple), (
XuruiYang's avatar
XuruiYang committed
1390
                "Shared + zero experts are mutually exclusive not yet supported"
1391
            )
XuruiYang's avatar
XuruiYang committed
1392
1393
1394
            return result, zero_expert_result
        else:
            return result
1395
1396


1397
1398
1399
class Fp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
1400
1401
1402
    """

    def __init__(self, quant_config: Fp8Config):
1403
        super().__init__(quant_config)