fp8.py 60.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from enum import Enum
5
from typing import TYPE_CHECKING, Any, Optional
6
7
8
9

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter
10
from torch.utils._python_dispatch import TorchDispatchMode
11

12
import vllm.envs as envs
13
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
14
from vllm import _custom_ops as ops
15
from vllm._aiter_ops import rocm_aiter_ops
16
from vllm.attention.layer import Attention
17
from vllm.distributed import get_tensor_model_parallel_world_size
18
from vllm.logger import init_logger
19
from vllm.model_executor.layers.batch_invariant import (
20
    vllm_is_batch_invariant,
21
)
bnellnm's avatar
bnellnm committed
22
from vllm.model_executor.layers.fused_moe import (
23
24
25
26
27
28
29
    FusedMoE,
    FusedMoEActivationFormat,
    FusedMoEMethodBase,
    FusedMoEPermuteExpertsUnpermute,
    FusedMoEPrepareAndFinalize,
    FusedMoeWeightScaleSupported,
)
30
from vllm.model_executor.layers.fused_moe.config import (
31
    FusedMoEParallelConfig,
32
    FusedMoEQuantConfig,
33
    RoutingMethodType,
34
35
    fp8_w8a8_moe_quant_config,
)
36
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
37
38
39
40
41
42
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
)
43
from vllm.model_executor.layers.quantization import QuantizationMethods
44
from vllm.model_executor.layers.quantization.base_config import (
45
46
47
    QuantizationConfig,
    QuantizeMethodBase,
)
48
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
49
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
50
51
    FlashinferMoeBackend,
    apply_flashinfer_per_tensor_scale_fp8,
52
    build_flashinfer_fp8_cutlass_moe_prepare_finalize,
53
54
55
56
57
58
    get_flashinfer_moe_backend,
    register_moe_scaling_factors,
    rotate_flashinfer_fp8_moe_weights,
    select_cutlass_fp8_gemm_impl,
    swap_w13_to_w31,
)
59
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
60
61
62
63
    W8A8BlockFp8LinearOp,
    create_fp8_input_scale,
    create_fp8_scale_parameter,
    create_fp8_weight_parameter,
64
    deepgemm_post_process_fp8_weight_block,
65
66
67
68
69
    maybe_post_process_fp8_weight_block,
    process_fp8_weight_block_strategy,
    process_fp8_weight_tensor_strategy,
    validate_fp8_block_shape,
)
70
71
72
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
    get_marlin_input_dtype,
)
73
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
74
75
76
77
    apply_fp8_marlin_linear,
    prepare_fp8_layer_for_marlin,
    prepare_moe_fp8_layer_for_marlin,
)
78
from vllm.model_executor.layers.quantization.utils.quant_utils import (
79
80
81
    GroupShape,
    is_layer_skipped,
)
82
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
83
84
85
86
87
88
89
90
91
92
93
94
95
    Fp8LinearOp,
    all_close_1d,
    cutlass_block_fp8_supported,
    cutlass_fp8_supported,
    maybe_create_device_identity,
    normalize_e4m3fn_to_e4m3fnuz,
    per_tensor_dequantize,
)
from vllm.model_executor.parameter import (
    BlockQuantScaleParameter,
    ModelWeightParameter,
    PerTensorScaleParameter,
)
96
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
97
from vllm.platforms import current_platform
98
from vllm.scalar_type import scalar_types
99
100
101
102
from vllm.utils.deep_gemm import (
    is_deep_gemm_e8m0_used,
    is_deep_gemm_supported,
)
103
from vllm.utils.flashinfer import has_flashinfer_moe
104
from vllm.utils.import_utils import has_deep_gemm
105

106
107
108
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

109
110
111
112
ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)

113

114
115
116
117
118
class Fp8MoeBackend(Enum):
    NONE = 0
    FLASHINFER_TRTLLM = 1
    FLASHINFER_CUTLASS = 2
    DEEPGEMM = 3
119
120
    MARLIN = 4
    TRITON = 5
121
122


123
def get_fp8_moe_backend(
124
125
126
    block_quant: bool,
    moe_parallel_config: FusedMoEParallelConfig,
    with_lora_support: bool,
127
) -> Fp8MoeBackend:
128
129
130
131
    """
    Select the primary FP8 MoE backend
    Note: Shape-specific fallbacks may still occur at runtime.
    """
132
133
    if with_lora_support:
        return Fp8MoeBackend.TRITON
134
    # Prefer FlashInfer backends on supported GPUs; allow SM90 and SM100.
135
136
    if (
        current_platform.is_cuda()
137
        and (
138
            current_platform.is_device_capability_family(100)
139
140
            or current_platform.is_device_capability(90)
        )
141
142
143
        and envs.VLLM_USE_FLASHINFER_MOE_FP8
        and has_flashinfer_moe()
    ):
144
145
        backend = get_flashinfer_moe_backend()
        if backend == FlashinferMoeBackend.TENSORRT_LLM:
146
            logger.info_once("Using FlashInfer FP8 MoE TRTLLM backend for SM100")
147
148
            return Fp8MoeBackend.FLASHINFER_TRTLLM
        else:
149
            if block_quant and current_platform.is_device_capability_family(100):
150
151
152
153
154
155
                raise ValueError(
                    "FlashInfer FP8 MoE throughput backend does not "
                    "support block quantization. Please use "
                    "VLLM_FLASHINFER_MOE_BACKEND=latency "
                    "instead."
                )
156
            logger.info_once("Using FlashInfer FP8 MoE CUTLASS backend for SM90/SM100")
157
158
159
            return Fp8MoeBackend.FLASHINFER_CUTLASS

    # weight-only path for older GPUs without native FP8
160
161
162
163
    use_marlin = (
        not current_platform.has_device_capability(89)
        or envs.VLLM_TEST_FORCE_FP8_MARLIN
    )
164
165
166
167
168
169
    if current_platform.is_rocm():
        use_marlin = False
    if use_marlin:
        logger.info_once("Using Marlin backend for FP8 MoE")
        return Fp8MoeBackend.MARLIN

170
171
172
173
174
175
176
177
178
179
180
181
182
    # Determine if we should use DeepGEMM with block-quantized weights:
    # - If explicitly set by user, respect their choice
    # - If not explicitly set (default), disable when TP size is >= 8
    moe_use_deep_gemm = envs.VLLM_MOE_USE_DEEP_GEMM
    if not envs.is_set("VLLM_MOE_USE_DEEP_GEMM") and moe_parallel_config.tp_size >= 8:
        moe_use_deep_gemm = False
        logger.info_once(
            "DeepGEMM MoE is disabled by default when TP size is >= 8. "
            "Set VLLM_MOE_USE_DEEP_GEMM=1 to enable it.",
            scope="local",
        )

    if envs.VLLM_USE_DEEP_GEMM and moe_use_deep_gemm and block_quant:
