fp8.py 56.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Callable
5
from enum import Enum
6
from typing import TYPE_CHECKING, Any, Optional
7
8
9
10
11

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter

12
import vllm.envs as envs
13
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
14
from vllm import _custom_ops as ops
15
from vllm.distributed import get_tensor_model_parallel_world_size
16
from vllm.logger import init_logger
17
from vllm.model_executor.layers.batch_invariant import (
18
    vllm_is_batch_invariant,
19
)
bnellnm's avatar
bnellnm committed
20
from vllm.model_executor.layers.fused_moe import (
21
22
23
24
25
26
27
    FusedMoE,
    FusedMoEActivationFormat,
    FusedMoEMethodBase,
    FusedMoEPermuteExpertsUnpermute,
    FusedMoEPrepareAndFinalize,
    FusedMoeWeightScaleSupported,
)
28
from vllm.model_executor.layers.fused_moe.config import (
29
30
31
    FusedMoEQuantConfig,
    fp8_w8a8_moe_quant_config,
)
32
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
33
34
35
36
37
38
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
)
39
from vllm.model_executor.layers.quantization import QuantizationMethods
40
from vllm.model_executor.layers.quantization.base_config import (
41
42
43
    QuantizationConfig,
    QuantizeMethodBase,
)
44
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
45
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
46
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
47
48
    FlashinferMoeBackend,
    apply_flashinfer_per_tensor_scale_fp8,
49
    build_flashinfer_fp8_cutlass_moe_prepare_finalize,
50
51
52
53
54
55
56
    flashinfer_cutlass_moe_fp8,
    get_flashinfer_moe_backend,
    register_moe_scaling_factors,
    rotate_flashinfer_fp8_moe_weights,
    select_cutlass_fp8_gemm_impl,
    swap_w13_to_w31,
)
57
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
58
59
60
61
62
63
64
65
66
67
68
69
    W8A8BlockFp8LinearOp,
    check_aiter_fp8_linear_support,
    create_fp8_input_scale,
    create_fp8_scale_parameter,
    create_fp8_weight_parameter,
    expert_weight_is_col_major,
    maybe_post_process_fp8_weight_block,
    process_fp8_weight_block_strategy,
    process_fp8_weight_tensor_strategy,
    requant_weight_ue8m0_inplace,
    validate_fp8_block_shape,
)
70
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
71
72
73
74
    apply_fp8_marlin_linear,
    prepare_fp8_layer_for_marlin,
    prepare_moe_fp8_layer_for_marlin,
)
75
from vllm.model_executor.layers.quantization.utils.quant_utils import (
76
77
78
    GroupShape,
    is_layer_skipped,
)
79
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
80
81
82
83
84
85
86
87
88
89
90
91
92
    Fp8LinearOp,
    all_close_1d,
    cutlass_block_fp8_supported,
    cutlass_fp8_supported,
    maybe_create_device_identity,
    normalize_e4m3fn_to_e4m3fnuz,
    per_tensor_dequantize,
)
from vllm.model_executor.parameter import (
    BlockQuantScaleParameter,
    ModelWeightParameter,
    PerTensorScaleParameter,
)
93
from vllm.model_executor.utils import set_weight_attrs
94
from vllm.platforms import current_platform
95
from vllm.scalar_type import scalar_types
96
from vllm.utils import has_deep_gemm
97
from vllm.utils.deep_gemm import (
98
    fp8_gemm_nt,
99
100
101
    get_col_major_tma_aligned_tensor,
    is_deep_gemm_e8m0_used,
    is_deep_gemm_supported,
102
    should_use_deepgemm_for_fp8_linear,
103
)
104
from vllm.utils.flashinfer import has_flashinfer_moe
105

106
107
108
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

109
110
111
112
ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)

113

114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
class Fp8MoeBackend(Enum):
    NONE = 0
    FLASHINFER_TRTLLM = 1
    FLASHINFER_CUTLASS = 2
    DEEPGEMM = 3
    CUTLASS_BLOCK_SCALED_GROUPED_GEMM = 4
    MARLIN = 5
    TRITON = 6


def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend:
    """
    Select the primary FP8 MoE backend
    Note: Shape-specific fallbacks may still occur at runtime.
    """
    # prefer FlashInfer backends when available and enabled on supported GPUs
130
131
132
133
134
135
    if (
        current_platform.is_cuda()
        and current_platform.is_device_capability(100)
        and envs.VLLM_USE_FLASHINFER_MOE_FP8
        and has_flashinfer_moe()
    ):
136
137
        backend = get_flashinfer_moe_backend()
        if backend == FlashinferMoeBackend.TENSORRT_LLM:
138
            logger.info_once("Using FlashInfer FP8 MoE TRTLLM backend for SM100")
139
140
            return Fp8MoeBackend.FLASHINFER_TRTLLM
        else:
141
            logger.info_once("Using FlashInfer FP8 MoE CUTLASS backend for SM100")
142
143
144
            return Fp8MoeBackend.FLASHINFER_CUTLASS

    # weight-only path for older GPUs without native FP8
145
146
147
148
    use_marlin = (
        not current_platform.has_device_capability(89)
        or envs.VLLM_TEST_FORCE_FP8_MARLIN
    )
149
150
151
152
153
154
155
156
157
    if current_platform.is_rocm():
        use_marlin = False
    if use_marlin:
        logger.info_once("Using Marlin backend for FP8 MoE")
        return Fp8MoeBackend.MARLIN

    # deepGEMM on supported platforms with block-quantized weights
    if envs.VLLM_USE_DEEP_GEMM and block_quant:
        if not has_deep_gemm():
158
            logger.warning_once("DeepGEMM backend requested but not available.")
159
160
161
162
163
        elif is_deep_gemm_supported():
            logger.info_once("Using DeepGEMM backend for FP8 MoE")
            return Fp8MoeBackend.DEEPGEMM

    # CUTLASS BlockScaled GroupedGemm on SM100 with block-quantized weights
164
165
166
167
168
169
    if (
        current_platform.is_cuda()
        and current_platform.is_device_capability(100)
        and block_quant
    ):
        logger.info_once("Using Cutlass BlockScaled GroupedGemm backend for FP8 MoE")
