fp8.py 60.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from enum import Enum
5
from typing import TYPE_CHECKING, Any, Optional
6
7
8
9

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter
10
from torch.utils._python_dispatch import TorchDispatchMode
11

12
import vllm.envs as envs
13
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
14
from vllm import _custom_ops as ops
15
from vllm._aiter_ops import rocm_aiter_ops
16
from vllm.attention.layer import Attention
17
from vllm.distributed import get_tensor_model_parallel_world_size
18
from vllm.logger import init_logger
19
from vllm.model_executor.layers.batch_invariant import (
20
    vllm_is_batch_invariant,
21
)
bnellnm's avatar
bnellnm committed
22
from vllm.model_executor.layers.fused_moe import (
23
24
25
26
27
28
29
    FusedMoE,
    FusedMoEActivationFormat,
    FusedMoEMethodBase,
    FusedMoEPermuteExpertsUnpermute,
    FusedMoEPrepareAndFinalize,
    FusedMoeWeightScaleSupported,
)
30
from vllm.model_executor.layers.fused_moe.config import (
31
    FusedMoEParallelConfig,
32
    FusedMoEQuantConfig,
33
    RoutingMethodType,
34
    fp8_w8a8_moe_quant_config,
35
    fp8_w8a16_moe_quant_config,
36
37
38
39
40
41
42
)
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
)
43
from vllm.model_executor.layers.quantization import QuantizationMethods
44
from vllm.model_executor.layers.quantization.base_config import (
45
46
47
    QuantizationConfig,
    QuantizeMethodBase,
)
48
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
49
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
50
51
    FlashinferMoeBackend,
    apply_flashinfer_per_tensor_scale_fp8,
52
    build_flashinfer_fp8_cutlass_moe_prepare_finalize,
53
54
55
56
57
58
    get_flashinfer_moe_backend,
    register_moe_scaling_factors,
    rotate_flashinfer_fp8_moe_weights,
    select_cutlass_fp8_gemm_impl,
    swap_w13_to_w31,
)
59
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
60
61
62
63
    W8A8BlockFp8LinearOp,
    create_fp8_input_scale,
    create_fp8_scale_parameter,
    create_fp8_weight_parameter,
64
    deepgemm_post_process_fp8_weight_block,
65
66
67
68
69
    maybe_post_process_fp8_weight_block,
    process_fp8_weight_block_strategy,
    process_fp8_weight_tensor_strategy,
    validate_fp8_block_shape,
)
70
71
72
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
    get_marlin_input_dtype,
)
73
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
74
75
76
77
    apply_fp8_marlin_linear,
    prepare_fp8_layer_for_marlin,
    prepare_moe_fp8_layer_for_marlin,
)
78
from vllm.model_executor.layers.quantization.utils.quant_utils import (
79
80
81
    GroupShape,
    is_layer_skipped,
)
82
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
83
84
85
86
87
88
89
90
91
92
93
94
95
    Fp8LinearOp,
    all_close_1d,
    cutlass_block_fp8_supported,
    cutlass_fp8_supported,
    maybe_create_device_identity,
    normalize_e4m3fn_to_e4m3fnuz,
    per_tensor_dequantize,
)
from vllm.model_executor.parameter import (
    BlockQuantScaleParameter,
    ModelWeightParameter,
    PerTensorScaleParameter,
)
96
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
97
from vllm.platforms import current_platform
98
99
100
101
from vllm.utils.deep_gemm import (
    is_deep_gemm_e8m0_used,
    is_deep_gemm_supported,
)
102
from vllm.utils.flashinfer import has_flashinfer_moe
103
from vllm.utils.import_utils import has_deep_gemm
104

105
106
107
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

108
109
110
111
ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)

112

113
114
115
116
117
class Fp8MoeBackend(Enum):
    NONE = 0
    FLASHINFER_TRTLLM = 1
    FLASHINFER_CUTLASS = 2
    DEEPGEMM = 3
118
119
    MARLIN = 4
    TRITON = 5
120
121


122
def get_fp8_moe_backend(
123
124
125
    block_quant: bool,
    moe_parallel_config: FusedMoEParallelConfig,
    with_lora_support: bool,
126
) -> Fp8MoeBackend | None:
127
128
129
130
    """
    Select the primary FP8 MoE backend
    Note: Shape-specific fallbacks may still occur at runtime.
    """
131
132
    if current_platform.is_xpu():
        return None
133
134
    if with_lora_support:
        return Fp8MoeBackend.TRITON
135
    # Prefer FlashInfer backends on supported GPUs; allow SM90 and SM100.
136
137
    if (
        current_platform.is_cuda()
138
        and (
139
            current_platform.is_device_capability_family(100)
140
141
            or current_platform.is_device_capability(90)
        )
142
143
144
        and envs.VLLM_USE_FLASHINFER_MOE_FP8
        and has_flashinfer_moe()
    ):
145
146
        backend = get_flashinfer_moe_backend()
        if backend == FlashinferMoeBackend.TENSORRT_LLM:
147
            logger.info_once("Using FlashInfer FP8 MoE TRTLLM backend for SM100")
148
149
            return Fp8MoeBackend.FLASHINFER_TRTLLM
        else:
150
            if block_quant and current_platform.is_device_capability_family(100):
151
152
153
154
155
156
                raise ValueError(
                    "FlashInfer FP8 MoE throughput backend does not "
                    "support block quantization. Please use "
                    "VLLM_FLASHINFER_MOE_BACKEND=latency "
                    "instead."
                )
157
            logger.info_once("Using FlashInfer FP8 MoE CUTLASS backend for SM90/SM100")
158
159
160
            return Fp8MoeBackend.FLASHINFER_CUTLASS

    # weight-only path for older GPUs without native FP8
161
162
163
164
    use_marlin = (
        not current_platform.has_device_capability(89)
        or envs.VLLM_TEST_FORCE_FP8_MARLIN
    )
165
166
167
168
169
170
    if current_platform.is_rocm():
        use_marlin = False
    if use_marlin:
        logger.info_once("Using Marlin backend for FP8 MoE")
        return Fp8MoeBackend.MARLIN

171
172
173
174
175
176
177
178
179
180
181
182
183
    # Determine if we should use DeepGEMM with block-quantized weights:
    # - If explicitly set by user, respect their choice
    # - If not explicitly set (default), disable when TP size is >= 8
    moe_use_deep_gemm = envs.VLLM_MOE_USE_DEEP_GEMM
    if not envs.is_set("VLLM_MOE_USE_DEEP_GEMM") and moe_parallel_config.tp_size >= 8:
        moe_use_deep_gemm = False
        logger.info_once(
            "DeepGEMM MoE is disabled by default when TP size is >= 8. "
            "Set VLLM_MOE_USE_DEEP_GEMM=1 to enable it.",
            scope="local",
        )

    if envs.VLLM_USE_DEEP_GEMM and moe_use_deep_gemm and block_quant:
184
        if not has_deep_gemm():
185
186
187
            logger.warning_once(
                "DeepGEMM backend requested but not available.", scope="local"
            )
188
        elif is_deep_gemm_supported():
189
            logger.info_once("Using DeepGEMM backend for FP8 MoE", scope="local")
190
191
192
193
194
195
196
            return Fp8MoeBackend.DEEPGEMM

    # default to Triton
    logger.info_once("Using Triton backend for FP8 MoE")
    return Fp8MoeBackend.TRITON


197
class Fp8Config(QuantizationConfig):
198
199
    """Config class for FP8."""

200
201
    def __init__(
        self,
202
        is_checkpoint_fp8_serialized: bool = False,
203
        activation_scheme: str = "dynamic",
204
205
        ignored_layers: list[str] | None = None,
        weight_block_size: list[int] | None = None,
206
    ) -> None:
207
        super().__init__()
208

209
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
210

211
        if activation_scheme not in ACTIVATION_SCHEMES:
212
            raise ValueError(f"Unsupported activation scheme {activation_scheme}")
213
        self.activation_scheme = activation_scheme
214
        self.ignored_layers = ignored_layers or []
215
216
217
218
        if weight_block_size is not None:
            if not is_checkpoint_fp8_serialized:
                raise ValueError(
                    "The block-wise quantization only supports fp8-serialized "
219
220
                    "checkpoint for now."
                )
221
222
223
            if len(weight_block_size) != 2:
                raise ValueError(
                    "The quantization block size of weight must have 2 "
224
225
                    f"dimensions, but got {len(weight_block_size)} dimensions"
                )
226
            if activation_scheme != "dynamic":
227
228
229
230
231
                raise ValueError(
                    "The block-wise quantization only supports "
                    "dynamic activation scheme for now, but got "
                    f"{activation_scheme} activation scheme."
                )
232
        self.weight_block_size = weight_block_size
233

234
    @classmethod
235
    def get_name(cls) -> QuantizationMethods:
236
237
238
        return "fp8"

    @classmethod
239
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
240
241
242
243
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
244
        return 75
245
246

    @classmethod
247
    def get_config_filenames(cls) -> list[str]:
248
249
        return []

250
251
    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if self.ignored_layers is not None:
252
            self.ignored_layers = hf_to_vllm_mapper.apply_list(self.ignored_layers)
253

254
    @classmethod
255
    def from_config(cls, config: dict[str, Any]) -> "Fp8Config":
256
        quant_method = cls.get_from_keys(config, ["quant_method"])
257
        is_checkpoint_fp8_serialized = "fp8" in quant_method
258
        activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
259
        ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
260
        weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None)
261
        if not ignored_layers:
262
263
264
265
266
267
268
269
270
271
272
273
274
            ignored_layers = cls.get_from_keys_or(
                config, ["modules_to_not_convert"], None
            )
        return cls(
            is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
            activation_scheme=activation_scheme,
            ignored_layers=ignored_layers,
            weight_block_size=weight_block_size,
        )

    def get_xpu_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
275
        from vllm.model_executor.layers.quantization.ipex_quant import (
276
277
278
279
            XPUFp8LinearMethod,
            XPUFp8MoEMethod,
        )

280
281
282
283
        fp8_config = Fp8Config(
            is_checkpoint_fp8_serialized=self.is_checkpoint_fp8_serialized,
            activation_scheme=self.activation_scheme,
            ignored_layers=self.ignored_layers,
284
285
            weight_block_size=self.weight_block_size,
        )
286
287

        if isinstance(layer, LinearBase):
288
289
290
291
292
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
293
294
295
                return UnquantizedLinearMethod()
            return XPUFp8LinearMethod(fp8_config)
        elif isinstance(layer, FusedMoE):
296
297
298
299
300
301
302
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
                return UnquantizedFusedMoEMethod(layer.moe_config)

303
304
305
306
307
            return XPUFp8MoEMethod(fp8_config, layer)
        elif isinstance(layer, Attention):
            return Fp8KVCacheMethod(self)
        return None

308
309
310
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
311
312
        if current_platform.is_xpu():
            return self.get_xpu_quant_method(layer, prefix)
313
        if isinstance(layer, LinearBase):
314
315
316
317
318
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
319
                return UnquantizedLinearMethod()
320
321
322
            quant_method = Fp8LinearMethod(self)
            quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
            return quant_method
323
        elif isinstance(layer, FusedMoE):
324
325
326
327
328
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
XuruiYang's avatar
XuruiYang committed
329
                return UnquantizedFusedMoEMethod(layer.moe_config)
330
331
332
333
            if self.is_checkpoint_fp8_serialized:
                moe_quant_method = Fp8MoEMethod(self, layer)
            else:
                moe_quant_method = Fp8OnlineMoEMethod(self, layer)
334
335
            moe_quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
            return moe_quant_method
336
        elif isinstance(layer, Attention):
337
            return Fp8KVCacheMethod(self)
338
        return None
339

340
    def get_cache_scale(self, name: str) -> str | None:
341
342
343
344
345
346
347
348
349
350
351
352
        """
        Check whether the param name matches the format for k/v cache scales
        in compressed-tensors. If this is the case, return its equivalent
        param name expected by vLLM

        :param name: param name
        :return: matching param name for KV cache scale in vLLM
        """
        if name.endswith(".output_scale") and ".k_proj" in name:
            return name.replace(".k_proj.output_scale", ".attn.k_scale")
        if name.endswith(".output_scale") and ".v_proj" in name:
            return name.replace(".v_proj.output_scale", ".attn.v_scale")
353
354
355
356
357
        if name.endswith(".output_scale") and ".q_proj" in name:
            return name.replace(".q_proj.output_scale", ".attn.q_scale")
        if name.endswith("self_attn.prob_output_scale"):
            return name.replace(".prob_output_scale", ".attn.prob_scale")
        # If no matches, return None
358
359
        return None

360

