"vscode:/vscode.git/clone" did not exist on "46ecc579733f13a555ca42da76c1234c586271eb"
fp8.py 60.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from enum import Enum
5
from typing import TYPE_CHECKING, Any, Optional
6
7
8
9

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter
10
from torch.utils._python_dispatch import TorchDispatchMode
11

12
import vllm.envs as envs
13
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
14
from vllm import _custom_ops as ops
15
from vllm._aiter_ops import rocm_aiter_ops
16
from vllm.attention.layer import Attention
17
from vllm.distributed import get_tensor_model_parallel_world_size
18
from vllm.logger import init_logger
19
from vllm.model_executor.layers.batch_invariant import (
20
    vllm_is_batch_invariant,
21
)
bnellnm's avatar
bnellnm committed
22
from vllm.model_executor.layers.fused_moe import (
23
24
25
26
27
28
29
    FusedMoE,
    FusedMoEActivationFormat,
    FusedMoEMethodBase,
    FusedMoEPermuteExpertsUnpermute,
    FusedMoEPrepareAndFinalize,
    FusedMoeWeightScaleSupported,
)
30
from vllm.model_executor.layers.fused_moe.config import (
31
    FusedMoEParallelConfig,
32
    FusedMoEQuantConfig,
33
    RoutingMethodType,
34
35
    fp8_w8a8_moe_quant_config,
)
36
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
37
38
39
40
41
42
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
)
43
from vllm.model_executor.layers.quantization import QuantizationMethods
44
from vllm.model_executor.layers.quantization.base_config import (
45
46
47
    QuantizationConfig,
    QuantizeMethodBase,
)
48
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
49
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
50
51
    FlashinferMoeBackend,
    apply_flashinfer_per_tensor_scale_fp8,
52
    build_flashinfer_fp8_cutlass_moe_prepare_finalize,
53
54
55
56
57
58
    get_flashinfer_moe_backend,
    register_moe_scaling_factors,
    rotate_flashinfer_fp8_moe_weights,
    select_cutlass_fp8_gemm_impl,
    swap_w13_to_w31,
)
59
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
60
61
62
63
    W8A8BlockFp8LinearOp,
    create_fp8_input_scale,
    create_fp8_scale_parameter,
    create_fp8_weight_parameter,
64
    deepgemm_post_process_fp8_weight_block,
65
66
67
68
69
    maybe_post_process_fp8_weight_block,
    process_fp8_weight_block_strategy,
    process_fp8_weight_tensor_strategy,
    validate_fp8_block_shape,
)
70
71
72
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
    get_marlin_input_dtype,
)
73
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
74
75
76
77
    apply_fp8_marlin_linear,
    prepare_fp8_layer_for_marlin,
    prepare_moe_fp8_layer_for_marlin,
)
78
from vllm.model_executor.layers.quantization.utils.quant_utils import (
79
80
81
    GroupShape,
    is_layer_skipped,
)
82
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
83
84
85
86
87
88
89
90
91
92
93
94
95
    Fp8LinearOp,
    all_close_1d,
    cutlass_block_fp8_supported,
    cutlass_fp8_supported,
    maybe_create_device_identity,
    normalize_e4m3fn_to_e4m3fnuz,
    per_tensor_dequantize,
)
from vllm.model_executor.parameter import (
    BlockQuantScaleParameter,
    ModelWeightParameter,
    PerTensorScaleParameter,
)
96
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
97
from vllm.platforms import current_platform
98
from vllm.scalar_type import scalar_types
99
100
101
102
from vllm.utils.deep_gemm import (
    is_deep_gemm_e8m0_used,
    is_deep_gemm_supported,
)
103
from vllm.utils.flashinfer import has_flashinfer_moe
104
from vllm.utils.import_utils import has_deep_gemm
105

106
107
108
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

109
110
111
112
ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)

113

114
115
116
117
118
class Fp8MoeBackend(Enum):
    NONE = 0
    FLASHINFER_TRTLLM = 1
    FLASHINFER_CUTLASS = 2
    DEEPGEMM = 3
119
120
    MARLIN = 4
    TRITON = 5
121
122


123
def get_fp8_moe_backend(
124
125
126
    block_quant: bool,
    moe_parallel_config: FusedMoEParallelConfig,
    with_lora_support: bool,
127
) -> Fp8MoeBackend | None:
128
129
130
131
    """
    Select the primary FP8 MoE backend
    Note: Shape-specific fallbacks may still occur at runtime.
    """
132
133
    if current_platform.is_xpu():
        return None
134
135
    if with_lora_support:
        return Fp8MoeBackend.TRITON
136
    # Prefer FlashInfer backends on supported GPUs; allow SM90 and SM100.
137
138
    if (
        current_platform.is_cuda()
139
        and (
140
            current_platform.is_device_capability_family(100)
141
142
            or current_platform.is_device_capability(90)
        )
143
144
145
        and envs.VLLM_USE_FLASHINFER_MOE_FP8
        and has_flashinfer_moe()
    ):
146
147
        backend = get_flashinfer_moe_backend()
        if backend == FlashinferMoeBackend.TENSORRT_LLM:
148
            logger.info_once("Using FlashInfer FP8 MoE TRTLLM backend for SM100")
149
150
            return Fp8MoeBackend.FLASHINFER_TRTLLM
        else:
151
            if block_quant and current_platform.is_device_capability_family(100):
152
153
154
155
156
157
                raise ValueError(
                    "FlashInfer FP8 MoE throughput backend does not "
                    "support block quantization. Please use "
                    "VLLM_FLASHINFER_MOE_BACKEND=latency "
                    "instead."
                )
158
            logger.info_once("Using FlashInfer FP8 MoE CUTLASS backend for SM90/SM100")
159
160
161
            return Fp8MoeBackend.FLASHINFER_CUTLASS

    # weight-only path for older GPUs without native FP8
162
163
164
165
    use_marlin = (
        not current_platform.has_device_capability(89)
        or envs.VLLM_TEST_FORCE_FP8_MARLIN
    )
166
167
168
169
170
171
    if current_platform.is_rocm():
        use_marlin = False
    if use_marlin:
        logger.info_once("Using Marlin backend for FP8 MoE")
        return Fp8MoeBackend.MARLIN

