"vscode:/vscode.git/clone" did not exist on "4b3e2d5edd5354093d03432086752ba0aa7c9f03"
fp8.py 54.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Callable
5
from enum import Enum
6
from functools import partial
7
from typing import TYPE_CHECKING, Any, Optional
8
9
10
11
12

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter

13
import vllm.envs as envs
14
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
15
from vllm import _custom_ops as ops
16
from vllm._aiter_ops import rocm_aiter_ops
17
from vllm.attention.layer import Attention
18
from vllm.distributed import get_tensor_model_parallel_world_size
19
from vllm.logger import init_logger
20
from vllm.model_executor.layers.batch_invariant import (
21
    vllm_is_batch_invariant,
22
)
bnellnm's avatar
bnellnm committed
23
from vllm.model_executor.layers.fused_moe import (
24
25
26
27
28
29
30
    FusedMoE,
    FusedMoEActivationFormat,
    FusedMoEMethodBase,
    FusedMoEPermuteExpertsUnpermute,
    FusedMoEPrepareAndFinalize,
    FusedMoeWeightScaleSupported,
)
31
from vllm.model_executor.layers.fused_moe.config import (
32
    FusedMoEParallelConfig,
33
    FusedMoEQuantConfig,
34
    RoutingMethodType,
35
36
    fp8_w8a8_moe_quant_config,
)
37
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
38
39
40
41
42
43
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
)
44
from vllm.model_executor.layers.quantization import QuantizationMethods
45
from vllm.model_executor.layers.quantization.base_config import (
46
47
48
    QuantizationConfig,
    QuantizeMethodBase,
)
49
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
50
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
51
52
    FlashinferMoeBackend,
    apply_flashinfer_per_tensor_scale_fp8,
53
    build_flashinfer_fp8_cutlass_moe_prepare_finalize,
54
55
56
57
58
59
60
    flashinfer_cutlass_moe_fp8,
    get_flashinfer_moe_backend,
    register_moe_scaling_factors,
    rotate_flashinfer_fp8_moe_weights,
    select_cutlass_fp8_gemm_impl,
    swap_w13_to_w31,
)
61
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
62
63
64
65
    W8A8BlockFp8LinearOp,
    create_fp8_input_scale,
    create_fp8_scale_parameter,
    create_fp8_weight_parameter,
66
    deepgemm_post_process_fp8_weight_block,
67
68
69
70
71
    maybe_post_process_fp8_weight_block,
    process_fp8_weight_block_strategy,
    process_fp8_weight_tensor_strategy,
    validate_fp8_block_shape,
)
72
73
74
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
    get_marlin_input_dtype,
)
75
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
76
77
78
79
    apply_fp8_marlin_linear,
    prepare_fp8_layer_for_marlin,
    prepare_moe_fp8_layer_for_marlin,
)
80
from vllm.model_executor.layers.quantization.utils.quant_utils import (
81
82
83
    GroupShape,
    is_layer_skipped,
)
84
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
85
86
87
88
89
90
91
92
93
94
95
96
97
    Fp8LinearOp,
    all_close_1d,
    cutlass_block_fp8_supported,
    cutlass_fp8_supported,
    maybe_create_device_identity,
    normalize_e4m3fn_to_e4m3fnuz,
    per_tensor_dequantize,
)
from vllm.model_executor.parameter import (
    BlockQuantScaleParameter,
    ModelWeightParameter,
    PerTensorScaleParameter,
)
98
from vllm.model_executor.utils import set_weight_attrs
99
from vllm.platforms import current_platform
100
from vllm.scalar_type import scalar_types
101
102
103
104
from vllm.utils.deep_gemm import (
    is_deep_gemm_e8m0_used,
    is_deep_gemm_supported,
)
105
from vllm.utils.flashinfer import has_flashinfer_moe
106
from vllm.utils.import_utils import has_deep_gemm
107

108
109
110
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

111
112
113
114
ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)

115

116
117
118
119
120
121
122
123
124
125
class Fp8MoeBackend(Enum):
    NONE = 0
    FLASHINFER_TRTLLM = 1
    FLASHINFER_CUTLASS = 2
    DEEPGEMM = 3
    CUTLASS_BLOCK_SCALED_GROUPED_GEMM = 4
    MARLIN = 5
    TRITON = 6


126
127
128
def get_fp8_moe_backend(
    block_quant: bool, moe_parallel_config: FusedMoEParallelConfig
) -> Fp8MoeBackend:
129
130
131
132
    """
    Select the primary FP8 MoE backend
    Note: Shape-specific fallbacks may still occur at runtime.
    """
133
    # Prefer FlashInfer backends on supported GPUs; allow SM90 and SM100.
134
135
    if (
        current_platform.is_cuda()
136
137
138
139
        and (
            current_platform.is_device_capability(100)
            or current_platform.is_device_capability(90)
        )
140
141
142
        and envs.VLLM_USE_FLASHINFER_MOE_FP8
        and has_flashinfer_moe()
    ):
143
144
        backend = get_flashinfer_moe_backend()
        if backend == FlashinferMoeBackend.TENSORRT_LLM:
145
            logger.info_once("Using FlashInfer FP8 MoE TRTLLM backend for SM100")
146
147
            return Fp8MoeBackend.FLASHINFER_TRTLLM
        else:
148
            if block_quant and current_platform.is_device_capability(100):
149
150
151
152
153
154
                raise ValueError(
                    "FlashInfer FP8 MoE throughput backend does not "
                    "support block quantization. Please use "
                    "VLLM_FLASHINFER_MOE_BACKEND=latency "
                    "instead."
                )
155
            logger.info_once("Using FlashInfer FP8 MoE CUTLASS backend for SM90/SM100")
156
157
158
            return Fp8MoeBackend.FLASHINFER_CUTLASS

    # weight-only path for older GPUs without native FP8
159
160
161
162
    use_marlin = (
        not current_platform.has_device_capability(89)
        or envs.VLLM_TEST_FORCE_FP8_MARLIN
    )
