fp8.py 56.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Callable
5
from enum import Enum
6
from typing import TYPE_CHECKING, Any, Optional
7
8
9
10
11

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter

12
import vllm.envs as envs
13
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
14
from vllm import _custom_ops as ops
15
from vllm._aiter_ops import rocm_aiter_ops
16
from vllm.distributed import get_tensor_model_parallel_world_size
17
from vllm.logger import init_logger
18
from vllm.model_executor.layers.batch_invariant import (
19
    vllm_is_batch_invariant,
20
)
bnellnm's avatar
bnellnm committed
21
from vllm.model_executor.layers.fused_moe import (
22
23
24
25
26
27
28
    FusedMoE,
    FusedMoEActivationFormat,
    FusedMoEMethodBase,
    FusedMoEPermuteExpertsUnpermute,
    FusedMoEPrepareAndFinalize,
    FusedMoeWeightScaleSupported,
)
29
from vllm.model_executor.layers.fused_moe.config import (
30
    FusedMoEQuantConfig,
31
    RoutingMethodType,
32
33
    fp8_w8a8_moe_quant_config,
)
34
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
35
36
37
38
39
40
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
)
41
from vllm.model_executor.layers.quantization import QuantizationMethods
42
from vllm.model_executor.layers.quantization.base_config import (
43
44
45
    QuantizationConfig,
    QuantizeMethodBase,
)
46
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
47
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
48
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
49
50
    FlashinferMoeBackend,
    apply_flashinfer_per_tensor_scale_fp8,
51
    build_flashinfer_fp8_cutlass_moe_prepare_finalize,
52
53
54
55
56
57
58
    flashinfer_cutlass_moe_fp8,
    get_flashinfer_moe_backend,
    register_moe_scaling_factors,
    rotate_flashinfer_fp8_moe_weights,
    select_cutlass_fp8_gemm_impl,
    swap_w13_to_w31,
)
59
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
60
61
62
63
64
65
66
67
68
69
70
    W8A8BlockFp8LinearOp,
    create_fp8_input_scale,
    create_fp8_scale_parameter,
    create_fp8_weight_parameter,
    expert_weight_is_col_major,
    maybe_post_process_fp8_weight_block,
    process_fp8_weight_block_strategy,
    process_fp8_weight_tensor_strategy,
    requant_weight_ue8m0_inplace,
    validate_fp8_block_shape,
)
71
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
72
73
74
75
    apply_fp8_marlin_linear,
    prepare_fp8_layer_for_marlin,
    prepare_moe_fp8_layer_for_marlin,
)
76
from vllm.model_executor.layers.quantization.utils.quant_utils import (
77
78
79
    GroupShape,
    is_layer_skipped,
)
80
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
81
82
83
84
85
86
87
88
89
90
91
92
93
    Fp8LinearOp,
    all_close_1d,
    cutlass_block_fp8_supported,
    cutlass_fp8_supported,
    maybe_create_device_identity,
    normalize_e4m3fn_to_e4m3fnuz,
    per_tensor_dequantize,
)
from vllm.model_executor.parameter import (
    BlockQuantScaleParameter,
    ModelWeightParameter,
    PerTensorScaleParameter,
)
94
from vllm.model_executor.utils import set_weight_attrs
95
from vllm.platforms import current_platform
96
from vllm.scalar_type import scalar_types
97
from vllm.utils.deep_gemm import (
98
    fp8_gemm_nt,
99
100
101
    get_col_major_tma_aligned_tensor,
    is_deep_gemm_e8m0_used,
    is_deep_gemm_supported,
102
    should_use_deepgemm_for_fp8_linear,
103
)
104
from vllm.utils.flashinfer import has_flashinfer_moe
105
from vllm.utils.import_utils import has_deep_gemm
106

107
108
109
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

110
111
112
113
ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)

114

115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
class Fp8MoeBackend(Enum):
    NONE = 0
    FLASHINFER_TRTLLM = 1
    FLASHINFER_CUTLASS = 2
    DEEPGEMM = 3
    CUTLASS_BLOCK_SCALED_GROUPED_GEMM = 4
    MARLIN = 5
    TRITON = 6


def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend:
    """
    Select the primary FP8 MoE backend
    Note: Shape-specific fallbacks may still occur at runtime.
    """
    # prefer FlashInfer backends when available and enabled on supported GPUs
131
132
133
134
135
136
    if (
        current_platform.is_cuda()
        and current_platform.is_device_capability(100)
        and envs.VLLM_USE_FLASHINFER_MOE_FP8
        and has_flashinfer_moe()
    ):
137
138
        backend = get_flashinfer_moe_backend()
        if backend == FlashinferMoeBackend.TENSORRT_LLM:
139
            logger.info_once("Using FlashInfer FP8 MoE TRTLLM backend for SM100")
140
141
            return Fp8MoeBackend.FLASHINFER_TRTLLM
        else:
142
143
144
145
146
147
148
            if block_quant:
                raise ValueError(
                    "FlashInfer FP8 MoE throughput backend does not "
                    "support block quantization. Please use "
                    "VLLM_FLASHINFER_MOE_BACKEND=latency "
                    "instead."
                )
149
            logger.info_once("Using FlashInfer FP8 MoE CUTLASS backend for SM100")
150
151
152
            return Fp8MoeBackend.FLASHINFER_CUTLASS

    # weight-only path for older GPUs without native FP8
153
154
155
156
    use_marlin = (
        not current_platform.has_device_capability(89)
        or envs.VLLM_TEST_FORCE_FP8_MARLIN
    )
157
158
159
160
161
162
163
164
165
    if current_platform.is_rocm():
        use_marlin = False
    if use_marlin:
        logger.info_once("Using Marlin backend for FP8 MoE")
        return Fp8MoeBackend.MARLIN

    # deepGEMM on supported platforms with block-quantized weights
    if envs.VLLM_USE_DEEP_GEMM and block_quant:
        if not has_deep_gemm():
166
            logger.warning_once("DeepGEMM backend requested but not available.")
