fp8.py 57.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from enum import Enum
5
from typing import TYPE_CHECKING, Any, Optional
6
7
8

import torch
from torch.nn import Module
9
from torch.utils._python_dispatch import TorchDispatchMode
10

11
import vllm.envs as envs
12
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
13
from vllm import _custom_ops as ops
14
from vllm._aiter_ops import rocm_aiter_ops
15
from vllm.attention.layer import Attention
16
from vllm.distributed import get_tensor_model_parallel_world_size
17
from vllm.logger import init_logger
18
from vllm.model_executor.layers.batch_invariant import (
19
    vllm_is_batch_invariant,
20
)
bnellnm's avatar
bnellnm committed
21
from vllm.model_executor.layers.fused_moe import (
22
23
24
25
26
27
28
    FusedMoE,
    FusedMoEActivationFormat,
    FusedMoEMethodBase,
    FusedMoEPermuteExpertsUnpermute,
    FusedMoEPrepareAndFinalize,
    FusedMoeWeightScaleSupported,
)
29
from vllm.model_executor.layers.fused_moe.config import (
30
    FusedMoEParallelConfig,
31
    FusedMoEQuantConfig,
32
    RoutingMethodType,
33
    fp8_w8a8_moe_quant_config,
34
    fp8_w8a16_moe_quant_config,
35
36
37
38
39
40
41
)
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
)
42
from vllm.model_executor.layers.quantization import QuantizationMethods
43
from vllm.model_executor.layers.quantization.base_config import (
44
45
46
    QuantizationConfig,
    QuantizeMethodBase,
)
47
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
48
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
49
50
    FlashinferMoeBackend,
    apply_flashinfer_per_tensor_scale_fp8,
51
    build_flashinfer_fp8_cutlass_moe_prepare_finalize,
52
53
54
55
56
57
    get_flashinfer_moe_backend,
    register_moe_scaling_factors,
    rotate_flashinfer_fp8_moe_weights,
    select_cutlass_fp8_gemm_impl,
    swap_w13_to_w31,
)
58
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
59
60
61
62
    W8A8BlockFp8LinearOp,
    create_fp8_input_scale,
    create_fp8_scale_parameter,
    create_fp8_weight_parameter,
63
    deepgemm_post_process_fp8_weight_block,
64
65
66
67
68
    maybe_post_process_fp8_weight_block,
    process_fp8_weight_block_strategy,
    process_fp8_weight_tensor_strategy,
    validate_fp8_block_shape,
)
69
70
71
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
    get_marlin_input_dtype,
)
72
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
73
74
75
76
    apply_fp8_marlin_linear,
    prepare_fp8_layer_for_marlin,
    prepare_moe_fp8_layer_for_marlin,
)
77
from vllm.model_executor.layers.quantization.utils.quant_utils import (
78
79
80
    GroupShape,
    is_layer_skipped,
)
81
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
82
83
84
85
86
87
88
89
90
91
92
93
94
    Fp8LinearOp,
    all_close_1d,
    cutlass_block_fp8_supported,
    cutlass_fp8_supported,
    maybe_create_device_identity,
    normalize_e4m3fn_to_e4m3fnuz,
    per_tensor_dequantize,
)
from vllm.model_executor.parameter import (
    BlockQuantScaleParameter,
    ModelWeightParameter,
    PerTensorScaleParameter,
)
95
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
96
from vllm.platforms import current_platform
97
98
99
100
from vllm.utils.deep_gemm import (
    is_deep_gemm_e8m0_used,
    is_deep_gemm_supported,
)
101
from vllm.utils.flashinfer import has_flashinfer_moe
102
from vllm.utils.import_utils import has_deep_gemm
103

104
105
106
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

107
108
109
110
ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)

111

112
113
114
115
116
class Fp8MoeBackend(Enum):
    NONE = 0
    FLASHINFER_TRTLLM = 1
    FLASHINFER_CUTLASS = 2
    DEEPGEMM = 3
117
118
    MARLIN = 4
    TRITON = 5
119
    AITER = 6
120
121


122
def get_fp8_moe_backend(
123
124
125
    block_quant: bool,
    moe_parallel_config: FusedMoEParallelConfig,
    with_lora_support: bool,
126
) -> Fp8MoeBackend | None:
127
128
129
130
    """
    Select the primary FP8 MoE backend
    Note: Shape-specific fallbacks may still occur at runtime.
    """
131
132
    if current_platform.is_xpu():
        return None
133
134
    if with_lora_support:
        return Fp8MoeBackend.TRITON
135
    # Prefer FlashInfer backends on supported GPUs; allow SM90 and SM100.
136
137
    if (
        current_platform.is_cuda()
138
        and (
139
            current_platform.is_device_capability_family(100)
140
141
            or current_platform.is_device_capability(90)
        )
142
143
144
        and envs.VLLM_USE_FLASHINFER_MOE_FP8
        and has_flashinfer_moe()
    ):
145
146
        backend = get_flashinfer_moe_backend()
        if backend == FlashinferMoeBackend.TENSORRT_LLM:
147
            logger.info_once("Using FlashInfer FP8 MoE TRTLLM backend for SM100")
148
149
            return Fp8MoeBackend.FLASHINFER_TRTLLM
        else:
150
            if block_quant and current_platform.is_device_capability_family(100):
151
152
153
154
155
156
                raise ValueError(
                    "FlashInfer FP8 MoE throughput backend does not "
                    "support block quantization. Please use "
                    "VLLM_FLASHINFER_MOE_BACKEND=latency "
                    "instead."
                )
157
            logger.info_once("Using FlashInfer FP8 MoE CUTLASS backend for SM90/SM100")
158
159
160
            return Fp8MoeBackend.FLASHINFER_CUTLASS

    # weight-only path for older GPUs without native FP8
161
162
163
164
    use_marlin = (
        not current_platform.has_device_capability(89)
        or envs.VLLM_TEST_FORCE_FP8_MARLIN
    )
165
166
167
168
169
170
    if current_platform.is_rocm():
        use_marlin = False
    if use_marlin:
        logger.info_once("Using Marlin backend for FP8 MoE")
        return Fp8MoeBackend.MARLIN