172
173
174
175
176
177
178
179
180
181
182
183
184
    # Determine if we should use DeepGEMM with block-quantized weights:
    # - If explicitly set by user, respect their choice
    # - If not explicitly set (default), disable when TP size is >= 8
    moe_use_deep_gemm = envs.VLLM_MOE_USE_DEEP_GEMM
    if not envs.is_set("VLLM_MOE_USE_DEEP_GEMM") and moe_parallel_config.tp_size >= 8:
        moe_use_deep_gemm = False
        logger.info_once(
            "DeepGEMM MoE is disabled by default when TP size is >= 8. "
            "Set VLLM_MOE_USE_DEEP_GEMM=1 to enable it.",
            scope="local",
        )

    if envs.VLLM_USE_DEEP_GEMM and moe_use_deep_gemm and block_quant:
185
        if not has_deep_gemm():
186
187
188
            logger.warning_once(
                "DeepGEMM backend requested but not available.", scope="local"
            )
189
        elif is_deep_gemm_supported():
190
            logger.info_once("Using DeepGEMM backend for FP8 MoE", scope="local")
191
192
193
194
195
196
197
            return Fp8MoeBackend.DEEPGEMM

    # default to Triton
    logger.info_once("Using Triton backend for FP8 MoE")
    return Fp8MoeBackend.TRITON


198
class Fp8Config(QuantizationConfig):
199
200
    """Config class for FP8."""

201
202
    def __init__(
        self,
203
        is_checkpoint_fp8_serialized: bool = False,
204
        activation_scheme: str = "dynamic",
205
206
        ignored_layers: list[str] | None = None,
        weight_block_size: list[int] | None = None,
207
    ) -> None:
208
        super().__init__()
209

210
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
211

212
        if activation_scheme not in ACTIVATION_SCHEMES:
213
            raise ValueError(f"Unsupported activation scheme {activation_scheme}")
214
        self.activation_scheme = activation_scheme
215
        self.ignored_layers = ignored_layers or []
216
217
218
219
        if weight_block_size is not None:
            if not is_checkpoint_fp8_serialized:
                raise ValueError(
                    "The block-wise quantization only supports fp8-serialized "
220
221
                    "checkpoint for now."
                )
222
223
224
            if len(weight_block_size) != 2:
                raise ValueError(
                    "The quantization block size of weight must have 2 "
225
226
                    f"dimensions, but got {len(weight_block_size)} dimensions"
                )
227
            if activation_scheme != "dynamic":
228
229
230
231
232
                raise ValueError(
                    "The block-wise quantization only supports "
                    "dynamic activation scheme for now, but got "
                    f"{activation_scheme} activation scheme."
                )
233
        self.weight_block_size = weight_block_size
234

235
    @classmethod
236
    def get_name(cls) -> QuantizationMethods:
237
238
239
        return "fp8"

    @classmethod
240
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
241
242
243
244
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
245
        return 75
246
247

    @classmethod
248
    def get_config_filenames(cls) -> list[str]:
249
250
        return []

251
252
    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if self.ignored_layers is not None:
253
            self.ignored_layers = hf_to_vllm_mapper.apply_list(self.ignored_layers)
254

255
    @classmethod
256
    def from_config(cls, config: dict[str, Any]) -> "Fp8Config":
257
        quant_method = cls.get_from_keys(config, ["quant_method"])
258
        is_checkpoint_fp8_serialized = "fp8" in quant_method
259
        activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
260
        ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
261
        weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None)
262
        if not ignored_layers:
263
264
265
266
267
268
269
270
271
272
273
274
275
            ignored_layers = cls.get_from_keys_or(
                config, ["modules_to_not_convert"], None
            )
        return cls(
            is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
            activation_scheme=activation_scheme,
            ignored_layers=ignored_layers,
            weight_block_size=weight_block_size,
        )

    def get_xpu_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
276
        from vllm.model_executor.layers.quantization.ipex_quant import (
277
278
279
280
            XPUFp8LinearMethod,
            XPUFp8MoEMethod,
        )

281
282
283
284
        fp8_config = Fp8Config(
            is_checkpoint_fp8_serialized=self.is_checkpoint_fp8_serialized,
            activation_scheme=self.activation_scheme,
            ignored_layers=self.ignored_layers,
285
286
            weight_block_size=self.weight_block_size,
        )
287
288

        if isinstance(layer, LinearBase):
289
290
291
292
293
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
294
295
296
                return UnquantizedLinearMethod()
            return XPUFp8LinearMethod(fp8_config)
        elif isinstance(layer, FusedMoE):
297
298
299
300
301
302
303
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
                return UnquantizedFusedMoEMethod(layer.moe_config)

304
305
306
307
308
            return XPUFp8MoEMethod(fp8_config, layer)
        elif isinstance(layer, Attention):
            return Fp8KVCacheMethod(self)
        return None

309
310
311
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
312
313
        if current_platform.is_xpu():
            return self.get_xpu_quant_method(layer, prefix)
314
        if isinstance(layer, LinearBase):
315
316
317
318
319
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
320
                return UnquantizedLinearMethod()
321
322
323
            quant_method = Fp8LinearMethod(self)
            quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
            return quant_method
324
        elif isinstance(layer, FusedMoE):
325
326
327
328
329
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
XuruiYang's avatar
XuruiYang committed
330
                return UnquantizedFusedMoEMethod(layer.moe_config)
331
332
333
334
            if self.is_checkpoint_fp8_serialized:
                moe_quant_method = Fp8MoEMethod(self, layer)
            else:
                moe_quant_method = Fp8OnlineMoEMethod(self, layer)
335
336
            moe_quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
            return moe_quant_method
337
        elif isinstance(layer, Attention):
338
            return Fp8KVCacheMethod(self)
339
        return None
340

341
    def get_cache_scale(self, name: str) -> str | None:
342
343
344
345
346
347
348
349
350
351
352
353
        """
        Check whether the param name matches the format for k/v cache scales
        in compressed-tensors. If this is the case, return its equivalent
        param name expected by vLLM

        :param name: param name
        :return: matching param name for KV cache scale in vLLM
        """
        if name.endswith(".output_scale") and ".k_proj" in name:
            return name.replace(".k_proj.output_scale", ".attn.k_scale")
        if name.endswith(".output_scale") and ".v_proj" in name:
            return name.replace(".v_proj.output_scale", ".attn.v_scale")
354
355
356
357
358
        if name.endswith(".output_scale") and ".q_proj" in name:
            return name.replace(".q_proj.output_scale", ".attn.q_scale")
        if name.endswith("self_attn.prob_output_scale"):
            return name.replace(".prob_output_scale", ".attn.prob_scale")
        # If no matches, return None
359
360
        return None

361

