fp8.py 56 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from enum import Enum
5
from typing import TYPE_CHECKING, Any, Optional
6
7
8

import torch
from torch.nn import Module
9
from torch.utils._python_dispatch import TorchDispatchMode
10

11
import vllm.envs as envs
12
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
13
from vllm import _custom_ops as ops
14
from vllm._aiter_ops import rocm_aiter_ops
15
from vllm.attention.layer import Attention
16
from vllm.distributed import get_tensor_model_parallel_world_size
17
from vllm.logger import init_logger
18
from vllm.model_executor.layers.batch_invariant import (
19
    vllm_is_batch_invariant,
20
)
bnellnm's avatar
bnellnm committed
21
from vllm.model_executor.layers.fused_moe import (
22
23
24
25
26
27
28
    FusedMoE,
    FusedMoEActivationFormat,
    FusedMoEMethodBase,
    FusedMoEPermuteExpertsUnpermute,
    FusedMoEPrepareAndFinalize,
    FusedMoeWeightScaleSupported,
)
29
from vllm.model_executor.layers.fused_moe.config import (
30
    FusedMoEParallelConfig,
31
    FusedMoEQuantConfig,
32
    RoutingMethodType,
33
    fp8_w8a8_moe_quant_config,
34
    fp8_w8a16_moe_quant_config,
35
36
37
38
39
40
41
)
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
)
42
from vllm.model_executor.layers.quantization import QuantizationMethods
43
from vllm.model_executor.layers.quantization.base_config import (
44
45
46
    QuantizationConfig,
    QuantizeMethodBase,
)
47
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
48
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
49
50
51
52
53
54
55
56
    FlashinferMoeBackend,
    apply_flashinfer_per_tensor_scale_fp8,
    get_flashinfer_moe_backend,
    register_moe_scaling_factors,
    rotate_flashinfer_fp8_moe_weights,
    select_cutlass_fp8_gemm_impl,
    swap_w13_to_w31,
)
57
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
58
59
60
61
    W8A8BlockFp8LinearOp,
    create_fp8_input_scale,
    create_fp8_scale_parameter,
    create_fp8_weight_parameter,
62
    deepgemm_post_process_fp8_weight_block,
63
64
65
66
67
    maybe_post_process_fp8_weight_block,
    process_fp8_weight_block_strategy,
    process_fp8_weight_tensor_strategy,
    validate_fp8_block_shape,
)
68
69
70
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
    get_marlin_input_dtype,
)
71
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
72
73
74
75
    apply_fp8_marlin_linear,
    prepare_fp8_layer_for_marlin,
    prepare_moe_fp8_layer_for_marlin,
)
76
from vllm.model_executor.layers.quantization.utils.quant_utils import (
77
78
79
    GroupShape,
    is_layer_skipped,
)
80
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
81
82
83
84
85
86
87
88
89
90
91
92
93
    Fp8LinearOp,
    all_close_1d,
    cutlass_block_fp8_supported,
    cutlass_fp8_supported,
    maybe_create_device_identity,
    normalize_e4m3fn_to_e4m3fnuz,
    per_tensor_dequantize,
)
from vllm.model_executor.parameter import (
    BlockQuantScaleParameter,
    ModelWeightParameter,
    PerTensorScaleParameter,
)
94
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
95
from vllm.platforms import current_platform
96
97
98
99
from vllm.utils.deep_gemm import (
    is_deep_gemm_e8m0_used,
    is_deep_gemm_supported,
)
100
from vllm.utils.flashinfer import has_flashinfer_moe
101
from vllm.utils.import_utils import has_deep_gemm
102

103
104
105
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

106
107
108
109
ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)

110

111
112
113
114
115
class Fp8MoeBackend(Enum):
    NONE = 0
    FLASHINFER_TRTLLM = 1
    FLASHINFER_CUTLASS = 2
    DEEPGEMM = 3
116
117
    MARLIN = 4
    TRITON = 5
118
    AITER = 6
119
120


121
def get_fp8_moe_backend(
122
123
124
    block_quant: bool,
    moe_parallel_config: FusedMoEParallelConfig,
    with_lora_support: bool,
125
) -> Fp8MoeBackend | None:
126
127
128
129
    """
    Select the primary FP8 MoE backend
    Note: Shape-specific fallbacks may still occur at runtime.
    """
130
131
    if current_platform.is_xpu():
        return None
132
133
    if with_lora_support:
        return Fp8MoeBackend.TRITON
134
    # Prefer FlashInfer backends on supported GPUs; allow SM90 and SM100.
135
136
    if (
        current_platform.is_cuda()
137
        and (
138
            current_platform.is_device_capability_family(100)
139
140
            or current_platform.is_device_capability(90)
        )
141
142
143
        and envs.VLLM_USE_FLASHINFER_MOE_FP8
        and has_flashinfer_moe()
    ):
144
145
        backend = get_flashinfer_moe_backend()
        if backend == FlashinferMoeBackend.TENSORRT_LLM:
146
            logger.info_once("Using FlashInfer FP8 MoE TRTLLM backend for SM100")
147
148
            return Fp8MoeBackend.FLASHINFER_TRTLLM
        else:
149
            if block_quant and current_platform.is_device_capability_family(100):
150
151
152
153
154
155
                raise ValueError(
                    "FlashInfer FP8 MoE throughput backend does not "
                    "support block quantization. Please use "
                    "VLLM_FLASHINFER_MOE_BACKEND=latency "
                    "instead."
                )
156
            logger.info_once("Using FlashInfer FP8 MoE CUTLASS backend for SM90/SM100")
157
158
159
            return Fp8MoeBackend.FLASHINFER_CUTLASS

    # weight-only path for older GPUs without native FP8
160
161
162
163
    use_marlin = (
        not current_platform.has_device_capability(89)
        or envs.VLLM_TEST_FORCE_FP8_MARLIN
    )
164
165
166
167
168
169
    if current_platform.is_rocm():
        use_marlin = False
    if use_marlin:
        logger.info_once("Using Marlin backend for FP8 MoE")
        return Fp8MoeBackend.MARLIN