170
171
172
173
174
175
176
        return Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM

    # default to Triton
    logger.info_once("Using Triton backend for FP8 MoE")
    return Fp8MoeBackend.TRITON


177
class Fp8Config(QuantizationConfig):
178
179
    """Config class for FP8."""

180
181
    def __init__(
        self,
182
        is_checkpoint_fp8_serialized: bool = False,
183
        activation_scheme: str = "dynamic",
184
185
        ignored_layers: list[str] | None = None,
        weight_block_size: list[int] | None = None,
186
    ) -> None:
187
        super().__init__()
188

189
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
190

191
        if activation_scheme not in ACTIVATION_SCHEMES:
192
            raise ValueError(f"Unsupported activation scheme {activation_scheme}")
193
        self.activation_scheme = activation_scheme
194
        self.ignored_layers = ignored_layers or []
195
196
197
198
        if weight_block_size is not None:
            if not is_checkpoint_fp8_serialized:
                raise ValueError(
                    "The block-wise quantization only supports fp8-serialized "
199
200
                    "checkpoint for now."
                )
201
202
203
            if len(weight_block_size) != 2:
                raise ValueError(
                    "The quantization block size of weight must have 2 "
204
205
                    f"dimensions, but got {len(weight_block_size)} dimensions"
                )
206
            if activation_scheme != "dynamic":
207
208
209
210
211
                raise ValueError(
                    "The block-wise quantization only supports "
                    "dynamic activation scheme for now, but got "
                    f"{activation_scheme} activation scheme."
                )
212
        self.weight_block_size = weight_block_size
213

214
    @classmethod
215
    def get_name(cls) -> QuantizationMethods:
216
217
218
        return "fp8"

    @classmethod
219
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
220
221
222
223
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
224
        return 80
225
226

    @classmethod
227
    def get_config_filenames(cls) -> list[str]:
228
229
        return []

230
231
    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if self.ignored_layers is not None:
232
            self.ignored_layers = hf_to_vllm_mapper.apply_list(self.ignored_layers)
233

234
    @classmethod
235
    def from_config(cls, config: dict[str, Any]) -> "Fp8Config":
236
        quant_method = cls.get_from_keys(config, ["quant_method"])
237
        is_checkpoint_fp8_serialized = "fp8" in quant_method
238
        activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
239
        ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
240
        weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None)
241
        if not ignored_layers:
242
243
244
245
246
247
248
249
250
251
252
253
254
            ignored_layers = cls.get_from_keys_or(
                config, ["modules_to_not_convert"], None
            )
        return cls(
            is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
            activation_scheme=activation_scheme,
            ignored_layers=ignored_layers,
            weight_block_size=weight_block_size,
        )

    def get_xpu_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
255
256
        from vllm.attention.layer import Attention
        from vllm.model_executor.layers.quantization.ipex_quant import (
257
258
259
260
            XPUFp8LinearMethod,
            XPUFp8MoEMethod,
        )

261
262
263
264
        fp8_config = Fp8Config(
            is_checkpoint_fp8_serialized=self.is_checkpoint_fp8_serialized,
            activation_scheme=self.activation_scheme,
            ignored_layers=self.ignored_layers,
265
266
            weight_block_size=self.weight_block_size,
        )
267
268

        if isinstance(layer, LinearBase):
269
270
271
272
273
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
274
275
276
277
278
279
280
281
                return UnquantizedLinearMethod()
            return XPUFp8LinearMethod(fp8_config)
        elif isinstance(layer, FusedMoE):
            return XPUFp8MoEMethod(fp8_config, layer)
        elif isinstance(layer, Attention):
            return Fp8KVCacheMethod(self)
        return None

282
283
284
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
285
286
        from vllm.attention.layer import Attention  # Avoid circular import

287
288
        if current_platform.is_xpu():
            return self.get_xpu_quant_method(layer, prefix)
289
        if isinstance(layer, LinearBase):
290
291
292
293
294
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
295
                return UnquantizedLinearMethod()
296
            return Fp8LinearMethod(self)
297
        elif isinstance(layer, FusedMoE):
298
299
300
301
302
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
XuruiYang's avatar
XuruiYang committed
303
                return UnquantizedFusedMoEMethod(layer.moe_config)
304
            return Fp8MoEMethod(self, layer)
305
        elif isinstance(layer, Attention):
306
            return Fp8KVCacheMethod(self)
307
        return None
308

309
    def get_cache_scale(self, name: str) -> str | None:
310
311
312
313
314
315
316
317
318
319
320
321
        """
        Check whether the param name matches the format for k/v cache scales
        in compressed-tensors. If this is the case, return its equivalent
        param name expected by vLLM

        :param name: param name
        :return: matching param name for KV cache scale in vLLM
        """
        if name.endswith(".output_scale") and ".k_proj" in name:
            return name.replace(".k_proj.output_scale", ".attn.k_scale")
        if name.endswith(".output_scale") and ".v_proj" in name:
            return name.replace(".v_proj.output_scale", ".attn.v_scale")
322
323
324
325
326
        if name.endswith(".output_scale") and ".q_proj" in name:
            return name.replace(".q_proj.output_scale", ".attn.q_scale")
        if name.endswith("self_attn.prob_output_scale"):
            return name.replace(".prob_output_scale", ".attn.prob_scale")
        # If no matches, return None
327
328
        return None

329
330
331

class Fp8LinearMethod(LinearMethodBase):
    """Linear method for FP8.
332
333
334
335
336
337
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.
338
339
340
341
342

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
    2. Only support float8_e4m3fn data type due to the limitation of
       torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
343

344
345
346
347
    Args:
        quant_config: The quantization config.
    """

348
    def __init__(self, quant_config: Fp8Config):
349
        self.quant_config = quant_config
350
        self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
351
        self.out_dtype = torch.get_default_dtype()
352

353
354
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
355
356
357
358
        self.use_marlin = (
            not current_platform.has_device_capability(89)
            or envs.VLLM_TEST_FORCE_FP8_MARLIN
        )
