fp8.py 56.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Callable
5
from enum import Enum
6
from typing import TYPE_CHECKING, Any, Optional
7
8
9
10
11

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter

12
import vllm.envs as envs
13
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
14
from vllm import _custom_ops as ops
15
from vllm.distributed import get_tensor_model_parallel_world_size
16
from vllm.logger import init_logger
17
from vllm.model_executor.layers.batch_invariant import (
18
    vllm_is_batch_invariant,
19
)
bnellnm's avatar
bnellnm committed
20
from vllm.model_executor.layers.fused_moe import (
21
22
23
24
25
26
27
    FusedMoE,
    FusedMoEActivationFormat,
    FusedMoEMethodBase,
    FusedMoEPermuteExpertsUnpermute,
    FusedMoEPrepareAndFinalize,
    FusedMoeWeightScaleSupported,
)
28
from vllm.model_executor.layers.fused_moe.config import (
29
30
31
    FusedMoEQuantConfig,
    fp8_w8a8_moe_quant_config,
)
32
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
33
34
35
36
37
38
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
)
39
from vllm.model_executor.layers.quantization import QuantizationMethods
40
from vllm.model_executor.layers.quantization.base_config import (
41
42
43
    QuantizationConfig,
    QuantizeMethodBase,
)
44
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
45
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
46
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
47
48
    FlashinferMoeBackend,
    apply_flashinfer_per_tensor_scale_fp8,
49
    build_flashinfer_fp8_cutlass_moe_prepare_finalize,
50
51
52
53
54
55
56
    flashinfer_cutlass_moe_fp8,
    get_flashinfer_moe_backend,
    register_moe_scaling_factors,
    rotate_flashinfer_fp8_moe_weights,
    select_cutlass_fp8_gemm_impl,
    swap_w13_to_w31,
)
57
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
58
59
60
61
62
63
64
65
66
67
68
69
    W8A8BlockFp8LinearOp,
    check_aiter_fp8_linear_support,
    create_fp8_input_scale,
    create_fp8_scale_parameter,
    create_fp8_weight_parameter,
    expert_weight_is_col_major,
    maybe_post_process_fp8_weight_block,
    process_fp8_weight_block_strategy,
    process_fp8_weight_tensor_strategy,
    requant_weight_ue8m0_inplace,
    validate_fp8_block_shape,
)
70
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
71
72
73
74
    apply_fp8_marlin_linear,
    prepare_fp8_layer_for_marlin,
    prepare_moe_fp8_layer_for_marlin,
)
75
from vllm.model_executor.layers.quantization.utils.quant_utils import (
76
77
78
    GroupShape,
    is_layer_skipped,
)
79
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
80
81
82
83
84
85
86
87
88
89
90
91
92
    Fp8LinearOp,
    all_close_1d,
    cutlass_block_fp8_supported,
    cutlass_fp8_supported,
    maybe_create_device_identity,
    normalize_e4m3fn_to_e4m3fnuz,
    per_tensor_dequantize,
)
from vllm.model_executor.parameter import (
    BlockQuantScaleParameter,
    ModelWeightParameter,
    PerTensorScaleParameter,
)
93
from vllm.model_executor.utils import set_weight_attrs
94
from vllm.platforms import current_platform
95
from vllm.scalar_type import scalar_types
96
from vllm.utils.deep_gemm import (
97
    fp8_gemm_nt,
98
99
100
    get_col_major_tma_aligned_tensor,
    is_deep_gemm_e8m0_used,
    is_deep_gemm_supported,
101
    should_use_deepgemm_for_fp8_linear,
102
)
103
from vllm.utils.flashinfer import has_flashinfer_moe
104
from vllm.utils.import_utils import has_deep_gemm
105

106
107
108
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

109
110
111
112
ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)

113

114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
class Fp8MoeBackend(Enum):
    NONE = 0
    FLASHINFER_TRTLLM = 1
    FLASHINFER_CUTLASS = 2
    DEEPGEMM = 3
    CUTLASS_BLOCK_SCALED_GROUPED_GEMM = 4
    MARLIN = 5
    TRITON = 6


def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend:
    """
    Select the primary FP8 MoE backend
    Note: Shape-specific fallbacks may still occur at runtime.
    """
    # prefer FlashInfer backends when available and enabled on supported GPUs
130
131
132
133
134
135
    if (
        current_platform.is_cuda()
        and current_platform.is_device_capability(100)
        and envs.VLLM_USE_FLASHINFER_MOE_FP8
        and has_flashinfer_moe()
    ):
136
137
        backend = get_flashinfer_moe_backend()
        if backend == FlashinferMoeBackend.TENSORRT_LLM:
138
            logger.info_once("Using FlashInfer FP8 MoE TRTLLM backend for SM100")
139
140
            return Fp8MoeBackend.FLASHINFER_TRTLLM
        else:
141
            logger.info_once("Using FlashInfer FP8 MoE CUTLASS backend for SM100")
142
143
144
            return Fp8MoeBackend.FLASHINFER_CUTLASS

    # weight-only path for older GPUs without native FP8
145
146
147
148
    use_marlin = (
        not current_platform.has_device_capability(89)
        or envs.VLLM_TEST_FORCE_FP8_MARLIN
    )
149
150
151
152
153
154
155
156
157
    if current_platform.is_rocm():
        use_marlin = False
    if use_marlin:
        logger.info_once("Using Marlin backend for FP8 MoE")
        return Fp8MoeBackend.MARLIN

    # deepGEMM on supported platforms with block-quantized weights
    if envs.VLLM_USE_DEEP_GEMM and block_quant:
        if not has_deep_gemm():
158
            logger.warning_once("DeepGEMM backend requested but not available.")
159
160
161
162
163
        elif is_deep_gemm_supported():
            logger.info_once("Using DeepGEMM backend for FP8 MoE")
            return Fp8MoeBackend.DEEPGEMM

    # CUTLASS BlockScaled GroupedGemm on SM100 with block-quantized weights
164
165
166
167
168
169
    if (
        current_platform.is_cuda()
        and current_platform.is_device_capability(100)
        and block_quant
    ):
        logger.info_once("Using Cutlass BlockScaled GroupedGemm backend for FP8 MoE")
