fp8.py 59.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from enum import Enum
5
from functools import partial
6
from typing import TYPE_CHECKING, Any, Optional
7
8
9
10

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter
11
from torch.utils._python_dispatch import TorchDispatchMode
12

13
import vllm.envs as envs
14
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
15
from vllm import _custom_ops as ops
16
from vllm._aiter_ops import rocm_aiter_ops
17
from vllm.attention.layer import Attention
18
from vllm.distributed import get_tensor_model_parallel_world_size
19
from vllm.logger import init_logger
20
from vllm.model_executor.layers.batch_invariant import (
21
    vllm_is_batch_invariant,
22
)
bnellnm's avatar
bnellnm committed
23
from vllm.model_executor.layers.fused_moe import (
24
25
26
27
28
29
30
    FusedMoE,
    FusedMoEActivationFormat,
    FusedMoEMethodBase,
    FusedMoEPermuteExpertsUnpermute,
    FusedMoEPrepareAndFinalize,
    FusedMoeWeightScaleSupported,
)
31
from vllm.model_executor.layers.fused_moe.config import (
32
    FusedMoEParallelConfig,
33
    FusedMoEQuantConfig,
34
    RoutingMethodType,
35
36
    fp8_w8a8_moe_quant_config,
)
37
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
38
39
40
41
42
43
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
)
44
from vllm.model_executor.layers.quantization import QuantizationMethods
45
from vllm.model_executor.layers.quantization.base_config import (
46
47
48
    QuantizationConfig,
    QuantizeMethodBase,
)
49
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
50
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
51
52
    FlashinferMoeBackend,
    apply_flashinfer_per_tensor_scale_fp8,
53
    build_flashinfer_fp8_cutlass_moe_prepare_finalize,
54
55
56
57
58
59
60
    flashinfer_cutlass_moe_fp8,
    get_flashinfer_moe_backend,
    register_moe_scaling_factors,
    rotate_flashinfer_fp8_moe_weights,
    select_cutlass_fp8_gemm_impl,
    swap_w13_to_w31,
)
61
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
62
63
64
65
    W8A8BlockFp8LinearOp,
    create_fp8_input_scale,
    create_fp8_scale_parameter,
    create_fp8_weight_parameter,
66
    deepgemm_post_process_fp8_weight_block,
67
68
69
70
71
    maybe_post_process_fp8_weight_block,
    process_fp8_weight_block_strategy,
    process_fp8_weight_tensor_strategy,
    validate_fp8_block_shape,
)
72
73
74
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
    get_marlin_input_dtype,
)
75
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
76
77
78
79
    apply_fp8_marlin_linear,
    prepare_fp8_layer_for_marlin,
    prepare_moe_fp8_layer_for_marlin,
)
80
from vllm.model_executor.layers.quantization.utils.quant_utils import (
81
82
83
    GroupShape,
    is_layer_skipped,
)
84
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
85
86
87
88
89
90
91
92
93
94
95
96
97
    Fp8LinearOp,
    all_close_1d,
    cutlass_block_fp8_supported,
    cutlass_fp8_supported,
    maybe_create_device_identity,
    normalize_e4m3fn_to_e4m3fnuz,
    per_tensor_dequantize,
)
from vllm.model_executor.parameter import (
    BlockQuantScaleParameter,
    ModelWeightParameter,
    PerTensorScaleParameter,
)
98
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
99
from vllm.platforms import current_platform
100
from vllm.scalar_type import scalar_types
101
102
103
104
from vllm.utils.deep_gemm import (
    is_deep_gemm_e8m0_used,
    is_deep_gemm_supported,
)
105
from vllm.utils.flashinfer import has_flashinfer_moe
106
from vllm.utils.import_utils import has_deep_gemm
107

108
109
110
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

111
112
113
114
ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)

115

116
117
118
119
120
121
122
123
124
125
class Fp8MoeBackend(Enum):
    NONE = 0
    FLASHINFER_TRTLLM = 1
    FLASHINFER_CUTLASS = 2
    DEEPGEMM = 3
    CUTLASS_BLOCK_SCALED_GROUPED_GEMM = 4
    MARLIN = 5
    TRITON = 6


126
def get_fp8_moe_backend(
127
128
129
    block_quant: bool,
    moe_parallel_config: FusedMoEParallelConfig,
    with_lora_support: bool,
130
) -> Fp8MoeBackend:
131
132
133
134
    """
    Select the primary FP8 MoE backend
    Note: Shape-specific fallbacks may still occur at runtime.
    """
135
136
    if with_lora_support:
        return Fp8MoeBackend.TRITON
137
    # Prefer FlashInfer backends on supported GPUs; allow SM90 and SM100.
138
139
    if (
        current_platform.is_cuda()
140
        and (
141
            current_platform.is_device_capability_family(100)
142
143
            or current_platform.is_device_capability(90)
        )
144
145
146
        and envs.VLLM_USE_FLASHINFER_MOE_FP8
        and has_flashinfer_moe()
    ):
147
148
        backend = get_flashinfer_moe_backend()
        if backend == FlashinferMoeBackend.TENSORRT_LLM:
149
            logger.info_once("Using FlashInfer FP8 MoE TRTLLM backend for SM100")
150
151
            return Fp8MoeBackend.FLASHINFER_TRTLLM
        else:
152
            if block_quant and current_platform.is_device_capability_family(100):
153
154
155
156
157
158
                raise ValueError(
                    "FlashInfer FP8 MoE throughput backend does not "
                    "support block quantization. Please use "
                    "VLLM_FLASHINFER_MOE_BACKEND=latency "
                    "instead."
                )
159
            logger.info_once("Using FlashInfer FP8 MoE CUTLASS backend for SM90/SM100")
160
161
162
            return Fp8MoeBackend.FLASHINFER_CUTLASS

    # weight-only path for older GPUs without native FP8
163
164
165
166
    use_marlin = (
        not current_platform.has_device_capability(89)
        or envs.VLLM_TEST_FORCE_FP8_MARLIN
    )
167
168
169
170
171
172
    if current_platform.is_rocm():
        use_marlin = False
    if use_marlin:
        logger.info_once("Using Marlin backend for FP8 MoE")
        return Fp8MoeBackend.MARLIN

