fp8.py 57.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Callable
5
from enum import Enum
6
from functools import partial
7
from typing import TYPE_CHECKING, Any, Optional
8
9
10
11
12

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter

13
import vllm.envs as envs
14
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
15
from vllm import _custom_ops as ops
16
from vllm._aiter_ops import rocm_aiter_ops
17
from vllm.attention.layer import Attention
18
from vllm.distributed import get_tensor_model_parallel_world_size
19
from vllm.logger import init_logger
20
from vllm.model_executor.layers.batch_invariant import (
21
    vllm_is_batch_invariant,
22
)
bnellnm's avatar
bnellnm committed
23
from vllm.model_executor.layers.fused_moe import (
24
25
26
27
28
29
30
    FusedMoE,
    FusedMoEActivationFormat,
    FusedMoEMethodBase,
    FusedMoEPermuteExpertsUnpermute,
    FusedMoEPrepareAndFinalize,
    FusedMoeWeightScaleSupported,
)
31
from vllm.model_executor.layers.fused_moe.config import (
32
    FusedMoEParallelConfig,
33
    FusedMoEQuantConfig,
34
    RoutingMethodType,
35
36
    fp8_w8a8_moe_quant_config,
)
37
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
38
39
40
41
42
43
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
)
44
from vllm.model_executor.layers.quantization import QuantizationMethods
45
from vllm.model_executor.layers.quantization.base_config import (
46
47
48
    QuantizationConfig,
    QuantizeMethodBase,
)
49
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
50
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
51
52
    FlashinferMoeBackend,
    apply_flashinfer_per_tensor_scale_fp8,
53
    build_flashinfer_fp8_cutlass_moe_prepare_finalize,
54
55
56
57
58
59
60
    flashinfer_cutlass_moe_fp8,
    get_flashinfer_moe_backend,
    register_moe_scaling_factors,
    rotate_flashinfer_fp8_moe_weights,
    select_cutlass_fp8_gemm_impl,
    swap_w13_to_w31,
)
61
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
62
63
64
65
    W8A8BlockFp8LinearOp,
    create_fp8_input_scale,
    create_fp8_scale_parameter,
    create_fp8_weight_parameter,
66
    deepgemm_post_process_fp8_weight_block,
67
68
69
70
71
    maybe_post_process_fp8_weight_block,
    process_fp8_weight_block_strategy,
    process_fp8_weight_tensor_strategy,
    validate_fp8_block_shape,
)
72
73
74
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
    get_marlin_input_dtype,
)
75
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
76
77
78
79
    apply_fp8_marlin_linear,
    prepare_fp8_layer_for_marlin,
    prepare_moe_fp8_layer_for_marlin,
)
80
from vllm.model_executor.layers.quantization.utils.quant_utils import (
81
82
83
    GroupShape,
    is_layer_skipped,
)
84
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
85
86
87
88
89
90
91
92
93
94
95
96
97
    Fp8LinearOp,
    all_close_1d,
    cutlass_block_fp8_supported,
    cutlass_fp8_supported,
    maybe_create_device_identity,
    normalize_e4m3fn_to_e4m3fnuz,
    per_tensor_dequantize,
)
from vllm.model_executor.parameter import (
    BlockQuantScaleParameter,
    ModelWeightParameter,
    PerTensorScaleParameter,
)
98
from vllm.model_executor.utils import set_weight_attrs
99
from vllm.platforms import current_platform
100
from vllm.scalar_type import scalar_types
101
102
103
104
from vllm.utils.deep_gemm import (
    is_deep_gemm_e8m0_used,
    is_deep_gemm_supported,
)
105
from vllm.utils.flashinfer import has_flashinfer_moe
106
from vllm.utils.import_utils import has_deep_gemm
107

108
109
110
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

111
112
113
114
ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)

115

116
117
118
119
120
121
122
123
124
125
class Fp8MoeBackend(Enum):
    NONE = 0
    FLASHINFER_TRTLLM = 1
    FLASHINFER_CUTLASS = 2
    DEEPGEMM = 3
    CUTLASS_BLOCK_SCALED_GROUPED_GEMM = 4
    MARLIN = 5
    TRITON = 6


126
def get_fp8_moe_backend(
127
128
129
    block_quant: bool,
    moe_parallel_config: FusedMoEParallelConfig,
    with_lora_support: bool,
130
) -> Fp8MoeBackend:
131
132
133
134
    """
    Select the primary FP8 MoE backend
    Note: Shape-specific fallbacks may still occur at runtime.
    """
135
136
    if with_lora_support:
        return Fp8MoeBackend.TRITON
137
    # Prefer FlashInfer backends on supported GPUs; allow SM90 and SM100.
138
139
    if (
        current_platform.is_cuda()
140
141
142
143
        and (
            current_platform.is_device_capability(100)
            or current_platform.is_device_capability(90)
        )
144
145
146
        and envs.VLLM_USE_FLASHINFER_MOE_FP8
        and has_flashinfer_moe()
    ):
147
148
        backend = get_flashinfer_moe_backend()
        if backend == FlashinferMoeBackend.TENSORRT_LLM:
149
            logger.info_once("Using FlashInfer FP8 MoE TRTLLM backend for SM100")
150
151
            return Fp8MoeBackend.FLASHINFER_TRTLLM
        else:
152
            if block_quant and current_platform.is_device_capability(100):
153
154
155
156
157
158
                raise ValueError(
                    "FlashInfer FP8 MoE throughput backend does not "
                    "support block quantization. Please use "
                    "VLLM_FLASHINFER_MOE_BACKEND=latency "
                    "instead."
                )
159
            logger.info_once("Using FlashInfer FP8 MoE CUTLASS backend for SM90/SM100")
160
161
162
            return Fp8MoeBackend.FLASHINFER_CUTLASS

    # weight-only path for older GPUs without native FP8
163
164
165
166
    use_marlin = (
        not current_platform.has_device_capability(89)
        or envs.VLLM_TEST_FORCE_FP8_MARLIN
    )
167
168
169
170
171
172
    if current_platform.is_rocm():
        use_marlin = False
    if use_marlin:
        logger.info_once("Using Marlin backend for FP8 MoE")
        return Fp8MoeBackend.MARLIN