359
        # Disable marlin for rocm
360
        if current_platform.is_rocm():
361
            self.use_marlin = False
362
        if vllm_is_batch_invariant():
363
            self.use_marlin = False
364

365
        self.use_aiter_and_is_supported = check_aiter_fp8_linear_support()
366

367
368
        self.weight_block_size = self.quant_config.weight_block_size
        self.block_quant = self.weight_block_size is not None
369
        self.act_q_static = self.quant_config.activation_scheme == "static"
370
371
        if self.weight_block_size:
            self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
372
        else:
373
374
375
376
377
            # Use per-token quantization for better perf if dynamic and cutlass
            if not self.act_q_static and cutlass_fp8_supported():
                self.act_q_group_shape = GroupShape.PER_TOKEN
            else:
                self.act_q_group_shape = GroupShape.PER_TENSOR
378

379
380
381
382
383
384
385
386
387
388
389
390
        if self.block_quant:
            assert not self.act_q_static
            assert self.weight_block_size is not None
            self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
                weight_group_shape=GroupShape(*self.weight_block_size),
                act_quant_group_shape=self.act_q_group_shape,
                cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
                use_aiter_and_is_supported=self.use_aiter_and_is_supported,
            )
        else:
            self.fp8_linear = Fp8LinearOp(
                act_quant_static=self.act_q_static,
391
392
                act_quant_group_shape=self.act_q_group_shape,
            )
393

394
395
396
397
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
398
        output_partition_sizes: list[int],
399
400
401
402
403
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
404
405
        maybe_create_device_identity()

406
        output_size_per_partition = sum(output_partition_sizes)
407
        weight_loader = extra_weight_attrs.get("weight_loader")
408
409
410
411
412
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None
413

414
        if self.block_quant:
415
416
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
417
418
419
420
421
422
423
424
            validate_fp8_block_shape(
                layer,
                input_size,
                output_size,
                input_size_per_partition,
                output_partition_sizes,
                self.weight_block_size,
            )
425

426
        # WEIGHT
427
        if self.quant_config.is_checkpoint_fp8_serialized:
428
429
430
            weight = create_fp8_weight_parameter(
                output_size_per_partition, input_size_per_partition, weight_loader
            )
431
432
        else:
            # For non-serialized checkpoints, use original dtype
433
434
435
436
437
438
439
440
441
442
            weight = ModelWeightParameter(
                data=torch.empty(
                    output_size_per_partition,
                    input_size_per_partition,
                    dtype=params_dtype,
                ),
                input_dim=1,
                output_dim=0,
                weight_loader=weight_loader,
            )
443
444
        layer.register_parameter("weight", weight)

445
446
447
448
        # If checkpoint is serialized fp8, load them.
        # Otherwise, wait until process_weights_after_loading.
        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
449
            if not self.block_quant:
450
451
452
453
454
455
456
                scale = create_fp8_scale_parameter(
                    PerTensorScaleParameter,
                    output_partition_sizes,
                    input_size_per_partition,
                    None,
                    weight_loader,
                )
457
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
458
459
                layer.register_parameter("weight_scale", scale)
            else:
460
461
                assert not self.act_q_static
                assert self.weight_block_size is not None
462
463
464
465
466
467
468
                scale = create_fp8_scale_parameter(
                    BlockQuantScaleParameter,
                    output_partition_sizes,
                    input_size_per_partition,
                    self.weight_block_size,
                    weight_loader,
                )
469
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
470
471
                # The weight_scale_inv name is intentional for deepseekv3
                layer.register_parameter("weight_scale_inv", scale)
472

473
            # INPUT ACTIVATION SCALE
474
            if self.act_q_static:
475
                scale = create_fp8_input_scale(output_partition_sizes, weight_loader)
476
                set_weight_attrs(scale, {"scale_type": "input_scale"})
477
                layer.register_parameter("input_scale", scale)
478
479
            else:
                layer.register_parameter("input_scale", None)
480

481
    def process_weights_after_loading(self, layer: Module) -> None:
482
        size_k_first = True
483
        input_scale = None
484
        # TODO(rob): refactor block quant into separate class.
485
        if self.block_quant:
486
            assert not self.act_q_static
487
            size_k_first = False
488

489
            weight, weight_scale = process_fp8_weight_block_strategy(
490
491
                layer.weight, layer.weight_scale_inv
            )
492
493
494
            # Delete the weight_scale_inv parameter to avoid confusion
            # with the weight_scale parameter
            del layer.weight_scale_inv
495

496
        # If checkpoint not serialized fp8, quantize the weights.
497
        elif not self.quant_config.is_checkpoint_fp8_serialized:
498
            qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
499
            weight = qweight.t()
500

501
        # If checkpoint is fp8 per-tensor, handle that there are N scales for N
502
        # shards in a fused module
503
        else:
504
505
            weight = layer.weight
            weight_scale = layer.weight_scale
506
507
508

            # If using w8a8, torch._scaled_mm needs per tensor, so
            # requantize the logical shards as a single weight.
509
            if not self.use_marlin:
510
511
512
513
514
515
                weight, weight_scale, input_scale = process_fp8_weight_tensor_strategy(
                    weight,
                    weight_scale,
                    layer.logical_widths,
                    getattr(layer, "input_scale", None),
                )
516
517
518
519
520
521
522
523
                if self.act_q_static:
                    assert input_scale is not None
                    input_scale = input_scale.max()
            weight = weight.t()

        # Update layer with new values.
        layer.weight = Parameter(weight.data, requires_grad=False)
        layer.weight_scale = Parameter(weight_scale.data, requires_grad=False)
524
525
526
527
528
        layer.input_scale = (
            Parameter(input_scale, requires_grad=False)
            if input_scale is not None
            else None
        )
529

530
        if self.use_marlin:
531
            prepare_fp8_layer_for_marlin(layer, size_k_first)
532
533
            # Activations not quantized for marlin.
            del layer.input_scale
534
            return
535

536
        if self.block_quant:
537
            maybe_post_process_fp8_weight_block(layer, self.cutlass_block_fp8_supported)
538

539
540
541
542
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
543
        bias: torch.Tensor | None = None,