173
174
175
176
177
178
179
180
181
182
183
184
185
    # Determine if we should use DeepGEMM with block-quantized weights:
    # - If explicitly set by user, respect their choice
    # - If not explicitly set (default), disable when TP size is >= 8
    moe_use_deep_gemm = envs.VLLM_MOE_USE_DEEP_GEMM
    if not envs.is_set("VLLM_MOE_USE_DEEP_GEMM") and moe_parallel_config.tp_size >= 8:
        moe_use_deep_gemm = False
        logger.info_once(
            "DeepGEMM MoE is disabled by default when TP size is >= 8. "
            "Set VLLM_MOE_USE_DEEP_GEMM=1 to enable it.",
            scope="local",
        )

    if envs.VLLM_USE_DEEP_GEMM and moe_use_deep_gemm and block_quant:
186
        if not has_deep_gemm():
187
188
189
            logger.warning_once(
                "DeepGEMM backend requested but not available.", scope="local"
            )
190
        elif is_deep_gemm_supported():
191
            logger.info_once("Using DeepGEMM backend for FP8 MoE", scope="local")
192
193
194
            return Fp8MoeBackend.DEEPGEMM

    # CUTLASS BlockScaled GroupedGemm on SM100 with block-quantized weights
195
196
    if (
        current_platform.is_cuda()
197
        and current_platform.is_device_capability_family(100)
198
199
        and block_quant
    ):
200
201
202
        logger.info_once(
            "Using Cutlass BlockScaled GroupedGemm backend for FP8 MoE", scope="local"
        )
203
204
205
206
207
208
209
        return Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM

    # default to Triton
    logger.info_once("Using Triton backend for FP8 MoE")
    return Fp8MoeBackend.TRITON


210
class Fp8Config(QuantizationConfig):
211
212
    """Config class for FP8."""

213
214
    def __init__(
        self,
215
        is_checkpoint_fp8_serialized: bool = False,
216
        activation_scheme: str = "dynamic",
217
218
        ignored_layers: list[str] | None = None,
        weight_block_size: list[int] | None = None,
219
    ) -> None:
220
        super().__init__()
221

222
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
223

224
        if activation_scheme not in ACTIVATION_SCHEMES:
225
            raise ValueError(f"Unsupported activation scheme {activation_scheme}")
226
        self.activation_scheme = activation_scheme
227
        self.ignored_layers = ignored_layers or []
228
229
230
231
        if weight_block_size is not None:
            if not is_checkpoint_fp8_serialized:
                raise ValueError(
                    "The block-wise quantization only supports fp8-serialized "
232
233
                    "checkpoint for now."
                )
234
235
236
            if len(weight_block_size) != 2:
                raise ValueError(
                    "The quantization block size of weight must have 2 "
237
238
                    f"dimensions, but got {len(weight_block_size)} dimensions"
                )
239
            if activation_scheme != "dynamic":
240
241
242
243
244
                raise ValueError(
                    "The block-wise quantization only supports "
                    "dynamic activation scheme for now, but got "
                    f"{activation_scheme} activation scheme."
                )
245
        self.weight_block_size = weight_block_size
246

247
    @classmethod
248
    def get_name(cls) -> QuantizationMethods:
249
250
251
        return "fp8"

    @classmethod
252
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
253
254
255
256
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
257
        return 75
258
259

    @classmethod
260
    def get_config_filenames(cls) -> list[str]:
261
262
        return []

263
264
    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if self.ignored_layers is not None:
265
            self.ignored_layers = hf_to_vllm_mapper.apply_list(self.ignored_layers)
266

267
    @classmethod
268
    def from_config(cls, config: dict[str, Any]) -> "Fp8Config":
269
        quant_method = cls.get_from_keys(config, ["quant_method"])
270
        is_checkpoint_fp8_serialized = "fp8" in quant_method
271
        activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
272
        ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
273
        weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None)
274
        if not ignored_layers:
275
276
277
278
279
280
281
282
283
284
285
286
287
            ignored_layers = cls.get_from_keys_or(
                config, ["modules_to_not_convert"], None
            )
        return cls(
            is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
            activation_scheme=activation_scheme,
            ignored_layers=ignored_layers,
            weight_block_size=weight_block_size,
        )

    def get_xpu_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
288
        from vllm.model_executor.layers.quantization.ipex_quant import (
289
290
291
292
            XPUFp8LinearMethod,
            XPUFp8MoEMethod,
        )

293
294
295
296
        fp8_config = Fp8Config(
            is_checkpoint_fp8_serialized=self.is_checkpoint_fp8_serialized,
            activation_scheme=self.activation_scheme,
            ignored_layers=self.ignored_layers,
297
298
            weight_block_size=self.weight_block_size,
        )
299
300

        if isinstance(layer, LinearBase):
301
302
303
304
305
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
306
307
308
309
310
311
312
313
                return UnquantizedLinearMethod()
            return XPUFp8LinearMethod(fp8_config)
        elif isinstance(layer, FusedMoE):
            return XPUFp8MoEMethod(fp8_config, layer)
        elif isinstance(layer, Attention):
            return Fp8KVCacheMethod(self)
        return None

314
315
316
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
317
318
        if current_platform.is_xpu():
            return self.get_xpu_quant_method(layer, prefix)
319
        if isinstance(layer, LinearBase):
320
321
322
323
324
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
325
                return UnquantizedLinearMethod()
326
327
328
            quant_method = Fp8LinearMethod(self)
            quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
            return quant_method
329
        elif isinstance(layer, FusedMoE):
330
331
332
333
334
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
XuruiYang's avatar
XuruiYang committed
335
                return UnquantizedFusedMoEMethod(layer.moe_config)
336
337
338
339
            if self.is_checkpoint_fp8_serialized:
                moe_quant_method = Fp8MoEMethod(self, layer)
            else:
                moe_quant_method = Fp8OnlineMoEMethod(self, layer)
340
341
            moe_quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
            return moe_quant_method
342
        elif isinstance(layer, Attention):
343
            return Fp8KVCacheMethod(self)
344
        return None
345

346
    def get_cache_scale(self, name: str) -> str | None:
347
348
349
350
351
352
353
354
355
356
357
358
        """
        Check whether the param name matches the format for k/v cache scales
        in compressed-tensors. If this is the case, return its equivalent
        param name expected by vLLM

        :param name: param name
        :return: matching param name for KV cache scale in vLLM
        """
        if name.endswith(".output_scale") and ".k_proj" in name:
            return name.replace(".k_proj.output_scale", ".attn.k_scale")
        if name.endswith(".output_scale") and ".v_proj" in name:
            return name.replace(".v_proj.output_scale", ".attn.v_scale")