170
171
172
173
174
175
176
177
178
179
180
181
    # Determine if we should use DeepGEMM with block-quantized weights:
    # - If explicitly set by user, respect their choice
    # - If not explicitly set (default), disable when TP size is >= 8
    moe_use_deep_gemm = envs.VLLM_MOE_USE_DEEP_GEMM
    if not envs.is_set("VLLM_MOE_USE_DEEP_GEMM") and moe_parallel_config.tp_size >= 8:
        moe_use_deep_gemm = False
        logger.info_once(
            "DeepGEMM MoE is disabled by default when TP size is >= 8. "
            "Set VLLM_MOE_USE_DEEP_GEMM=1 to enable it.",
            scope="local",
        )

182
183
184
185
186
187
188
189
190
191
192
193
194
    # Determine if we should use DeepGEMM (top-level enable switch)
    # - If explicitly set by user, respect their choice
    # - If not platform supports DeepGEMM, disable it
    # This helps avoid warning messages on unsupported platforms.
    use_deep_gemm = envs.VLLM_USE_DEEP_GEMM
    if not is_deep_gemm_supported():
        use_deep_gemm = False
        logger.info_once(
            "DeepGEMM is disabled because the platform does not support it.",
            scope="local",
        )

    if use_deep_gemm and moe_use_deep_gemm and block_quant:
195
        if not has_deep_gemm():
196
197
198
            logger.warning_once(
                "DeepGEMM backend requested but not available.", scope="local"
            )
199
        elif is_deep_gemm_supported():
200
            logger.info_once("Using DeepGEMM backend for FP8 MoE", scope="local")
201
202
            return Fp8MoeBackend.DEEPGEMM

203
204
205
206
    if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MOE:
        logger.info_once("Using ROCm AITER backend for FP8 MoE", scope="local")
        return Fp8MoeBackend.AITER

207
208
209
210
211
    # default to Triton
    logger.info_once("Using Triton backend for FP8 MoE")
    return Fp8MoeBackend.TRITON


212
class Fp8Config(QuantizationConfig):
213
214
    """Config class for FP8."""

215
216
    def __init__(
        self,
217
        is_checkpoint_fp8_serialized: bool = False,
218
        activation_scheme: str = "dynamic",
219
220
        ignored_layers: list[str] | None = None,
        weight_block_size: list[int] | None = None,
221
    ) -> None:
222
        super().__init__()
223

224
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
225

226
        if activation_scheme not in ACTIVATION_SCHEMES:
227
            raise ValueError(f"Unsupported activation scheme {activation_scheme}")
228
        self.activation_scheme = activation_scheme
229
        self.ignored_layers = ignored_layers or []
230
231
232
233
        if weight_block_size is not None:
            if not is_checkpoint_fp8_serialized:
                raise ValueError(
                    "The block-wise quantization only supports fp8-serialized "
234
235
                    "checkpoint for now."
                )
236
237
238
            if len(weight_block_size) != 2:
                raise ValueError(
                    "The quantization block size of weight must have 2 "
239
240
                    f"dimensions, but got {len(weight_block_size)} dimensions"
                )
241
            if activation_scheme != "dynamic":
242
243
244
245
246
                raise ValueError(
                    "The block-wise quantization only supports "
                    "dynamic activation scheme for now, but got "
                    f"{activation_scheme} activation scheme."
                )
247
        self.weight_block_size = weight_block_size
248

249
    @classmethod
250
    def get_name(cls) -> QuantizationMethods:
251
252
253
        return "fp8"

    @classmethod
254
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
255
256
257
258
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
259
        return 75
260
261

    @classmethod
262
    def get_config_filenames(cls) -> list[str]:
263
264
        return []

265
266
    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if self.ignored_layers is not None:
267
            self.ignored_layers = hf_to_vllm_mapper.apply_list(self.ignored_layers)
268

269
    @classmethod
270
    def from_config(cls, config: dict[str, Any]) -> "Fp8Config":
271
        quant_method = cls.get_from_keys(config, ["quant_method"])
272
        is_checkpoint_fp8_serialized = "fp8" in quant_method
273
        activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
274
        ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
275
        weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None)
276
        if not ignored_layers:
277
278
279
280
281
282
283
284
285
286
287
288
289
            ignored_layers = cls.get_from_keys_or(
                config, ["modules_to_not_convert"], None
            )
        return cls(
            is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
            activation_scheme=activation_scheme,
            ignored_layers=ignored_layers,
            weight_block_size=weight_block_size,
        )

    def get_xpu_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
290
        from vllm.model_executor.layers.quantization.ipex_quant import (
291
292
293
294
            XPUFp8LinearMethod,
            XPUFp8MoEMethod,
        )

295
296
297
298
        fp8_config = Fp8Config(
            is_checkpoint_fp8_serialized=self.is_checkpoint_fp8_serialized,
            activation_scheme=self.activation_scheme,
            ignored_layers=self.ignored_layers,
299
300
            weight_block_size=self.weight_block_size,
        )
301
302

        if isinstance(layer, LinearBase):
303
304
305
306
307
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
308
309
310
                return UnquantizedLinearMethod()
            return XPUFp8LinearMethod(fp8_config)
        elif isinstance(layer, FusedMoE):
311
312
313
314
315
316
317
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
                return UnquantizedFusedMoEMethod(layer.moe_config)

318
319
320
321
322
            return XPUFp8MoEMethod(fp8_config, layer)
        elif isinstance(layer, Attention):
            return Fp8KVCacheMethod(self)
        return None

323
324
325
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
326
327
        if current_platform.is_xpu():
            return self.get_xpu_quant_method(layer, prefix)
328
        if isinstance(layer, LinearBase):
329
330
331
332
333
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
334
                return UnquantizedLinearMethod()
335
336
337
            quant_method = Fp8LinearMethod(self)
            quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
            return quant_method
338
        elif isinstance(layer, FusedMoE):
339
340
341
342
343
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
XuruiYang's avatar
XuruiYang committed
344
                return UnquantizedFusedMoEMethod(layer.moe_config)
345
346
347
348
            if self.is_checkpoint_fp8_serialized:
                moe_quant_method = Fp8MoEMethod(self, layer)
            else:
                moe_quant_method = Fp8OnlineMoEMethod(self, layer)
349
350
            moe_quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
            return moe_quant_method
351
        elif isinstance(layer, Attention):