183
        if not has_deep_gemm():
184
185
186
            logger.warning_once(
                "DeepGEMM backend requested but not available.", scope="local"
            )
187
        elif is_deep_gemm_supported():
188
            logger.info_once("Using DeepGEMM backend for FP8 MoE", scope="local")
189
190
191
192
193
194
195
            return Fp8MoeBackend.DEEPGEMM

    # default to Triton
    logger.info_once("Using Triton backend for FP8 MoE")
    return Fp8MoeBackend.TRITON


196
class Fp8Config(QuantizationConfig):
197
198
    """Config class for FP8."""

199
200
    def __init__(
        self,
201
        is_checkpoint_fp8_serialized: bool = False,
202
        activation_scheme: str = "dynamic",
203
204
        ignored_layers: list[str] | None = None,
        weight_block_size: list[int] | None = None,
205
    ) -> None:
206
        super().__init__()
207

208
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
209

210
        if activation_scheme not in ACTIVATION_SCHEMES:
211
            raise ValueError(f"Unsupported activation scheme {activation_scheme}")
212
        self.activation_scheme = activation_scheme
213
        self.ignored_layers = ignored_layers or []
214
215
216
217
        if weight_block_size is not None:
            if not is_checkpoint_fp8_serialized:
                raise ValueError(
                    "The block-wise quantization only supports fp8-serialized "
218
219
                    "checkpoint for now."
                )
220
221
222
            if len(weight_block_size) != 2:
                raise ValueError(
                    "The quantization block size of weight must have 2 "
223
224
                    f"dimensions, but got {len(weight_block_size)} dimensions"
                )
225
            if activation_scheme != "dynamic":
226
227
228
229
230
                raise ValueError(
                    "The block-wise quantization only supports "
                    "dynamic activation scheme for now, but got "
                    f"{activation_scheme} activation scheme."
                )
231
        self.weight_block_size = weight_block_size
232

233
    @classmethod
234
    def get_name(cls) -> QuantizationMethods:
235
236
237
        return "fp8"

    @classmethod
238
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
239
240
241
242
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
243
        return 75
244
245

    @classmethod
246
    def get_config_filenames(cls) -> list[str]:
247
248
        return []

249
250
    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if self.ignored_layers is not None:
251
            self.ignored_layers = hf_to_vllm_mapper.apply_list(self.ignored_layers)
252

253
    @classmethod
254
    def from_config(cls, config: dict[str, Any]) -> "Fp8Config":
255
        quant_method = cls.get_from_keys(config, ["quant_method"])
256
        is_checkpoint_fp8_serialized = "fp8" in quant_method
257
        activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
258
        ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
259
        weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None)
260
        if not ignored_layers:
261
262
263
264
265
266
267
268
269
270
271
272
273
            ignored_layers = cls.get_from_keys_or(
                config, ["modules_to_not_convert"], None
            )
        return cls(
            is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
            activation_scheme=activation_scheme,
            ignored_layers=ignored_layers,
            weight_block_size=weight_block_size,
        )

    def get_xpu_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
274
        from vllm.model_executor.layers.quantization.ipex_quant import (
275
276
277
278
            XPUFp8LinearMethod,
            XPUFp8MoEMethod,
        )

279
280
281
282
        fp8_config = Fp8Config(
            is_checkpoint_fp8_serialized=self.is_checkpoint_fp8_serialized,
            activation_scheme=self.activation_scheme,
            ignored_layers=self.ignored_layers,
283
284
            weight_block_size=self.weight_block_size,
        )
285
286

        if isinstance(layer, LinearBase):
287
288
289
290
291
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
292
293
294
295
296
297
298
299
                return UnquantizedLinearMethod()
            return XPUFp8LinearMethod(fp8_config)
        elif isinstance(layer, FusedMoE):
            return XPUFp8MoEMethod(fp8_config, layer)
        elif isinstance(layer, Attention):
            return Fp8KVCacheMethod(self)
        return None

300
301
302
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
303
304
        if current_platform.is_xpu():
            return self.get_xpu_quant_method(layer, prefix)
305
        if isinstance(layer, LinearBase):
306
307
308
309
310
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
311
                return UnquantizedLinearMethod()
312
313
314
            quant_method = Fp8LinearMethod(self)
            quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
            return quant_method
315
        elif isinstance(layer, FusedMoE):
316
317
318
319
320
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
XuruiYang's avatar
XuruiYang committed
321
                return UnquantizedFusedMoEMethod(layer.moe_config)
322
323
324
325
            if self.is_checkpoint_fp8_serialized:
                moe_quant_method = Fp8MoEMethod(self, layer)
            else:
                moe_quant_method = Fp8OnlineMoEMethod(self, layer)
326
327
            moe_quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
            return moe_quant_method
328
        elif isinstance(layer, Attention):
329
            return Fp8KVCacheMethod(self)
330
        return None
331

332
    def get_cache_scale(self, name: str) -> str | None:
333
334
335
336
337
338
339
340
341
342
343
344
        """
        Check whether the param name matches the format for k/v cache scales
        in compressed-tensors. If this is the case, return its equivalent
        param name expected by vLLM

        :param name: param name
        :return: matching param name for KV cache scale in vLLM
        """
        if name.endswith(".output_scale") and ".k_proj" in name:
            return name.replace(".k_proj.output_scale", ".attn.k_scale")
        if name.endswith(".output_scale") and ".v_proj" in name:
            return name.replace(".v_proj.output_scale", ".attn.v_scale")
345
346
347
348
349
        if name.endswith(".output_scale") and ".q_proj" in name:
            return name.replace(".q_proj.output_scale", ".attn.q_scale")
        if name.endswith("self_attn.prob_output_scale"):
            return name.replace(".prob_output_scale", ".attn.prob_scale")
        # If no matches, return None
350
351
        return None

352

353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
class CopyNumelCounter(TorchDispatchMode):
    """
    Tracks total number of elements modified with `copy_`. Useful for keeping
    track of weight loading where underlying weights can be arbitrarily
    transformed (such as with `narrow`) before calling copy.
    """

    def __init__(self):
        super().__init__()
        self.copied_numel = 0

    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        out = func(*args, **kwargs)
        if func == torch.ops.aten.copy_.default:
            self.copied_numel += args[0].numel()
        return out