171
172
173
174
175
176
177
178
179
180
181
182
183
    # Determine if we should use DeepGEMM with block-quantized weights:
    # - If explicitly set by user, respect their choice
    # - If not explicitly set (default), disable when TP size is >= 8
    moe_use_deep_gemm = envs.VLLM_MOE_USE_DEEP_GEMM
    if not envs.is_set("VLLM_MOE_USE_DEEP_GEMM") and moe_parallel_config.tp_size >= 8:
        moe_use_deep_gemm = False
        logger.info_once(
            "DeepGEMM MoE is disabled by default when TP size is >= 8. "
            "Set VLLM_MOE_USE_DEEP_GEMM=1 to enable it.",
            scope="local",
        )

    if envs.VLLM_USE_DEEP_GEMM and moe_use_deep_gemm and block_quant:
184
        if not has_deep_gemm():
185
186
187
            logger.warning_once(
                "DeepGEMM backend requested but not available.", scope="local"
            )
188
        elif is_deep_gemm_supported():
189
            logger.info_once("Using DeepGEMM backend for FP8 MoE", scope="local")
190
191
            return Fp8MoeBackend.DEEPGEMM

192
193
194
195
    if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MOE:
        logger.info_once("Using ROCm AITER backend for FP8 MoE", scope="local")
        return Fp8MoeBackend.AITER

196
197
198
199
200
    # default to Triton
    logger.info_once("Using Triton backend for FP8 MoE")
    return Fp8MoeBackend.TRITON


201
class Fp8Config(QuantizationConfig):
202
203
    """Config class for FP8."""

204
205
    def __init__(
        self,
206
        is_checkpoint_fp8_serialized: bool = False,
207
        activation_scheme: str = "dynamic",
208
209
        ignored_layers: list[str] | None = None,
        weight_block_size: list[int] | None = None,
210
    ) -> None:
211
        super().__init__()
212

213
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
214

215
        if activation_scheme not in ACTIVATION_SCHEMES:
216
            raise ValueError(f"Unsupported activation scheme {activation_scheme}")
217
        self.activation_scheme = activation_scheme
218
        self.ignored_layers = ignored_layers or []
219
220
221
222
        if weight_block_size is not None:
            if not is_checkpoint_fp8_serialized:
                raise ValueError(
                    "The block-wise quantization only supports fp8-serialized "
223
224
                    "checkpoint for now."
                )
225
226
227
            if len(weight_block_size) != 2:
                raise ValueError(
                    "The quantization block size of weight must have 2 "
228
229
                    f"dimensions, but got {len(weight_block_size)} dimensions"
                )
230
            if activation_scheme != "dynamic":
231
232
233
234
235
                raise ValueError(
                    "The block-wise quantization only supports "
                    "dynamic activation scheme for now, but got "
                    f"{activation_scheme} activation scheme."
                )
236
        self.weight_block_size = weight_block_size
237

238
    @classmethod
239
    def get_name(cls) -> QuantizationMethods:
240
241
242
        return "fp8"

    @classmethod
243
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
244
245
246
247
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
248
        return 75
249
250

    @classmethod
251
    def get_config_filenames(cls) -> list[str]:
252
253
        return []

254
255
    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if self.ignored_layers is not None:
256
            self.ignored_layers = hf_to_vllm_mapper.apply_list(self.ignored_layers)
257

258
    @classmethod
259
    def from_config(cls, config: dict[str, Any]) -> "Fp8Config":
260
        quant_method = cls.get_from_keys(config, ["quant_method"])
261
        is_checkpoint_fp8_serialized = "fp8" in quant_method
262
        activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
263
        ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
264
        weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None)
265
        if not ignored_layers:
266
267
268
269
270
271
272
273
274
275
276
277
278
            ignored_layers = cls.get_from_keys_or(
                config, ["modules_to_not_convert"], None
            )
        return cls(
            is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
            activation_scheme=activation_scheme,
            ignored_layers=ignored_layers,
            weight_block_size=weight_block_size,
        )

    def get_xpu_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
279
        from vllm.model_executor.layers.quantization.ipex_quant import (
280
281
282
283
            XPUFp8LinearMethod,
            XPUFp8MoEMethod,
        )

284
285
286
287
        fp8_config = Fp8Config(
            is_checkpoint_fp8_serialized=self.is_checkpoint_fp8_serialized,
            activation_scheme=self.activation_scheme,
            ignored_layers=self.ignored_layers,
288
289
            weight_block_size=self.weight_block_size,
        )
290
291

        if isinstance(layer, LinearBase):
292
293
294
295
296
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
297
298
299
                return UnquantizedLinearMethod()
            return XPUFp8LinearMethod(fp8_config)
        elif isinstance(layer, FusedMoE):
300
301
302
303
304
305
306
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
                return UnquantizedFusedMoEMethod(layer.moe_config)

307
308
309
310
311
            return XPUFp8MoEMethod(fp8_config, layer)
        elif isinstance(layer, Attention):
            return Fp8KVCacheMethod(self)
        return None

312
313
314
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
315
316
        if current_platform.is_xpu():
            return self.get_xpu_quant_method(layer, prefix)
317
        if isinstance(layer, LinearBase):
318
319
320
321
322
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
323
                return UnquantizedLinearMethod()
324
325
326
            quant_method = Fp8LinearMethod(self)
            quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
            return quant_method
327
        elif isinstance(layer, FusedMoE):
328
329
330
331
332
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
XuruiYang's avatar
XuruiYang committed
333
                return UnquantizedFusedMoEMethod(layer.moe_config)
334
335
336
337
            if self.is_checkpoint_fp8_serialized:
                moe_quant_method = Fp8MoEMethod(self, layer)
            else:
                moe_quant_method = Fp8OnlineMoEMethod(self, layer)
338
339
            moe_quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
            return moe_quant_method
340
        elif isinstance(layer, Attention):
341
            return Fp8KVCacheMethod(self)
342
        return None
343

344
    def get_cache_scale(self, name: str) -> str | None:
345
346
347
348
349
350
351
352
353
354
355
356
        """
        Check whether the param name matches the format for k/v cache scales
        in compressed-tensors. If this is the case, return its equivalent
        param name expected by vLLM

        :param name: param name
        :return: matching param name for KV cache scale in vLLM
        """
        if name.endswith(".output_scale") and ".k_proj" in name:
            return name.replace(".k_proj.output_scale", ".attn.k_scale")
        if name.endswith(".output_scale") and ".v_proj" in name:
            return name.replace(".v_proj.output_scale", ".attn.v_scale")