173
174
175
176
177
178
179
180
181
182
183
184
185
    # Determine if we should use DeepGEMM with block-quantized weights:
    # - If explicitly set by user, respect their choice
    # - If not explicitly set (default), disable when TP size is >= 8
    moe_use_deep_gemm = envs.VLLM_MOE_USE_DEEP_GEMM
    if not envs.is_set("VLLM_MOE_USE_DEEP_GEMM") and moe_parallel_config.tp_size >= 8:
        moe_use_deep_gemm = False
        logger.info_once(
            "DeepGEMM MoE is disabled by default when TP size is >= 8. "
            "Set VLLM_MOE_USE_DEEP_GEMM=1 to enable it.",
            scope="local",
        )

    if envs.VLLM_USE_DEEP_GEMM and moe_use_deep_gemm and block_quant:
186
        if not has_deep_gemm():
187
188
189
            logger.warning_once(
                "DeepGEMM backend requested but not available.", scope="local"
            )
190
        elif is_deep_gemm_supported():
191
            logger.info_once("Using DeepGEMM backend for FP8 MoE", scope="local")
192
193
194
            return Fp8MoeBackend.DEEPGEMM

    # CUTLASS BlockScaled GroupedGemm on SM100 with block-quantized weights
195
196
197
198
199
    if (
        current_platform.is_cuda()
        and current_platform.is_device_capability(100)
        and block_quant
    ):
200
201
202
        logger.info_once(
            "Using Cutlass BlockScaled GroupedGemm backend for FP8 MoE", scope="local"
        )
203
204
205
206
207
208
209
        return Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM

    # default to Triton
    logger.info_once("Using Triton backend for FP8 MoE")
    return Fp8MoeBackend.TRITON


210
class Fp8Config(QuantizationConfig):
211
212
    """Config class for FP8."""

213
214
    def __init__(
        self,
215
        is_checkpoint_fp8_serialized: bool = False,
216
        activation_scheme: str = "dynamic",
217
218
        ignored_layers: list[str] | None = None,
        weight_block_size: list[int] | None = None,
219
    ) -> None:
220
        super().__init__()
221

222
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
223

224
        if activation_scheme not in ACTIVATION_SCHEMES:
225
            raise ValueError(f"Unsupported activation scheme {activation_scheme}")
226
        self.activation_scheme = activation_scheme
227
        self.ignored_layers = ignored_layers or []
228
229
230
231
        if weight_block_size is not None:
            if not is_checkpoint_fp8_serialized:
                raise ValueError(
                    "The block-wise quantization only supports fp8-serialized "
232
233
                    "checkpoint for now."
                )
234
235
236
            if len(weight_block_size) != 2:
                raise ValueError(
                    "The quantization block size of weight must have 2 "
237
238
                    f"dimensions, but got {len(weight_block_size)} dimensions"
                )
239
            if activation_scheme != "dynamic":
240
241
242
243
244
                raise ValueError(
                    "The block-wise quantization only supports "
                    "dynamic activation scheme for now, but got "
                    f"{activation_scheme} activation scheme."
                )
245
        self.weight_block_size = weight_block_size
246

247
    @classmethod
248
    def get_name(cls) -> QuantizationMethods:
249
250
251
        return "fp8"

    @classmethod
252
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
253
254
255
256
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
257
        return 80
258
259

    @classmethod
260
    def get_config_filenames(cls) -> list[str]:
261
262
        return []

263
264
    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if self.ignored_layers is not None:
265
            self.ignored_layers = hf_to_vllm_mapper.apply_list(self.ignored_layers)
266

267
    @classmethod
268
    def from_config(cls, config: dict[str, Any]) -> "Fp8Config":
269
        quant_method = cls.get_from_keys(config, ["quant_method"])
270
        is_checkpoint_fp8_serialized = "fp8" in quant_method
271
        activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
272
        ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
273
        weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None)
274
        if not ignored_layers:
275
276
277
278
279
280
281
282
283
284
285
286
287
            ignored_layers = cls.get_from_keys_or(
                config, ["modules_to_not_convert"], None
            )
        return cls(
            is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
            activation_scheme=activation_scheme,
            ignored_layers=ignored_layers,
            weight_block_size=weight_block_size,
        )

    def get_xpu_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
288
        from vllm.model_executor.layers.quantization.ipex_quant import (
289
290
291
292
            XPUFp8LinearMethod,
            XPUFp8MoEMethod,
        )

293
294
295
296
        fp8_config = Fp8Config(
            is_checkpoint_fp8_serialized=self.is_checkpoint_fp8_serialized,
            activation_scheme=self.activation_scheme,
            ignored_layers=self.ignored_layers,
297
298
            weight_block_size=self.weight_block_size,
        )
299
300

        if isinstance(layer, LinearBase):
301
302
303
304
305
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
306
307
308
309
310
311
312
313
                return UnquantizedLinearMethod()
            return XPUFp8LinearMethod(fp8_config)
        elif isinstance(layer, FusedMoE):
            return XPUFp8MoEMethod(fp8_config, layer)
        elif isinstance(layer, Attention):
            return Fp8KVCacheMethod(self)
        return None

314
315
316
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
317
318
        if current_platform.is_xpu():
            return self.get_xpu_quant_method(layer, prefix)
319
        if isinstance(layer, LinearBase):
320
321
322
323
324
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
325
                return UnquantizedLinearMethod()
326
327
328
            quant_method = Fp8LinearMethod(self)
            quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
            return quant_method
329
        elif isinstance(layer, FusedMoE):
330
331
332
333
334
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
XuruiYang's avatar
XuruiYang committed
335
                return UnquantizedFusedMoEMethod(layer.moe_config)
336
337
338
            moe_quant_method = Fp8MoEMethod(self, layer)
            moe_quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
            return moe_quant_method
339
        elif isinstance(layer, Attention):
340
            return Fp8KVCacheMethod(self)
341
        return None
342

343
    def get_cache_scale(self, name: str) -> str | None:
344
345
346
347
348
349
350
351
352
353
354
355
        """
        Check whether the param name matches the format for k/v cache scales
        in compressed-tensors. If this is the case, return its equivalent
        param name expected by vLLM

        :param name: param name
        :return: matching param name for KV cache scale in vLLM
        """
        if name.endswith(".output_scale") and ".k_proj" in name:
            return name.replace(".k_proj.output_scale", ".attn.k_scale")
        if name.endswith(".output_scale") and ".v_proj" in name:
            return name.replace(".v_proj.output_scale", ".attn.v_scale")