544
    ) -> torch.Tensor:
545
546
        # if batch invariant mode is enabled, prefer DeepGEMM FP8 path
        # we will use BF16 dequant when DeepGEMM is not supported.
547
        if vllm_is_batch_invariant():
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
            if self.block_quant and should_use_deepgemm_for_fp8_linear(
                torch.bfloat16, layer.weight, None
            ):
                # use group quant consistent with block size across K
                assert self.act_q_group_shape is not None
                q_input, input_scale = QuantFP8(
                    False,
                    self.act_q_group_shape,
                    column_major_scales=True,
                )(x)

                output_2d = torch.empty(
                    (q_input.shape[0], layer.weight.shape[0]),
                    dtype=torch.bfloat16,
                    device=q_input.device,
                )
                fp8_gemm_nt(
                    (q_input, input_scale),
                    (layer.weight, layer.weight_scale),
                    output_2d,
                )
                if bias is not None:
                    output_2d = output_2d + bias
                return output_2d

573
574
575
576
577
578
579
580
581
582
583
584
585
586
            # Dequantize FP8 weights to BF16
            weight_fp8 = layer.weight.to(torch.bfloat16)
            weight_scale = layer.weight_scale.to(torch.bfloat16)

            # Handle different quantization granularities
            if self.block_quant:
                # Block-wise quantization:
                # - Weight is NOT transposed, shape is [N, K] (output_size, input_size)
                # - Scale has shape [num_blocks_k, num_blocks_n] (TRANSPOSED!)
                assert self.weight_block_size is not None
                block_n, block_k = self.weight_block_size  # Note: order is [N, K]

                N, K = weight_fp8.shape

587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
                # determine expected number of blocks along N and K
                num_blocks_n = (N + block_n - 1) // block_n
                num_blocks_k = (K + block_k - 1) // block_k

                # scale layout may be [num_blocks_n, num_blocks_k]
                # or [num_blocks_k, num_blocks_n] depending on backend
                if weight_scale.dim() != 2:
                    raise RuntimeError(
                        f"FP8 block scale must be 2D, got {tuple(weight_scale.shape)}"
                    )

                scale_rows, scale_cols = weight_scale.shape
                if (scale_rows, scale_cols) == (num_blocks_k, num_blocks_n):
                    if num_blocks_n == num_blocks_k:
                        # ambiguous square case, warn and skip transpose
                        logger.warning(
                            "Batch-invariant FP8: square block-scale %dx%d; "
                            "skipping transpose to avoid misorientation.",
                            scale_rows,
                            scale_cols,
                        )
                    else:
                        # clear KN -> transpose to NK
                        weight_scale = weight_scale.t()
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651

                # Expand scale to match weight dimensions
                # scale_expanded should have shape [N, K]
                scale_expanded = weight_scale.repeat_interleave(
                    block_n, dim=0
                ).repeat_interleave(block_k, dim=1)
                # Trim to exact weight size (in case of padding)
                scale_expanded = scale_expanded[:N, :K]
                weight_bf16 = weight_fp8 * scale_expanded
            else:
                # Per-tensor quantization: weight IS transposed to [K, N]
                # scale should be scalar or [1] or per-output-channel [N]
                if weight_scale.numel() == 1:
                    # Per-tensor: simple scalar multiplication
                    weight_bf16 = weight_fp8 * weight_scale
                else:
                    # Multiple scales (fused modules like QKV)
                    # Try to infer correct broadcasting
                    # weight is [K, N], scale could be [num_logical_weights]
                    # Need to figure out how to broadcast - for now just try
                    # direct multiplication
                    if (
                        weight_scale.dim() == 1
                        and weight_scale.shape[0] == weight_fp8.shape[0]
                    ):
                        # Per-row scaling
                        weight_bf16 = weight_fp8 * weight_scale.unsqueeze(1)
                    else:
                        # Fallback
                        weight_bf16 = weight_fp8 * weight_scale

            # For block quant, weight is [N, K], for per-tensor it's [K, N]
            # F.linear expects weight to be [N, K], so:
            if self.block_quant:
                # Already in correct shape [N, K]
                output = torch.nn.functional.linear(x, weight_bf16, bias)
            else:
                # Need to transpose back: [K, N] -> [N, K]
                output = torch.nn.functional.linear(x, weight_bf16.t(), bias)
            return output

652
        if self.use_marlin:
653
654
655
656
            return apply_fp8_marlin_linear(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale,
657
658
659
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
660
661
                bias=bias,
            )
662

663
        if self.block_quant:
664
665
666
            assert self.weight_block_size is not None

            return self.w8a8_block_fp8_linear.apply(
667
                input=x,
668
669
670
                weight=layer.weight,
                weight_scale=layer.weight_scale,
                input_scale=layer.input_scale,
671
                bias=bias,
672
            )
673

674
675
676
677
678
679
680
681
        return self.fp8_linear.apply(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            out_dtype=self.out_dtype,
            input_scale=layer.input_scale,
            bias=bias,
        )
682
683