373
374
class Fp8LinearMethod(LinearMethodBase):
    """Linear method for FP8.
375
376
377
378
379
380
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.
381
382
383
384
385

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
    2. Only support float8_e4m3fn data type due to the limitation of
       torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
386

387
388
389
390
    Args:
        quant_config: The quantization config.
    """

391
    def __init__(self, quant_config: Fp8Config):
392
        self.quant_config = quant_config
393
        self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
394
        self.out_dtype = torch.get_default_dtype()
395

396
397
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
398
        self.marlin_input_dtype = None
399
400
401
402
        self.use_marlin = (
            not current_platform.has_device_capability(89)
            or envs.VLLM_TEST_FORCE_FP8_MARLIN
        )
403
        # Disable marlin for rocm
404
        if current_platform.is_rocm():
405
            self.use_marlin = False
406
        if vllm_is_batch_invariant():
407
            self.use_marlin = False
408

409
        self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enaled()
410
        self.use_deep_gemm = is_deep_gemm_supported()
411

412
413
        self.weight_block_size = self.quant_config.weight_block_size
        self.block_quant = self.weight_block_size is not None
414
        self.act_q_static = self.quant_config.activation_scheme == "static"
415
416
        if self.weight_block_size:
            self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
417
        else:
418
419
420
421
422
            # Use per-token quantization for better perf if dynamic and cutlass
            if not self.act_q_static and cutlass_fp8_supported():
                self.act_q_group_shape = GroupShape.PER_TOKEN
            else:
                self.act_q_group_shape = GroupShape.PER_TENSOR
423

424
425
426
427
428
429
430
431
432
433
434
435
        if self.block_quant:
            assert not self.act_q_static
            assert self.weight_block_size is not None
            self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
                weight_group_shape=GroupShape(*self.weight_block_size),
                act_quant_group_shape=self.act_q_group_shape,
                cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
                use_aiter_and_is_supported=self.use_aiter_and_is_supported,
            )
        else:
            self.fp8_linear = Fp8LinearOp(
                act_quant_static=self.act_q_static,
436
437
                act_quant_group_shape=self.act_q_group_shape,
            )
438

439
440
441
442
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
443
        output_partition_sizes: list[int],
444
445
446
447
448
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
449
450
        maybe_create_device_identity()

451
        output_size_per_partition = sum(output_partition_sizes)
452
        weight_loader = extra_weight_attrs.get("weight_loader")
453
454
455
456
457
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None
458

459
        if self.block_quant:
460
461
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
462
463
464
465
466
467
468
469
            validate_fp8_block_shape(
                layer,
                input_size,
                output_size,
                input_size_per_partition,
                output_partition_sizes,
                self.weight_block_size,
            )
470

471
        # WEIGHT
472
        if self.quant_config.is_checkpoint_fp8_serialized:
473
474
475
            weight = create_fp8_weight_parameter(
                output_size_per_partition, input_size_per_partition, weight_loader
            )
476
        else:
477
478
479
480
481

            def patched_weight_loader(param, loaded_weight, *args, **kwargs):
                # track how many elements we have updated
                if not hasattr(layer, "_loaded_numel"):
                    layer._loaded_numel = 0
482
483
484
485
486
487

                # load the current weight chunk
                copy_numel_counter = CopyNumelCounter()
                with copy_numel_counter:
                    res = weight_loader(param, loaded_weight, *args, **kwargs)  # type: ignore[misc]
                layer._loaded_numel += copy_numel_counter.copied_numel
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502

                # if we have loaded all of the elements, call
                # process_weights_after_loading
                target_loaded_numel = layer.weight.numel()
                if layer._loaded_numel == target_loaded_numel:
                    self.process_weights_after_loading(layer)

                    # Delete the bookkeeping
                    del layer._loaded_numel
                    # Prevent the usual `process_weights_after_loading` call from doing
                    # anything
                    layer._already_called_process_weights_after_loading = True

                return res

503
            # For non-serialized checkpoints, use original dtype
504
505
506
507
508
509
510
511
            weight = ModelWeightParameter(
                data=torch.empty(
                    output_size_per_partition,
                    input_size_per_partition,
                    dtype=params_dtype,
                ),
                input_dim=1,
                output_dim=0,
512
                weight_loader=patched_weight_loader,
513
            )
514
515
        layer.register_parameter("weight", weight)

516
517
518
519
        # If checkpoint is serialized fp8, load them.
        # Otherwise, wait until process_weights_after_loading.
        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
520
            if not self.block_quant:
521
522
523
524
525
526
527
                scale = create_fp8_scale_parameter(
                    PerTensorScaleParameter,
                    output_partition_sizes,
                    input_size_per_partition,
                    None,
                    weight_loader,
                )
528
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
529
530
                layer.register_parameter("weight_scale", scale)
            else:
531
532
                assert not self.act_q_static
                assert self.weight_block_size is not None
533
534
535
536
537
538
539
                scale = create_fp8_scale_parameter(
                    BlockQuantScaleParameter,
                    output_partition_sizes,
                    input_size_per_partition,
                    self.weight_block_size,
                    weight_loader,
                )
540
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
541
542
                # The weight_scale_inv name is intentional for deepseekv3
                layer.register_parameter("weight_scale_inv", scale)
543

544
            # INPUT ACTIVATION SCALE
545
            if self.act_q_static:
546
                scale = create_fp8_input_scale(output_partition_sizes, weight_loader)
547
                set_weight_attrs(scale, {"scale_type": "input_scale"})
548
                layer.register_parameter("input_scale", scale)
549
550
            else:
                layer.register_parameter("input_scale", None)
551

552
    def process_weights_after_loading(self, layer: Module) -> None:
553
554
555
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

556
        size_k_first = True
557
        input_scale = None
558
        # TODO(rob): refactor block quant into separate class.
559
        if self.block_quant:
560
            assert not self.act_q_static
561
            size_k_first = False
562

563
            weight, weight_scale_inv = process_fp8_weight_block_strategy(
564
565
                layer.weight, layer.weight_scale_inv
            )
566
567
568
569

            # Update layer with new values
            replace_parameter(layer, "weight", weight.data)
            replace_parameter(layer, "weight_scale_inv", weight_scale_inv.data)
570