359
360
361
362
363
        if name.endswith(".output_scale") and ".q_proj" in name:
            return name.replace(".q_proj.output_scale", ".attn.q_scale")
        if name.endswith("self_attn.prob_output_scale"):
            return name.replace(".prob_output_scale", ".attn.prob_scale")
        # If no matches, return None
364
365
        return None

366

367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
class CopyNumelCounter(TorchDispatchMode):
    """
    Tracks total number of elements modified with `copy_`. Useful for keeping
    track of weight loading where underlying weights can be arbitrarily
    transformed (such as with `narrow`) before calling copy.
    """

    def __init__(self):
        super().__init__()
        self.copied_numel = 0

    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        out = func(*args, **kwargs)
        if func == torch.ops.aten.copy_.default:
            self.copied_numel += args[0].numel()
        return out


387
388
class Fp8LinearMethod(LinearMethodBase):
    """Linear method for FP8.
389
390
391
392
393
394
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.
395
396
397
398
399

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
    2. Only support float8_e4m3fn data type due to the limitation of
       torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
400

401
402
403
404
    Args:
        quant_config: The quantization config.
    """

405
    def __init__(self, quant_config: Fp8Config):
406
        self.quant_config = quant_config
407
        self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
408
        self.out_dtype = torch.get_default_dtype()
409

410
411
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
412
        self.marlin_input_dtype = None
413
414
415
416
        self.use_marlin = (
            not current_platform.has_device_capability(89)
            or envs.VLLM_TEST_FORCE_FP8_MARLIN
        )
417
        # Disable marlin for rocm
418
        if current_platform.is_rocm():
419
            self.use_marlin = False
420
        if vllm_is_batch_invariant():
421
            self.use_marlin = False
422

423
        self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enaled()
424
        self.use_deep_gemm = is_deep_gemm_supported()
425

426
427
        self.weight_block_size = self.quant_config.weight_block_size
        self.block_quant = self.weight_block_size is not None
428
        self.act_q_static = self.quant_config.activation_scheme == "static"
429
430
        if self.weight_block_size:
            self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
431
        else:
432
433
434
435
436
            # Use per-token quantization for better perf if dynamic and cutlass
            if not self.act_q_static and cutlass_fp8_supported():
                self.act_q_group_shape = GroupShape.PER_TOKEN
            else:
                self.act_q_group_shape = GroupShape.PER_TENSOR
437

438
439
440
441
442
443
444
445
446
447
448
449
        if self.block_quant:
            assert not self.act_q_static
            assert self.weight_block_size is not None
            self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
                weight_group_shape=GroupShape(*self.weight_block_size),
                act_quant_group_shape=self.act_q_group_shape,
                cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
                use_aiter_and_is_supported=self.use_aiter_and_is_supported,
            )
        else:
            self.fp8_linear = Fp8LinearOp(
                act_quant_static=self.act_q_static,
450
451
                act_quant_group_shape=self.act_q_group_shape,
            )
452

453
454
455
456
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
457
        output_partition_sizes: list[int],
458
459
460
461
462
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
463
464
        maybe_create_device_identity()

465
        output_size_per_partition = sum(output_partition_sizes)
466
        weight_loader = extra_weight_attrs.get("weight_loader")
467
468
469
470
471
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None
472

473
        if self.block_quant:
474
475
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
476
477
478
479
480
481
482
483
            validate_fp8_block_shape(
                layer,
                input_size,
                output_size,
                input_size_per_partition,
                output_partition_sizes,
                self.weight_block_size,
            )
484

485
        # WEIGHT
486
        if self.quant_config.is_checkpoint_fp8_serialized:
487
488
489
            weight = create_fp8_weight_parameter(
                output_size_per_partition, input_size_per_partition, weight_loader
            )
490
        else:
491
492
493
494
495

            def patched_weight_loader(param, loaded_weight, *args, **kwargs):
                # track how many elements we have updated
                if not hasattr(layer, "_loaded_numel"):
                    layer._loaded_numel = 0
496
497
498
499
500
501

                # load the current weight chunk
                copy_numel_counter = CopyNumelCounter()
                with copy_numel_counter:
                    res = weight_loader(param, loaded_weight, *args, **kwargs)  # type: ignore[misc]
                layer._loaded_numel += copy_numel_counter.copied_numel
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516

                # if we have loaded all of the elements, call
                # process_weights_after_loading
                target_loaded_numel = layer.weight.numel()
                if layer._loaded_numel == target_loaded_numel:
                    self.process_weights_after_loading(layer)

                    # Delete the bookkeeping
                    del layer._loaded_numel
                    # Prevent the usual `process_weights_after_loading` call from doing
                    # anything
                    layer._already_called_process_weights_after_loading = True

                return res

517
            # For non-serialized checkpoints, use original dtype
518
519
520
521
522
523
524
525
            weight = ModelWeightParameter(
                data=torch.empty(
                    output_size_per_partition,
                    input_size_per_partition,
                    dtype=params_dtype,
                ),
                input_dim=1,
                output_dim=0,
526
                weight_loader=patched_weight_loader,
527
            )
528
529
        layer.register_parameter("weight", weight)

530
531
532
533
        # If checkpoint is serialized fp8, load them.
        # Otherwise, wait until process_weights_after_loading.
        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
534
            if not self.block_quant:
535
536
537
538
539
540
541
                scale = create_fp8_scale_parameter(
                    PerTensorScaleParameter,
                    output_partition_sizes,
                    input_size_per_partition,
                    None,
                    weight_loader,
                )
542
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
543
544
                layer.register_parameter("weight_scale", scale)
            else:
545
546
                assert not self.act_q_static
                assert self.weight_block_size is not None
547
548
549
550
551
552
553
                scale = create_fp8_scale_parameter(
                    BlockQuantScaleParameter,
                    output_partition_sizes,
                    input_size_per_partition,
                    self.weight_block_size,
                    weight_loader,
                )
554
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
555
556
                # The weight_scale_inv name is intentional for deepseekv3
                layer.register_parameter("weight_scale_inv", scale)
557

558
            # INPUT ACTIVATION SCALE
559
            if self.act_q_static:
560
                scale = create_fp8_input_scale(output_partition_sizes, weight_loader)
561
                set_weight_attrs(scale, {"scale_type": "input_scale"})