352
            return Fp8KVCacheMethod(self)
353
        return None
354

355
    def get_cache_scale(self, name: str) -> str | None:
356
357
358
359
360
361
362
363
364
365
366
367
        """
        Check whether the param name matches the format for k/v cache scales
        in compressed-tensors. If this is the case, return its equivalent
        param name expected by vLLM

        :param name: param name
        :return: matching param name for KV cache scale in vLLM
        """
        if name.endswith(".output_scale") and ".k_proj" in name:
            return name.replace(".k_proj.output_scale", ".attn.k_scale")
        if name.endswith(".output_scale") and ".v_proj" in name:
            return name.replace(".v_proj.output_scale", ".attn.v_scale")
368
369
370
371
372
        if name.endswith(".output_scale") and ".q_proj" in name:
            return name.replace(".q_proj.output_scale", ".attn.q_scale")
        if name.endswith("self_attn.prob_output_scale"):
            return name.replace(".prob_output_scale", ".attn.prob_scale")
        # If no matches, return None
373
374
        return None

375

376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
class CopyNumelCounter(TorchDispatchMode):
    """
    Tracks total number of elements modified with `copy_`. Useful for keeping
    track of weight loading where underlying weights can be arbitrarily
    transformed (such as with `narrow`) before calling copy.
    """

    def __init__(self):
        super().__init__()
        self.copied_numel = 0

    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        out = func(*args, **kwargs)
        if func == torch.ops.aten.copy_.default:
            self.copied_numel += args[0].numel()
        return out


396
397
class Fp8LinearMethod(LinearMethodBase):
    """Linear method for FP8.
398
399
400
401
402
403
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.
404
405
406
407
408

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
    2. Only support float8_e4m3fn data type due to the limitation of
       torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
409

410
411
412
413
    Args:
        quant_config: The quantization config.
    """

414
    def __init__(self, quant_config: Fp8Config):
415
        self.quant_config = quant_config
416
        self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
417
        self.out_dtype = torch.get_default_dtype()
418

419
420
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
421
        self.marlin_input_dtype = None
422
423
424
425
        self.use_marlin = (
            not current_platform.has_device_capability(89)
            or envs.VLLM_TEST_FORCE_FP8_MARLIN
        )
426
        # Disable marlin for rocm
427
        if current_platform.is_rocm():
428
            self.use_marlin = False
429
        if vllm_is_batch_invariant():
430
            self.use_marlin = False
431

432
        self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled()
433
        self.use_deep_gemm = is_deep_gemm_supported()
434

435
436
        self.weight_block_size = self.quant_config.weight_block_size
        self.block_quant = self.weight_block_size is not None
437
        self.act_q_static = self.quant_config.activation_scheme == "static"
438
439
        if self.weight_block_size:
            self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
440
        else:
441
442
443
444
445
            # Use per-token quantization for better perf if dynamic and cutlass
            if not self.act_q_static and cutlass_fp8_supported():
                self.act_q_group_shape = GroupShape.PER_TOKEN
            else:
                self.act_q_group_shape = GroupShape.PER_TENSOR
446

447
448
449
450
451
452
453
454
455
456
457
458
        if self.block_quant:
            assert not self.act_q_static
            assert self.weight_block_size is not None
            self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
                weight_group_shape=GroupShape(*self.weight_block_size),
                act_quant_group_shape=self.act_q_group_shape,
                cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
                use_aiter_and_is_supported=self.use_aiter_and_is_supported,
            )
        else:
            self.fp8_linear = Fp8LinearOp(
                act_quant_static=self.act_q_static,
459
460
                act_quant_group_shape=self.act_q_group_shape,
            )
461

462
463
464
465
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
466
        output_partition_sizes: list[int],
467
468
469
470
471
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
472
473
        maybe_create_device_identity()

474
        output_size_per_partition = sum(output_partition_sizes)
475
        weight_loader = extra_weight_attrs.get("weight_loader")
476
477
478
479
480
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None
481

482
        if self.block_quant:
483
484
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
485
486
487
488
489
490
491
492
            validate_fp8_block_shape(
                layer,
                input_size,
                output_size,
                input_size_per_partition,
                output_partition_sizes,
                self.weight_block_size,
            )
493

494
        # WEIGHT
495
        if self.quant_config.is_checkpoint_fp8_serialized:
496
497
498
            weight = create_fp8_weight_parameter(
                output_size_per_partition, input_size_per_partition, weight_loader
            )
499
        else:
500
501
502
503
504

            def patched_weight_loader(param, loaded_weight, *args, **kwargs):
                # track how many elements we have updated
                if not hasattr(layer, "_loaded_numel"):
                    layer._loaded_numel = 0
505
506
507
508
509
510

                # load the current weight chunk
                copy_numel_counter = CopyNumelCounter()
                with copy_numel_counter:
                    res = weight_loader(param, loaded_weight, *args, **kwargs)  # type: ignore[misc]
                layer._loaded_numel += copy_numel_counter.copied_numel
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525

                # if we have loaded all of the elements, call
                # process_weights_after_loading
                target_loaded_numel = layer.weight.numel()
                if layer._loaded_numel == target_loaded_numel:
                    self.process_weights_after_loading(layer)

                    # Delete the bookkeeping
                    del layer._loaded_numel
                    # Prevent the usual `process_weights_after_loading` call from doing
                    # anything
                    layer._already_called_process_weights_after_loading = True

                return res

526
            # For non-serialized checkpoints, use original dtype
527
528
529
530
531
532
533
534
            weight = ModelWeightParameter(
                data=torch.empty(
                    output_size_per_partition,
                    input_size_per_partition,
                    dtype=params_dtype,
                ),
                input_dim=1,
                output_dim=0,
535
                weight_loader=patched_weight_loader,
536
            )
537
538
        layer.register_parameter("weight", weight)

539
540
541
542
        # If checkpoint is serialized fp8, load them.
        # Otherwise, wait until process_weights_after_loading.
        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
543
            if not self.block_quant:
544
545
546
547
548
549
550
                scale = create_fp8_scale_parameter(
                    PerTensorScaleParameter,
                    output_partition_sizes,
                    input_size_per_partition,
                    None,
                    weight_loader,
                )