362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
class CopyNumelCounter(TorchDispatchMode):
    """
    Tracks total number of elements modified with `copy_`. Useful for keeping
    track of weight loading where underlying weights can be arbitrarily
    transformed (such as with `narrow`) before calling copy.
    """

    def __init__(self):
        super().__init__()
        self.copied_numel = 0

    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        out = func(*args, **kwargs)
        if func == torch.ops.aten.copy_.default:
            self.copied_numel += args[0].numel()
        return out


382
383
class Fp8LinearMethod(LinearMethodBase):
    """Linear method for FP8.
384
385
386
387
388
389
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.
390
391
392
393
394

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
    2. Only support float8_e4m3fn data type due to the limitation of
       torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
395

396
397
398
399
    Args:
        quant_config: The quantization config.
    """

400
    def __init__(self, quant_config: Fp8Config):
401
        self.quant_config = quant_config
402
        self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
403
        self.out_dtype = torch.get_default_dtype()
404

405
406
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
407
        self.marlin_input_dtype = None
408
409
410
411
        self.use_marlin = (
            not current_platform.has_device_capability(89)
            or envs.VLLM_TEST_FORCE_FP8_MARLIN
        )
412
        # Disable marlin for rocm
413
        if current_platform.is_rocm():
414
            self.use_marlin = False
415
        if vllm_is_batch_invariant():
416
            self.use_marlin = False
417

418
        self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enaled()
419
        self.use_deep_gemm = is_deep_gemm_supported()
420

421
422
        self.weight_block_size = self.quant_config.weight_block_size
        self.block_quant = self.weight_block_size is not None
423
        self.act_q_static = self.quant_config.activation_scheme == "static"
424
425
        if self.weight_block_size:
            self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
426
        else:
427
428
429
430
431
            # Use per-token quantization for better perf if dynamic and cutlass
            if not self.act_q_static and cutlass_fp8_supported():
                self.act_q_group_shape = GroupShape.PER_TOKEN
            else:
                self.act_q_group_shape = GroupShape.PER_TENSOR
432

433
434
435
436
437
438
439
440
441
442
443
444
        if self.block_quant:
            assert not self.act_q_static
            assert self.weight_block_size is not None
            self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
                weight_group_shape=GroupShape(*self.weight_block_size),
                act_quant_group_shape=self.act_q_group_shape,
                cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
                use_aiter_and_is_supported=self.use_aiter_and_is_supported,
            )
        else:
            self.fp8_linear = Fp8LinearOp(
                act_quant_static=self.act_q_static,
445
446
                act_quant_group_shape=self.act_q_group_shape,
            )
447

448
449
450
451
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
452
        output_partition_sizes: list[int],
453
454
455
456
457
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
458
459
        maybe_create_device_identity()

460
        output_size_per_partition = sum(output_partition_sizes)
461
        weight_loader = extra_weight_attrs.get("weight_loader")
462
463
464
465
466
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None
467

468
        if self.block_quant:
469
470
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
471
472
473
474
475
476
477
478
            validate_fp8_block_shape(
                layer,
                input_size,
                output_size,
                input_size_per_partition,
                output_partition_sizes,
                self.weight_block_size,
            )
479

480
        # WEIGHT
481
        if self.quant_config.is_checkpoint_fp8_serialized:
482
483
484
            weight = create_fp8_weight_parameter(
                output_size_per_partition, input_size_per_partition, weight_loader
            )
485
        else:
486
487
488
489
490

            def patched_weight_loader(param, loaded_weight, *args, **kwargs):
                # track how many elements we have updated
                if not hasattr(layer, "_loaded_numel"):
                    layer._loaded_numel = 0
491
492
493
494
495
496

                # load the current weight chunk
                copy_numel_counter = CopyNumelCounter()
                with copy_numel_counter:
                    res = weight_loader(param, loaded_weight, *args, **kwargs)  # type: ignore[misc]
                layer._loaded_numel += copy_numel_counter.copied_numel
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511

                # if we have loaded all of the elements, call
                # process_weights_after_loading
                target_loaded_numel = layer.weight.numel()
                if layer._loaded_numel == target_loaded_numel:
                    self.process_weights_after_loading(layer)

                    # Delete the bookkeeping
                    del layer._loaded_numel
                    # Prevent the usual `process_weights_after_loading` call from doing
                    # anything
                    layer._already_called_process_weights_after_loading = True

                return res

512
            # For non-serialized checkpoints, use original dtype
513
514
515
516
517
518
519
520
            weight = ModelWeightParameter(
                data=torch.empty(
                    output_size_per_partition,
                    input_size_per_partition,
                    dtype=params_dtype,
                ),
                input_dim=1,
                output_dim=0,
521
                weight_loader=patched_weight_loader,
522
            )
523
524
        layer.register_parameter("weight", weight)

525
526
527
528
        # If checkpoint is serialized fp8, load them.
        # Otherwise, wait until process_weights_after_loading.
        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
529
            if not self.block_quant:
530
531
532
533
534
535
536
                scale = create_fp8_scale_parameter(
                    PerTensorScaleParameter,
                    output_partition_sizes,
                    input_size_per_partition,
                    None,
                    weight_loader,
                )
537
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
538
539
                layer.register_parameter("weight_scale", scale)
            else:
540
541
                assert not self.act_q_static
                assert self.weight_block_size is not None
542
543
544
545
546
547
548
                scale = create_fp8_scale_parameter(
                    BlockQuantScaleParameter,
                    output_partition_sizes,
                    input_size_per_partition,
                    self.weight_block_size,
                    weight_loader,
                )
549
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
550
551
                # The weight_scale_inv name is intentional for deepseekv3
                layer.register_parameter("weight_scale_inv", scale)
552

553
            # INPUT ACTIVATION SCALE
554
            if self.act_q_static:
555
                scale = create_fp8_input_scale(output_partition_sizes, weight_loader)
556
                set_weight_attrs(scale, {"scale_type": "input_scale"})
557
                layer.register_parameter("input_scale", scale)
558
559
            else:
                layer.register_parameter("input_scale", None)
560

561
    def process_weights_after_loading(self, layer: Module) -> None:
562
563
564
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

565
        size_k_first = True
566
        input_scale = None
567
        # TODO(rob): refactor block quant into separate class.
568
        if self.block_quant:
569
            assert not self.act_q_static
570
            size_k_first = False
571

572
            weight, weight_scale_inv = process_fp8_weight_block_strategy(
573
574
                layer.weight, layer.weight_scale_inv
            )
