fp8.py 60 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from enum import Enum
5
from typing import TYPE_CHECKING, Any, Optional
6
7
8
9

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter
10
from torch.utils._python_dispatch import TorchDispatchMode
11

12
import vllm.envs as envs
13
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
14
from vllm import _custom_ops as ops
15
from vllm._aiter_ops import rocm_aiter_ops
16
from vllm.attention.layer import Attention
17
from vllm.distributed import get_tensor_model_parallel_world_size
18
from vllm.logger import init_logger
19
from vllm.model_executor.layers.batch_invariant import (
20
    vllm_is_batch_invariant,
21
)
bnellnm's avatar
bnellnm committed
22
from vllm.model_executor.layers.fused_moe import (
23
24
25
26
27
28
29
    FusedMoE,
    FusedMoEActivationFormat,
    FusedMoEMethodBase,
    FusedMoEPermuteExpertsUnpermute,
    FusedMoEPrepareAndFinalize,
    FusedMoeWeightScaleSupported,
)
30
from vllm.model_executor.layers.fused_moe.config import (
31
    FusedMoEParallelConfig,
32
    FusedMoEQuantConfig,
33
    RoutingMethodType,
34
    fp8_w8a8_moe_quant_config,
35
    fp8_w8a16_moe_quant_config,
36
37
38
39
40
41
42
)
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
)
43
from vllm.model_executor.layers.quantization import QuantizationMethods
44
from vllm.model_executor.layers.quantization.base_config import (
45
46
47
    QuantizationConfig,
    QuantizeMethodBase,
)
48
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
49
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
50
51
    FlashinferMoeBackend,
    apply_flashinfer_per_tensor_scale_fp8,
52
    build_flashinfer_fp8_cutlass_moe_prepare_finalize,
53
54
55
56
57
58
    get_flashinfer_moe_backend,
    register_moe_scaling_factors,
    rotate_flashinfer_fp8_moe_weights,
    select_cutlass_fp8_gemm_impl,
    swap_w13_to_w31,
)
59
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
60
61
62
63
    W8A8BlockFp8LinearOp,
    create_fp8_input_scale,
    create_fp8_scale_parameter,
    create_fp8_weight_parameter,
64
    deepgemm_post_process_fp8_weight_block,
65
66
67
68
69
    maybe_post_process_fp8_weight_block,
    process_fp8_weight_block_strategy,
    process_fp8_weight_tensor_strategy,
    validate_fp8_block_shape,
)
70
71
72
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
    get_marlin_input_dtype,
)
73
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
74
75
76
77
    apply_fp8_marlin_linear,
    prepare_fp8_layer_for_marlin,
    prepare_moe_fp8_layer_for_marlin,
)
78
from vllm.model_executor.layers.quantization.utils.quant_utils import (
79
80
81
    GroupShape,
    is_layer_skipped,
)
82
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
83
84
85
86
87
88
89
90
91
92
93
94
95
    Fp8LinearOp,
    all_close_1d,
    cutlass_block_fp8_supported,
    cutlass_fp8_supported,
    maybe_create_device_identity,
    normalize_e4m3fn_to_e4m3fnuz,
    per_tensor_dequantize,
)
from vllm.model_executor.parameter import (
    BlockQuantScaleParameter,
    ModelWeightParameter,
    PerTensorScaleParameter,
)
96
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
97
from vllm.platforms import current_platform
98
99
100
101
from vllm.utils.deep_gemm import (
    is_deep_gemm_e8m0_used,
    is_deep_gemm_supported,
)
102
from vllm.utils.flashinfer import has_flashinfer_moe
103
from vllm.utils.import_utils import has_deep_gemm
104

105
106
107
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

108
109
110
111
ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)

112

113
114
115
116
117
class Fp8MoeBackend(Enum):
    NONE = 0
    FLASHINFER_TRTLLM = 1
    FLASHINFER_CUTLASS = 2
    DEEPGEMM = 3
118
119
    MARLIN = 4
    TRITON = 5
120
    AITER = 6
121
122


123
def get_fp8_moe_backend(
124
125
126
    block_quant: bool,
    moe_parallel_config: FusedMoEParallelConfig,
    with_lora_support: bool,
127
) -> Fp8MoeBackend | None:
128
129
130
131
    """
    Select the primary FP8 MoE backend
    Note: Shape-specific fallbacks may still occur at runtime.
    """
132
133
    if current_platform.is_xpu():
        return None
134
135
    if with_lora_support:
        return Fp8MoeBackend.TRITON
136
    # Prefer FlashInfer backends on supported GPUs; allow SM90 and SM100.
137
138
    if (
        current_platform.is_cuda()
139
        and (
140
            current_platform.is_device_capability_family(100)
141
142
            or current_platform.is_device_capability(90)
        )
143
144
145
        and envs.VLLM_USE_FLASHINFER_MOE_FP8
        and has_flashinfer_moe()
    ):
146
147
        backend = get_flashinfer_moe_backend()
        if backend == FlashinferMoeBackend.TENSORRT_LLM:
148
            logger.info_once("Using FlashInfer FP8 MoE TRTLLM backend for SM100")
149
150
            return Fp8MoeBackend.FLASHINFER_TRTLLM
        else:
151
            if block_quant and current_platform.is_device_capability_family(100):
152
153
154
155
156
157
                raise ValueError(
                    "FlashInfer FP8 MoE throughput backend does not "
                    "support block quantization. Please use "
                    "VLLM_FLASHINFER_MOE_BACKEND=latency "
                    "instead."
                )
158
            logger.info_once("Using FlashInfer FP8 MoE CUTLASS backend for SM90/SM100")
159
160
161
            return Fp8MoeBackend.FLASHINFER_CUTLASS

    # weight-only path for older GPUs without native FP8
162
163
164
165
    use_marlin = (
        not current_platform.has_device_capability(89)
        or envs.VLLM_TEST_FORCE_FP8_MARLIN
    )
166
167
168
169
170
171
    if current_platform.is_rocm():
        use_marlin = False
    if use_marlin:
        logger.info_once("Using Marlin backend for FP8 MoE")
        return Fp8MoeBackend.MARLIN