163
164
165
166
167
168
    if current_platform.is_rocm():
        use_marlin = False
    if use_marlin:
        logger.info_once("Using Marlin backend for FP8 MoE")
        return Fp8MoeBackend.MARLIN

169
170
171
172
173
174
175
176
177
178
179
180
181
    # Determine if we should use DeepGEMM with block-quantized weights:
    # - If explicitly set by user, respect their choice
    # - If not explicitly set (default), disable when TP size is >= 8
    moe_use_deep_gemm = envs.VLLM_MOE_USE_DEEP_GEMM
    if not envs.is_set("VLLM_MOE_USE_DEEP_GEMM") and moe_parallel_config.tp_size >= 8:
        moe_use_deep_gemm = False
        logger.info_once(
            "DeepGEMM MoE is disabled by default when TP size is >= 8. "
            "Set VLLM_MOE_USE_DEEP_GEMM=1 to enable it.",
            scope="local",
        )

    if envs.VLLM_USE_DEEP_GEMM and moe_use_deep_gemm and block_quant:
182
        if not has_deep_gemm():
183
184
185
            logger.warning_once(
                "DeepGEMM backend requested but not available.", scope="local"
            )
186
        elif is_deep_gemm_supported():
187
            logger.info_once("Using DeepGEMM backend for FP8 MoE", scope="local")
188
189
190
            return Fp8MoeBackend.DEEPGEMM

    # CUTLASS BlockScaled GroupedGemm on SM100 with block-quantized weights
191
192
193
194
195
    if (
        current_platform.is_cuda()
        and current_platform.is_device_capability(100)
        and block_quant
    ):
196
197
198
        logger.info_once(
            "Using Cutlass BlockScaled GroupedGemm backend for FP8 MoE", scope="local"
        )
199
200
201
202
203
204
205
        return Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM

    # default to Triton
    logger.info_once("Using Triton backend for FP8 MoE")
    return Fp8MoeBackend.TRITON


206
class Fp8Config(QuantizationConfig):
207
208
    """Config class for FP8."""

209
210
    def __init__(
        self,
211
        is_checkpoint_fp8_serialized: bool = False,
212
        activation_scheme: str = "dynamic",
213
214
        ignored_layers: list[str] | None = None,
        weight_block_size: list[int] | None = None,
215
    ) -> None:
216
        super().__init__()
217

218
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
219

220
        if activation_scheme not in ACTIVATION_SCHEMES:
221
            raise ValueError(f"Unsupported activation scheme {activation_scheme}")
222
        self.activation_scheme = activation_scheme
223
        self.ignored_layers = ignored_layers or []
224
225
226
227
        if weight_block_size is not None:
            if not is_checkpoint_fp8_serialized:
                raise ValueError(
                    "The block-wise quantization only supports fp8-serialized "
228
229
                    "checkpoint for now."
                )
230
231
232
            if len(weight_block_size) != 2:
                raise ValueError(
                    "The quantization block size of weight must have 2 "
233
234
                    f"dimensions, but got {len(weight_block_size)} dimensions"
                )
235
            if activation_scheme != "dynamic":
236
237
238
239
240
                raise ValueError(
                    "The block-wise quantization only supports "
                    "dynamic activation scheme for now, but got "
                    f"{activation_scheme} activation scheme."
                )
241
        self.weight_block_size = weight_block_size
242

243
    @classmethod
244
    def get_name(cls) -> QuantizationMethods:
245
246
247
        return "fp8"

    @classmethod
248
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
249
250
251
252
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
253
        return 80
254
255

    @classmethod
256
    def get_config_filenames(cls) -> list[str]:
257
258
        return []

259
260
    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if self.ignored_layers is not None:
261
            self.ignored_layers = hf_to_vllm_mapper.apply_list(self.ignored_layers)
262

263
    @classmethod
264
    def from_config(cls, config: dict[str, Any]) -> "Fp8Config":
265
        quant_method = cls.get_from_keys(config, ["quant_method"])
266
        is_checkpoint_fp8_serialized = "fp8" in quant_method
267
        activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
268
        ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
269
        weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None)
270
        if not ignored_layers:
271
272
273
274
275
276
277
278
279
280
281
282
283
            ignored_layers = cls.get_from_keys_or(
                config, ["modules_to_not_convert"], None
            )
        return cls(
            is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
            activation_scheme=activation_scheme,
            ignored_layers=ignored_layers,
            weight_block_size=weight_block_size,
        )

    def get_xpu_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
284
        from vllm.model_executor.layers.quantization.ipex_quant import (
285
286
287
288
            XPUFp8LinearMethod,
            XPUFp8MoEMethod,
        )

289
290
291
292
        fp8_config = Fp8Config(
            is_checkpoint_fp8_serialized=self.is_checkpoint_fp8_serialized,
            activation_scheme=self.activation_scheme,
            ignored_layers=self.ignored_layers,
293
294
            weight_block_size=self.weight_block_size,
        )
295
296

        if isinstance(layer, LinearBase):
297
298
299
300
301
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
302
303
304
305
306
307
308
309
                return UnquantizedLinearMethod()
            return XPUFp8LinearMethod(fp8_config)
        elif isinstance(layer, FusedMoE):
            return XPUFp8MoEMethod(fp8_config, layer)
        elif isinstance(layer, Attention):
            return Fp8KVCacheMethod(self)
        return None

310
311
312
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
313
314
        if current_platform.is_xpu():
            return self.get_xpu_quant_method(layer, prefix)
315
        if isinstance(layer, LinearBase):
316
317
318
319
320
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
321
                return UnquantizedLinearMethod()
322
323
324
            quant_method = Fp8LinearMethod(self)
            quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
            return quant_method
325
        elif isinstance(layer, FusedMoE):
326
327
328
329
330
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
XuruiYang's avatar
XuruiYang committed
331
                return UnquantizedFusedMoEMethod(layer.moe_config)