357
358
359
360
361
        if name.endswith(".output_scale") and ".q_proj" in name:
            return name.replace(".q_proj.output_scale", ".attn.q_scale")
        if name.endswith("self_attn.prob_output_scale"):
            return name.replace(".prob_output_scale", ".attn.prob_scale")
        # If no matches, return None
362
363
        return None

364

365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
class CopyNumelCounter(TorchDispatchMode):
    """
    Tracks total number of elements modified with `copy_`. Useful for keeping
    track of weight loading where underlying weights can be arbitrarily
    transformed (such as with `narrow`) before calling copy.
    """

    def __init__(self):
        super().__init__()
        self.copied_numel = 0

    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        out = func(*args, **kwargs)
        if func == torch.ops.aten.copy_.default:
            self.copied_numel += args[0].numel()
        return out


385
386
class Fp8LinearMethod(LinearMethodBase):
    """Linear method for FP8.
387
388
389
390
391
392
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.
393
394
395
396
397

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
    2. Only support float8_e4m3fn data type due to the limitation of
       torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
398

399
400
401
402
    Args:
        quant_config: The quantization config.
    """

403
    def __init__(self, quant_config: Fp8Config):
404
        self.quant_config = quant_config
405
        self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
406
        self.out_dtype = torch.get_default_dtype()
407

408
409
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
410
        self.marlin_input_dtype = None
411
412
413
414
        self.use_marlin = (
            not current_platform.has_device_capability(89)
            or envs.VLLM_TEST_FORCE_FP8_MARLIN
        )
415
        # Disable marlin for rocm
416
        if current_platform.is_rocm():
417
            self.use_marlin = False
418
        if vllm_is_batch_invariant():
419
            self.use_marlin = False
420

421
        self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled()
422
        self.use_deep_gemm = is_deep_gemm_supported()
423

424
425
        self.weight_block_size = self.quant_config.weight_block_size
        self.block_quant = self.weight_block_size is not None
426
        self.act_q_static = self.quant_config.activation_scheme == "static"
427
428
        if self.weight_block_size:
            self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
429
        else:
430
431
432
433
434
            # Use per-token quantization for better perf if dynamic and cutlass
            if not self.act_q_static and cutlass_fp8_supported():
                self.act_q_group_shape = GroupShape.PER_TOKEN
            else:
                self.act_q_group_shape = GroupShape.PER_TENSOR
435

436
437
438
439
440
441
442
443
444
445
446
447
        if self.block_quant:
            assert not self.act_q_static
            assert self.weight_block_size is not None
            self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
                weight_group_shape=GroupShape(*self.weight_block_size),
                act_quant_group_shape=self.act_q_group_shape,
                cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
                use_aiter_and_is_supported=self.use_aiter_and_is_supported,
            )
        else:
            self.fp8_linear = Fp8LinearOp(
                act_quant_static=self.act_q_static,
448
449
                act_quant_group_shape=self.act_q_group_shape,
            )
450

451
452
453
454
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
455
        output_partition_sizes: list[int],
456
457
458
459
460
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
461
462
        maybe_create_device_identity()

463
        output_size_per_partition = sum(output_partition_sizes)
464
        weight_loader = extra_weight_attrs.get("weight_loader")
465
466
467
468
469
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None
470

471
        if self.block_quant:
472
473
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
474
475
476
477
478
479
480
481
            validate_fp8_block_shape(
                layer,
                input_size,
                output_size,
                input_size_per_partition,
                output_partition_sizes,
                self.weight_block_size,
            )
482

483
        # WEIGHT
484
        if self.quant_config.is_checkpoint_fp8_serialized:
485
486
487
            weight = create_fp8_weight_parameter(
                output_size_per_partition, input_size_per_partition, weight_loader
            )
488
        else:
489
490
491
492
493

            def patched_weight_loader(param, loaded_weight, *args, **kwargs):
                # track how many elements we have updated
                if not hasattr(layer, "_loaded_numel"):
                    layer._loaded_numel = 0
494
495
496
497
498
499

                # load the current weight chunk
                copy_numel_counter = CopyNumelCounter()
                with copy_numel_counter:
                    res = weight_loader(param, loaded_weight, *args, **kwargs)  # type: ignore[misc]
                layer._loaded_numel += copy_numel_counter.copied_numel
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514

                # if we have loaded all of the elements, call
                # process_weights_after_loading
                target_loaded_numel = layer.weight.numel()
                if layer._loaded_numel == target_loaded_numel:
                    self.process_weights_after_loading(layer)

                    # Delete the bookkeeping
                    del layer._loaded_numel
                    # Prevent the usual `process_weights_after_loading` call from doing
                    # anything
                    layer._already_called_process_weights_after_loading = True

                return res

515
            # For non-serialized checkpoints, use original dtype
516
517
518
519
520
521
522
523
            weight = ModelWeightParameter(
                data=torch.empty(
                    output_size_per_partition,
                    input_size_per_partition,
                    dtype=params_dtype,
                ),
                input_dim=1,
                output_dim=0,
524
                weight_loader=patched_weight_loader,
525
            )
526
527
        layer.register_parameter("weight", weight)

528
529
530
531
        # If checkpoint is serialized fp8, load them.
        # Otherwise, wait until process_weights_after_loading.
        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
532
            if not self.block_quant:
533
534
535
536
537
538
539
                scale = create_fp8_scale_parameter(
                    PerTensorScaleParameter,
                    output_partition_sizes,
                    input_size_per_partition,
                    None,
                    weight_loader,
                )
540
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
541
542
                layer.register_parameter("weight_scale", scale)
            else:
543
544
                assert not self.act_q_static
                assert self.weight_block_size is not None
545
546
547
548
549
550
551
                scale = create_fp8_scale_parameter(
                    BlockQuantScaleParameter,
                    output_partition_sizes,
                    input_size_per_partition,
                    self.weight_block_size,
                    weight_loader,
                )
552
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
553
554
                # The weight_scale_inv name is intentional for deepseekv3
                layer.register_parameter("weight_scale_inv", scale)
555

556
            # INPUT ACTIVATION SCALE
