fp8.py 42.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import functools
5
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
6
7

import torch
8
import torch.nn.functional as F
9
10
11
from torch.nn import Module
from torch.nn.parameter import Parameter

12
import vllm.envs as envs
13
from vllm import _custom_ops as ops
14
from vllm.distributed import get_tensor_model_parallel_world_size
15
from vllm.logger import init_logger
16
17
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
                                                  FusedMoeWeightScaleSupported)
18
19
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
                                               UnquantizedLinearMethod)
20
from vllm.model_executor.layers.quantization import QuantizationMethods
21
from vllm.model_executor.layers.quantization.base_config import (
22
    QuantizationConfig, QuantizeMethodBase)
23
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
24
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
25
26
    apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin,
    prepare_moe_fp8_layer_for_marlin)
27
28
from vllm.model_executor.layers.quantization.utils.quant_utils import (
    is_layer_skipped)
29
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
30
31
32
33
    Fp8LinearOp, all_close_1d, cutlass_block_fp8_supported,
    cutlass_fp8_supported, maybe_create_device_identity,
    normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
    requantize_with_max_scale)
34
35
from vllm.model_executor.parameter import (BlockQuantScaleParameter,
                                           ModelWeightParameter,
36
                                           PerTensorScaleParameter)
37
from vllm.model_executor.utils import set_weight_attrs
38
from vllm.platforms import current_platform
39
from vllm.scalar_type import scalar_types
40
from vllm.utils import has_deep_gemm
41

42
43
44
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

45
46
47
48
ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)

49
50
51
52
53
54

def _is_col_major(x: torch.Tensor) -> bool:
    assert x.dim() == 3
    b, m, n = x.shape
    return x.stride(0) == m * n and x.stride(1) == 1 and x.stride(2) == m

55

56
class Fp8Config(QuantizationConfig):
57
58
    """Config class for FP8."""

59
60
    def __init__(
        self,
61
        is_checkpoint_fp8_serialized: bool = False,
62
        activation_scheme: str = "dynamic",
63
64
        ignored_layers: Optional[list[str]] = None,
        weight_block_size: Optional[list[int]] = None,
65
    ) -> None:
66
        super().__init__()
67

68
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
69

70
71
72
        if activation_scheme not in ACTIVATION_SCHEMES:
            raise ValueError(
                f"Unsupported activation scheme {activation_scheme}")
73
        self.activation_scheme = activation_scheme
74
        self.ignored_layers = ignored_layers or []
75
76
77
78
79
80
81
82
83
84
85
86
87
88
        if weight_block_size is not None:
            if not is_checkpoint_fp8_serialized:
                raise ValueError(
                    "The block-wise quantization only supports fp8-serialized "
                    "checkpoint for now.")
            if len(weight_block_size) != 2:
                raise ValueError(
                    "The quantization block size of weight must have 2 "
                    f"dimensions, but got {len(weight_block_size)} dimensions")
            if activation_scheme != "dynamic":
                raise ValueError("The block-wise quantization only supports "
                                 "dynamic activation scheme for now, but got "
                                 f"{activation_scheme} activation scheme.")
        self.weight_block_size = weight_block_size
89

90
    @classmethod
91
    def get_name(cls) -> QuantizationMethods:
92
93
94
        return "fp8"

    @classmethod
95
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
96
97
98
99
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
100
        return 80
101
102

    @classmethod
103
    def get_config_filenames(cls) -> list[str]:
104
105
        return []

106
107
108
109
110
    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if self.ignored_layers is not None:
            self.ignored_layers = hf_to_vllm_mapper.apply_list(
                self.ignored_layers)

111
    @classmethod
112
    def from_config(cls, config: dict[str, Any]) -> "Fp8Config":
113
114
        quant_method = cls.get_from_keys(config, ["quant_method"])
        is_checkpoint_fp8_serialized = ("fp8" in quant_method)
115
        activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
116
        ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
117
118
        weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"],
                                                 None)
119
        return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
120
                   activation_scheme=activation_scheme,
121
122
                   ignored_layers=ignored_layers,
                   weight_block_size=weight_block_size)
123

124
125
    def get_quant_method(self, layer: torch.nn.Module,
                         prefix: str) -> Optional["QuantizeMethodBase"]:
126
127
        from vllm.attention.layer import Attention  # Avoid circular import