571
        # If checkpoint not serialized fp8, quantize the weights.
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
        else:
            if not self.quant_config.is_checkpoint_fp8_serialized:
                qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
                weight = qweight.t()

            # If checkpoint is fp8 per-tensor, handle that there are N scales for N
            # shards in a fused module
            else:
                weight = layer.weight
                weight_scale = layer.weight_scale

                # If using w8a8, torch._scaled_mm needs per tensor, so
                # requantize the logical shards as a single weight.
                if not self.use_marlin:
                    weight, weight_scale, input_scale = (
                        process_fp8_weight_tensor_strategy(
                            weight,
                            weight_scale,
                            layer.logical_widths,
                            getattr(layer, "input_scale", None),
                        )
                    )
                    if self.act_q_static:
                        assert input_scale is not None
                        input_scale = input_scale.max()
                weight = weight.t()
598

599
600
601
602
603
604
            # Update layer with new values.
            replace_parameter(layer, "weight", weight.data)
            replace_parameter(layer, "weight_scale", weight_scale.data)

        if input_scale is not None:
            replace_parameter(layer, "input_scale", input_scale)
605
        else:
606
            layer.input_scale = None
607

608
        if self.use_marlin:
609
610
611
            prepare_fp8_layer_for_marlin(
                layer, size_k_first, input_dtype=self.marlin_input_dtype
            )
612
613
            # Activations not quantized for marlin.
            del layer.input_scale
614
            return
615

616
        if self.block_quant:
617
            maybe_post_process_fp8_weight_block(layer)
618

619
620
621
622
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
623
        bias: torch.Tensor | None = None,
624
    ) -> torch.Tensor:
625
626
        # if batch invariant mode is enabled, prefer DeepGEMM FP8 path
        # we will use BF16 dequant when DeepGEMM is not supported.
627
        if vllm_is_batch_invariant():
628
629
            if self.block_quant:
                assert self.weight_block_size is not None
630
631
632
                return self.w8a8_block_fp8_linear.apply(
                    input=x,
                    weight=layer.weight,
633
                    weight_scale=layer.weight_scale_inv,
634
635
636
                    input_scale=layer.input_scale,
                    bias=bias,
                )
637
            else:
638
639
640
                # per-tensor/channel: dequant to BF16 and run GEMM
                weight_fp8 = layer.weight.to(torch.bfloat16)
                weight_scale = layer.weight_scale.to(torch.bfloat16)
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
                if weight_scale.numel() == 1:
                    # Per-tensor: simple scalar multiplication
                    weight_bf16 = weight_fp8 * weight_scale
                else:
                    # Multiple scales (fused modules like QKV)
                    # Try to infer correct broadcasting
                    # weight is [K, N], scale could be [num_logical_weights]
                    # Need to figure out how to broadcast - for now just try
                    # direct multiplication
                    if (
                        weight_scale.dim() == 1
                        and weight_scale.shape[0] == weight_fp8.shape[0]
                    ):
                        # Per-row scaling
                        weight_bf16 = weight_fp8 * weight_scale.unsqueeze(1)
                    else:
                        # Fallback
                        weight_bf16 = weight_fp8 * weight_scale
659
                return torch.nn.functional.linear(x, weight_bf16.t(), bias)
660

661
        if self.use_marlin:
662
663
664
665
666
            if self.block_quant:
                weight_scale = layer.weight_scale_inv
            else:
                weight_scale = layer.weight_scale

667
668
669
            return apply_fp8_marlin_linear(
                input=x,
                weight=layer.weight,
670
                weight_scale=weight_scale,
671
672
673
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
674
                input_dtype=self.marlin_input_dtype,
675
676
                bias=bias,
            )
677

678
        if self.block_quant:
679
680
681
            assert self.weight_block_size is not None

            return self.w8a8_block_fp8_linear.apply(
682
                input=x,
683
                weight=layer.weight,
684
                weight_scale=layer.weight_scale_inv,
685
                input_scale=layer.input_scale,
686
                bias=bias,
687
            )
688

689
690
691
692
693
694
695
696
        return self.fp8_linear.apply(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            out_dtype=self.out_dtype,
            input_scale=layer.input_scale,
            bias=bias,
        )
697
698