575
576
577
578

            # Update layer with new values
            replace_parameter(layer, "weight", weight.data)
            replace_parameter(layer, "weight_scale_inv", weight_scale_inv.data)
579

580
        # If checkpoint not serialized fp8, quantize the weights.
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
        else:
            if not self.quant_config.is_checkpoint_fp8_serialized:
                qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
                weight = qweight.t()

            # If checkpoint is fp8 per-tensor, handle that there are N scales for N
            # shards in a fused module
            else:
                weight = layer.weight
                weight_scale = layer.weight_scale

                # If using w8a8, torch._scaled_mm needs per tensor, so
                # requantize the logical shards as a single weight.
                if not self.use_marlin:
                    weight, weight_scale, input_scale = (
                        process_fp8_weight_tensor_strategy(
                            weight,
                            weight_scale,
                            layer.logical_widths,
                            getattr(layer, "input_scale", None),
                        )
                    )
                    if self.act_q_static:
                        assert input_scale is not None
                        input_scale = input_scale.max()
                weight = weight.t()
607

608
609
610
611
612
613
            # Update layer with new values.
            replace_parameter(layer, "weight", weight.data)
            replace_parameter(layer, "weight_scale", weight_scale.data)

        if input_scale is not None:
            replace_parameter(layer, "input_scale", input_scale)
614
        else:
615
            layer.input_scale = None
616

617
        if self.use_marlin:
618
619
620
            prepare_fp8_layer_for_marlin(
                layer, size_k_first, input_dtype=self.marlin_input_dtype
            )
621
622
            # Activations not quantized for marlin.
            del layer.input_scale
623
            return
624

625
        if self.block_quant:
626
            maybe_post_process_fp8_weight_block(layer)
627

628
629
630
631
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
632
        bias: torch.Tensor | None = None,
633
    ) -> torch.Tensor:
634
635
        # if batch invariant mode is enabled, prefer DeepGEMM FP8 path
        # we will use BF16 dequant when DeepGEMM is not supported.
636
        if vllm_is_batch_invariant():
637
638
            if self.block_quant:
                assert self.weight_block_size is not None
639
640
641
                return self.w8a8_block_fp8_linear.apply(
                    input=x,
                    weight=layer.weight,
642
                    weight_scale=layer.weight_scale_inv,
643
644
645
                    input_scale=layer.input_scale,
                    bias=bias,
                )
646
            else:
647
648
649
                # per-tensor/channel: dequant to BF16 and run GEMM
                weight_fp8 = layer.weight.to(torch.bfloat16)
                weight_scale = layer.weight_scale.to(torch.bfloat16)
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
                if weight_scale.numel() == 1:
                    # Per-tensor: simple scalar multiplication
                    weight_bf16 = weight_fp8 * weight_scale
                else:
                    # Multiple scales (fused modules like QKV)
                    # Try to infer correct broadcasting
                    # weight is [K, N], scale could be [num_logical_weights]
                    # Need to figure out how to broadcast - for now just try
                    # direct multiplication
                    if (
                        weight_scale.dim() == 1
                        and weight_scale.shape[0] == weight_fp8.shape[0]
                    ):
                        # Per-row scaling
                        weight_bf16 = weight_fp8 * weight_scale.unsqueeze(1)
                    else:
                        # Fallback
                        weight_bf16 = weight_fp8 * weight_scale
668
                return torch.nn.functional.linear(x, weight_bf16.t(), bias)
669

670
        if self.use_marlin:
671
672
673
674
675
            if self.block_quant:
                weight_scale = layer.weight_scale_inv
            else:
                weight_scale = layer.weight_scale

676
677
678
            return apply_fp8_marlin_linear(
                input=x,
                weight=layer.weight,
679
                weight_scale=weight_scale,
680
681
682
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
683
                input_dtype=self.marlin_input_dtype,
684
685
                bias=bias,
            )
686

687
        if self.block_quant:
688
689
690
            assert self.weight_block_size is not None

            return self.w8a8_block_fp8_linear.apply(
691
                input=x,
692
                weight=layer.weight,
693
                weight_scale=layer.weight_scale_inv,
694
                input_scale=layer.input_scale,
695
                bias=bias,
696
            )
697

698
699
700
701
702
703
704
705
        return self.fp8_linear.apply(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            out_dtype=self.out_dtype,
            input_scale=layer.input_scale,
            bias=bias,
        )
706
707