170
171
172
173
174
175
176
        return Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM

    # default to Triton
    logger.info_once("Using Triton backend for FP8 MoE")
    return Fp8MoeBackend.TRITON


177
class Fp8Config(QuantizationConfig):
178
179
    """Config class for FP8."""

180
181
    def __init__(
        self,
182
        is_checkpoint_fp8_serialized: bool = False,
183
        activation_scheme: str = "dynamic",
184
185
        ignored_layers: list[str] | None = None,
        weight_block_size: list[int] | None = None,
186
    ) -> None:
187
        super().__init__()
188

189
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
190

191
        if activation_scheme not in ACTIVATION_SCHEMES:
192
            raise ValueError(f"Unsupported activation scheme {activation_scheme}")
193
        self.activation_scheme = activation_scheme
194
        self.ignored_layers = ignored_layers or []
195
196
197
198
        if weight_block_size is not None:
            if not is_checkpoint_fp8_serialized:
                raise ValueError(
                    "The block-wise quantization only supports fp8-serialized "
199
200
                    "checkpoint for now."
                )
201
202
203
            if len(weight_block_size) != 2:
                raise ValueError(
                    "The quantization block size of weight must have 2 "
204
205
                    f"dimensions, but got {len(weight_block_size)} dimensions"
                )
206
            if activation_scheme != "dynamic":
207
208
209
210
211
                raise ValueError(
                    "The block-wise quantization only supports "
                    "dynamic activation scheme for now, but got "
                    f"{activation_scheme} activation scheme."
                )
212
        self.weight_block_size = weight_block_size
213

214
    @classmethod
215
    def get_name(cls) -> QuantizationMethods:
216
217
218
        return "fp8"

    @classmethod
219
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
220
221
222
223
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
224
        return 80
225
226

    @classmethod
227
    def get_config_filenames(cls) -> list[str]:
228
229
        return []

230
231
    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if self.ignored_layers is not None:
232
            self.ignored_layers = hf_to_vllm_mapper.apply_list(self.ignored_layers)
233

234
    @classmethod
235
    def from_config(cls, config: dict[str, Any]) -> "Fp8Config":
236
        quant_method = cls.get_from_keys(config, ["quant_method"])
237
        is_checkpoint_fp8_serialized = "fp8" in quant_method
238
        activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
239
        ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
240
        weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None)
241
        if not ignored_layers:
242
243
244
245
246
247
248
249
250
251
252
253
254
            ignored_layers = cls.get_from_keys_or(
                config, ["modules_to_not_convert"], None
            )
        return cls(
            is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
            activation_scheme=activation_scheme,
            ignored_layers=ignored_layers,
            weight_block_size=weight_block_size,
        )

    def get_xpu_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
255
256
        from vllm.attention.layer import Attention
        from vllm.model_executor.layers.quantization.ipex_quant import (
257
258
259
260
            XPUFp8LinearMethod,
            XPUFp8MoEMethod,
        )

261
262
263
264
        fp8_config = Fp8Config(
            is_checkpoint_fp8_serialized=self.is_checkpoint_fp8_serialized,
            activation_scheme=self.activation_scheme,
            ignored_layers=self.ignored_layers,
265
266
            weight_block_size=self.weight_block_size,
        )
267
268

        if isinstance(layer, LinearBase):
269
270
271
272
273
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
274
275
276
277
278
279
280
281
                return UnquantizedLinearMethod()
            return XPUFp8LinearMethod(fp8_config)
        elif isinstance(layer, FusedMoE):
            return XPUFp8MoEMethod(fp8_config, layer)
        elif isinstance(layer, Attention):
            return Fp8KVCacheMethod(self)
        return None

282
283
284
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
285
286
        from vllm.attention.layer import Attention  # Avoid circular import

287
288
        if current_platform.is_xpu():
            return self.get_xpu_quant_method(layer, prefix)
289
        if isinstance(layer, LinearBase):
290
291
292
293
294
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
295
                return UnquantizedLinearMethod()
296
            return Fp8LinearMethod(self)
297
        elif isinstance(layer, FusedMoE):
298
299
300
301
302
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
XuruiYang's avatar
XuruiYang committed
303
                return UnquantizedFusedMoEMethod(layer.moe_config)
304
            return Fp8MoEMethod(self, layer)
305
        elif isinstance(layer, Attention):
306
            return Fp8KVCacheMethod(self)
307
        return None
308

309
    def get_cache_scale(self, name: str) -> str | None:
310
311
312
313
314
315
316
317
318
319
320
321
        """
        Check whether the param name matches the format for k/v cache scales
        in compressed-tensors. If this is the case, return its equivalent
        param name expected by vLLM

        :param name: param name
        :return: matching param name for KV cache scale in vLLM
        """
        if name.endswith(".output_scale") and ".k_proj" in name:
            return name.replace(".k_proj.output_scale", ".attn.k_scale")
        if name.endswith(".output_scale") and ".v_proj" in name:
            return name.replace(".v_proj.output_scale", ".attn.v_scale")
322
323
324
325
326
        if name.endswith(".output_scale") and ".q_proj" in name:
            return name.replace(".q_proj.output_scale", ".attn.q_scale")
        if name.endswith("self_attn.prob_output_scale"):
            return name.replace(".prob_output_scale", ".attn.prob_scale")
        # If no matches, return None
327
328
        return None

329
330
331

class Fp8LinearMethod(LinearMethodBase):
    """Linear method for FP8.
332
333
334
335
336
337
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.
338
339
340
341
342

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
    2. Only support float8_e4m3fn data type due to the limitation of
       torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
343

344
345
346
347
    Args:
        quant_config: The quantization config.
    """

348
    def __init__(self, quant_config: Fp8Config):
349
        self.quant_config = quant_config
350
        self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
351
        self.out_dtype = torch.get_default_dtype()
352

353
354
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
355
356
357
358
        self.use_marlin = (
            not current_platform.has_device_capability(89)
            or envs.VLLM_TEST_FORCE_FP8_MARLIN
        )
359
        # Disable marlin for rocm