167
168
169
170
171
        elif is_deep_gemm_supported():
            logger.info_once("Using DeepGEMM backend for FP8 MoE")
            return Fp8MoeBackend.DEEPGEMM

    # CUTLASS BlockScaled GroupedGemm on SM100 with block-quantized weights
172
173
174
175
176
177
    if (
        current_platform.is_cuda()
        and current_platform.is_device_capability(100)
        and block_quant
    ):
        logger.info_once("Using Cutlass BlockScaled GroupedGemm backend for FP8 MoE")
178
179
180
181
182
183
184
        return Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM

    # default to Triton
    logger.info_once("Using Triton backend for FP8 MoE")
    return Fp8MoeBackend.TRITON


185
class Fp8Config(QuantizationConfig):
186
187
    """Config class for FP8."""

188
189
    def __init__(
        self,
190
        is_checkpoint_fp8_serialized: bool = False,
191
        activation_scheme: str = "dynamic",
192
193
        ignored_layers: list[str] | None = None,
        weight_block_size: list[int] | None = None,
194
    ) -> None:
195
        super().__init__()
196

197
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
198

199
        if activation_scheme not in ACTIVATION_SCHEMES:
200
            raise ValueError(f"Unsupported activation scheme {activation_scheme}")
201
        self.activation_scheme = activation_scheme
202
        self.ignored_layers = ignored_layers or []
203
204
205
206
        if weight_block_size is not None:
            if not is_checkpoint_fp8_serialized:
                raise ValueError(
                    "The block-wise quantization only supports fp8-serialized "
207
208
                    "checkpoint for now."
                )
209
210
211
            if len(weight_block_size) != 2:
                raise ValueError(
                    "The quantization block size of weight must have 2 "
212
213
                    f"dimensions, but got {len(weight_block_size)} dimensions"
                )
214
            if activation_scheme != "dynamic":
215
216
217
218
219
                raise ValueError(
                    "The block-wise quantization only supports "
                    "dynamic activation scheme for now, but got "
                    f"{activation_scheme} activation scheme."
                )
220
        self.weight_block_size = weight_block_size
221

222
    @classmethod
223
    def get_name(cls) -> QuantizationMethods:
224
225
226
        return "fp8"

    @classmethod
227
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
228
229
230
231
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
232
        return 80
233
234

    @classmethod
235
    def get_config_filenames(cls) -> list[str]:
236
237
        return []

238
239
    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if self.ignored_layers is not None:
240
            self.ignored_layers = hf_to_vllm_mapper.apply_list(self.ignored_layers)
241

242
    @classmethod
243
    def from_config(cls, config: dict[str, Any]) -> "Fp8Config":
244
        quant_method = cls.get_from_keys(config, ["quant_method"])
245
        is_checkpoint_fp8_serialized = "fp8" in quant_method
246
        activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
247
        ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
248
        weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None)
249
        if not ignored_layers:
250
251
252
253
254
255
256
257
258
259
260
261
262
            ignored_layers = cls.get_from_keys_or(
                config, ["modules_to_not_convert"], None
            )
        return cls(
            is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
            activation_scheme=activation_scheme,
            ignored_layers=ignored_layers,
            weight_block_size=weight_block_size,
        )

    def get_xpu_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
263
264
        from vllm.attention.layer import Attention
        from vllm.model_executor.layers.quantization.ipex_quant import (
265
266
267
268
            XPUFp8LinearMethod,
            XPUFp8MoEMethod,
        )

269
270
271
272
        fp8_config = Fp8Config(
            is_checkpoint_fp8_serialized=self.is_checkpoint_fp8_serialized,
            activation_scheme=self.activation_scheme,
            ignored_layers=self.ignored_layers,
273
274
            weight_block_size=self.weight_block_size,
        )
275
276

        if isinstance(layer, LinearBase):
277
278
279
280
281
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
282
283
284
285
286
287
288
289
                return UnquantizedLinearMethod()
            return XPUFp8LinearMethod(fp8_config)
        elif isinstance(layer, FusedMoE):
            return XPUFp8MoEMethod(fp8_config, layer)
        elif isinstance(layer, Attention):
            return Fp8KVCacheMethod(self)
        return None

290
291
292
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
293
294
        from vllm.attention.layer import Attention  # Avoid circular import

295
296
        if current_platform.is_xpu():
            return self.get_xpu_quant_method(layer, prefix)
297
        if isinstance(layer, LinearBase):
298
299
300
301
302
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
303
                return UnquantizedLinearMethod()
304
            return Fp8LinearMethod(self)
305
        elif isinstance(layer, FusedMoE):
306
307
308
309
310
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
XuruiYang's avatar
XuruiYang committed
311
                return UnquantizedFusedMoEMethod(layer.moe_config)
312
            return Fp8MoEMethod(self, layer)
313
        elif isinstance(layer, Attention):
314
            return Fp8KVCacheMethod(self)
315
        return None
316

317
    def get_cache_scale(self, name: str) -> str | None:
318
319
320
321
322
323
324
325
326
327
328
329
        """
        Check whether the param name matches the format for k/v cache scales
        in compressed-tensors. If this is the case, return its equivalent
        param name expected by vLLM

        :param name: param name
        :return: matching param name for KV cache scale in vLLM
        """
        if name.endswith(".output_scale") and ".k_proj" in name:
            return name.replace(".k_proj.output_scale", ".attn.k_scale")
        if name.endswith(".output_scale") and ".v_proj" in name:
            return name.replace(".v_proj.output_scale", ".attn.v_scale")
330
331
332
333
334
        if name.endswith(".output_scale") and ".q_proj" in name:
            return name.replace(".q_proj.output_scale", ".attn.q_scale")
        if name.endswith("self_attn.prob_output_scale"):
            return name.replace(".prob_output_scale", ".attn.prob_scale")
        # If no matches, return None
335
336
        return None

337
338
339

class Fp8LinearMethod(LinearMethodBase):
    """Linear method for FP8.
340
341
342
343
344
345
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.
346
347
348
349
350

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
    2. Only support float8_e4m3fn data type due to the limitation of
       torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
351

352
353
354
355
    Args:
        quant_config: The quantization config.
    """