551
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
552
553
                layer.register_parameter("weight_scale", scale)
            else:
554
555
                assert not self.act_q_static
                assert self.weight_block_size is not None
556
557
558
559
560
561
562
                scale = create_fp8_scale_parameter(
                    BlockQuantScaleParameter,
                    output_partition_sizes,
                    input_size_per_partition,
                    self.weight_block_size,
                    weight_loader,
                )
563
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
564
565
                # The weight_scale_inv name is intentional for deepseekv3
                layer.register_parameter("weight_scale_inv", scale)
566

567
            # INPUT ACTIVATION SCALE
568
            if self.act_q_static:
569
                scale = create_fp8_input_scale(output_partition_sizes, weight_loader)
570
                set_weight_attrs(scale, {"scale_type": "input_scale"})
571
                layer.register_parameter("input_scale", scale)
572
573
            else:
                layer.register_parameter("input_scale", None)
574

575
    def process_weights_after_loading(self, layer: Module) -> None:
576
577
578
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

579
        size_k_first = True
580
        input_scale = None
581
        # TODO(rob): refactor block quant into separate class.
582
        if self.block_quant:
583
            assert not self.act_q_static
584
            size_k_first = False
585

586
            weight, weight_scale_inv = process_fp8_weight_block_strategy(
587
588
                layer.weight, layer.weight_scale_inv
            )
589
590
591
592

            # Update layer with new values
            replace_parameter(layer, "weight", weight.data)
            replace_parameter(layer, "weight_scale_inv", weight_scale_inv.data)
593

594
        # If checkpoint not serialized fp8, quantize the weights.
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
        else:
            if not self.quant_config.is_checkpoint_fp8_serialized:
                qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
                weight = qweight.t()

            # If checkpoint is fp8 per-tensor, handle that there are N scales for N
            # shards in a fused module
            else:
                weight = layer.weight
                weight_scale = layer.weight_scale

                # If using w8a8, torch._scaled_mm needs per tensor, so
                # requantize the logical shards as a single weight.
                if not self.use_marlin:
                    weight, weight_scale, input_scale = (
                        process_fp8_weight_tensor_strategy(
                            weight,
                            weight_scale,
                            layer.logical_widths,
                            getattr(layer, "input_scale", None),
                        )
                    )
                    if self.act_q_static:
                        assert input_scale is not None
                        input_scale = input_scale.max()
                weight = weight.t()
621

622
623
624
625
626
627
            # Update layer with new values.
            replace_parameter(layer, "weight", weight.data)
            replace_parameter(layer, "weight_scale", weight_scale.data)

        if input_scale is not None:
            replace_parameter(layer, "input_scale", input_scale)
628
        else:
629
            layer.input_scale = None
630

631
        if self.use_marlin:
632
633
634
            prepare_fp8_layer_for_marlin(
                layer, size_k_first, input_dtype=self.marlin_input_dtype
            )
635
636
            # Activations not quantized for marlin.
            del layer.input_scale
637
            return
638

639
        if self.block_quant:
640
            maybe_post_process_fp8_weight_block(layer)
641

642
643
644
645
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
646
        bias: torch.Tensor | None = None,
647
    ) -> torch.Tensor:
648
649
        # if batch invariant mode is enabled, prefer DeepGEMM FP8 path
        # we will use BF16 dequant when DeepGEMM is not supported.
650
        if vllm_is_batch_invariant():
651
652
            if self.block_quant:
                assert self.weight_block_size is not None
653
654
655
                return self.w8a8_block_fp8_linear.apply(
                    input=x,
                    weight=layer.weight,
656
                    weight_scale=layer.weight_scale_inv,
657
658
659
                    input_scale=layer.input_scale,
                    bias=bias,
                )
660
            else:
661
662
663
                # per-tensor/channel: dequant to BF16 and run GEMM
                weight_fp8 = layer.weight.to(torch.bfloat16)
                weight_scale = layer.weight_scale.to(torch.bfloat16)
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
                if weight_scale.numel() == 1:
                    # Per-tensor: simple scalar multiplication
                    weight_bf16 = weight_fp8 * weight_scale
                else:
                    # Multiple scales (fused modules like QKV)
                    # Try to infer correct broadcasting
                    # weight is [K, N], scale could be [num_logical_weights]
                    # Need to figure out how to broadcast - for now just try
                    # direct multiplication
                    if (
                        weight_scale.dim() == 1
                        and weight_scale.shape[0] == weight_fp8.shape[0]
                    ):
                        # Per-row scaling
                        weight_bf16 = weight_fp8 * weight_scale.unsqueeze(1)
                    else:
                        # Fallback
                        weight_bf16 = weight_fp8 * weight_scale
682
                return torch.nn.functional.linear(x, weight_bf16.t(), bias)
683

684
        if self.use_marlin:
685
686
687
688
689
            if self.block_quant:
                weight_scale = layer.weight_scale_inv
            else:
                weight_scale = layer.weight_scale

690
691
692
            return apply_fp8_marlin_linear(
                input=x,
                weight=layer.weight,
693
                weight_scale=weight_scale,
694
695
696
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
697
                input_dtype=self.marlin_input_dtype,
698
699
                bias=bias,
            )
700

701
        if self.block_quant:
702
703
704
            assert self.weight_block_size is not None

            return self.w8a8_block_fp8_linear.apply(
705
                input=x,
706
                weight=layer.weight,
707
                weight_scale=layer.weight_scale_inv,
708
                input_scale=layer.input_scale,
709
                bias=bias,
710
            )
711

712
713
714
715
716
717
718
719
        return self.fp8_linear.apply(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            out_dtype=self.out_dtype,
            input_scale=layer.input_scale,
            bias=bias,
        )
720
721