332
333
334
            moe_quant_method = Fp8MoEMethod(self, layer)
            moe_quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
            return moe_quant_method
335
        elif isinstance(layer, Attention):
336
            return Fp8KVCacheMethod(self)
337
        return None
338

339
    def get_cache_scale(self, name: str) -> str | None:
340
341
342
343
344
345
346
347
348
349
350
351
        """
        Check whether the param name matches the format for k/v cache scales
        in compressed-tensors. If this is the case, return its equivalent
        param name expected by vLLM

        :param name: param name
        :return: matching param name for KV cache scale in vLLM
        """
        if name.endswith(".output_scale") and ".k_proj" in name:
            return name.replace(".k_proj.output_scale", ".attn.k_scale")
        if name.endswith(".output_scale") and ".v_proj" in name:
            return name.replace(".v_proj.output_scale", ".attn.v_scale")
352
353
354
355
356
        if name.endswith(".output_scale") and ".q_proj" in name:
            return name.replace(".q_proj.output_scale", ".attn.q_scale")
        if name.endswith("self_attn.prob_output_scale"):
            return name.replace(".prob_output_scale", ".attn.prob_scale")
        # If no matches, return None
357
358
        return None

359
360
361

class Fp8LinearMethod(LinearMethodBase):
    """Linear method for FP8.
362
363
364
365
366
367
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.
368
369
370
371
372

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
    2. Only support float8_e4m3fn data type due to the limitation of
       torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
373

374
375
376
377
    Args:
        quant_config: The quantization config.
    """

378
    def __init__(self, quant_config: Fp8Config):
379
        self.quant_config = quant_config
380
        self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
381
        self.out_dtype = torch.get_default_dtype()
382

383
384
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
385
        self.marlin_input_dtype = None
386
387
388
389
        self.use_marlin = (
            not current_platform.has_device_capability(89)
            or envs.VLLM_TEST_FORCE_FP8_MARLIN
        )
390
        # Disable marlin for rocm
391
        if current_platform.is_rocm():
392
            self.use_marlin = False
393
        if vllm_is_batch_invariant():
394
            self.use_marlin = False
395

396
        self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enaled()
397
        self.use_deep_gemm = is_deep_gemm_supported()
398

399
400
        self.weight_block_size = self.quant_config.weight_block_size
        self.block_quant = self.weight_block_size is not None
401
        self.act_q_static = self.quant_config.activation_scheme == "static"
402
403
        if self.weight_block_size:
            self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
404
        else:
405
406
407
408
409
            # Use per-token quantization for better perf if dynamic and cutlass
            if not self.act_q_static and cutlass_fp8_supported():
                self.act_q_group_shape = GroupShape.PER_TOKEN
            else:
                self.act_q_group_shape = GroupShape.PER_TENSOR
410

411
412
413
414
415
416
417
418
419
420
421
422
        if self.block_quant:
            assert not self.act_q_static
            assert self.weight_block_size is not None
            self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
                weight_group_shape=GroupShape(*self.weight_block_size),
                act_quant_group_shape=self.act_q_group_shape,
                cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
                use_aiter_and_is_supported=self.use_aiter_and_is_supported,
            )
        else:
            self.fp8_linear = Fp8LinearOp(
                act_quant_static=self.act_q_static,
423
424
                act_quant_group_shape=self.act_q_group_shape,
            )
425

426
427
428
429
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
430
        output_partition_sizes: list[int],
431
432
433
434
435
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
436
437
        maybe_create_device_identity()

438
        output_size_per_partition = sum(output_partition_sizes)
439
        weight_loader = extra_weight_attrs.get("weight_loader")
440
441
442
443
444
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None
445

446
        if self.block_quant:
447
448
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
449
450
451
452
453
454
455
456
            validate_fp8_block_shape(
                layer,
                input_size,
                output_size,
                input_size_per_partition,
                output_partition_sizes,
                self.weight_block_size,
            )
457

458
        # WEIGHT
459
        if self.quant_config.is_checkpoint_fp8_serialized:
460
461
462
            weight = create_fp8_weight_parameter(
                output_size_per_partition, input_size_per_partition, weight_loader
            )
463
464
        else:
            # For non-serialized checkpoints, use original dtype
465
466
467
468
469
470
471
472
473
474
            weight = ModelWeightParameter(
                data=torch.empty(
                    output_size_per_partition,
                    input_size_per_partition,
                    dtype=params_dtype,
                ),
                input_dim=1,
                output_dim=0,
                weight_loader=weight_loader,
            )
475
476
        layer.register_parameter("weight", weight)

477
478
479
480
        # If checkpoint is serialized fp8, load them.
        # Otherwise, wait until process_weights_after_loading.
        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
481
            if not self.block_quant:
482
483
484
485
486
487
488
                scale = create_fp8_scale_parameter(
                    PerTensorScaleParameter,
                    output_partition_sizes,
                    input_size_per_partition,
                    None,
                    weight_loader,
                )
489
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
490
491
                layer.register_parameter("weight_scale", scale)
            else:
492
493
                assert not self.act_q_static
                assert self.weight_block_size is not None
494
495
496
497
498
499
500
                scale = create_fp8_scale_parameter(
                    BlockQuantScaleParameter,
                    output_partition_sizes,
                    input_size_per_partition,
                    self.weight_block_size,
                    weight_loader,
                )
501
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
502
503
                # The weight_scale_inv name is intentional for deepseekv3
                layer.register_parameter("weight_scale_inv", scale)
504

505
            # INPUT ACTIVATION SCALE
506
            if self.act_q_static:
507
                scale = create_fp8_input_scale(output_partition_sizes, weight_loader)
508
                set_weight_attrs(scale, {"scale_type": "input_scale"})
509
                layer.register_parameter("input_scale", scale)
510
511
            else:
                layer.register_parameter("input_scale", None)
512

513
    def process_weights_after_loading(self, layer: Module) -> None:
514
        size_k_first = True
515
        input_scale = None
516
        # TODO(rob): refactor block quant into separate class.
517
        if self.block_quant:
518
            assert not self.act_q_static
519
            size_k_first = False
520

521
            weight, weight_scale = process_fp8_weight_block_strategy(
522
523
                layer.weight, layer.weight_scale_inv
            )
524
525
526
            # Delete the weight_scale_inv parameter to avoid confusion
            # with the weight_scale parameter
            del layer.weight_scale_inv
527

528
        # If checkpoint not serialized fp8, quantize the weights.
529
        elif not self.quant_config.is_checkpoint_fp8_serialized:
530
            qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
531
            weight = qweight.t()
532

533
        # If checkpoint is fp8 per-tensor, handle that there are N scales for N
534
        # shards in a fused module
535
        else:
536
537
            weight = layer.weight
            weight_scale = layer.weight_scale
538
539
540

            # If using w8a8, torch._scaled_mm needs per tensor, so
            # requantize the logical shards as a single weight.
541
            if not self.use_marlin:
542
543
544
545
546
547
                weight, weight_scale, input_scale = process_fp8_weight_tensor_strategy(
                    weight,
                    weight_scale,
                    layer.logical_widths,
                    getattr(layer, "input_scale", None),
                )
548
549
550
551
552
553
554
555
                if self.act_q_static:
                    assert input_scale is not None
                    input_scale = input_scale.max()
            weight = weight.t()

        # Update layer with new values.
        layer.weight = Parameter(weight.data, requires_grad=False)
        layer.weight_scale = Parameter(weight_scale.data, requires_grad=False)
556
557
558
559
560
        layer.input_scale = (
            Parameter(input_scale, requires_grad=False)
            if input_scale is not None
            else None
        )
561

562
        if self.use_marlin:
563
564
565
            prepare_fp8_layer_for_marlin(
                layer, size_k_first, input_dtype=self.marlin_input_dtype
            )
566
567
            # Activations not quantized for marlin.
            del layer.input_scale
568
            return
569

570
        if self.block_quant:
571
            maybe_post_process_fp8_weight_block(layer)
572

573
574
575
576
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
577
        bias: torch.Tensor | None = None,
578
    ) -> torch.Tensor:
579
580
        # if batch invariant mode is enabled, prefer DeepGEMM FP8 path
        # we will use BF16 dequant when DeepGEMM is not supported.
581
        if vllm_is_batch_invariant():
582
583
            if self.block_quant:
                assert self.weight_block_size is not None
584
585
586
587
588
589
590
                return self.w8a8_block_fp8_linear.apply(
                    input=x,
                    weight=layer.weight,
                    weight_scale=layer.weight_scale,
                    input_scale=layer.input_scale,
                    bias=bias,
                )
591
            else:
592
593
594
                # per-tensor/channel: dequant to BF16 and run GEMM
                weight_fp8 = layer.weight.to(torch.bfloat16)
                weight_scale = layer.weight_scale.to(torch.bfloat16)
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
                if weight_scale.numel() == 1:
                    # Per-tensor: simple scalar multiplication
                    weight_bf16 = weight_fp8 * weight_scale
                else:
                    # Multiple scales (fused modules like QKV)
                    # Try to infer correct broadcasting
                    # weight is [K, N], scale could be [num_logical_weights]
                    # Need to figure out how to broadcast - for now just try
                    # direct multiplication
                    if (
                        weight_scale.dim() == 1
                        and weight_scale.shape[0] == weight_fp8.shape[0]
                    ):
                        # Per-row scaling
                        weight_bf16 = weight_fp8 * weight_scale.unsqueeze(1)
                    else:
                        # Fallback
                        weight_bf16 = weight_fp8 * weight_scale
613
                return torch.nn.functional.linear(x, weight_bf16.t(), bias)
614

615
        if self.use_marlin:
616
617
618
619
            return apply_fp8_marlin_linear(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale,
620
621
622
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
623
                input_dtype=self.marlin_input_dtype,
624
625
                bias=bias,
            )
626

627
        if self.block_quant:
628
629
630
            assert self.weight_block_size is not None

            return self.w8a8_block_fp8_linear.apply(
631
                input=x,
632
633
634
                weight=layer.weight,
                weight_scale=layer.weight_scale,
                input_scale=layer.input_scale,
635
                bias=bias,
636
            )
637

638
639
640
641
642
643
644
645
        return self.fp8_linear.apply(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            out_dtype=self.out_dtype,
            input_scale=layer.input_scale,
            bias=bias,
        )
646
647