699
700
701
702
703
704
705
706
707
708
709
710
711
class Fp8MoEMethod(FusedMoEMethodBase):
    """MoE method for FP8.
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

712
713
714
    def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
        super().__init__(layer.moe_config)
        self.layer = layer
715
        self.quant_config = quant_config
716
        self.weight_block_size = self.quant_config.weight_block_size
717
        self.block_quant: bool = self.weight_block_size is not None
718
        self.fp8_backend = get_fp8_moe_backend(
719
            self.block_quant, layer.moe_parallel_config, self.moe.is_lora_enabled
720
        )
721

722
        self.marlin_input_dtype = None
723
        self.use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN
724
        self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
725
726
727
728
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
            self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
        elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
            self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
            if self.block_quant and self.weight_block_size != [128, 128]:
                raise NotImplementedError(
                    "FlashInfer CUTLASS FP8 MoE backend only supports block "
                    "size [128, 128]."
                )
            if not self.block_quant:
                if layer.renormalize or layer.custom_routing_function is not None:
                    raise NotImplementedError(
                        "FlashInfer CUTLASS FP8 MoE backend does custom routing "
                        f"function or renormalization, but got {layer.renormalize} and "
                        f"{layer.custom_routing_function}."
                    )
                if layer.scoring_func != "sigmoid":
                    raise NotImplementedError(
                        "FlashInfer CUTLASS FP8 MoE backend only supports "
                        f"'sigmoid' scoring function, but got {layer.scoring_func}."
                    )
            if layer.activation != "silu":
                raise NotImplementedError(
                    "FlashInfer CUTLASS FP8 MoE backend only supports SiLU "
                    "activation function, but got {layer.activation}."
750
                )
751

752
753
754
755
756
757
758
759
760
    def create_weights(
        self,
        layer: Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
761
762
763
764
765
766
        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

767
768
769
        assert self.quant_config.is_checkpoint_fp8_serialized
        params_dtype = torch.float8_e4m3fn

770
        if self.block_quant:
771
772
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
773
774
            tp_size = get_tensor_model_parallel_world_size()
            block_n, block_k = (
775
776
                self.weight_block_size[0],
                self.weight_block_size[1],
777
778
779
780
781
            )
            # NOTE: To ensure proper alignment of the block-wise quantization
            # scales, the output_size of the weights for both the gate and up
            # layers must be divisible by block_n.
            # Required by column parallel or enabling merged weights
782
            if intermediate_size_per_partition % block_n != 0:
783
784
                raise ValueError(
                    f"The output_size of gate's and up's weight = "
785
                    f"{intermediate_size_per_partition} is not divisible by "
786
787
788
                    f"weight quantization block_n = {block_n}."
                )
            if tp_size > 1 and intermediate_size_per_partition % block_k != 0:
789
                # Required by row parallel
790
791
792
                raise ValueError(
                    f"The input_size of down's weight = "
                    f"{intermediate_size_per_partition} is not divisible by "
793
794
                    f"weight quantization block_k = {block_k}."
                )
795
796

        # WEIGHTS
797
798
799
800
801
802
803
804
805
        w13_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                hidden_size,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
806
807
808
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

809
810
811
812
813
814
815
816
817
        w2_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
818
819
820
821
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
822
823
824
        if not self.block_quant:
            # Allocate 2 scales for w1 and w3 respectively.
            # They will be combined to a single scale after weight loading.
825
826
827
828
829
830
            w13_weight_scale = torch.nn.Parameter(
                torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
            )
            w2_weight_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
831
832
833
834
835
836
            layer.register_parameter("w13_weight_scale", w13_weight_scale)
            layer.register_parameter("w2_weight_scale", w2_weight_scale)
        else:
            w13_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
837
                    2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
838
839
840
841
842
843
844
845
846
                    (hidden_size + block_k - 1) // block_k,
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            w2_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
                    (hidden_size + block_n - 1) // block_n,
847
                    (intermediate_size_per_partition + block_k - 1) // block_k,
848
849
850
851
852
853
854
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
            layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
            assert self.quant_config.activation_scheme == "dynamic"
855

856
857
858
        # Add the quantization method used (per tensor/grouped/channel)
        # to ensure the weight scales are loaded in properly
        extra_weight_attrs.update(
859
860
861
862
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
            if self.block_quant
            else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
        )
863
864
        set_weight_attrs(w13_weight_scale, extra_weight_attrs)
        set_weight_attrs(w2_weight_scale, extra_weight_attrs)
865
866
867

        # INPUT_SCALES
        if self.quant_config.activation_scheme == "static":
868
869
870
            w13_input_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
871
            layer.register_parameter("w13_input_scale", w13_input_scale)
872
            set_weight_attrs(w13_input_scale, extra_weight_attrs)
873

874
875
876
            w2_input_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
877
            layer.register_parameter("w2_input_scale", w2_input_scale)
878
879
            set_weight_attrs(w2_input_scale, extra_weight_attrs)

880
        else:
881
882
            layer.w13_input_scale = None
            layer.w2_input_scale = None
883

884
885
        self.rocm_aiter_moe_enabled = False

886
    def process_weights_after_loading(self, layer: Module) -> None:
887
888
889
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

890
891
        # Lazy import to avoid importing triton too early.

892
        self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
893

894
        # TODO (rob): refactor block quant into separate class.
895
        if self.block_quant:
896
            assert self.quant_config.activation_scheme == "dynamic"
897
            if current_platform.is_fp8_fnuz():
898
                w13_weight, w13_weight_scale_inv, w13_input_scale = (
899
                    normalize_e4m3fn_to_e4m3fnuz(
900
901
902
903
904
905
                        layer.w13_weight,
                        layer.w13_weight_scale_inv,
                        layer.w13_input_scale,
                    )
                )
                w2_weight, w2_weight_scale_inv, w2_input_scale = (
906
                    normalize_e4m3fn_to_e4m3fnuz(
907
908
909
                        layer.w2_weight, layer.w2_weight_scale_inv, layer.w2_input_scale
                    )
                )
910
            elif self.flashinfer_moe_backend is not None:
911
912
                # NOTE: weights have to be swapped since the activation is
                # applied on different half for flashinfer vs vllm
913
                w13_weight = swap_w13_to_w31(layer.w13_weight.data)
914
                w13_weight_scale_inv = swap_w13_to_w31(layer.w13_weight_scale_inv.data)
915
916
                w2_weight = layer.w2_weight.data
                w2_weight_scale_inv = layer.w2_weight_scale_inv.data
917
918
919
920
921
922
923
            else:
                w13_weight = layer.w13_weight.data
                w13_weight_scale_inv = layer.w13_weight_scale_inv.data
                w2_weight = layer.w2_weight
                w2_weight_scale_inv = layer.w2_weight_scale_inv

            # torch.compile() cannot use Parameter subclasses.
924
925
926
927
            replace_parameter(layer, "w13_weight", w13_weight)
            replace_parameter(layer, "w13_weight_scale_inv", w13_weight_scale_inv)
            replace_parameter(layer, "w2_weight", w2_weight)
            replace_parameter(layer, "w2_weight_scale_inv", w2_weight_scale_inv)
928
            if self.rocm_aiter_moe_enabled:
929
                # reshaping weights is required for aiter moe kernel.
930
                shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
931
932
                    layer.w13_weight.data, layer.w2_weight.data
                )
933

934
935
                replace_parameter(layer, "w13_weight", shuffled_w13)
                replace_parameter(layer, "w2_weight", shuffled_w2)
936

937
            # DeepGemm scales need to be transposed and aligned. We try to do
938
            # it ahead of time for performance reasons.
939
            if self.fp8_backend == Fp8MoeBackend.DEEPGEMM:
940
941
942
943
944
945
                dg_w13_weight, dg_w13_weight_scale_inv = (
                    deepgemm_post_process_fp8_weight_block(
                        wq=layer.w13_weight.data,
                        ws=layer.w13_weight_scale_inv.data,
                        quant_block_shape=tuple(layer.weight_block_size),
                        use_e8m0=is_deep_gemm_e8m0_used(),
946
                    )
947
948
949
950
951
952
953
                )
                dg_w2_weight, dg_w2_weight_scale_inv = (
                    deepgemm_post_process_fp8_weight_block(
                        wq=layer.w2_weight.data,
                        ws=layer.w2_weight_scale_inv.data,
                        quant_block_shape=tuple(layer.weight_block_size),
                        use_e8m0=is_deep_gemm_e8m0_used(),
954
                    )
955
956
957
958
959
960
961
962
963
                )
                layer.w13_weight = Parameter(dg_w13_weight, requires_grad=False)
                layer.w13_weight_scale_inv = Parameter(
                    dg_w13_weight_scale_inv, requires_grad=False
                )
                layer.w2_weight = Parameter(dg_w2_weight, requires_grad=False)
                layer.w2_weight_scale_inv = Parameter(
                    dg_w2_weight_scale_inv, requires_grad=False
                )
964
965
966
967
        else:
            # Fp8 moe kernels require a single activation scale.
            # We take the max of all the scales in case they differ.
            if self.quant_config.activation_scheme == "static":
968
                if layer.w13_input_scale is None or layer.w2_input_scale is None:
969
970
                    raise ValueError(
                        "QuantConfig has static quantization, but found "
971
972
973
974
975
                        "activation scales are None."
                    )
                if not all_close_1d(layer.w13_input_scale) or not all_close_1d(
                    layer.w2_input_scale
                ):
976
                    logger.warning_once(
977
978
                        "Found input_scales that are not equal for "
                        "fp8 MoE layer. Using the maximum across experts "
979
980
                        "for each layer."
                    )
981
982
                replace_parameter(layer, "w13_input_scale", layer.w13_input_scale.max())
                replace_parameter(layer, "w2_input_scale", layer.w2_input_scale.max())
983
            if current_platform.is_fp8_fnuz():
984
                # Normalize the weights and scales
985
                w13_weight, w13_weight_scale, w13_input_scale = (
986
                    normalize_e4m3fn_to_e4m3fnuz(
987
988
989
990
                        layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale
                    )
                )
                w2_weight, w2_weight_scale, w2_input_scale = (
991
                    normalize_e4m3fn_to_e4m3fnuz(
992
993
994
                        layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale
                    )
                )
995
                # Reset the parameter
996
997
                replace_parameter(layer, "w13_weight", w13_weight)
                replace_parameter(layer, "w13_weight_scale", w13_weight_scale)
998
                if w13_input_scale is not None:
999
1000
1001
                    replace_parameter(layer, "w13_input_scale", w13_input_scale)
                replace_parameter(layer, "w2_weight", w2_weight)
                replace_parameter(layer, "w2_weight_scale", w2_weight_scale)
1002
                if w2_input_scale is not None:
1003
                    replace_parameter(layer, "w2_input_scale", w2_input_scale)
1004
1005
1006

            # Fp8 moe kernel needs single weight scale for w13 per expert.
            # We take the max then dequant and requant each expert.
1007
            assert layer.w13_weight_scale is not None
1008
            shard_size = layer.intermediate_size_per_partition
1009
            max_w13_scales = layer.w13_weight_scale.max(dim=1).values
1010
            for expert_id in range(layer.local_num_experts):
1011
1012
1013
                start = 0
                for shard_id in range(2):
                    dq_weight = per_tensor_dequantize(
1014
1015
1016
1017
1018
1019
                        layer.w13_weight[expert_id][start : start + shard_size, :],
                        layer.w13_weight_scale[expert_id][shard_id],
                    )
                    layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
                        ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
                    )
1020
1021
                    start += shard_size

1022
            if self.rocm_aiter_moe_enabled:
1023
                shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
1024
1025
                    layer.w13_weight, layer.w2_weight
                )
1026

1027
1028
                replace_parameter(layer, "w13_weight", shuffled_w13)
                replace_parameter(layer, "w2_weight", shuffled_w2)
1029

1030
            replace_parameter(layer, "w13_weight_scale", max_w13_scales)
1031

1032
1033
1034
1035
1036
1037
            if self.flashinfer_moe_backend is not None:
                # NOTE: weights have to be swapped since the activation is
                # applied on different half for flashinfer vs vllm
                assert not self.block_quant
                register_moe_scaling_factors(layer)
                w13_weight = swap_w13_to_w31(layer.w13_weight.data)
1038
                if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
1039
1040
1041
                    rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
                layer.w13_weight.data = w13_weight.data

1042
        if self.use_marlin:
1043
1044
1045
            prepare_moe_fp8_layer_for_marlin(
                layer, False, input_dtype=self.marlin_input_dtype
            )
1046
1047
1048
            # Activations not quantized for marlin.
            del layer.w13_input_scale
            del layer.w2_input_scale
1049

1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
        # NOTE(rob): this is a WIP refactor. We are first migrating
        # all of the kernels in the TP case to use mk. Once this is
        # done, then we will initialzie the TP case and DP/EP case
        # via the same code path (i.e. via maybe_init_modular_kernel).
        # NOTE(rob): in progress migrating all into this format.
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
            from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
                FlashInferExperts,
            )
            from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import (  # noqa: E501
                FlashInferAllGatherMoEPrepareAndFinalize,
            )

            config = self.get_fused_moe_quant_config(layer)
            assert config is not None
            self.moe_quant_config = config

            self.kernel = mk.FusedMoEModularKernel(
                FlashInferAllGatherMoEPrepareAndFinalize(
                    use_dp=(self.moe.dp_size > 1),
                    use_deepseek_fp8_block_scale=self.block_quant,
                ),
                FlashInferExperts(
                    out_dtype=torch.get_default_dtype(),
                    quant_config=self.moe_quant_config,
                    ep_rank=self.moe.ep_rank,
                    ep_size=self.moe.ep_size,
                    tp_rank=self.moe.tp_rank,
                    tp_size=self.moe.tp_size,
                    use_dp=(self.moe.dp_size > 1),
                    use_deepseek_fp8_block_scale=self.block_quant,
                ),
            )
            self.use_inplace = False

        elif self.fp8_backend in [Fp8MoeBackend.DEEPGEMM, Fp8MoeBackend.TRITON]:
            from vllm.model_executor.layers.fused_moe import (
                TritonOrDeepGemmExperts,
            )
            from vllm.model_executor.layers.fused_moe.prepare_finalize import (
                MoEPrepareAndFinalizeNoEP,
            )

            config = self.get_fused_moe_quant_config(layer)
            assert config is not None
            self.moe_quant_config = config
            self.kernel = mk.FusedMoEModularKernel(
                MoEPrepareAndFinalizeNoEP(),
                TritonOrDeepGemmExperts(
                    quant_config=self.moe_quant_config,
                    allow_deep_gemm=(self.fp8_backend == Fp8MoeBackend.DEEPGEMM),
                ),
            )
            self.use_inplace = True

1105
1106
1107
1108
    def maybe_make_prepare_finalize(
        self,
        routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
    ) -> mk.FusedMoEPrepareAndFinalize | None:
1109
1110
1111
1112
1113
        if (
            self.rocm_aiter_moe_enabled
            or self.use_marlin
            or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1114
1115
            return None
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
1116
1117
1118
1119
1120
1121
            if self.block_quant:
                assert self.weight_block_size == [128, 128], (
                    f"Only support weight_block_size == [128, 128], "
                    f"got {self.weight_block_size}"
                )
            # Wire block-scale flag through prepare/finalize when using CUTLASS
1122
            prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
1123
1124
                self.moe,
                use_deepseek_fp8_block_scale=self.block_quant,
1125
            )
1126
1127
1128
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
        else:
1129
            return super().maybe_make_prepare_finalize(routing_tables)
1130

bnellnm's avatar
bnellnm committed
1131
1132
1133
    def select_gemm_impl(
        self,
        prepare_finalize: FusedMoEPrepareAndFinalize,
1134
        layer: torch.nn.Module,
bnellnm's avatar
bnellnm committed
1135
    ) -> FusedMoEPermuteExpertsUnpermute:
1136
        from vllm.model_executor.layers.fused_moe import (
1137
1138
            BatchedDeepGemmExperts,
            BatchedTritonExperts,
1139
            TritonExperts,
1140
1141
            TritonOrDeepGemmExperts,
        )
1142

1143
        assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
1144
1145
            "Marlin and ROCm AITER are not supported with all2all yet."
        )
1146

1147
1148
        assert self.moe_quant_config is not None

1149
1150
1151
1152
1153
        if (
            prepare_finalize.activation_format
            == FusedMoEActivationFormat.BatchedExperts
        ):
            max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
bnellnm's avatar
bnellnm committed
1154
            assert max_num_tokens_per_rank is not None
1155
1156

            experts_impl = (
1157
1158
1159
                BatchedDeepGemmExperts
                if self.fp8_backend == Fp8MoeBackend.DEEPGEMM
                else BatchedTritonExperts
1160
            )
bnellnm's avatar
bnellnm committed
1161
            logger.debug(
1162
1163
                "%s(%s): max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
                experts_impl.__name__,
1164
1165
1166
1167
1168
                self.__class__.__name__,
                max_num_tokens_per_rank,
                self.weight_block_size,
                False,
            )
1169
            return experts_impl(
1170
                max_num_tokens=max_num_tokens_per_rank,
1171
                num_dispatchers=prepare_finalize.num_dispatchers(),
1172
                quant_config=self.moe_quant_config,
1173
            )
1174
1175
        elif self.moe.is_lora_enabled:
            return TritonExperts(quant_config=self.moe_quant_config)
1176
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
1177
            # Select GEMM experts with block-scale when weights are block-quantized
1178
            experts = select_cutlass_fp8_gemm_impl(
1179
1180
                self.moe,
                self.moe_quant_config,
1181
                use_deepseek_fp8_block_scale=self.block_quant,
1182
1183
1184
            )
            logger.debug_once("Using %s", experts.__class__.__name__)
            return experts
1185
        else:
bnellnm's avatar
bnellnm committed
1186
1187
            logger.debug(
                "TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s",
1188
1189
1190
1191
                self.__class__.__name__,
                self.weight_block_size,
                False,
            )
bnellnm's avatar
bnellnm committed
1192
            return TritonOrDeepGemmExperts(
1193
                quant_config=self.moe_quant_config,
1194
                allow_deep_gemm=(self.fp8_backend == Fp8MoeBackend.DEEPGEMM),
1195
1196
            )

1197
    def get_fused_moe_quant_config(
1198
        self, layer: torch.nn.Module
1199
    ) -> FusedMoEQuantConfig | None:
1200
1201
1202
1203
        if self.use_marlin:
            return None

        return fp8_w8a8_moe_quant_config(
1204
1205
1206
1207
1208
1209
1210
1211
            w1_scale=(
                layer.w13_weight_scale_inv
                if self.block_quant
                else layer.w13_weight_scale
            ),
            w2_scale=(
                layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
            ),
1212
1213
            a1_scale=layer.w13_input_scale,
            a2_scale=layer.w2_input_scale,
1214
            block_shape=self.weight_block_size,
1215
1216
        )

1217
1218
1219
1220
1221
1222
1223
1224
    @property
    def supports_eplb(self) -> bool:
        return True

    @property
    def allow_inplace(self) -> bool:
        return True

1225
1226
    def apply(
        self,
1227
        layer: FusedMoE,
1228
1229
        x: torch.Tensor,
        router_logits: torch.Tensor,
1230
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
1231
        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
1232
            # TODO(rob): convert this to MK.
1233
1234
1235
1236
            if layer.enable_eplb:
                raise NotImplementedError("EPLB not supported for `Fp8MoEMethod` yet.")
            assert layer.activation == "silu", (
                f"Expected 'silu' activation but got {layer.activation}"
1237
            )
1238

1239
            if self.block_quant:
1240
                import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe  # noqa: E501, F401
1241
1242

                e_score_correction_bias = (
1243
1244
                    layer.e_score_correction_bias.to(x.dtype)
                    if layer.e_score_correction_bias is not None
1245
1246
                    else None
                )
1247
                routing_method_type = layer.routing_method_type
1248
                return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
1249
1250
1251
                    routing_logits=router_logits.to(torch.float32)
                    if routing_method_type == RoutingMethodType.DeepSeekV3
                    else router_logits,
1252
1253
1254
1255
1256
1257
                    routing_bias=e_score_correction_bias,
                    x=x,
                    w13_weight=layer.w13_weight,
                    w13_weight_scale_inv=layer.w13_weight_scale_inv,
                    w2_weight=layer.w2_weight,
                    w2_weight_scale_inv=layer.w2_weight_scale_inv,
1258
1259
1260
1261
                    global_num_experts=layer.global_num_experts,
                    top_k=layer.top_k,
                    num_expert_group=layer.num_expert_group,
                    topk_group=layer.topk_group,
1262
1263
1264
                    intermediate_size=layer.intermediate_size_per_partition,
                    expert_offset=layer.ep_rank * layer.local_num_experts,
                    local_num_experts=layer.local_num_experts,
1265
                    block_shape=self.weight_block_size,
1266
                    routing_method_type=routing_method_type,
1267
                    routed_scaling=layer.routed_scaling_factor,
1268
1269
                )
            else:
1270
1271
1272
                assert (
                    not layer.renormalize and layer.custom_routing_function is not None
                )
XuruiYang's avatar
XuruiYang committed
1273
                result = apply_flashinfer_per_tensor_scale_fp8(
1274
1275
1276
                    layer=layer,
                    hidden_states=x,
                    router_logits=router_logits,
1277
1278
1279
1280
1281
1282
                    routing_bias=layer.e_score_correction_bias,
                    global_num_experts=layer.global_num_experts,
                    top_k=layer.top_k,
                    num_expert_group=layer.num_expert_group,
                    topk_group=layer.topk_group,
                    apply_router_weight_on_input=layer.apply_router_weight_on_input,
1283
                )
1284

1285
        select_result = layer.select_experts(
1286
1287
1288
1289
            hidden_states=x,
            router_logits=router_logits,
        )

XuruiYang's avatar
XuruiYang committed
1290
1291
        topk_weights, topk_ids, zero_expert_result = select_result

1292
1293
        if self.rocm_aiter_moe_enabled:
            from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (  # noqa: E501
1294
1295
1296
                rocm_aiter_fused_experts,
            )

1297
            # TODO(rob): convert this to MK.
XuruiYang's avatar
XuruiYang committed
1298
            result = rocm_aiter_fused_experts(
1299
1300
1301
                x,
                layer.w13_weight,
                layer.w2_weight,
1302
1303
                topk_weights=topk_weights,
                topk_ids=topk_ids,
1304
1305
1306
                activation=layer.activation,
                apply_router_weight_on_input=layer.apply_router_weight_on_input,
                expert_map=layer.expert_map,
1307
1308
                quant_config=self.moe_quant_config,
            )
1309
        elif self.use_marlin:
1310
            # TODO(rob): convert this to MK.
1311
1312
1313
            assert layer.activation == "silu", (
                f"{layer.activation} not supported for Marlin MoE."
            )
1314
            result = fused_marlin_moe(
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
                x,
                layer.w13_weight,
                layer.w2_weight,
                None,
                None,
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                router_logits,
                topk_weights,
                topk_ids,
                quant_type_id=scalar_types.float8_e4m3fn.id,
1326
1327
1328
                apply_router_weight_on_input=layer.apply_router_weight_on_input,
                global_num_experts=layer.global_num_experts,
                expert_map=layer.expert_map,
1329
                input_dtype=self.marlin_input_dtype,
1330
1331
                workspace=layer.workspace,
            )
1332
1333
        else:
            result = self.kernel(
1334
                x,
1335
1336
                layer.w13_weight,
                layer.w2_weight,
1337
1338
                topk_weights,
                topk_ids,
1339
                inplace=self.use_inplace,
1340
1341
1342
1343
                activation=layer.activation,
                global_num_experts=layer.global_num_experts,
                expert_map=layer.expert_map,
                apply_router_weight_on_input=layer.apply_router_weight_on_input,
1344
            )
1345
1346

        if layer.zero_expert_num != 0 and layer.zero_expert_type is not None:
1347
            assert not isinstance(result, tuple), (
XuruiYang's avatar
XuruiYang committed
1348
                "Shared + zero experts are mutually exclusive not yet supported"
1349
            )
XuruiYang's avatar
XuruiYang committed
1350
1351
1352
            return result, zero_expert_result
        else:
            return result
1353
1354


1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
class Fp8OnlineMoEMethod(Fp8MoEMethod):
    """MoE method for online FP8 quantization.
    Supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

    def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
        super().__init__(quant_config, layer)
        assert not quant_config.is_checkpoint_fp8_serialized
        assert quant_config.activation_scheme == "dynamic"
        assert quant_config.weight_block_size is None
        assert self.flashinfer_moe_backend is None

    def create_weights(
        self,
        layer: Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

        # We are doing online quantization, patch the weight loaded
        # to call `process_weights_after_loading` in a streaming fashion
        # as soon as the last weight chunk is loaded.
        weight_loader = extra_weight_attrs["weight_loader"]
        # create a new holder to prevent modifying behavior of any other
        # objects which might depend on the old one
        new_extra_weight_attrs = extra_weight_attrs

        def patched_weight_loader(param, loaded_weight, *args, **kwargs):
            # add a counter to track how many elements we have updated
            if not hasattr(layer, "_loaded_numel"):
                layer._loaded_numel = 0
1399
1400
1401
1402
1403
1404

            # load the current weight chunk
            copy_numel_counter = CopyNumelCounter()
            with copy_numel_counter:
                res = weight_loader(param, loaded_weight, *args, **kwargs)  # type: ignore[misc]
            layer._loaded_numel += copy_numel_counter.copied_numel
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501

            # if we have loaded all of the elements, call
            # process_weights_after_loading
            target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel()
            if layer._loaded_numel == target_loaded_numel:
                self.process_weights_after_loading(layer)

                # Delete the bookkeeping
                del layer._loaded_numel
                # Prevent the usual `process_weights_after_loading` call
                # from doing anything
                layer._already_called_process_weights_after_loading = True

            return res

        new_extra_weight_attrs["weight_loader"] = patched_weight_loader
        extra_weight_attrs = new_extra_weight_attrs

        # WEIGHTS
        w13_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                hidden_size,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

        w2_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
        # Allocate 2 scales for w1 and w3 respectively.
        # They will be combined to a single scale after weight loading.
        w13_weight_scale = torch.nn.Parameter(
            torch.ones(num_experts, dtype=torch.float32), requires_grad=False
        )
        w2_weight_scale = torch.nn.Parameter(
            torch.ones(num_experts, dtype=torch.float32), requires_grad=False
        )
        layer.register_parameter("w13_weight_scale", w13_weight_scale)
        layer.register_parameter("w2_weight_scale", w2_weight_scale)

        layer.w13_input_scale = None
        layer.w2_input_scale = None

        self.rocm_aiter_moe_enabled = False

    def process_weights_after_loading(self, layer: Module) -> None:
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

        # Lazy import to avoid importing triton too early.
        self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()

        # If checkpoint is fp16, quantize in place.
        fp8_dtype = current_platform.fp8_dtype()
        w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
        w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)

        for expert in range(layer.local_num_experts):
            w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
                ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
            )
            w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
                ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
            )
        replace_parameter(layer, "w13_weight", w13_weight)
        replace_parameter(layer, "w2_weight", w2_weight)

        # Reshuffle weights for AITER if needed.
        if self.rocm_aiter_moe_enabled:
            shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
                layer.w13_weight, layer.w2_weight
            )
            replace_parameter(layer, "w13_weight", shuffled_w13)
            replace_parameter(layer, "w2_weight", shuffled_w2)

        # Rushuffle weights for MARLIN if needed.
        if self.use_marlin:
            prepare_moe_fp8_layer_for_marlin(
                layer, False, input_dtype=self.marlin_input_dtype
            )


1502
1503
1504
class Fp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
1505
1506
1507
    """

    def __init__(self, quant_config: Fp8Config):
1508
        super().__init__(quant_config)