557
            if self.act_q_static:
558
                scale = create_fp8_input_scale(output_partition_sizes, weight_loader)
559
                set_weight_attrs(scale, {"scale_type": "input_scale"})
560
                layer.register_parameter("input_scale", scale)
561
562
            else:
                layer.register_parameter("input_scale", None)
563

564
    def process_weights_after_loading(self, layer: Module) -> None:
565
566
567
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

568
        size_k_first = True
569
        input_scale = None
570
        # TODO(rob): refactor block quant into separate class.
571
        if self.block_quant:
572
            assert not self.act_q_static
573
            size_k_first = False
574

575
            weight, weight_scale_inv = process_fp8_weight_block_strategy(
576
577
                layer.weight, layer.weight_scale_inv
            )
578
579
580
581

            # Update layer with new values
            replace_parameter(layer, "weight", weight.data)
            replace_parameter(layer, "weight_scale_inv", weight_scale_inv.data)
582

583
        # If checkpoint not serialized fp8, quantize the weights.
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
        else:
            if not self.quant_config.is_checkpoint_fp8_serialized:
                qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
                weight = qweight.t()

            # If checkpoint is fp8 per-tensor, handle that there are N scales for N
            # shards in a fused module
            else:
                weight = layer.weight
                weight_scale = layer.weight_scale

                # If using w8a8, torch._scaled_mm needs per tensor, so
                # requantize the logical shards as a single weight.
                if not self.use_marlin:
                    weight, weight_scale, input_scale = (
                        process_fp8_weight_tensor_strategy(
                            weight,
                            weight_scale,
                            layer.logical_widths,
                            getattr(layer, "input_scale", None),
                        )
                    )
                    if self.act_q_static:
                        assert input_scale is not None
                        input_scale = input_scale.max()
                weight = weight.t()
610

611
612
613
614
615
616
            # Update layer with new values.
            replace_parameter(layer, "weight", weight.data)
            replace_parameter(layer, "weight_scale", weight_scale.data)

        if input_scale is not None:
            replace_parameter(layer, "input_scale", input_scale)
617
        else:
618
            layer.input_scale = None
619

620
        if self.use_marlin:
621
622
623
            prepare_fp8_layer_for_marlin(
                layer, size_k_first, input_dtype=self.marlin_input_dtype
            )
624
625
            # Activations not quantized for marlin.
            del layer.input_scale
626
            return
627

628
        if self.block_quant:
629
            maybe_post_process_fp8_weight_block(layer)
630

631
632
633
634
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
635
        bias: torch.Tensor | None = None,
636
    ) -> torch.Tensor:
637
638
        # if batch invariant mode is enabled, prefer DeepGEMM FP8 path
        # we will use BF16 dequant when DeepGEMM is not supported.
639
        if vllm_is_batch_invariant():
640
641
            if self.block_quant:
                assert self.weight_block_size is not None
642
643
644
                return self.w8a8_block_fp8_linear.apply(
                    input=x,
                    weight=layer.weight,
645
                    weight_scale=layer.weight_scale_inv,
646
647
648
                    input_scale=layer.input_scale,
                    bias=bias,
                )
649
            else:
650
651
652
                # per-tensor/channel: dequant to BF16 and run GEMM
                weight_fp8 = layer.weight.to(torch.bfloat16)
                weight_scale = layer.weight_scale.to(torch.bfloat16)
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
                if weight_scale.numel() == 1:
                    # Per-tensor: simple scalar multiplication
                    weight_bf16 = weight_fp8 * weight_scale
                else:
                    # Multiple scales (fused modules like QKV)
                    # Try to infer correct broadcasting
                    # weight is [K, N], scale could be [num_logical_weights]
                    # Need to figure out how to broadcast - for now just try
                    # direct multiplication
                    if (
                        weight_scale.dim() == 1
                        and weight_scale.shape[0] == weight_fp8.shape[0]
                    ):
                        # Per-row scaling
                        weight_bf16 = weight_fp8 * weight_scale.unsqueeze(1)
                    else:
                        # Fallback
                        weight_bf16 = weight_fp8 * weight_scale
671
                return torch.nn.functional.linear(x, weight_bf16.t(), bias)
672

673
        if self.use_marlin:
674
675
676
677
678
            if self.block_quant:
                weight_scale = layer.weight_scale_inv
            else:
                weight_scale = layer.weight_scale

679
680
681
            return apply_fp8_marlin_linear(
                input=x,
                weight=layer.weight,
682
                weight_scale=weight_scale,
683
684
685
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
686
                input_dtype=self.marlin_input_dtype,
687
688
                bias=bias,
            )
689

690
        if self.block_quant:
691
692
693
            assert self.weight_block_size is not None

            return self.w8a8_block_fp8_linear.apply(
694
                input=x,
695
                weight=layer.weight,
696
                weight_scale=layer.weight_scale_inv,
697
                input_scale=layer.input_scale,
698
                bias=bias,
699
            )
700

701
702
703
704
705
706
707
708
        return self.fp8_linear.apply(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            out_dtype=self.out_dtype,
            input_scale=layer.input_scale,
            bias=bias,
        )
709
710