708
709
710
711
712
713
714
715
716
717
718
719
720
class Fp8MoEMethod(FusedMoEMethodBase):
    """MoE method for FP8.
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

721
722
723
    def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
        super().__init__(layer.moe_config)
        self.layer = layer
724
        self.quant_config = quant_config
725
        self.weight_block_size = self.quant_config.weight_block_size
726
        self.block_quant: bool = self.weight_block_size is not None
727
        self.fp8_backend = get_fp8_moe_backend(
728
            self.block_quant, layer.moe_parallel_config, self.moe.is_lora_enabled
729
        )
730

731
        self.marlin_input_dtype = None
732
        self.use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN
733
        self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
734
735
736
737
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
            self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
        elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
            self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
            if self.block_quant and self.weight_block_size != [128, 128]:
                raise NotImplementedError(
                    "FlashInfer CUTLASS FP8 MoE backend only supports block "
                    "size [128, 128]."
                )
            if not self.block_quant:
                if layer.renormalize or layer.custom_routing_function is not None:
                    raise NotImplementedError(
                        "FlashInfer CUTLASS FP8 MoE backend does custom routing "
                        f"function or renormalization, but got {layer.renormalize} and "
                        f"{layer.custom_routing_function}."
                    )
                if layer.scoring_func != "sigmoid":
                    raise NotImplementedError(
                        "FlashInfer CUTLASS FP8 MoE backend only supports "
                        f"'sigmoid' scoring function, but got {layer.scoring_func}."
                    )
            if layer.activation != "silu":
                raise NotImplementedError(
                    "FlashInfer CUTLASS FP8 MoE backend only supports SiLU "
                    "activation function, but got {layer.activation}."
759
                )
760

761
762
763
764
765
766
767
768
769
    def create_weights(
        self,
        layer: Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
770
771
772
773
774
775
        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

776
777
778
        assert self.quant_config.is_checkpoint_fp8_serialized
        params_dtype = torch.float8_e4m3fn

779
        if self.block_quant:
780
781
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
782
783
            tp_size = get_tensor_model_parallel_world_size()
            block_n, block_k = (
784
785
                self.weight_block_size[0],
                self.weight_block_size[1],
786
787
788
789
790
            )
            # NOTE: To ensure proper alignment of the block-wise quantization
            # scales, the output_size of the weights for both the gate and up
            # layers must be divisible by block_n.
            # Required by column parallel or enabling merged weights
791
            if intermediate_size_per_partition % block_n != 0:
792
793
                raise ValueError(
                    f"The output_size of gate's and up's weight = "
794
                    f"{intermediate_size_per_partition} is not divisible by "
795
796
797
                    f"weight quantization block_n = {block_n}."
                )
            if tp_size > 1 and intermediate_size_per_partition % block_k != 0:
798
                # Required by row parallel
799
800
801
                raise ValueError(
                    f"The input_size of down's weight = "
                    f"{intermediate_size_per_partition} is not divisible by "
802
803
                    f"weight quantization block_k = {block_k}."
                )
804
805

        # WEIGHTS
806
807
808
809
810
811
812
813
814
        w13_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                hidden_size,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
815
816
817
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

818
819
820
821
822
823
824
825
826
        w2_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
827
828
829
830
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
831
832
833
        if not self.block_quant:
            # Allocate 2 scales for w1 and w3 respectively.
            # They will be combined to a single scale after weight loading.
834
835
836
837
838
839
            w13_weight_scale = torch.nn.Parameter(
                torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
            )
            w2_weight_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
840
841
842
843
844
845
            layer.register_parameter("w13_weight_scale", w13_weight_scale)
            layer.register_parameter("w2_weight_scale", w2_weight_scale)
        else:
            w13_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
846
                    2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
847
848
849
850
851
852
853
854
855
                    (hidden_size + block_k - 1) // block_k,
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            w2_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
                    (hidden_size + block_n - 1) // block_n,
856
                    (intermediate_size_per_partition + block_k - 1) // block_k,
857
858
859
860
861
862
863
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
            layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
            assert self.quant_config.activation_scheme == "dynamic"
864

865
866
867
        # Add the quantization method used (per tensor/grouped/channel)
        # to ensure the weight scales are loaded in properly
        extra_weight_attrs.update(
868
869
870
871
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
            if self.block_quant
            else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
        )
872
873
        set_weight_attrs(w13_weight_scale, extra_weight_attrs)
        set_weight_attrs(w2_weight_scale, extra_weight_attrs)
874
875
876

        # INPUT_SCALES
        if self.quant_config.activation_scheme == "static":
877
878
879
            w13_input_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
880
            layer.register_parameter("w13_input_scale", w13_input_scale)
881
            set_weight_attrs(w13_input_scale, extra_weight_attrs)
882

883
884
885
            w2_input_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
886
            layer.register_parameter("w2_input_scale", w2_input_scale)
887
888
            set_weight_attrs(w2_input_scale, extra_weight_attrs)

889
        else:
890
891
            layer.w13_input_scale = None
            layer.w2_input_scale = None
892

893
894
        self.rocm_aiter_moe_enabled = False

895
    def process_weights_after_loading(self, layer: Module) -> None:
896
897
898
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

899
900
        # Lazy import to avoid importing triton too early.

901
        self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
902

903
        # TODO (rob): refactor block quant into separate class.
904
        if self.block_quant:
905
            assert self.quant_config.activation_scheme == "dynamic"
906
            if current_platform.is_fp8_fnuz():
907
                w13_weight, w13_weight_scale_inv, w13_input_scale = (
908
                    normalize_e4m3fn_to_e4m3fnuz(
909
910
911
912
913
914
                        layer.w13_weight,
                        layer.w13_weight_scale_inv,
                        layer.w13_input_scale,
                    )
                )
                w2_weight, w2_weight_scale_inv, w2_input_scale = (
915
                    normalize_e4m3fn_to_e4m3fnuz(
916
917
918
                        layer.w2_weight, layer.w2_weight_scale_inv, layer.w2_input_scale
                    )
                )
919
            elif self.flashinfer_moe_backend is not None:
920
921
                # NOTE: weights have to be swapped since the activation is
                # applied on different half for flashinfer vs vllm
922
                w13_weight = swap_w13_to_w31(layer.w13_weight.data)
923
                w13_weight_scale_inv = swap_w13_to_w31(layer.w13_weight_scale_inv.data)
924
925
                w2_weight = layer.w2_weight.data
                w2_weight_scale_inv = layer.w2_weight_scale_inv.data
926
927
928
929
930
931
932
            else:
                w13_weight = layer.w13_weight.data
                w13_weight_scale_inv = layer.w13_weight_scale_inv.data
                w2_weight = layer.w2_weight
                w2_weight_scale_inv = layer.w2_weight_scale_inv

            # torch.compile() cannot use Parameter subclasses.
933
934
935
936
            replace_parameter(layer, "w13_weight", w13_weight)
            replace_parameter(layer, "w13_weight_scale_inv", w13_weight_scale_inv)
            replace_parameter(layer, "w2_weight", w2_weight)
            replace_parameter(layer, "w2_weight_scale_inv", w2_weight_scale_inv)
937
            if self.rocm_aiter_moe_enabled:
938
                # reshaping weights is required for aiter moe kernel.
939
                shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
940
941
                    layer.w13_weight.data, layer.w2_weight.data
                )
942

943
944
                replace_parameter(layer, "w13_weight", shuffled_w13)
                replace_parameter(layer, "w2_weight", shuffled_w2)
945

946
            # DeepGemm scales need to be transposed and aligned. We try to do
947
            # it ahead of time for performance reasons.
948
            if self.fp8_backend == Fp8MoeBackend.DEEPGEMM:
949
950
951
952
953
954
                dg_w13_weight, dg_w13_weight_scale_inv = (
                    deepgemm_post_process_fp8_weight_block(
                        wq=layer.w13_weight.data,
                        ws=layer.w13_weight_scale_inv.data,
                        quant_block_shape=tuple(layer.weight_block_size),
                        use_e8m0=is_deep_gemm_e8m0_used(),
955
                    )
956
957
958
959
960
961
962
                )
                dg_w2_weight, dg_w2_weight_scale_inv = (
                    deepgemm_post_process_fp8_weight_block(
                        wq=layer.w2_weight.data,
                        ws=layer.w2_weight_scale_inv.data,
                        quant_block_shape=tuple(layer.weight_block_size),
                        use_e8m0=is_deep_gemm_e8m0_used(),
963
                    )
964
965
966
967
968
969
970
971
972
                )
                layer.w13_weight = Parameter(dg_w13_weight, requires_grad=False)
                layer.w13_weight_scale_inv = Parameter(
                    dg_w13_weight_scale_inv, requires_grad=False
                )
                layer.w2_weight = Parameter(dg_w2_weight, requires_grad=False)
                layer.w2_weight_scale_inv = Parameter(
                    dg_w2_weight_scale_inv, requires_grad=False
                )
973
974
975
976
        else:
            # Fp8 moe kernels require a single activation scale.
            # We take the max of all the scales in case they differ.
            if self.quant_config.activation_scheme == "static":
977
                if layer.w13_input_scale is None or layer.w2_input_scale is None:
978
979
                    raise ValueError(
                        "QuantConfig has static quantization, but found "
980
981
982
983
984
                        "activation scales are None."
                    )
                if not all_close_1d(layer.w13_input_scale) or not all_close_1d(
                    layer.w2_input_scale
                ):
985
                    logger.warning_once(
986
987
                        "Found input_scales that are not equal for "
                        "fp8 MoE layer. Using the maximum across experts "
988
989
                        "for each layer."
                    )
990
991
                replace_parameter(layer, "w13_input_scale", layer.w13_input_scale.max())
                replace_parameter(layer, "w2_input_scale", layer.w2_input_scale.max())
992
            if current_platform.is_fp8_fnuz():
993
                # Normalize the weights and scales
994
                w13_weight, w13_weight_scale, w13_input_scale = (
995
                    normalize_e4m3fn_to_e4m3fnuz(
996
997
998
999
                        layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale
                    )
                )
                w2_weight, w2_weight_scale, w2_input_scale = (
1000
                    normalize_e4m3fn_to_e4m3fnuz(
1001
1002
1003
                        layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale
                    )
                )
1004
                # Reset the parameter
1005
1006
                replace_parameter(layer, "w13_weight", w13_weight)
                replace_parameter(layer, "w13_weight_scale", w13_weight_scale)
1007
                if w13_input_scale is not None:
1008
1009
1010
                    replace_parameter(layer, "w13_input_scale", w13_input_scale)
                replace_parameter(layer, "w2_weight", w2_weight)
                replace_parameter(layer, "w2_weight_scale", w2_weight_scale)
1011
                if w2_input_scale is not None:
1012
                    replace_parameter(layer, "w2_input_scale", w2_input_scale)
1013
1014
1015

            # Fp8 moe kernel needs single weight scale for w13 per expert.
            # We take the max then dequant and requant each expert.
1016
            assert layer.w13_weight_scale is not None
1017
            shard_size = layer.intermediate_size_per_partition
1018
            max_w13_scales = layer.w13_weight_scale.max(dim=1).values
1019
            for expert_id in range(layer.local_num_experts):
1020
1021
1022
                start = 0
                for shard_id in range(2):
                    dq_weight = per_tensor_dequantize(
1023
1024
1025
1026
1027
1028
                        layer.w13_weight[expert_id][start : start + shard_size, :],
                        layer.w13_weight_scale[expert_id][shard_id],
                    )
                    layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
                        ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
                    )
1029
1030
                    start += shard_size

1031
            if self.rocm_aiter_moe_enabled:
1032
                shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
1033
1034
                    layer.w13_weight, layer.w2_weight
                )
1035

1036
1037
                replace_parameter(layer, "w13_weight", shuffled_w13)
                replace_parameter(layer, "w2_weight", shuffled_w2)
1038

1039
            replace_parameter(layer, "w13_weight_scale", max_w13_scales)
1040

1041
1042
1043
1044
1045
1046
            if self.flashinfer_moe_backend is not None:
                # NOTE: weights have to be swapped since the activation is
                # applied on different half for flashinfer vs vllm
                assert not self.block_quant
                register_moe_scaling_factors(layer)
                w13_weight = swap_w13_to_w31(layer.w13_weight.data)
1047
                if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
1048
1049
1050
                    rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
                layer.w13_weight.data = w13_weight.data

1051
        if self.use_marlin:
1052
1053
1054
            prepare_moe_fp8_layer_for_marlin(
                layer, False, input_dtype=self.marlin_input_dtype
            )
1055
1056
1057
            # Activations not quantized for marlin.
            del layer.w13_input_scale
            del layer.w2_input_scale
1058

1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
        # NOTE(rob): this is a WIP refactor. We are first migrating
        # all of the kernels in the TP case to use mk. Once this is
        # done, then we will initialzie the TP case and DP/EP case
        # via the same code path (i.e. via maybe_init_modular_kernel).
        # NOTE(rob): in progress migrating all into this format.
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
            from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
                FlashInferExperts,
            )
            from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import (  # noqa: E501
                FlashInferAllGatherMoEPrepareAndFinalize,
            )

            config = self.get_fused_moe_quant_config(layer)
            assert config is not None
            self.moe_quant_config = config

            self.kernel = mk.FusedMoEModularKernel(
                FlashInferAllGatherMoEPrepareAndFinalize(
                    use_dp=(self.moe.dp_size > 1),
                    use_deepseek_fp8_block_scale=self.block_quant,
                ),
                FlashInferExperts(
                    out_dtype=torch.get_default_dtype(),
                    quant_config=self.moe_quant_config,
                    ep_rank=self.moe.ep_rank,
                    ep_size=self.moe.ep_size,
                    tp_rank=self.moe.tp_rank,
                    tp_size=self.moe.tp_size,
                    use_dp=(self.moe.dp_size > 1),
                    use_deepseek_fp8_block_scale=self.block_quant,
                ),
            )
            self.use_inplace = False

        elif self.fp8_backend in [Fp8MoeBackend.DEEPGEMM, Fp8MoeBackend.TRITON]:
            from vllm.model_executor.layers.fused_moe import (
                TritonOrDeepGemmExperts,
            )
            from vllm.model_executor.layers.fused_moe.prepare_finalize import (
                MoEPrepareAndFinalizeNoEP,
            )

            config = self.get_fused_moe_quant_config(layer)
            assert config is not None
            self.moe_quant_config = config
            self.kernel = mk.FusedMoEModularKernel(
                MoEPrepareAndFinalizeNoEP(),
                TritonOrDeepGemmExperts(
                    quant_config=self.moe_quant_config,
                    allow_deep_gemm=(self.fp8_backend == Fp8MoeBackend.DEEPGEMM),
                ),
            )
            self.use_inplace = True

1114
1115
1116
1117
    def maybe_make_prepare_finalize(
        self,
        routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
    ) -> mk.FusedMoEPrepareAndFinalize | None:
1118
        if (
1119
1120
            current_platform.is_xpu()
            or self.rocm_aiter_moe_enabled
1121
1122
1123
            or self.use_marlin
            or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
        ):
1124
1125
            return None
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
1126
1127
1128
1129
1130
1131
            if self.block_quant:
                assert self.weight_block_size == [128, 128], (
                    f"Only support weight_block_size == [128, 128], "
                    f"got {self.weight_block_size}"
                )
            # Wire block-scale flag through prepare/finalize when using CUTLASS
1132
            prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
1133
1134
                self.moe,
                use_deepseek_fp8_block_scale=self.block_quant,
1135
            )
1136
1137
1138
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
        else:
1139
            return super().maybe_make_prepare_finalize(routing_tables)
1140

bnellnm's avatar
bnellnm committed
1141
1142
1143
    def select_gemm_impl(
        self,
        prepare_finalize: FusedMoEPrepareAndFinalize,
1144
        layer: torch.nn.Module,
bnellnm's avatar
bnellnm committed
1145
    ) -> FusedMoEPermuteExpertsUnpermute:
1146
        from vllm.model_executor.layers.fused_moe import (
1147
1148
            BatchedDeepGemmExperts,
            BatchedTritonExperts,
1149
            TritonExperts,
1150
1151
            TritonOrDeepGemmExperts,
        )
1152

1153
        assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
1154
1155
            "Marlin and ROCm AITER are not supported with all2all yet."
        )
1156

1157
1158
        assert self.moe_quant_config is not None

1159
1160
1161
1162
1163
        if (
            prepare_finalize.activation_format
            == FusedMoEActivationFormat.BatchedExperts
        ):
            max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
bnellnm's avatar
bnellnm committed
1164
            assert max_num_tokens_per_rank is not None
1165
1166

            experts_impl = (
1167
1168
1169
                BatchedDeepGemmExperts
                if self.fp8_backend == Fp8MoeBackend.DEEPGEMM
                else BatchedTritonExperts
1170
            )
bnellnm's avatar
bnellnm committed
1171
            logger.debug(
1172
1173
                "%s(%s): max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
                experts_impl.__name__,
1174
1175
1176
1177
1178
                self.__class__.__name__,
                max_num_tokens_per_rank,
                self.weight_block_size,
                False,
            )
1179
            return experts_impl(
1180
                max_num_tokens=max_num_tokens_per_rank,
1181
                num_dispatchers=prepare_finalize.num_dispatchers(),
1182
                quant_config=self.moe_quant_config,
1183
            )
1184
1185
        elif self.moe.is_lora_enabled:
            return TritonExperts(quant_config=self.moe_quant_config)
1186
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
1187
            # Select GEMM experts with block-scale when weights are block-quantized
1188
            experts = select_cutlass_fp8_gemm_impl(
1189
1190
                self.moe,
                self.moe_quant_config,
1191
                use_deepseek_fp8_block_scale=self.block_quant,
1192
1193
1194
            )
            logger.debug_once("Using %s", experts.__class__.__name__)
            return experts
1195
        else:
bnellnm's avatar
bnellnm committed
1196
1197
            logger.debug(
                "TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s",
1198
1199
1200
1201
                self.__class__.__name__,
                self.weight_block_size,
                False,
            )
bnellnm's avatar
bnellnm committed
1202
            return TritonOrDeepGemmExperts(
1203
                quant_config=self.moe_quant_config,
1204
                allow_deep_gemm=(self.fp8_backend == Fp8MoeBackend.DEEPGEMM),
1205
1206
            )

1207
    def get_fused_moe_quant_config(
1208
        self, layer: torch.nn.Module
1209
    ) -> FusedMoEQuantConfig | None:
1210
1211
1212
1213
        if self.use_marlin:
            return None

        return fp8_w8a8_moe_quant_config(
1214
1215
1216
1217
1218
1219
1220
1221
            w1_scale=(
                layer.w13_weight_scale_inv
                if self.block_quant
                else layer.w13_weight_scale
            ),
            w2_scale=(
                layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
            ),
1222
1223
            a1_scale=layer.w13_input_scale,
            a2_scale=layer.w2_input_scale,
1224
            block_shape=self.weight_block_size,
1225
1226
        )

1227
1228
1229
1230
1231
1232
1233
1234
    @property
    def supports_eplb(self) -> bool:
        return True

    @property
    def allow_inplace(self) -> bool:
        return True

1235
1236
    def apply(
        self,
1237
        layer: FusedMoE,
1238
1239
        x: torch.Tensor,
        router_logits: torch.Tensor,
1240
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
1241
        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
1242
            # TODO(rob): convert this to MK.
1243
1244
1245
1246
            if layer.enable_eplb:
                raise NotImplementedError("EPLB not supported for `Fp8MoEMethod` yet.")
            assert layer.activation == "silu", (
                f"Expected 'silu' activation but got {layer.activation}"
1247
            )
1248

1249
            if self.block_quant:
1250
                import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe  # noqa: E501, F401
1251
1252

                e_score_correction_bias = (
1253
1254
                    layer.e_score_correction_bias.to(x.dtype)
                    if layer.e_score_correction_bias is not None
1255
1256
                    else None
                )
1257
                routing_method_type = layer.routing_method_type
1258
                return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
1259
1260
1261
                    routing_logits=router_logits.to(torch.float32)
                    if routing_method_type == RoutingMethodType.DeepSeekV3
                    else router_logits,
1262
1263
1264
1265
1266
1267
                    routing_bias=e_score_correction_bias,
                    x=x,
                    w13_weight=layer.w13_weight,
                    w13_weight_scale_inv=layer.w13_weight_scale_inv,
                    w2_weight=layer.w2_weight,
                    w2_weight_scale_inv=layer.w2_weight_scale_inv,
1268
1269
1270
1271
                    global_num_experts=layer.global_num_experts,
                    top_k=layer.top_k,
                    num_expert_group=layer.num_expert_group,
                    topk_group=layer.topk_group,
1272
1273
1274
                    intermediate_size=layer.intermediate_size_per_partition,
                    expert_offset=layer.ep_rank * layer.local_num_experts,
                    local_num_experts=layer.local_num_experts,
1275
                    block_shape=self.weight_block_size,
1276
                    routing_method_type=routing_method_type,
1277
                    routed_scaling=layer.routed_scaling_factor,
1278
1279
                )
            else:
1280
1281
1282
                assert (
                    not layer.renormalize and layer.custom_routing_function is not None
                )
XuruiYang's avatar
XuruiYang committed
1283
                result = apply_flashinfer_per_tensor_scale_fp8(
1284
1285
1286
                    layer=layer,
                    hidden_states=x,
                    router_logits=router_logits,
1287
1288
1289
1290
1291
1292
                    routing_bias=layer.e_score_correction_bias,
                    global_num_experts=layer.global_num_experts,
                    top_k=layer.top_k,
                    num_expert_group=layer.num_expert_group,
                    topk_group=layer.topk_group,
                    apply_router_weight_on_input=layer.apply_router_weight_on_input,
1293
                )
1294

1295
        topk_weights, topk_ids = layer.select_experts(
1296
1297
1298
1299
1300
1301
            hidden_states=x,
            router_logits=router_logits,
        )

        if self.rocm_aiter_moe_enabled:
            from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (  # noqa: E501
1302
1303
1304
                rocm_aiter_fused_experts,
            )

1305
            # TODO(rob): convert this to MK.
XuruiYang's avatar
XuruiYang committed
1306
            result = rocm_aiter_fused_experts(
1307
1308
1309
                x,
                layer.w13_weight,
                layer.w2_weight,
1310
1311
                topk_weights=topk_weights,
                topk_ids=topk_ids,
1312
1313
1314
                activation=layer.activation,
                apply_router_weight_on_input=layer.apply_router_weight_on_input,
                expert_map=layer.expert_map,
1315
1316
                quant_config=self.moe_quant_config,
            )
1317
        elif self.use_marlin:
1318
            # TODO(rob): convert this to MK.
1319
1320
1321
            assert layer.activation == "silu", (
                f"{layer.activation} not supported for Marlin MoE."
            )
1322
            result = fused_marlin_moe(
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
                x,
                layer.w13_weight,
                layer.w2_weight,
                None,
                None,
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                router_logits,
                topk_weights,
                topk_ids,
                quant_type_id=scalar_types.float8_e4m3fn.id,
1334
1335
1336
                apply_router_weight_on_input=layer.apply_router_weight_on_input,
                global_num_experts=layer.global_num_experts,
                expert_map=layer.expert_map,
1337
                input_dtype=self.marlin_input_dtype,
1338
1339
                workspace=layer.workspace,
            )
1340
1341
        else:
            result = self.kernel(
1342
                x,
1343
1344
                layer.w13_weight,
                layer.w2_weight,
1345
1346
                topk_weights,
                topk_ids,
1347
                inplace=self.use_inplace,
1348
1349
1350
1351
                activation=layer.activation,
                global_num_experts=layer.global_num_experts,
                expert_map=layer.expert_map,
                apply_router_weight_on_input=layer.apply_router_weight_on_input,
1352
            )
1353

1354
        return result
1355
1356


1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
class Fp8OnlineMoEMethod(Fp8MoEMethod):
    """MoE method for online FP8 quantization.
    Supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

    def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
        super().__init__(quant_config, layer)
        assert not quant_config.is_checkpoint_fp8_serialized
        assert quant_config.activation_scheme == "dynamic"
        assert quant_config.weight_block_size is None
        assert self.flashinfer_moe_backend is None

    def create_weights(
        self,
        layer: Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

        # We are doing online quantization, patch the weight loaded
        # to call `process_weights_after_loading` in a streaming fashion
        # as soon as the last weight chunk is loaded.
        weight_loader = extra_weight_attrs["weight_loader"]
        # create a new holder to prevent modifying behavior of any other
        # objects which might depend on the old one
        new_extra_weight_attrs = extra_weight_attrs

        def patched_weight_loader(param, loaded_weight, *args, **kwargs):
            # add a counter to track how many elements we have updated
            if not hasattr(layer, "_loaded_numel"):
                layer._loaded_numel = 0
1401
1402
1403
1404
1405
1406

            # load the current weight chunk
            copy_numel_counter = CopyNumelCounter()
            with copy_numel_counter:
                res = weight_loader(param, loaded_weight, *args, **kwargs)  # type: ignore[misc]
            layer._loaded_numel += copy_numel_counter.copied_numel
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503

            # if we have loaded all of the elements, call
            # process_weights_after_loading
            target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel()
            if layer._loaded_numel == target_loaded_numel:
                self.process_weights_after_loading(layer)

                # Delete the bookkeeping
                del layer._loaded_numel
                # Prevent the usual `process_weights_after_loading` call
                # from doing anything
                layer._already_called_process_weights_after_loading = True

            return res

        new_extra_weight_attrs["weight_loader"] = patched_weight_loader
        extra_weight_attrs = new_extra_weight_attrs

        # WEIGHTS
        w13_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                hidden_size,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

        w2_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
        # Allocate 2 scales for w1 and w3 respectively.
        # They will be combined to a single scale after weight loading.
        w13_weight_scale = torch.nn.Parameter(
            torch.ones(num_experts, dtype=torch.float32), requires_grad=False
        )
        w2_weight_scale = torch.nn.Parameter(
            torch.ones(num_experts, dtype=torch.float32), requires_grad=False
        )
        layer.register_parameter("w13_weight_scale", w13_weight_scale)
        layer.register_parameter("w2_weight_scale", w2_weight_scale)

        layer.w13_input_scale = None
        layer.w2_input_scale = None

        self.rocm_aiter_moe_enabled = False

    def process_weights_after_loading(self, layer: Module) -> None:
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

        # Lazy import to avoid importing triton too early.
        self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()

        # If checkpoint is fp16, quantize in place.
        fp8_dtype = current_platform.fp8_dtype()
        w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
        w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)

        for expert in range(layer.local_num_experts):
            w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
                ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
            )
            w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
                ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
            )
        replace_parameter(layer, "w13_weight", w13_weight)
        replace_parameter(layer, "w2_weight", w2_weight)

        # Reshuffle weights for AITER if needed.
        if self.rocm_aiter_moe_enabled:
            shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
                layer.w13_weight, layer.w2_weight
            )
            replace_parameter(layer, "w13_weight", shuffled_w13)
            replace_parameter(layer, "w2_weight", shuffled_w2)

        # Rushuffle weights for MARLIN if needed.
        if self.use_marlin:
            prepare_moe_fp8_layer_for_marlin(
                layer, False, input_dtype=self.marlin_input_dtype
            )


1504
1505
1506
class Fp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
1507
1508
1509
    """

    def __init__(self, quant_config: Fp8Config):
1510
        super().__init__(quant_config)