128
        if isinstance(layer, LinearBase):
129
130
131
            if is_layer_skipped(prefix=prefix,
                                ignored_layers=self.ignored_layers,
                                fused_mapping=self.packed_modules_mapping):
132
                return UnquantizedLinearMethod()
133
            return Fp8LinearMethod(self)
134
135
136
        elif isinstance(layer, FusedMoE):
            return Fp8MoEMethod(self)
        elif isinstance(layer, Attention):
137
            return Fp8KVCacheMethod(self)
138
        return None
139

140
141
142
143
144
145
146
147
148
149
150
151
152
    def get_cache_scale(self, name: str) -> Optional[str]:
        """
        Check whether the param name matches the format for k/v cache scales
        in compressed-tensors. If this is the case, return its equivalent
        param name expected by vLLM

        :param name: param name
        :return: matching param name for KV cache scale in vLLM
        """
        if name.endswith(".output_scale") and ".k_proj" in name:
            return name.replace(".k_proj.output_scale", ".attn.k_scale")
        if name.endswith(".output_scale") and ".v_proj" in name:
            return name.replace(".v_proj.output_scale", ".attn.v_scale")
153
154
155
156
157
        if name.endswith(".output_scale") and ".q_proj" in name:
            return name.replace(".q_proj.output_scale", ".attn.q_scale")
        if name.endswith("self_attn.prob_output_scale"):
            return name.replace(".prob_output_scale", ".attn.prob_scale")
        # If no matches, return None
158
159
        return None

160
161
162

class Fp8LinearMethod(LinearMethodBase):
    """Linear method for FP8.
163
164
165
166
167
168
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.
169
170
171
172
173

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
    2. Only support float8_e4m3fn data type due to the limitation of
       torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
174

175
176
177
178
    Args:
        quant_config: The quantization config.
    """

179
    def __init__(self, quant_config: Fp8Config):
180
        self.quant_config = quant_config
181
        self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
182
        self.out_dtype = torch.get_default_dtype()
183

184
185
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
186
187
        self.use_marlin = (not current_platform.has_device_capability(89)
                           or envs.VLLM_TEST_FORCE_FP8_MARLIN)
188
        # Disable marlin for rocm
189
        if current_platform.is_rocm():
190
            self.use_marlin = False
191

192
193
194
195
196
197
198
        # AITER is only supported on ROCm and only for FP8_FNUZ
        # and at the moment are MI300 series
        self.use_aiter_and_is_supported = (current_platform.is_rocm()
                                           and envs.VLLM_ROCM_USE_AITER
                                           and envs.VLLM_ROCM_USE_AITER_LINEAR
                                           and current_platform.is_fp8_fnuz())

199
        self.block_quant = self.quant_config.weight_block_size is not None
200
201
202
203
        self.fp8_linear = Fp8LinearOp(
            # Default to using per_token quantization if cutlass is supported
            use_per_token_if_dynamic=cutlass_fp8_supported())

204
205
206
207
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
208
        output_partition_sizes: list[int],
209
210
211
212
213
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
214
215
        maybe_create_device_identity()

216
        output_size_per_partition = sum(output_partition_sizes)
217
        weight_loader = extra_weight_attrs.get("weight_loader")
218
219
220
221
222
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None
223

224
225
226
        if self.block_quant:
            tp_size = get_tensor_model_parallel_world_size()
            assert self.quant_config.weight_block_size is not None
227
            layer.weight_block_size = self.quant_config.weight_block_size
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
            block_n, block_k = (
                self.quant_config.weight_block_size[0],
                self.quant_config.weight_block_size[1],
            )
            # Required by row parallel
            if (tp_size > 1
                    and input_size // input_size_per_partition == tp_size
                    and input_size_per_partition % block_k != 0):
                raise ValueError(
                    f"Weight input_size_per_partition = "
                    f"{input_size_per_partition} is not divisible by "
                    f"weight quantization block_k = {block_k}.")
            # Required by column parallel or enabling merged weights
            if (tp_size > 1 and output_size // output_size_per_partition
                    == tp_size) or len(output_partition_sizes) > 1:
                for output_partition_size in output_partition_sizes:
                    if output_partition_size % block_n != 0:
                        raise ValueError(
                            f"Weight output_partition_size = "
                            f"{output_partition_size} is not divisible by "
                            f"weight quantization block_n = {block_n}.")