562
                layer.register_parameter("input_scale", scale)
563
564
            else:
                layer.register_parameter("input_scale", None)
565

566
    def process_weights_after_loading(self, layer: Module) -> None:
567
568
569
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

570
        size_k_first = True
571
        input_scale = None
572
        # TODO(rob): refactor block quant into separate class.
573
        if self.block_quant:
574
            assert not self.act_q_static
575
            size_k_first = False
576

577
            weight, weight_scale_inv = process_fp8_weight_block_strategy(
578
579
                layer.weight, layer.weight_scale_inv
            )
580
581
582
583

            # Update layer with new values
            replace_parameter(layer, "weight", weight.data)
            replace_parameter(layer, "weight_scale_inv", weight_scale_inv.data)
584

585
        # If checkpoint not serialized fp8, quantize the weights.
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
        else:
            if not self.quant_config.is_checkpoint_fp8_serialized:
                qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
                weight = qweight.t()

            # If checkpoint is fp8 per-tensor, handle that there are N scales for N
            # shards in a fused module
            else:
                weight = layer.weight
                weight_scale = layer.weight_scale

                # If using w8a8, torch._scaled_mm needs per tensor, so
                # requantize the logical shards as a single weight.
                if not self.use_marlin:
                    weight, weight_scale, input_scale = (
                        process_fp8_weight_tensor_strategy(
                            weight,
                            weight_scale,
                            layer.logical_widths,
                            getattr(layer, "input_scale", None),
                        )
                    )
                    if self.act_q_static:
                        assert input_scale is not None
                        input_scale = input_scale.max()
                weight = weight.t()
612

613
614
615
616
617
618
            # Update layer with new values.
            replace_parameter(layer, "weight", weight.data)
            replace_parameter(layer, "weight_scale", weight_scale.data)

        if input_scale is not None:
            replace_parameter(layer, "input_scale", input_scale)
619
        else:
620
            layer.input_scale = None
621

622
        if self.use_marlin:
623
624
625
            prepare_fp8_layer_for_marlin(
                layer, size_k_first, input_dtype=self.marlin_input_dtype
            )
626
627
            # Activations not quantized for marlin.
            del layer.input_scale
628
            return
629

630
        if self.block_quant:
631
            maybe_post_process_fp8_weight_block(layer)
632

633
634
635
636
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
637
        bias: torch.Tensor | None = None,
638
    ) -> torch.Tensor:
639
640
        # if batch invariant mode is enabled, prefer DeepGEMM FP8 path
        # we will use BF16 dequant when DeepGEMM is not supported.
641
        if vllm_is_batch_invariant():
642
643
            if self.block_quant:
                assert self.weight_block_size is not None
644
645
646
                return self.w8a8_block_fp8_linear.apply(
                    input=x,
                    weight=layer.weight,
647
                    weight_scale=layer.weight_scale_inv,
648
649
650
                    input_scale=layer.input_scale,
                    bias=bias,
                )
651
            else:
652
653
654
                # per-tensor/channel: dequant to BF16 and run GEMM
                weight_fp8 = layer.weight.to(torch.bfloat16)
                weight_scale = layer.weight_scale.to(torch.bfloat16)
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
                if weight_scale.numel() == 1:
                    # Per-tensor: simple scalar multiplication
                    weight_bf16 = weight_fp8 * weight_scale
                else:
                    # Multiple scales (fused modules like QKV)
                    # Try to infer correct broadcasting
                    # weight is [K, N], scale could be [num_logical_weights]
                    # Need to figure out how to broadcast - for now just try
                    # direct multiplication
                    if (
                        weight_scale.dim() == 1
                        and weight_scale.shape[0] == weight_fp8.shape[0]
                    ):
                        # Per-row scaling
                        weight_bf16 = weight_fp8 * weight_scale.unsqueeze(1)
                    else:
                        # Fallback
                        weight_bf16 = weight_fp8 * weight_scale
673
                return torch.nn.functional.linear(x, weight_bf16.t(), bias)
674

675
        if self.use_marlin:
676
677
678
679
680
            if self.block_quant:
                weight_scale = layer.weight_scale_inv
            else:
                weight_scale = layer.weight_scale

681
682
683
            return apply_fp8_marlin_linear(
                input=x,
                weight=layer.weight,
684
                weight_scale=weight_scale,
685
686
687
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
688
                input_dtype=self.marlin_input_dtype,
689
690
                bias=bias,
            )
691

692
        if self.block_quant:
693
694
695
            assert self.weight_block_size is not None

            return self.w8a8_block_fp8_linear.apply(
696
                input=x,
697
                weight=layer.weight,
698
                weight_scale=layer.weight_scale_inv,
699
                input_scale=layer.input_scale,
700
                bias=bias,
701
            )
702

703
704
705
706
707
708
709
710
        return self.fp8_linear.apply(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            out_dtype=self.out_dtype,
            input_scale=layer.input_scale,
            bias=bias,
        )
711
712


