fp8.py 58.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from enum import Enum
5
from typing import TYPE_CHECKING, Any, Optional
6
7
8

import torch
from torch.nn import Module
9
from torch.utils._python_dispatch import TorchDispatchMode
10

11
import vllm.envs as envs
12
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
13
from vllm import _custom_ops as ops
14
from vllm._aiter_ops import rocm_aiter_ops
15
from vllm.attention.layer import Attention
16
from vllm.distributed import get_tensor_model_parallel_world_size
17
from vllm.logger import init_logger
18
from vllm.model_executor.layers.batch_invariant import (
19
    vllm_is_batch_invariant,
20
)
bnellnm's avatar
bnellnm committed
21
from vllm.model_executor.layers.fused_moe import (
22
23
24
25
26
27
28
    FusedMoE,
    FusedMoEActivationFormat,
    FusedMoEMethodBase,
    FusedMoEPermuteExpertsUnpermute,
    FusedMoEPrepareAndFinalize,
    FusedMoeWeightScaleSupported,
)
29
from vllm.model_executor.layers.fused_moe.config import (
30
    FusedMoEParallelConfig,
31
    FusedMoEQuantConfig,
32
    RoutingMethodType,
33
    fp8_w8a8_moe_quant_config,
34
    fp8_w8a16_moe_quant_config,
35
36
37
38
39
40
41
)
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
)
42
from vllm.model_executor.layers.quantization import QuantizationMethods
43
from vllm.model_executor.layers.quantization.base_config import (
44
45
46
    QuantizationConfig,
    QuantizeMethodBase,
)
47
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
48
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
49
50
    FlashinferMoeBackend,
    apply_flashinfer_per_tensor_scale_fp8,
51
    build_flashinfer_fp8_cutlass_moe_prepare_finalize,
52
    get_flashinfer_moe_backend,
53
54
    make_fp8_moe_alpha_scales_for_fi,
    register_scales_for_trtllm_fp8_per_tensor_moe,
55
56
57
58
    rotate_flashinfer_fp8_moe_weights,
    select_cutlass_fp8_gemm_impl,
    swap_w13_to_w31,
)
59
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
60
61
62
63
    W8A8BlockFp8LinearOp,
    create_fp8_input_scale,
    create_fp8_scale_parameter,
    create_fp8_weight_parameter,
64
    deepgemm_post_process_fp8_weight_block,
65
66
67
68
69
    maybe_post_process_fp8_weight_block,
    process_fp8_weight_block_strategy,
    process_fp8_weight_tensor_strategy,
    validate_fp8_block_shape,
)
70
71
72
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
    get_marlin_input_dtype,
)
73
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
74
75
76
77
    apply_fp8_marlin_linear,
    prepare_fp8_layer_for_marlin,
    prepare_moe_fp8_layer_for_marlin,
)
78
from vllm.model_executor.layers.quantization.utils.quant_utils import (
79
80
81
    GroupShape,
    is_layer_skipped,
)
82
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
83
84
85
86
87
88
89
90
91
92
93
94
95
    Fp8LinearOp,
    all_close_1d,
    cutlass_block_fp8_supported,
    cutlass_fp8_supported,
    maybe_create_device_identity,
    normalize_e4m3fn_to_e4m3fnuz,
    per_tensor_dequantize,
)
from vllm.model_executor.parameter import (
    BlockQuantScaleParameter,
    ModelWeightParameter,
    PerTensorScaleParameter,
)
96
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
97
from vllm.platforms import current_platform
98
99
100
101
from vllm.utils.deep_gemm import (
    is_deep_gemm_e8m0_used,
    is_deep_gemm_supported,
)
102
from vllm.utils.flashinfer import has_flashinfer_moe
103
from vllm.utils.import_utils import has_deep_gemm
104

105
106
107
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

108
109
110
111
ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)

112

113
114
115
116
117
class Fp8MoeBackend(Enum):
    NONE = 0
    FLASHINFER_TRTLLM = 1
    FLASHINFER_CUTLASS = 2
    DEEPGEMM = 3
118
119
    MARLIN = 4
    TRITON = 5
120
    AITER = 6
121
122


123
def get_fp8_moe_backend(
124
125
126
    block_quant: bool,
    moe_parallel_config: FusedMoEParallelConfig,
    with_lora_support: bool,
127
) -> Fp8MoeBackend | None:
128
129
130
131
    """
    Select the primary FP8 MoE backend
    Note: Shape-specific fallbacks may still occur at runtime.
    """
132
133
    if current_platform.is_xpu():
        return None
134
135
    if with_lora_support:
        return Fp8MoeBackend.TRITON
136
    # Prefer FlashInfer backends on supported GPUs; allow SM90 and SM100.
137
138
    if (
        current_platform.is_cuda()
139
        and (
140
            current_platform.is_device_capability_family(100)
141
142
            or current_platform.is_device_capability(90)
        )
143
144
145
        and envs.VLLM_USE_FLASHINFER_MOE_FP8
        and has_flashinfer_moe()
    ):
146
147
        backend = get_flashinfer_moe_backend()
        if backend == FlashinferMoeBackend.TENSORRT_LLM:
148
            logger.info_once("Using FlashInfer FP8 MoE TRTLLM backend for SM100")
149
150
            return Fp8MoeBackend.FLASHINFER_TRTLLM
        else:
151
            if block_quant and current_platform.is_device_capability_family(100):
152
153
                raise ValueError(
                    "FlashInfer FP8 MoE throughput backend does not "
154
                    "support block quantization on SM100. Please use "
155
156
157
                    "VLLM_FLASHINFER_MOE_BACKEND=latency "
                    "instead."
                )
158
            logger.info_once("Using FlashInfer FP8 MoE CUTLASS backend for SM90/SM100")
159
160
161
            return Fp8MoeBackend.FLASHINFER_CUTLASS

    # weight-only path for older GPUs without native FP8
162
163
164
165
    use_marlin = (
        not current_platform.has_device_capability(89)
        or envs.VLLM_TEST_FORCE_FP8_MARLIN
    )
166
167
168
169
170
171
    if current_platform.is_rocm():
        use_marlin = False
    if use_marlin:
        logger.info_once("Using Marlin backend for FP8 MoE")
        return Fp8MoeBackend.MARLIN