250
251
252
253
        # WEIGHT
        weight_dtype = (torch.float8_e4m3fn
                        if self.quant_config.is_checkpoint_fp8_serialized else
                        params_dtype)
254
255
256
257
258
259
260
261

        weight = ModelWeightParameter(data=torch.empty(
            output_size_per_partition,
            input_size_per_partition,
            dtype=weight_dtype),
                                      input_dim=1,
                                      output_dim=0,
                                      weight_loader=weight_loader)
262
263
        layer.register_parameter("weight", weight)

264
265
266
267
        # If checkpoint is serialized fp8, load them.
        # Otherwise, wait until process_weights_after_loading.
        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
268
269
270
271
272
273
274
            if not self.block_quant:
                scale = PerTensorScaleParameter(
                    data=torch.empty(len(output_partition_sizes),
                                     dtype=torch.float32),
                    weight_loader=weight_loader,
                )
                scale[:] = torch.finfo(torch.float32).min
275
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
276
277
278
279
280
281
282
283
284
285
286
287
288
289
                layer.register_parameter("weight_scale", scale)
            else:
                assert self.quant_config.activation_scheme == "dynamic"
                scale = BlockQuantScaleParameter(
                    data=torch.empty(
                        (output_size_per_partition + block_n - 1) // block_n,
                        (input_size_per_partition + block_k - 1) // block_k,
                        dtype=torch.float32,
                    ),
                    input_dim=1,
                    output_dim=0,
                    weight_loader=weight_loader,
                )
                scale[:] = torch.finfo(torch.float32).min
290
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
291
292
                # The weight_scale_inv name is intentional for deepseekv3
                layer.register_parameter("weight_scale_inv", scale)
293

294
            # INPUT ACTIVATION SCALE
295
            if self.quant_config.activation_scheme == "static":
296
297
298
299
300
                scale = PerTensorScaleParameter(data=torch.empty(
                    len(output_partition_sizes), dtype=torch.float32),
                                                weight_loader=weight_loader)

                scale[:] = torch.finfo(torch.float32).min
301
                set_weight_attrs(scale, {"scale_type": "input_scale"})
302
                layer.register_parameter("input_scale", scale)
303
304
            else:
                layer.register_parameter("input_scale", None)
305

306
    def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor:
307
308
309
310
311
312
313
314
315
316
        # Pad the weight tensor. This is an optimization on ROCm platform, which
        # can benefit from tensors located far enough from one another in memory
        if (envs.VLLM_ROCM_FP8_PADDING and current_platform.is_rocm()
                and weight.stride(-1) == 1
                and (weight.stride(-2) * weight.element_size()) % 512 == 0):
            num_pad = 256 // weight.element_size()
            weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
            torch.cuda.empty_cache()
        return weight

317
    def process_weights_after_loading(self, layer: Module) -> None:
318
        size_k_first = True
319
        # TODO(rob): refactor block quant into separate class.
320
        if self.block_quant:
321
            assert self.quant_config.activation_scheme == "dynamic"
322
            size_k_first = False
323
            if current_platform.is_fp8_fnuz():
324
                weight, weight_scale_inv, _ = \
325
326
                    normalize_e4m3fn_to_e4m3fnuz(
                        weight=layer.weight,
327
328
329
330
331
                        weight_scale=layer.weight_scale_inv)
            else:
                weight = layer.weight.data
                weight_scale_inv = layer.weight_scale_inv.data

332
            weight = self._maybe_pad_weight(weight)
333

334
335
336
337
338
            # Torch.compile cannot use Parameter subclasses.
            layer.weight = Parameter(weight, requires_grad=False)
            layer.weight_scale_inv = Parameter(weight_scale_inv,
                                               requires_grad=False)

339
        # If checkpoint not serialized fp8, quantize the weights.
340
        elif not self.quant_config.is_checkpoint_fp8_serialized:
341
342
            qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
                                                         scale=None)
343
344

            # Update the layer with the new values.
345
346
            layer.weight = Parameter(qweight.t(), requires_grad=False)
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
347
            layer.input_scale = None
348

349
350
        # If checkpoint is fp8, handle that there are N scales for N
        # shards in a fused module
351
        else:
352
353
354
355
356
            layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
                                                    requires_grad=False)
            if self.quant_config.activation_scheme == "static":
                layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
                                                       requires_grad=False)