172
173
174
175
176
177
178
179
180
181
182
183
184
    # Determine if we should use DeepGEMM with block-quantized weights:
    # - If explicitly set by user, respect their choice
    # - If not explicitly set (default), disable when TP size is >= 8
    moe_use_deep_gemm = envs.VLLM_MOE_USE_DEEP_GEMM
    if not envs.is_set("VLLM_MOE_USE_DEEP_GEMM") and moe_parallel_config.tp_size >= 8:
        moe_use_deep_gemm = False
        logger.info_once(
            "DeepGEMM MoE is disabled by default when TP size is >= 8. "
            "Set VLLM_MOE_USE_DEEP_GEMM=1 to enable it.",
            scope="local",
        )

    if envs.VLLM_USE_DEEP_GEMM and moe_use_deep_gemm and block_quant:
185
        if not has_deep_gemm():
186
187
188
            logger.warning_once(
                "DeepGEMM backend requested but not available.", scope="local"
            )
189
        elif is_deep_gemm_supported():
190
            logger.info_once("Using DeepGEMM backend for FP8 MoE", scope="local")
191
192
            return Fp8MoeBackend.DEEPGEMM

193
194
195
196
    if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MOE:
        logger.info_once("Using ROCm AITER backend for FP8 MoE", scope="local")
        return Fp8MoeBackend.AITER

197
198
199
200
201
    # default to Triton
    logger.info_once("Using Triton backend for FP8 MoE")
    return Fp8MoeBackend.TRITON


202
class Fp8Config(QuantizationConfig):
203
204
    """Config class for FP8."""

205
206
    def __init__(
        self,
207
        is_checkpoint_fp8_serialized: bool = False,
208
        activation_scheme: str = "dynamic",
209
210
        ignored_layers: list[str] | None = None,
        weight_block_size: list[int] | None = None,
211
    ) -> None:
212
        super().__init__()
213

214
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
215

216
        if activation_scheme not in ACTIVATION_SCHEMES:
217
            raise ValueError(f"Unsupported activation scheme {activation_scheme}")
218
        self.activation_scheme = activation_scheme
219
        self.ignored_layers = ignored_layers or []
220
221
222
223
        if weight_block_size is not None:
            if not is_checkpoint_fp8_serialized:
                raise ValueError(
                    "The block-wise quantization only supports fp8-serialized "
224
225
                    "checkpoint for now."
                )
226
227
228
            if len(weight_block_size) != 2:
                raise ValueError(
                    "The quantization block size of weight must have 2 "
229
230
                    f"dimensions, but got {len(weight_block_size)} dimensions"
                )
231
            if activation_scheme != "dynamic":
232
233
234
235
236
                raise ValueError(
                    "The block-wise quantization only supports "
                    "dynamic activation scheme for now, but got "
                    f"{activation_scheme} activation scheme."
                )
237
        self.weight_block_size = weight_block_size
238

239
    @classmethod
240
    def get_name(cls) -> QuantizationMethods:
241
242
243
        return "fp8"

    @classmethod
244
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
245
246
247
248
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
249
        return 75
250
251

    @classmethod
252
    def get_config_filenames(cls) -> list[str]:
253
254
        return []

255
256
    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if self.ignored_layers is not None:
257
            self.ignored_layers = hf_to_vllm_mapper.apply_list(self.ignored_layers)
258

259
    @classmethod
260
    def from_config(cls, config: dict[str, Any]) -> "Fp8Config":
261
        quant_method = cls.get_from_keys(config, ["quant_method"])
262
        is_checkpoint_fp8_serialized = "fp8" in quant_method
263
        activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
264
        ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
265
        weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None)
266
        if not ignored_layers:
267
268
269
270
271
272
273
274
275
276
277
278
279
            ignored_layers = cls.get_from_keys_or(
                config, ["modules_to_not_convert"], None
            )
        return cls(
            is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
            activation_scheme=activation_scheme,
            ignored_layers=ignored_layers,
            weight_block_size=weight_block_size,
        )

    def get_xpu_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
280
        from vllm.model_executor.layers.quantization.ipex_quant import (
281
282
283
284
            XPUFp8LinearMethod,
            XPUFp8MoEMethod,
        )

285
286
287
288
        fp8_config = Fp8Config(
            is_checkpoint_fp8_serialized=self.is_checkpoint_fp8_serialized,
            activation_scheme=self.activation_scheme,
            ignored_layers=self.ignored_layers,
289
290
            weight_block_size=self.weight_block_size,
        )
291
292

        if isinstance(layer, LinearBase):
293
294
295
296
297
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
298
299
300
                return UnquantizedLinearMethod()
            return XPUFp8LinearMethod(fp8_config)
        elif isinstance(layer, FusedMoE):
301
302
303
304
305
306
307
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
                return UnquantizedFusedMoEMethod(layer.moe_config)

308
309
310
311
312
            return XPUFp8MoEMethod(fp8_config, layer)
        elif isinstance(layer, Attention):
            return Fp8KVCacheMethod(self)
        return None

313
314
315
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
316
317
        if current_platform.is_xpu():
            return self.get_xpu_quant_method(layer, prefix)
318
        if isinstance(layer, LinearBase):
319
320
321
322
323
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
324
                return UnquantizedLinearMethod()
325
326
327
            quant_method = Fp8LinearMethod(self)
            quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
            return quant_method
328
        elif isinstance(layer, FusedMoE):
329
330
331
332
333
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
XuruiYang's avatar
XuruiYang committed
334
                return UnquantizedFusedMoEMethod(layer.moe_config)
335
336
337
338
            if self.is_checkpoint_fp8_serialized:
                moe_quant_method = Fp8MoEMethod(self, layer)
            else:
                moe_quant_method = Fp8OnlineMoEMethod(self, layer)
339
340
            moe_quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
            return moe_quant_method
341
        elif isinstance(layer, Attention):
342
            return Fp8KVCacheMethod(self)
343
        return None
344

345
    def get_cache_scale(self, name: str) -> str | None:
346
347
348
349
350
351
352
353
354
355
356
357
        """
        Check whether the param name matches the format for k/v cache scales
        in compressed-tensors. If this is the case, return its equivalent
        param name expected by vLLM

        :param name: param name
        :return: matching param name for KV cache scale in vLLM
        """
        if name.endswith(".output_scale") and ".k_proj" in name:
            return name.replace(".k_proj.output_scale", ".attn.k_scale")
        if name.endswith(".output_scale") and ".v_proj" in name:
            return name.replace(".v_proj.output_scale", ".attn.v_scale")
358
359
360
361
362
        if name.endswith(".output_scale") and ".q_proj" in name:
            return name.replace(".q_proj.output_scale", ".attn.q_scale")
        if name.endswith("self_attn.prob_output_scale"):
            return name.replace(".prob_output_scale", ".attn.prob_scale")
        # If no matches, return None