722
723
724
725
726
727
728
729
730
731
732
733
734
class Fp8MoEMethod(FusedMoEMethodBase):
    """MoE method for FP8.
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

735
736
737
    def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
        super().__init__(layer.moe_config)
        self.layer = layer
738
        self.quant_config = quant_config
739
        self.weight_block_size = self.quant_config.weight_block_size
740
        self.block_quant: bool = self.weight_block_size is not None
741
742
743
        self.weight_scale_name = (
            "weight_scale_inv" if self.block_quant else "weight_scale"
        )
744
        self.fp8_backend = get_fp8_moe_backend(
745
            self.block_quant, layer.moe_parallel_config, self.moe.is_lora_enabled
746
        )
747

748
        self.marlin_input_dtype = None
749
        self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
750
751
752
753
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
            self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
        elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
            self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
            if self.block_quant and self.weight_block_size != [128, 128]:
                raise NotImplementedError(
                    "FlashInfer CUTLASS FP8 MoE backend only supports block "
                    "size [128, 128]."
                )
            if not self.block_quant:
                if layer.renormalize or layer.custom_routing_function is not None:
                    raise NotImplementedError(
                        "FlashInfer CUTLASS FP8 MoE backend does custom routing "
                        f"function or renormalization, but got {layer.renormalize} and "
                        f"{layer.custom_routing_function}."
                    )
                if layer.scoring_func != "sigmoid":
                    raise NotImplementedError(
                        "FlashInfer CUTLASS FP8 MoE backend only supports "
                        f"'sigmoid' scoring function, but got {layer.scoring_func}."
                    )
            if layer.activation != "silu":
                raise NotImplementedError(
                    "FlashInfer CUTLASS FP8 MoE backend only supports SiLU "
                    "activation function, but got {layer.activation}."
775
                )
776

777
778
779
780
781
782
783
784
785
    def create_weights(
        self,
        layer: Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
786
787
788
789
790
791
        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

792
793
794
        assert self.quant_config.is_checkpoint_fp8_serialized
        params_dtype = torch.float8_e4m3fn

795
        if self.block_quant:
796
797
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
798
799
            tp_size = get_tensor_model_parallel_world_size()
            block_n, block_k = (
800
801
                self.weight_block_size[0],
                self.weight_block_size[1],
802
803
804
805
806
            )
            # NOTE: To ensure proper alignment of the block-wise quantization
            # scales, the output_size of the weights for both the gate and up
            # layers must be divisible by block_n.
            # Required by column parallel or enabling merged weights
807
            if intermediate_size_per_partition % block_n != 0:
808
809
                raise ValueError(
                    f"The output_size of gate's and up's weight = "
810
                    f"{intermediate_size_per_partition} is not divisible by "
811
812
813
                    f"weight quantization block_n = {block_n}."
                )
            if tp_size > 1 and intermediate_size_per_partition % block_k != 0:
814
                # Required by row parallel
815
816
817
                raise ValueError(
                    f"The input_size of down's weight = "
                    f"{intermediate_size_per_partition} is not divisible by "
818
819
                    f"weight quantization block_k = {block_k}."
                )
820
821

        # WEIGHTS
822
823
824
825
826
827
828
829
830
        w13_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                hidden_size,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
831
832
833
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

834
835
836
837
838
839
840
841
842
        w2_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
843
844
845
846
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
847
        if not self.block_quant:
848
849
850
            # For per-tensor quant, the scales are per expert and weight.
            w13_scale_data = torch.ones(num_experts, 2, dtype=torch.float32)
            w2_scale_data = torch.ones(num_experts, dtype=torch.float32)
851
        else:
852
853
854
855
856
857
            # For block quant, the scales are per block (typically 128x128).
            w13_scale_data = torch.ones(
                num_experts,
                2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
                (hidden_size + block_k - 1) // block_k,
                dtype=torch.float32,
858
            )
859
860
861
862
863
            w2_scale_data = torch.ones(
                num_experts,
                (hidden_size + block_n - 1) // block_n,
                (intermediate_size_per_partition + block_k - 1) // block_k,
                dtype=torch.float32,
864
            )
865
866
867
868
869
        w13_weight_scale = torch.nn.Parameter(w13_scale_data, requires_grad=False)
        w2_weight_scale = torch.nn.Parameter(w2_scale_data, requires_grad=False)
        # Note: name is weight_scale for tensor, weight_scale_inv for block.
        layer.register_parameter(f"w13_{self.weight_scale_name}", w13_weight_scale)
        layer.register_parameter(f"w2_{self.weight_scale_name}", w2_weight_scale)
870

871
872
873
        # Add the quantization method used (per tensor/grouped/channel)
        # to ensure the weight scales are loaded in properly
        extra_weight_attrs.update(
874
875
876
877
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
            if self.block_quant
            else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
        )
878
879
        set_weight_attrs(w13_weight_scale, extra_weight_attrs)
        set_weight_attrs(w2_weight_scale, extra_weight_attrs)
880
881
882

        # INPUT_SCALES
        if self.quant_config.activation_scheme == "static":
883
            assert not self.block_quant
884
885
886
            w13_input_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
887
            layer.register_parameter("w13_input_scale", w13_input_scale)
888
            set_weight_attrs(w13_input_scale, extra_weight_attrs)
889

890
891
892
            w2_input_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
893
            layer.register_parameter("w2_input_scale", w2_input_scale)
894
895
            set_weight_attrs(w2_input_scale, extra_weight_attrs)

896
        else:
897
898
            layer.w13_input_scale = None
            layer.w2_input_scale = None
899

900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
    def _convert_weights_to_kernel_format(
        self,
        layer: Module,
        w13_weight: torch.Tensor,
        w2_weight: torch.Tensor,
        w13_weight_scale: torch.Tensor,
        w2_weight_scale: torch.Tensor,
    ) -> None:
        if self.fp8_backend == Fp8MoeBackend.DEEPGEMM:
            assert self.block_quant
            w13_weight, w13_weight_scale = deepgemm_post_process_fp8_weight_block(
                wq=w13_weight,
                ws=w13_weight_scale,
                quant_block_shape=tuple(layer.weight_block_size),
                use_e8m0=is_deep_gemm_e8m0_used(),
            )
            w2_weight, w2_weight_scale = deepgemm_post_process_fp8_weight_block(
                wq=w2_weight,
                ws=w2_weight_scale,
                quant_block_shape=tuple(layer.weight_block_size),
                use_e8m0=is_deep_gemm_e8m0_used(),
            )
        elif self.fp8_backend == Fp8MoeBackend.AITER:
            w13_weight, w2_weight = rocm_aiter_ops.shuffle_weights(
                w13_weight, w2_weight
            )
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
        elif self.fp8_backend == Fp8MoeBackend.MARLIN:
            (
                workspace,
                w13_weight,
                w2_weight,
                w13_weight_scale,
                w2_weight_scale,
            ) = prepare_moe_fp8_layer_for_marlin(
                layer,
                w13_weight,
                w2_weight,
                w13_weight_scale,
                w2_weight_scale,
                input_dtype=self.marlin_input_dtype,
            )
            layer.workspace = workspace

943
944
945
946
947
948
949
        elif self.fp8_backend in [
            Fp8MoeBackend.FLASHINFER_CUTLASS,
            Fp8MoeBackend.FLASHINFER_TRTLLM,
        ]:
            w13_weight = swap_w13_to_w31(w13_weight)
            if self.block_quant:
                w13_weight_scale = swap_w13_to_w31(w13_weight_scale)
950
            else:
951
952
                # TODO(rob): this function is a hack that renames the scaling
                # factors in the Module. This is a hack we should clean up.
953
                register_moe_scaling_factors(layer)
954
                if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
955
                    rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
956
957
958
959
        elif self.fp8_backend == Fp8MoeBackend.AITER:
            w13_weight, w2_weight = rocm_aiter_ops.shuffle_weights(
                w13_weight, w2_weight
            )
960

961
962
963
964
965
966
967
968
969
        # Replace parameters with updated versions. Note that this helper
        # function ensures the replacement is compatible with RL weight reloads.
        replace_parameter(layer, "w13_weight", w13_weight)
        replace_parameter(layer, "w2_weight", w2_weight)
        replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_weight_scale)
        replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_weight_scale)

    def _setup_kernel(self, layer: Module) -> None:
        """Setup Modular Kernel for TP Case"""
970
971
972
973
974
975
        # NOTE(rob): this is a WIP refactor. We are first migrating
        # all of the kernels in the TP case to use mk. Once this is
        # done, then we will initialzie the TP case and DP/EP case
        # via the same code path (i.e. via maybe_init_modular_kernel).
        # NOTE(rob): in progress migrating all into this format.

976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
        from vllm.model_executor.layers.fused_moe import (
            TritonOrDeepGemmExperts,
        )
        from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
            FlashInferExperts,
        )
        from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
            MarlinExperts,
        )
        from vllm.model_executor.layers.fused_moe.prepare_finalize import (
            MoEPrepareAndFinalizeNoEP,
        )
        from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
            AiterExperts,
        )

        self.moe_quant_config = self.get_fused_moe_quant_config(layer)
        assert self.moe_quant_config is not None
        self.use_inplace = True
995

996
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
997
            self.kernel = mk.FusedMoEModularKernel(
998
999
                # TODO: make defer_input_quant an attr of the FlashInferExperts
                MoEPrepareAndFinalizeNoEP(defer_input_quant=self.block_quant),
1000
                FlashInferExperts(
1001
                    out_dtype=layer.orig_dtype,
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
                    quant_config=self.moe_quant_config,
                    ep_rank=self.moe.ep_rank,
                    ep_size=self.moe.ep_size,
                    tp_rank=self.moe.tp_rank,
                    tp_size=self.moe.tp_size,
                    use_dp=(self.moe.dp_size > 1),
                    use_deepseek_fp8_block_scale=self.block_quant,
                ),
            )
            self.use_inplace = False

1013
1014
1015
1016
1017
        elif self.fp8_backend == Fp8MoeBackend.AITER:
            self.kernel = mk.FusedMoEModularKernel(
                # TODO: make defer_input_quant an attr of the AiterExperts
                MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
                AiterExperts(quant_config=self.moe_quant_config),
1018
            )
1019
1020
1021
1022
        elif self.fp8_backend == Fp8MoeBackend.MARLIN:
            self.kernel = mk.FusedMoEModularKernel(
                MoEPrepareAndFinalizeNoEP(),
                MarlinExperts(quant_config=self.moe_quant_config),
1023
            )
1024
1025
1026
1027
1028
1029
1030
        else:
            self.kernel = mk.FusedMoEModularKernel(
                MoEPrepareAndFinalizeNoEP(),
                TritonOrDeepGemmExperts(
                    quant_config=self.moe_quant_config,
                    allow_deep_gemm=(self.fp8_backend == Fp8MoeBackend.DEEPGEMM),
                ),
1031
            )
1032

1033
1034
1035
1036
1037
1038
1039
1040
1041
    def process_weights_after_loading(self, layer: Module) -> None:
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

        # Allow for accessing weights and scales in standard way.
        w13_weight = layer.w13_weight
        w2_weight = layer.w2_weight
        w13_weight_scale = getattr(layer, f"w13_{self.weight_scale_name}")
        w2_weight_scale = getattr(layer, f"w2_{self.weight_scale_name}")
1042
1043
        w13_input_scale = layer.w13_input_scale
        w2_input_scale = layer.w2_input_scale
1044
1045
1046

        # MI300x and MI325x use FNUZ format for FP8. Convert if needed.
        if current_platform.is_fp8_fnuz():
1047
            w13_weight, w13_weight_scale, w13_input_scale = (
1048
                normalize_e4m3fn_to_e4m3fnuz(
1049
                    w13_weight, w13_weight_scale, w13_input_scale
1050
1051
                )
            )
1052
1053
            w2_weight, w2_weight_scale, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz(
                w2_weight, w2_weight_scale, w2_input_scale
1054
1055
1056
1057
1058
            )

        # Per tensor kernels require single activation scale. Use the max.
        if self.quant_config.activation_scheme == "static":
            assert not self.block_quant
1059
1060
            assert w13_input_scale is not None and w2_input_scale is not None
            if not all_close_1d(w13_input_scale) or not all_close_1d(w2_input_scale):
1061
1062
1063
1064
1065
                logger.warning_once(
                    "Found input_scales that are not equal for "
                    "fp8 MoE layer. Using the maximum across experts "
                    "for each layer."
                )
1066
1067
            replace_parameter(layer, "w13_input_scale", w13_input_scale.max())
            replace_parameter(layer, "w2_input_scale", w2_input_scale.max())
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094

        # Per tensor kernels require single weight scale for w13 per expert, but
        # on disk there is a scale for w1 and w3. Use the max to requantize.
        if not self.block_quant:
            shard_size = layer.intermediate_size_per_partition
            max_w13_scales = w13_weight_scale.max(dim=1).values
            for expert_id in range(layer.local_num_experts):
                start = 0
                for shard_id in range(2):
                    dq_weight = per_tensor_dequantize(
                        w13_weight[expert_id][start : start + shard_size, :],
                        w13_weight_scale[expert_id][shard_id],
                    )
                    w13_weight[expert_id][start : start + shard_size, :], _ = (
                        ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
                    )
                    start += shard_size
            w13_weight_scale = max_w13_scales

        # Shuffle weights into the runtime format.
        self._convert_weights_to_kernel_format(
            layer, w13_weight, w2_weight, w13_weight_scale, w2_weight_scale
        )

        # Setup modular kernel for TP case.
        self._setup_kernel(layer)

1095
1096
1097
1098
    def maybe_make_prepare_finalize(
        self,
        routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
    ) -> mk.FusedMoEPrepareAndFinalize | None:
1099
        if (
1100
            self.fp8_backend == Fp8MoeBackend.AITER
1101
            or self.fp8_backend == Fp8MoeBackend.MARLIN
1102
1103
            or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1104
            return None
1105
        return super().maybe_make_prepare_finalize(routing_tables)
1106

bnellnm's avatar
bnellnm committed
1107
1108
1109
    def select_gemm_impl(
        self,
        prepare_finalize: FusedMoEPrepareAndFinalize,
1110
        layer: torch.nn.Module,
bnellnm's avatar
bnellnm committed
1111
    ) -> FusedMoEPermuteExpertsUnpermute:
1112
        from vllm.model_executor.layers.fused_moe import (
1113
1114
            BatchedDeepGemmExperts,
            BatchedTritonExperts,
1115
            TritonExperts,
1116
1117
            TritonOrDeepGemmExperts,
        )
1118

1119
1120
1121
1122
        if self.fp8_backend in [Fp8MoeBackend.MARLIN, Fp8MoeBackend.AITER]:
            raise NotImplementedError(
                "Marlin and ROCm AITER are not supported with all2all yet."
            )
1123

1124
1125
        assert self.moe_quant_config is not None

1126
1127
1128
1129
1130
        if (
            prepare_finalize.activation_format
            == FusedMoEActivationFormat.BatchedExperts
        ):
            max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
bnellnm's avatar
bnellnm committed
1131
            assert max_num_tokens_per_rank is not None
1132
1133

            experts_impl = (
1134
1135
1136
                BatchedDeepGemmExperts
                if self.fp8_backend == Fp8MoeBackend.DEEPGEMM
                else BatchedTritonExperts
1137
            )
bnellnm's avatar
bnellnm committed
1138
            logger.debug(
1139
1140
                "%s(%s): max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
                experts_impl.__name__,
1141
1142
1143
1144
1145
                self.__class__.__name__,
                max_num_tokens_per_rank,
                self.weight_block_size,
                False,
            )
1146
            return experts_impl(
1147
                max_num_tokens=max_num_tokens_per_rank,
1148
                num_dispatchers=prepare_finalize.num_dispatchers(),
1149
                quant_config=self.moe_quant_config,
1150
            )
1151
1152
        elif self.moe.is_lora_enabled:
            return TritonExperts(quant_config=self.moe_quant_config)
1153
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
1154
            # Select GEMM experts with block-scale when weights are block-quantized
1155
            experts = select_cutlass_fp8_gemm_impl(
1156
1157
                self.moe,
                self.moe_quant_config,
1158
                use_deepseek_fp8_block_scale=self.block_quant,
1159
1160
1161
            )
            logger.debug_once("Using %s", experts.__class__.__name__)
            return experts
1162
        else:
bnellnm's avatar
bnellnm committed
1163
1164
            logger.debug(
                "TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s",
1165
1166
1167
1168
                self.__class__.__name__,
                self.weight_block_size,
                False,
            )
bnellnm's avatar
bnellnm committed
1169
            return TritonOrDeepGemmExperts(
1170
                quant_config=self.moe_quant_config,
1171
                allow_deep_gemm=(self.fp8_backend == Fp8MoeBackend.DEEPGEMM),
1172
1173
            )

1174
    def get_fused_moe_quant_config(
1175
        self, layer: torch.nn.Module
1176
    ) -> FusedMoEQuantConfig | None:
1177
1178
        if self.fp8_backend == Fp8MoeBackend.MARLIN:
            return fp8_w8a16_moe_quant_config(
1179
1180
                w1_scale=getattr(layer, f"w13_{self.weight_scale_name}"),
                w2_scale=getattr(layer, f"w2_{self.weight_scale_name}"),
1181
1182
                block_shape=self.weight_block_size,
            )
1183
1184

        return fp8_w8a8_moe_quant_config(
1185
1186
            w1_scale=getattr(layer, f"w13_{self.weight_scale_name}"),
            w2_scale=getattr(layer, f"w2_{self.weight_scale_name}"),
1187
1188
            a1_scale=layer.w13_input_scale,
            a2_scale=layer.w2_input_scale,
1189
            block_shape=self.weight_block_size,
1190
1191
        )

1192
1193
1194
1195
1196
1197
1198
1199
    @property
    def supports_eplb(self) -> bool:
        return True

    @property
    def allow_inplace(self) -> bool:
        return True

1200
1201
    def apply(
        self,
1202
        layer: FusedMoE,
1203
1204
        x: torch.Tensor,
        router_logits: torch.Tensor,
1205
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
1206
        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
1207
            # TODO(rob): convert this to MK.
1208
1209
1210
1211
            if layer.enable_eplb:
                raise NotImplementedError("EPLB not supported for `Fp8MoEMethod` yet.")
            assert layer.activation == "silu", (
                f"Expected 'silu' activation but got {layer.activation}"
1212
            )
1213

1214
            if self.block_quant:
1215
                import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe  # noqa: E501, F401
1216
1217

                e_score_correction_bias = (
1218
1219
                    layer.e_score_correction_bias.to(x.dtype)
                    if layer.e_score_correction_bias is not None
1220
1221
                    else None
                )
1222
                routing_method_type = layer.routing_method_type
1223
                return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
1224
1225
1226
                    routing_logits=router_logits.to(torch.float32)
                    if routing_method_type == RoutingMethodType.DeepSeekV3
                    else router_logits,
1227
1228
1229
1230
1231
1232
                    routing_bias=e_score_correction_bias,
                    x=x,
                    w13_weight=layer.w13_weight,
                    w13_weight_scale_inv=layer.w13_weight_scale_inv,
                    w2_weight=layer.w2_weight,
                    w2_weight_scale_inv=layer.w2_weight_scale_inv,
1233
1234
1235
1236
                    global_num_experts=layer.global_num_experts,
                    top_k=layer.top_k,
                    num_expert_group=layer.num_expert_group,
                    topk_group=layer.topk_group,
1237
1238
1239
                    intermediate_size=layer.intermediate_size_per_partition,
                    expert_offset=layer.ep_rank * layer.local_num_experts,
                    local_num_experts=layer.local_num_experts,
1240
                    block_shape=self.weight_block_size,
1241
                    routing_method_type=routing_method_type,
1242
                    routed_scaling=layer.routed_scaling_factor,
1243
1244
                )
            else:
1245
1246
1247
                assert (
                    not layer.renormalize and layer.custom_routing_function is not None
                )
XuruiYang's avatar
XuruiYang committed
1248
                result = apply_flashinfer_per_tensor_scale_fp8(
1249
1250
1251
                    layer=layer,
                    hidden_states=x,
                    router_logits=router_logits,
1252
1253
1254
1255
1256
1257
                    routing_bias=layer.e_score_correction_bias,
                    global_num_experts=layer.global_num_experts,
                    top_k=layer.top_k,
                    num_expert_group=layer.num_expert_group,
                    topk_group=layer.topk_group,
                    apply_router_weight_on_input=layer.apply_router_weight_on_input,
1258
                )
1259

1260
        topk_weights, topk_ids = layer.select_experts(
1261
1262
1263
            hidden_states=x,
            router_logits=router_logits,
        )
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
        result = self.kernel(
            x,
            layer.w13_weight,
            layer.w2_weight,
            topk_weights,
            topk_ids,
            inplace=self.use_inplace,
            activation=layer.activation,
            global_num_experts=layer.global_num_experts,
            expert_map=layer.expert_map,
            apply_router_weight_on_input=layer.apply_router_weight_on_input,
        )
1276

1277
        return result
1278
1279


1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
class Fp8OnlineMoEMethod(Fp8MoEMethod):
    """MoE method for online FP8 quantization.
    Supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

    def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
        super().__init__(quant_config, layer)
        assert not quant_config.is_checkpoint_fp8_serialized
        assert quant_config.activation_scheme == "dynamic"
        assert quant_config.weight_block_size is None
        assert self.flashinfer_moe_backend is None

    def create_weights(
        self,
        layer: Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

        # We are doing online quantization, patch the weight loaded
        # to call `process_weights_after_loading` in a streaming fashion
        # as soon as the last weight chunk is loaded.
        weight_loader = extra_weight_attrs["weight_loader"]
        # create a new holder to prevent modifying behavior of any other
        # objects which might depend on the old one
        new_extra_weight_attrs = extra_weight_attrs

        def patched_weight_loader(param, loaded_weight, *args, **kwargs):
            # add a counter to track how many elements we have updated
            if not hasattr(layer, "_loaded_numel"):
                layer._loaded_numel = 0
1324
1325
1326
1327
1328
1329

            # load the current weight chunk
            copy_numel_counter = CopyNumelCounter()
            with copy_numel_counter:
                res = weight_loader(param, loaded_weight, *args, **kwargs)  # type: ignore[misc]
            layer._loaded_numel += copy_numel_counter.copied_numel
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393

            # if we have loaded all of the elements, call
            # process_weights_after_loading
            target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel()
            if layer._loaded_numel == target_loaded_numel:
                self.process_weights_after_loading(layer)

                # Delete the bookkeeping
                del layer._loaded_numel
                # Prevent the usual `process_weights_after_loading` call
                # from doing anything
                layer._already_called_process_weights_after_loading = True

            return res

        new_extra_weight_attrs["weight_loader"] = patched_weight_loader
        extra_weight_attrs = new_extra_weight_attrs

        # WEIGHTS
        w13_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                hidden_size,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

        w2_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
        # Allocate 2 scales for w1 and w3 respectively.
        # They will be combined to a single scale after weight loading.
        w13_weight_scale = torch.nn.Parameter(
            torch.ones(num_experts, dtype=torch.float32), requires_grad=False
        )
        w2_weight_scale = torch.nn.Parameter(
            torch.ones(num_experts, dtype=torch.float32), requires_grad=False
        )
        layer.register_parameter("w13_weight_scale", w13_weight_scale)
        layer.register_parameter("w2_weight_scale", w2_weight_scale)

        layer.w13_input_scale = None
        layer.w2_input_scale = None

    def process_weights_after_loading(self, layer: Module) -> None:
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

        # If checkpoint is fp16, quantize in place.
        fp8_dtype = current_platform.fp8_dtype()
1394
1395
        w13_weight = torch.empty_like(layer.w13_weight, dtype=fp8_dtype)
        w2_weight = torch.empty_like(layer.w2_weight, dtype=fp8_dtype)
1396
1397
1398

        for expert in range(layer.local_num_experts):
            w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
1399
                ops.scaled_fp8_quant(layer.w13_weight[expert, :, :])
1400
1401
            )
            w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
1402
                ops.scaled_fp8_quant(layer.w2_weight[expert, :, :])
1403
1404
1405
1406
            )
        replace_parameter(layer, "w13_weight", w13_weight)
        replace_parameter(layer, "w2_weight", w2_weight)

1407
1408
1409
1410
        # Shuffle weights into the runtime format.
        self._convert_weights_to_kernel_format(
            layer, w13_weight, w2_weight, layer.w13_weight_scale, layer.w2_weight_scale
        )
1411

1412
1413
        # Setup modular kernel for TP case.
        self._setup_kernel(layer)
1414
1415


1416
1417
1418
class Fp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
1419
1420
1421
    """

    def __init__(self, quant_config: Fp8Config):
1422
        super().__init__(quant_config)