711
712
713
714
715
716
717
718
719
720
721
722
723
class Fp8MoEMethod(FusedMoEMethodBase):
    """MoE method for FP8.
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

724
725
726
    def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
        super().__init__(layer.moe_config)
        self.layer = layer
727
        self.quant_config = quant_config
728
        self.weight_block_size = self.quant_config.weight_block_size
729
        self.block_quant: bool = self.weight_block_size is not None
730
731
732
        self.weight_scale_name = (
            "weight_scale_inv" if self.block_quant else "weight_scale"
        )
733
        self.fp8_backend = get_fp8_moe_backend(
734
            self.block_quant, layer.moe_parallel_config, self.moe.is_lora_enabled
735
        )
736

737
        self.marlin_input_dtype = None
738
        self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
739
740
741
742
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
            self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
        elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
            self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
            if self.block_quant and self.weight_block_size != [128, 128]:
                raise NotImplementedError(
                    "FlashInfer CUTLASS FP8 MoE backend only supports block "
                    "size [128, 128]."
                )
            if not self.block_quant:
                if layer.renormalize or layer.custom_routing_function is not None:
                    raise NotImplementedError(
                        "FlashInfer CUTLASS FP8 MoE backend does custom routing "
                        f"function or renormalization, but got {layer.renormalize} and "
                        f"{layer.custom_routing_function}."
                    )
                if layer.scoring_func != "sigmoid":
                    raise NotImplementedError(
                        "FlashInfer CUTLASS FP8 MoE backend only supports "
                        f"'sigmoid' scoring function, but got {layer.scoring_func}."
                    )
            if layer.activation != "silu":
                raise NotImplementedError(
                    "FlashInfer CUTLASS FP8 MoE backend only supports SiLU "
                    "activation function, but got {layer.activation}."
764
                )
765

766
767
768
769
770
771
772
773
774
    def create_weights(
        self,
        layer: Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
775
776
777
778
779
780
        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

781
782
783
        assert self.quant_config.is_checkpoint_fp8_serialized
        params_dtype = torch.float8_e4m3fn

784
        if self.block_quant:
785
786
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
787
788
            tp_size = get_tensor_model_parallel_world_size()
            block_n, block_k = (
789
790
                self.weight_block_size[0],
                self.weight_block_size[1],
791
792
793
794
795
            )
            # NOTE: To ensure proper alignment of the block-wise quantization
            # scales, the output_size of the weights for both the gate and up
            # layers must be divisible by block_n.
            # Required by column parallel or enabling merged weights
796
            if intermediate_size_per_partition % block_n != 0:
797
798
                raise ValueError(
                    f"The output_size of gate's and up's weight = "
799
                    f"{intermediate_size_per_partition} is not divisible by "
800
801
802
                    f"weight quantization block_n = {block_n}."
                )
            if tp_size > 1 and intermediate_size_per_partition % block_k != 0:
803
                # Required by row parallel
804
805
806
                raise ValueError(
                    f"The input_size of down's weight = "
                    f"{intermediate_size_per_partition} is not divisible by "
807
808
                    f"weight quantization block_k = {block_k}."
                )
809
810

        # WEIGHTS
811
812
813
814
815
816
817
818
819
        w13_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                hidden_size,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
820
821
822
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

823
824
825
826
827
828
829
830
831
        w2_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
832
833
834
835
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
836
        if not self.block_quant:
837
838
839
            # For per-tensor quant, the scales are per expert and weight.
            w13_scale_data = torch.ones(num_experts, 2, dtype=torch.float32)
            w2_scale_data = torch.ones(num_experts, dtype=torch.float32)
840
        else:
841
842
843
844
845
846
            # For block quant, the scales are per block (typically 128x128).
            w13_scale_data = torch.ones(
                num_experts,
                2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
                (hidden_size + block_k - 1) // block_k,
                dtype=torch.float32,
847
            )
848
849
850
851
852
            w2_scale_data = torch.ones(
                num_experts,
                (hidden_size + block_n - 1) // block_n,
                (intermediate_size_per_partition + block_k - 1) // block_k,
                dtype=torch.float32,
853
            )
854
855
856
857
858
        w13_weight_scale = torch.nn.Parameter(w13_scale_data, requires_grad=False)
        w2_weight_scale = torch.nn.Parameter(w2_scale_data, requires_grad=False)
        # Note: name is weight_scale for tensor, weight_scale_inv for block.
        layer.register_parameter(f"w13_{self.weight_scale_name}", w13_weight_scale)
        layer.register_parameter(f"w2_{self.weight_scale_name}", w2_weight_scale)
859

860
861
862
        # Add the quantization method used (per tensor/grouped/channel)
        # to ensure the weight scales are loaded in properly
        extra_weight_attrs.update(
863
864
865
866
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
            if self.block_quant
            else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
        )
867
868
        set_weight_attrs(w13_weight_scale, extra_weight_attrs)
        set_weight_attrs(w2_weight_scale, extra_weight_attrs)
869
870
871

        # INPUT_SCALES
        if self.quant_config.activation_scheme == "static":
872
            assert not self.block_quant
873
874
875
            w13_input_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
876
            layer.register_parameter("w13_input_scale", w13_input_scale)
877
            set_weight_attrs(w13_input_scale, extra_weight_attrs)
878

879
880
881
            w2_input_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
882
            layer.register_parameter("w2_input_scale", w2_input_scale)
883
884
            set_weight_attrs(w2_input_scale, extra_weight_attrs)

885
        else:
886
887
            layer.w13_input_scale = None
            layer.w2_input_scale = None
888

889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
    def _convert_weights_to_kernel_format(
        self,
        layer: Module,
        w13_weight: torch.Tensor,
        w2_weight: torch.Tensor,
        w13_weight_scale: torch.Tensor,
        w2_weight_scale: torch.Tensor,
    ) -> None:
        if self.fp8_backend == Fp8MoeBackend.DEEPGEMM:
            assert self.block_quant
            w13_weight, w13_weight_scale = deepgemm_post_process_fp8_weight_block(
                wq=w13_weight,
                ws=w13_weight_scale,
                quant_block_shape=tuple(layer.weight_block_size),
                use_e8m0=is_deep_gemm_e8m0_used(),
            )
            w2_weight, w2_weight_scale = deepgemm_post_process_fp8_weight_block(
                wq=w2_weight,
                ws=w2_weight_scale,
                quant_block_shape=tuple(layer.weight_block_size),
                use_e8m0=is_deep_gemm_e8m0_used(),
            )
        elif self.fp8_backend == Fp8MoeBackend.AITER:
            w13_weight, w2_weight = rocm_aiter_ops.shuffle_weights(
                w13_weight, w2_weight
            )
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
        elif self.fp8_backend == Fp8MoeBackend.MARLIN:
            (
                workspace,
                w13_weight,
                w2_weight,
                w13_weight_scale,
                w2_weight_scale,
            ) = prepare_moe_fp8_layer_for_marlin(
                layer,
                w13_weight,
                w2_weight,
                w13_weight_scale,
                w2_weight_scale,
                input_dtype=self.marlin_input_dtype,
            )
            layer.workspace = workspace

932
933
934
935
936
937
938
        elif self.fp8_backend in [
            Fp8MoeBackend.FLASHINFER_CUTLASS,
            Fp8MoeBackend.FLASHINFER_TRTLLM,
        ]:
            w13_weight = swap_w13_to_w31(w13_weight)
            if self.block_quant:
                w13_weight_scale = swap_w13_to_w31(w13_weight_scale)
939
            else:
940
941
                # TODO(rob): this function is a hack that renames the scaling
                # factors in the Module. This is a hack we should clean up.
942
                register_moe_scaling_factors(layer)
943
                if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
944
                    rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
945
946
947
948
        elif self.fp8_backend == Fp8MoeBackend.AITER:
            w13_weight, w2_weight = rocm_aiter_ops.shuffle_weights(
                w13_weight, w2_weight
            )
949

950
951
952
953
954
955
956
957
958
        # Replace parameters with updated versions. Note that this helper
        # function ensures the replacement is compatible with RL weight reloads.
        replace_parameter(layer, "w13_weight", w13_weight)
        replace_parameter(layer, "w2_weight", w2_weight)
        replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_weight_scale)
        replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_weight_scale)

    def _setup_kernel(self, layer: Module) -> None:
        """Setup Modular Kernel for TP Case"""
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
        # NOTE(rob): this is a WIP refactor. We are first migrating
        # all of the kernels in the TP case to use mk. Once this is
        # done, then we will initialzie the TP case and DP/EP case
        # via the same code path (i.e. via maybe_init_modular_kernel).
        # NOTE(rob): in progress migrating all into this format.
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
            from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
                FlashInferExperts,
            )
            from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import (  # noqa: E501
                FlashInferAllGatherMoEPrepareAndFinalize,
            )

            config = self.get_fused_moe_quant_config(layer)
            assert config is not None
            self.moe_quant_config = config

            self.kernel = mk.FusedMoEModularKernel(
977
978
                # TODO(rob): we can use the generic MoEPrepareAndFinalizeNoEP
                # with the changes to defer input quantization
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
                FlashInferAllGatherMoEPrepareAndFinalize(
                    use_dp=(self.moe.dp_size > 1),
                    use_deepseek_fp8_block_scale=self.block_quant,
                ),
                FlashInferExperts(
                    out_dtype=torch.get_default_dtype(),
                    quant_config=self.moe_quant_config,
                    ep_rank=self.moe.ep_rank,
                    ep_size=self.moe.ep_size,
                    tp_rank=self.moe.tp_rank,
                    tp_size=self.moe.tp_size,
                    use_dp=(self.moe.dp_size > 1),
                    use_deepseek_fp8_block_scale=self.block_quant,
                ),
            )
            self.use_inplace = False

996
997
998
999
        elif self.fp8_backend in [
            Fp8MoeBackend.DEEPGEMM,
            Fp8MoeBackend.TRITON,
            Fp8MoeBackend.MARLIN,
1000
            Fp8MoeBackend.AITER,
1001
        ]:
1002
1003
1004
            from vllm.model_executor.layers.fused_moe import (
                TritonOrDeepGemmExperts,
            )
1005
1006
1007
            from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
                MarlinExperts,
            )
1008
1009
1010
            from vllm.model_executor.layers.fused_moe.prepare_finalize import (
                MoEPrepareAndFinalizeNoEP,
            )
1011
1012
1013
            from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
                AiterExperts,
            )
1014
1015
1016
1017

            config = self.get_fused_moe_quant_config(layer)
            assert config is not None
            self.moe_quant_config = config
1018

1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
            if self.fp8_backend == Fp8MoeBackend.AITER:
                self.kernel = mk.FusedMoEModularKernel(
                    # TODO: make defer_input_quant an attr of the AiterExperts
                    MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
                    AiterExperts(quant_config=self.moe_quant_config),
                )
            elif self.fp8_backend == Fp8MoeBackend.MARLIN:
                self.kernel = mk.FusedMoEModularKernel(
                    MoEPrepareAndFinalizeNoEP(),
                    MarlinExperts(quant_config=self.moe_quant_config),
                )
            else:
                self.kernel = mk.FusedMoEModularKernel(
                    MoEPrepareAndFinalizeNoEP(),
                    TritonOrDeepGemmExperts(
                        quant_config=self.moe_quant_config,
                        allow_deep_gemm=(self.fp8_backend == Fp8MoeBackend.DEEPGEMM),
                    ),
                )
1038
1039
            self.use_inplace = True

1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
    def process_weights_after_loading(self, layer: Module) -> None:
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

        # Allow for accessing weights and scales in standard way.
        w13_weight = layer.w13_weight
        w2_weight = layer.w2_weight
        w13_weight_scale = getattr(layer, f"w13_{self.weight_scale_name}")
        w2_weight_scale = getattr(layer, f"w2_{self.weight_scale_name}")

        # MI300x and MI325x use FNUZ format for FP8. Convert if needed.
        if current_platform.is_fp8_fnuz():
            w13_weight, w13_weight_scale, layer.w13_input_scale = (
                normalize_e4m3fn_to_e4m3fnuz(
                    w13_weight, w13_weight_scale, layer.w13_input_scale
                )
            )
            w2_weight, w2_weight_scale, layer.w2_input_scale = (
                normalize_e4m3fn_to_e4m3fnuz(
                    w2_weight, w2_weight_scale, layer.w2_input_scale
                )
            )

        # Per tensor kernels require single activation scale. Use the max.
        if self.quant_config.activation_scheme == "static":
            assert not self.block_quant
            assert layer.w13_input_scale is not None
            assert layer.w2_input_scale is not None
            if not all_close_1d(layer.w13_input_scale) or not all_close_1d(
                layer.w2_input_scale
            ):
                logger.warning_once(
                    "Found input_scales that are not equal for "
                    "fp8 MoE layer. Using the maximum across experts "
                    "for each layer."
                )
            replace_parameter(layer, "w13_input_scale", layer.w13_input_scale.max())
            replace_parameter(layer, "w2_input_scale", layer.w2_input_scale.max())

        # Per tensor kernels require single weight scale for w13 per expert, but
        # on disk there is a scale for w1 and w3. Use the max to requantize.
        if not self.block_quant:
            shard_size = layer.intermediate_size_per_partition
            max_w13_scales = w13_weight_scale.max(dim=1).values
            for expert_id in range(layer.local_num_experts):
                start = 0
                for shard_id in range(2):
                    dq_weight = per_tensor_dequantize(
                        w13_weight[expert_id][start : start + shard_size, :],
                        w13_weight_scale[expert_id][shard_id],
                    )
                    w13_weight[expert_id][start : start + shard_size, :], _ = (
                        ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
                    )
                    start += shard_size
            w13_weight_scale = max_w13_scales

        # Shuffle weights into the runtime format.
        self._convert_weights_to_kernel_format(
            layer, w13_weight, w2_weight, w13_weight_scale, w2_weight_scale
        )

        # Setup modular kernel for TP case.
        self._setup_kernel(layer)

1105
1106
1107
1108
    def maybe_make_prepare_finalize(
        self,
        routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
    ) -> mk.FusedMoEPrepareAndFinalize | None:
1109
        if (
1110
            self.fp8_backend == Fp8MoeBackend.AITER
1111
            or self.fp8_backend == Fp8MoeBackend.MARLIN
1112
1113
            or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1114
1115
            return None
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
1116
1117
1118
1119
1120
1121
            if self.block_quant:
                assert self.weight_block_size == [128, 128], (
                    f"Only support weight_block_size == [128, 128], "
                    f"got {self.weight_block_size}"
                )
            # Wire block-scale flag through prepare/finalize when using CUTLASS
1122
            prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
1123
1124
                self.moe,
                use_deepseek_fp8_block_scale=self.block_quant,
1125
            )
1126
1127
1128
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
        else:
1129
            return super().maybe_make_prepare_finalize(routing_tables)
1130

bnellnm's avatar
bnellnm committed
1131
1132
1133
    def select_gemm_impl(
        self,
        prepare_finalize: FusedMoEPrepareAndFinalize,
1134
        layer: torch.nn.Module,
bnellnm's avatar
bnellnm committed
1135
    ) -> FusedMoEPermuteExpertsUnpermute:
1136
        from vllm.model_executor.layers.fused_moe import (
1137
1138
            BatchedDeepGemmExperts,
            BatchedTritonExperts,
1139
            TritonExperts,
1140
1141
            TritonOrDeepGemmExperts,
        )
1142

1143
1144
1145
1146
        if self.fp8_backend in [Fp8MoeBackend.MARLIN, Fp8MoeBackend.AITER]:
            raise NotImplementedError(
                "Marlin and ROCm AITER are not supported with all2all yet."
            )
1147

1148
1149
        assert self.moe_quant_config is not None

1150
1151
1152
1153
1154
        if (
            prepare_finalize.activation_format
            == FusedMoEActivationFormat.BatchedExperts
        ):
            max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
bnellnm's avatar
bnellnm committed
1155
            assert max_num_tokens_per_rank is not None
1156
1157

            experts_impl = (
1158
1159
1160
                BatchedDeepGemmExperts
                if self.fp8_backend == Fp8MoeBackend.DEEPGEMM
                else BatchedTritonExperts
1161
            )
bnellnm's avatar
bnellnm committed
1162
            logger.debug(
1163
1164
                "%s(%s): max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
                experts_impl.__name__,
1165
1166
1167
1168
1169
                self.__class__.__name__,
                max_num_tokens_per_rank,
                self.weight_block_size,
                False,
            )
1170
            return experts_impl(
1171
                max_num_tokens=max_num_tokens_per_rank,
1172
                num_dispatchers=prepare_finalize.num_dispatchers(),
1173
                quant_config=self.moe_quant_config,
1174
            )
1175
1176
        elif self.moe.is_lora_enabled:
            return TritonExperts(quant_config=self.moe_quant_config)
1177
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
1178
            # Select GEMM experts with block-scale when weights are block-quantized
1179
            experts = select_cutlass_fp8_gemm_impl(
1180
1181
                self.moe,
                self.moe_quant_config,
1182
                use_deepseek_fp8_block_scale=self.block_quant,
1183
1184
1185
            )
            logger.debug_once("Using %s", experts.__class__.__name__)
            return experts
1186
        else:
bnellnm's avatar
bnellnm committed
1187
1188
            logger.debug(
                "TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s",
1189
1190
1191
1192
                self.__class__.__name__,
                self.weight_block_size,
                False,
            )
bnellnm's avatar
bnellnm committed
1193
            return TritonOrDeepGemmExperts(
1194
                quant_config=self.moe_quant_config,
1195
                allow_deep_gemm=(self.fp8_backend == Fp8MoeBackend.DEEPGEMM),
1196
1197
            )

1198
    def get_fused_moe_quant_config(
1199
        self, layer: torch.nn.Module
1200
    ) -> FusedMoEQuantConfig | None:
1201
1202
        if self.fp8_backend == Fp8MoeBackend.MARLIN:
            return fp8_w8a16_moe_quant_config(
1203
1204
                w1_scale=getattr(layer, f"w13_{self.weight_scale_name}"),
                w2_scale=getattr(layer, f"w2_{self.weight_scale_name}"),
1205
1206
                block_shape=self.weight_block_size,
            )
1207
1208

        return fp8_w8a8_moe_quant_config(
1209
1210
            w1_scale=getattr(layer, f"w13_{self.weight_scale_name}"),
            w2_scale=getattr(layer, f"w2_{self.weight_scale_name}"),
1211
1212
            a1_scale=layer.w13_input_scale,
            a2_scale=layer.w2_input_scale,
1213
            block_shape=self.weight_block_size,
1214
1215
        )

1216
1217
1218
1219
1220
1221
1222
1223
    @property
    def supports_eplb(self) -> bool:
        return True

    @property
    def allow_inplace(self) -> bool:
        return True

1224
1225
    def apply(
        self,
1226
        layer: FusedMoE,
1227
1228
        x: torch.Tensor,
        router_logits: torch.Tensor,
1229
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
1230
        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
1231
            # TODO(rob): convert this to MK.
1232
1233
1234
1235
            if layer.enable_eplb:
                raise NotImplementedError("EPLB not supported for `Fp8MoEMethod` yet.")
            assert layer.activation == "silu", (
                f"Expected 'silu' activation but got {layer.activation}"
1236
            )
1237

1238
            if self.block_quant:
1239
                import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe  # noqa: E501, F401
1240
1241

                e_score_correction_bias = (
1242
1243
                    layer.e_score_correction_bias.to(x.dtype)
                    if layer.e_score_correction_bias is not None
1244
1245
                    else None
                )
1246
                routing_method_type = layer.routing_method_type
1247
                return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
1248
1249
1250
                    routing_logits=router_logits.to(torch.float32)
                    if routing_method_type == RoutingMethodType.DeepSeekV3
                    else router_logits,
1251
1252
1253
1254
1255
1256
                    routing_bias=e_score_correction_bias,
                    x=x,
                    w13_weight=layer.w13_weight,
                    w13_weight_scale_inv=layer.w13_weight_scale_inv,
                    w2_weight=layer.w2_weight,
                    w2_weight_scale_inv=layer.w2_weight_scale_inv,
1257
1258
1259
1260
                    global_num_experts=layer.global_num_experts,
                    top_k=layer.top_k,
                    num_expert_group=layer.num_expert_group,
                    topk_group=layer.topk_group,
1261
1262
1263
                    intermediate_size=layer.intermediate_size_per_partition,
                    expert_offset=layer.ep_rank * layer.local_num_experts,
                    local_num_experts=layer.local_num_experts,
1264
                    block_shape=self.weight_block_size,
1265
                    routing_method_type=routing_method_type,
1266
                    routed_scaling=layer.routed_scaling_factor,
1267
1268
                )
            else:
1269
1270
1271
                assert (
                    not layer.renormalize and layer.custom_routing_function is not None
                )
XuruiYang's avatar
XuruiYang committed
1272
                result = apply_flashinfer_per_tensor_scale_fp8(
1273
1274
1275
                    layer=layer,
                    hidden_states=x,
                    router_logits=router_logits,
1276
1277
1278
1279
1280
1281
                    routing_bias=layer.e_score_correction_bias,
                    global_num_experts=layer.global_num_experts,
                    top_k=layer.top_k,
                    num_expert_group=layer.num_expert_group,
                    topk_group=layer.topk_group,
                    apply_router_weight_on_input=layer.apply_router_weight_on_input,
1282
                )
1283

1284
        topk_weights, topk_ids = layer.select_experts(
1285
1286
1287
            hidden_states=x,
            router_logits=router_logits,
        )
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
        result = self.kernel(
            x,
            layer.w13_weight,
            layer.w2_weight,
            topk_weights,
            topk_ids,
            inplace=self.use_inplace,
            activation=layer.activation,
            global_num_experts=layer.global_num_experts,
            expert_map=layer.expert_map,
            apply_router_weight_on_input=layer.apply_router_weight_on_input,
        )
1300

1301
        return result
1302
1303


1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
class Fp8OnlineMoEMethod(Fp8MoEMethod):
    """MoE method for online FP8 quantization.
    Supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

    def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
        super().__init__(quant_config, layer)
        assert not quant_config.is_checkpoint_fp8_serialized
        assert quant_config.activation_scheme == "dynamic"
        assert quant_config.weight_block_size is None
        assert self.flashinfer_moe_backend is None

    def create_weights(
        self,
        layer: Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

        # We are doing online quantization, patch the weight loaded
        # to call `process_weights_after_loading` in a streaming fashion
        # as soon as the last weight chunk is loaded.
        weight_loader = extra_weight_attrs["weight_loader"]
        # create a new holder to prevent modifying behavior of any other
        # objects which might depend on the old one
        new_extra_weight_attrs = extra_weight_attrs

        def patched_weight_loader(param, loaded_weight, *args, **kwargs):
            # add a counter to track how many elements we have updated
            if not hasattr(layer, "_loaded_numel"):
                layer._loaded_numel = 0
1348
1349
1350
1351
1352
1353

            # load the current weight chunk
            copy_numel_counter = CopyNumelCounter()
            with copy_numel_counter:
                res = weight_loader(param, loaded_weight, *args, **kwargs)  # type: ignore[misc]
            layer._loaded_numel += copy_numel_counter.copied_numel
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417

            # if we have loaded all of the elements, call
            # process_weights_after_loading
            target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel()
            if layer._loaded_numel == target_loaded_numel:
                self.process_weights_after_loading(layer)

                # Delete the bookkeeping
                del layer._loaded_numel
                # Prevent the usual `process_weights_after_loading` call
                # from doing anything
                layer._already_called_process_weights_after_loading = True

            return res

        new_extra_weight_attrs["weight_loader"] = patched_weight_loader
        extra_weight_attrs = new_extra_weight_attrs

        # WEIGHTS
        w13_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                hidden_size,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

        w2_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
        # Allocate 2 scales for w1 and w3 respectively.
        # They will be combined to a single scale after weight loading.
        w13_weight_scale = torch.nn.Parameter(
            torch.ones(num_experts, dtype=torch.float32), requires_grad=False
        )
        w2_weight_scale = torch.nn.Parameter(
            torch.ones(num_experts, dtype=torch.float32), requires_grad=False
        )
        layer.register_parameter("w13_weight_scale", w13_weight_scale)
        layer.register_parameter("w2_weight_scale", w2_weight_scale)

        layer.w13_input_scale = None
        layer.w2_input_scale = None

    def process_weights_after_loading(self, layer: Module) -> None:
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

        # If checkpoint is fp16, quantize in place.
        fp8_dtype = current_platform.fp8_dtype()
1418
1419
        w13_weight = torch.empty_like(layer.w13_weight, dtype=fp8_dtype)
        w2_weight = torch.empty_like(layer.w2_weight, dtype=fp8_dtype)
1420
1421
1422

        for expert in range(layer.local_num_experts):
            w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
1423
                ops.scaled_fp8_quant(layer.w13_weight[expert, :, :])
1424
1425
            )
            w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
1426
                ops.scaled_fp8_quant(layer.w2_weight[expert, :, :])
1427
1428
1429
1430
            )
        replace_parameter(layer, "w13_weight", w13_weight)
        replace_parameter(layer, "w2_weight", w2_weight)

1431
1432
1433
1434
        # Shuffle weights into the runtime format.
        self._convert_weights_to_kernel_format(
            layer, w13_weight, w2_weight, layer.w13_weight_scale, layer.w2_weight_scale
        )
1435

1436
1437
        # Setup modular kernel for TP case.
        self._setup_kernel(layer)
1438
1439


1440
1441
1442
class Fp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
1443
1444
1445
    """

    def __init__(self, quant_config: Fp8Config):
1446
        super().__init__(quant_config)