360
        if current_platform.is_rocm():
361
            self.use_marlin = False
362
        if vllm_is_batch_invariant():
363
            self.use_marlin = False
364

365
        self.use_aiter_and_is_supported = check_aiter_fp8_linear_support()
366
        self.use_deep_gemm = is_deep_gemm_supported()
367

368
369
        self.weight_block_size = self.quant_config.weight_block_size
        self.block_quant = self.weight_block_size is not None
370
        self.act_q_static = self.quant_config.activation_scheme == "static"
371
372
        if self.weight_block_size:
            self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
373
        else:
374
375
376
377
378
            # Use per-token quantization for better perf if dynamic and cutlass
            if not self.act_q_static and cutlass_fp8_supported():
                self.act_q_group_shape = GroupShape.PER_TOKEN
            else:
                self.act_q_group_shape = GroupShape.PER_TENSOR
379

380
381
382
383
384
385
386
387
388
389
390
391
        if self.block_quant:
            assert not self.act_q_static
            assert self.weight_block_size is not None
            self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
                weight_group_shape=GroupShape(*self.weight_block_size),
                act_quant_group_shape=self.act_q_group_shape,
                cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
                use_aiter_and_is_supported=self.use_aiter_and_is_supported,
            )
        else:
            self.fp8_linear = Fp8LinearOp(
                act_quant_static=self.act_q_static,
392
393
                act_quant_group_shape=self.act_q_group_shape,
            )
394

395
396
397
398
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
399
        output_partition_sizes: list[int],
400
401
402
403
404
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
405
406
        maybe_create_device_identity()

407
        output_size_per_partition = sum(output_partition_sizes)
408
        weight_loader = extra_weight_attrs.get("weight_loader")
409
410
411
412
413
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None
414

415
        if self.block_quant:
416
417
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
418
419
420
421
422
423
424
425
            validate_fp8_block_shape(
                layer,
                input_size,
                output_size,
                input_size_per_partition,
                output_partition_sizes,
                self.weight_block_size,
            )
426

427
        # WEIGHT
428
        if self.quant_config.is_checkpoint_fp8_serialized:
429
430
431
            weight = create_fp8_weight_parameter(
                output_size_per_partition, input_size_per_partition, weight_loader
            )
432
433
        else:
            # For non-serialized checkpoints, use original dtype
434
435
436
437
438
439
440
441
442
443
            weight = ModelWeightParameter(
                data=torch.empty(
                    output_size_per_partition,
                    input_size_per_partition,
                    dtype=params_dtype,
                ),
                input_dim=1,
                output_dim=0,
                weight_loader=weight_loader,
            )
444
445
        layer.register_parameter("weight", weight)

446
447
448
449
        # If checkpoint is serialized fp8, load them.
        # Otherwise, wait until process_weights_after_loading.
        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
450
            if not self.block_quant:
451
452
453
454
455
456
457
                scale = create_fp8_scale_parameter(
                    PerTensorScaleParameter,
                    output_partition_sizes,
                    input_size_per_partition,
                    None,
                    weight_loader,
                )
458
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
459
460
                layer.register_parameter("weight_scale", scale)
            else:
461
462
                assert not self.act_q_static
                assert self.weight_block_size is not None
463
464
465
466
467
468
469
                scale = create_fp8_scale_parameter(
                    BlockQuantScaleParameter,
                    output_partition_sizes,
                    input_size_per_partition,
                    self.weight_block_size,
                    weight_loader,
                )
470
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
471
472
                # The weight_scale_inv name is intentional for deepseekv3
                layer.register_parameter("weight_scale_inv", scale)
473

474
            # INPUT ACTIVATION SCALE
475
            if self.act_q_static:
476
                scale = create_fp8_input_scale(output_partition_sizes, weight_loader)
477
                set_weight_attrs(scale, {"scale_type": "input_scale"})
478
                layer.register_parameter("input_scale", scale)
479
480
            else:
                layer.register_parameter("input_scale", None)
481

482
    def process_weights_after_loading(self, layer: Module) -> None:
483
        size_k_first = True
484
        input_scale = None
485
        # TODO(rob): refactor block quant into separate class.
486
        if self.block_quant:
487
            assert not self.act_q_static
488
            size_k_first = False
489

490
            weight, weight_scale = process_fp8_weight_block_strategy(
491
492
                layer.weight, layer.weight_scale_inv
            )
493
494
495
            # Delete the weight_scale_inv parameter to avoid confusion
            # with the weight_scale parameter
            del layer.weight_scale_inv
496

497
        # If checkpoint not serialized fp8, quantize the weights.
498
        elif not self.quant_config.is_checkpoint_fp8_serialized:
499
            qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
500
            weight = qweight.t()
501

502
        # If checkpoint is fp8 per-tensor, handle that there are N scales for N
503
        # shards in a fused module
504
        else:
505
506
            weight = layer.weight
            weight_scale = layer.weight_scale
507
508
509

            # If using w8a8, torch._scaled_mm needs per tensor, so
            # requantize the logical shards as a single weight.
510
            if not self.use_marlin:
511
512
513
514
515
516
                weight, weight_scale, input_scale = process_fp8_weight_tensor_strategy(
                    weight,
                    weight_scale,
                    layer.logical_widths,
                    getattr(layer, "input_scale", None),
                )
517
518
519
520
521
522
523
524
                if self.act_q_static:
                    assert input_scale is not None
                    input_scale = input_scale.max()
            weight = weight.t()

        # Update layer with new values.
        layer.weight = Parameter(weight.data, requires_grad=False)
        layer.weight_scale = Parameter(weight_scale.data, requires_grad=False)
525
526
527
528
529
        layer.input_scale = (
            Parameter(input_scale, requires_grad=False)
            if input_scale is not None
            else None
        )
530

531
        if self.use_marlin:
532
            prepare_fp8_layer_for_marlin(layer, size_k_first)
533
534
            # Activations not quantized for marlin.
            del layer.input_scale
535
            return
536

537
        if self.block_quant:
538
            maybe_post_process_fp8_weight_block(layer, self.cutlass_block_fp8_supported)
539

540
541
542
543
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
544
        bias: torch.Tensor | None = None,