363
364
        return None

365

366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
class CopyNumelCounter(TorchDispatchMode):
    """
    Tracks total number of elements modified with `copy_`. Useful for keeping
    track of weight loading where underlying weights can be arbitrarily
    transformed (such as with `narrow`) before calling copy.
    """

    def __init__(self):
        super().__init__()
        self.copied_numel = 0

    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        out = func(*args, **kwargs)
        if func == torch.ops.aten.copy_.default:
            self.copied_numel += args[0].numel()
        return out


386
387
class Fp8LinearMethod(LinearMethodBase):
    """Linear method for FP8.
388
389
390
391
392
393
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.
394
395
396
397
398

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
    2. Only support float8_e4m3fn data type due to the limitation of
       torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
399

400
401
402
403
    Args:
        quant_config: The quantization config.
    """

404
    def __init__(self, quant_config: Fp8Config):
405
        self.quant_config = quant_config
406
        self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
407
        self.out_dtype = torch.get_default_dtype()
408

409
410
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
411
        self.marlin_input_dtype = None
412
413
414
415
        self.use_marlin = (
            not current_platform.has_device_capability(89)
            or envs.VLLM_TEST_FORCE_FP8_MARLIN
        )
416
        # Disable marlin for rocm
417
        if current_platform.is_rocm():
418
            self.use_marlin = False
419
        if vllm_is_batch_invariant():
420
            self.use_marlin = False
421

422
        self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled()
423
        self.use_deep_gemm = is_deep_gemm_supported()
424

425
426
        self.weight_block_size = self.quant_config.weight_block_size
        self.block_quant = self.weight_block_size is not None
427
        self.act_q_static = self.quant_config.activation_scheme == "static"
428
429
        if self.weight_block_size:
            self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
430
        else:
431
432
433
434
435
            # Use per-token quantization for better perf if dynamic and cutlass
            if not self.act_q_static and cutlass_fp8_supported():
                self.act_q_group_shape = GroupShape.PER_TOKEN
            else:
                self.act_q_group_shape = GroupShape.PER_TENSOR
436

437
438
439
440
441
442
443
444
445
446
447
448
        if self.block_quant:
            assert not self.act_q_static
            assert self.weight_block_size is not None
            self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
                weight_group_shape=GroupShape(*self.weight_block_size),
                act_quant_group_shape=self.act_q_group_shape,
                cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
                use_aiter_and_is_supported=self.use_aiter_and_is_supported,
            )
        else:
            self.fp8_linear = Fp8LinearOp(
                act_quant_static=self.act_q_static,
449
450
                act_quant_group_shape=self.act_q_group_shape,
            )
451

452
453
454
455
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
456
        output_partition_sizes: list[int],
457
458
459
460
461
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
462
463
        maybe_create_device_identity()

464
        output_size_per_partition = sum(output_partition_sizes)
465
        weight_loader = extra_weight_attrs.get("weight_loader")
466
467
468
469
470
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None
471

472
        if self.block_quant:
473
474
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
475
476
477
478
479
480
481
482
            validate_fp8_block_shape(
                layer,
                input_size,
                output_size,
                input_size_per_partition,
                output_partition_sizes,
                self.weight_block_size,
            )
483

484
        # WEIGHT
485
        if self.quant_config.is_checkpoint_fp8_serialized:
486
487
488
            weight = create_fp8_weight_parameter(
                output_size_per_partition, input_size_per_partition, weight_loader
            )
489
        else:
490
491
492
493
494

            def patched_weight_loader(param, loaded_weight, *args, **kwargs):
                # track how many elements we have updated
                if not hasattr(layer, "_loaded_numel"):
                    layer._loaded_numel = 0
495
496
497
498
499
500

                # load the current weight chunk
                copy_numel_counter = CopyNumelCounter()
                with copy_numel_counter:
                    res = weight_loader(param, loaded_weight, *args, **kwargs)  # type: ignore[misc]
                layer._loaded_numel += copy_numel_counter.copied_numel
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515

                # if we have loaded all of the elements, call
                # process_weights_after_loading
                target_loaded_numel = layer.weight.numel()
                if layer._loaded_numel == target_loaded_numel:
                    self.process_weights_after_loading(layer)

                    # Delete the bookkeeping
                    del layer._loaded_numel
                    # Prevent the usual `process_weights_after_loading` call from doing
                    # anything
                    layer._already_called_process_weights_after_loading = True

                return res

516
            # For non-serialized checkpoints, use original dtype
517
518
519
520
521
522
523
524
            weight = ModelWeightParameter(
                data=torch.empty(
                    output_size_per_partition,
                    input_size_per_partition,
                    dtype=params_dtype,
                ),
                input_dim=1,
                output_dim=0,
525
                weight_loader=patched_weight_loader,
526
            )
527
528
        layer.register_parameter("weight", weight)

529
530
531
532
        # If checkpoint is serialized fp8, load them.
        # Otherwise, wait until process_weights_after_loading.
        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
533
            if not self.block_quant:
534
535
536
537
538
539
540
                scale = create_fp8_scale_parameter(
                    PerTensorScaleParameter,
                    output_partition_sizes,
                    input_size_per_partition,
                    None,
                    weight_loader,
                )
541
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
542
543
                layer.register_parameter("weight_scale", scale)
            else:
544
545
                assert not self.act_q_static
                assert self.weight_block_size is not None
546
547
548
549
550
551
552
                scale = create_fp8_scale_parameter(
                    BlockQuantScaleParameter,
                    output_partition_sizes,
                    input_size_per_partition,
                    self.weight_block_size,
                    weight_loader,
                )
553
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
554
555
                # The weight_scale_inv name is intentional for deepseekv3
                layer.register_parameter("weight_scale_inv", scale)
556

557
            # INPUT ACTIVATION SCALE
558
            if self.act_q_static:
559
                scale = create_fp8_input_scale(output_partition_sizes, weight_loader)
560
                set_weight_attrs(scale, {"scale_type": "input_scale"})
561
                layer.register_parameter("input_scale", scale)
562
563
            else:
                layer.register_parameter("input_scale", None)
564

565
    def process_weights_after_loading(self, layer: Module) -> None:
566
567
568
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

569
        size_k_first = True
570
        input_scale = None
571
        # TODO(rob): refactor block quant into separate class.
572
        if self.block_quant:
573
            assert not self.act_q_static
574
            size_k_first = False
575

576
            weight, weight_scale_inv = process_fp8_weight_block_strategy(
577
578
                layer.weight, layer.weight_scale_inv
            )