172
173
174
175
176
177
178
179
180
181
182
183
    # Determine if we should use DeepGEMM with block-quantized weights:
    # - If explicitly set by user, respect their choice
    # - If not explicitly set (default), disable when TP size is >= 8
    moe_use_deep_gemm = envs.VLLM_MOE_USE_DEEP_GEMM
    if not envs.is_set("VLLM_MOE_USE_DEEP_GEMM") and moe_parallel_config.tp_size >= 8:
        moe_use_deep_gemm = False
        logger.info_once(
            "DeepGEMM MoE is disabled by default when TP size is >= 8. "
            "Set VLLM_MOE_USE_DEEP_GEMM=1 to enable it.",
            scope="local",
        )

184
185
186
187
188
189
190
191
192
193
194
195
196
    # Determine if we should use DeepGEMM (top-level enable switch)
    # - If explicitly set by user, respect their choice
    # - If not platform supports DeepGEMM, disable it
    # This helps avoid warning messages on unsupported platforms.
    use_deep_gemm = envs.VLLM_USE_DEEP_GEMM
    if not is_deep_gemm_supported():
        use_deep_gemm = False
        logger.info_once(
            "DeepGEMM is disabled because the platform does not support it.",
            scope="local",
        )

    if use_deep_gemm and moe_use_deep_gemm and block_quant:
197
        if not has_deep_gemm():
198
199
200
            logger.warning_once(
                "DeepGEMM backend requested but not available.", scope="local"
            )
201
        elif is_deep_gemm_supported():
202
            logger.info_once("Using DeepGEMM backend for FP8 MoE", scope="local")
203
204
            return Fp8MoeBackend.DEEPGEMM

205
206
207
208
    if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MOE:
        logger.info_once("Using ROCm AITER backend for FP8 MoE", scope="local")
        return Fp8MoeBackend.AITER

209
210
211
212
213
    # default to Triton
    logger.info_once("Using Triton backend for FP8 MoE")
    return Fp8MoeBackend.TRITON


214
class Fp8Config(QuantizationConfig):
215
216
    """Config class for FP8."""

217
218
    def __init__(
        self,
219
        is_checkpoint_fp8_serialized: bool = False,
220
        activation_scheme: str = "dynamic",
221
222
        ignored_layers: list[str] | None = None,
        weight_block_size: list[int] | None = None,
223
    ) -> None:
224
        super().__init__()
225

226
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
227

228
        if activation_scheme not in ACTIVATION_SCHEMES:
229
            raise ValueError(f"Unsupported activation scheme {activation_scheme}")
230
        self.activation_scheme = activation_scheme
231
        self.ignored_layers = ignored_layers or []
232
233
234
235
        if weight_block_size is not None:
            if not is_checkpoint_fp8_serialized:
                raise ValueError(
                    "The block-wise quantization only supports fp8-serialized "
236
237
                    "checkpoint for now."
                )
238
239
240
            if len(weight_block_size) != 2:
                raise ValueError(
                    "The quantization block size of weight must have 2 "
241
242
                    f"dimensions, but got {len(weight_block_size)} dimensions"
                )
243
            if activation_scheme != "dynamic":
244
245
246
247
248
                raise ValueError(
                    "The block-wise quantization only supports "
                    "dynamic activation scheme for now, but got "
                    f"{activation_scheme} activation scheme."
                )
249
        self.weight_block_size = weight_block_size
250

251
    @classmethod
252
    def get_name(cls) -> QuantizationMethods:
253
254
255
        return "fp8"

    @classmethod
256
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
257
258
259
260
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
261
        return 75
262
263

    @classmethod
264
    def get_config_filenames(cls) -> list[str]:
265
266
        return []

267
268
    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if self.ignored_layers is not None:
269
            self.ignored_layers = hf_to_vllm_mapper.apply_list(self.ignored_layers)
270

271
    @classmethod
272
    def from_config(cls, config: dict[str, Any]) -> "Fp8Config":
273
        quant_method = cls.get_from_keys(config, ["quant_method"])
274
        is_checkpoint_fp8_serialized = "fp8" in quant_method
275
        activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
276
        ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
277
        weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None)
278
        if not ignored_layers:
279
280
281
282
283
284
285
286
287
288
289
290
291
            ignored_layers = cls.get_from_keys_or(
                config, ["modules_to_not_convert"], None
            )
        return cls(
            is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
            activation_scheme=activation_scheme,
            ignored_layers=ignored_layers,
            weight_block_size=weight_block_size,
        )

    def get_xpu_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
292
        from vllm.model_executor.layers.quantization.ipex_quant import (
293
294
295
296
            XPUFp8LinearMethod,
            XPUFp8MoEMethod,
        )

297
298
299
300
        fp8_config = Fp8Config(
            is_checkpoint_fp8_serialized=self.is_checkpoint_fp8_serialized,
            activation_scheme=self.activation_scheme,
            ignored_layers=self.ignored_layers,
301
302
            weight_block_size=self.weight_block_size,
        )
303
304

        if isinstance(layer, LinearBase):
305
306
307
308
309
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
310
311
312
                return UnquantizedLinearMethod()
            return XPUFp8LinearMethod(fp8_config)
        elif isinstance(layer, FusedMoE):
313
314
315
316
317
318
319
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
                return UnquantizedFusedMoEMethod(layer.moe_config)

320
321
322
323
324
            return XPUFp8MoEMethod(fp8_config, layer)
        elif isinstance(layer, Attention):
            return Fp8KVCacheMethod(self)
        return None

325
326
327
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
328
329
        if current_platform.is_xpu():
            return self.get_xpu_quant_method(layer, prefix)
330
        if isinstance(layer, LinearBase):
331
332
333
334
335
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
336
                return UnquantizedLinearMethod()
337
338
339
            quant_method = Fp8LinearMethod(self)
            quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
            return quant_method
340
        elif isinstance(layer, FusedMoE):
341
342
343
344
345
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
XuruiYang's avatar
XuruiYang committed
346
                return UnquantizedFusedMoEMethod(layer.moe_config)
347
348
349
350
            if self.is_checkpoint_fp8_serialized:
                moe_quant_method = Fp8MoEMethod(self, layer)
            else:
                moe_quant_method = Fp8OnlineMoEMethod(self, layer)
351
352
            moe_quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
            return moe_quant_method
353
        elif isinstance(layer, Attention):
354
            return Fp8KVCacheMethod(self)
355
        return None
356

357
    def get_cache_scale(self, name: str) -> str | None:
358
359
360
361
362
363
364
365
366
367
368
369
        """
        Check whether the param name matches the format for k/v cache scales
        in compressed-tensors. If this is the case, return its equivalent
        param name expected by vLLM

        :param name: param name
        :return: matching param name for KV cache scale in vLLM
        """
        if name.endswith(".output_scale") and ".k_proj" in name:
            return name.replace(".k_proj.output_scale", ".attn.k_scale")
        if name.endswith(".output_scale") and ".v_proj" in name:
            return name.replace(".v_proj.output_scale", ".attn.v_scale")