361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
class CopyNumelCounter(TorchDispatchMode):
    """
    Tracks total number of elements modified with `copy_`. Useful for keeping
    track of weight loading where underlying weights can be arbitrarily
    transformed (such as with `narrow`) before calling copy.
    """

    def __init__(self):
        super().__init__()
        self.copied_numel = 0

    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        out = func(*args, **kwargs)
        if func == torch.ops.aten.copy_.default:
            self.copied_numel += args[0].numel()
        return out


381
382
class Fp8LinearMethod(LinearMethodBase):
    """Linear method for FP8.
383
384
385
386
387
388
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.
389
390
391
392
393

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
    2. Only support float8_e4m3fn data type due to the limitation of
       torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
394

395
396
397
398
    Args:
        quant_config: The quantization config.
    """

399
    def __init__(self, quant_config: Fp8Config):
400
        self.quant_config = quant_config
401
        self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
402
        self.out_dtype = torch.get_default_dtype()
403

404
405
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
406
        self.marlin_input_dtype = None
407
408
409
410
        self.use_marlin = (
            not current_platform.has_device_capability(89)
            or envs.VLLM_TEST_FORCE_FP8_MARLIN
        )
411
        # Disable marlin for rocm
412
        if current_platform.is_rocm():
413
            self.use_marlin = False
414
        if vllm_is_batch_invariant():
415
            self.use_marlin = False
416

417
        self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled()
418
        self.use_deep_gemm = is_deep_gemm_supported()
419

420
421
        self.weight_block_size = self.quant_config.weight_block_size
        self.block_quant = self.weight_block_size is not None
422
        self.act_q_static = self.quant_config.activation_scheme == "static"
423
424
        if self.weight_block_size:
            self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
425
        else:
426
427
428
429
430
            # Use per-token quantization for better perf if dynamic and cutlass
            if not self.act_q_static and cutlass_fp8_supported():
                self.act_q_group_shape = GroupShape.PER_TOKEN
            else:
                self.act_q_group_shape = GroupShape.PER_TENSOR
431

432
433
434
435
436
437
438
439
440
441
442
443
        if self.block_quant:
            assert not self.act_q_static
            assert self.weight_block_size is not None
            self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
                weight_group_shape=GroupShape(*self.weight_block_size),
                act_quant_group_shape=self.act_q_group_shape,
                cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
                use_aiter_and_is_supported=self.use_aiter_and_is_supported,
            )
        else:
            self.fp8_linear = Fp8LinearOp(
                act_quant_static=self.act_q_static,
444
445
                act_quant_group_shape=self.act_q_group_shape,
            )
446

447
448
449
450
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
451
        output_partition_sizes: list[int],
452
453
454
455
456
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
457
458
        maybe_create_device_identity()

459
        output_size_per_partition = sum(output_partition_sizes)
460
        weight_loader = extra_weight_attrs.get("weight_loader")
461
462
463
464
465
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None
466

467
        if self.block_quant:
468
469
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
470
471
472
473
474
475
476
477
            validate_fp8_block_shape(
                layer,
                input_size,
                output_size,
                input_size_per_partition,
                output_partition_sizes,
                self.weight_block_size,
            )
478

479
        # WEIGHT
480
        if self.quant_config.is_checkpoint_fp8_serialized:
481
482
483
            weight = create_fp8_weight_parameter(
                output_size_per_partition, input_size_per_partition, weight_loader
            )
484
        else:
485
486
487
488
489

            def patched_weight_loader(param, loaded_weight, *args, **kwargs):
                # track how many elements we have updated
                if not hasattr(layer, "_loaded_numel"):
                    layer._loaded_numel = 0
490
491
492
493
494
495

                # load the current weight chunk
                copy_numel_counter = CopyNumelCounter()
                with copy_numel_counter:
                    res = weight_loader(param, loaded_weight, *args, **kwargs)  # type: ignore[misc]
                layer._loaded_numel += copy_numel_counter.copied_numel
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510

                # if we have loaded all of the elements, call
                # process_weights_after_loading
                target_loaded_numel = layer.weight.numel()
                if layer._loaded_numel == target_loaded_numel:
                    self.process_weights_after_loading(layer)

                    # Delete the bookkeeping
                    del layer._loaded_numel
                    # Prevent the usual `process_weights_after_loading` call from doing
                    # anything
                    layer._already_called_process_weights_after_loading = True

                return res

511
            # For non-serialized checkpoints, use original dtype
512
513
514
515
516
517
518
519
            weight = ModelWeightParameter(
                data=torch.empty(
                    output_size_per_partition,
                    input_size_per_partition,
                    dtype=params_dtype,
                ),
                input_dim=1,
                output_dim=0,
520
                weight_loader=patched_weight_loader,
521
            )
522
523
        layer.register_parameter("weight", weight)

524
525
526
527
        # If checkpoint is serialized fp8, load them.
        # Otherwise, wait until process_weights_after_loading.
        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
528
            if not self.block_quant:
529
530
531
532
533
534
535
                scale = create_fp8_scale_parameter(
                    PerTensorScaleParameter,
                    output_partition_sizes,
                    input_size_per_partition,
                    None,
                    weight_loader,
                )
536
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
537
538
                layer.register_parameter("weight_scale", scale)
            else:
539
540
                assert not self.act_q_static
                assert self.weight_block_size is not None
541
542
543
544
545
546
547
                scale = create_fp8_scale_parameter(
                    BlockQuantScaleParameter,
                    output_partition_sizes,
                    input_size_per_partition,
                    self.weight_block_size,
                    weight_loader,
                )
548
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
549
550
                # The weight_scale_inv name is intentional for deepseekv3
                layer.register_parameter("weight_scale_inv", scale)
551

552
            # INPUT ACTIVATION SCALE
553
            if self.act_q_static:
554
                scale = create_fp8_input_scale(output_partition_sizes, weight_loader)
555
                set_weight_attrs(scale, {"scale_type": "input_scale"})
556
                layer.register_parameter("input_scale", scale)
557
558
            else:
                layer.register_parameter("input_scale", None)
559

560
    def process_weights_after_loading(self, layer: Module) -> None:
561
562
563
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

564
        size_k_first = True
565
        input_scale = None
566
        # TODO(rob): refactor block quant into separate class.
567
        if self.block_quant:
568
            assert not self.act_q_static
569
            size_k_first = False
570

571
            weight, weight_scale_inv = process_fp8_weight_block_strategy(
572
573
                layer.weight, layer.weight_scale_inv
            )
574
575
576
577

            # Update layer with new values
            replace_parameter(layer, "weight", weight.data)
            replace_parameter(layer, "weight_scale_inv", weight_scale_inv.data)
578