357
358
359

            weight = layer.weight
            weight_scale = layer.weight_scale
360
361
362

            # If using w8a8, torch._scaled_mm needs per tensor, so
            # requantize the logical shards as a single weight.
363
            if not self.use_marlin:
364
                # Dequant -> Quant with max scale so we can run per tensor.
365
                if current_platform.is_fp8_fnuz():
366
367
368
369
370
371
372
373
374
                    weight, weight_scale, input_scale = \
                        normalize_e4m3fn_to_e4m3fnuz(
                            weight=weight,
                            weight_scale=weight_scale,
                            input_scale=layer.input_scale)
                    if input_scale is not None:
                        layer.input_scale = Parameter(input_scale,
                                                      requires_grad=False)

375
                weight_scale, weight = requantize_with_max_scale(
376
377
                    weight=weight,
                    weight_scale=weight_scale,
378
379
                    logical_widths=layer.logical_widths,
                )
380

381
            weight = self._maybe_pad_weight(weight)
382
            # Update layer with new values.
383
            layer.weight = Parameter(weight.t(), requires_grad=False)
384
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
385
            if self.quant_config.activation_scheme == "static":
386
387
                layer.input_scale = Parameter(layer.input_scale.max(),
                                              requires_grad=False)
388

389
        if self.use_marlin:
390
            prepare_fp8_layer_for_marlin(layer, size_k_first)
391
392
            # Activations not quantized for marlin.
            del layer.input_scale
393

394
395
396
397
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
398

399
        if self.use_marlin:
400
401
402
403
            return apply_fp8_marlin_linear(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale,
404
405
406
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
407
                bias=bias)
408

409
410
        if self.block_quant:
            assert self.quant_config.weight_block_size is not None
411

412
            return torch.ops.vllm.apply_w8a8_block_fp8_linear(
413
414
415
416
417
418
                input=x,
                weight=layer.weight,
                block_size=self.quant_config.weight_block_size,
                weight_scale=layer.weight_scale_inv,
                input_scale=layer.input_scale,
                bias=bias,
419
                cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
420
                use_aiter_and_is_supported=self.use_aiter_and_is_supported,
421
422
            )

423
424
425
        return self.fp8_linear.apply(input=x,
                                     weight=layer.weight,
                                     weight_scale=layer.weight_scale,
426
                                     out_dtype=self.out_dtype,
427
428
                                     input_scale=layer.input_scale,
                                     bias=bias)
429
430