356
357
358
359
360
        if name.endswith(".output_scale") and ".q_proj" in name:
            return name.replace(".q_proj.output_scale", ".attn.q_scale")
        if name.endswith("self_attn.prob_output_scale"):
            return name.replace(".prob_output_scale", ".attn.prob_scale")
        # If no matches, return None
361
362
        return None

363
364
365

class Fp8LinearMethod(LinearMethodBase):
    """Linear method for FP8.
366
367
368
369
370
371
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.
372
373
374
375
376

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
    2. Only support float8_e4m3fn data type due to the limitation of
       torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
377

378
379
380
381
    Args:
        quant_config: The quantization config.
    """

382
    def __init__(self, quant_config: Fp8Config):
383
        self.quant_config = quant_config
384
        self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
385
        self.out_dtype = torch.get_default_dtype()
386

387
388
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
389
        self.marlin_input_dtype = None
390
391
392
393
        self.use_marlin = (
            not current_platform.has_device_capability(89)
            or envs.VLLM_TEST_FORCE_FP8_MARLIN
        )
394
        # Disable marlin for rocm
395
        if current_platform.is_rocm():
396
            self.use_marlin = False
397
        if vllm_is_batch_invariant():
398
            self.use_marlin = False
399

400
        self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enaled()
401
        self.use_deep_gemm = is_deep_gemm_supported()
402

403
404
        self.weight_block_size = self.quant_config.weight_block_size
        self.block_quant = self.weight_block_size is not None
405
        self.act_q_static = self.quant_config.activation_scheme == "static"
406
407
        if self.weight_block_size:
            self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
408
        else:
409
410
411
412
413
            # Use per-token quantization for better perf if dynamic and cutlass
            if not self.act_q_static and cutlass_fp8_supported():
                self.act_q_group_shape = GroupShape.PER_TOKEN
            else:
                self.act_q_group_shape = GroupShape.PER_TENSOR
414

415
416
417
418
419
420
421
422
423
424
425
426
        if self.block_quant:
            assert not self.act_q_static
            assert self.weight_block_size is not None
            self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
                weight_group_shape=GroupShape(*self.weight_block_size),
                act_quant_group_shape=self.act_q_group_shape,
                cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
                use_aiter_and_is_supported=self.use_aiter_and_is_supported,
            )
        else:
            self.fp8_linear = Fp8LinearOp(
                act_quant_static=self.act_q_static,
427
428
                act_quant_group_shape=self.act_q_group_shape,
            )
429

430
431
432
433
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
434
        output_partition_sizes: list[int],
435
436
437
438
439
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
440
441
        maybe_create_device_identity()

442
        output_size_per_partition = sum(output_partition_sizes)
443
        weight_loader = extra_weight_attrs.get("weight_loader")
444
445
446
447
448
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None
449

450
        if self.block_quant:
451
452
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
453
454
455
456
457
458
459
460
            validate_fp8_block_shape(
                layer,
                input_size,
                output_size,
                input_size_per_partition,
                output_partition_sizes,
                self.weight_block_size,
            )
461

462
        # WEIGHT
463
        if self.quant_config.is_checkpoint_fp8_serialized:
464
465
466
            weight = create_fp8_weight_parameter(
                output_size_per_partition, input_size_per_partition, weight_loader
            )
467
        else:
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491

            def patched_weight_loader(param, loaded_weight, *args, **kwargs):
                # load the current weight chunk
                res = weight_loader(param, loaded_weight, *args, **kwargs)  # type: ignore[misc]

                # track how many elements we have updated
                if not hasattr(layer, "_loaded_numel"):
                    layer._loaded_numel = 0
                layer._loaded_numel += loaded_weight.numel()

                # if we have loaded all of the elements, call
                # process_weights_after_loading
                target_loaded_numel = layer.weight.numel()
                if layer._loaded_numel == target_loaded_numel:
                    self.process_weights_after_loading(layer)

                    # Delete the bookkeeping
                    del layer._loaded_numel
                    # Prevent the usual `process_weights_after_loading` call from doing
                    # anything
                    layer._already_called_process_weights_after_loading = True

                return res

492
            # For non-serialized checkpoints, use original dtype
493
494
495
496
497
498
499
500
            weight = ModelWeightParameter(
                data=torch.empty(
                    output_size_per_partition,
                    input_size_per_partition,
                    dtype=params_dtype,
                ),
                input_dim=1,
                output_dim=0,
501
                weight_loader=patched_weight_loader,
502
            )
503
504
        layer.register_parameter("weight", weight)

505
506
507
508
        # If checkpoint is serialized fp8, load them.
        # Otherwise, wait until process_weights_after_loading.
        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
509
            if not self.block_quant:
510
511
512
513
514
515
516
                scale = create_fp8_scale_parameter(
                    PerTensorScaleParameter,
                    output_partition_sizes,
                    input_size_per_partition,
                    None,
                    weight_loader,
                )
517
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
518
519
                layer.register_parameter("weight_scale", scale)
            else:
520
521
                assert not self.act_q_static
                assert self.weight_block_size is not None
522
523
524
525
526
527
528
                scale = create_fp8_scale_parameter(
                    BlockQuantScaleParameter,
                    output_partition_sizes,
                    input_size_per_partition,
                    self.weight_block_size,
                    weight_loader,
                )
529
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
530
531
                # The weight_scale_inv name is intentional for deepseekv3
                layer.register_parameter("weight_scale_inv", scale)
532

533
            # INPUT ACTIVATION SCALE
534
            if self.act_q_static:
535
                scale = create_fp8_input_scale(output_partition_sizes, weight_loader)
536
                set_weight_attrs(scale, {"scale_type": "input_scale"})
537
                layer.register_parameter("input_scale", scale)
538
539
            else:
                layer.register_parameter("input_scale", None)
540

541
    def process_weights_after_loading(self, layer: Module) -> None:
542
543
544
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

545
        size_k_first = True
546
        input_scale = None
547
        # TODO(rob): refactor block quant into separate class.
548
        if self.block_quant:
549
            assert not self.act_q_static
550
            size_k_first = False
551

552
            weight, weight_scale = process_fp8_weight_block_strategy(
553
554
                layer.weight, layer.weight_scale_inv
            )
555
556
557
            # Delete the weight_scale_inv parameter to avoid confusion
            # with the weight_scale parameter
            del layer.weight_scale_inv
558

559
        # If checkpoint not serialized fp8, quantize the weights.
560
        elif not self.quant_config.is_checkpoint_fp8_serialized:
561
            qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
562
            weight = qweight.t()
563

564
        # If checkpoint is fp8 per-tensor, handle that there are N scales for N
565
        # shards in a fused module
566
        else:
567
568
            weight = layer.weight
            weight_scale = layer.weight_scale
569
570
571

            # If using w8a8, torch._scaled_mm needs per tensor, so
            # requantize the logical shards as a single weight.
572
            if not self.use_marlin:
573
574
575
576
577
578
                weight, weight_scale, input_scale = process_fp8_weight_tensor_strategy(
                    weight,
                    weight_scale,
                    layer.logical_widths,
                    getattr(layer, "input_scale", None),
                )
579
580
581
582
583
584
585
586
                if self.act_q_static:
                    assert input_scale is not None
                    input_scale = input_scale.max()
            weight = weight.t()

        # Update layer with new values.
        layer.weight = Parameter(weight.data, requires_grad=False)
        layer.weight_scale = Parameter(weight_scale.data, requires_grad=False)
587
588
589
590
591
        layer.input_scale = (
            Parameter(input_scale, requires_grad=False)
            if input_scale is not None
            else None
        )
592

593
        if self.use_marlin:
594
595
596
            prepare_fp8_layer_for_marlin(
                layer, size_k_first, input_dtype=self.marlin_input_dtype
            )
597
598
            # Activations not quantized for marlin.
            del layer.input_scale
599
            return
600

601
        if self.block_quant:
602
            maybe_post_process_fp8_weight_block(layer)
603

604
605
606
607
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
608
        bias: torch.Tensor | None = None,
609
    ) -> torch.Tensor:
610
611
        # if batch invariant mode is enabled, prefer DeepGEMM FP8 path
        # we will use BF16 dequant when DeepGEMM is not supported.
612
        if vllm_is_batch_invariant():
613
614
            if self.block_quant:
                assert self.weight_block_size is not None
615
616
617
618
619
620
621
                return self.w8a8_block_fp8_linear.apply(
                    input=x,
                    weight=layer.weight,
                    weight_scale=layer.weight_scale,
                    input_scale=layer.input_scale,
                    bias=bias,
                )
622
            else:
623
624
625
                # per-tensor/channel: dequant to BF16 and run GEMM
                weight_fp8 = layer.weight.to(torch.bfloat16)
                weight_scale = layer.weight_scale.to(torch.bfloat16)
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
                if weight_scale.numel() == 1:
                    # Per-tensor: simple scalar multiplication
                    weight_bf16 = weight_fp8 * weight_scale
                else:
                    # Multiple scales (fused modules like QKV)
                    # Try to infer correct broadcasting
                    # weight is [K, N], scale could be [num_logical_weights]
                    # Need to figure out how to broadcast - for now just try
                    # direct multiplication
                    if (
                        weight_scale.dim() == 1
                        and weight_scale.shape[0] == weight_fp8.shape[0]
                    ):
                        # Per-row scaling
                        weight_bf16 = weight_fp8 * weight_scale.unsqueeze(1)
                    else:
                        # Fallback
                        weight_bf16 = weight_fp8 * weight_scale
644
                return torch.nn.functional.linear(x, weight_bf16.t(), bias)
645

646
        if self.use_marlin:
647
648
649
650
            return apply_fp8_marlin_linear(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale,
651
652
653
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
654
                input_dtype=self.marlin_input_dtype,
655
656
                bias=bias,
            )
657

658
        if self.block_quant:
659
660
661
            assert self.weight_block_size is not None

            return self.w8a8_block_fp8_linear.apply(
662
                input=x,
663
664
665
                weight=layer.weight,
                weight_scale=layer.weight_scale,
                input_scale=layer.input_scale,
666
                bias=bias,
667
            )
668

669
670
671
672
673
674
675
676
        return self.fp8_linear.apply(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            out_dtype=self.out_dtype,
            input_scale=layer.input_scale,
            bias=bias,
        )
677
678