579
        # If checkpoint not serialized fp8, quantize the weights.
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
        else:
            if not self.quant_config.is_checkpoint_fp8_serialized:
                qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
                weight = qweight.t()

            # If checkpoint is fp8 per-tensor, handle that there are N scales for N
            # shards in a fused module
            else:
                weight = layer.weight
                weight_scale = layer.weight_scale

                # If using w8a8, torch._scaled_mm needs per tensor, so
                # requantize the logical shards as a single weight.
                if not self.use_marlin:
                    weight, weight_scale, input_scale = (
                        process_fp8_weight_tensor_strategy(
                            weight,
                            weight_scale,
                            layer.logical_widths,
                            getattr(layer, "input_scale", None),
                        )
                    )
                    if self.act_q_static:
                        assert input_scale is not None
                        input_scale = input_scale.max()
                weight = weight.t()
606

607
608
609
610
611
612
            # Update layer with new values.
            replace_parameter(layer, "weight", weight.data)
            replace_parameter(layer, "weight_scale", weight_scale.data)

        if input_scale is not None:
            replace_parameter(layer, "input_scale", input_scale)
613
        else:
614
            layer.input_scale = None
615

616
        if self.use_marlin:
617
618
619
            prepare_fp8_layer_for_marlin(
                layer, size_k_first, input_dtype=self.marlin_input_dtype
            )
620
621
            # Activations not quantized for marlin.
            del layer.input_scale
622
            return
623

624
        if self.block_quant:
625
            maybe_post_process_fp8_weight_block(layer)
626

627
628
629
630
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
631
        bias: torch.Tensor | None = None,
632
    ) -> torch.Tensor:
633
634
        # if batch invariant mode is enabled, prefer DeepGEMM FP8 path
        # we will use BF16 dequant when DeepGEMM is not supported.
635
        if vllm_is_batch_invariant():
636
637
            if self.block_quant:
                assert self.weight_block_size is not None
638
639
640
                return self.w8a8_block_fp8_linear.apply(
                    input=x,
                    weight=layer.weight,
641
                    weight_scale=layer.weight_scale_inv,
642
643
644
                    input_scale=layer.input_scale,
                    bias=bias,
                )
645
            else:
646
647
648
                # per-tensor/channel: dequant to BF16 and run GEMM
                weight_fp8 = layer.weight.to(torch.bfloat16)
                weight_scale = layer.weight_scale.to(torch.bfloat16)
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
                if weight_scale.numel() == 1:
                    # Per-tensor: simple scalar multiplication
                    weight_bf16 = weight_fp8 * weight_scale
                else:
                    # Multiple scales (fused modules like QKV)
                    # Try to infer correct broadcasting
                    # weight is [K, N], scale could be [num_logical_weights]
                    # Need to figure out how to broadcast - for now just try
                    # direct multiplication
                    if (
                        weight_scale.dim() == 1
                        and weight_scale.shape[0] == weight_fp8.shape[0]
                    ):
                        # Per-row scaling
                        weight_bf16 = weight_fp8 * weight_scale.unsqueeze(1)
                    else:
                        # Fallback
                        weight_bf16 = weight_fp8 * weight_scale
667
                return torch.nn.functional.linear(x, weight_bf16.t(), bias)
668

669
        if self.use_marlin:
670
671
672
673
674
            if self.block_quant:
                weight_scale = layer.weight_scale_inv
            else:
                weight_scale = layer.weight_scale

675
676
677
            return apply_fp8_marlin_linear(
                input=x,
                weight=layer.weight,
678
                weight_scale=weight_scale,
679
680
681
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
682
                input_dtype=self.marlin_input_dtype,
683
684
                bias=bias,
            )
685

686
        if self.block_quant:
687
688
689
            assert self.weight_block_size is not None

            return self.w8a8_block_fp8_linear.apply(
690
                input=x,
691
                weight=layer.weight,
692
                weight_scale=layer.weight_scale_inv,
693
                input_scale=layer.input_scale,
694
                bias=bias,
695
            )
696

697
698
699
700
701
702
703
704
        return self.fp8_linear.apply(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            out_dtype=self.out_dtype,
            input_scale=layer.input_scale,
            bias=bias,
        )
705
706