431
432
433
434
435
436
437
438
439
440
441
442
443
444
class Fp8MoEMethod(FusedMoEMethodBase):
    """MoE method for FP8.
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

    def __init__(self, quant_config: Fp8Config):
445

446
        from vllm.model_executor.layers.fused_moe import fused_experts
447
        self.quant_config = quant_config
448
        self.block_quant = self.quant_config.weight_block_size is not None
449

450
451
452
453
454
455
456
457
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
        self.use_marlin = (not current_platform.has_device_capability(89)
                           or envs.VLLM_TEST_FORCE_FP8_MARLIN)
        # Disable marlin for rocm
        if current_platform.is_rocm():
            self.use_marlin = False

458
459
460
        # Check for DeepGemm support.
        self.allow_deep_gemm = False
        if envs.VLLM_USE_DEEP_GEMM:
461
            if not has_deep_gemm():
462
                logger.warning_once("Failed to import DeepGemm kernels.")
463
464
465
            elif not self.block_quant:
                logger.warning_once("Model is not block quantized. Not using "
                                    " DeepGemm kernels")
466
467
468
469
470
471
472
473
            elif (current_platform.is_cuda()
                  and current_platform.has_device_capability(90)):
                logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.")
                self.allow_deep_gemm = True
            else:
                logger.warning_once(
                    "DeepGemm not supported on the current platform.")

474
        self.topk_indices_dtype = None
475
        self.fused_experts = functools.partial(  # type: ignore
476
            fused_experts,
477
            use_fp8_w8a8=True,
478
479
480
            block_shape=self.quant_config.weight_block_size,
            allow_deep_gemm=self.allow_deep_gemm)

481
    def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
482
483
                       intermediate_size_per_partition: int,
                       params_dtype: torch.dtype, **extra_weight_attrs):
484

485
486
487
488
489
490
        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

491
492
        if self.quant_config.is_checkpoint_fp8_serialized:
            params_dtype = torch.float8_e4m3fn
493
494
        if self.block_quant:
            assert self.quant_config.weight_block_size is not None
495
            layer.weight_block_size = self.quant_config.weight_block_size
496
497
498
499
500
501
502
503
504
            tp_size = get_tensor_model_parallel_world_size()
            block_n, block_k = (
                self.quant_config.weight_block_size[0],
                self.quant_config.weight_block_size[1],
            )
            # NOTE: To ensure proper alignment of the block-wise quantization
            # scales, the output_size of the weights for both the gate and up
            # layers must be divisible by block_n.
            # Required by column parallel or enabling merged weights
505
            if intermediate_size_per_partition % block_n != 0:
506
507
                raise ValueError(
                    f"The output_size of gate's and up's weight = "
508
                    f"{intermediate_size_per_partition} is not divisible by "
509
                    f"weight quantization block_n = {block_n}.")
510
511
            if (tp_size > 1
                    and intermediate_size_per_partition % block_k != 0):
512
                # Required by row parallel
513
514
515
516
                raise ValueError(
                    f"The input_size of down's weight = "
                    f"{intermediate_size_per_partition} is not divisible by "
                    f"weight quantization block_k = {block_k}.")
517
518

        # WEIGHTS
519
520
521
522
523
        w13_weight = torch.nn.Parameter(torch.empty(
            num_experts,
            2 * intermediate_size_per_partition,
            hidden_size,
            dtype=params_dtype),
524
525
526
527
                                        requires_grad=False)
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

528
529
530
531
532
        w2_weight = torch.nn.Parameter(torch.empty(
            num_experts,
            hidden_size,
            intermediate_size_per_partition,
            dtype=params_dtype),
533
534
535
536
537
                                       requires_grad=False)
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
        if not self.block_quant:
            # Allocate 2 scales for w1 and w3 respectively.
            # They will be combined to a single scale after weight loading.
            w13_weight_scale = torch.nn.Parameter(torch.ones(
                num_experts, 2, dtype=torch.float32),
                                                  requires_grad=False)
            w2_weight_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                 requires_grad=False)
            layer.register_parameter("w13_weight_scale", w13_weight_scale)
            layer.register_parameter("w2_weight_scale", w2_weight_scale)
        else:
            w13_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
553
554
                    2 * ((intermediate_size_per_partition + block_n - 1) //
                         block_n),
555
556
557
558
559
560
561
562
563
                    (hidden_size + block_k - 1) // block_k,
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            w2_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
                    (hidden_size + block_n - 1) // block_n,
564
                    (intermediate_size_per_partition + block_k - 1) // block_k,
565
566
567
568
569
570
571
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
            layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
            assert self.quant_config.activation_scheme == "dynamic"
572

573
574
575
        # Add the quantization method used (per tensor/grouped/channel)
        # to ensure the weight scales are loaded in properly
        extra_weight_attrs.update(
576
577
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.
             value} if self.block_quant else
578
            {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
579
580
581
582
        # If loading fp8 checkpoint, pass the weight loaders.
        # If loading an fp16 checkpoint, do not (we will quantize in
        #   process_weights_after_loading()
        if self.quant_config.is_checkpoint_fp8_serialized:
583
584
            set_weight_attrs(w13_weight_scale, extra_weight_attrs)
            set_weight_attrs(w2_weight_scale, extra_weight_attrs)
585
586
587
588
589
590
591
592

        # INPUT_SCALES
        if self.quant_config.activation_scheme == "static":
            if not self.quant_config.is_checkpoint_fp8_serialized:
                raise ValueError(
                    "Found static activation scheme for checkpoint that "
                    "was not serialized fp8.")

593
594
595
596
            w13_input_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                 requires_grad=False)
            layer.register_parameter("w13_input_scale", w13_input_scale)
597
            set_weight_attrs(w13_input_scale, extra_weight_attrs)
598
599
600
601
602

            w2_input_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                requires_grad=False)
            layer.register_parameter("w2_input_scale", w2_input_scale)
603
604
            set_weight_attrs(w2_input_scale, extra_weight_attrs)

605
        else:
606
607
            layer.w13_input_scale = None
            layer.w2_input_scale = None
608
609

    def process_weights_after_loading(self, layer: Module) -> None:
610
611
        # Lazy import to avoid importing triton too early.
        from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
612
            is_rocm_aiter_moe_enabled, shuffle_weights)
613

614
615
        self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()

616
        # TODO (rob): refactor block quant into separate class.
617
        if self.block_quant:
618
            assert self.quant_config.activation_scheme == "dynamic"
619
            if current_platform.is_fp8_fnuz():
620
621
622
623
624
625
626
627
                w13_weight, w13_weight_scale_inv, w13_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w13_weight, layer.w13_weight_scale_inv,
                        layer.w13_input_scale)
                w2_weight, w2_weight_scale_inv, w2_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w2_weight, layer.w2_weight_scale_inv,
                        layer.w2_input_scale)
628
629
630
631
632
633
634
635
636
637
638
639
640
            else:
                w13_weight = layer.w13_weight.data
                w13_weight_scale_inv = layer.w13_weight_scale_inv.data
                w2_weight = layer.w2_weight
                w2_weight_scale_inv = layer.w2_weight_scale_inv

            # torch.compile() cannot use Parameter subclasses.
            layer.w13_weight = Parameter(w13_weight, requires_grad=False)
            layer.w13_weight_scale_inv = Parameter(w13_weight_scale_inv,
                                                   requires_grad=False)
            layer.w2_weight = Parameter(w2_weight, requires_grad=False)
            layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv,
                                                  requires_grad=False)
641
            if self.rocm_aiter_moe_enabled:
642
643
                # reshaping weights is required for aiter moe kernel.
                shuffled_w13, shuffled_w2 = shuffle_weights(
644
                    layer.w13_weight.data, layer.w2_weight.data)
645
646
647
648
649

                layer.w13_weight = torch.nn.Parameter(shuffled_w13,
                                                      requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2,
                                                     requires_grad=False)
650
651
652
653
654
655
656
657
658
659
660
661
662

            # DeepGemm scales need to be transposed and aligned.  We try to do
            # it ahead of time for performance reasons.
            if self.allow_deep_gemm:
                # Lazy import to avoid CUDA initialization problems.
                import deep_gemm as dg
                if _is_col_major(layer.w13_weight_scale_inv):
                    layer.w13_weight_scale_inv = \
                        dg.get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv).contiguous()
                if _is_col_major(layer.w2_weight_scale_inv):
                    layer.w2_weight_scale_inv = \
                        dg.get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv).contiguous()

663
        # If checkpoint is fp16, quantize in place.
664
        elif not self.quant_config.is_checkpoint_fp8_serialized:
665
            fp8_dtype = current_platform.fp8_dtype()
666
            w13_weight = torch.empty_like(layer.w13_weight.data,
667
668
                                          dtype=fp8_dtype)
            w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
669
670
671

            # Re-initialize w13_scale because we directly quantize
            # merged w13 weights and generate a single scaling factor.
672
            layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
673
                layer.local_num_experts,
674
675
                dtype=torch.float32,
                device=w13_weight.device),
676
                                                        requires_grad=False)
677
            for expert in range(layer.local_num_experts):
678
                w13_weight[expert, :, :], layer.w13_weight_scale[
679
680
                    expert] = ops.scaled_fp8_quant(
                        layer.w13_weight.data[expert, :, :])
681
                w2_weight[expert, :, :], layer.w2_weight_scale[
682
683
684
685
686
687
                    expert] = ops.scaled_fp8_quant(
                        layer.w2_weight.data[expert, :, :])
            layer.w13_weight = torch.nn.Parameter(w13_weight,
                                                  requires_grad=False)
            layer.w2_weight = torch.nn.Parameter(w2_weight,
                                                 requires_grad=False)
688
            if self.rocm_aiter_moe_enabled:
689
                # reshaping weights is required for aiter moe kernel.
690
691
                shuffled_w13, shuffled_w2 = shuffle_weights(
                    layer.w13_weight, layer.w2_weight)
692
693
694
695
696

                layer.w13_weight = torch.nn.Parameter(shuffled_w13,
                                                      requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2,
                                                     requires_grad=False)
697
698
699
700
701
702
703
        # If checkpoint is fp8, we need to handle that the
        # MoE kernels require single activation scale and single weight
        # scale for w13 per expert.
        else:
            # Fp8 moe kernels require a single activation scale.
            # We take the max of all the scales in case they differ.
            if self.quant_config.activation_scheme == "static":
704
705
                if (layer.w13_input_scale is None
                        or layer.w2_input_scale is None):
706
707
708
                    raise ValueError(
                        "QuantConfig has static quantization, but found "
                        "activation scales are None.")
709
710
                if (not all_close_1d(layer.w13_input_scale)
                        or not all_close_1d(layer.w2_input_scale)):
711
                    logger.warning_once(
712
713
                        "Found input_scales that are not equal for "
                        "fp8 MoE layer. Using the maximum across experts "
714
                        "for each layer.")
715
716
717
718
                layer.w13_input_scale = torch.nn.Parameter(
                    layer.w13_input_scale.max(), requires_grad=False)
                layer.w2_input_scale = torch.nn.Parameter(
                    layer.w2_input_scale.max(), requires_grad=False)
719
            if current_platform.is_fp8_fnuz():
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
                # Normalize the weights and scales
                w13_weight, w13_weight_scale, w13_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w13_weight, layer.w13_weight_scale,
                        layer.w13_input_scale)
                w2_weight, w2_weight_scale, w2_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w2_weight, layer.w2_weight_scale,
                        layer.w2_input_scale)
                # Reset the parameter
                layer.w13_weight = torch.nn.Parameter(w13_weight,
                                                      requires_grad=False)
                layer.w13_weight_scale = torch.nn.Parameter(
                    w13_weight_scale, requires_grad=False)
                if w13_input_scale is not None:
                    layer.w13_input_scale = torch.nn.Parameter(
                        w13_input_scale, requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(w2_weight,
                                                     requires_grad=False)
                layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale,
                                                           requires_grad=False)
                if w2_input_scale is not None:
                    layer.w2_input_scale = torch.nn.Parameter(
                        w2_input_scale, requires_grad=False)
744
745
746

            # Fp8 moe kernel needs single weight scale for w13 per expert.
            # We take the max then dequant and requant each expert.
747
            assert layer.w13_weight_scale is not None
748
            shard_size = layer.intermediate_size_per_partition
749
            max_w13_scales = layer.w13_weight_scale.max(dim=1).values
750
            for expert_id in range(layer.local_num_experts):
751
752
753
754
755
                start = 0
                for shard_id in range(2):
                    dq_weight = per_tensor_dequantize(
                        layer.w13_weight[expert_id][start:start +
                                                    shard_size, :],
756
                        layer.w13_weight_scale[expert_id][shard_id])
757
                    layer.w13_weight[expert_id][
758
                        start:start + shard_size, :], _ = ops.scaled_fp8_quant(
759
760
761
                            dq_weight, max_w13_scales[expert_id])
                    start += shard_size

762
            if self.rocm_aiter_moe_enabled:
763
764
                shuffled_w13, shuffled_w2 = shuffle_weights(
                    layer.w13_weight, layer.w2_weight)
765
766
767
768
769
770

                layer.w13_weight = torch.nn.Parameter(shuffled_w13,
                                                      requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2,
                                                     requires_grad=False)

771
772
            layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
                                                        requires_grad=False)
773
774
775
776
777
778

        if self.use_marlin:
            prepare_moe_fp8_layer_for_marlin(layer, False)
            # Activations not quantized for marlin.
            del layer.w13_input_scale
            del layer.w2_input_scale
779

780
    def select_gemm_impl(self, prepare_finalize, moe):
781

782
783
        from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import (  # noqa: E501
            BatchedTritonOrDeepGemmExperts)
784
785
786
        from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
            TritonOrDeepGemmExperts)

787
788
        assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
            "Marlin and ROCm AITER are not supported with all2all yet.")
789

790
        experts: Optional[Union[BatchedTritonOrDeepGemmExperts,
791
792
793
                                TritonOrDeepGemmExperts]] = None
        max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
        use_batched_experts = max_num_tokens_per_rank is not None
794

795
        if use_batched_experts:
796
            experts = BatchedTritonOrDeepGemmExperts(
797
798
799
800
801
802
803
                max_num_tokens=max_num_tokens_per_rank,
                world_size=prepare_finalize.world_size,
                dp_size=prepare_finalize.dp_size,
                use_fp8_w8a8=True,
                use_int8_w8a8=False,
                use_int8_w8a16=False,
                use_int4_w4a16=False,
804
805
806
                per_channel_quant=False,
                block_shape=self.quant_config.weight_block_size,
                allow_deep_gemm=self.allow_deep_gemm,
807
808
809
810
811
812
813
814
815
            )
        else:
            experts = TritonOrDeepGemmExperts(
                use_fp8_w8a8=True,
                block_shape=self.quant_config.weight_block_size,
                allow_deep_gemm=self.allow_deep_gemm,
            )

        assert experts is not None
816
        return experts
817

818
819
820
821
822
823
824
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
825
        use_grouped_topk: bool = False,
826
827
        topk_group: Optional[int] = None,
        num_expert_group: Optional[int] = None,
828
829
        global_num_experts: int = -1,
        expert_map: Optional[torch.Tensor] = None,
830
        custom_routing_function: Optional[Callable] = None,
Simon Mo's avatar
Simon Mo committed
831
832
        scoring_func: str = "softmax",
        e_score_correction_bias: Optional[torch.Tensor] = None,
833
        apply_router_weight_on_input: bool = False,
Michael Goin's avatar
Michael Goin committed
834
        activation: str = "silu",
835
836
837
838
        enable_eplb: bool = False,
        expert_load_view: Optional[torch.Tensor] = None,
        logical_to_physical_map: Optional[torch.Tensor] = None,
        logical_replica_count: Optional[torch.Tensor] = None,
839
    ) -> torch.Tensor:
840
841
842
843
844
        if enable_eplb:
            assert expert_load_view is not None
            assert logical_to_physical_map is not None
            assert logical_replica_count is not None
            assert isinstance(layer, FusedMoE)
845

846
847
848
849
850
851
852
        topk_weights, topk_ids = FusedMoE.select_experts(
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
853
            num_expert_group=num_expert_group,
Simon Mo's avatar
Simon Mo committed
854
855
856
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
            e_score_correction_bias=e_score_correction_bias,
857
            indices_type=self.topk_indices_dtype,
858
859
860
861
862
            enable_eplb=enable_eplb,
            expert_map=expert_map,
            expert_load_view=expert_load_view,
            logical_to_physical_map=logical_to_physical_map,
            logical_replica_count=logical_replica_count,
Simon Mo's avatar
Simon Mo committed
863
        )
864

865
        if self.rocm_aiter_moe_enabled:
866
867
            from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (  # noqa: E501
                rocm_aiter_fused_experts)
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
            return rocm_aiter_fused_experts(
                x,
                layer.w13_weight,
                layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                activation=activation,
                use_fp8_w8a8=True,
                apply_router_weight_on_input=apply_router_weight_on_input,
                w1_scale=(layer.w13_weight_scale_inv
                          if self.block_quant else layer.w13_weight_scale),
                w2_scale=(layer.w2_weight_scale_inv
                          if self.block_quant else layer.w2_weight_scale),
                a1_scale=layer.w13_input_scale,
                a2_scale=layer.w2_input_scale,
883
884
                block_shape=self.quant_config.weight_block_size,
                expert_map=expert_map)
885
        elif self.use_marlin:
886
887
888
889
            assert activation == "silu", (
                f"{activation} not supported for Marlin MoE.")
            assert not apply_router_weight_on_input, (
                "Apply router weight on input not supported for Marlin MoE.")
890
891
892
893
894
895
896
897
898
899
900
901
            return torch.ops.vllm.fused_marlin_moe(
                x,
                layer.w13_weight,
                layer.w2_weight,
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                router_logits,
                topk_weights,
                topk_ids,
                quant_type_id=scalar_types.float8_e4m3fn.id,
                global_num_experts=global_num_experts,
                expert_map=expert_map)
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
        else:
            return self.fused_experts(
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                inplace=True,
                activation=activation,
                global_num_experts=global_num_experts,
                apply_router_weight_on_input=apply_router_weight_on_input,
                expert_map=expert_map,
                w1_scale=(layer.w13_weight_scale_inv
                          if self.block_quant else layer.w13_weight_scale),
                w2_scale=(layer.w2_weight_scale_inv
                          if self.block_quant else layer.w2_weight_scale),
                a1_scale=layer.w13_input_scale,
                a2_scale=layer.w2_input_scale,
            )
921
922


923
924
925
class Fp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
926
927
928
    """

    def __init__(self, quant_config: Fp8Config):
929
        super().__init__(quant_config)