648
649
650
651
652
653
654
655
656
657
658
659
660
class Fp8MoEMethod(FusedMoEMethodBase):
    """MoE method for FP8.
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

661
662
663
    def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
        super().__init__(layer.moe_config)
        self.layer = layer
664
        self.quant_config = quant_config
665
        self.weight_block_size = self.quant_config.weight_block_size
666
        self.block_quant: bool = self.weight_block_size is not None
667
668
669
        self.fp8_backend = get_fp8_moe_backend(
            self.block_quant, layer.moe_parallel_config
        )
670

671
        self.marlin_input_dtype = None
672
        self.use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN
673
        self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
674
675
676
677
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
            self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
        elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
            self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS
678
679
680
681
682
683
684
685
686
687
            if self.block_quant:
                assert self.weight_block_size == [128, 128], (
                    f"Only support weight_block_size == [128, 128], "
                    f"got {self.weight_block_size}"
                )
            self.flashinfer_moe_fn = partial(
                flashinfer_cutlass_moe_fp8,
                moe=self.moe,
                use_deepseek_fp8_block_scale=self.block_quant,
            )
688

689
        self.allow_deep_gemm = self.fp8_backend == Fp8MoeBackend.DEEPGEMM
690
691
692
        self.allow_cutlass_block_scaled_grouped_gemm = (
            self.fp8_backend == Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM
        )
693

694
695
696
697
698
699
700
701
702
    def create_weights(
        self,
        layer: Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
703
704
705
706
707
708
        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

709
710
        if self.quant_config.is_checkpoint_fp8_serialized:
            params_dtype = torch.float8_e4m3fn
711
        if self.block_quant:
712
713
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
714
715
            tp_size = get_tensor_model_parallel_world_size()
            block_n, block_k = (
716
717
                self.weight_block_size[0],
                self.weight_block_size[1],
718
719
720
721
722
            )
            # NOTE: To ensure proper alignment of the block-wise quantization
            # scales, the output_size of the weights for both the gate and up
            # layers must be divisible by block_n.
            # Required by column parallel or enabling merged weights
723
            if intermediate_size_per_partition % block_n != 0:
724
725
                raise ValueError(
                    f"The output_size of gate's and up's weight = "
726
                    f"{intermediate_size_per_partition} is not divisible by "
727
728
729
                    f"weight quantization block_n = {block_n}."
                )
            if tp_size > 1 and intermediate_size_per_partition % block_k != 0:
730
                # Required by row parallel
731
732
733
                raise ValueError(
                    f"The input_size of down's weight = "
                    f"{intermediate_size_per_partition} is not divisible by "
734
735
                    f"weight quantization block_k = {block_k}."
                )
736
737

        # WEIGHTS
738
739
740
741
742
743
744
745
746
        w13_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                hidden_size,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
747
748
749
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

750
751
752
753
754
755
756
757
758
        w2_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
759
760
761
762
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
763
764
765
        if not self.block_quant:
            # Allocate 2 scales for w1 and w3 respectively.
            # They will be combined to a single scale after weight loading.
766
767
768
769
770
771
            w13_weight_scale = torch.nn.Parameter(
                torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
            )
            w2_weight_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
772
773
774
775
776
777
            layer.register_parameter("w13_weight_scale", w13_weight_scale)
            layer.register_parameter("w2_weight_scale", w2_weight_scale)
        else:
            w13_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
778
                    2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
779
780
781
782
783
784
785
786
787
                    (hidden_size + block_k - 1) // block_k,
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            w2_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
                    (hidden_size + block_n - 1) // block_n,
788
                    (intermediate_size_per_partition + block_k - 1) // block_k,
789
790
791
792
793
794
795
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
            layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
            assert self.quant_config.activation_scheme == "dynamic"
796

797
798
799
        # Add the quantization method used (per tensor/grouped/channel)
        # to ensure the weight scales are loaded in properly
        extra_weight_attrs.update(
800
801
802
803
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
            if self.block_quant
            else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
        )
804
805
806
807
        # If loading fp8 checkpoint, pass the weight loaders.
        # If loading an fp16 checkpoint, do not (we will quantize in
        #   process_weights_after_loading()
        if self.quant_config.is_checkpoint_fp8_serialized:
808
809
            set_weight_attrs(w13_weight_scale, extra_weight_attrs)
            set_weight_attrs(w2_weight_scale, extra_weight_attrs)
810
811
812
813
814
815

        # INPUT_SCALES
        if self.quant_config.activation_scheme == "static":
            if not self.quant_config.is_checkpoint_fp8_serialized:
                raise ValueError(
                    "Found static activation scheme for checkpoint that "
816
817
                    "was not serialized fp8."
                )
818

819
820
821
            w13_input_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
822
            layer.register_parameter("w13_input_scale", w13_input_scale)
823
            set_weight_attrs(w13_input_scale, extra_weight_attrs)
824

825
826
827
            w2_input_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
828
            layer.register_parameter("w2_input_scale", w2_input_scale)
829
830
            set_weight_attrs(w2_input_scale, extra_weight_attrs)

831
        else:
832
833
            layer.w13_input_scale = None
            layer.w2_input_scale = None
834

835
836
        self.rocm_aiter_moe_enabled = False

837
    def process_weights_after_loading(self, layer: Module) -> None:
838
839
        # Lazy import to avoid importing triton too early.

840
        self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
841

842
        # TODO (rob): refactor block quant into separate class.
843
        if self.block_quant:
844
            assert self.quant_config.activation_scheme == "dynamic"
845
            if current_platform.is_fp8_fnuz():
846
                w13_weight, w13_weight_scale_inv, w13_input_scale = (
847
                    normalize_e4m3fn_to_e4m3fnuz(
848
849
850
851
852
853
                        layer.w13_weight,
                        layer.w13_weight_scale_inv,
                        layer.w13_input_scale,
                    )
                )
                w2_weight, w2_weight_scale_inv, w2_input_scale = (
854
                    normalize_e4m3fn_to_e4m3fnuz(
855
856
857
                        layer.w2_weight, layer.w2_weight_scale_inv, layer.w2_input_scale
                    )
                )
858
            elif self.flashinfer_moe_backend is not None:
859
860
                # NOTE: weights have to be swapped since the activation is
                # applied on different half for flashinfer vs vllm
861
                w13_weight = swap_w13_to_w31(layer.w13_weight.data)
862
                w13_weight_scale_inv = swap_w13_to_w31(layer.w13_weight_scale_inv.data)
863
864
                w2_weight = layer.w2_weight.data
                w2_weight_scale_inv = layer.w2_weight_scale_inv.data
865
866
867
868
869
870
871
872
            else:
                w13_weight = layer.w13_weight.data
                w13_weight_scale_inv = layer.w13_weight_scale_inv.data
                w2_weight = layer.w2_weight
                w2_weight_scale_inv = layer.w2_weight_scale_inv

            # torch.compile() cannot use Parameter subclasses.
            layer.w13_weight = Parameter(w13_weight, requires_grad=False)
873
874
875
            layer.w13_weight_scale_inv = Parameter(
                w13_weight_scale_inv, requires_grad=False
            )
876
            layer.w2_weight = Parameter(w2_weight, requires_grad=False)
877
878
879
            layer.w2_weight_scale_inv = Parameter(
                w2_weight_scale_inv, requires_grad=False
            )
880
            if self.rocm_aiter_moe_enabled:
881
                # reshaping weights is required for aiter moe kernel.
882
                shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
883
884
                    layer.w13_weight.data, layer.w2_weight.data
                )
885

886
887
                layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
888

889
            # DeepGemm scales need to be transposed and aligned. We try to do
890
            # it ahead of time for performance reasons.
891
892
893
894
895
896
897
            if self.allow_deep_gemm:
                dg_w13_weight, dg_w13_weight_scale_inv = (
                    deepgemm_post_process_fp8_weight_block(
                        wq=layer.w13_weight.data,
                        ws=layer.w13_weight_scale_inv.data,
                        quant_block_shape=tuple(layer.weight_block_size),
                        use_e8m0=is_deep_gemm_e8m0_used(),
898
                    )
899
900
901
902
903
904
905
                )
                dg_w2_weight, dg_w2_weight_scale_inv = (
                    deepgemm_post_process_fp8_weight_block(
                        wq=layer.w2_weight.data,
                        ws=layer.w2_weight_scale_inv.data,
                        quant_block_shape=tuple(layer.weight_block_size),
                        use_e8m0=is_deep_gemm_e8m0_used(),
906
                    )
907
908
909
910
911
912
913
914
915
                )
                layer.w13_weight = Parameter(dg_w13_weight, requires_grad=False)
                layer.w13_weight_scale_inv = Parameter(
                    dg_w13_weight_scale_inv, requires_grad=False
                )
                layer.w2_weight = Parameter(dg_w2_weight, requires_grad=False)
                layer.w2_weight_scale_inv = Parameter(
                    dg_w2_weight_scale_inv, requires_grad=False
                )
916

917
        # If checkpoint is fp16, quantize in place.
918
        elif not self.quant_config.is_checkpoint_fp8_serialized:
919
            fp8_dtype = current_platform.fp8_dtype()
920
            w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
921
            w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
922
923
924

            # Re-initialize w13_scale because we directly quantize
            # merged w13 weights and generate a single scaling factor.
925
926
927
928
929
930
931
932
            layer.w13_weight_scale = torch.nn.Parameter(
                torch.ones(
                    layer.local_num_experts,
                    dtype=torch.float32,
                    device=w13_weight.device,
                ),
                requires_grad=False,
            )
933
            for expert in range(layer.local_num_experts):
934
935
936
937
938
939
940
941
                w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
                    ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
                )
                w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
                    ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
                )
            layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
            layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
942
            if self.rocm_aiter_moe_enabled:
943
                # reshaping weights is required for aiter moe kernel.
944
                shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
945
946
                    layer.w13_weight, layer.w2_weight
                )
947

948
949
                layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
950
951
952
953
954
955
956
        # If checkpoint is fp8, we need to handle that the
        # MoE kernels require single activation scale and single weight
        # scale for w13 per expert.
        else:
            # Fp8 moe kernels require a single activation scale.
            # We take the max of all the scales in case they differ.
            if self.quant_config.activation_scheme == "static":
957
                if layer.w13_input_scale is None or layer.w2_input_scale is None:
958
959
                    raise ValueError(
                        "QuantConfig has static quantization, but found "
960
961
962
963
964
                        "activation scales are None."
                    )
                if not all_close_1d(layer.w13_input_scale) or not all_close_1d(
                    layer.w2_input_scale
                ):
965
                    logger.warning_once(
966
967
                        "Found input_scales that are not equal for "
                        "fp8 MoE layer. Using the maximum across experts "
968
969
                        "for each layer."
                    )
970
                layer.w13_input_scale = torch.nn.Parameter(
971
972
                    layer.w13_input_scale.max(), requires_grad=False
                )
973
                layer.w2_input_scale = torch.nn.Parameter(
974
975
                    layer.w2_input_scale.max(), requires_grad=False
                )
976
            if current_platform.is_fp8_fnuz():
977
                # Normalize the weights and scales
978
                w13_weight, w13_weight_scale, w13_input_scale = (
979
                    normalize_e4m3fn_to_e4m3fnuz(
980
981
982
983
                        layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale
                    )
                )
                w2_weight, w2_weight_scale, w2_input_scale = (
984
                    normalize_e4m3fn_to_e4m3fnuz(
985
986
987
                        layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale
                    )
                )
988
                # Reset the parameter
989
                layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
990
                layer.w13_weight_scale = torch.nn.Parameter(
991
992
                    w13_weight_scale, requires_grad=False
                )
993
994
                if w13_input_scale is not None:
                    layer.w13_input_scale = torch.nn.Parameter(
995
996
997
998
999
1000
                        w13_input_scale, requires_grad=False
                    )
                layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
                layer.w2_weight_scale = torch.nn.Parameter(
                    w2_weight_scale, requires_grad=False
                )
1001
1002
                if w2_input_scale is not None:
                    layer.w2_input_scale = torch.nn.Parameter(
1003
1004
                        w2_input_scale, requires_grad=False
                    )
1005
1006
1007

            # Fp8 moe kernel needs single weight scale for w13 per expert.
            # We take the max then dequant and requant each expert.
1008
            assert layer.w13_weight_scale is not None
1009
            shard_size = layer.intermediate_size_per_partition
1010
            max_w13_scales = layer.w13_weight_scale.max(dim=1).values
1011
            for expert_id in range(layer.local_num_experts):
1012
1013
1014
                start = 0
                for shard_id in range(2):
                    dq_weight = per_tensor_dequantize(
1015
1016
1017
1018
1019
1020
                        layer.w13_weight[expert_id][start : start + shard_size, :],
                        layer.w13_weight_scale[expert_id][shard_id],
                    )
                    layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
                        ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
                    )
1021
1022
                    start += shard_size

1023
            if self.rocm_aiter_moe_enabled:
1024
                shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
1025
1026
                    layer.w13_weight, layer.w2_weight
                )
1027

1028
1029
                layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
1030

1031
1032
1033
            layer.w13_weight_scale = torch.nn.Parameter(
                max_w13_scales, requires_grad=False
            )
1034

1035
1036
1037
1038
1039
1040
            if self.flashinfer_moe_backend is not None:
                # NOTE: weights have to be swapped since the activation is
                # applied on different half for flashinfer vs vllm
                assert not self.block_quant
                register_moe_scaling_factors(layer)
                w13_weight = swap_w13_to_w31(layer.w13_weight.data)
1041
                if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
1042
1043
1044
                    rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
                layer.w13_weight.data = w13_weight.data

1045
        if self.use_marlin:
1046
1047
1048
            prepare_moe_fp8_layer_for_marlin(
                layer, False, input_dtype=self.marlin_input_dtype
            )
1049
1050
1051
            # Activations not quantized for marlin.
            del layer.w13_input_scale
            del layer.w2_input_scale
1052

1053
1054
1055
1056
    def maybe_make_prepare_finalize(
        self,
        routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
    ) -> mk.FusedMoEPrepareAndFinalize | None:
1057
1058
1059
1060
1061
        if (
            self.rocm_aiter_moe_enabled
            or self.use_marlin
            or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1062
1063
            return None
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
1064
1065
1066
1067
1068
1069
            if self.block_quant:
                assert self.weight_block_size == [128, 128], (
                    f"Only support weight_block_size == [128, 128], "
                    f"got {self.weight_block_size}"
                )
            # Wire block-scale flag through prepare/finalize when using CUTLASS
1070
            prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
1071
1072
                self.moe,
                use_deepseek_fp8_block_scale=self.block_quant,
1073
            )
1074
1075
1076
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
        else:
1077
            return super().maybe_make_prepare_finalize(routing_tables)
1078

bnellnm's avatar
bnellnm committed
1079
1080
1081
    def select_gemm_impl(
        self,
        prepare_finalize: FusedMoEPrepareAndFinalize,
1082
        layer: torch.nn.Module,
bnellnm's avatar
bnellnm committed
1083
    ) -> FusedMoEPermuteExpertsUnpermute:
1084
        from vllm.model_executor.layers.fused_moe import (
1085
1086
            BatchedDeepGemmExperts,
            BatchedTritonExperts,
1087
1088
            TritonOrDeepGemmExperts,
        )
1089

1090
        assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
1091
1092
            "Marlin and ROCm AITER are not supported with all2all yet."
        )
1093

1094
1095
        assert self.moe_quant_config is not None

1096
1097
1098
1099
1100
        if (
            prepare_finalize.activation_format
            == FusedMoEActivationFormat.BatchedExperts
        ):
            max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
bnellnm's avatar
bnellnm committed
1101
            assert max_num_tokens_per_rank is not None
1102
1103
1104
1105

            experts_impl = (
                BatchedDeepGemmExperts if self.allow_deep_gemm else BatchedTritonExperts
            )
bnellnm's avatar
bnellnm committed
1106
            logger.debug(
1107
1108
                "%s(%s): max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
                experts_impl.__name__,
1109
1110
1111
1112
1113
                self.__class__.__name__,
                max_num_tokens_per_rank,
                self.weight_block_size,
                False,
            )
1114
            return experts_impl(
1115
                max_num_tokens=max_num_tokens_per_rank,
1116
                num_dispatchers=prepare_finalize.num_dispatchers(),
1117
                quant_config=self.moe_quant_config,
1118
            )
1119

1120
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
1121
            # Select GEMM experts with block-scale when weights are block-quantized
1122
            experts = select_cutlass_fp8_gemm_impl(
1123
1124
                self.moe,
                self.moe_quant_config,
1125
                use_deepseek_fp8_block_scale=self.block_quant,
1126
1127
1128
            )
            logger.debug_once("Using %s", experts.__class__.__name__)
            return experts
1129
        else:
bnellnm's avatar
bnellnm committed
1130
1131
            logger.debug(
                "TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s",
1132
1133
1134
1135
                self.__class__.__name__,
                self.weight_block_size,
                False,
            )
bnellnm's avatar
bnellnm committed
1136
            return TritonOrDeepGemmExperts(
1137
                quant_config=self.moe_quant_config,
1138
1139
1140
                allow_deep_gemm=self.allow_deep_gemm,
            )

1141
    def get_fused_moe_quant_config(
1142
        self, layer: torch.nn.Module
1143
    ) -> FusedMoEQuantConfig | None:
1144
1145
1146
1147
        if self.use_marlin:
            return None

        return fp8_w8a8_moe_quant_config(
1148
1149
1150
1151
1152
1153
1154
1155
            w1_scale=(
                layer.w13_weight_scale_inv
                if self.block_quant
                else layer.w13_weight_scale
            ),
            w2_scale=(
                layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
            ),
1156
1157
            a1_scale=layer.w13_input_scale,
            a2_scale=layer.w2_input_scale,
1158
            block_shape=self.weight_block_size,
1159
1160
        )

1161
1162
1163
1164
1165
1166
1167
1168
    @property
    def supports_eplb(self) -> bool:
        return True

    @property
    def allow_inplace(self) -> bool:
        return True

1169
1170
    def apply(
        self,
1171
        layer: FusedMoE,
1172
1173
1174
1175
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
1176
        use_grouped_topk: bool = False,
1177
1178
        topk_group: int | None = None,
        num_expert_group: int | None = None,
1179
        global_num_experts: int = -1,
1180
1181
        expert_map: torch.Tensor | None = None,
        custom_routing_function: Callable | None = None,
Simon Mo's avatar
Simon Mo committed
1182
        scoring_func: str = "softmax",
1183
        routed_scaling_factor: float = 1.0,
1184
        e_score_correction_bias: torch.Tensor | None = None,
1185
        apply_router_weight_on_input: bool = False,
Michael Goin's avatar
Michael Goin committed
1186
        activation: str = "silu",
1187
        enable_eplb: bool = False,
1188
1189
1190
1191
        expert_load_view: torch.Tensor | None = None,
        logical_to_physical_map: torch.Tensor | None = None,
        logical_replica_count: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
1192
1193
1194
1195
1196
        if enable_eplb:
            assert expert_load_view is not None
            assert logical_to_physical_map is not None
            assert logical_replica_count is not None
            assert isinstance(layer, FusedMoE)
1197

1198
        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
1199
1200
1201
            assert activation == "silu", (
                f"Expected 'silu' activation but got {activation}"
            )
1202

1203
            if self.block_quant:
1204
                import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe  # noqa: E501, F401
1205
1206
1207
1208
1209
1210

                e_score_correction_bias = (
                    e_score_correction_bias.to(x.dtype)
                    if e_score_correction_bias is not None
                    else None
                )
1211
                routing_method_type = layer.routing_method_type
1212
                return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
1213
1214
1215
                    routing_logits=router_logits.to(torch.float32)
                    if routing_method_type == RoutingMethodType.DeepSeekV3
                    else router_logits,
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
                    routing_bias=e_score_correction_bias,
                    x=x,
                    w13_weight=layer.w13_weight,
                    w13_weight_scale_inv=layer.w13_weight_scale_inv,
                    w2_weight=layer.w2_weight,
                    w2_weight_scale_inv=layer.w2_weight_scale_inv,
                    global_num_experts=global_num_experts,
                    top_k=top_k,
                    num_expert_group=num_expert_group,
                    topk_group=topk_group,
                    intermediate_size=layer.intermediate_size_per_partition,
                    expert_offset=layer.ep_rank * layer.local_num_experts,
                    local_num_experts=layer.local_num_experts,
1229
                    block_shape=self.weight_block_size,
1230
                    routing_method_type=routing_method_type,
1231
                    routed_scaling=routed_scaling_factor,
1232
1233
                )
            else:
1234
                assert not renormalize and custom_routing_function is not None
XuruiYang's avatar
XuruiYang committed
1235
                result = apply_flashinfer_per_tensor_scale_fp8(
1236
1237
1238
1239
1240
1241
1242
1243
                    layer=layer,
                    hidden_states=x,
                    router_logits=router_logits,
                    routing_bias=e_score_correction_bias,
                    global_num_experts=global_num_experts,
                    top_k=top_k,
                    num_expert_group=num_expert_group,
                    topk_group=topk_group,
1244
1245
                    apply_router_weight_on_input=apply_router_weight_on_input,
                )
1246

1247
        select_result = layer.select_experts(
1248
1249
1250
1251
            hidden_states=x,
            router_logits=router_logits,
        )

XuruiYang's avatar
XuruiYang committed
1252
1253
        topk_weights, topk_ids, zero_expert_result = select_result

1254
1255
        if self.rocm_aiter_moe_enabled:
            from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (  # noqa: E501
1256
1257
1258
                rocm_aiter_fused_experts,
            )

XuruiYang's avatar
XuruiYang committed
1259
            result = rocm_aiter_fused_experts(
1260
1261
1262
                x,
                layer.w13_weight,
                layer.w2_weight,
1263
1264
1265
1266
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                activation=activation,
                apply_router_weight_on_input=apply_router_weight_on_input,
1267
                expert_map=expert_map,
1268
1269
                quant_config=self.moe_quant_config,
            )
1270
        elif self.use_marlin:
1271
            assert activation == "silu", f"{activation} not supported for Marlin MoE."
1272
            result = fused_marlin_moe(
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
                x,
                layer.w13_weight,
                layer.w2_weight,
                None,
                None,
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                router_logits,
                topk_weights,
                topk_ids,
                quant_type_id=scalar_types.float8_e4m3fn.id,
                apply_router_weight_on_input=apply_router_weight_on_input,
                global_num_experts=global_num_experts,
1286
                expert_map=expert_map,
1287
                input_dtype=self.marlin_input_dtype,
1288
1289
                workspace=layer.workspace,
            )
1290
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
1291
1292
1293
            assert activation == "silu", (
                f"Expected 'silu' activation but got {activation}"
            )
1294
1295
1296
1297
1298
1299
1300
1301
            if not self.block_quant:
                assert not renormalize and custom_routing_function is not None
                assert scoring_func == "sigmoid", (
                    f"Expected 'sigmoid' scoring func but got {scoring_func}"
                )
            # Delegate to CUTLASS FlashInfer path; function already bound with
            # use_deepseek_fp8_block_scale for block-quant when applicable
            result = self.flashinfer_moe_fn(
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
                x,
                layer,
                topk_weights,
                topk_ids,
                inplace=False,
                activation=activation,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
1312
        else:
1313
            from vllm.model_executor.layers.fused_moe import fused_experts
1314

XuruiYang's avatar
XuruiYang committed
1315
            result = fused_experts(
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                inplace=True,
                activation=activation,
                global_num_experts=global_num_experts,
                apply_router_weight_on_input=apply_router_weight_on_input,
                expert_map=expert_map,
1326
1327
1328
                quant_config=self.moe_quant_config,
                allow_deep_gemm=self.allow_deep_gemm,
                allow_cutlass_block_scaled_grouped_gemm=(
1329
1330
1331
                    self.allow_cutlass_block_scaled_grouped_gemm
                ),
            )
1332
1333

        if layer.zero_expert_num != 0 and layer.zero_expert_type is not None:
1334
            assert not isinstance(result, tuple), (
XuruiYang's avatar
XuruiYang committed
1335
                "Shared + zero experts are mutually exclusive not yet supported"
1336
            )
XuruiYang's avatar
XuruiYang committed
1337
1338
1339
            return result, zero_expert_result
        else:
            return result
1340
1341


1342
1343
1344
class Fp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
1345
1346
1347
    """

    def __init__(self, quant_config: Fp8Config):
1348
        super().__init__(quant_config)