713
714
715
716
717
718
719
720
721
722
723
724
725
class Fp8MoEMethod(FusedMoEMethodBase):
    """MoE method for FP8.
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

726
727
728
    def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
        super().__init__(layer.moe_config)
        self.layer = layer
729
        self.quant_config = quant_config
730
        self.weight_block_size = self.quant_config.weight_block_size
731
        self.block_quant: bool = self.weight_block_size is not None
732
        self.fp8_backend = get_fp8_moe_backend(
733
            self.block_quant, layer.moe_parallel_config, self.moe.is_lora_enabled
734
        )
735

736
        self.marlin_input_dtype = None
737
        self.use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN
738
        self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
739
740
741
742
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
            self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
        elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
            self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS
743
744
745
746
747
748
749
750
751
752
            if self.block_quant:
                assert self.weight_block_size == [128, 128], (
                    f"Only support weight_block_size == [128, 128], "
                    f"got {self.weight_block_size}"
                )
            self.flashinfer_moe_fn = partial(
                flashinfer_cutlass_moe_fp8,
                moe=self.moe,
                use_deepseek_fp8_block_scale=self.block_quant,
            )
753

754
        self.allow_deep_gemm = self.fp8_backend == Fp8MoeBackend.DEEPGEMM
755
756
757
        self.allow_cutlass_block_scaled_grouped_gemm = (
            self.fp8_backend == Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM
        )
758

759
760
761
762
763
764
765
766
767
    def create_weights(
        self,
        layer: Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
768
769
770
771
772
773
        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

774
775
776
        assert self.quant_config.is_checkpoint_fp8_serialized
        params_dtype = torch.float8_e4m3fn

777
        if self.block_quant:
778
779
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
780
781
            tp_size = get_tensor_model_parallel_world_size()
            block_n, block_k = (
782
783
                self.weight_block_size[0],
                self.weight_block_size[1],
784
785
786
787
788
            )
            # NOTE: To ensure proper alignment of the block-wise quantization
            # scales, the output_size of the weights for both the gate and up
            # layers must be divisible by block_n.
            # Required by column parallel or enabling merged weights
789
            if intermediate_size_per_partition % block_n != 0:
790
791
                raise ValueError(
                    f"The output_size of gate's and up's weight = "
792
                    f"{intermediate_size_per_partition} is not divisible by "
793
794
795
                    f"weight quantization block_n = {block_n}."
                )
            if tp_size > 1 and intermediate_size_per_partition % block_k != 0:
796
                # Required by row parallel
797
798
799
                raise ValueError(
                    f"The input_size of down's weight = "
                    f"{intermediate_size_per_partition} is not divisible by "
800
801
                    f"weight quantization block_k = {block_k}."
                )
802
803

        # WEIGHTS
804
805
806
807
808
809
810
811
812
        w13_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                hidden_size,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
813
814
815
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

816
817
818
819
820
821
822
823
824
        w2_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
825
826
827
828
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
829
830
831
        if not self.block_quant:
            # Allocate 2 scales for w1 and w3 respectively.
            # They will be combined to a single scale after weight loading.
832
833
834
835
836
837
            w13_weight_scale = torch.nn.Parameter(
                torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
            )
            w2_weight_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
838
839
840
841
842
843
            layer.register_parameter("w13_weight_scale", w13_weight_scale)
            layer.register_parameter("w2_weight_scale", w2_weight_scale)
        else:
            w13_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
844
                    2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
845
846
847
848
849
850
851
852
853
                    (hidden_size + block_k - 1) // block_k,
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            w2_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
                    (hidden_size + block_n - 1) // block_n,
854
                    (intermediate_size_per_partition + block_k - 1) // block_k,
855
856
857
858
859
860
861
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
            layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
            assert self.quant_config.activation_scheme == "dynamic"
862

863
864
865
        # Add the quantization method used (per tensor/grouped/channel)
        # to ensure the weight scales are loaded in properly
        extra_weight_attrs.update(
866
867
868
869
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
            if self.block_quant
            else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
        )
870
871
        set_weight_attrs(w13_weight_scale, extra_weight_attrs)
        set_weight_attrs(w2_weight_scale, extra_weight_attrs)
872
873
874

        # INPUT_SCALES
        if self.quant_config.activation_scheme == "static":
875
876
877
            w13_input_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
878
            layer.register_parameter("w13_input_scale", w13_input_scale)
879
            set_weight_attrs(w13_input_scale, extra_weight_attrs)
880

881
882
883
            w2_input_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
884
            layer.register_parameter("w2_input_scale", w2_input_scale)
885
886
            set_weight_attrs(w2_input_scale, extra_weight_attrs)

887
        else:
888
889
            layer.w13_input_scale = None
            layer.w2_input_scale = None
890

891
892
        self.rocm_aiter_moe_enabled = False

893
    def process_weights_after_loading(self, layer: Module) -> None:
894
895
896
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

897
898
        # Lazy import to avoid importing triton too early.

899
        self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
900

901
        # TODO (rob): refactor block quant into separate class.
902
        if self.block_quant:
903
            assert self.quant_config.activation_scheme == "dynamic"
904
            if current_platform.is_fp8_fnuz():
905
                w13_weight, w13_weight_scale_inv, w13_input_scale = (
906
                    normalize_e4m3fn_to_e4m3fnuz(
907
908
909
910
911
912
                        layer.w13_weight,
                        layer.w13_weight_scale_inv,
                        layer.w13_input_scale,
                    )
                )
                w2_weight, w2_weight_scale_inv, w2_input_scale = (
913
                    normalize_e4m3fn_to_e4m3fnuz(
914
915
916
                        layer.w2_weight, layer.w2_weight_scale_inv, layer.w2_input_scale
                    )
                )
917
            elif self.flashinfer_moe_backend is not None:
918
919
                # NOTE: weights have to be swapped since the activation is
                # applied on different half for flashinfer vs vllm
920
                w13_weight = swap_w13_to_w31(layer.w13_weight.data)
921
                w13_weight_scale_inv = swap_w13_to_w31(layer.w13_weight_scale_inv.data)
922
923
                w2_weight = layer.w2_weight.data
                w2_weight_scale_inv = layer.w2_weight_scale_inv.data
924
925
926
927
928
929
930
            else:
                w13_weight = layer.w13_weight.data
                w13_weight_scale_inv = layer.w13_weight_scale_inv.data
                w2_weight = layer.w2_weight
                w2_weight_scale_inv = layer.w2_weight_scale_inv

            # torch.compile() cannot use Parameter subclasses.
931
932
933
934
            replace_parameter(layer, "w13_weight", w13_weight)
            replace_parameter(layer, "w13_weight_scale_inv", w13_weight_scale_inv)
            replace_parameter(layer, "w2_weight", w2_weight)
            replace_parameter(layer, "w2_weight_scale_inv", w2_weight_scale_inv)
935
            if self.rocm_aiter_moe_enabled:
936
                # reshaping weights is required for aiter moe kernel.
937
                shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
938
939
                    layer.w13_weight.data, layer.w2_weight.data
                )
940

941
942
                replace_parameter(layer, "w13_weight", shuffled_w13)
                replace_parameter(layer, "w2_weight", shuffled_w2)
943

944
            # DeepGemm scales need to be transposed and aligned. We try to do
945
            # it ahead of time for performance reasons.
946
947
948
949
950
951
952
            if self.allow_deep_gemm:
                dg_w13_weight, dg_w13_weight_scale_inv = (
                    deepgemm_post_process_fp8_weight_block(
                        wq=layer.w13_weight.data,
                        ws=layer.w13_weight_scale_inv.data,
                        quant_block_shape=tuple(layer.weight_block_size),
                        use_e8m0=is_deep_gemm_e8m0_used(),
953
                    )
954
955
956
957
958
959
960
                )
                dg_w2_weight, dg_w2_weight_scale_inv = (
                    deepgemm_post_process_fp8_weight_block(
                        wq=layer.w2_weight.data,
                        ws=layer.w2_weight_scale_inv.data,
                        quant_block_shape=tuple(layer.weight_block_size),
                        use_e8m0=is_deep_gemm_e8m0_used(),
961
                    )
962
963
964
965
966
967
968
969
970
                )
                layer.w13_weight = Parameter(dg_w13_weight, requires_grad=False)
                layer.w13_weight_scale_inv = Parameter(
                    dg_w13_weight_scale_inv, requires_grad=False
                )
                layer.w2_weight = Parameter(dg_w2_weight, requires_grad=False)
                layer.w2_weight_scale_inv = Parameter(
                    dg_w2_weight_scale_inv, requires_grad=False
                )
971
972
973
974
        else:
            # Fp8 moe kernels require a single activation scale.
            # We take the max of all the scales in case they differ.
            if self.quant_config.activation_scheme == "static":
975
                if layer.w13_input_scale is None or layer.w2_input_scale is None:
976
977
                    raise ValueError(
                        "QuantConfig has static quantization, but found "
978
979
980
981
982
                        "activation scales are None."
                    )
                if not all_close_1d(layer.w13_input_scale) or not all_close_1d(
                    layer.w2_input_scale
                ):
983
                    logger.warning_once(
984
985
                        "Found input_scales that are not equal for "
                        "fp8 MoE layer. Using the maximum across experts "
986
987
                        "for each layer."
                    )
988
989
                replace_parameter(layer, "w13_input_scale", layer.w13_input_scale.max())
                replace_parameter(layer, "w2_input_scale", layer.w2_input_scale.max())
990
            if current_platform.is_fp8_fnuz():
991
                # Normalize the weights and scales
992
                w13_weight, w13_weight_scale, w13_input_scale = (
993
                    normalize_e4m3fn_to_e4m3fnuz(
994
995
996
997
                        layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale
                    )
                )
                w2_weight, w2_weight_scale, w2_input_scale = (
998
                    normalize_e4m3fn_to_e4m3fnuz(
999
1000
1001
                        layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale
                    )
                )
1002
                # Reset the parameter
1003
1004
                replace_parameter(layer, "w13_weight", w13_weight)
                replace_parameter(layer, "w13_weight_scale", w13_weight_scale)
1005
                if w13_input_scale is not None:
1006
1007
1008
                    replace_parameter(layer, "w13_input_scale", w13_input_scale)
                replace_parameter(layer, "w2_weight", w2_weight)
                replace_parameter(layer, "w2_weight_scale", w2_weight_scale)
1009
                if w2_input_scale is not None:
1010
                    replace_parameter(layer, "w2_input_scale", w2_input_scale)
1011
1012
1013

            # Fp8 moe kernel needs single weight scale for w13 per expert.
            # We take the max then dequant and requant each expert.
1014
            assert layer.w13_weight_scale is not None
1015
            shard_size = layer.intermediate_size_per_partition
1016
            max_w13_scales = layer.w13_weight_scale.max(dim=1).values
1017
            for expert_id in range(layer.local_num_experts):
1018
1019
1020
                start = 0
                for shard_id in range(2):
                    dq_weight = per_tensor_dequantize(
1021
1022
1023
1024
1025
1026
                        layer.w13_weight[expert_id][start : start + shard_size, :],
                        layer.w13_weight_scale[expert_id][shard_id],
                    )
                    layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
                        ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
                    )
1027
1028
                    start += shard_size

1029
            if self.rocm_aiter_moe_enabled:
1030
                shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
1031
1032
                    layer.w13_weight, layer.w2_weight
                )
1033

1034
1035
                replace_parameter(layer, "w13_weight", shuffled_w13)
                replace_parameter(layer, "w2_weight", shuffled_w2)
1036

1037
            replace_parameter(layer, "w13_weight_scale", max_w13_scales)
1038

1039
1040
1041
1042
1043
1044
            if self.flashinfer_moe_backend is not None:
                # NOTE: weights have to be swapped since the activation is
                # applied on different half for flashinfer vs vllm
                assert not self.block_quant
                register_moe_scaling_factors(layer)
                w13_weight = swap_w13_to_w31(layer.w13_weight.data)
1045
                if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
1046
1047
1048
                    rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
                layer.w13_weight.data = w13_weight.data

1049
        if self.use_marlin:
1050
1051
1052
            prepare_moe_fp8_layer_for_marlin(
                layer, False, input_dtype=self.marlin_input_dtype
            )
1053
1054
1055
            # Activations not quantized for marlin.
            del layer.w13_input_scale
            del layer.w2_input_scale
1056

1057
1058
1059
1060
    def maybe_make_prepare_finalize(
        self,
        routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
    ) -> mk.FusedMoEPrepareAndFinalize | None:
1061
1062
1063
1064
1065
        if (
            self.rocm_aiter_moe_enabled
            or self.use_marlin
            or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1066
1067
            return None
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
1068
1069
1070
1071
1072
1073
            if self.block_quant:
                assert self.weight_block_size == [128, 128], (
                    f"Only support weight_block_size == [128, 128], "
                    f"got {self.weight_block_size}"
                )
            # Wire block-scale flag through prepare/finalize when using CUTLASS
1074
            prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
1075
1076
                self.moe,
                use_deepseek_fp8_block_scale=self.block_quant,
1077
            )
1078
1079
1080
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
        else:
1081
            return super().maybe_make_prepare_finalize(routing_tables)
1082

bnellnm's avatar
bnellnm committed
1083
1084
1085
    def select_gemm_impl(
        self,
        prepare_finalize: FusedMoEPrepareAndFinalize,
1086
        layer: torch.nn.Module,
bnellnm's avatar
bnellnm committed
1087
    ) -> FusedMoEPermuteExpertsUnpermute:
1088
        from vllm.model_executor.layers.fused_moe import (
1089
1090
            BatchedDeepGemmExperts,
            BatchedTritonExperts,
1091
            TritonExperts,
1092
1093
            TritonOrDeepGemmExperts,
        )
1094

1095
        assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
1096
1097
            "Marlin and ROCm AITER are not supported with all2all yet."
        )
1098

1099
1100
        assert self.moe_quant_config is not None

1101
1102
1103
1104
1105
        if (
            prepare_finalize.activation_format
            == FusedMoEActivationFormat.BatchedExperts
        ):
            max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
bnellnm's avatar
bnellnm committed
1106
            assert max_num_tokens_per_rank is not None
1107
1108
1109
1110

            experts_impl = (
                BatchedDeepGemmExperts if self.allow_deep_gemm else BatchedTritonExperts
            )
bnellnm's avatar
bnellnm committed
1111
            logger.debug(
1112
1113
                "%s(%s): max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
                experts_impl.__name__,
1114
1115
1116
1117
1118
                self.__class__.__name__,
                max_num_tokens_per_rank,
                self.weight_block_size,
                False,
            )
1119
            return experts_impl(
1120
                max_num_tokens=max_num_tokens_per_rank,
1121
                num_dispatchers=prepare_finalize.num_dispatchers(),
1122
                quant_config=self.moe_quant_config,
1123
            )
1124
1125
        elif self.moe.is_lora_enabled:
            return TritonExperts(quant_config=self.moe_quant_config)
1126
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
1127
            # Select GEMM experts with block-scale when weights are block-quantized
1128
            experts = select_cutlass_fp8_gemm_impl(
1129
1130
                self.moe,
                self.moe_quant_config,
1131
                use_deepseek_fp8_block_scale=self.block_quant,
1132
1133
1134
            )
            logger.debug_once("Using %s", experts.__class__.__name__)
            return experts
1135
        else:
bnellnm's avatar
bnellnm committed
1136
1137
            logger.debug(
                "TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s",
1138
1139
1140
1141
                self.__class__.__name__,
                self.weight_block_size,
                False,
            )
bnellnm's avatar
bnellnm committed
1142
            return TritonOrDeepGemmExperts(
1143
                quant_config=self.moe_quant_config,
1144
1145
1146
                allow_deep_gemm=self.allow_deep_gemm,
            )

1147
    def get_fused_moe_quant_config(
1148
        self, layer: torch.nn.Module
1149
    ) -> FusedMoEQuantConfig | None:
1150
1151
1152
1153
        if self.use_marlin:
            return None

        return fp8_w8a8_moe_quant_config(
1154
1155
1156
1157
1158
1159
1160
1161
            w1_scale=(
                layer.w13_weight_scale_inv
                if self.block_quant
                else layer.w13_weight_scale
            ),
            w2_scale=(
                layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
            ),
1162
1163
            a1_scale=layer.w13_input_scale,
            a2_scale=layer.w2_input_scale,
1164
            block_shape=self.weight_block_size,
1165
1166
        )

1167
1168
1169
1170
1171
1172
1173
1174
    @property
    def supports_eplb(self) -> bool:
        return True

    @property
    def allow_inplace(self) -> bool:
        return True

1175
1176
    def apply(
        self,
1177
        layer: FusedMoE,
1178
1179
        x: torch.Tensor,
        router_logits: torch.Tensor,
1180
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
1181
        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
1182
1183
1184
1185
            if layer.enable_eplb:
                raise NotImplementedError("EPLB not supported for `Fp8MoEMethod` yet.")
            assert layer.activation == "silu", (
                f"Expected 'silu' activation but got {layer.activation}"
1186
            )
1187

1188
            if self.block_quant:
1189
                import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe  # noqa: E501, F401
1190
1191

                e_score_correction_bias = (
1192
1193
                    layer.e_score_correction_bias.to(x.dtype)
                    if layer.e_score_correction_bias is not None
1194
1195
                    else None
                )
1196
                routing_method_type = layer.routing_method_type
1197
                return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
1198
1199
1200
                    routing_logits=router_logits.to(torch.float32)
                    if routing_method_type == RoutingMethodType.DeepSeekV3
                    else router_logits,
1201
1202
1203
1204
1205
1206
                    routing_bias=e_score_correction_bias,
                    x=x,
                    w13_weight=layer.w13_weight,
                    w13_weight_scale_inv=layer.w13_weight_scale_inv,
                    w2_weight=layer.w2_weight,
                    w2_weight_scale_inv=layer.w2_weight_scale_inv,
1207
1208
1209
1210
                    global_num_experts=layer.global_num_experts,
                    top_k=layer.top_k,
                    num_expert_group=layer.num_expert_group,
                    topk_group=layer.topk_group,
1211
1212
1213
                    intermediate_size=layer.intermediate_size_per_partition,
                    expert_offset=layer.ep_rank * layer.local_num_experts,
                    local_num_experts=layer.local_num_experts,
1214
                    block_shape=self.weight_block_size,
1215
                    routing_method_type=routing_method_type,
1216
                    routed_scaling=layer.routed_scaling_factor,
1217
1218
                )
            else:
1219
1220
1221
                assert (
                    not layer.renormalize and layer.custom_routing_function is not None
                )
XuruiYang's avatar
XuruiYang committed
1222
                result = apply_flashinfer_per_tensor_scale_fp8(
1223
1224
1225
                    layer=layer,
                    hidden_states=x,
                    router_logits=router_logits,
1226
1227
1228
1229
1230
1231
                    routing_bias=layer.e_score_correction_bias,
                    global_num_experts=layer.global_num_experts,
                    top_k=layer.top_k,
                    num_expert_group=layer.num_expert_group,
                    topk_group=layer.topk_group,
                    apply_router_weight_on_input=layer.apply_router_weight_on_input,
1232
                )
1233

1234
        select_result = layer.select_experts(
1235
1236
1237
1238
            hidden_states=x,
            router_logits=router_logits,
        )

XuruiYang's avatar
XuruiYang committed
1239
1240
        topk_weights, topk_ids, zero_expert_result = select_result

1241
1242
        if self.rocm_aiter_moe_enabled:
            from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (  # noqa: E501
1243
1244
1245
                rocm_aiter_fused_experts,
            )

XuruiYang's avatar
XuruiYang committed
1246
            result = rocm_aiter_fused_experts(
1247
1248
1249
                x,
                layer.w13_weight,
                layer.w2_weight,
1250
1251
                topk_weights=topk_weights,
                topk_ids=topk_ids,
1252
1253
1254
                activation=layer.activation,
                apply_router_weight_on_input=layer.apply_router_weight_on_input,
                expert_map=layer.expert_map,
1255
1256
                quant_config=self.moe_quant_config,
            )
1257
        elif self.use_marlin:
1258
1259
1260
            assert layer.activation == "silu", (
                f"{layer.activation} not supported for Marlin MoE."
            )
1261
            result = fused_marlin_moe(
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
                x,
                layer.w13_weight,
                layer.w2_weight,
                None,
                None,
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                router_logits,
                topk_weights,
                topk_ids,
                quant_type_id=scalar_types.float8_e4m3fn.id,
1273
1274
1275
                apply_router_weight_on_input=layer.apply_router_weight_on_input,
                global_num_experts=layer.global_num_experts,
                expert_map=layer.expert_map,
1276
                input_dtype=self.marlin_input_dtype,
1277
1278
                workspace=layer.workspace,
            )
1279
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
1280
1281
            assert layer.activation == "silu", (
                f"Expected 'silu' activation but got {layer.activation}"
1282
            )
1283
            if not self.block_quant:
1284
1285
1286
1287
1288
                assert (
                    not layer.renormalize and layer.custom_routing_function is not None
                )
                assert layer.scoring_func == "sigmoid", (
                    f"Expected 'sigmoid' scoring func but got {layer.scoring_func}"
1289
1290
1291
1292
                )
            # Delegate to CUTLASS FlashInfer path; function already bound with
            # use_deepseek_fp8_block_scale for block-quant when applicable
            result = self.flashinfer_moe_fn(
1293
1294
1295
1296
1297
                x,
                layer,
                topk_weights,
                topk_ids,
                inplace=False,
1298
1299
1300
1301
                activation=layer.activation,
                global_num_experts=layer.global_num_experts,
                expert_map=layer.expert_map,
                apply_router_weight_on_input=layer.apply_router_weight_on_input,
1302
            )
1303
        else:
1304
            from vllm.model_executor.layers.fused_moe import fused_experts
1305

XuruiYang's avatar
XuruiYang committed
1306
            result = fused_experts(
1307
1308
1309
1310
1311
1312
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                inplace=True,
1313
1314
1315
1316
                activation=layer.activation,
                global_num_experts=layer.global_num_experts,
                apply_router_weight_on_input=layer.apply_router_weight_on_input,
                expert_map=layer.expert_map,
1317
1318
1319
                quant_config=self.moe_quant_config,
                allow_deep_gemm=self.allow_deep_gemm,
                allow_cutlass_block_scaled_grouped_gemm=(
1320
1321
1322
                    self.allow_cutlass_block_scaled_grouped_gemm
                ),
            )
1323
1324

        if layer.zero_expert_num != 0 and layer.zero_expert_type is not None:
1325
            assert not isinstance(result, tuple), (
XuruiYang's avatar
XuruiYang committed
1326
                "Shared + zero experts are mutually exclusive not yet supported"
1327
            )
XuruiYang's avatar
XuruiYang committed
1328
1329
1330
            return result, zero_expert_result
        else:
            return result
1331
1332


1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
class Fp8OnlineMoEMethod(Fp8MoEMethod):
    """MoE method for online FP8 quantization.
    Supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

    def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
        super().__init__(quant_config, layer)
        assert not quant_config.is_checkpoint_fp8_serialized
        assert quant_config.activation_scheme == "dynamic"
        assert quant_config.weight_block_size is None
        assert self.flashinfer_moe_backend is None

    def create_weights(
        self,
        layer: Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

        # We are doing online quantization, patch the weight loaded
        # to call `process_weights_after_loading` in a streaming fashion
        # as soon as the last weight chunk is loaded.
        weight_loader = extra_weight_attrs["weight_loader"]
        # create a new holder to prevent modifying behavior of any other
        # objects which might depend on the old one
        new_extra_weight_attrs = extra_weight_attrs

        def patched_weight_loader(param, loaded_weight, *args, **kwargs):
            # add a counter to track how many elements we have updated
            if not hasattr(layer, "_loaded_numel"):
                layer._loaded_numel = 0
1377
1378
1379
1380
1381
1382

            # load the current weight chunk
            copy_numel_counter = CopyNumelCounter()
            with copy_numel_counter:
                res = weight_loader(param, loaded_weight, *args, **kwargs)  # type: ignore[misc]
            layer._loaded_numel += copy_numel_counter.copied_numel
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479

            # if we have loaded all of the elements, call
            # process_weights_after_loading
            target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel()
            if layer._loaded_numel == target_loaded_numel:
                self.process_weights_after_loading(layer)

                # Delete the bookkeeping
                del layer._loaded_numel
                # Prevent the usual `process_weights_after_loading` call
                # from doing anything
                layer._already_called_process_weights_after_loading = True

            return res

        new_extra_weight_attrs["weight_loader"] = patched_weight_loader
        extra_weight_attrs = new_extra_weight_attrs

        # WEIGHTS
        w13_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                hidden_size,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

        w2_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
        # Allocate 2 scales for w1 and w3 respectively.
        # They will be combined to a single scale after weight loading.
        w13_weight_scale = torch.nn.Parameter(
            torch.ones(num_experts, dtype=torch.float32), requires_grad=False
        )
        w2_weight_scale = torch.nn.Parameter(
            torch.ones(num_experts, dtype=torch.float32), requires_grad=False
        )
        layer.register_parameter("w13_weight_scale", w13_weight_scale)
        layer.register_parameter("w2_weight_scale", w2_weight_scale)

        layer.w13_input_scale = None
        layer.w2_input_scale = None

        self.rocm_aiter_moe_enabled = False

    def process_weights_after_loading(self, layer: Module) -> None:
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

        # Lazy import to avoid importing triton too early.
        self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()

        # If checkpoint is fp16, quantize in place.
        fp8_dtype = current_platform.fp8_dtype()
        w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
        w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)

        for expert in range(layer.local_num_experts):
            w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
                ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
            )
            w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
                ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
            )
        replace_parameter(layer, "w13_weight", w13_weight)
        replace_parameter(layer, "w2_weight", w2_weight)

        # Reshuffle weights for AITER if needed.
        if self.rocm_aiter_moe_enabled:
            shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
                layer.w13_weight, layer.w2_weight
            )
            replace_parameter(layer, "w13_weight", shuffled_w13)
            replace_parameter(layer, "w2_weight", shuffled_w2)

        # Rushuffle weights for MARLIN if needed.
        if self.use_marlin:
            prepare_moe_fp8_layer_for_marlin(
                layer, False, input_dtype=self.marlin_input_dtype
            )


1480
1481
1482
class Fp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
1483
1484
1485
    """

    def __init__(self, quant_config: Fp8Config):
1486
        super().__init__(quant_config)