579
580
581
582

            # Update layer with new values
            replace_parameter(layer, "weight", weight.data)
            replace_parameter(layer, "weight_scale_inv", weight_scale_inv.data)
583

584
        # If checkpoint not serialized fp8, quantize the weights.
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
        else:
            if not self.quant_config.is_checkpoint_fp8_serialized:
                qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
                weight = qweight.t()

            # If checkpoint is fp8 per-tensor, handle that there are N scales for N
            # shards in a fused module
            else:
                weight = layer.weight
                weight_scale = layer.weight_scale

                # If using w8a8, torch._scaled_mm needs per tensor, so
                # requantize the logical shards as a single weight.
                if not self.use_marlin:
                    weight, weight_scale, input_scale = (
                        process_fp8_weight_tensor_strategy(
                            weight,
                            weight_scale,
                            layer.logical_widths,
                            getattr(layer, "input_scale", None),
                        )
                    )
                    if self.act_q_static:
                        assert input_scale is not None
                        input_scale = input_scale.max()
                weight = weight.t()
611

612
613
614
615
616
617
            # Update layer with new values.
            replace_parameter(layer, "weight", weight.data)
            replace_parameter(layer, "weight_scale", weight_scale.data)

        if input_scale is not None:
            replace_parameter(layer, "input_scale", input_scale)
618
        else:
619
            layer.input_scale = None
620

621
        if self.use_marlin:
622
623
624
            prepare_fp8_layer_for_marlin(
                layer, size_k_first, input_dtype=self.marlin_input_dtype
            )
625
626
            # Activations not quantized for marlin.
            del layer.input_scale
627
            return
628

629
        if self.block_quant:
630
            maybe_post_process_fp8_weight_block(layer)
631

632
633
634
635
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
636
        bias: torch.Tensor | None = None,
637
    ) -> torch.Tensor:
638
639
        # if batch invariant mode is enabled, prefer DeepGEMM FP8 path
        # we will use BF16 dequant when DeepGEMM is not supported.
640
        if vllm_is_batch_invariant():
641
642
            if self.block_quant:
                assert self.weight_block_size is not None
643
644
645
                return self.w8a8_block_fp8_linear.apply(
                    input=x,
                    weight=layer.weight,
646
                    weight_scale=layer.weight_scale_inv,
647
648
649
                    input_scale=layer.input_scale,
                    bias=bias,
                )
650
            else:
651
652
653
                # per-tensor/channel: dequant to BF16 and run GEMM
                weight_fp8 = layer.weight.to(torch.bfloat16)
                weight_scale = layer.weight_scale.to(torch.bfloat16)
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
                if weight_scale.numel() == 1:
                    # Per-tensor: simple scalar multiplication
                    weight_bf16 = weight_fp8 * weight_scale
                else:
                    # Multiple scales (fused modules like QKV)
                    # Try to infer correct broadcasting
                    # weight is [K, N], scale could be [num_logical_weights]
                    # Need to figure out how to broadcast - for now just try
                    # direct multiplication
                    if (
                        weight_scale.dim() == 1
                        and weight_scale.shape[0] == weight_fp8.shape[0]
                    ):
                        # Per-row scaling
                        weight_bf16 = weight_fp8 * weight_scale.unsqueeze(1)
                    else:
                        # Fallback
                        weight_bf16 = weight_fp8 * weight_scale
672
                return torch.nn.functional.linear(x, weight_bf16.t(), bias)
673

674
        if self.use_marlin:
675
676
677
678
679
            if self.block_quant:
                weight_scale = layer.weight_scale_inv
            else:
                weight_scale = layer.weight_scale

680
681
682
            return apply_fp8_marlin_linear(
                input=x,
                weight=layer.weight,
683
                weight_scale=weight_scale,
684
685
686
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
687
                input_dtype=self.marlin_input_dtype,
688
689
                bias=bias,
            )
690

691
        if self.block_quant:
692
693
694
            assert self.weight_block_size is not None

            return self.w8a8_block_fp8_linear.apply(
695
                input=x,
696
                weight=layer.weight,
697
                weight_scale=layer.weight_scale_inv,
698
                input_scale=layer.input_scale,
699
                bias=bias,
700
            )
701

702
703
704
705
706
707
708
709
        return self.fp8_linear.apply(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            out_dtype=self.out_dtype,
            input_scale=layer.input_scale,
            bias=bias,
        )
710
711