356
    def __init__(self, quant_config: Fp8Config):
357
        self.quant_config = quant_config
358
        self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
359
        self.out_dtype = torch.get_default_dtype()
360

361
362
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
363
364
365
366
        self.use_marlin = (
            not current_platform.has_device_capability(89)
            or envs.VLLM_TEST_FORCE_FP8_MARLIN
        )
367
        # Disable marlin for rocm
368
        if current_platform.is_rocm():
369
            self.use_marlin = False
370
        if vllm_is_batch_invariant():
371
            self.use_marlin = False
372

373
        self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enaled()
374
        self.use_deep_gemm = is_deep_gemm_supported()
375

376
377
        self.weight_block_size = self.quant_config.weight_block_size
        self.block_quant = self.weight_block_size is not None
378
        self.act_q_static = self.quant_config.activation_scheme == "static"
379
380
        if self.weight_block_size:
            self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
381
        else:
382
383
384
385
386
            # Use per-token quantization for better perf if dynamic and cutlass
            if not self.act_q_static and cutlass_fp8_supported():
                self.act_q_group_shape = GroupShape.PER_TOKEN
            else:
                self.act_q_group_shape = GroupShape.PER_TENSOR
387

388
389
390
391
392
393
394
395
396
397
398
399
        if self.block_quant:
            assert not self.act_q_static
            assert self.weight_block_size is not None
            self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
                weight_group_shape=GroupShape(*self.weight_block_size),
                act_quant_group_shape=self.act_q_group_shape,
                cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
                use_aiter_and_is_supported=self.use_aiter_and_is_supported,
            )
        else:
            self.fp8_linear = Fp8LinearOp(
                act_quant_static=self.act_q_static,
400
401
                act_quant_group_shape=self.act_q_group_shape,
            )
402

403
404
405
406
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
407
        output_partition_sizes: list[int],
408
409
410
411
412
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
413
414
        maybe_create_device_identity()

415
        output_size_per_partition = sum(output_partition_sizes)
416
        weight_loader = extra_weight_attrs.get("weight_loader")
417
418
419
420
421
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None
422

423
        if self.block_quant:
424
425
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
426
427
428
429
430
431
432
433
            validate_fp8_block_shape(
                layer,
                input_size,
                output_size,
                input_size_per_partition,
                output_partition_sizes,
                self.weight_block_size,
            )
434

435
        # WEIGHT
436
        if self.quant_config.is_checkpoint_fp8_serialized:
437
438
439
            weight = create_fp8_weight_parameter(
                output_size_per_partition, input_size_per_partition, weight_loader
            )
440
441
        else:
            # For non-serialized checkpoints, use original dtype
442
443
444
445
446
447
448
449
450
451
            weight = ModelWeightParameter(
                data=torch.empty(
                    output_size_per_partition,
                    input_size_per_partition,
                    dtype=params_dtype,
                ),
                input_dim=1,
                output_dim=0,
                weight_loader=weight_loader,
            )
452
453
        layer.register_parameter("weight", weight)

454
455
456
457
        # If checkpoint is serialized fp8, load them.
        # Otherwise, wait until process_weights_after_loading.
        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
458
            if not self.block_quant:
459
460
461
462
463
464
465
                scale = create_fp8_scale_parameter(
                    PerTensorScaleParameter,
                    output_partition_sizes,
                    input_size_per_partition,
                    None,
                    weight_loader,
                )
466
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
467
468
                layer.register_parameter("weight_scale", scale)
            else:
469
470
                assert not self.act_q_static
                assert self.weight_block_size is not None
471
472
473
474
475
476
477
                scale = create_fp8_scale_parameter(
                    BlockQuantScaleParameter,
                    output_partition_sizes,
                    input_size_per_partition,
                    self.weight_block_size,
                    weight_loader,
                )
478
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
479
480
                # The weight_scale_inv name is intentional for deepseekv3
                layer.register_parameter("weight_scale_inv", scale)
481

482
            # INPUT ACTIVATION SCALE
483
            if self.act_q_static:
484
                scale = create_fp8_input_scale(output_partition_sizes, weight_loader)
485
                set_weight_attrs(scale, {"scale_type": "input_scale"})
486
                layer.register_parameter("input_scale", scale)
487
488
            else:
                layer.register_parameter("input_scale", None)
489

490
    def process_weights_after_loading(self, layer: Module) -> None:
491
        size_k_first = True
492
        input_scale = None
493
        # TODO(rob): refactor block quant into separate class.
494
        if self.block_quant:
495
            assert not self.act_q_static
496
            size_k_first = False
497

498
            weight, weight_scale = process_fp8_weight_block_strategy(
499
500
                layer.weight, layer.weight_scale_inv
            )
501
502
503
            # Delete the weight_scale_inv parameter to avoid confusion
            # with the weight_scale parameter
            del layer.weight_scale_inv
504

505
        # If checkpoint not serialized fp8, quantize the weights.
506
        elif not self.quant_config.is_checkpoint_fp8_serialized:
507
            qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
508
            weight = qweight.t()
509

510
        # If checkpoint is fp8 per-tensor, handle that there are N scales for N
511
        # shards in a fused module
512
        else:
513
514
            weight = layer.weight
            weight_scale = layer.weight_scale
515
516
517

            # If using w8a8, torch._scaled_mm needs per tensor, so
            # requantize the logical shards as a single weight.
518
            if not self.use_marlin:
519
520
521
522
523
524
                weight, weight_scale, input_scale = process_fp8_weight_tensor_strategy(
                    weight,
                    weight_scale,
                    layer.logical_widths,
                    getattr(layer, "input_scale", None),
                )
525
526
527
528
529
530
531
532
                if self.act_q_static:
                    assert input_scale is not None
                    input_scale = input_scale.max()
            weight = weight.t()

        # Update layer with new values.
        layer.weight = Parameter(weight.data, requires_grad=False)
        layer.weight_scale = Parameter(weight_scale.data, requires_grad=False)
533
534
535
536
537
        layer.input_scale = (
            Parameter(input_scale, requires_grad=False)
            if input_scale is not None
            else None
        )