370
371
372
373
374
        if name.endswith(".output_scale") and ".q_proj" in name:
            return name.replace(".q_proj.output_scale", ".attn.q_scale")
        if name.endswith("self_attn.prob_output_scale"):
            return name.replace(".prob_output_scale", ".attn.prob_scale")
        # If no matches, return None
375
376
        return None

377

378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
class CopyNumelCounter(TorchDispatchMode):
    """
    Tracks total number of elements modified with `copy_`. Useful for keeping
    track of weight loading where underlying weights can be arbitrarily
    transformed (such as with `narrow`) before calling copy.
    """

    def __init__(self):
        super().__init__()
        self.copied_numel = 0

    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        out = func(*args, **kwargs)
        if func == torch.ops.aten.copy_.default:
            self.copied_numel += args[0].numel()
        return out


398
399
class Fp8LinearMethod(LinearMethodBase):
    """Linear method for FP8.
400
401
402
403
404
405
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.
406
407
408
409
410

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
    2. Only support float8_e4m3fn data type due to the limitation of
       torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
411

412
413
414
415
    Args:
        quant_config: The quantization config.
    """

416
    def __init__(self, quant_config: Fp8Config):
417
        self.quant_config = quant_config
418
        self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
419
        self.out_dtype = torch.get_default_dtype()
420

421
422
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
423
        self.marlin_input_dtype = None
424
425
426
427
        self.use_marlin = (
            not current_platform.has_device_capability(89)
            or envs.VLLM_TEST_FORCE_FP8_MARLIN
        )
428
        # Disable marlin for rocm
429
        if current_platform.is_rocm():
430
            self.use_marlin = False
431
        if vllm_is_batch_invariant():
432
            self.use_marlin = False
433

434
        self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled()
435
        self.use_deep_gemm = is_deep_gemm_supported()
436

437
438
        self.weight_block_size = self.quant_config.weight_block_size
        self.block_quant = self.weight_block_size is not None
439
        self.act_q_static = self.quant_config.activation_scheme == "static"
440
441
        if self.weight_block_size:
            self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
442
        else:
443
444
445
446
447
            # Use per-token quantization for better perf if dynamic and cutlass
            if not self.act_q_static and cutlass_fp8_supported():
                self.act_q_group_shape = GroupShape.PER_TOKEN
            else:
                self.act_q_group_shape = GroupShape.PER_TENSOR
448

449
450
451
452
453
454
455
456
457
458
459
460
        if self.block_quant:
            assert not self.act_q_static
            assert self.weight_block_size is not None
            self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
                weight_group_shape=GroupShape(*self.weight_block_size),
                act_quant_group_shape=self.act_q_group_shape,
                cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
                use_aiter_and_is_supported=self.use_aiter_and_is_supported,
            )
        else:
            self.fp8_linear = Fp8LinearOp(
                act_quant_static=self.act_q_static,
461
462
                act_quant_group_shape=self.act_q_group_shape,
            )
463

464
465
466
467
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
468
        output_partition_sizes: list[int],
469
470
471
472
473
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
474
475
        maybe_create_device_identity()

476
        output_size_per_partition = sum(output_partition_sizes)
477
        weight_loader = extra_weight_attrs.get("weight_loader")
478
479
480
481
482
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None
483

484
        if self.block_quant:
485
486
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
487
488
489
490
491
492
493
494
            validate_fp8_block_shape(
                layer,
                input_size,
                output_size,
                input_size_per_partition,
                output_partition_sizes,
                self.weight_block_size,
            )
495

496
        # WEIGHT
497
        if self.quant_config.is_checkpoint_fp8_serialized:
498
499
500
            weight = create_fp8_weight_parameter(
                output_size_per_partition, input_size_per_partition, weight_loader
            )
501
        else:
502
503
504
505
506

            def patched_weight_loader(param, loaded_weight, *args, **kwargs):
                # track how many elements we have updated
                if not hasattr(layer, "_loaded_numel"):
                    layer._loaded_numel = 0
507
508
509
510
511
512

                # load the current weight chunk
                copy_numel_counter = CopyNumelCounter()
                with copy_numel_counter:
                    res = weight_loader(param, loaded_weight, *args, **kwargs)  # type: ignore[misc]
                layer._loaded_numel += copy_numel_counter.copied_numel
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527

                # if we have loaded all of the elements, call
                # process_weights_after_loading
                target_loaded_numel = layer.weight.numel()
                if layer._loaded_numel == target_loaded_numel:
                    self.process_weights_after_loading(layer)

                    # Delete the bookkeeping
                    del layer._loaded_numel
                    # Prevent the usual `process_weights_after_loading` call from doing
                    # anything
                    layer._already_called_process_weights_after_loading = True

                return res

528
            # For non-serialized checkpoints, use original dtype
529
530
531
532
533
534
535
536
            weight = ModelWeightParameter(
                data=torch.empty(
                    output_size_per_partition,
                    input_size_per_partition,
                    dtype=params_dtype,
                ),
                input_dim=1,
                output_dim=0,
537
                weight_loader=patched_weight_loader,
538
            )
539
540
        layer.register_parameter("weight", weight)

541
542
543
544
        # If checkpoint is serialized fp8, load them.
        # Otherwise, wait until process_weights_after_loading.
        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
545
            if not self.block_quant:
546
547
548
549
550
551
552
                scale = create_fp8_scale_parameter(
                    PerTensorScaleParameter,
                    output_partition_sizes,
                    input_size_per_partition,
                    None,
                    weight_loader,
                )
553
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
554
555
                layer.register_parameter("weight_scale", scale)
            else:
556
557
                assert not self.act_q_static
                assert self.weight_block_size is not None
558
559
560
561
562
563
564
                scale = create_fp8_scale_parameter(
                    BlockQuantScaleParameter,
                    output_partition_sizes,
                    input_size_per_partition,
                    self.weight_block_size,
                    weight_loader,
                )
565
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
566
567
                # The weight_scale_inv name is intentional for deepseekv3
                layer.register_parameter("weight_scale_inv", scale)
568

569
            # INPUT ACTIVATION SCALE
570
            if self.act_q_static:
571
                scale = create_fp8_input_scale(output_partition_sizes, weight_loader)
572
                set_weight_attrs(scale, {"scale_type": "input_scale"})