679
680
681
682
683
684
685
686
687
688
689
690
691
class Fp8MoEMethod(FusedMoEMethodBase):
    """MoE method for FP8.
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

692
693
694
    def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
        super().__init__(layer.moe_config)
        self.layer = layer
695
        self.quant_config = quant_config
696
        self.weight_block_size = self.quant_config.weight_block_size
697
        self.block_quant: bool = self.weight_block_size is not None
698
        self.fp8_backend = get_fp8_moe_backend(
699
            self.block_quant, layer.moe_parallel_config, self.moe.is_lora_enabled
700
        )
701

702
        self.marlin_input_dtype = None
703
        self.use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN
704
        self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
705
706
707
708
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
            self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
        elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
            self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS
709
710
711
712
713
714
715
716
717
718
            if self.block_quant:
                assert self.weight_block_size == [128, 128], (
                    f"Only support weight_block_size == [128, 128], "
                    f"got {self.weight_block_size}"
                )
            self.flashinfer_moe_fn = partial(
                flashinfer_cutlass_moe_fp8,
                moe=self.moe,
                use_deepseek_fp8_block_scale=self.block_quant,
            )
719

720
        self.allow_deep_gemm = self.fp8_backend == Fp8MoeBackend.DEEPGEMM
721
722
723
        self.allow_cutlass_block_scaled_grouped_gemm = (
            self.fp8_backend == Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM
        )
724

725
726
727
728
729
730
731
732
733
    def create_weights(
        self,
        layer: Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
734
735
736
737
738
739
        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

740
741
        if self.quant_config.is_checkpoint_fp8_serialized:
            params_dtype = torch.float8_e4m3fn
742
        if self.block_quant:
743
744
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
745
746
            tp_size = get_tensor_model_parallel_world_size()
            block_n, block_k = (
747
748
                self.weight_block_size[0],
                self.weight_block_size[1],
749
750
751
752
753
            )
            # NOTE: To ensure proper alignment of the block-wise quantization
            # scales, the output_size of the weights for both the gate and up
            # layers must be divisible by block_n.
            # Required by column parallel or enabling merged weights
754
            if intermediate_size_per_partition % block_n != 0:
755
756
                raise ValueError(
                    f"The output_size of gate's and up's weight = "
757
                    f"{intermediate_size_per_partition} is not divisible by "
758
759
760
                    f"weight quantization block_n = {block_n}."
                )
            if tp_size > 1 and intermediate_size_per_partition % block_k != 0:
761
                # Required by row parallel
762
763
764
                raise ValueError(
                    f"The input_size of down's weight = "
                    f"{intermediate_size_per_partition} is not divisible by "
765
766
                    f"weight quantization block_k = {block_k}."
                )
767

768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
        # if we are doing online quantization, patch the weight
        # loaded to call `process_weights_after_loading` in a streaming fashion
        # as soon as the last weight chunk is loaded
        if not self.quant_config.is_checkpoint_fp8_serialized:
            weight_loader = extra_weight_attrs["weight_loader"]
            # create a new holder to prevent modifying behavior of any other
            # objects which might depend on the old one
            new_extra_weight_attrs = extra_weight_attrs

            def patched_weight_loader(param, loaded_weight, *args, **kwargs):
                # load the current weight chunk
                res = weight_loader(param, loaded_weight, *args, **kwargs)  # type: ignore[misc]

                # add a counter to track how many elements we have updated
                if not hasattr(layer, "_loaded_numel"):
                    layer._loaded_numel = 0
                layer._loaded_numel += loaded_weight.numel()

                # if we have loaded all of the elements, call
                # process_weights_after_loading
                target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel()
                if layer._loaded_numel == target_loaded_numel:
                    self.process_weights_after_loading(layer)

                    # Delete the bookkeeping
                    del layer._loaded_numel
                    # Prevent the usual `process_weights_after_loading` call
                    # from doing anything
                    layer._already_called_process_weights_after_loading = True

                return res

            new_extra_weight_attrs["weight_loader"] = patched_weight_loader
            extra_weight_attrs = new_extra_weight_attrs

803
        # WEIGHTS
804
805
806
807
808
809
810
811
812
        w13_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                hidden_size,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
813
814
815
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

816
817
818
819
820
821
822
823
824
        w2_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
825
826
827
828
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
829
830
831
        if not self.block_quant:
            # Allocate 2 scales for w1 and w3 respectively.
            # They will be combined to a single scale after weight loading.
832
833
834
835
836
837
            w13_weight_scale = torch.nn.Parameter(
                torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
            )
            w2_weight_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
838
839
840
841
842
843
            layer.register_parameter("w13_weight_scale", w13_weight_scale)
            layer.register_parameter("w2_weight_scale", w2_weight_scale)
        else:
            w13_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
844
                    2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
845
846
847
848
849
850
851
852
853
                    (hidden_size + block_k - 1) // block_k,
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            w2_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
                    (hidden_size + block_n - 1) // block_n,
854
                    (intermediate_size_per_partition + block_k - 1) // block_k,
855
856
857
858
859
860
861
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
            layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
            assert self.quant_config.activation_scheme == "dynamic"
862

863
864
865
        # Add the quantization method used (per tensor/grouped/channel)
        # to ensure the weight scales are loaded in properly
        extra_weight_attrs.update(
866
867
868
869
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
            if self.block_quant
            else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
        )
870
871
872
873
        # If loading fp8 checkpoint, pass the weight loaders.
        # If loading an fp16 checkpoint, do not (we will quantize in
        #   process_weights_after_loading()
        if self.quant_config.is_checkpoint_fp8_serialized:
874
875
            set_weight_attrs(w13_weight_scale, extra_weight_attrs)
            set_weight_attrs(w2_weight_scale, extra_weight_attrs)
876
877
878
879
880
881

        # INPUT_SCALES
        if self.quant_config.activation_scheme == "static":
            if not self.quant_config.is_checkpoint_fp8_serialized:
                raise ValueError(
                    "Found static activation scheme for checkpoint that "
882
883
                    "was not serialized fp8."
                )
884

885
886
887
            w13_input_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
888
            layer.register_parameter("w13_input_scale", w13_input_scale)
889
            set_weight_attrs(w13_input_scale, extra_weight_attrs)
890

891
892
893
            w2_input_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
894
            layer.register_parameter("w2_input_scale", w2_input_scale)
895
896
            set_weight_attrs(w2_input_scale, extra_weight_attrs)

897
        else:
898
899
            layer.w13_input_scale = None
            layer.w2_input_scale = None
900

901
902
        self.rocm_aiter_moe_enabled = False

903
    def process_weights_after_loading(self, layer: Module) -> None:
904
905
906
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

907
908
        # Lazy import to avoid importing triton too early.

909
        self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
910

911
        # TODO (rob): refactor block quant into separate class.
912
        if self.block_quant:
913
            assert self.quant_config.activation_scheme == "dynamic"
914
            if current_platform.is_fp8_fnuz():
915
                w13_weight, w13_weight_scale_inv, w13_input_scale = (
916
                    normalize_e4m3fn_to_e4m3fnuz(
917
918
919
920
921
922
                        layer.w13_weight,
                        layer.w13_weight_scale_inv,
                        layer.w13_input_scale,
                    )
                )
                w2_weight, w2_weight_scale_inv, w2_input_scale = (
923
                    normalize_e4m3fn_to_e4m3fnuz(
924
925
926
                        layer.w2_weight, layer.w2_weight_scale_inv, layer.w2_input_scale
                    )
                )
927
            elif self.flashinfer_moe_backend is not None:
928
929
                # NOTE: weights have to be swapped since the activation is
                # applied on different half for flashinfer vs vllm
930
                w13_weight = swap_w13_to_w31(layer.w13_weight.data)
931
                w13_weight_scale_inv = swap_w13_to_w31(layer.w13_weight_scale_inv.data)
932
933
                w2_weight = layer.w2_weight.data
                w2_weight_scale_inv = layer.w2_weight_scale_inv.data
934
935
936
937
938
939
940
941
            else:
                w13_weight = layer.w13_weight.data
                w13_weight_scale_inv = layer.w13_weight_scale_inv.data
                w2_weight = layer.w2_weight
                w2_weight_scale_inv = layer.w2_weight_scale_inv

            # torch.compile() cannot use Parameter subclasses.
            layer.w13_weight = Parameter(w13_weight, requires_grad=False)
942
943
944
            layer.w13_weight_scale_inv = Parameter(
                w13_weight_scale_inv, requires_grad=False
            )
945
            layer.w2_weight = Parameter(w2_weight, requires_grad=False)
946
947
948
            layer.w2_weight_scale_inv = Parameter(
                w2_weight_scale_inv, requires_grad=False
            )
949
            if self.rocm_aiter_moe_enabled:
950
                # reshaping weights is required for aiter moe kernel.
951
                shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
952
953
                    layer.w13_weight.data, layer.w2_weight.data
                )
954

955
956
                layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
957

958
            # DeepGemm scales need to be transposed and aligned. We try to do
959
            # it ahead of time for performance reasons.
960
961
962
963
964
965
966
            if self.allow_deep_gemm:
                dg_w13_weight, dg_w13_weight_scale_inv = (
                    deepgemm_post_process_fp8_weight_block(
                        wq=layer.w13_weight.data,
                        ws=layer.w13_weight_scale_inv.data,
                        quant_block_shape=tuple(layer.weight_block_size),
                        use_e8m0=is_deep_gemm_e8m0_used(),
967
                    )
968
969
970
971
972
973
974
                )
                dg_w2_weight, dg_w2_weight_scale_inv = (
                    deepgemm_post_process_fp8_weight_block(
                        wq=layer.w2_weight.data,
                        ws=layer.w2_weight_scale_inv.data,
                        quant_block_shape=tuple(layer.weight_block_size),
                        use_e8m0=is_deep_gemm_e8m0_used(),
975
                    )
976
977
978
979
980
981
982
983
984
                )
                layer.w13_weight = Parameter(dg_w13_weight, requires_grad=False)
                layer.w13_weight_scale_inv = Parameter(
                    dg_w13_weight_scale_inv, requires_grad=False
                )
                layer.w2_weight = Parameter(dg_w2_weight, requires_grad=False)
                layer.w2_weight_scale_inv = Parameter(
                    dg_w2_weight_scale_inv, requires_grad=False
                )
985

986
        # If checkpoint is fp16, quantize in place.
987
        elif not self.quant_config.is_checkpoint_fp8_serialized:
988
            fp8_dtype = current_platform.fp8_dtype()
989
            w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
990
            w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
991
992
993

            # Re-initialize w13_scale because we directly quantize
            # merged w13 weights and generate a single scaling factor.
994
995
996
997
998
999
1000
1001
            layer.w13_weight_scale = torch.nn.Parameter(
                torch.ones(
                    layer.local_num_experts,
                    dtype=torch.float32,
                    device=w13_weight.device,
                ),
                requires_grad=False,
            )
1002
            for expert in range(layer.local_num_experts):
1003
1004
1005
1006
1007
1008
1009
1010
                w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
                    ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
                )
                w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
                    ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
                )
            layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
            layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
1011
            if self.rocm_aiter_moe_enabled:
1012
                # reshaping weights is required for aiter moe kernel.
1013
                shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
1014
1015
                    layer.w13_weight, layer.w2_weight
                )
1016

1017
1018
                layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
1019
1020
1021
1022
1023
1024
1025
        # If checkpoint is fp8, we need to handle that the
        # MoE kernels require single activation scale and single weight
        # scale for w13 per expert.
        else:
            # Fp8 moe kernels require a single activation scale.
            # We take the max of all the scales in case they differ.
            if self.quant_config.activation_scheme == "static":
1026
                if layer.w13_input_scale is None or layer.w2_input_scale is None:
1027
1028
                    raise ValueError(
                        "QuantConfig has static quantization, but found "
1029
1030
1031
1032
1033
                        "activation scales are None."
                    )
                if not all_close_1d(layer.w13_input_scale) or not all_close_1d(
                    layer.w2_input_scale
                ):
1034
                    logger.warning_once(
1035
1036
                        "Found input_scales that are not equal for "
                        "fp8 MoE layer. Using the maximum across experts "
1037
1038
                        "for each layer."
                    )
1039
                layer.w13_input_scale = torch.nn.Parameter(
1040
1041
                    layer.w13_input_scale.max(), requires_grad=False
                )
1042
                layer.w2_input_scale = torch.nn.Parameter(
1043
1044
                    layer.w2_input_scale.max(), requires_grad=False
                )
1045
            if current_platform.is_fp8_fnuz():
1046
                # Normalize the weights and scales
1047
                w13_weight, w13_weight_scale, w13_input_scale = (
1048
                    normalize_e4m3fn_to_e4m3fnuz(
1049
1050
1051
1052
                        layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale
                    )
                )
                w2_weight, w2_weight_scale, w2_input_scale = (
1053
                    normalize_e4m3fn_to_e4m3fnuz(
1054
1055
1056
                        layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale
                    )
                )
1057
                # Reset the parameter
1058
                layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
1059
                layer.w13_weight_scale = torch.nn.Parameter(
1060
1061
                    w13_weight_scale, requires_grad=False
                )
1062
1063
                if w13_input_scale is not None:
                    layer.w13_input_scale = torch.nn.Parameter(
1064
1065
1066
1067
1068
1069
                        w13_input_scale, requires_grad=False
                    )
                layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
                layer.w2_weight_scale = torch.nn.Parameter(
                    w2_weight_scale, requires_grad=False
                )
1070
1071
                if w2_input_scale is not None:
                    layer.w2_input_scale = torch.nn.Parameter(
1072
1073
                        w2_input_scale, requires_grad=False
                    )
1074
1075
1076

            # Fp8 moe kernel needs single weight scale for w13 per expert.
            # We take the max then dequant and requant each expert.
1077
            assert layer.w13_weight_scale is not None
1078
            shard_size = layer.intermediate_size_per_partition
1079
            max_w13_scales = layer.w13_weight_scale.max(dim=1).values
1080
            for expert_id in range(layer.local_num_experts):
1081
1082
1083
                start = 0
                for shard_id in range(2):
                    dq_weight = per_tensor_dequantize(
1084
1085
1086
1087
1088
1089
                        layer.w13_weight[expert_id][start : start + shard_size, :],
                        layer.w13_weight_scale[expert_id][shard_id],
                    )
                    layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
                        ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
                    )
1090
1091
                    start += shard_size

1092
            if self.rocm_aiter_moe_enabled:
1093
                shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
1094
1095
                    layer.w13_weight, layer.w2_weight
                )
1096

1097
1098
                layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
1099

1100
1101
1102
            layer.w13_weight_scale = torch.nn.Parameter(
                max_w13_scales, requires_grad=False
            )
1103

1104
1105
1106
1107
1108
1109
            if self.flashinfer_moe_backend is not None:
                # NOTE: weights have to be swapped since the activation is
                # applied on different half for flashinfer vs vllm
                assert not self.block_quant
                register_moe_scaling_factors(layer)
                w13_weight = swap_w13_to_w31(layer.w13_weight.data)
1110
                if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
1111
1112
1113
                    rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
                layer.w13_weight.data = w13_weight.data

1114
        if self.use_marlin:
1115
1116
1117
            prepare_moe_fp8_layer_for_marlin(
                layer, False, input_dtype=self.marlin_input_dtype
            )
1118
1119
1120
            # Activations not quantized for marlin.
            del layer.w13_input_scale
            del layer.w2_input_scale
1121

1122
1123
1124
1125
    def maybe_make_prepare_finalize(
        self,
        routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
    ) -> mk.FusedMoEPrepareAndFinalize | None:
1126
1127
1128
1129
1130
        if (
            self.rocm_aiter_moe_enabled
            or self.use_marlin
            or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1131
1132
            return None
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
1133
1134
1135
1136
1137
1138
            if self.block_quant:
                assert self.weight_block_size == [128, 128], (
                    f"Only support weight_block_size == [128, 128], "
                    f"got {self.weight_block_size}"
                )
            # Wire block-scale flag through prepare/finalize when using CUTLASS
1139
            prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
1140
1141
                self.moe,
                use_deepseek_fp8_block_scale=self.block_quant,
1142
            )
1143
1144
1145
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
        else:
1146
            return super().maybe_make_prepare_finalize(routing_tables)
1147

bnellnm's avatar
bnellnm committed
1148
1149
1150
    def select_gemm_impl(
        self,
        prepare_finalize: FusedMoEPrepareAndFinalize,
1151
        layer: torch.nn.Module,
bnellnm's avatar
bnellnm committed
1152
    ) -> FusedMoEPermuteExpertsUnpermute:
1153
        from vllm.model_executor.layers.fused_moe import (
1154
1155
            BatchedDeepGemmExperts,
            BatchedTritonExperts,
1156
            TritonExperts,
1157
1158
            TritonOrDeepGemmExperts,
        )
1159

1160
        assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
1161
1162
            "Marlin and ROCm AITER are not supported with all2all yet."
        )
1163

1164
1165
        assert self.moe_quant_config is not None

1166
1167
1168
1169
1170
        if (
            prepare_finalize.activation_format
            == FusedMoEActivationFormat.BatchedExperts
        ):
            max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
bnellnm's avatar
bnellnm committed
1171
            assert max_num_tokens_per_rank is not None
1172
1173
1174
1175

            experts_impl = (
                BatchedDeepGemmExperts if self.allow_deep_gemm else BatchedTritonExperts
            )
bnellnm's avatar
bnellnm committed
1176
            logger.debug(
1177
1178
                "%s(%s): max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
                experts_impl.__name__,
1179
1180
1181
1182
1183
                self.__class__.__name__,
                max_num_tokens_per_rank,
                self.weight_block_size,
                False,
            )
1184
            return experts_impl(
1185
                max_num_tokens=max_num_tokens_per_rank,
1186
                num_dispatchers=prepare_finalize.num_dispatchers(),
1187
                quant_config=self.moe_quant_config,
1188
            )
1189
1190
        elif self.moe.is_lora_enabled:
            return TritonExperts(quant_config=self.moe_quant_config)
1191
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
1192
            # Select GEMM experts with block-scale when weights are block-quantized
1193
            experts = select_cutlass_fp8_gemm_impl(
1194
1195
                self.moe,
                self.moe_quant_config,
1196
                use_deepseek_fp8_block_scale=self.block_quant,
1197
1198
1199
            )
            logger.debug_once("Using %s", experts.__class__.__name__)
            return experts
1200
        else:
bnellnm's avatar
bnellnm committed
1201
1202
            logger.debug(
                "TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s",
1203
1204
1205
1206
                self.__class__.__name__,
                self.weight_block_size,
                False,
            )
bnellnm's avatar
bnellnm committed
1207
            return TritonOrDeepGemmExperts(
1208
                quant_config=self.moe_quant_config,
1209
1210
1211
                allow_deep_gemm=self.allow_deep_gemm,
            )

1212
    def get_fused_moe_quant_config(
1213
        self, layer: torch.nn.Module
1214
    ) -> FusedMoEQuantConfig | None:
1215
1216
1217
1218
        if self.use_marlin:
            return None

        return fp8_w8a8_moe_quant_config(
1219
1220
1221
1222
1223
1224
1225
1226
            w1_scale=(
                layer.w13_weight_scale_inv
                if self.block_quant
                else layer.w13_weight_scale
            ),
            w2_scale=(
                layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
            ),
1227
1228
            a1_scale=layer.w13_input_scale,
            a2_scale=layer.w2_input_scale,
1229
            block_shape=self.weight_block_size,
1230
1231
        )

1232
1233
1234
1235
1236
1237
1238
1239
    @property
    def supports_eplb(self) -> bool:
        return True

    @property
    def allow_inplace(self) -> bool:
        return True

1240
1241
    def apply(
        self,
1242
        layer: FusedMoE,
1243
1244
1245
1246
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
1247
        use_grouped_topk: bool = False,
1248
1249
        topk_group: int | None = None,
        num_expert_group: int | None = None,
1250
        global_num_experts: int = -1,
1251
1252
        expert_map: torch.Tensor | None = None,
        custom_routing_function: Callable | None = None,
Simon Mo's avatar
Simon Mo committed
1253
        scoring_func: str = "softmax",
1254
        routed_scaling_factor: float = 1.0,
1255
        e_score_correction_bias: torch.Tensor | None = None,
1256
        apply_router_weight_on_input: bool = False,
Michael Goin's avatar
Michael Goin committed
1257
        activation: str = "silu",
1258
        enable_eplb: bool = False,
1259
1260
1261
1262
        expert_load_view: torch.Tensor | None = None,
        logical_to_physical_map: torch.Tensor | None = None,
        logical_replica_count: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
1263
1264
1265
1266
1267
        if enable_eplb:
            assert expert_load_view is not None
            assert logical_to_physical_map is not None
            assert logical_replica_count is not None
            assert isinstance(layer, FusedMoE)
1268

1269
        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
1270
1271
1272
            assert activation == "silu", (
                f"Expected 'silu' activation but got {activation}"
            )
1273

1274
            if self.block_quant:
1275
                import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe  # noqa: E501, F401
1276
1277
1278
1279
1280
1281

                e_score_correction_bias = (
                    e_score_correction_bias.to(x.dtype)
                    if e_score_correction_bias is not None
                    else None
                )
1282
                routing_method_type = layer.routing_method_type
1283
                return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
1284
1285
1286
                    routing_logits=router_logits.to(torch.float32)
                    if routing_method_type == RoutingMethodType.DeepSeekV3
                    else router_logits,
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
                    routing_bias=e_score_correction_bias,
                    x=x,
                    w13_weight=layer.w13_weight,
                    w13_weight_scale_inv=layer.w13_weight_scale_inv,
                    w2_weight=layer.w2_weight,
                    w2_weight_scale_inv=layer.w2_weight_scale_inv,
                    global_num_experts=global_num_experts,
                    top_k=top_k,
                    num_expert_group=num_expert_group,
                    topk_group=topk_group,
                    intermediate_size=layer.intermediate_size_per_partition,
                    expert_offset=layer.ep_rank * layer.local_num_experts,
                    local_num_experts=layer.local_num_experts,
1300
                    block_shape=self.weight_block_size,
1301
                    routing_method_type=routing_method_type,
1302
                    routed_scaling=routed_scaling_factor,
1303
1304
                )
            else:
1305
                assert not renormalize and custom_routing_function is not None
XuruiYang's avatar
XuruiYang committed
1306
                result = apply_flashinfer_per_tensor_scale_fp8(
1307
1308
1309
1310
1311
1312
1313
1314
                    layer=layer,
                    hidden_states=x,
                    router_logits=router_logits,
                    routing_bias=e_score_correction_bias,
                    global_num_experts=global_num_experts,
                    top_k=top_k,
                    num_expert_group=num_expert_group,
                    topk_group=topk_group,
1315
1316
                    apply_router_weight_on_input=apply_router_weight_on_input,
                )
1317

1318
        select_result = layer.select_experts(
1319
1320
1321
1322
            hidden_states=x,
            router_logits=router_logits,
        )

XuruiYang's avatar
XuruiYang committed
1323
1324
        topk_weights, topk_ids, zero_expert_result = select_result

1325
1326
        if self.rocm_aiter_moe_enabled:
            from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (  # noqa: E501
1327
1328
1329
                rocm_aiter_fused_experts,
            )

XuruiYang's avatar
XuruiYang committed
1330
            result = rocm_aiter_fused_experts(
1331
1332
1333
                x,
                layer.w13_weight,
                layer.w2_weight,
1334
1335
1336
1337
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                activation=activation,
                apply_router_weight_on_input=apply_router_weight_on_input,
1338
                expert_map=expert_map,
1339
1340
                quant_config=self.moe_quant_config,
            )
1341
        elif self.use_marlin:
1342
            assert activation == "silu", f"{activation} not supported for Marlin MoE."
1343
            result = fused_marlin_moe(
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
                x,
                layer.w13_weight,
                layer.w2_weight,
                None,
                None,
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                router_logits,
                topk_weights,
                topk_ids,
                quant_type_id=scalar_types.float8_e4m3fn.id,
                apply_router_weight_on_input=apply_router_weight_on_input,
                global_num_experts=global_num_experts,
1357
                expert_map=expert_map,
1358
                input_dtype=self.marlin_input_dtype,
1359
1360
                workspace=layer.workspace,
            )
1361
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
1362
1363
1364
            assert activation == "silu", (
                f"Expected 'silu' activation but got {activation}"
            )
1365
1366
1367
1368
1369
1370
1371
1372
            if not self.block_quant:
                assert not renormalize and custom_routing_function is not None
                assert scoring_func == "sigmoid", (
                    f"Expected 'sigmoid' scoring func but got {scoring_func}"
                )
            # Delegate to CUTLASS FlashInfer path; function already bound with
            # use_deepseek_fp8_block_scale for block-quant when applicable
            result = self.flashinfer_moe_fn(
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
                x,
                layer,
                topk_weights,
                topk_ids,
                inplace=False,
                activation=activation,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
1383
        else:
1384
            from vllm.model_executor.layers.fused_moe import fused_experts
1385

XuruiYang's avatar
XuruiYang committed
1386
            result = fused_experts(
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                inplace=True,
                activation=activation,
                global_num_experts=global_num_experts,
                apply_router_weight_on_input=apply_router_weight_on_input,
                expert_map=expert_map,
1397
1398
1399
                quant_config=self.moe_quant_config,
                allow_deep_gemm=self.allow_deep_gemm,
                allow_cutlass_block_scaled_grouped_gemm=(
1400
1401
1402
                    self.allow_cutlass_block_scaled_grouped_gemm
                ),
            )
1403
1404

        if layer.zero_expert_num != 0 and layer.zero_expert_type is not None:
1405
            assert not isinstance(result, tuple), (
XuruiYang's avatar
XuruiYang committed
1406
                "Shared + zero experts are mutually exclusive not yet supported"
1407
            )
XuruiYang's avatar
XuruiYang committed
1408
1409
1410
            return result, zero_expert_result
        else:
            return result
1411
1412


1413
1414
1415
class Fp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
1416
1417
1418
    """

    def __init__(self, quant_config: Fp8Config):
1419
        super().__init__(quant_config)