538

539
        if self.use_marlin:
540
            prepare_fp8_layer_for_marlin(layer, size_k_first)
541
542
            # Activations not quantized for marlin.
            del layer.input_scale
543
            return
544

545
        if self.block_quant:
546
            maybe_post_process_fp8_weight_block(layer, self.cutlass_block_fp8_supported)
547

548
549
550
551
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
552
        bias: torch.Tensor | None = None,
553
    ) -> torch.Tensor:
554
555
        # if batch invariant mode is enabled, prefer DeepGEMM FP8 path
        # we will use BF16 dequant when DeepGEMM is not supported.
556
        if vllm_is_batch_invariant():
557
558
            # Call is_deep_gemm_supported() ahead of time for torch.compile
            # dynamo has trouble tracing through
559
            if self.block_quant and should_use_deepgemm_for_fp8_linear(
560
                torch.bfloat16, layer.weight, self.use_deep_gemm
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
            ):
                # use group quant consistent with block size across K
                assert self.act_q_group_shape is not None
                q_input, input_scale = QuantFP8(
                    False,
                    self.act_q_group_shape,
                    column_major_scales=True,
                )(x)

                output_2d = torch.empty(
                    (q_input.shape[0], layer.weight.shape[0]),
                    dtype=torch.bfloat16,
                    device=q_input.device,
                )
                fp8_gemm_nt(
                    (q_input, input_scale),
                    (layer.weight, layer.weight_scale),
                    output_2d,
                )
                if bias is not None:
                    output_2d = output_2d + bias
                return output_2d

584
585
586
587
588
589
590
591
592
593
594
595
596
597
            # Dequantize FP8 weights to BF16
            weight_fp8 = layer.weight.to(torch.bfloat16)
            weight_scale = layer.weight_scale.to(torch.bfloat16)

            # Handle different quantization granularities
            if self.block_quant:
                # Block-wise quantization:
                # - Weight is NOT transposed, shape is [N, K] (output_size, input_size)
                # - Scale has shape [num_blocks_k, num_blocks_n] (TRANSPOSED!)
                assert self.weight_block_size is not None
                block_n, block_k = self.weight_block_size  # Note: order is [N, K]

                N, K = weight_fp8.shape

598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
                # determine expected number of blocks along N and K
                num_blocks_n = (N + block_n - 1) // block_n
                num_blocks_k = (K + block_k - 1) // block_k

                # scale layout may be [num_blocks_n, num_blocks_k]
                # or [num_blocks_k, num_blocks_n] depending on backend
                if weight_scale.dim() != 2:
                    raise RuntimeError(
                        f"FP8 block scale must be 2D, got {tuple(weight_scale.shape)}"
                    )

                scale_rows, scale_cols = weight_scale.shape
                if (scale_rows, scale_cols) == (num_blocks_k, num_blocks_n):
                    if num_blocks_n == num_blocks_k:
                        # ambiguous square case, warn and skip transpose
                        logger.warning(
                            "Batch-invariant FP8: square block-scale %dx%d; "
                            "skipping transpose to avoid misorientation.",
                            scale_rows,
                            scale_cols,
                        )
                    else:
                        # clear KN -> transpose to NK
                        weight_scale = weight_scale.t()
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662

                # Expand scale to match weight dimensions
                # scale_expanded should have shape [N, K]
                scale_expanded = weight_scale.repeat_interleave(
                    block_n, dim=0
                ).repeat_interleave(block_k, dim=1)
                # Trim to exact weight size (in case of padding)
                scale_expanded = scale_expanded[:N, :K]
                weight_bf16 = weight_fp8 * scale_expanded
            else:
                # Per-tensor quantization: weight IS transposed to [K, N]
                # scale should be scalar or [1] or per-output-channel [N]
                if weight_scale.numel() == 1:
                    # Per-tensor: simple scalar multiplication
                    weight_bf16 = weight_fp8 * weight_scale
                else:
                    # Multiple scales (fused modules like QKV)
                    # Try to infer correct broadcasting
                    # weight is [K, N], scale could be [num_logical_weights]
                    # Need to figure out how to broadcast - for now just try
                    # direct multiplication
                    if (
                        weight_scale.dim() == 1
                        and weight_scale.shape[0] == weight_fp8.shape[0]
                    ):
                        # Per-row scaling
                        weight_bf16 = weight_fp8 * weight_scale.unsqueeze(1)
                    else:
                        # Fallback
                        weight_bf16 = weight_fp8 * weight_scale

            # For block quant, weight is [N, K], for per-tensor it's [K, N]
            # F.linear expects weight to be [N, K], so:
            if self.block_quant:
                # Already in correct shape [N, K]
                output = torch.nn.functional.linear(x, weight_bf16, bias)
            else:
                # Need to transpose back: [K, N] -> [N, K]
                output = torch.nn.functional.linear(x, weight_bf16.t(), bias)
            return output

663
        if self.use_marlin:
664
665
666
667
            return apply_fp8_marlin_linear(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale,
668
669
670
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
671
672
                bias=bias,
            )
673

674
        if self.block_quant:
675
676
677
            assert self.weight_block_size is not None

            return self.w8a8_block_fp8_linear.apply(
678
                input=x,
679
680
681
                weight=layer.weight,
                weight_scale=layer.weight_scale,
                input_scale=layer.input_scale,
682
                bias=bias,
683
            )
684

685
686
687
688
689
690
691
692
        return self.fp8_linear.apply(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            out_dtype=self.out_dtype,
            input_scale=layer.input_scale,
            bias=bias,
        )
693
694