684
685
686
687
688
689
690
691
692
693
694
695
696
class Fp8MoEMethod(FusedMoEMethodBase):
    """MoE method for FP8.
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

697
698
699
    def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
        super().__init__(layer.moe_config)
        self.layer = layer
700
        self.quant_config = quant_config
701
        self.weight_block_size = self.quant_config.weight_block_size
702
        self.block_quant: bool = self.weight_block_size is not None
703

704
        self.fused_experts: mk.FusedMoEModularKernel | None = None  # type: ignore
705

706
        self.fp8_backend = get_fp8_moe_backend(self.block_quant)
707

708
        self.use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN
709
        self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
710
711
712
713
714
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
            self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
        elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
            self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS

715
        self.allow_deep_gemm = self.fp8_backend == Fp8MoeBackend.DEEPGEMM
716
717
718
        self.allow_cutlass_block_scaled_grouped_gemm = (
            self.fp8_backend == Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM
        )
719

720
721
722
723
724
725
726
727
728
    def create_weights(
        self,
        layer: Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
729
730
731
732
733
734
        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

735
736
        if self.quant_config.is_checkpoint_fp8_serialized:
            params_dtype = torch.float8_e4m3fn
737
        if self.block_quant:
738
739
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
740
741
            tp_size = get_tensor_model_parallel_world_size()
            block_n, block_k = (
742
743
                self.weight_block_size[0],
                self.weight_block_size[1],
744
745
746
747
748
            )
            # NOTE: To ensure proper alignment of the block-wise quantization
            # scales, the output_size of the weights for both the gate and up
            # layers must be divisible by block_n.
            # Required by column parallel or enabling merged weights
749
            if intermediate_size_per_partition % block_n != 0:
750
751
                raise ValueError(
                    f"The output_size of gate's and up's weight = "
752
                    f"{intermediate_size_per_partition} is not divisible by "
753
754
755
                    f"weight quantization block_n = {block_n}."
                )
            if tp_size > 1 and intermediate_size_per_partition % block_k != 0:
756
                # Required by row parallel
757
758
759
                raise ValueError(
                    f"The input_size of down's weight = "
                    f"{intermediate_size_per_partition} is not divisible by "
760
761
                    f"weight quantization block_k = {block_k}."
                )
762
763

        # WEIGHTS
764
765
766
767
768
769
770
771
772
        w13_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                hidden_size,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
773
774
775
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

776
777
778
779
780
781
782
783
784
        w2_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
785
786
787
788
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
789
790
791
        if not self.block_quant:
            # Allocate 2 scales for w1 and w3 respectively.
            # They will be combined to a single scale after weight loading.
792
793
794
795
796
797
            w13_weight_scale = torch.nn.Parameter(
                torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
            )
            w2_weight_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
798
799
800
801
802
803
            layer.register_parameter("w13_weight_scale", w13_weight_scale)
            layer.register_parameter("w2_weight_scale", w2_weight_scale)
        else:
            w13_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
804
                    2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
805
806
807
808
809
810
811
812
813
                    (hidden_size + block_k - 1) // block_k,
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            w2_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
                    (hidden_size + block_n - 1) // block_n,
814
                    (intermediate_size_per_partition + block_k - 1) // block_k,
815
816
817
818
819
820
821
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
            layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
            assert self.quant_config.activation_scheme == "dynamic"
822

823
824
825
        # Add the quantization method used (per tensor/grouped/channel)
        # to ensure the weight scales are loaded in properly
        extra_weight_attrs.update(
826
827
828
829
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
            if self.block_quant
            else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
        )
830
831
832
833
        # If loading fp8 checkpoint, pass the weight loaders.
        # If loading an fp16 checkpoint, do not (we will quantize in
        #   process_weights_after_loading()
        if self.quant_config.is_checkpoint_fp8_serialized:
834
835
            set_weight_attrs(w13_weight_scale, extra_weight_attrs)
            set_weight_attrs(w2_weight_scale, extra_weight_attrs)
836
837
838
839
840
841

        # INPUT_SCALES
        if self.quant_config.activation_scheme == "static":
            if not self.quant_config.is_checkpoint_fp8_serialized:
                raise ValueError(
                    "Found static activation scheme for checkpoint that "
842
843
                    "was not serialized fp8."
                )
844

845
846
847
            w13_input_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
848
            layer.register_parameter("w13_input_scale", w13_input_scale)
849
            set_weight_attrs(w13_input_scale, extra_weight_attrs)
850

851
852
853
            w2_input_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
854
            layer.register_parameter("w2_input_scale", w2_input_scale)
855
856
            set_weight_attrs(w2_input_scale, extra_weight_attrs)

857
        else:
858
859
            layer.w13_input_scale = None
            layer.w2_input_scale = None
860

861
862
        self.rocm_aiter_moe_enabled = False

863
    def process_weights_after_loading(self, layer: Module) -> None:
864
865
        # Lazy import to avoid importing triton too early.
        from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
866
867
868
            is_rocm_aiter_moe_enabled,
            shuffle_weights,
        )
869

870
871
        self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()

872
        # TODO (rob): refactor block quant into separate class.
873
        if self.block_quant:
874
            assert self.quant_config.activation_scheme == "dynamic"
875
            if current_platform.is_fp8_fnuz():
876
                w13_weight, w13_weight_scale_inv, w13_input_scale = (
877
                    normalize_e4m3fn_to_e4m3fnuz(
878
879
880
881
882
883
                        layer.w13_weight,
                        layer.w13_weight_scale_inv,
                        layer.w13_input_scale,
                    )
                )
                w2_weight, w2_weight_scale_inv, w2_input_scale = (
884
                    normalize_e4m3fn_to_e4m3fnuz(
885
886
887
                        layer.w2_weight, layer.w2_weight_scale_inv, layer.w2_input_scale
                    )
                )
888
            elif self.flashinfer_moe_backend is not None:
889
890
                # NOTE: weights have to be swapped since the activation is
                # applied on different half for flashinfer vs vllm
891
                w13_weight = swap_w13_to_w31(layer.w13_weight.data)
892
                w13_weight_scale_inv = swap_w13_to_w31(layer.w13_weight_scale_inv.data)
893
894
                w2_weight = layer.w2_weight.data
                w2_weight_scale_inv = layer.w2_weight_scale_inv.data
895
896
897
898
899
900
901
902
            else:
                w13_weight = layer.w13_weight.data
                w13_weight_scale_inv = layer.w13_weight_scale_inv.data
                w2_weight = layer.w2_weight
                w2_weight_scale_inv = layer.w2_weight_scale_inv

            # torch.compile() cannot use Parameter subclasses.
            layer.w13_weight = Parameter(w13_weight, requires_grad=False)
903
904
905
            layer.w13_weight_scale_inv = Parameter(
                w13_weight_scale_inv, requires_grad=False
            )
906
            layer.w2_weight = Parameter(w2_weight, requires_grad=False)
907
908
909
            layer.w2_weight_scale_inv = Parameter(
                w2_weight_scale_inv, requires_grad=False
            )
910
            if self.rocm_aiter_moe_enabled:
911
912
                # reshaping weights is required for aiter moe kernel.
                shuffled_w13, shuffled_w2 = shuffle_weights(
913
914
                    layer.w13_weight.data, layer.w2_weight.data
                )
915

916
917
                layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
918

919
            # DeepGemm scales need to be transposed and aligned. We try to do
920
            # it ahead of time for performance reasons.
921
            if self.allow_deep_gemm and not is_deep_gemm_e8m0_used():
922
                if expert_weight_is_col_major(layer.w13_weight_scale_inv):
923
924
925
                    layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(
                        layer.w13_weight_scale_inv
                    )
926
                if expert_weight_is_col_major(layer.w2_weight_scale_inv):
927
928
929
                    layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(
                        layer.w2_weight_scale_inv
                    )
930

931
        # If checkpoint is fp16, quantize in place.
932
        elif not self.quant_config.is_checkpoint_fp8_serialized:
933
            fp8_dtype = current_platform.fp8_dtype()
934
            w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
935
            w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
936
937
938

            # Re-initialize w13_scale because we directly quantize
            # merged w13 weights and generate a single scaling factor.
939
940
941
942
943
944
945
946
            layer.w13_weight_scale = torch.nn.Parameter(
                torch.ones(
                    layer.local_num_experts,
                    dtype=torch.float32,
                    device=w13_weight.device,
                ),
                requires_grad=False,
            )
947
            for expert in range(layer.local_num_experts):
948
949
950
951
952
953
954
955
                w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
                    ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
                )
                w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
                    ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
                )
            layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
            layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
956
            if self.rocm_aiter_moe_enabled:
957
                # reshaping weights is required for aiter moe kernel.
958
                shuffled_w13, shuffled_w2 = shuffle_weights(
959
960
                    layer.w13_weight, layer.w2_weight
                )
961

962
963
                layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
964
965
966
967
968
969
970
        # If checkpoint is fp8, we need to handle that the
        # MoE kernels require single activation scale and single weight
        # scale for w13 per expert.
        else:
            # Fp8 moe kernels require a single activation scale.
            # We take the max of all the scales in case they differ.
            if self.quant_config.activation_scheme == "static":
971
                if layer.w13_input_scale is None or layer.w2_input_scale is None:
972
973
                    raise ValueError(
                        "QuantConfig has static quantization, but found "
974
975
976
977
978
                        "activation scales are None."
                    )
                if not all_close_1d(layer.w13_input_scale) or not all_close_1d(
                    layer.w2_input_scale
                ):
979
                    logger.warning_once(
980
981
                        "Found input_scales that are not equal for "
                        "fp8 MoE layer. Using the maximum across experts "
982
983
                        "for each layer."
                    )
984
                layer.w13_input_scale = torch.nn.Parameter(
985
986
                    layer.w13_input_scale.max(), requires_grad=False
                )
987
                layer.w2_input_scale = torch.nn.Parameter(
988
989
                    layer.w2_input_scale.max(), requires_grad=False
                )
990
            if current_platform.is_fp8_fnuz():
991
                # Normalize the weights and scales
992
                w13_weight, w13_weight_scale, w13_input_scale = (
993
                    normalize_e4m3fn_to_e4m3fnuz(
994
995
996
997
                        layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale
                    )
                )
                w2_weight, w2_weight_scale, w2_input_scale = (
998
                    normalize_e4m3fn_to_e4m3fnuz(
999
1000
1001
                        layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale
                    )
                )
1002
                # Reset the parameter
1003
                layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
1004
                layer.w13_weight_scale = torch.nn.Parameter(
1005
1006
                    w13_weight_scale, requires_grad=False
                )
1007
1008
                if w13_input_scale is not None:
                    layer.w13_input_scale = torch.nn.Parameter(
1009
1010
1011
1012
1013
1014
                        w13_input_scale, requires_grad=False
                    )
                layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
                layer.w2_weight_scale = torch.nn.Parameter(
                    w2_weight_scale, requires_grad=False
                )
1015
1016
                if w2_input_scale is not None:
                    layer.w2_input_scale = torch.nn.Parameter(
1017
1018
                        w2_input_scale, requires_grad=False
                    )
1019
1020
1021

            # Fp8 moe kernel needs single weight scale for w13 per expert.
            # We take the max then dequant and requant each expert.
1022
            assert layer.w13_weight_scale is not None
1023
            shard_size = layer.intermediate_size_per_partition
1024
            max_w13_scales = layer.w13_weight_scale.max(dim=1).values
1025
            for expert_id in range(layer.local_num_experts):
1026
1027
1028
                start = 0
                for shard_id in range(2):
                    dq_weight = per_tensor_dequantize(
1029
1030
1031
1032
1033
1034
                        layer.w13_weight[expert_id][start : start + shard_size, :],
                        layer.w13_weight_scale[expert_id][shard_id],
                    )
                    layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
                        ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
                    )
1035
1036
                    start += shard_size

1037
            if self.rocm_aiter_moe_enabled:
1038
                shuffled_w13, shuffled_w2 = shuffle_weights(
1039
1040
                    layer.w13_weight, layer.w2_weight
                )
1041

1042
1043
                layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
1044

1045
1046
1047
            layer.w13_weight_scale = torch.nn.Parameter(
                max_w13_scales, requires_grad=False
            )
1048

1049
1050
1051
1052
1053
1054
            if self.flashinfer_moe_backend is not None:
                # NOTE: weights have to be swapped since the activation is
                # applied on different half for flashinfer vs vllm
                assert not self.block_quant
                register_moe_scaling_factors(layer)
                w13_weight = swap_w13_to_w31(layer.w13_weight.data)
1055
                if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
1056
1057
1058
                    rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
                layer.w13_weight.data = w13_weight.data

1059
1060
1061
1062
1063
        if self.use_marlin:
            prepare_moe_fp8_layer_for_marlin(layer, False)
            # Activations not quantized for marlin.
            del layer.w13_input_scale
            del layer.w2_input_scale
1064

1065
        if is_deep_gemm_e8m0_used() and self.block_quant:
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
            assert layer.weight_block_size is not None
            # Re-quantise the expert weights so their scales are UE8M0.
            block_sz = tuple(layer.weight_block_size)
            requant_weight_ue8m0_inplace(
                layer.w13_weight.data,
                layer.w13_weight_scale_inv.data,
                block_sz,
            )
            requant_weight_ue8m0_inplace(
                layer.w2_weight.data,
                layer.w2_weight_scale_inv.data,
                block_sz,
            )

            # Ensure column-major TMA alignment expected by DeepGEMM.
1081
            if expert_weight_is_col_major(layer.w13_weight_scale_inv):
1082
                layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(
1083
1084
                    layer.w13_weight_scale_inv
                )
1085
            if expert_weight_is_col_major(layer.w2_weight_scale_inv):
1086
                layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(
1087
1088
                    layer.w2_weight_scale_inv
                )
1089

1090
    def maybe_make_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize | None:
1091
1092
1093
1094
1095
        if (
            self.rocm_aiter_moe_enabled
            or self.use_marlin
            or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1096
1097
            return None
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
1098
1099
1100
            prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
                self.moe
            )
1101
1102
1103
1104
1105
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
        else:
            return super().maybe_make_prepare_finalize()

bnellnm's avatar
bnellnm committed
1106
1107
1108
    def select_gemm_impl(
        self,
        prepare_finalize: FusedMoEPrepareAndFinalize,
1109
        layer: torch.nn.Module,
bnellnm's avatar
bnellnm committed
1110
    ) -> FusedMoEPermuteExpertsUnpermute:
1111
        from vllm.model_executor.layers.fused_moe import (
1112
1113
1114
            BatchedTritonOrDeepGemmExperts,
            TritonOrDeepGemmExperts,
        )
1115

1116
        assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
1117
1118
            "Marlin and ROCm AITER are not supported with all2all yet."
        )
1119

1120
1121
        assert self.moe_quant_config is not None

1122
1123
1124
1125
1126
        if (
            prepare_finalize.activation_format
            == FusedMoEActivationFormat.BatchedExperts
        ):
            max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
bnellnm's avatar
bnellnm committed
1127
1128
1129
1130
            assert max_num_tokens_per_rank is not None
            logger.debug(
                "BatchedTritonOrDeepGemmExperts(%s): "
                "max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
1131
1132
1133
1134
1135
                self.__class__.__name__,
                max_num_tokens_per_rank,
                self.weight_block_size,
                False,
            )
bnellnm's avatar
bnellnm committed
1136
            return BatchedTritonOrDeepGemmExperts(
1137
                max_num_tokens=max_num_tokens_per_rank,
1138
                num_dispatchers=prepare_finalize.num_dispatchers(),
1139
                quant_config=self.moe_quant_config,
1140
                allow_deep_gemm=self.allow_deep_gemm,
1141
            )
1142
1143
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
            experts = select_cutlass_fp8_gemm_impl(
1144
1145
                self.moe,
                self.moe_quant_config,
1146
1147
1148
            )
            logger.debug_once("Using %s", experts.__class__.__name__)
            return experts
1149
        else:
bnellnm's avatar
bnellnm committed
1150
1151
            logger.debug(
                "TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s",
1152
1153
1154
1155
                self.__class__.__name__,
                self.weight_block_size,
                False,
            )
bnellnm's avatar
bnellnm committed
1156
            return TritonOrDeepGemmExperts(
1157
                quant_config=self.moe_quant_config,
1158
1159
1160
                allow_deep_gemm=self.allow_deep_gemm,
            )

1161
    def get_fused_moe_quant_config(
1162
        self, layer: torch.nn.Module
1163
    ) -> FusedMoEQuantConfig | None:
1164
1165
1166
1167
        if self.use_marlin:
            return None

        return fp8_w8a8_moe_quant_config(
1168
1169
1170
1171
1172
1173
1174
1175
            w1_scale=(
                layer.w13_weight_scale_inv
                if self.block_quant
                else layer.w13_weight_scale
            ),
            w2_scale=(
                layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
            ),
1176
1177
            a1_scale=layer.w13_input_scale,
            a2_scale=layer.w2_input_scale,
1178
            block_shape=self.weight_block_size,
1179
1180
        )

1181
1182
1183
1184
1185
1186
1187
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
1188
        use_grouped_topk: bool = False,
1189
1190
        topk_group: int | None = None,
        num_expert_group: int | None = None,
1191
        global_num_experts: int = -1,
1192
1193
        expert_map: torch.Tensor | None = None,
        custom_routing_function: Callable | None = None,
Simon Mo's avatar
Simon Mo committed
1194
        scoring_func: str = "softmax",
1195
        routed_scaling_factor: float = 1.0,
1196
        e_score_correction_bias: torch.Tensor | None = None,
1197
        apply_router_weight_on_input: bool = False,
Michael Goin's avatar
Michael Goin committed
1198
        activation: str = "silu",
1199
        enable_eplb: bool = False,
1200
1201
1202
1203
        expert_load_view: torch.Tensor | None = None,
        logical_to_physical_map: torch.Tensor | None = None,
        logical_replica_count: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
1204
1205
1206
1207
1208
        if enable_eplb:
            assert expert_load_view is not None
            assert logical_to_physical_map is not None
            assert logical_replica_count is not None
            assert isinstance(layer, FusedMoE)
1209

1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
        if (
            self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
            and self.fused_experts is None
        ):
            assert activation == "silu", (
                f"Expected 'silu' activation but got {activation}"
            )
            assert scoring_func == "sigmoid", (
                f"Expected 'sigmoid' scoring func but got {scoring_func}"
            )
1220
            if self.block_quant:
1221
                import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe  # noqa: E501, F401
1222
1223
1224
1225
1226
1227
1228
1229
1230

                assert (
                    renormalize and use_grouped_topk and custom_routing_function is None
                )
                e_score_correction_bias = (
                    e_score_correction_bias.to(x.dtype)
                    if e_score_correction_bias is not None
                    else None
                )
1231
                return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
                    routing_logits=router_logits.to(torch.float32),
                    routing_bias=e_score_correction_bias,
                    x=x,
                    w13_weight=layer.w13_weight,
                    w13_weight_scale_inv=layer.w13_weight_scale_inv,
                    w2_weight=layer.w2_weight,
                    w2_weight_scale_inv=layer.w2_weight_scale_inv,
                    global_num_experts=global_num_experts,
                    top_k=top_k,
                    num_expert_group=num_expert_group,
                    topk_group=topk_group,
                    intermediate_size=layer.intermediate_size_per_partition,
                    expert_offset=layer.ep_rank * layer.local_num_experts,
                    local_num_experts=layer.local_num_experts,
1246
                    block_shape=self.weight_block_size,
1247
                    routed_scaling=routed_scaling_factor,
1248
1249
                )
            else:
1250
                assert not renormalize and custom_routing_function is not None
XuruiYang's avatar
XuruiYang committed
1251
                result = apply_flashinfer_per_tensor_scale_fp8(
1252
1253
1254
1255
1256
1257
1258
1259
                    layer=layer,
                    hidden_states=x,
                    router_logits=router_logits,
                    routing_bias=e_score_correction_bias,
                    global_num_experts=global_num_experts,
                    top_k=top_k,
                    num_expert_group=num_expert_group,
                    topk_group=topk_group,
1260
1261
                    apply_router_weight_on_input=apply_router_weight_on_input,
                )
1262

1263
1264
        zero_expert_num = getattr(layer, "zero_expert_num", 0)
        zero_expert_type = getattr(layer, "zero_expert_type", None)
XuruiYang's avatar
XuruiYang committed
1265
1266

        select_result = FusedMoE.select_experts(
1267
1268
1269
1270
1271
1272
1273
1274
1275
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
            num_expert_group=num_expert_group,
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
1276
            routed_scaling_factor=routed_scaling_factor,
1277
1278
1279
1280
1281
1282
1283
            e_score_correction_bias=e_score_correction_bias,
            indices_type=self.topk_indices_dtype,
            enable_eplb=enable_eplb,
            expert_map=expert_map,
            expert_load_view=expert_load_view,
            logical_to_physical_map=logical_to_physical_map,
            logical_replica_count=logical_replica_count,
XuruiYang's avatar
XuruiYang committed
1284
1285
1286
            global_num_experts=global_num_experts,
            zero_expert_num=zero_expert_num,
            zero_expert_type=zero_expert_type,
1287
            num_fused_shared_experts=layer.num_fused_shared_experts,
1288
1289
        )

1290
1291
1292
1293
        #
        # Note: the order of checks is important since self.fused_experts
        # can override fused_experts or cutlass but not rocm or marlin.
        #
XuruiYang's avatar
XuruiYang committed
1294
1295
        topk_weights, topk_ids, zero_expert_result = select_result

1296
1297
        if self.rocm_aiter_moe_enabled:
            from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (  # noqa: E501
1298
1299
1300
                rocm_aiter_fused_experts,
            )

1301
            assert self.fused_experts is None
XuruiYang's avatar
XuruiYang committed
1302
            result = rocm_aiter_fused_experts(
1303
1304
1305
                x,
                layer.w13_weight,
                layer.w2_weight,
1306
1307
1308
1309
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                activation=activation,
                apply_router_weight_on_input=apply_router_weight_on_input,
1310
                expert_map=expert_map,
1311
1312
                quant_config=self.moe_quant_config,
            )
1313
        elif self.use_marlin:
1314
            assert activation == "silu", f"{activation} not supported for Marlin MoE."
1315
            assert self.fused_experts is None
1316
            result = fused_marlin_moe(
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
                x,
                layer.w13_weight,
                layer.w2_weight,
                None,
                None,
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                router_logits,
                topk_weights,
                topk_ids,
                quant_type_id=scalar_types.float8_e4m3fn.id,
                apply_router_weight_on_input=apply_router_weight_on_input,
                global_num_experts=global_num_experts,
1330
                expert_map=expert_map,
1331
1332
                workspace=layer.workspace,
            )
1333
        elif self.fused_experts:
XuruiYang's avatar
XuruiYang committed
1334
            result = self.fused_experts(
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                inplace=True,
                activation=activation,
                global_num_experts=global_num_experts,
                apply_router_weight_on_input=apply_router_weight_on_input,
                expert_map=expert_map,
            )
1346
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
1347
            assert not self.block_quant
1348
1349
1350
1351
1352
1353
1354
            assert not renormalize and custom_routing_function is not None
            assert activation == "silu", (
                f"Expected 'silu' activation but got {activation}"
            )
            assert scoring_func == "sigmoid", (
                f"Expected 'sigmoid' scoring func but got {scoring_func}"
            )
1355

XuruiYang's avatar
XuruiYang committed
1356
            result = flashinfer_cutlass_moe_fp8(
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
                x,
                layer,
                topk_weights,
                topk_ids,
                inplace=False,
                activation=activation,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
1367
        else:
1368
            from vllm.model_executor.layers.fused_moe import fused_experts
1369

XuruiYang's avatar
XuruiYang committed
1370
            result = fused_experts(
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                inplace=True,
                activation=activation,
                global_num_experts=global_num_experts,
                apply_router_weight_on_input=apply_router_weight_on_input,
                expert_map=expert_map,
1381
1382
1383
                quant_config=self.moe_quant_config,
                allow_deep_gemm=self.allow_deep_gemm,
                allow_cutlass_block_scaled_grouped_gemm=(
1384
1385
1386
                    self.allow_cutlass_block_scaled_grouped_gemm
                ),
            )
XuruiYang's avatar
XuruiYang committed
1387
        if zero_expert_num != 0 and zero_expert_type is not None:
1388
            assert not isinstance(result, tuple), (
XuruiYang's avatar
XuruiYang committed
1389
                "Shared + zero experts are mutually exclusive not yet supported"
1390
            )
XuruiYang's avatar
XuruiYang committed
1391
1392
1393
            return result, zero_expert_result
        else:
            return result
1394
1395


1396
1397
1398
class Fp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
1399
1400
1401
    """

    def __init__(self, quant_config: Fp8Config):
1402
        super().__init__(quant_config)