545
    ) -> torch.Tensor:
546
547
        # if batch invariant mode is enabled, prefer DeepGEMM FP8 path
        # we will use BF16 dequant when DeepGEMM is not supported.
548
        if vllm_is_batch_invariant():
549
550
            # Call is_deep_gemm_supported() ahead of time for torch.compile
            # dynamo has trouble tracing through
551
            if self.block_quant and should_use_deepgemm_for_fp8_linear(
552
                torch.bfloat16, layer.weight, self.use_deep_gemm
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
            ):
                # use group quant consistent with block size across K
                assert self.act_q_group_shape is not None
                q_input, input_scale = QuantFP8(
                    False,
                    self.act_q_group_shape,
                    column_major_scales=True,
                )(x)

                output_2d = torch.empty(
                    (q_input.shape[0], layer.weight.shape[0]),
                    dtype=torch.bfloat16,
                    device=q_input.device,
                )
                fp8_gemm_nt(
                    (q_input, input_scale),
                    (layer.weight, layer.weight_scale),
                    output_2d,
                )
                if bias is not None:
                    output_2d = output_2d + bias
                return output_2d

576
577
578
579
580
581
582
583
584
585
586
587
588
589
            # Dequantize FP8 weights to BF16
            weight_fp8 = layer.weight.to(torch.bfloat16)
            weight_scale = layer.weight_scale.to(torch.bfloat16)

            # Handle different quantization granularities
            if self.block_quant:
                # Block-wise quantization:
                # - Weight is NOT transposed, shape is [N, K] (output_size, input_size)
                # - Scale has shape [num_blocks_k, num_blocks_n] (TRANSPOSED!)
                assert self.weight_block_size is not None
                block_n, block_k = self.weight_block_size  # Note: order is [N, K]

                N, K = weight_fp8.shape

590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
                # determine expected number of blocks along N and K
                num_blocks_n = (N + block_n - 1) // block_n
                num_blocks_k = (K + block_k - 1) // block_k

                # scale layout may be [num_blocks_n, num_blocks_k]
                # or [num_blocks_k, num_blocks_n] depending on backend
                if weight_scale.dim() != 2:
                    raise RuntimeError(
                        f"FP8 block scale must be 2D, got {tuple(weight_scale.shape)}"
                    )

                scale_rows, scale_cols = weight_scale.shape
                if (scale_rows, scale_cols) == (num_blocks_k, num_blocks_n):
                    if num_blocks_n == num_blocks_k:
                        # ambiguous square case, warn and skip transpose
                        logger.warning(
                            "Batch-invariant FP8: square block-scale %dx%d; "
                            "skipping transpose to avoid misorientation.",
                            scale_rows,
                            scale_cols,
                        )
                    else:
                        # clear KN -> transpose to NK
                        weight_scale = weight_scale.t()
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654

                # Expand scale to match weight dimensions
                # scale_expanded should have shape [N, K]
                scale_expanded = weight_scale.repeat_interleave(
                    block_n, dim=0
                ).repeat_interleave(block_k, dim=1)
                # Trim to exact weight size (in case of padding)
                scale_expanded = scale_expanded[:N, :K]
                weight_bf16 = weight_fp8 * scale_expanded
            else:
                # Per-tensor quantization: weight IS transposed to [K, N]
                # scale should be scalar or [1] or per-output-channel [N]
                if weight_scale.numel() == 1:
                    # Per-tensor: simple scalar multiplication
                    weight_bf16 = weight_fp8 * weight_scale
                else:
                    # Multiple scales (fused modules like QKV)
                    # Try to infer correct broadcasting
                    # weight is [K, N], scale could be [num_logical_weights]
                    # Need to figure out how to broadcast - for now just try
                    # direct multiplication
                    if (
                        weight_scale.dim() == 1
                        and weight_scale.shape[0] == weight_fp8.shape[0]
                    ):
                        # Per-row scaling
                        weight_bf16 = weight_fp8 * weight_scale.unsqueeze(1)
                    else:
                        # Fallback
                        weight_bf16 = weight_fp8 * weight_scale

            # For block quant, weight is [N, K], for per-tensor it's [K, N]
            # F.linear expects weight to be [N, K], so:
            if self.block_quant:
                # Already in correct shape [N, K]
                output = torch.nn.functional.linear(x, weight_bf16, bias)
            else:
                # Need to transpose back: [K, N] -> [N, K]
                output = torch.nn.functional.linear(x, weight_bf16.t(), bias)
            return output

655
        if self.use_marlin:
656
657
658
659
            return apply_fp8_marlin_linear(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale,
660
661
662
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
663
664
                bias=bias,
            )
665

666
        if self.block_quant:
667
668
669
            assert self.weight_block_size is not None

            return self.w8a8_block_fp8_linear.apply(
670
                input=x,
671
672
673
                weight=layer.weight,
                weight_scale=layer.weight_scale,
                input_scale=layer.input_scale,
674
                bias=bias,
675
            )
676

677
678
679
680
681
682
683
684
        return self.fp8_linear.apply(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            out_dtype=self.out_dtype,
            input_scale=layer.input_scale,
            bias=bias,
        )
685
686