695
696
697
698
699
700
701
702
703
704
705
706
707
class Fp8MoEMethod(FusedMoEMethodBase):
    """MoE method for FP8.
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

708
709
710
    def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
        super().__init__(layer.moe_config)
        self.layer = layer
711
        self.quant_config = quant_config
712
        self.weight_block_size = self.quant_config.weight_block_size
713
        self.block_quant: bool = self.weight_block_size is not None
714
        self.fp8_backend = get_fp8_moe_backend(self.block_quant)
715

716
        self.use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN
717
        self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
718
719
720
721
722
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
            self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
        elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
            self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS

723
        self.allow_deep_gemm = self.fp8_backend == Fp8MoeBackend.DEEPGEMM
724
725
726
        self.allow_cutlass_block_scaled_grouped_gemm = (
            self.fp8_backend == Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM
        )
727

728
729
730
731
732
733
734
735
736
    def create_weights(
        self,
        layer: Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
737
738
739
740
741
742
        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

743
744
        if self.quant_config.is_checkpoint_fp8_serialized:
            params_dtype = torch.float8_e4m3fn
745
        if self.block_quant:
746
747
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
748
749
            tp_size = get_tensor_model_parallel_world_size()
            block_n, block_k = (
750
751
                self.weight_block_size[0],
                self.weight_block_size[1],
752
753
754
755
756
            )
            # NOTE: To ensure proper alignment of the block-wise quantization
            # scales, the output_size of the weights for both the gate and up
            # layers must be divisible by block_n.
            # Required by column parallel or enabling merged weights
757
            if intermediate_size_per_partition % block_n != 0:
758
759
                raise ValueError(
                    f"The output_size of gate's and up's weight = "
760
                    f"{intermediate_size_per_partition} is not divisible by "
761
762
763
                    f"weight quantization block_n = {block_n}."
                )
            if tp_size > 1 and intermediate_size_per_partition % block_k != 0:
764
                # Required by row parallel
765
766
767
                raise ValueError(
                    f"The input_size of down's weight = "
                    f"{intermediate_size_per_partition} is not divisible by "
768
769
                    f"weight quantization block_k = {block_k}."
                )
770
771

        # WEIGHTS
772
773
774
775
776
777
778
779
780
        w13_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                hidden_size,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
781
782
783
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

784
785
786
787
788
789
790
791
792
        w2_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
793
794
795
796
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
797
798
799
        if not self.block_quant:
            # Allocate 2 scales for w1 and w3 respectively.
            # They will be combined to a single scale after weight loading.
800
801
802
803
804
805
            w13_weight_scale = torch.nn.Parameter(
                torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
            )
            w2_weight_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
806
807
808
809
810
811
            layer.register_parameter("w13_weight_scale", w13_weight_scale)
            layer.register_parameter("w2_weight_scale", w2_weight_scale)
        else:
            w13_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
812
                    2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
813
814
815
816
817
818
819
820
821
                    (hidden_size + block_k - 1) // block_k,
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            w2_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
                    (hidden_size + block_n - 1) // block_n,
822
                    (intermediate_size_per_partition + block_k - 1) // block_k,
823
824
825
826
827
828
829
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
            layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
            assert self.quant_config.activation_scheme == "dynamic"
830

831
832
833
        # Add the quantization method used (per tensor/grouped/channel)
        # to ensure the weight scales are loaded in properly
        extra_weight_attrs.update(
834
835
836
837
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
            if self.block_quant
            else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
        )
838
839
840
841
        # If loading fp8 checkpoint, pass the weight loaders.
        # If loading an fp16 checkpoint, do not (we will quantize in
        #   process_weights_after_loading()
        if self.quant_config.is_checkpoint_fp8_serialized:
842
843
            set_weight_attrs(w13_weight_scale, extra_weight_attrs)
            set_weight_attrs(w2_weight_scale, extra_weight_attrs)
844
845
846
847
848
849

        # INPUT_SCALES
        if self.quant_config.activation_scheme == "static":
            if not self.quant_config.is_checkpoint_fp8_serialized:
                raise ValueError(
                    "Found static activation scheme for checkpoint that "
850
851
                    "was not serialized fp8."
                )
852

853
854
855
            w13_input_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
856
            layer.register_parameter("w13_input_scale", w13_input_scale)
857
            set_weight_attrs(w13_input_scale, extra_weight_attrs)
858

859
860
861
            w2_input_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
862
            layer.register_parameter("w2_input_scale", w2_input_scale)
863
864
            set_weight_attrs(w2_input_scale, extra_weight_attrs)

865
        else:
866
867
            layer.w13_input_scale = None
            layer.w2_input_scale = None
868

869
870
        self.rocm_aiter_moe_enabled = False

871
    def process_weights_after_loading(self, layer: Module) -> None:
872
873
        # Lazy import to avoid importing triton too early.

874
        self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
875

876
        # TODO (rob): refactor block quant into separate class.
877
        if self.block_quant:
878
            assert self.quant_config.activation_scheme == "dynamic"
879
            if current_platform.is_fp8_fnuz():
880
                w13_weight, w13_weight_scale_inv, w13_input_scale = (
881
                    normalize_e4m3fn_to_e4m3fnuz(
882
883
884
885
886
887
                        layer.w13_weight,
                        layer.w13_weight_scale_inv,
                        layer.w13_input_scale,
                    )
                )
                w2_weight, w2_weight_scale_inv, w2_input_scale = (
888
                    normalize_e4m3fn_to_e4m3fnuz(
889
890
891
                        layer.w2_weight, layer.w2_weight_scale_inv, layer.w2_input_scale
                    )
                )
892
            elif self.flashinfer_moe_backend is not None:
893
894
                # NOTE: weights have to be swapped since the activation is
                # applied on different half for flashinfer vs vllm
895
                w13_weight = swap_w13_to_w31(layer.w13_weight.data)
896
                w13_weight_scale_inv = swap_w13_to_w31(layer.w13_weight_scale_inv.data)
897
898
                w2_weight = layer.w2_weight.data
                w2_weight_scale_inv = layer.w2_weight_scale_inv.data
899
900
901
902
903
904
905
906
            else:
                w13_weight = layer.w13_weight.data
                w13_weight_scale_inv = layer.w13_weight_scale_inv.data
                w2_weight = layer.w2_weight
                w2_weight_scale_inv = layer.w2_weight_scale_inv

            # torch.compile() cannot use Parameter subclasses.
            layer.w13_weight = Parameter(w13_weight, requires_grad=False)
907
908
909
            layer.w13_weight_scale_inv = Parameter(
                w13_weight_scale_inv, requires_grad=False
            )
910
            layer.w2_weight = Parameter(w2_weight, requires_grad=False)
911
912
913
            layer.w2_weight_scale_inv = Parameter(
                w2_weight_scale_inv, requires_grad=False
            )
914
            if self.rocm_aiter_moe_enabled:
915
                # reshaping weights is required for aiter moe kernel.
916
                shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
917
918
                    layer.w13_weight.data, layer.w2_weight.data
                )
919

920
921
                layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
922

923
            # DeepGemm scales need to be transposed and aligned. We try to do
924
            # it ahead of time for performance reasons.
925
            if self.allow_deep_gemm and not is_deep_gemm_e8m0_used():
926
                if expert_weight_is_col_major(layer.w13_weight_scale_inv):
927
928
929
                    layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(
                        layer.w13_weight_scale_inv
                    )
930
                if expert_weight_is_col_major(layer.w2_weight_scale_inv):
931
932
933
                    layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(
                        layer.w2_weight_scale_inv
                    )
934

935
        # If checkpoint is fp16, quantize in place.
936
        elif not self.quant_config.is_checkpoint_fp8_serialized:
937
            fp8_dtype = current_platform.fp8_dtype()
938
            w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
939
            w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
940
941
942

            # Re-initialize w13_scale because we directly quantize
            # merged w13 weights and generate a single scaling factor.
943
944
945
946
947
948
949
950
            layer.w13_weight_scale = torch.nn.Parameter(
                torch.ones(
                    layer.local_num_experts,
                    dtype=torch.float32,
                    device=w13_weight.device,
                ),
                requires_grad=False,
            )
951
            for expert in range(layer.local_num_experts):
952
953
954
955
956
957
958
959
                w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
                    ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
                )
                w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
                    ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
                )
            layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
            layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
960
            if self.rocm_aiter_moe_enabled:
961
                # reshaping weights is required for aiter moe kernel.
962
                shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
963
964
                    layer.w13_weight, layer.w2_weight
                )
965

966
967
                layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
968
969
970
971
972
973
974
        # If checkpoint is fp8, we need to handle that the
        # MoE kernels require single activation scale and single weight
        # scale for w13 per expert.
        else:
            # Fp8 moe kernels require a single activation scale.
            # We take the max of all the scales in case they differ.
            if self.quant_config.activation_scheme == "static":
975
                if layer.w13_input_scale is None or layer.w2_input_scale is None:
976
977
                    raise ValueError(
                        "QuantConfig has static quantization, but found "
978
979
980
981
982
                        "activation scales are None."
                    )
                if not all_close_1d(layer.w13_input_scale) or not all_close_1d(
                    layer.w2_input_scale
                ):
983
                    logger.warning_once(
984
985
                        "Found input_scales that are not equal for "
                        "fp8 MoE layer. Using the maximum across experts "
986
987
                        "for each layer."
                    )
988
                layer.w13_input_scale = torch.nn.Parameter(
989
990
                    layer.w13_input_scale.max(), requires_grad=False
                )
991
                layer.w2_input_scale = torch.nn.Parameter(
992
993
                    layer.w2_input_scale.max(), requires_grad=False
                )
994
            if current_platform.is_fp8_fnuz():
995
                # Normalize the weights and scales
996
                w13_weight, w13_weight_scale, w13_input_scale = (
997
                    normalize_e4m3fn_to_e4m3fnuz(
998
999
1000
1001
                        layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale
                    )
                )
                w2_weight, w2_weight_scale, w2_input_scale = (
1002
                    normalize_e4m3fn_to_e4m3fnuz(
1003
1004
1005
                        layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale
                    )
                )
1006
                # Reset the parameter
1007
                layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
1008
                layer.w13_weight_scale = torch.nn.Parameter(
1009
1010
                    w13_weight_scale, requires_grad=False
                )
1011
1012
                if w13_input_scale is not None:
                    layer.w13_input_scale = torch.nn.Parameter(
1013
1014
1015
1016
1017
1018
                        w13_input_scale, requires_grad=False
                    )
                layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
                layer.w2_weight_scale = torch.nn.Parameter(
                    w2_weight_scale, requires_grad=False
                )
1019
1020
                if w2_input_scale is not None:
                    layer.w2_input_scale = torch.nn.Parameter(
1021
1022
                        w2_input_scale, requires_grad=False
                    )
1023
1024
1025

            # Fp8 moe kernel needs single weight scale for w13 per expert.
            # We take the max then dequant and requant each expert.
1026
            assert layer.w13_weight_scale is not None
1027
            shard_size = layer.intermediate_size_per_partition
1028
            max_w13_scales = layer.w13_weight_scale.max(dim=1).values
1029
            for expert_id in range(layer.local_num_experts):
1030
1031
1032
                start = 0
                for shard_id in range(2):
                    dq_weight = per_tensor_dequantize(
1033
1034
1035
1036
1037
1038
                        layer.w13_weight[expert_id][start : start + shard_size, :],
                        layer.w13_weight_scale[expert_id][shard_id],
                    )
                    layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
                        ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
                    )
1039
1040
                    start += shard_size

1041
            if self.rocm_aiter_moe_enabled:
1042
                shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
1043
1044
                    layer.w13_weight, layer.w2_weight
                )
1045

1046
1047
                layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
1048

1049
1050
1051
            layer.w13_weight_scale = torch.nn.Parameter(
                max_w13_scales, requires_grad=False
            )
1052

1053
1054
1055
1056
1057
1058
            if self.flashinfer_moe_backend is not None:
                # NOTE: weights have to be swapped since the activation is
                # applied on different half for flashinfer vs vllm
                assert not self.block_quant
                register_moe_scaling_factors(layer)
                w13_weight = swap_w13_to_w31(layer.w13_weight.data)
1059
                if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
1060
1061
1062
                    rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
                layer.w13_weight.data = w13_weight.data

1063
1064
1065
1066
1067
        if self.use_marlin:
            prepare_moe_fp8_layer_for_marlin(layer, False)
            # Activations not quantized for marlin.
            del layer.w13_input_scale
            del layer.w2_input_scale
1068

1069
        if is_deep_gemm_e8m0_used() and self.block_quant:
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
            assert layer.weight_block_size is not None
            # Re-quantise the expert weights so their scales are UE8M0.
            block_sz = tuple(layer.weight_block_size)
            requant_weight_ue8m0_inplace(
                layer.w13_weight.data,
                layer.w13_weight_scale_inv.data,
                block_sz,
            )
            requant_weight_ue8m0_inplace(
                layer.w2_weight.data,
                layer.w2_weight_scale_inv.data,
                block_sz,
            )

            # Ensure column-major TMA alignment expected by DeepGEMM.
1085
            if expert_weight_is_col_major(layer.w13_weight_scale_inv):
1086
                layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(
1087
1088
                    layer.w13_weight_scale_inv
                )
1089
            if expert_weight_is_col_major(layer.w2_weight_scale_inv):
1090
                layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(
1091
1092
                    layer.w2_weight_scale_inv
                )
1093

1094
    def maybe_make_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize | None:
1095
1096
1097
1098
1099
        if (
            self.rocm_aiter_moe_enabled
            or self.use_marlin
            or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1100
1101
            return None
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
1102
1103
1104
            prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
                self.moe
            )
1105
1106
1107
1108
1109
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
        else:
            return super().maybe_make_prepare_finalize()

bnellnm's avatar
bnellnm committed
1110
1111
1112
    def select_gemm_impl(
        self,
        prepare_finalize: FusedMoEPrepareAndFinalize,
1113
        layer: torch.nn.Module,
bnellnm's avatar
bnellnm committed
1114
    ) -> FusedMoEPermuteExpertsUnpermute:
1115
        from vllm.model_executor.layers.fused_moe import (
1116
1117
1118
            BatchedTritonOrDeepGemmExperts,
            TritonOrDeepGemmExperts,
        )
1119

1120
        assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
1121
1122
            "Marlin and ROCm AITER are not supported with all2all yet."
        )
1123

1124
1125
        assert self.moe_quant_config is not None

1126
1127
1128
1129
1130
        if (
            prepare_finalize.activation_format
            == FusedMoEActivationFormat.BatchedExperts
        ):
            max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
bnellnm's avatar
bnellnm committed
1131
1132
1133
1134
            assert max_num_tokens_per_rank is not None
            logger.debug(
                "BatchedTritonOrDeepGemmExperts(%s): "
                "max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
1135
1136
1137
1138
1139
                self.__class__.__name__,
                max_num_tokens_per_rank,
                self.weight_block_size,
                False,
            )
bnellnm's avatar
bnellnm committed
1140
            return BatchedTritonOrDeepGemmExperts(
1141
                max_num_tokens=max_num_tokens_per_rank,
1142
                num_dispatchers=prepare_finalize.num_dispatchers(),
1143
                quant_config=self.moe_quant_config,
1144
                allow_deep_gemm=self.allow_deep_gemm,
1145
            )
1146
1147
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
            experts = select_cutlass_fp8_gemm_impl(
1148
1149
                self.moe,
                self.moe_quant_config,
1150
1151
1152
            )
            logger.debug_once("Using %s", experts.__class__.__name__)
            return experts
1153
        else:
bnellnm's avatar
bnellnm committed
1154
1155
            logger.debug(
                "TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s",
1156
1157
1158
1159
                self.__class__.__name__,
                self.weight_block_size,
                False,
            )
bnellnm's avatar
bnellnm committed
1160
            return TritonOrDeepGemmExperts(
1161
                quant_config=self.moe_quant_config,
1162
1163
1164
                allow_deep_gemm=self.allow_deep_gemm,
            )

1165
    def get_fused_moe_quant_config(
1166
        self, layer: torch.nn.Module
1167
    ) -> FusedMoEQuantConfig | None:
1168
1169
1170
1171
        if self.use_marlin:
            return None

        return fp8_w8a8_moe_quant_config(
1172
1173
1174
1175
1176
1177
1178
1179
            w1_scale=(
                layer.w13_weight_scale_inv
                if self.block_quant
                else layer.w13_weight_scale
            ),
            w2_scale=(
                layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
            ),
1180
1181
            a1_scale=layer.w13_input_scale,
            a2_scale=layer.w2_input_scale,
1182
            block_shape=self.weight_block_size,
1183
1184
        )

1185
1186
1187
1188
1189
1190
1191
1192
    @property
    def supports_eplb(self) -> bool:
        return True

    @property
    def allow_inplace(self) -> bool:
        return True

1193
1194
1195
1196
1197
1198
1199
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
1200
        use_grouped_topk: bool = False,
1201
1202
        topk_group: int | None = None,
        num_expert_group: int | None = None,
1203
        global_num_experts: int = -1,
1204
1205
        expert_map: torch.Tensor | None = None,
        custom_routing_function: Callable | None = None,
Simon Mo's avatar
Simon Mo committed
1206
        scoring_func: str = "softmax",
1207
        routed_scaling_factor: float = 1.0,
1208
        e_score_correction_bias: torch.Tensor | None = None,
1209
        apply_router_weight_on_input: bool = False,
Michael Goin's avatar
Michael Goin committed
1210
        activation: str = "silu",
1211
        enable_eplb: bool = False,
1212
1213
1214
1215
        expert_load_view: torch.Tensor | None = None,
        logical_to_physical_map: torch.Tensor | None = None,
        logical_replica_count: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
1216
1217
1218
1219
1220
        if enable_eplb:
            assert expert_load_view is not None
            assert logical_to_physical_map is not None
            assert logical_replica_count is not None
            assert isinstance(layer, FusedMoE)
1221

1222
        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
1223
1224
1225
            assert activation == "silu", (
                f"Expected 'silu' activation but got {activation}"
            )
1226

1227
            if self.block_quant:
1228
                import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe  # noqa: E501, F401
1229
1230
1231
1232
1233
1234

                e_score_correction_bias = (
                    e_score_correction_bias.to(x.dtype)
                    if e_score_correction_bias is not None
                    else None
                )
1235
                routing_method_type = layer.routing_method_type
1236
                return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
1237
1238
1239
                    routing_logits=router_logits.to(torch.float32)
                    if routing_method_type == RoutingMethodType.DeepSeekV3
                    else router_logits,
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
                    routing_bias=e_score_correction_bias,
                    x=x,
                    w13_weight=layer.w13_weight,
                    w13_weight_scale_inv=layer.w13_weight_scale_inv,
                    w2_weight=layer.w2_weight,
                    w2_weight_scale_inv=layer.w2_weight_scale_inv,
                    global_num_experts=global_num_experts,
                    top_k=top_k,
                    num_expert_group=num_expert_group,
                    topk_group=topk_group,
                    intermediate_size=layer.intermediate_size_per_partition,
                    expert_offset=layer.ep_rank * layer.local_num_experts,
                    local_num_experts=layer.local_num_experts,
1253
                    block_shape=self.weight_block_size,
1254
                    routing_method_type=routing_method_type,
1255
                    routed_scaling=routed_scaling_factor,
1256
1257
                )
            else:
1258
                assert not renormalize and custom_routing_function is not None
XuruiYang's avatar
XuruiYang committed
1259
                result = apply_flashinfer_per_tensor_scale_fp8(
1260
1261
1262
1263
1264
1265
1266
1267
                    layer=layer,
                    hidden_states=x,
                    router_logits=router_logits,
                    routing_bias=e_score_correction_bias,
                    global_num_experts=global_num_experts,
                    top_k=top_k,
                    num_expert_group=num_expert_group,
                    topk_group=topk_group,
1268
1269
                    apply_router_weight_on_input=apply_router_weight_on_input,
                )
1270

1271
1272
        zero_expert_num = getattr(layer, "zero_expert_num", 0)
        zero_expert_type = getattr(layer, "zero_expert_type", None)
XuruiYang's avatar
XuruiYang committed
1273
1274

        select_result = FusedMoE.select_experts(
1275
1276
1277
1278
1279
1280
1281
1282
1283
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
            num_expert_group=num_expert_group,
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
1284
            routed_scaling_factor=routed_scaling_factor,
1285
1286
1287
1288
1289
1290
1291
            e_score_correction_bias=e_score_correction_bias,
            indices_type=self.topk_indices_dtype,
            enable_eplb=enable_eplb,
            expert_map=expert_map,
            expert_load_view=expert_load_view,
            logical_to_physical_map=logical_to_physical_map,
            logical_replica_count=logical_replica_count,
XuruiYang's avatar
XuruiYang committed
1292
1293
1294
            global_num_experts=global_num_experts,
            zero_expert_num=zero_expert_num,
            zero_expert_type=zero_expert_type,
1295
            num_fused_shared_experts=layer.num_fused_shared_experts,
1296
1297
        )

XuruiYang's avatar
XuruiYang committed
1298
1299
        topk_weights, topk_ids, zero_expert_result = select_result

1300
1301
        if self.rocm_aiter_moe_enabled:
            from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (  # noqa: E501
1302
1303
1304
                rocm_aiter_fused_experts,
            )

XuruiYang's avatar
XuruiYang committed
1305
            result = rocm_aiter_fused_experts(
1306
1307
1308
                x,
                layer.w13_weight,
                layer.w2_weight,
1309
1310
1311
1312
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                activation=activation,
                apply_router_weight_on_input=apply_router_weight_on_input,
1313
                expert_map=expert_map,
1314
1315
                quant_config=self.moe_quant_config,
            )
1316
        elif self.use_marlin:
1317
            assert activation == "silu", f"{activation} not supported for Marlin MoE."
1318
            result = fused_marlin_moe(
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
                x,
                layer.w13_weight,
                layer.w2_weight,
                None,
                None,
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                router_logits,
                topk_weights,
                topk_ids,
                quant_type_id=scalar_types.float8_e4m3fn.id,
                apply_router_weight_on_input=apply_router_weight_on_input,
                global_num_experts=global_num_experts,
1332
                expert_map=expert_map,
1333
1334
                workspace=layer.workspace,
            )
1335
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
1336
            assert not self.block_quant
1337
1338
1339
1340
1341
1342
1343
            assert not renormalize and custom_routing_function is not None
            assert activation == "silu", (
                f"Expected 'silu' activation but got {activation}"
            )
            assert scoring_func == "sigmoid", (
                f"Expected 'sigmoid' scoring func but got {scoring_func}"
            )
1344

XuruiYang's avatar
XuruiYang committed
1345
            result = flashinfer_cutlass_moe_fp8(
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
                x,
                layer,
                topk_weights,
                topk_ids,
                inplace=False,
                activation=activation,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
1356
        else:
1357
            from vllm.model_executor.layers.fused_moe import fused_experts
1358

XuruiYang's avatar
XuruiYang committed
1359
            result = fused_experts(
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                inplace=True,
                activation=activation,
                global_num_experts=global_num_experts,
                apply_router_weight_on_input=apply_router_weight_on_input,
                expert_map=expert_map,
1370
1371
1372
                quant_config=self.moe_quant_config,
                allow_deep_gemm=self.allow_deep_gemm,
                allow_cutlass_block_scaled_grouped_gemm=(
1373
1374
1375
                    self.allow_cutlass_block_scaled_grouped_gemm
                ),
            )
XuruiYang's avatar
XuruiYang committed
1376
        if zero_expert_num != 0 and zero_expert_type is not None:
1377
            assert not isinstance(result, tuple), (
XuruiYang's avatar
XuruiYang committed
1378
                "Shared + zero experts are mutually exclusive not yet supported"
1379
            )
XuruiYang's avatar
XuruiYang committed
1380
1381
1382
            return result, zero_expert_result
        else:
            return result
1383
1384


1385
1386
1387
class Fp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
1388
1389
1390
    """

    def __init__(self, quant_config: Fp8Config):
1391
        super().__init__(quant_config)