573
                layer.register_parameter("input_scale", scale)
574
575
            else:
                layer.register_parameter("input_scale", None)
576

577
    def process_weights_after_loading(self, layer: Module) -> None:
578
579
580
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

581
        size_k_first = True
582
        input_scale = None
583
        # TODO(rob): refactor block quant into separate class.
584
        if self.block_quant:
585
            assert not self.act_q_static
586
            size_k_first = False
587

588
            weight, weight_scale_inv = process_fp8_weight_block_strategy(
589
590
                layer.weight, layer.weight_scale_inv
            )
591
592
593
594

            # Update layer with new values
            replace_parameter(layer, "weight", weight.data)
            replace_parameter(layer, "weight_scale_inv", weight_scale_inv.data)
595

596
        # If checkpoint not serialized fp8, quantize the weights.
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
        else:
            if not self.quant_config.is_checkpoint_fp8_serialized:
                qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
                weight = qweight.t()

            # If checkpoint is fp8 per-tensor, handle that there are N scales for N
            # shards in a fused module
            else:
                weight = layer.weight
                weight_scale = layer.weight_scale

                # If using w8a8, torch._scaled_mm needs per tensor, so
                # requantize the logical shards as a single weight.
                if not self.use_marlin:
                    weight, weight_scale, input_scale = (
                        process_fp8_weight_tensor_strategy(
                            weight,
                            weight_scale,
                            layer.logical_widths,
                            getattr(layer, "input_scale", None),
                        )
                    )
                    if self.act_q_static:
                        assert input_scale is not None
                        input_scale = input_scale.max()
                weight = weight.t()
623

624
625
626
627
628
629
            # Update layer with new values.
            replace_parameter(layer, "weight", weight.data)
            replace_parameter(layer, "weight_scale", weight_scale.data)

        if input_scale is not None:
            replace_parameter(layer, "input_scale", input_scale)
630
        else:
631
            layer.input_scale = None
632

633
        if self.use_marlin:
634
635
636
            prepare_fp8_layer_for_marlin(
                layer, size_k_first, input_dtype=self.marlin_input_dtype
            )
637
638
            # Activations not quantized for marlin.
            del layer.input_scale
639
            return
640

641
        if self.block_quant:
642
            maybe_post_process_fp8_weight_block(layer)
643

644
645
646
647
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
648
        bias: torch.Tensor | None = None,
649
    ) -> torch.Tensor:
650
651
        # if batch invariant mode is enabled, prefer DeepGEMM FP8 path
        # we will use BF16 dequant when DeepGEMM is not supported.
652
        if vllm_is_batch_invariant():
653
654
            if self.block_quant:
                assert self.weight_block_size is not None
655
656
657
                return self.w8a8_block_fp8_linear.apply(
                    input=x,
                    weight=layer.weight,
658
                    weight_scale=layer.weight_scale_inv,
659
660
661
                    input_scale=layer.input_scale,
                    bias=bias,
                )
662
            else:
663
664
665
                # per-tensor/channel: dequant to BF16 and run GEMM
                weight_fp8 = layer.weight.to(torch.bfloat16)
                weight_scale = layer.weight_scale.to(torch.bfloat16)
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
                if weight_scale.numel() == 1:
                    # Per-tensor: simple scalar multiplication
                    weight_bf16 = weight_fp8 * weight_scale
                else:
                    # Multiple scales (fused modules like QKV)
                    # Try to infer correct broadcasting
                    # weight is [K, N], scale could be [num_logical_weights]
                    # Need to figure out how to broadcast - for now just try
                    # direct multiplication
                    if (
                        weight_scale.dim() == 1
                        and weight_scale.shape[0] == weight_fp8.shape[0]
                    ):
                        # Per-row scaling
                        weight_bf16 = weight_fp8 * weight_scale.unsqueeze(1)
                    else:
                        # Fallback
                        weight_bf16 = weight_fp8 * weight_scale
684
                return torch.nn.functional.linear(x, weight_bf16.t(), bias)
685

686
        if self.use_marlin:
687
688
689
690
691
            if self.block_quant:
                weight_scale = layer.weight_scale_inv
            else:
                weight_scale = layer.weight_scale

692
693
694
            return apply_fp8_marlin_linear(
                input=x,
                weight=layer.weight,
695
                weight_scale=weight_scale,
696
697
698
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
699
                input_dtype=self.marlin_input_dtype,
700
701
                bias=bias,
            )
702

703
        if self.block_quant:
704
705
706
            assert self.weight_block_size is not None

            return self.w8a8_block_fp8_linear.apply(
707
                input=x,
708
                weight=layer.weight,
709
                weight_scale=layer.weight_scale_inv,
710
                input_scale=layer.input_scale,
711
                bias=bias,
712
            )
713

714
715
716
717
718
719
720
721
        return self.fp8_linear.apply(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            out_dtype=self.out_dtype,
            input_scale=layer.input_scale,
            bias=bias,
        )
722
723