687
688
689
690
691
692
693
694
695
696
697
698
699
class Fp8MoEMethod(FusedMoEMethodBase):
    """MoE method for FP8.
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

700
701
702
    def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
        super().__init__(layer.moe_config)
        self.layer = layer
703
        self.quant_config = quant_config
704
        self.weight_block_size = self.quant_config.weight_block_size
705
        self.block_quant: bool = self.weight_block_size is not None
706

707
        self.fused_experts: mk.FusedMoEModularKernel | None = None  # type: ignore
708

709
        self.fp8_backend = get_fp8_moe_backend(self.block_quant)
710

711
        self.use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN
712
        self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
713
714
715
716
717
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
            self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
        elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
            self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS

718
        self.allow_deep_gemm = self.fp8_backend == Fp8MoeBackend.DEEPGEMM
719
720
721
        self.allow_cutlass_block_scaled_grouped_gemm = (
            self.fp8_backend == Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM
        )
722

723
724
725
726
727
728
729
730
731
    def create_weights(
        self,
        layer: Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
732
733
734
735
736
737
        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

738
739
        if self.quant_config.is_checkpoint_fp8_serialized:
            params_dtype = torch.float8_e4m3fn
740
        if self.block_quant:
741
742
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
743
744
            tp_size = get_tensor_model_parallel_world_size()
            block_n, block_k = (
745
746
                self.weight_block_size[0],
                self.weight_block_size[1],
747
748
749
750
751
            )
            # NOTE: To ensure proper alignment of the block-wise quantization
            # scales, the output_size of the weights for both the gate and up
            # layers must be divisible by block_n.
            # Required by column parallel or enabling merged weights
752
            if intermediate_size_per_partition % block_n != 0:
753
754
                raise ValueError(
                    f"The output_size of gate's and up's weight = "
755
                    f"{intermediate_size_per_partition} is not divisible by "
756
757
758
                    f"weight quantization block_n = {block_n}."
                )
            if tp_size > 1 and intermediate_size_per_partition % block_k != 0:
759
                # Required by row parallel
760
761
762
                raise ValueError(
                    f"The input_size of down's weight = "
                    f"{intermediate_size_per_partition} is not divisible by "
763
764
                    f"weight quantization block_k = {block_k}."
                )
765
766

        # WEIGHTS
767
768
769
770
771
772
773
774
775
        w13_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                hidden_size,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
776
777
778
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

779
780
781
782
783
784
785
786
787
        w2_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
788
789
790
791
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
792
793
794
        if not self.block_quant:
            # Allocate 2 scales for w1 and w3 respectively.
            # They will be combined to a single scale after weight loading.
795
796
797
798
799
800
            w13_weight_scale = torch.nn.Parameter(
                torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
            )
            w2_weight_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
801
802
803
804
805
806
            layer.register_parameter("w13_weight_scale", w13_weight_scale)
            layer.register_parameter("w2_weight_scale", w2_weight_scale)
        else:
            w13_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
807
                    2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
808
809
810
811
812
813
814
815
816
                    (hidden_size + block_k - 1) // block_k,
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            w2_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
                    (hidden_size + block_n - 1) // block_n,
817
                    (intermediate_size_per_partition + block_k - 1) // block_k,
818
819
820
821
822
823
824
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
            layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
            assert self.quant_config.activation_scheme == "dynamic"
825

826
827
828
        # Add the quantization method used (per tensor/grouped/channel)
        # to ensure the weight scales are loaded in properly
        extra_weight_attrs.update(
829
830
831
832
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
            if self.block_quant
            else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
        )
833
834
835
836
        # If loading fp8 checkpoint, pass the weight loaders.
        # If loading an fp16 checkpoint, do not (we will quantize in
        #   process_weights_after_loading()
        if self.quant_config.is_checkpoint_fp8_serialized:
837
838
            set_weight_attrs(w13_weight_scale, extra_weight_attrs)
            set_weight_attrs(w2_weight_scale, extra_weight_attrs)
839
840
841
842
843
844

        # INPUT_SCALES
        if self.quant_config.activation_scheme == "static":
            if not self.quant_config.is_checkpoint_fp8_serialized:
                raise ValueError(
                    "Found static activation scheme for checkpoint that "
845
846
                    "was not serialized fp8."
                )
847

848
849
850
            w13_input_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
851
            layer.register_parameter("w13_input_scale", w13_input_scale)
852
            set_weight_attrs(w13_input_scale, extra_weight_attrs)
853

854
855
856
            w2_input_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
857
            layer.register_parameter("w2_input_scale", w2_input_scale)
858
859
            set_weight_attrs(w2_input_scale, extra_weight_attrs)

860
        else:
861
862
            layer.w13_input_scale = None
            layer.w2_input_scale = None
863

864
865
        self.rocm_aiter_moe_enabled = False

866
    def process_weights_after_loading(self, layer: Module) -> None:
867
868
        # Lazy import to avoid importing triton too early.
        from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
869
870
871
            is_rocm_aiter_moe_enabled,
            shuffle_weights,
        )
872

873
874
        self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()

875
        # TODO (rob): refactor block quant into separate class.
876
        if self.block_quant:
877
            assert self.quant_config.activation_scheme == "dynamic"
878
            if current_platform.is_fp8_fnuz():
879
                w13_weight, w13_weight_scale_inv, w13_input_scale = (
880
                    normalize_e4m3fn_to_e4m3fnuz(
881
882
883
884
885
886
                        layer.w13_weight,
                        layer.w13_weight_scale_inv,
                        layer.w13_input_scale,
                    )
                )
                w2_weight, w2_weight_scale_inv, w2_input_scale = (
887
                    normalize_e4m3fn_to_e4m3fnuz(
888
889
890
                        layer.w2_weight, layer.w2_weight_scale_inv, layer.w2_input_scale
                    )
                )
891
            elif self.flashinfer_moe_backend is not None:
892
893
                # NOTE: weights have to be swapped since the activation is
                # applied on different half for flashinfer vs vllm
894
                w13_weight = swap_w13_to_w31(layer.w13_weight.data)
895
                w13_weight_scale_inv = swap_w13_to_w31(layer.w13_weight_scale_inv.data)
896
897
                w2_weight = layer.w2_weight.data
                w2_weight_scale_inv = layer.w2_weight_scale_inv.data
898
899
900
901
902
903
904
905
            else:
                w13_weight = layer.w13_weight.data
                w13_weight_scale_inv = layer.w13_weight_scale_inv.data
                w2_weight = layer.w2_weight
                w2_weight_scale_inv = layer.w2_weight_scale_inv

            # torch.compile() cannot use Parameter subclasses.
            layer.w13_weight = Parameter(w13_weight, requires_grad=False)
906
907
908
            layer.w13_weight_scale_inv = Parameter(
                w13_weight_scale_inv, requires_grad=False
            )
909
            layer.w2_weight = Parameter(w2_weight, requires_grad=False)
910
911
912
            layer.w2_weight_scale_inv = Parameter(
                w2_weight_scale_inv, requires_grad=False
            )
913
            if self.rocm_aiter_moe_enabled:
914
915
                # reshaping weights is required for aiter moe kernel.
                shuffled_w13, shuffled_w2 = shuffle_weights(
916
917
                    layer.w13_weight.data, layer.w2_weight.data
                )
918

919
920
                layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
921

922
            # DeepGemm scales need to be transposed and aligned. We try to do
923
            # it ahead of time for performance reasons.
924
            if self.allow_deep_gemm and not is_deep_gemm_e8m0_used():
925
                if expert_weight_is_col_major(layer.w13_weight_scale_inv):
926
927
928
                    layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(
                        layer.w13_weight_scale_inv
                    )
929
                if expert_weight_is_col_major(layer.w2_weight_scale_inv):
930
931
932
                    layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(
                        layer.w2_weight_scale_inv
                    )
933

934
        # If checkpoint is fp16, quantize in place.
935
        elif not self.quant_config.is_checkpoint_fp8_serialized:
936
            fp8_dtype = current_platform.fp8_dtype()
937
            w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
938
            w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
939
940
941

            # Re-initialize w13_scale because we directly quantize
            # merged w13 weights and generate a single scaling factor.
942
943
944
945
946
947
948
949
            layer.w13_weight_scale = torch.nn.Parameter(
                torch.ones(
                    layer.local_num_experts,
                    dtype=torch.float32,
                    device=w13_weight.device,
                ),
                requires_grad=False,
            )
950
            for expert in range(layer.local_num_experts):
951
952
953
954
955
956
957
958
                w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
                    ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
                )
                w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
                    ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
                )
            layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
            layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
959
            if self.rocm_aiter_moe_enabled:
960
                # reshaping weights is required for aiter moe kernel.
961
                shuffled_w13, shuffled_w2 = shuffle_weights(
962
963
                    layer.w13_weight, layer.w2_weight
                )
964

965
966
                layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
967
968
969
970
971
972
973
        # If checkpoint is fp8, we need to handle that the
        # MoE kernels require single activation scale and single weight
        # scale for w13 per expert.
        else:
            # Fp8 moe kernels require a single activation scale.
            # We take the max of all the scales in case they differ.
            if self.quant_config.activation_scheme == "static":
974
                if layer.w13_input_scale is None or layer.w2_input_scale is None:
975
976
                    raise ValueError(
                        "QuantConfig has static quantization, but found "
977
978
979
980
981
                        "activation scales are None."
                    )
                if not all_close_1d(layer.w13_input_scale) or not all_close_1d(
                    layer.w2_input_scale
                ):
982
                    logger.warning_once(
983
984
                        "Found input_scales that are not equal for "
                        "fp8 MoE layer. Using the maximum across experts "
985
986
                        "for each layer."
                    )
987
                layer.w13_input_scale = torch.nn.Parameter(
988
989
                    layer.w13_input_scale.max(), requires_grad=False
                )
990
                layer.w2_input_scale = torch.nn.Parameter(
991
992
                    layer.w2_input_scale.max(), requires_grad=False
                )
993
            if current_platform.is_fp8_fnuz():
994
                # Normalize the weights and scales
995
                w13_weight, w13_weight_scale, w13_input_scale = (
996
                    normalize_e4m3fn_to_e4m3fnuz(
997
998
999
1000
                        layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale
                    )
                )
                w2_weight, w2_weight_scale, w2_input_scale = (
1001
                    normalize_e4m3fn_to_e4m3fnuz(
1002
1003
1004
                        layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale
                    )
                )
1005
                # Reset the parameter
1006
                layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
1007
                layer.w13_weight_scale = torch.nn.Parameter(
1008
1009
                    w13_weight_scale, requires_grad=False
                )
1010
1011
                if w13_input_scale is not None:
                    layer.w13_input_scale = torch.nn.Parameter(
1012
1013
1014
1015
1016
1017
                        w13_input_scale, requires_grad=False
                    )
                layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
                layer.w2_weight_scale = torch.nn.Parameter(
                    w2_weight_scale, requires_grad=False
                )
1018
1019
                if w2_input_scale is not None:
                    layer.w2_input_scale = torch.nn.Parameter(
1020
1021
                        w2_input_scale, requires_grad=False
                    )
1022
1023
1024

            # Fp8 moe kernel needs single weight scale for w13 per expert.
            # We take the max then dequant and requant each expert.
1025
            assert layer.w13_weight_scale is not None
1026
            shard_size = layer.intermediate_size_per_partition
1027
            max_w13_scales = layer.w13_weight_scale.max(dim=1).values
1028
            for expert_id in range(layer.local_num_experts):
1029
1030
1031
                start = 0
                for shard_id in range(2):
                    dq_weight = per_tensor_dequantize(
1032
1033
1034
1035
1036
1037
                        layer.w13_weight[expert_id][start : start + shard_size, :],
                        layer.w13_weight_scale[expert_id][shard_id],
                    )
                    layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
                        ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
                    )
1038
1039
                    start += shard_size

1040
            if self.rocm_aiter_moe_enabled:
1041
                shuffled_w13, shuffled_w2 = shuffle_weights(
1042
1043
                    layer.w13_weight, layer.w2_weight
                )
1044

1045
1046
                layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
1047

1048
1049
1050
            layer.w13_weight_scale = torch.nn.Parameter(
                max_w13_scales, requires_grad=False
            )
1051

1052
1053
1054
1055
1056
1057
            if self.flashinfer_moe_backend is not None:
                # NOTE: weights have to be swapped since the activation is
                # applied on different half for flashinfer vs vllm
                assert not self.block_quant
                register_moe_scaling_factors(layer)
                w13_weight = swap_w13_to_w31(layer.w13_weight.data)
1058
                if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
1059
1060
1061
                    rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
                layer.w13_weight.data = w13_weight.data

1062
1063
1064
1065
1066
        if self.use_marlin:
            prepare_moe_fp8_layer_for_marlin(layer, False)
            # Activations not quantized for marlin.
            del layer.w13_input_scale
            del layer.w2_input_scale
1067

1068
        if is_deep_gemm_e8m0_used() and self.block_quant:
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
            assert layer.weight_block_size is not None
            # Re-quantise the expert weights so their scales are UE8M0.
            block_sz = tuple(layer.weight_block_size)
            requant_weight_ue8m0_inplace(
                layer.w13_weight.data,
                layer.w13_weight_scale_inv.data,
                block_sz,
            )
            requant_weight_ue8m0_inplace(
                layer.w2_weight.data,
                layer.w2_weight_scale_inv.data,
                block_sz,
            )

            # Ensure column-major TMA alignment expected by DeepGEMM.
1084
            if expert_weight_is_col_major(layer.w13_weight_scale_inv):
1085
                layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(
1086
1087
                    layer.w13_weight_scale_inv
                )
1088
            if expert_weight_is_col_major(layer.w2_weight_scale_inv):
1089
                layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(
1090
1091
                    layer.w2_weight_scale_inv
                )
1092

1093
    def maybe_make_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize | None:
1094
1095
1096
1097
1098
        if (
            self.rocm_aiter_moe_enabled
            or self.use_marlin
            or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1099
1100
            return None
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
1101
1102
1103
            prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
                self.moe
            )
1104
1105
1106
1107
1108
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
        else:
            return super().maybe_make_prepare_finalize()

bnellnm's avatar
bnellnm committed
1109
1110
1111
    def select_gemm_impl(
        self,
        prepare_finalize: FusedMoEPrepareAndFinalize,
1112
        layer: torch.nn.Module,
bnellnm's avatar
bnellnm committed
1113
    ) -> FusedMoEPermuteExpertsUnpermute:
1114
        from vllm.model_executor.layers.fused_moe import (
1115
1116
1117
            BatchedTritonOrDeepGemmExperts,
            TritonOrDeepGemmExperts,
        )
1118

1119
        assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
1120
1121
            "Marlin and ROCm AITER are not supported with all2all yet."
        )
1122

1123
1124
        assert self.moe_quant_config is not None

1125
1126
1127
1128
1129
        if (
            prepare_finalize.activation_format
            == FusedMoEActivationFormat.BatchedExperts
        ):
            max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
bnellnm's avatar
bnellnm committed
1130
1131
1132
1133
            assert max_num_tokens_per_rank is not None
            logger.debug(
                "BatchedTritonOrDeepGemmExperts(%s): "
                "max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
1134
1135
1136
1137
1138
                self.__class__.__name__,
                max_num_tokens_per_rank,
                self.weight_block_size,
                False,
            )
bnellnm's avatar
bnellnm committed
1139
            return BatchedTritonOrDeepGemmExperts(
1140
                max_num_tokens=max_num_tokens_per_rank,
1141
                num_dispatchers=prepare_finalize.num_dispatchers(),
1142
                quant_config=self.moe_quant_config,
1143
                allow_deep_gemm=self.allow_deep_gemm,
1144
            )
1145
1146
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
            experts = select_cutlass_fp8_gemm_impl(
1147
1148
                self.moe,
                self.moe_quant_config,
1149
1150
1151
            )
            logger.debug_once("Using %s", experts.__class__.__name__)
            return experts
1152
        else:
bnellnm's avatar
bnellnm committed
1153
1154
            logger.debug(
                "TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s",
1155
1156
1157
1158
                self.__class__.__name__,
                self.weight_block_size,
                False,
            )
bnellnm's avatar
bnellnm committed
1159
            return TritonOrDeepGemmExperts(
1160
                quant_config=self.moe_quant_config,
1161
1162
1163
                allow_deep_gemm=self.allow_deep_gemm,
            )

1164
    def get_fused_moe_quant_config(
1165
        self, layer: torch.nn.Module
1166
    ) -> FusedMoEQuantConfig | None:
1167
1168
1169
1170
        if self.use_marlin:
            return None

        return fp8_w8a8_moe_quant_config(
1171
1172
1173
1174
1175
1176
1177
1178
            w1_scale=(
                layer.w13_weight_scale_inv
                if self.block_quant
                else layer.w13_weight_scale
            ),
            w2_scale=(
                layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
            ),
1179
1180
            a1_scale=layer.w13_input_scale,
            a2_scale=layer.w2_input_scale,
1181
            block_shape=self.weight_block_size,
1182
1183
        )

1184
1185
1186
1187
1188
1189
1190
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
1191
        use_grouped_topk: bool = False,
1192
1193
        topk_group: int | None = None,
        num_expert_group: int | None = None,
1194
        global_num_experts: int = -1,
1195
1196
        expert_map: torch.Tensor | None = None,
        custom_routing_function: Callable | None = None,
Simon Mo's avatar
Simon Mo committed
1197
        scoring_func: str = "softmax",
1198
        routed_scaling_factor: float = 1.0,
1199
        e_score_correction_bias: torch.Tensor | None = None,
1200
        apply_router_weight_on_input: bool = False,
Michael Goin's avatar
Michael Goin committed
1201
        activation: str = "silu",
1202
        enable_eplb: bool = False,
1203
1204
1205
1206
        expert_load_view: torch.Tensor | None = None,
        logical_to_physical_map: torch.Tensor | None = None,
        logical_replica_count: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
1207
1208
1209
1210
1211
        if enable_eplb:
            assert expert_load_view is not None
            assert logical_to_physical_map is not None
            assert logical_replica_count is not None
            assert isinstance(layer, FusedMoE)
1212

1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
        if (
            self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
            and self.fused_experts is None
        ):
            assert activation == "silu", (
                f"Expected 'silu' activation but got {activation}"
            )
            assert scoring_func == "sigmoid", (
                f"Expected 'sigmoid' scoring func but got {scoring_func}"
            )
1223
            if self.block_quant:
1224
                import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe  # noqa: E501, F401
1225
1226
1227
1228
1229
1230
1231
1232
1233

                assert (
                    renormalize and use_grouped_topk and custom_routing_function is None
                )
                e_score_correction_bias = (
                    e_score_correction_bias.to(x.dtype)
                    if e_score_correction_bias is not None
                    else None
                )
1234
                return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
                    routing_logits=router_logits.to(torch.float32),
                    routing_bias=e_score_correction_bias,
                    x=x,
                    w13_weight=layer.w13_weight,
                    w13_weight_scale_inv=layer.w13_weight_scale_inv,
                    w2_weight=layer.w2_weight,
                    w2_weight_scale_inv=layer.w2_weight_scale_inv,
                    global_num_experts=global_num_experts,
                    top_k=top_k,
                    num_expert_group=num_expert_group,
                    topk_group=topk_group,
                    intermediate_size=layer.intermediate_size_per_partition,
                    expert_offset=layer.ep_rank * layer.local_num_experts,
                    local_num_experts=layer.local_num_experts,
1249
                    block_shape=self.weight_block_size,
1250
                    routed_scaling=routed_scaling_factor,
1251
1252
                )
            else:
1253
                assert not renormalize and custom_routing_function is not None
XuruiYang's avatar
XuruiYang committed
1254
                result = apply_flashinfer_per_tensor_scale_fp8(
1255
1256
1257
1258
1259
1260
1261
1262
                    layer=layer,
                    hidden_states=x,
                    router_logits=router_logits,
                    routing_bias=e_score_correction_bias,
                    global_num_experts=global_num_experts,
                    top_k=top_k,
                    num_expert_group=num_expert_group,
                    topk_group=topk_group,
1263
1264
                    apply_router_weight_on_input=apply_router_weight_on_input,
                )
1265

1266
1267
        zero_expert_num = getattr(layer, "zero_expert_num", 0)
        zero_expert_type = getattr(layer, "zero_expert_type", None)
XuruiYang's avatar
XuruiYang committed
1268
1269

        select_result = FusedMoE.select_experts(
1270
1271
1272
1273
1274
1275
1276
1277
1278
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
            num_expert_group=num_expert_group,
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
1279
            routed_scaling_factor=routed_scaling_factor,
1280
1281
1282
1283
1284
1285
1286
            e_score_correction_bias=e_score_correction_bias,
            indices_type=self.topk_indices_dtype,
            enable_eplb=enable_eplb,
            expert_map=expert_map,
            expert_load_view=expert_load_view,
            logical_to_physical_map=logical_to_physical_map,
            logical_replica_count=logical_replica_count,
XuruiYang's avatar
XuruiYang committed
1287
1288
1289
            global_num_experts=global_num_experts,
            zero_expert_num=zero_expert_num,
            zero_expert_type=zero_expert_type,
1290
            num_fused_shared_experts=layer.num_fused_shared_experts,
1291
1292
        )

1293
1294
1295
1296
        #
        # Note: the order of checks is important since self.fused_experts
        # can override fused_experts or cutlass but not rocm or marlin.
        #
XuruiYang's avatar
XuruiYang committed
1297
1298
        topk_weights, topk_ids, zero_expert_result = select_result

1299
1300
        if self.rocm_aiter_moe_enabled:
            from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (  # noqa: E501
1301
1302
1303
                rocm_aiter_fused_experts,
            )

1304
            assert self.fused_experts is None
XuruiYang's avatar
XuruiYang committed
1305
            result = rocm_aiter_fused_experts(
1306
1307
1308
                x,
                layer.w13_weight,
                layer.w2_weight,
1309
1310
1311
1312
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                activation=activation,
                apply_router_weight_on_input=apply_router_weight_on_input,
1313
                expert_map=expert_map,
1314
1315
                quant_config=self.moe_quant_config,
            )
1316
        elif self.use_marlin:
1317
            assert activation == "silu", f"{activation} not supported for Marlin MoE."
1318
            assert self.fused_experts is None
1319
            result = fused_marlin_moe(
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
                x,
                layer.w13_weight,
                layer.w2_weight,
                None,
                None,
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                router_logits,
                topk_weights,
                topk_ids,
                quant_type_id=scalar_types.float8_e4m3fn.id,
                apply_router_weight_on_input=apply_router_weight_on_input,
                global_num_experts=global_num_experts,
1333
                expert_map=expert_map,
1334
1335
                workspace=layer.workspace,
            )
1336
        elif self.fused_experts:
XuruiYang's avatar
XuruiYang committed
1337
            result = self.fused_experts(
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                inplace=True,
                activation=activation,
                global_num_experts=global_num_experts,
                apply_router_weight_on_input=apply_router_weight_on_input,
                expert_map=expert_map,
            )
1349
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
1350
            assert not self.block_quant
1351
1352
1353
1354
1355
1356
1357
            assert not renormalize and custom_routing_function is not None
            assert activation == "silu", (
                f"Expected 'silu' activation but got {activation}"
            )
            assert scoring_func == "sigmoid", (
                f"Expected 'sigmoid' scoring func but got {scoring_func}"
            )
1358

XuruiYang's avatar
XuruiYang committed
1359
            result = flashinfer_cutlass_moe_fp8(
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
                x,
                layer,
                topk_weights,
                topk_ids,
                inplace=False,
                activation=activation,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
1370
        else:
1371
            from vllm.model_executor.layers.fused_moe import fused_experts
1372

XuruiYang's avatar
XuruiYang committed
1373
            result = fused_experts(
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                inplace=True,
                activation=activation,
                global_num_experts=global_num_experts,
                apply_router_weight_on_input=apply_router_weight_on_input,
                expert_map=expert_map,
1384
1385
1386
                quant_config=self.moe_quant_config,
                allow_deep_gemm=self.allow_deep_gemm,
                allow_cutlass_block_scaled_grouped_gemm=(
1387
1388
1389
                    self.allow_cutlass_block_scaled_grouped_gemm
                ),
            )
XuruiYang's avatar
XuruiYang committed
1390
        if zero_expert_num != 0 and zero_expert_type is not None:
1391
            assert not isinstance(result, tuple), (
XuruiYang's avatar
XuruiYang committed
1392
                "Shared + zero experts are mutually exclusive not yet supported"
1393
            )
XuruiYang's avatar
XuruiYang committed
1394
1395
1396
            return result, zero_expert_result
        else:
            return result
1397
1398


1399
1400
1401
class Fp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
1402
1403
1404
    """

    def __init__(self, quant_config: Fp8Config):
1405
        super().__init__(quant_config)