712
713
714
715
716
717
718
719
720
721
722
723
724
class Fp8MoEMethod(FusedMoEMethodBase):
    """MoE method for FP8.
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

725
726
727
    def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
        super().__init__(layer.moe_config)
        self.layer = layer
728
        self.quant_config = quant_config
729
        self.weight_block_size = self.quant_config.weight_block_size
730
        self.block_quant: bool = self.weight_block_size is not None
731
        self.fp8_backend = get_fp8_moe_backend(
732
            self.block_quant, layer.moe_parallel_config, self.moe.is_lora_enabled
733
        )
734

735
        self.marlin_input_dtype = None
736
        self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
737
738
739
740
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
            self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
        elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
            self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
            if self.block_quant and self.weight_block_size != [128, 128]:
                raise NotImplementedError(
                    "FlashInfer CUTLASS FP8 MoE backend only supports block "
                    "size [128, 128]."
                )
            if not self.block_quant:
                if layer.renormalize or layer.custom_routing_function is not None:
                    raise NotImplementedError(
                        "FlashInfer CUTLASS FP8 MoE backend does custom routing "
                        f"function or renormalization, but got {layer.renormalize} and "
                        f"{layer.custom_routing_function}."
                    )
                if layer.scoring_func != "sigmoid":
                    raise NotImplementedError(
                        "FlashInfer CUTLASS FP8 MoE backend only supports "
                        f"'sigmoid' scoring function, but got {layer.scoring_func}."
                    )
            if layer.activation != "silu":
                raise NotImplementedError(
                    "FlashInfer CUTLASS FP8 MoE backend only supports SiLU "
                    "activation function, but got {layer.activation}."
762
                )
763

764
765
766
767
768
769
770
771
772
    def create_weights(
        self,
        layer: Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
773
774
775
776
777
778
        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

779
780
781
        assert self.quant_config.is_checkpoint_fp8_serialized
        params_dtype = torch.float8_e4m3fn

782
        if self.block_quant:
783
784
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
785
786
            tp_size = get_tensor_model_parallel_world_size()
            block_n, block_k = (
787
788
                self.weight_block_size[0],
                self.weight_block_size[1],
789
790
791
792
793
            )
            # NOTE: To ensure proper alignment of the block-wise quantization
            # scales, the output_size of the weights for both the gate and up
            # layers must be divisible by block_n.
            # Required by column parallel or enabling merged weights
794
            if intermediate_size_per_partition % block_n != 0:
795
796
                raise ValueError(
                    f"The output_size of gate's and up's weight = "
797
                    f"{intermediate_size_per_partition} is not divisible by "
798
799
800
                    f"weight quantization block_n = {block_n}."
                )
            if tp_size > 1 and intermediate_size_per_partition % block_k != 0:
801
                # Required by row parallel
802
803
804
                raise ValueError(
                    f"The input_size of down's weight = "
                    f"{intermediate_size_per_partition} is not divisible by "
805
806
                    f"weight quantization block_k = {block_k}."
                )
807
808

        # WEIGHTS
809
810
811
812
813
814
815
816
817
        w13_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                hidden_size,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
818
819
820
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

821
822
823
824
825
826
827
828
829
        w2_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
830
831
832
833
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
834
835
836
        if not self.block_quant:
            # Allocate 2 scales for w1 and w3 respectively.
            # They will be combined to a single scale after weight loading.
837
838
839
840
841
842
            w13_weight_scale = torch.nn.Parameter(
                torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
            )
            w2_weight_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
843
844
845
846
847
848
            layer.register_parameter("w13_weight_scale", w13_weight_scale)
            layer.register_parameter("w2_weight_scale", w2_weight_scale)
        else:
            w13_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
849
                    2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
850
851
852
853
854
855
856
857
858
                    (hidden_size + block_k - 1) // block_k,
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            w2_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
                    (hidden_size + block_n - 1) // block_n,
859
                    (intermediate_size_per_partition + block_k - 1) // block_k,
860
861
862
863
864
865
866
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
            layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
            assert self.quant_config.activation_scheme == "dynamic"
867

868
869
870
        # Add the quantization method used (per tensor/grouped/channel)
        # to ensure the weight scales are loaded in properly
        extra_weight_attrs.update(
871
872
873
874
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
            if self.block_quant
            else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
        )
875
876
        set_weight_attrs(w13_weight_scale, extra_weight_attrs)
        set_weight_attrs(w2_weight_scale, extra_weight_attrs)
877
878
879

        # INPUT_SCALES
        if self.quant_config.activation_scheme == "static":
880
881
882
            w13_input_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
883
            layer.register_parameter("w13_input_scale", w13_input_scale)
884
            set_weight_attrs(w13_input_scale, extra_weight_attrs)
885

886
887
888
            w2_input_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
889
            layer.register_parameter("w2_input_scale", w2_input_scale)
890
891
            set_weight_attrs(w2_input_scale, extra_weight_attrs)

892
        else:
893
894
            layer.w13_input_scale = None
            layer.w2_input_scale = None
895
896

    def process_weights_after_loading(self, layer: Module) -> None:
897
898
899
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

900
        # TODO (rob): refactor block quant into separate class.
901
        if self.block_quant:
902
            assert self.quant_config.activation_scheme == "dynamic"
903
            if current_platform.is_fp8_fnuz():
904
                w13_weight, w13_weight_scale_inv, w13_input_scale = (
905
                    normalize_e4m3fn_to_e4m3fnuz(
906
907
908
909
910
911
                        layer.w13_weight,
                        layer.w13_weight_scale_inv,
                        layer.w13_input_scale,
                    )
                )
                w2_weight, w2_weight_scale_inv, w2_input_scale = (
912
                    normalize_e4m3fn_to_e4m3fnuz(
913
914
915
                        layer.w2_weight, layer.w2_weight_scale_inv, layer.w2_input_scale
                    )
                )
916
            elif self.flashinfer_moe_backend is not None:
917
918
                # NOTE: weights have to be swapped since the activation is
                # applied on different half for flashinfer vs vllm
919
                w13_weight = swap_w13_to_w31(layer.w13_weight.data)
920
                w13_weight_scale_inv = swap_w13_to_w31(layer.w13_weight_scale_inv.data)
921
922
                w2_weight = layer.w2_weight.data
                w2_weight_scale_inv = layer.w2_weight_scale_inv.data
923
924
925
926
927
928
929
            else:
                w13_weight = layer.w13_weight.data
                w13_weight_scale_inv = layer.w13_weight_scale_inv.data
                w2_weight = layer.w2_weight
                w2_weight_scale_inv = layer.w2_weight_scale_inv

            # torch.compile() cannot use Parameter subclasses.
930
931
932
933
            replace_parameter(layer, "w13_weight", w13_weight)
            replace_parameter(layer, "w13_weight_scale_inv", w13_weight_scale_inv)
            replace_parameter(layer, "w2_weight", w2_weight)
            replace_parameter(layer, "w2_weight_scale_inv", w2_weight_scale_inv)
934
            if self.fp8_backend == Fp8MoeBackend.AITER:
935
                # reshaping weights is required for aiter moe kernel.
936
                shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
937
938
                    layer.w13_weight.data, layer.w2_weight.data
                )
939

940
941
                replace_parameter(layer, "w13_weight", shuffled_w13)
                replace_parameter(layer, "w2_weight", shuffled_w2)
942

943
            # DeepGemm scales need to be transposed and aligned. We try to do
944
            # it ahead of time for performance reasons.
945
            if self.fp8_backend == Fp8MoeBackend.DEEPGEMM:
946
947
948
949
950
951
                dg_w13_weight, dg_w13_weight_scale_inv = (
                    deepgemm_post_process_fp8_weight_block(
                        wq=layer.w13_weight.data,
                        ws=layer.w13_weight_scale_inv.data,
                        quant_block_shape=tuple(layer.weight_block_size),
                        use_e8m0=is_deep_gemm_e8m0_used(),
952
                    )
953
954
955
956
957
958
959
                )
                dg_w2_weight, dg_w2_weight_scale_inv = (
                    deepgemm_post_process_fp8_weight_block(
                        wq=layer.w2_weight.data,
                        ws=layer.w2_weight_scale_inv.data,
                        quant_block_shape=tuple(layer.weight_block_size),
                        use_e8m0=is_deep_gemm_e8m0_used(),
960
                    )
961
962
963
964
965
966
967
968
969
                )
                layer.w13_weight = Parameter(dg_w13_weight, requires_grad=False)
                layer.w13_weight_scale_inv = Parameter(
                    dg_w13_weight_scale_inv, requires_grad=False
                )
                layer.w2_weight = Parameter(dg_w2_weight, requires_grad=False)
                layer.w2_weight_scale_inv = Parameter(
                    dg_w2_weight_scale_inv, requires_grad=False
                )
970
971
972
973
        else:
            # Fp8 moe kernels require a single activation scale.
            # We take the max of all the scales in case they differ.
            if self.quant_config.activation_scheme == "static":
974
                if layer.w13_input_scale is None or layer.w2_input_scale is None:
975
976
                    raise ValueError(
                        "QuantConfig has static quantization, but found "
977
978
979
980
981
                        "activation scales are None."
                    )
                if not all_close_1d(layer.w13_input_scale) or not all_close_1d(
                    layer.w2_input_scale
                ):
982
                    logger.warning_once(
983
984
                        "Found input_scales that are not equal for "
                        "fp8 MoE layer. Using the maximum across experts "
985
986
                        "for each layer."
                    )
987
988
                replace_parameter(layer, "w13_input_scale", layer.w13_input_scale.max())
                replace_parameter(layer, "w2_input_scale", layer.w2_input_scale.max())
989
            if current_platform.is_fp8_fnuz():
990
                # Normalize the weights and scales
991
                w13_weight, w13_weight_scale, w13_input_scale = (
992
                    normalize_e4m3fn_to_e4m3fnuz(
993
994
995
996
                        layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale
                    )
                )
                w2_weight, w2_weight_scale, w2_input_scale = (
997
                    normalize_e4m3fn_to_e4m3fnuz(
998
999
1000
                        layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale
                    )
                )
1001
                # Reset the parameter
1002
1003
                replace_parameter(layer, "w13_weight", w13_weight)
                replace_parameter(layer, "w13_weight_scale", w13_weight_scale)
1004
                if w13_input_scale is not None:
1005
1006
1007
                    replace_parameter(layer, "w13_input_scale", w13_input_scale)
                replace_parameter(layer, "w2_weight", w2_weight)
                replace_parameter(layer, "w2_weight_scale", w2_weight_scale)
1008
                if w2_input_scale is not None:
1009
                    replace_parameter(layer, "w2_input_scale", w2_input_scale)
1010
1011
1012

            # Fp8 moe kernel needs single weight scale for w13 per expert.
            # We take the max then dequant and requant each expert.
1013
            assert layer.w13_weight_scale is not None
1014
            shard_size = layer.intermediate_size_per_partition
1015
            max_w13_scales = layer.w13_weight_scale.max(dim=1).values
1016
            for expert_id in range(layer.local_num_experts):
1017
1018
1019
                start = 0
                for shard_id in range(2):
                    dq_weight = per_tensor_dequantize(
1020
1021
1022
1023
1024
1025
                        layer.w13_weight[expert_id][start : start + shard_size, :],
                        layer.w13_weight_scale[expert_id][shard_id],
                    )
                    layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
                        ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
                    )
1026
1027
                    start += shard_size

1028
            if self.fp8_backend == Fp8MoeBackend.AITER:
1029
                shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
1030
1031
                    layer.w13_weight, layer.w2_weight
                )
1032

1033
1034
                replace_parameter(layer, "w13_weight", shuffled_w13)
                replace_parameter(layer, "w2_weight", shuffled_w2)
1035

1036
            replace_parameter(layer, "w13_weight_scale", max_w13_scales)
1037

1038
1039
1040
1041
1042
1043
            if self.flashinfer_moe_backend is not None:
                # NOTE: weights have to be swapped since the activation is
                # applied on different half for flashinfer vs vllm
                assert not self.block_quant
                register_moe_scaling_factors(layer)
                w13_weight = swap_w13_to_w31(layer.w13_weight.data)
1044
                if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
1045
1046
1047
                    rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
                layer.w13_weight.data = w13_weight.data

1048
        if self.fp8_backend == Fp8MoeBackend.MARLIN:
1049
1050
1051
            prepare_moe_fp8_layer_for_marlin(
                layer, False, input_dtype=self.marlin_input_dtype
            )
1052
1053
1054
            # Activations not quantized for marlin.
            del layer.w13_input_scale
            del layer.w2_input_scale
1055

1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
        # NOTE(rob): this is a WIP refactor. We are first migrating
        # all of the kernels in the TP case to use mk. Once this is
        # done, then we will initialzie the TP case and DP/EP case
        # via the same code path (i.e. via maybe_init_modular_kernel).
        # NOTE(rob): in progress migrating all into this format.
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
            from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
                FlashInferExperts,
            )
            from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import (  # noqa: E501
                FlashInferAllGatherMoEPrepareAndFinalize,
            )

            config = self.get_fused_moe_quant_config(layer)
            assert config is not None
            self.moe_quant_config = config

            self.kernel = mk.FusedMoEModularKernel(
1074
1075
                # TODO(rob): we can use the generic MoEPrepareAndFinalizeNoEP
                # with the changes to defer input quantization
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
                FlashInferAllGatherMoEPrepareAndFinalize(
                    use_dp=(self.moe.dp_size > 1),
                    use_deepseek_fp8_block_scale=self.block_quant,
                ),
                FlashInferExperts(
                    out_dtype=torch.get_default_dtype(),
                    quant_config=self.moe_quant_config,
                    ep_rank=self.moe.ep_rank,
                    ep_size=self.moe.ep_size,
                    tp_rank=self.moe.tp_rank,
                    tp_size=self.moe.tp_size,
                    use_dp=(self.moe.dp_size > 1),
                    use_deepseek_fp8_block_scale=self.block_quant,
                ),
            )
            self.use_inplace = False

1093
1094
1095
1096
        elif self.fp8_backend in [
            Fp8MoeBackend.DEEPGEMM,
            Fp8MoeBackend.TRITON,
            Fp8MoeBackend.MARLIN,
1097
            Fp8MoeBackend.AITER,
1098
        ]:
1099
1100
1101
            from vllm.model_executor.layers.fused_moe import (
                TritonOrDeepGemmExperts,
            )
1102
1103
1104
            from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
                MarlinExperts,
            )
1105
1106
1107
            from vllm.model_executor.layers.fused_moe.prepare_finalize import (
                MoEPrepareAndFinalizeNoEP,
            )
1108
1109
1110
            from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
                AiterExperts,
            )
1111
1112
1113
1114

            config = self.get_fused_moe_quant_config(layer)
            assert config is not None
            self.moe_quant_config = config
1115

1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
            if self.fp8_backend == Fp8MoeBackend.AITER:
                self.kernel = mk.FusedMoEModularKernel(
                    # TODO: make defer_input_quant an attr of the AiterExperts
                    MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
                    AiterExperts(quant_config=self.moe_quant_config),
                )
            elif self.fp8_backend == Fp8MoeBackend.MARLIN:
                self.kernel = mk.FusedMoEModularKernel(
                    MoEPrepareAndFinalizeNoEP(),
                    MarlinExperts(quant_config=self.moe_quant_config),
                )
            else:
                self.kernel = mk.FusedMoEModularKernel(
                    MoEPrepareAndFinalizeNoEP(),
                    TritonOrDeepGemmExperts(
                        quant_config=self.moe_quant_config,
                        allow_deep_gemm=(self.fp8_backend == Fp8MoeBackend.DEEPGEMM),
                    ),
                )
1135
1136
            self.use_inplace = True

1137
1138
1139
1140
    def maybe_make_prepare_finalize(
        self,
        routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
    ) -> mk.FusedMoEPrepareAndFinalize | None:
1141
        if (
1142
            self.fp8_backend == Fp8MoeBackend.AITER
1143
            or self.fp8_backend == Fp8MoeBackend.MARLIN
1144
1145
            or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1146
1147
            return None
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
1148
1149
1150
1151
1152
1153
            if self.block_quant:
                assert self.weight_block_size == [128, 128], (
                    f"Only support weight_block_size == [128, 128], "
                    f"got {self.weight_block_size}"
                )
            # Wire block-scale flag through prepare/finalize when using CUTLASS
1154
            prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
1155
1156
                self.moe,
                use_deepseek_fp8_block_scale=self.block_quant,
1157
            )
1158
1159
1160
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
        else:
1161
            return super().maybe_make_prepare_finalize(routing_tables)
1162

bnellnm's avatar
bnellnm committed
1163
1164
1165
    def select_gemm_impl(
        self,
        prepare_finalize: FusedMoEPrepareAndFinalize,
1166
        layer: torch.nn.Module,
bnellnm's avatar
bnellnm committed
1167
    ) -> FusedMoEPermuteExpertsUnpermute:
1168
        from vllm.model_executor.layers.fused_moe import (
1169
1170
            BatchedDeepGemmExperts,
            BatchedTritonExperts,
1171
            TritonExperts,
1172
1173
            TritonOrDeepGemmExperts,
        )
1174

1175
1176
1177
1178
        if self.fp8_backend in [Fp8MoeBackend.MARLIN, Fp8MoeBackend.AITER]:
            raise NotImplementedError(
                "Marlin and ROCm AITER are not supported with all2all yet."
            )
1179

1180
1181
        assert self.moe_quant_config is not None

1182
1183
1184
1185
1186
        if (
            prepare_finalize.activation_format
            == FusedMoEActivationFormat.BatchedExperts
        ):
            max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
bnellnm's avatar
bnellnm committed
1187
            assert max_num_tokens_per_rank is not None
1188
1189

            experts_impl = (
1190
1191
1192
                BatchedDeepGemmExperts
                if self.fp8_backend == Fp8MoeBackend.DEEPGEMM
                else BatchedTritonExperts
1193
            )
bnellnm's avatar
bnellnm committed
1194
            logger.debug(
1195
1196
                "%s(%s): max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
                experts_impl.__name__,
1197
1198
1199
1200
1201
                self.__class__.__name__,
                max_num_tokens_per_rank,
                self.weight_block_size,
                False,
            )
1202
            return experts_impl(
1203
                max_num_tokens=max_num_tokens_per_rank,
1204
                num_dispatchers=prepare_finalize.num_dispatchers(),
1205
                quant_config=self.moe_quant_config,
1206
            )
1207
1208
        elif self.moe.is_lora_enabled:
            return TritonExperts(quant_config=self.moe_quant_config)
1209
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
1210
            # Select GEMM experts with block-scale when weights are block-quantized
1211
            experts = select_cutlass_fp8_gemm_impl(
1212
1213
                self.moe,
                self.moe_quant_config,
1214
                use_deepseek_fp8_block_scale=self.block_quant,
1215
1216
1217
            )
            logger.debug_once("Using %s", experts.__class__.__name__)
            return experts
1218
        else:
bnellnm's avatar
bnellnm committed
1219
1220
            logger.debug(
                "TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s",
1221
1222
1223
1224
                self.__class__.__name__,
                self.weight_block_size,
                False,
            )
bnellnm's avatar
bnellnm committed
1225
            return TritonOrDeepGemmExperts(
1226
                quant_config=self.moe_quant_config,
1227
                allow_deep_gemm=(self.fp8_backend == Fp8MoeBackend.DEEPGEMM),
1228
1229
            )

1230
    def get_fused_moe_quant_config(
1231
        self, layer: torch.nn.Module
1232
    ) -> FusedMoEQuantConfig | None:
1233
1234
1235
1236
1237
1238
        if self.fp8_backend == Fp8MoeBackend.MARLIN:
            return fp8_w8a16_moe_quant_config(
                w1_scale=layer.w13_weight_scale,
                w2_scale=layer.w2_weight_scale,
                block_shape=self.weight_block_size,
            )
1239
1240

        return fp8_w8a8_moe_quant_config(
1241
1242
1243
1244
1245
1246
1247
1248
            w1_scale=(
                layer.w13_weight_scale_inv
                if self.block_quant
                else layer.w13_weight_scale
            ),
            w2_scale=(
                layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
            ),
1249
1250
            a1_scale=layer.w13_input_scale,
            a2_scale=layer.w2_input_scale,
1251
            block_shape=self.weight_block_size,
1252
1253
        )

1254
1255
1256
1257
1258
1259
1260
1261
    @property
    def supports_eplb(self) -> bool:
        return True

    @property
    def allow_inplace(self) -> bool:
        return True

1262
1263
    def apply(
        self,
1264
        layer: FusedMoE,
1265
1266
        x: torch.Tensor,
        router_logits: torch.Tensor,
1267
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
1268
        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
1269
            # TODO(rob): convert this to MK.
1270
1271
1272
1273
            if layer.enable_eplb:
                raise NotImplementedError("EPLB not supported for `Fp8MoEMethod` yet.")
            assert layer.activation == "silu", (
                f"Expected 'silu' activation but got {layer.activation}"
1274
            )
1275

1276
            if self.block_quant:
1277
                import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe  # noqa: E501, F401
1278
1279

                e_score_correction_bias = (
1280
1281
                    layer.e_score_correction_bias.to(x.dtype)
                    if layer.e_score_correction_bias is not None
1282
1283
                    else None
                )
1284
                routing_method_type = layer.routing_method_type
1285
                return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
1286
1287
1288
                    routing_logits=router_logits.to(torch.float32)
                    if routing_method_type == RoutingMethodType.DeepSeekV3
                    else router_logits,
1289
1290
1291
1292
1293
1294
                    routing_bias=e_score_correction_bias,
                    x=x,
                    w13_weight=layer.w13_weight,
                    w13_weight_scale_inv=layer.w13_weight_scale_inv,
                    w2_weight=layer.w2_weight,
                    w2_weight_scale_inv=layer.w2_weight_scale_inv,
1295
1296
1297
1298
                    global_num_experts=layer.global_num_experts,
                    top_k=layer.top_k,
                    num_expert_group=layer.num_expert_group,
                    topk_group=layer.topk_group,
1299
1300
1301
                    intermediate_size=layer.intermediate_size_per_partition,
                    expert_offset=layer.ep_rank * layer.local_num_experts,
                    local_num_experts=layer.local_num_experts,
1302
                    block_shape=self.weight_block_size,
1303
                    routing_method_type=routing_method_type,
1304
                    routed_scaling=layer.routed_scaling_factor,
1305
1306
                )
            else:
1307
1308
1309
                assert (
                    not layer.renormalize and layer.custom_routing_function is not None
                )
XuruiYang's avatar
XuruiYang committed
1310
                result = apply_flashinfer_per_tensor_scale_fp8(
1311
1312
1313
                    layer=layer,
                    hidden_states=x,
                    router_logits=router_logits,
1314
1315
1316
1317
1318
1319
                    routing_bias=layer.e_score_correction_bias,
                    global_num_experts=layer.global_num_experts,
                    top_k=layer.top_k,
                    num_expert_group=layer.num_expert_group,
                    topk_group=layer.topk_group,
                    apply_router_weight_on_input=layer.apply_router_weight_on_input,
1320
                )
1321

1322
        topk_weights, topk_ids = layer.select_experts(
1323
1324
1325
            hidden_states=x,
            router_logits=router_logits,
        )
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
        result = self.kernel(
            x,
            layer.w13_weight,
            layer.w2_weight,
            topk_weights,
            topk_ids,
            inplace=self.use_inplace,
            activation=layer.activation,
            global_num_experts=layer.global_num_experts,
            expert_map=layer.expert_map,
            apply_router_weight_on_input=layer.apply_router_weight_on_input,
        )
1338

1339
        return result
1340
1341


1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
class Fp8OnlineMoEMethod(Fp8MoEMethod):
    """MoE method for online FP8 quantization.
    Supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

    def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
        super().__init__(quant_config, layer)
        assert not quant_config.is_checkpoint_fp8_serialized
        assert quant_config.activation_scheme == "dynamic"
        assert quant_config.weight_block_size is None
        assert self.flashinfer_moe_backend is None

    def create_weights(
        self,
        layer: Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

        # We are doing online quantization, patch the weight loaded
        # to call `process_weights_after_loading` in a streaming fashion
        # as soon as the last weight chunk is loaded.
        weight_loader = extra_weight_attrs["weight_loader"]
        # create a new holder to prevent modifying behavior of any other
        # objects which might depend on the old one
        new_extra_weight_attrs = extra_weight_attrs

        def patched_weight_loader(param, loaded_weight, *args, **kwargs):
            # add a counter to track how many elements we have updated
            if not hasattr(layer, "_loaded_numel"):
                layer._loaded_numel = 0
1386
1387
1388
1389
1390
1391

            # load the current weight chunk
            copy_numel_counter = CopyNumelCounter()
            with copy_numel_counter:
                res = weight_loader(param, loaded_weight, *args, **kwargs)  # type: ignore[misc]
            layer._loaded_numel += copy_numel_counter.copied_numel
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469

            # if we have loaded all of the elements, call
            # process_weights_after_loading
            target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel()
            if layer._loaded_numel == target_loaded_numel:
                self.process_weights_after_loading(layer)

                # Delete the bookkeeping
                del layer._loaded_numel
                # Prevent the usual `process_weights_after_loading` call
                # from doing anything
                layer._already_called_process_weights_after_loading = True

            return res

        new_extra_weight_attrs["weight_loader"] = patched_weight_loader
        extra_weight_attrs = new_extra_weight_attrs

        # WEIGHTS
        w13_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                hidden_size,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

        w2_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
        # Allocate 2 scales for w1 and w3 respectively.
        # They will be combined to a single scale after weight loading.
        w13_weight_scale = torch.nn.Parameter(
            torch.ones(num_experts, dtype=torch.float32), requires_grad=False
        )
        w2_weight_scale = torch.nn.Parameter(
            torch.ones(num_experts, dtype=torch.float32), requires_grad=False
        )
        layer.register_parameter("w13_weight_scale", w13_weight_scale)
        layer.register_parameter("w2_weight_scale", w2_weight_scale)

        layer.w13_input_scale = None
        layer.w2_input_scale = None

    def process_weights_after_loading(self, layer: Module) -> None:
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

        # If checkpoint is fp16, quantize in place.
        fp8_dtype = current_platform.fp8_dtype()
        w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
        w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)

        for expert in range(layer.local_num_experts):
            w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
                ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
            )
            w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
                ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
            )
        replace_parameter(layer, "w13_weight", w13_weight)
        replace_parameter(layer, "w2_weight", w2_weight)

        # Reshuffle weights for AITER if needed.
1470
        if self.fp8_backend == Fp8MoeBackend.AITER:
1471
1472
1473
1474
1475
1476
1477
            shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
                layer.w13_weight, layer.w2_weight
            )
            replace_parameter(layer, "w13_weight", shuffled_w13)
            replace_parameter(layer, "w2_weight", shuffled_w2)

        # Rushuffle weights for MARLIN if needed.
1478
        elif self.fp8_backend == Fp8MoeBackend.MARLIN:
1479
1480
1481
1482
1483
            prepare_moe_fp8_layer_for_marlin(
                layer, False, input_dtype=self.marlin_input_dtype
            )


1484
1485
1486
class Fp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
1487
1488
1489
    """

    def __init__(self, quant_config: Fp8Config):
1490
        super().__init__(quant_config)