707
708
709
710
711
712
713
714
715
716
717
718
719
class Fp8MoEMethod(FusedMoEMethodBase):
    """MoE method for FP8.
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

720
721
722
    def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
        super().__init__(layer.moe_config)
        self.layer = layer
723
        self.quant_config = quant_config
724
        self.weight_block_size = self.quant_config.weight_block_size
725
        self.block_quant: bool = self.weight_block_size is not None
726
        self.fp8_backend = get_fp8_moe_backend(
727
            self.block_quant, layer.moe_parallel_config, self.moe.is_lora_enabled
728
        )
729

730
        self.marlin_input_dtype = None
731
        self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
732
733
734
735
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
            self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
        elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
            self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
            if self.block_quant and self.weight_block_size != [128, 128]:
                raise NotImplementedError(
                    "FlashInfer CUTLASS FP8 MoE backend only supports block "
                    "size [128, 128]."
                )
            if not self.block_quant:
                if layer.renormalize or layer.custom_routing_function is not None:
                    raise NotImplementedError(
                        "FlashInfer CUTLASS FP8 MoE backend does custom routing "
                        f"function or renormalization, but got {layer.renormalize} and "
                        f"{layer.custom_routing_function}."
                    )
                if layer.scoring_func != "sigmoid":
                    raise NotImplementedError(
                        "FlashInfer CUTLASS FP8 MoE backend only supports "
                        f"'sigmoid' scoring function, but got {layer.scoring_func}."
                    )
            if layer.activation != "silu":
                raise NotImplementedError(
                    "FlashInfer CUTLASS FP8 MoE backend only supports SiLU "
                    "activation function, but got {layer.activation}."
757
                )
758

759
760
761
762
763
764
765
766
767
    def create_weights(
        self,
        layer: Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
768
769
770
771
772
773
        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

774
775
776
        assert self.quant_config.is_checkpoint_fp8_serialized
        params_dtype = torch.float8_e4m3fn

777
        if self.block_quant:
778
779
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
780
781
            tp_size = get_tensor_model_parallel_world_size()
            block_n, block_k = (
782
783
                self.weight_block_size[0],
                self.weight_block_size[1],
784
785
786
787
788
            )
            # NOTE: To ensure proper alignment of the block-wise quantization
            # scales, the output_size of the weights for both the gate and up
            # layers must be divisible by block_n.
            # Required by column parallel or enabling merged weights
789
            if intermediate_size_per_partition % block_n != 0:
790
791
                raise ValueError(
                    f"The output_size of gate's and up's weight = "
792
                    f"{intermediate_size_per_partition} is not divisible by "
793
794
795
                    f"weight quantization block_n = {block_n}."
                )
            if tp_size > 1 and intermediate_size_per_partition % block_k != 0:
796
                # Required by row parallel
797
798
799
                raise ValueError(
                    f"The input_size of down's weight = "
                    f"{intermediate_size_per_partition} is not divisible by "
800
801
                    f"weight quantization block_k = {block_k}."
                )
802
803

        # WEIGHTS
804
805
806
807
808
809
810
811
812
        w13_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                hidden_size,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
813
814
815
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

816
817
818
819
820
821
822
823
824
        w2_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
825
826
827
828
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
829
830
831
        if not self.block_quant:
            # Allocate 2 scales for w1 and w3 respectively.
            # They will be combined to a single scale after weight loading.
832
833
834
835
836
837
            w13_weight_scale = torch.nn.Parameter(
                torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
            )
            w2_weight_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
838
839
840
841
842
843
            layer.register_parameter("w13_weight_scale", w13_weight_scale)
            layer.register_parameter("w2_weight_scale", w2_weight_scale)
        else:
            w13_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
844
                    2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
845
846
847
848
849
850
851
852
853
                    (hidden_size + block_k - 1) // block_k,
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            w2_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
                    (hidden_size + block_n - 1) // block_n,
854
                    (intermediate_size_per_partition + block_k - 1) // block_k,
855
856
857
858
859
860
861
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
            layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
            assert self.quant_config.activation_scheme == "dynamic"
862

863
864
865
        # Add the quantization method used (per tensor/grouped/channel)
        # to ensure the weight scales are loaded in properly
        extra_weight_attrs.update(
866
867
868
869
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
            if self.block_quant
            else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
        )
870
871
        set_weight_attrs(w13_weight_scale, extra_weight_attrs)
        set_weight_attrs(w2_weight_scale, extra_weight_attrs)
872
873
874

        # INPUT_SCALES
        if self.quant_config.activation_scheme == "static":
875
876
877
            w13_input_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
878
            layer.register_parameter("w13_input_scale", w13_input_scale)
879
            set_weight_attrs(w13_input_scale, extra_weight_attrs)
880

881
882
883
            w2_input_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
884
            layer.register_parameter("w2_input_scale", w2_input_scale)
885
886
            set_weight_attrs(w2_input_scale, extra_weight_attrs)

887
        else:
888
889
            layer.w13_input_scale = None
            layer.w2_input_scale = None
890

891
892
        self.rocm_aiter_moe_enabled = False

893
    def process_weights_after_loading(self, layer: Module) -> None:
894
895
896
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

897
898
        # Lazy import to avoid importing triton too early.

899
        self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
900

901
        # TODO (rob): refactor block quant into separate class.
902
        if self.block_quant:
903
            assert self.quant_config.activation_scheme == "dynamic"
904
            if current_platform.is_fp8_fnuz():
905
                w13_weight, w13_weight_scale_inv, w13_input_scale = (
906
                    normalize_e4m3fn_to_e4m3fnuz(
907
908
909
910
911
912
                        layer.w13_weight,
                        layer.w13_weight_scale_inv,
                        layer.w13_input_scale,
                    )
                )
                w2_weight, w2_weight_scale_inv, w2_input_scale = (
913
                    normalize_e4m3fn_to_e4m3fnuz(
914
915
916
                        layer.w2_weight, layer.w2_weight_scale_inv, layer.w2_input_scale
                    )
                )
917
            elif self.flashinfer_moe_backend is not None:
918
919
                # NOTE: weights have to be swapped since the activation is
                # applied on different half for flashinfer vs vllm
920
                w13_weight = swap_w13_to_w31(layer.w13_weight.data)
921
                w13_weight_scale_inv = swap_w13_to_w31(layer.w13_weight_scale_inv.data)
922
923
                w2_weight = layer.w2_weight.data
                w2_weight_scale_inv = layer.w2_weight_scale_inv.data
924
925
926
927
928
929
930
            else:
                w13_weight = layer.w13_weight.data
                w13_weight_scale_inv = layer.w13_weight_scale_inv.data
                w2_weight = layer.w2_weight
                w2_weight_scale_inv = layer.w2_weight_scale_inv

            # torch.compile() cannot use Parameter subclasses.
931
932
933
934
            replace_parameter(layer, "w13_weight", w13_weight)
            replace_parameter(layer, "w13_weight_scale_inv", w13_weight_scale_inv)
            replace_parameter(layer, "w2_weight", w2_weight)
            replace_parameter(layer, "w2_weight_scale_inv", w2_weight_scale_inv)
935
            if self.rocm_aiter_moe_enabled:
936
                # reshaping weights is required for aiter moe kernel.
937
                shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
938
939
                    layer.w13_weight.data, layer.w2_weight.data
                )
940

941
942
                replace_parameter(layer, "w13_weight", shuffled_w13)
                replace_parameter(layer, "w2_weight", shuffled_w2)
943

944
            # DeepGemm scales need to be transposed and aligned. We try to do
945
            # it ahead of time for performance reasons.
946
            if self.fp8_backend == Fp8MoeBackend.DEEPGEMM:
947
948
949
950
951
952
                dg_w13_weight, dg_w13_weight_scale_inv = (
                    deepgemm_post_process_fp8_weight_block(
                        wq=layer.w13_weight.data,
                        ws=layer.w13_weight_scale_inv.data,
                        quant_block_shape=tuple(layer.weight_block_size),
                        use_e8m0=is_deep_gemm_e8m0_used(),
953
                    )
954
955
956
957
958
959
960
                )
                dg_w2_weight, dg_w2_weight_scale_inv = (
                    deepgemm_post_process_fp8_weight_block(
                        wq=layer.w2_weight.data,
                        ws=layer.w2_weight_scale_inv.data,
                        quant_block_shape=tuple(layer.weight_block_size),
                        use_e8m0=is_deep_gemm_e8m0_used(),
961
                    )
962
963
964
965
966
967
968
969
970
                )
                layer.w13_weight = Parameter(dg_w13_weight, requires_grad=False)
                layer.w13_weight_scale_inv = Parameter(
                    dg_w13_weight_scale_inv, requires_grad=False
                )
                layer.w2_weight = Parameter(dg_w2_weight, requires_grad=False)
                layer.w2_weight_scale_inv = Parameter(
                    dg_w2_weight_scale_inv, requires_grad=False
                )
971
972
973
974
        else:
            # Fp8 moe kernels require a single activation scale.
            # We take the max of all the scales in case they differ.
            if self.quant_config.activation_scheme == "static":
975
                if layer.w13_input_scale is None or layer.w2_input_scale is None:
976
977
                    raise ValueError(
                        "QuantConfig has static quantization, but found "
978
979
980
981
982
                        "activation scales are None."
                    )
                if not all_close_1d(layer.w13_input_scale) or not all_close_1d(
                    layer.w2_input_scale
                ):
983
                    logger.warning_once(
984
985
                        "Found input_scales that are not equal for "
                        "fp8 MoE layer. Using the maximum across experts "
986
987
                        "for each layer."
                    )
988
989
                replace_parameter(layer, "w13_input_scale", layer.w13_input_scale.max())
                replace_parameter(layer, "w2_input_scale", layer.w2_input_scale.max())
990
            if current_platform.is_fp8_fnuz():
991
                # Normalize the weights and scales
992
                w13_weight, w13_weight_scale, w13_input_scale = (
993
                    normalize_e4m3fn_to_e4m3fnuz(
994
995
996
997
                        layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale
                    )
                )
                w2_weight, w2_weight_scale, w2_input_scale = (
998
                    normalize_e4m3fn_to_e4m3fnuz(
999
1000
1001
                        layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale
                    )
                )
1002
                # Reset the parameter
1003
1004
                replace_parameter(layer, "w13_weight", w13_weight)
                replace_parameter(layer, "w13_weight_scale", w13_weight_scale)
1005
                if w13_input_scale is not None:
1006
1007
1008
                    replace_parameter(layer, "w13_input_scale", w13_input_scale)
                replace_parameter(layer, "w2_weight", w2_weight)
                replace_parameter(layer, "w2_weight_scale", w2_weight_scale)
1009
                if w2_input_scale is not None:
1010
                    replace_parameter(layer, "w2_input_scale", w2_input_scale)
1011
1012
1013

            # Fp8 moe kernel needs single weight scale for w13 per expert.
            # We take the max then dequant and requant each expert.
1014
            assert layer.w13_weight_scale is not None
1015
            shard_size = layer.intermediate_size_per_partition
1016
            max_w13_scales = layer.w13_weight_scale.max(dim=1).values
1017
            for expert_id in range(layer.local_num_experts):
1018
1019
1020
                start = 0
                for shard_id in range(2):
                    dq_weight = per_tensor_dequantize(
1021
1022
1023
1024
1025
1026
                        layer.w13_weight[expert_id][start : start + shard_size, :],
                        layer.w13_weight_scale[expert_id][shard_id],
                    )
                    layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
                        ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
                    )
1027
1028
                    start += shard_size

1029
            if self.rocm_aiter_moe_enabled:
1030
                shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
1031
1032
                    layer.w13_weight, layer.w2_weight
                )
1033

1034
1035
                replace_parameter(layer, "w13_weight", shuffled_w13)
                replace_parameter(layer, "w2_weight", shuffled_w2)
1036

1037
            replace_parameter(layer, "w13_weight_scale", max_w13_scales)
1038

1039
1040
1041
1042
1043
1044
            if self.flashinfer_moe_backend is not None:
                # NOTE: weights have to be swapped since the activation is
                # applied on different half for flashinfer vs vllm
                assert not self.block_quant
                register_moe_scaling_factors(layer)
                w13_weight = swap_w13_to_w31(layer.w13_weight.data)
1045
                if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
1046
1047
1048
                    rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
                layer.w13_weight.data = w13_weight.data

1049
        if self.fp8_backend == Fp8MoeBackend.MARLIN:
1050
1051
1052
            prepare_moe_fp8_layer_for_marlin(
                layer, False, input_dtype=self.marlin_input_dtype
            )
1053
1054
1055
            # Activations not quantized for marlin.
            del layer.w13_input_scale
            del layer.w2_input_scale
1056

1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
        # NOTE(rob): this is a WIP refactor. We are first migrating
        # all of the kernels in the TP case to use mk. Once this is
        # done, then we will initialzie the TP case and DP/EP case
        # via the same code path (i.e. via maybe_init_modular_kernel).
        # NOTE(rob): in progress migrating all into this format.
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
            from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
                FlashInferExperts,
            )
            from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import (  # noqa: E501
                FlashInferAllGatherMoEPrepareAndFinalize,
            )

            config = self.get_fused_moe_quant_config(layer)
            assert config is not None
            self.moe_quant_config = config

            self.kernel = mk.FusedMoEModularKernel(
                FlashInferAllGatherMoEPrepareAndFinalize(
                    use_dp=(self.moe.dp_size > 1),
                    use_deepseek_fp8_block_scale=self.block_quant,
                ),
                FlashInferExperts(
                    out_dtype=torch.get_default_dtype(),
                    quant_config=self.moe_quant_config,
                    ep_rank=self.moe.ep_rank,
                    ep_size=self.moe.ep_size,
                    tp_rank=self.moe.tp_rank,
                    tp_size=self.moe.tp_size,
                    use_dp=(self.moe.dp_size > 1),
                    use_deepseek_fp8_block_scale=self.block_quant,
                ),
            )
            self.use_inplace = False

1092
1093
1094
1095
1096
        elif self.fp8_backend in [
            Fp8MoeBackend.DEEPGEMM,
            Fp8MoeBackend.TRITON,
            Fp8MoeBackend.MARLIN,
        ]:
1097
1098
1099
            from vllm.model_executor.layers.fused_moe import (
                TritonOrDeepGemmExperts,
            )
1100
1101
1102
            from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
                MarlinExperts,
            )
1103
1104
1105
1106
1107
1108
1109
            from vllm.model_executor.layers.fused_moe.prepare_finalize import (
                MoEPrepareAndFinalizeNoEP,
            )

            config = self.get_fused_moe_quant_config(layer)
            assert config is not None
            self.moe_quant_config = config
1110
1111
1112
1113
1114
1115
            use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN
            allow_deep_gemm = self.fp8_backend == Fp8MoeBackend.DEEPGEMM
            moe_kernel = (
                MarlinExperts(quant_config=self.moe_quant_config)
                if use_marlin
                else TritonOrDeepGemmExperts(
1116
                    quant_config=self.moe_quant_config,
1117
1118
1119
1120
1121
1122
                    allow_deep_gemm=allow_deep_gemm,
                )
            )

            self.kernel = mk.FusedMoEModularKernel(
                MoEPrepareAndFinalizeNoEP(), moe_kernel
1123
1124
1125
            )
            self.use_inplace = True

1126
1127
1128
1129
    def maybe_make_prepare_finalize(
        self,
        routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
    ) -> mk.FusedMoEPrepareAndFinalize | None:
1130
        if (
1131
1132
            self.rocm_aiter_moe_enabled
            or self.fp8_backend == Fp8MoeBackend.MARLIN
1133
1134
            or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1135
1136
            return None
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
1137
1138
1139
1140
1141
1142
            if self.block_quant:
                assert self.weight_block_size == [128, 128], (
                    f"Only support weight_block_size == [128, 128], "
                    f"got {self.weight_block_size}"
                )
            # Wire block-scale flag through prepare/finalize when using CUTLASS
1143
            prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
1144
1145
                self.moe,
                use_deepseek_fp8_block_scale=self.block_quant,
1146
            )
1147
1148
1149
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
        else:
1150
            return super().maybe_make_prepare_finalize(routing_tables)
1151

bnellnm's avatar
bnellnm committed
1152
1153
1154
    def select_gemm_impl(
        self,
        prepare_finalize: FusedMoEPrepareAndFinalize,
1155
        layer: torch.nn.Module,
bnellnm's avatar
bnellnm committed
1156
    ) -> FusedMoEPermuteExpertsUnpermute:
1157
        from vllm.model_executor.layers.fused_moe import (
1158
1159
            BatchedDeepGemmExperts,
            BatchedTritonExperts,
1160
            TritonExperts,
1161
1162
            TritonOrDeepGemmExperts,
        )
1163

1164
1165
1166
        assert (
            self.fp8_backend != Fp8MoeBackend.MARLIN
        ) and not self.rocm_aiter_moe_enabled, (
1167
1168
            "Marlin and ROCm AITER are not supported with all2all yet."
        )
1169

1170
1171
        assert self.moe_quant_config is not None

1172
1173
1174
1175
1176
        if (
            prepare_finalize.activation_format
            == FusedMoEActivationFormat.BatchedExperts
        ):
            max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
bnellnm's avatar
bnellnm committed
1177
            assert max_num_tokens_per_rank is not None
1178
1179

            experts_impl = (
1180
1181
1182
                BatchedDeepGemmExperts
                if self.fp8_backend == Fp8MoeBackend.DEEPGEMM
                else BatchedTritonExperts
1183
            )
bnellnm's avatar
bnellnm committed
1184
            logger.debug(
1185
1186
                "%s(%s): max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
                experts_impl.__name__,
1187
1188
1189
1190
1191
                self.__class__.__name__,
                max_num_tokens_per_rank,
                self.weight_block_size,
                False,
            )
1192
            return experts_impl(
1193
                max_num_tokens=max_num_tokens_per_rank,
1194
                num_dispatchers=prepare_finalize.num_dispatchers(),
1195
                quant_config=self.moe_quant_config,
1196
            )
1197
1198
        elif self.moe.is_lora_enabled:
            return TritonExperts(quant_config=self.moe_quant_config)
1199
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
1200
            # Select GEMM experts with block-scale when weights are block-quantized
1201
            experts = select_cutlass_fp8_gemm_impl(
1202
1203
                self.moe,
                self.moe_quant_config,
1204
                use_deepseek_fp8_block_scale=self.block_quant,
1205
1206
1207
            )
            logger.debug_once("Using %s", experts.__class__.__name__)
            return experts
1208
        else:
bnellnm's avatar
bnellnm committed
1209
1210
            logger.debug(
                "TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s",
1211
1212
1213
1214
                self.__class__.__name__,
                self.weight_block_size,
                False,
            )
bnellnm's avatar
bnellnm committed
1215
            return TritonOrDeepGemmExperts(
1216
                quant_config=self.moe_quant_config,
1217
                allow_deep_gemm=(self.fp8_backend == Fp8MoeBackend.DEEPGEMM),
1218
1219
            )

1220
    def get_fused_moe_quant_config(
1221
        self, layer: torch.nn.Module
1222
    ) -> FusedMoEQuantConfig | None:
1223
1224
1225
1226
1227
1228
        if self.fp8_backend == Fp8MoeBackend.MARLIN:
            return fp8_w8a16_moe_quant_config(
                w1_scale=layer.w13_weight_scale,
                w2_scale=layer.w2_weight_scale,
                block_shape=self.weight_block_size,
            )
1229
1230

        return fp8_w8a8_moe_quant_config(
1231
1232
1233
1234
1235
1236
1237
1238
            w1_scale=(
                layer.w13_weight_scale_inv
                if self.block_quant
                else layer.w13_weight_scale
            ),
            w2_scale=(
                layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
            ),
1239
1240
            a1_scale=layer.w13_input_scale,
            a2_scale=layer.w2_input_scale,
1241
            block_shape=self.weight_block_size,
1242
1243
        )

1244
1245
1246
1247
1248
1249
1250
1251
    @property
    def supports_eplb(self) -> bool:
        return True

    @property
    def allow_inplace(self) -> bool:
        return True

1252
1253
    def apply(
        self,
1254
        layer: FusedMoE,
1255
1256
        x: torch.Tensor,
        router_logits: torch.Tensor,
1257
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
1258
        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
1259
            # TODO(rob): convert this to MK.
1260
1261
1262
1263
            if layer.enable_eplb:
                raise NotImplementedError("EPLB not supported for `Fp8MoEMethod` yet.")
            assert layer.activation == "silu", (
                f"Expected 'silu' activation but got {layer.activation}"
1264
            )
1265

1266
            if self.block_quant:
1267
                import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe  # noqa: E501, F401
1268
1269

                e_score_correction_bias = (
1270
1271
                    layer.e_score_correction_bias.to(x.dtype)
                    if layer.e_score_correction_bias is not None
1272
1273
                    else None
                )
1274
                routing_method_type = layer.routing_method_type
1275
                return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
1276
1277
1278
                    routing_logits=router_logits.to(torch.float32)
                    if routing_method_type == RoutingMethodType.DeepSeekV3
                    else router_logits,
1279
1280
1281
1282
1283
1284
                    routing_bias=e_score_correction_bias,
                    x=x,
                    w13_weight=layer.w13_weight,
                    w13_weight_scale_inv=layer.w13_weight_scale_inv,
                    w2_weight=layer.w2_weight,
                    w2_weight_scale_inv=layer.w2_weight_scale_inv,
1285
1286
1287
1288
                    global_num_experts=layer.global_num_experts,
                    top_k=layer.top_k,
                    num_expert_group=layer.num_expert_group,
                    topk_group=layer.topk_group,
1289
1290
1291
                    intermediate_size=layer.intermediate_size_per_partition,
                    expert_offset=layer.ep_rank * layer.local_num_experts,
                    local_num_experts=layer.local_num_experts,
1292
                    block_shape=self.weight_block_size,
1293
                    routing_method_type=routing_method_type,
1294
                    routed_scaling=layer.routed_scaling_factor,
1295
1296
                )
            else:
1297
1298
1299
                assert (
                    not layer.renormalize and layer.custom_routing_function is not None
                )
XuruiYang's avatar
XuruiYang committed
1300
                result = apply_flashinfer_per_tensor_scale_fp8(
1301
1302
1303
                    layer=layer,
                    hidden_states=x,
                    router_logits=router_logits,
1304
1305
1306
1307
1308
1309
                    routing_bias=layer.e_score_correction_bias,
                    global_num_experts=layer.global_num_experts,
                    top_k=layer.top_k,
                    num_expert_group=layer.num_expert_group,
                    topk_group=layer.topk_group,
                    apply_router_weight_on_input=layer.apply_router_weight_on_input,
1310
                )
1311

1312
        topk_weights, topk_ids = layer.select_experts(
1313
1314
1315
1316
1317
1318
            hidden_states=x,
            router_logits=router_logits,
        )

        if self.rocm_aiter_moe_enabled:
            from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (  # noqa: E501
1319
1320
1321
                rocm_aiter_fused_experts,
            )

1322
            # TODO(rob): convert this to MK.
XuruiYang's avatar
XuruiYang committed
1323
            result = rocm_aiter_fused_experts(
1324
1325
1326
                x,
                layer.w13_weight,
                layer.w2_weight,
1327
1328
                topk_weights=topk_weights,
                topk_ids=topk_ids,
1329
1330
1331
                activation=layer.activation,
                apply_router_weight_on_input=layer.apply_router_weight_on_input,
                expert_map=layer.expert_map,
1332
1333
                quant_config=self.moe_quant_config,
            )
1334
1335
        else:
            result = self.kernel(
1336
                x,
1337
1338
                layer.w13_weight,
                layer.w2_weight,
1339
1340
                topk_weights,
                topk_ids,
1341
                inplace=self.use_inplace,
1342
1343
1344
1345
                activation=layer.activation,
                global_num_experts=layer.global_num_experts,
                expert_map=layer.expert_map,
                apply_router_weight_on_input=layer.apply_router_weight_on_input,
1346
            )
1347

1348
        return result
1349
1350


1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
class Fp8OnlineMoEMethod(Fp8MoEMethod):
    """MoE method for online FP8 quantization.
    Supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

    def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
        super().__init__(quant_config, layer)
        assert not quant_config.is_checkpoint_fp8_serialized
        assert quant_config.activation_scheme == "dynamic"
        assert quant_config.weight_block_size is None
        assert self.flashinfer_moe_backend is None

    def create_weights(
        self,
        layer: Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

        # We are doing online quantization, patch the weight loaded
        # to call `process_weights_after_loading` in a streaming fashion
        # as soon as the last weight chunk is loaded.
        weight_loader = extra_weight_attrs["weight_loader"]
        # create a new holder to prevent modifying behavior of any other
        # objects which might depend on the old one
        new_extra_weight_attrs = extra_weight_attrs

        def patched_weight_loader(param, loaded_weight, *args, **kwargs):
            # add a counter to track how many elements we have updated
            if not hasattr(layer, "_loaded_numel"):
                layer._loaded_numel = 0
1395
1396
1397
1398
1399
1400

            # load the current weight chunk
            copy_numel_counter = CopyNumelCounter()
            with copy_numel_counter:
                res = weight_loader(param, loaded_weight, *args, **kwargs)  # type: ignore[misc]
            layer._loaded_numel += copy_numel_counter.copied_numel
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491

            # if we have loaded all of the elements, call
            # process_weights_after_loading
            target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel()
            if layer._loaded_numel == target_loaded_numel:
                self.process_weights_after_loading(layer)

                # Delete the bookkeeping
                del layer._loaded_numel
                # Prevent the usual `process_weights_after_loading` call
                # from doing anything
                layer._already_called_process_weights_after_loading = True

            return res

        new_extra_weight_attrs["weight_loader"] = patched_weight_loader
        extra_weight_attrs = new_extra_weight_attrs

        # WEIGHTS
        w13_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                hidden_size,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

        w2_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
        # Allocate 2 scales for w1 and w3 respectively.
        # They will be combined to a single scale after weight loading.
        w13_weight_scale = torch.nn.Parameter(
            torch.ones(num_experts, dtype=torch.float32), requires_grad=False
        )
        w2_weight_scale = torch.nn.Parameter(
            torch.ones(num_experts, dtype=torch.float32), requires_grad=False
        )
        layer.register_parameter("w13_weight_scale", w13_weight_scale)
        layer.register_parameter("w2_weight_scale", w2_weight_scale)

        layer.w13_input_scale = None
        layer.w2_input_scale = None

        self.rocm_aiter_moe_enabled = False

    def process_weights_after_loading(self, layer: Module) -> None:
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

        # Lazy import to avoid importing triton too early.
        self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()

        # If checkpoint is fp16, quantize in place.
        fp8_dtype = current_platform.fp8_dtype()
        w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
        w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)

        for expert in range(layer.local_num_experts):
            w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
                ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
            )
            w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
                ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
            )
        replace_parameter(layer, "w13_weight", w13_weight)
        replace_parameter(layer, "w2_weight", w2_weight)

        # Reshuffle weights for AITER if needed.
        if self.rocm_aiter_moe_enabled:
            shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
                layer.w13_weight, layer.w2_weight
            )
            replace_parameter(layer, "w13_weight", shuffled_w13)
            replace_parameter(layer, "w2_weight", shuffled_w2)

        # Rushuffle weights for MARLIN if needed.
1492
        if self.fp8_backend == Fp8MoeBackend.MARLIN:
1493
1494
1495
1496
1497
            prepare_moe_fp8_layer_for_marlin(
                layer, False, input_dtype=self.marlin_input_dtype
            )


1498
1499
1500
class Fp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
1501
1502
1503
    """

    def __init__(self, quant_config: Fp8Config):
1504
        super().__init__(quant_config)