724
725
726
727
728
729
730
731
732
733
734
735
736
class Fp8MoEMethod(FusedMoEMethodBase):
    """MoE method for FP8.
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

737
738
739
    def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
        super().__init__(layer.moe_config)
        self.layer = layer
740
        self.quant_config = quant_config
741
        self.weight_block_size = self.quant_config.weight_block_size
742
        self.block_quant: bool = self.weight_block_size is not None
743
744
745
        self.weight_scale_name = (
            "weight_scale_inv" if self.block_quant else "weight_scale"
        )
746
        self.fp8_backend = get_fp8_moe_backend(
747
            self.block_quant, layer.moe_parallel_config, self.moe.is_lora_enabled
748
        )
749

750
        self.marlin_input_dtype = None
751
        self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
752
753
754
755
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
            self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
        elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
            self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
            if self.block_quant and self.weight_block_size != [128, 128]:
                raise NotImplementedError(
                    "FlashInfer CUTLASS FP8 MoE backend only supports block "
                    "size [128, 128]."
                )
            if not self.block_quant:
                if layer.renormalize or layer.custom_routing_function is not None:
                    raise NotImplementedError(
                        "FlashInfer CUTLASS FP8 MoE backend does custom routing "
                        f"function or renormalization, but got {layer.renormalize} and "
                        f"{layer.custom_routing_function}."
                    )
                if layer.scoring_func != "sigmoid":
                    raise NotImplementedError(
                        "FlashInfer CUTLASS FP8 MoE backend only supports "
                        f"'sigmoid' scoring function, but got {layer.scoring_func}."
                    )
            if layer.activation != "silu":
                raise NotImplementedError(
                    "FlashInfer CUTLASS FP8 MoE backend only supports SiLU "
                    "activation function, but got {layer.activation}."
777
                )
778
779
780
781
782
783
784
785
        dynamic_per_token = (
            not self.block_quant and self.quant_config.activation_scheme != "static"
        )
        if self.flashinfer_moe_backend is not None and dynamic_per_token:
            raise NotImplementedError(
                "FlashInfer FP8 MoE backend does not support dynamic per token "
                "activation quantization."
            )
786

787
788
789
790
791
792
793
794
795
    def create_weights(
        self,
        layer: Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
796
797
798
799
800
801
        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

802
803
804
        assert self.quant_config.is_checkpoint_fp8_serialized
        params_dtype = torch.float8_e4m3fn

805
        if self.block_quant:
806
807
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
808
809
            tp_size = get_tensor_model_parallel_world_size()
            block_n, block_k = (
810
811
                self.weight_block_size[0],
                self.weight_block_size[1],
812
813
814
815
816
            )
            # NOTE: To ensure proper alignment of the block-wise quantization
            # scales, the output_size of the weights for both the gate and up
            # layers must be divisible by block_n.
            # Required by column parallel or enabling merged weights
817
            if intermediate_size_per_partition % block_n != 0:
818
819
                raise ValueError(
                    f"The output_size of gate's and up's weight = "
820
                    f"{intermediate_size_per_partition} is not divisible by "
821
822
823
                    f"weight quantization block_n = {block_n}."
                )
            if tp_size > 1 and intermediate_size_per_partition % block_k != 0:
824
                # Required by row parallel
825
826
827
                raise ValueError(
                    f"The input_size of down's weight = "
                    f"{intermediate_size_per_partition} is not divisible by "
828
829
                    f"weight quantization block_k = {block_k}."
                )
830
831

        # WEIGHTS
832
833
834
835
836
837
838
839
840
        w13_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                hidden_size,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
841
842
843
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

844
845
846
847
848
849
850
851
852
        w2_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
853
854
855
856
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
857
        if not self.block_quant:
858
859
860
            # For per-tensor quant, the scales are per expert and weight.
            w13_scale_data = torch.ones(num_experts, 2, dtype=torch.float32)
            w2_scale_data = torch.ones(num_experts, dtype=torch.float32)
861
        else:
862
863
864
865
866
867
            # For block quant, the scales are per block (typically 128x128).
            w13_scale_data = torch.ones(
                num_experts,
                2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
                (hidden_size + block_k - 1) // block_k,
                dtype=torch.float32,
868
            )
869
870
871
872
873
            w2_scale_data = torch.ones(
                num_experts,
                (hidden_size + block_n - 1) // block_n,
                (intermediate_size_per_partition + block_k - 1) // block_k,
                dtype=torch.float32,
874
            )
875
876
877
878
879
        w13_weight_scale = torch.nn.Parameter(w13_scale_data, requires_grad=False)
        w2_weight_scale = torch.nn.Parameter(w2_scale_data, requires_grad=False)
        # Note: name is weight_scale for tensor, weight_scale_inv for block.
        layer.register_parameter(f"w13_{self.weight_scale_name}", w13_weight_scale)
        layer.register_parameter(f"w2_{self.weight_scale_name}", w2_weight_scale)
880

881
882
883
        # Add the quantization method used (per tensor/grouped/channel)
        # to ensure the weight scales are loaded in properly
        extra_weight_attrs.update(
884
885
886
887
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
            if self.block_quant
            else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
        )
888
889
        set_weight_attrs(w13_weight_scale, extra_weight_attrs)
        set_weight_attrs(w2_weight_scale, extra_weight_attrs)
890
891
892

        # INPUT_SCALES
        if self.quant_config.activation_scheme == "static":
893
            assert not self.block_quant
894
895
896
            w13_input_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
897
            layer.register_parameter("w13_input_scale", w13_input_scale)
898
            set_weight_attrs(w13_input_scale, extra_weight_attrs)
899

900
901
902
            w2_input_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
903
            layer.register_parameter("w2_input_scale", w2_input_scale)
904
905
            set_weight_attrs(w2_input_scale, extra_weight_attrs)

906
        else:
907
908
            layer.w13_input_scale = None
            layer.w2_input_scale = None
909

910
911
912
913
914
915
916
    def _convert_weights_to_kernel_format(
        self,
        layer: Module,
        w13_weight: torch.Tensor,
        w2_weight: torch.Tensor,
        w13_weight_scale: torch.Tensor,
        w2_weight_scale: torch.Tensor,
917
918
        w13_input_scale: torch.Tensor | None,
        w2_input_scale: torch.Tensor | None,
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
    ) -> None:
        if self.fp8_backend == Fp8MoeBackend.DEEPGEMM:
            assert self.block_quant
            w13_weight, w13_weight_scale = deepgemm_post_process_fp8_weight_block(
                wq=w13_weight,
                ws=w13_weight_scale,
                quant_block_shape=tuple(layer.weight_block_size),
                use_e8m0=is_deep_gemm_e8m0_used(),
            )
            w2_weight, w2_weight_scale = deepgemm_post_process_fp8_weight_block(
                wq=w2_weight,
                ws=w2_weight_scale,
                quant_block_shape=tuple(layer.weight_block_size),
                use_e8m0=is_deep_gemm_e8m0_used(),
            )
        elif self.fp8_backend == Fp8MoeBackend.AITER:
            w13_weight, w2_weight = rocm_aiter_ops.shuffle_weights(
                w13_weight, w2_weight
            )
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
        elif self.fp8_backend == Fp8MoeBackend.MARLIN:
            (
                workspace,
                w13_weight,
                w2_weight,
                w13_weight_scale,
                w2_weight_scale,
            ) = prepare_moe_fp8_layer_for_marlin(
                layer,
                w13_weight,
                w2_weight,
                w13_weight_scale,
                w2_weight_scale,
                input_dtype=self.marlin_input_dtype,
            )
            layer.workspace = workspace

955
956
957
958
959
960
961
        elif self.fp8_backend in [
            Fp8MoeBackend.FLASHINFER_CUTLASS,
            Fp8MoeBackend.FLASHINFER_TRTLLM,
        ]:
            w13_weight = swap_w13_to_w31(w13_weight)
            if self.block_quant:
                w13_weight_scale = swap_w13_to_w31(w13_weight_scale)
962
            else:
963
                if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
964
                    rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
965
966
967
968
969
970
971
972
                    register_scales_for_trtllm_fp8_per_tensor_moe(
                        layer=layer,
                        w13_weight_scale=w13_weight,
                        w13_input_scale=w13_input_scale,
                        w2_weight_scale=w2_weight,
                        w2_input_scale=w2_input_scale,
                    )

973
974
975
976
        elif self.fp8_backend == Fp8MoeBackend.AITER:
            w13_weight, w2_weight = rocm_aiter_ops.shuffle_weights(
                w13_weight, w2_weight
            )
977

978
979
980
981
982
983
984
985
986
        # Replace parameters with updated versions. Note that this helper
        # function ensures the replacement is compatible with RL weight reloads.
        replace_parameter(layer, "w13_weight", w13_weight)
        replace_parameter(layer, "w2_weight", w2_weight)
        replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_weight_scale)
        replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_weight_scale)

    def _setup_kernel(self, layer: Module) -> None:
        """Setup Modular Kernel for TP Case"""
987
988
989
990
991
992
        # NOTE(rob): this is a WIP refactor. We are first migrating
        # all of the kernels in the TP case to use mk. Once this is
        # done, then we will initialzie the TP case and DP/EP case
        # via the same code path (i.e. via maybe_init_modular_kernel).
        # NOTE(rob): in progress migrating all into this format.

993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
        from vllm.model_executor.layers.fused_moe import (
            TritonOrDeepGemmExperts,
        )
        from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
            FlashInferExperts,
        )
        from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
            MarlinExperts,
        )
        from vllm.model_executor.layers.fused_moe.prepare_finalize import (
            MoEPrepareAndFinalizeNoEP,
        )
        from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
            AiterExperts,
        )

1009
1010
1011
1012
        # Flashinfer TRTLLM does not use the modular kernel abstraction.
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
            return

1013
1014
1015
        self.moe_quant_config = self.get_fused_moe_quant_config(layer)
        assert self.moe_quant_config is not None
        self.use_inplace = True
1016

1017
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
1018
            self.kernel = mk.FusedMoEModularKernel(
1019
1020
                # TODO: make defer_input_quant an attr of the FlashInferExperts
                MoEPrepareAndFinalizeNoEP(defer_input_quant=self.block_quant),
1021
                FlashInferExperts(
1022
                    out_dtype=layer.orig_dtype,
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
                    quant_config=self.moe_quant_config,
                    ep_rank=self.moe.ep_rank,
                    ep_size=self.moe.ep_size,
                    tp_rank=self.moe.tp_rank,
                    tp_size=self.moe.tp_size,
                    use_dp=(self.moe.dp_size > 1),
                    use_deepseek_fp8_block_scale=self.block_quant,
                ),
            )
            self.use_inplace = False

1034
1035
1036
1037
1038
        elif self.fp8_backend == Fp8MoeBackend.AITER:
            self.kernel = mk.FusedMoEModularKernel(
                # TODO: make defer_input_quant an attr of the AiterExperts
                MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
                AiterExperts(quant_config=self.moe_quant_config),
1039
            )
1040
1041
1042
1043
        elif self.fp8_backend == Fp8MoeBackend.MARLIN:
            self.kernel = mk.FusedMoEModularKernel(
                MoEPrepareAndFinalizeNoEP(),
                MarlinExperts(quant_config=self.moe_quant_config),
1044
            )
1045
1046
1047
1048
1049
1050
1051
        else:
            self.kernel = mk.FusedMoEModularKernel(
                MoEPrepareAndFinalizeNoEP(),
                TritonOrDeepGemmExperts(
                    quant_config=self.moe_quant_config,
                    allow_deep_gemm=(self.fp8_backend == Fp8MoeBackend.DEEPGEMM),
                ),
1052
            )
1053

1054
1055
1056
1057
1058
1059
1060
1061
1062
    def process_weights_after_loading(self, layer: Module) -> None:
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

        # Allow for accessing weights and scales in standard way.
        w13_weight = layer.w13_weight
        w2_weight = layer.w2_weight
        w13_weight_scale = getattr(layer, f"w13_{self.weight_scale_name}")
        w2_weight_scale = getattr(layer, f"w2_{self.weight_scale_name}")
1063
1064
        w13_input_scale = layer.w13_input_scale
        w2_input_scale = layer.w2_input_scale
1065
1066
1067

        # MI300x and MI325x use FNUZ format for FP8. Convert if needed.
        if current_platform.is_fp8_fnuz():
1068
            w13_weight, w13_weight_scale, w13_input_scale = (
1069
                normalize_e4m3fn_to_e4m3fnuz(
1070
                    w13_weight, w13_weight_scale, w13_input_scale
1071
1072
                )
            )
1073
1074
            w2_weight, w2_weight_scale, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz(
                w2_weight, w2_weight_scale, w2_input_scale
1075
1076
1077
1078
1079
            )

        # Per tensor kernels require single activation scale. Use the max.
        if self.quant_config.activation_scheme == "static":
            assert not self.block_quant
1080
1081
            assert w13_input_scale is not None and w2_input_scale is not None
            if not all_close_1d(w13_input_scale) or not all_close_1d(w2_input_scale):
1082
1083
1084
1085
1086
                logger.warning_once(
                    "Found input_scales that are not equal for "
                    "fp8 MoE layer. Using the maximum across experts "
                    "for each layer."
                )
1087
1088
            replace_parameter(layer, "w13_input_scale", w13_input_scale.max())
            replace_parameter(layer, "w2_input_scale", w2_input_scale.max())
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109

        # Per tensor kernels require single weight scale for w13 per expert, but
        # on disk there is a scale for w1 and w3. Use the max to requantize.
        if not self.block_quant:
            shard_size = layer.intermediate_size_per_partition
            max_w13_scales = w13_weight_scale.max(dim=1).values
            for expert_id in range(layer.local_num_experts):
                start = 0
                for shard_id in range(2):
                    dq_weight = per_tensor_dequantize(
                        w13_weight[expert_id][start : start + shard_size, :],
                        w13_weight_scale[expert_id][shard_id],
                    )
                    w13_weight[expert_id][start : start + shard_size, :], _ = (
                        ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
                    )
                    start += shard_size
            w13_weight_scale = max_w13_scales

        # Shuffle weights into the runtime format.
        self._convert_weights_to_kernel_format(
1110
1111
1112
1113
1114
1115
1116
            layer=layer,
            w13_weight=w13_weight,
            w2_weight=w2_weight,
            w13_weight_scale=w13_weight_scale,
            w2_weight_scale=w2_weight_scale,
            w13_input_scale=w13_input_scale,
            w2_input_scale=w2_input_scale,
1117
1118
1119
1120
1121
        )

        # Setup modular kernel for TP case.
        self._setup_kernel(layer)

1122
1123
1124
1125
    def maybe_make_prepare_finalize(
        self,
        routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
    ) -> mk.FusedMoEPrepareAndFinalize | None:
1126
        if (
1127
            self.fp8_backend == Fp8MoeBackend.AITER
1128
            or self.fp8_backend == Fp8MoeBackend.MARLIN
1129
1130
            or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1131
            return None
1132
1133
1134
1135
1136
1137
1138
        elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
            prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
                self.moe,
                use_deepseek_fp8_block_scale=self.block_quant,
            )
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
1139
        return super().maybe_make_prepare_finalize(routing_tables)
1140

bnellnm's avatar
bnellnm committed
1141
1142
1143
    def select_gemm_impl(
        self,
        prepare_finalize: FusedMoEPrepareAndFinalize,
1144
        layer: torch.nn.Module,
bnellnm's avatar
bnellnm committed
1145
    ) -> FusedMoEPermuteExpertsUnpermute:
1146
        from vllm.model_executor.layers.fused_moe import (
1147
1148
            BatchedDeepGemmExperts,
            BatchedTritonExperts,
1149
            TritonExperts,
1150
1151
            TritonOrDeepGemmExperts,
        )
1152

1153
1154
1155
1156
        if self.fp8_backend in [Fp8MoeBackend.MARLIN, Fp8MoeBackend.AITER]:
            raise NotImplementedError(
                "Marlin and ROCm AITER are not supported with all2all yet."
            )
1157

1158
1159
        assert self.moe_quant_config is not None

1160
1161
1162
1163
1164
        if (
            prepare_finalize.activation_format
            == FusedMoEActivationFormat.BatchedExperts
        ):
            max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
bnellnm's avatar
bnellnm committed
1165
            assert max_num_tokens_per_rank is not None
1166
1167

            experts_impl = (
1168
1169
1170
                BatchedDeepGemmExperts
                if self.fp8_backend == Fp8MoeBackend.DEEPGEMM
                else BatchedTritonExperts
1171
            )
bnellnm's avatar
bnellnm committed
1172
            logger.debug(
1173
1174
                "%s(%s): max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
                experts_impl.__name__,
1175
1176
1177
1178
1179
                self.__class__.__name__,
                max_num_tokens_per_rank,
                self.weight_block_size,
                False,
            )
1180
            return experts_impl(
1181
                max_num_tokens=max_num_tokens_per_rank,
1182
                num_dispatchers=prepare_finalize.num_dispatchers(),
1183
                quant_config=self.moe_quant_config,
1184
            )
1185
1186
        elif self.moe.is_lora_enabled:
            return TritonExperts(quant_config=self.moe_quant_config)
1187
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
1188
            # Select GEMM experts with block-scale when weights are block-quantized
1189
            experts = select_cutlass_fp8_gemm_impl(
1190
1191
                self.moe,
                self.moe_quant_config,
1192
                use_deepseek_fp8_block_scale=self.block_quant,
1193
1194
1195
            )
            logger.debug_once("Using %s", experts.__class__.__name__)
            return experts
1196
        else:
bnellnm's avatar
bnellnm committed
1197
1198
            logger.debug(
                "TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s",
1199
1200
1201
1202
                self.__class__.__name__,
                self.weight_block_size,
                False,
            )
bnellnm's avatar
bnellnm committed
1203
            return TritonOrDeepGemmExperts(
1204
                quant_config=self.moe_quant_config,
1205
                allow_deep_gemm=(self.fp8_backend == Fp8MoeBackend.DEEPGEMM),
1206
1207
            )

1208
    def get_fused_moe_quant_config(
1209
        self, layer: torch.nn.Module
1210
    ) -> FusedMoEQuantConfig | None:
1211
1212
1213
1214
1215
        # TRTLLM does not use Modular Kernel.
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
            return None

        # MARLIN uses mixed precision W8A16 config.
1216
1217
        if self.fp8_backend == Fp8MoeBackend.MARLIN:
            return fp8_w8a16_moe_quant_config(
1218
1219
                w1_scale=getattr(layer, f"w13_{self.weight_scale_name}"),
                w2_scale=getattr(layer, f"w2_{self.weight_scale_name}"),
1220
1221
                block_shape=self.weight_block_size,
            )
1222

1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
        w1_scale = getattr(layer, f"w13_{self.weight_scale_name}")
        w2_scale = getattr(layer, f"w2_{self.weight_scale_name}")
        a1_scale = layer.w13_input_scale
        a2_scale = layer.w2_input_scale

        # Flashinfer CUTLASS per-tensor uses single dq scale
        # (alpha = w_scale * a_scale) and inverse a2 scale.
        if (
            self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS
            and not self.block_quant
        ):
            g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi(
                w1_scale,
                a1_scale,
                w2_scale,
                a2_scale,
            )
            return fp8_w8a8_moe_quant_config(
                w1_scale=w1_scale,
                w2_scale=w2_scale,
                a1_scale=a1_scale,
                a2_scale=(1.0 / a2_scale),
                g1_alphas=g1_alphas,
                g2_alphas=g2_alphas,
            )

        # All other backends use normal config.
1250
        return fp8_w8a8_moe_quant_config(
1251
1252
1253
1254
            w1_scale=w1_scale,
            w2_scale=w2_scale,
            a1_scale=a1_scale,
            a2_scale=a2_scale,
1255
            block_shape=self.weight_block_size,
1256
1257
        )

1258
1259
1260
1261
1262
1263
1264
1265
    @property
    def supports_eplb(self) -> bool:
        return True

    @property
    def allow_inplace(self) -> bool:
        return True

1266
1267
    def apply(
        self,
1268
        layer: FusedMoE,
1269
1270
        x: torch.Tensor,
        router_logits: torch.Tensor,
1271
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
1272
        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
1273
            # TODO(rob): convert this to MK.
1274
1275
1276
1277
            if layer.enable_eplb:
                raise NotImplementedError("EPLB not supported for `Fp8MoEMethod` yet.")
            assert layer.activation == "silu", (
                f"Expected 'silu' activation but got {layer.activation}"
1278
            )
1279

1280
            if self.block_quant:
1281
                import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe  # noqa: E501, F401
1282
1283

                e_score_correction_bias = (
1284
1285
                    layer.e_score_correction_bias.to(x.dtype)
                    if layer.e_score_correction_bias is not None
1286
1287
                    else None
                )
1288
                routing_method_type = layer.routing_method_type
1289
                return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
1290
1291
1292
                    routing_logits=router_logits.to(torch.float32)
                    if routing_method_type == RoutingMethodType.DeepSeekV3
                    else router_logits,
1293
1294
1295
1296
1297
1298
                    routing_bias=e_score_correction_bias,
                    x=x,
                    w13_weight=layer.w13_weight,
                    w13_weight_scale_inv=layer.w13_weight_scale_inv,
                    w2_weight=layer.w2_weight,
                    w2_weight_scale_inv=layer.w2_weight_scale_inv,
1299
1300
1301
1302
                    global_num_experts=layer.global_num_experts,
                    top_k=layer.top_k,
                    num_expert_group=layer.num_expert_group,
                    topk_group=layer.topk_group,
1303
1304
1305
                    intermediate_size=layer.intermediate_size_per_partition,
                    expert_offset=layer.ep_rank * layer.local_num_experts,
                    local_num_experts=layer.local_num_experts,
1306
                    block_shape=self.weight_block_size,
1307
                    routing_method_type=routing_method_type,
1308
                    routed_scaling=layer.routed_scaling_factor,
1309
1310
                )
            else:
1311
1312
1313
                assert (
                    not layer.renormalize and layer.custom_routing_function is not None
                )
XuruiYang's avatar
XuruiYang committed
1314
                result = apply_flashinfer_per_tensor_scale_fp8(
1315
1316
1317
                    layer=layer,
                    hidden_states=x,
                    router_logits=router_logits,
1318
1319
1320
1321
1322
1323
                    routing_bias=layer.e_score_correction_bias,
                    global_num_experts=layer.global_num_experts,
                    top_k=layer.top_k,
                    num_expert_group=layer.num_expert_group,
                    topk_group=layer.topk_group,
                    apply_router_weight_on_input=layer.apply_router_weight_on_input,
1324
                )
1325

1326
        topk_weights, topk_ids = layer.select_experts(
1327
1328
1329
            hidden_states=x,
            router_logits=router_logits,
        )
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
        result = self.kernel(
            x,
            layer.w13_weight,
            layer.w2_weight,
            topk_weights,
            topk_ids,
            inplace=self.use_inplace,
            activation=layer.activation,
            global_num_experts=layer.global_num_experts,
            expert_map=layer.expert_map,
            apply_router_weight_on_input=layer.apply_router_weight_on_input,
        )
1342

1343
        return result
1344
1345


1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
class Fp8OnlineMoEMethod(Fp8MoEMethod):
    """MoE method for online FP8 quantization.
    Supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

    def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
        super().__init__(quant_config, layer)
        assert not quant_config.is_checkpoint_fp8_serialized
        assert quant_config.activation_scheme == "dynamic"
        assert quant_config.weight_block_size is None
        assert self.flashinfer_moe_backend is None

    def create_weights(
        self,
        layer: Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

        # We are doing online quantization, patch the weight loaded
        # to call `process_weights_after_loading` in a streaming fashion
        # as soon as the last weight chunk is loaded.
        weight_loader = extra_weight_attrs["weight_loader"]
        # create a new holder to prevent modifying behavior of any other
        # objects which might depend on the old one
        new_extra_weight_attrs = extra_weight_attrs

        def patched_weight_loader(param, loaded_weight, *args, **kwargs):
            # add a counter to track how many elements we have updated
            if not hasattr(layer, "_loaded_numel"):
                layer._loaded_numel = 0
1390
1391
1392
1393
1394
1395

            # load the current weight chunk
            copy_numel_counter = CopyNumelCounter()
            with copy_numel_counter:
                res = weight_loader(param, loaded_weight, *args, **kwargs)  # type: ignore[misc]
            layer._loaded_numel += copy_numel_counter.copied_numel
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459

            # if we have loaded all of the elements, call
            # process_weights_after_loading
            target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel()
            if layer._loaded_numel == target_loaded_numel:
                self.process_weights_after_loading(layer)

                # Delete the bookkeeping
                del layer._loaded_numel
                # Prevent the usual `process_weights_after_loading` call
                # from doing anything
                layer._already_called_process_weights_after_loading = True

            return res

        new_extra_weight_attrs["weight_loader"] = patched_weight_loader
        extra_weight_attrs = new_extra_weight_attrs

        # WEIGHTS
        w13_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                hidden_size,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

        w2_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
        # Allocate 2 scales for w1 and w3 respectively.
        # They will be combined to a single scale after weight loading.
        w13_weight_scale = torch.nn.Parameter(
            torch.ones(num_experts, dtype=torch.float32), requires_grad=False
        )
        w2_weight_scale = torch.nn.Parameter(
            torch.ones(num_experts, dtype=torch.float32), requires_grad=False
        )
        layer.register_parameter("w13_weight_scale", w13_weight_scale)
        layer.register_parameter("w2_weight_scale", w2_weight_scale)

        layer.w13_input_scale = None
        layer.w2_input_scale = None

    def process_weights_after_loading(self, layer: Module) -> None:
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

        # If checkpoint is fp16, quantize in place.
        fp8_dtype = current_platform.fp8_dtype()
1460
1461
        w13_weight = torch.empty_like(layer.w13_weight, dtype=fp8_dtype)
        w2_weight = torch.empty_like(layer.w2_weight, dtype=fp8_dtype)
1462
1463
1464

        for expert in range(layer.local_num_experts):
            w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
1465
                ops.scaled_fp8_quant(layer.w13_weight[expert, :, :])
1466
1467
            )
            w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
1468
                ops.scaled_fp8_quant(layer.w2_weight[expert, :, :])
1469
1470
1471
1472
            )
        replace_parameter(layer, "w13_weight", w13_weight)
        replace_parameter(layer, "w2_weight", w2_weight)

1473
1474
        # Shuffle weights into the runtime format.
        self._convert_weights_to_kernel_format(
1475
1476
1477
1478
1479
1480
1481
            layer=layer,
            w13_weight=w13_weight,
            w2_weight=w2_weight,
            w13_weight_scale=layer.w13_weight_scale,
            w2_weight_scale=layer.w2_weight_scale,
            w13_input_scale=None,
            w2_input_scale=None,
1482
        )
1483

1484
1485
        # Setup modular kernel for TP case.
        self._setup_kernel(layer)
1486
1487


1488
1489
1490
class Fp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
1491
1492
1493
    """

    def __init__(self, quant_config: Fp8Config):
1494
        super().__init__(quant_config)