fp8.py 48.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import functools
bnellnm's avatar
bnellnm committed
5
from typing import TYPE_CHECKING, Any, Callable, Optional
6
7

import torch
8
import torch.nn.functional as F
9
10
11
from torch.nn import Module
from torch.nn.parameter import Parameter

12
import vllm.envs as envs
13
from vllm import _custom_ops as ops
14
from vllm.distributed import get_tensor_model_parallel_world_size
15
from vllm.logger import init_logger
bnellnm's avatar
bnellnm committed
16
from vllm.model_executor.layers.fused_moe import (
17
18
19
    FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase,
    FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize,
    FusedMoeWeightScaleSupported)
20
21
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
                                               UnquantizedLinearMethod)
22
from vllm.model_executor.layers.quantization import QuantizationMethods
23
from vllm.model_executor.layers.quantization.base_config import (
24
    QuantizationConfig, QuantizeMethodBase)
25
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
26
27
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
    get_col_major_tma_aligned_tensor, requant_weight_ue8m0_inplace)
28
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
29
30
    apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin,
    prepare_moe_fp8_layer_for_marlin)
31
from vllm.model_executor.layers.quantization.utils.quant_utils import (
32
    GroupShape, is_layer_skipped)
33
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
34
35
36
37
    Fp8LinearOp, all_close_1d, cutlass_block_fp8_supported,
    cutlass_fp8_supported, maybe_create_device_identity,
    normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
    requantize_with_max_scale)
38
39
from vllm.model_executor.parameter import (BlockQuantScaleParameter,
                                           ModelWeightParameter,
40
                                           PerTensorScaleParameter)
41
from vllm.model_executor.utils import set_weight_attrs
42
from vllm.platforms import current_platform
43
from vllm.scalar_type import scalar_types
44
from vllm.utils import has_deep_gemm
45
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
46
from vllm.utils.flashinfer import has_flashinfer_moe
47

48
49
50
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

51
52
53
54
ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)

55

56
57
58
59
60
def _swap_w13_to_w31(x: torch.Tensor) -> torch.Tensor:
    return x.reshape(-1, 2, x.shape[-2] // 2,
                     x.shape[-1]).flip(dims=[1]).reshape(x.shape)


61
62
63
64
65
def _is_col_major(x: torch.Tensor) -> bool:
    assert x.dim() == 3
    b, m, n = x.shape
    return x.stride(0) == m * n and x.stride(1) == 1 and x.stride(2) == m

66

67
class Fp8Config(QuantizationConfig):
68
69
    """Config class for FP8."""

70
71
    def __init__(
        self,
72
        is_checkpoint_fp8_serialized: bool = False,
73
        activation_scheme: str = "dynamic",
74
75
        ignored_layers: Optional[list[str]] = None,
        weight_block_size: Optional[list[int]] = None,
76
    ) -> None:
77
        super().__init__()
78

79
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
80

81
82
83
        if activation_scheme not in ACTIVATION_SCHEMES:
            raise ValueError(
                f"Unsupported activation scheme {activation_scheme}")
84
        self.activation_scheme = activation_scheme
85
        self.ignored_layers = ignored_layers or []
86
87
88
89
90
91
92
93
94
95
96
97
98
99
        if weight_block_size is not None:
            if not is_checkpoint_fp8_serialized:
                raise ValueError(
                    "The block-wise quantization only supports fp8-serialized "
                    "checkpoint for now.")
            if len(weight_block_size) != 2:
                raise ValueError(
                    "The quantization block size of weight must have 2 "
                    f"dimensions, but got {len(weight_block_size)} dimensions")
            if activation_scheme != "dynamic":
                raise ValueError("The block-wise quantization only supports "
                                 "dynamic activation scheme for now, but got "
                                 f"{activation_scheme} activation scheme.")
        self.weight_block_size = weight_block_size
100

101
    @classmethod
102
    def get_name(cls) -> QuantizationMethods:
103
104
105
        return "fp8"

    @classmethod
106
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
107
108
109
110
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
111
        return 80
112
113

    @classmethod
114
    def get_config_filenames(cls) -> list[str]:
115
116
        return []

117
118
119
120
121
    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if self.ignored_layers is not None:
            self.ignored_layers = hf_to_vllm_mapper.apply_list(
                self.ignored_layers)

122
    @classmethod
123
    def from_config(cls, config: dict[str, Any]) -> "Fp8Config":
124
125
        quant_method = cls.get_from_keys(config, ["quant_method"])
        is_checkpoint_fp8_serialized = ("fp8" in quant_method)
126
        activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
127
        ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
128
129
        weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"],
                                                 None)
130
        return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
131
                   activation_scheme=activation_scheme,
132
133
                   ignored_layers=ignored_layers,
                   weight_block_size=weight_block_size)
134

135
136
    def get_quant_method(self, layer: torch.nn.Module,
                         prefix: str) -> Optional["QuantizeMethodBase"]:
137
138
        from vllm.attention.layer import Attention  # Avoid circular import

139
        if isinstance(layer, LinearBase):
140
141
142
            if is_layer_skipped(prefix=prefix,
                                ignored_layers=self.ignored_layers,
                                fused_mapping=self.packed_modules_mapping):
143
                return UnquantizedLinearMethod()
144
            return Fp8LinearMethod(self)
145
146
147
        elif isinstance(layer, FusedMoE):
            return Fp8MoEMethod(self)
        elif isinstance(layer, Attention):
148
            return Fp8KVCacheMethod(self)
149
        return None
150

151
152
153
154
155
156
157
158
159
160
161
162
163
    def get_cache_scale(self, name: str) -> Optional[str]:
        """
        Check whether the param name matches the format for k/v cache scales
        in compressed-tensors. If this is the case, return its equivalent
        param name expected by vLLM

        :param name: param name
        :return: matching param name for KV cache scale in vLLM
        """
        if name.endswith(".output_scale") and ".k_proj" in name:
            return name.replace(".k_proj.output_scale", ".attn.k_scale")
        if name.endswith(".output_scale") and ".v_proj" in name:
            return name.replace(".v_proj.output_scale", ".attn.v_scale")
164
165
166
167
168
        if name.endswith(".output_scale") and ".q_proj" in name:
            return name.replace(".q_proj.output_scale", ".attn.q_scale")
        if name.endswith("self_attn.prob_output_scale"):
            return name.replace(".prob_output_scale", ".attn.prob_scale")
        # If no matches, return None
169
170
        return None

171
172
173

class Fp8LinearMethod(LinearMethodBase):
    """Linear method for FP8.
174
175
176
177
178
179
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.
180
181
182
183
184

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
    2. Only support float8_e4m3fn data type due to the limitation of
       torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
185

186
187
188
189
    Args:
        quant_config: The quantization config.
    """

190
    def __init__(self, quant_config: Fp8Config):
191
        self.quant_config = quant_config
192
        self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
193
        self.out_dtype = torch.get_default_dtype()
194

195
196
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
197
198
        self.use_marlin = (not current_platform.has_device_capability(89)
                           or envs.VLLM_TEST_FORCE_FP8_MARLIN)
199
        # Disable marlin for rocm
200
        if current_platform.is_rocm():
201
            self.use_marlin = False
202

203
204
205
206
207
208
209
        # AITER is only supported on ROCm and only for FP8_FNUZ
        # and at the moment are MI300 series
        self.use_aiter_and_is_supported = (current_platform.is_rocm()
                                           and envs.VLLM_ROCM_USE_AITER
                                           and envs.VLLM_ROCM_USE_AITER_LINEAR
                                           and current_platform.is_fp8_fnuz())

210
        self.block_quant = self.quant_config.weight_block_size is not None
211
212
213
214
215
216
217
        self.act_q_static = self.quant_config.activation_scheme == "static"
        # Use per-token quantization for better perf if dynamic and cutlass
        if not self.act_q_static and cutlass_fp8_supported():
            self.act_q_group_shape = GroupShape.PER_TOKEN
        else:
            self.act_q_group_shape = GroupShape.PER_TENSOR

218
        self.fp8_linear = Fp8LinearOp(
219
220
221
            act_quant_static=self.act_q_static,
            act_quant_group_shape=self.act_q_group_shape,
            cutlass_fp8_supported=cutlass_fp8_supported())
222

223
224
225
226
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
227
        output_partition_sizes: list[int],
228
229
230
231
232
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
233
234
        maybe_create_device_identity()

235
        output_size_per_partition = sum(output_partition_sizes)
236
        weight_loader = extra_weight_attrs.get("weight_loader")
237
238
239
240
241
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None
242

243
244
245
        if self.block_quant:
            tp_size = get_tensor_model_parallel_world_size()
            assert self.quant_config.weight_block_size is not None
246
            layer.weight_block_size = self.quant_config.weight_block_size
247
248
249
250
251
252
253
254
255
256
257
258
259
            block_n, block_k = (
                self.quant_config.weight_block_size[0],
                self.quant_config.weight_block_size[1],
            )
            # Required by row parallel
            if (tp_size > 1
                    and input_size // input_size_per_partition == tp_size
                    and input_size_per_partition % block_k != 0):
                raise ValueError(
                    f"Weight input_size_per_partition = "
                    f"{input_size_per_partition} is not divisible by "
                    f"weight quantization block_k = {block_k}.")
            # Required by column parallel or enabling merged weights
260
261
262
263
264
265
266
267
268
269
            is_tp_split = (tp_size > 1 and
                           output_size // output_size_per_partition == tp_size)
            is_merged_gemm = len(output_partition_sizes) > 1
            if is_tp_split or is_merged_gemm:
                sizes_to_check = output_partition_sizes
                if not is_tp_split and is_merged_gemm:
                    # In case of merged matrices, we allow the last
                    # matrix to not be a multiple of block size
                    sizes_to_check = output_partition_sizes[:-1]
                for output_partition_size in sizes_to_check:
270
271
272
273
274
275
                    if output_partition_size % block_n != 0:
                        raise ValueError(
                            f"Weight output_partition_size = "
                            f"{output_partition_size} is not divisible by "
                            f"weight quantization block_n = {block_n}.")

276
277
278
279
        # WEIGHT
        weight_dtype = (torch.float8_e4m3fn
                        if self.quant_config.is_checkpoint_fp8_serialized else
                        params_dtype)
280
281
282
283
284
285
286
287

        weight = ModelWeightParameter(data=torch.empty(
            output_size_per_partition,
            input_size_per_partition,
            dtype=weight_dtype),
                                      input_dim=1,
                                      output_dim=0,
                                      weight_loader=weight_loader)
288
289
        layer.register_parameter("weight", weight)

290
291
292
293
        # If checkpoint is serialized fp8, load them.
        # Otherwise, wait until process_weights_after_loading.
        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
294
295
296
297
298
299
300
            if not self.block_quant:
                scale = PerTensorScaleParameter(
                    data=torch.empty(len(output_partition_sizes),
                                     dtype=torch.float32),
                    weight_loader=weight_loader,
                )
                scale[:] = torch.finfo(torch.float32).min
301
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
302
303
304
305
306
307
308
309
310
311
312
313
314
315
                layer.register_parameter("weight_scale", scale)
            else:
                assert self.quant_config.activation_scheme == "dynamic"
                scale = BlockQuantScaleParameter(
                    data=torch.empty(
                        (output_size_per_partition + block_n - 1) // block_n,
                        (input_size_per_partition + block_k - 1) // block_k,
                        dtype=torch.float32,
                    ),
                    input_dim=1,
                    output_dim=0,
                    weight_loader=weight_loader,
                )
                scale[:] = torch.finfo(torch.float32).min
316
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
317
318
                # The weight_scale_inv name is intentional for deepseekv3
                layer.register_parameter("weight_scale_inv", scale)
319

320
            # INPUT ACTIVATION SCALE
321
            if self.quant_config.activation_scheme == "static":
322
323
324
325
326
                scale = PerTensorScaleParameter(data=torch.empty(
                    len(output_partition_sizes), dtype=torch.float32),
                                                weight_loader=weight_loader)

                scale[:] = torch.finfo(torch.float32).min
327
                set_weight_attrs(scale, {"scale_type": "input_scale"})
328
                layer.register_parameter("input_scale", scale)
329
330
            else:
                layer.register_parameter("input_scale", None)
331

332
    def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor:
333
334
335
336
337
338
339
340
341
342
        # Pad the weight tensor. This is an optimization on ROCm platform, which
        # can benefit from tensors located far enough from one another in memory
        if (envs.VLLM_ROCM_FP8_PADDING and current_platform.is_rocm()
                and weight.stride(-1) == 1
                and (weight.stride(-2) * weight.element_size()) % 512 == 0):
            num_pad = 256 // weight.element_size()
            weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
            torch.cuda.empty_cache()
        return weight

343
    def process_weights_after_loading(self, layer: Module) -> None:
344
        size_k_first = True
345
        # TODO(rob): refactor block quant into separate class.
346
        if self.block_quant:
347
            assert self.quant_config.activation_scheme == "dynamic"
348
            size_k_first = False
349
            if current_platform.is_fp8_fnuz():
350
                weight, weight_scale_inv, _ = \
351
352
                    normalize_e4m3fn_to_e4m3fnuz(
                        weight=layer.weight,
353
354
355
356
357
                        weight_scale=layer.weight_scale_inv)
            else:
                weight = layer.weight.data
                weight_scale_inv = layer.weight_scale_inv.data

358
            weight = self._maybe_pad_weight(weight)
359

360
361
362
363
364
            # Torch.compile cannot use Parameter subclasses.
            layer.weight = Parameter(weight, requires_grad=False)
            layer.weight_scale_inv = Parameter(weight_scale_inv,
                                               requires_grad=False)

365
        # If checkpoint not serialized fp8, quantize the weights.
366
        elif not self.quant_config.is_checkpoint_fp8_serialized:
367
368
            qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
                                                         scale=None)
369
370

            # Update the layer with the new values.
371
372
            layer.weight = Parameter(qweight.t(), requires_grad=False)
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
373
            layer.input_scale = None
374

375
376
        # If checkpoint is fp8, handle that there are N scales for N
        # shards in a fused module
377
        else:
378
379
380
381
382
            layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
                                                    requires_grad=False)
            if self.quant_config.activation_scheme == "static":
                layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
                                                       requires_grad=False)
383
384
385

            weight = layer.weight
            weight_scale = layer.weight_scale
386
387
388

            # If using w8a8, torch._scaled_mm needs per tensor, so
            # requantize the logical shards as a single weight.
389
            if not self.use_marlin:
390
                # Dequant -> Quant with max scale so we can run per tensor.
391
                if current_platform.is_fp8_fnuz():
392
393
394
395
396
397
398
399
400
                    weight, weight_scale, input_scale = \
                        normalize_e4m3fn_to_e4m3fnuz(
                            weight=weight,
                            weight_scale=weight_scale,
                            input_scale=layer.input_scale)
                    if input_scale is not None:
                        layer.input_scale = Parameter(input_scale,
                                                      requires_grad=False)

401
                weight_scale, weight = requantize_with_max_scale(
402
403
                    weight=weight,
                    weight_scale=weight_scale,
404
405
                    logical_widths=layer.logical_widths,
                )
406

407
            weight = self._maybe_pad_weight(weight)
408
            # Update layer with new values.
409
            layer.weight = Parameter(weight.t(), requires_grad=False)
410
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
411
            if self.quant_config.activation_scheme == "static":
412
413
                layer.input_scale = Parameter(layer.input_scale.max(),
                                              requires_grad=False)
414

415
        if self.use_marlin:
416
            prepare_fp8_layer_for_marlin(layer, size_k_first)
417
418
            # Activations not quantized for marlin.
            del layer.input_scale
419

420
421
422
423
424
425
426
427
428
429
430
431
432
        # On B200, DeepGemm only support E8M0 scale, which means we need to
        # requantize the weight and input to the specific scale
        # at the same time.
        if is_blackwell_deep_gemm_used():
            assert layer.weight_block_size is not None
            block_sz = tuple(layer.weight_block_size)
            requant_weight_ue8m0_inplace(
                layer.weight.data,
                layer.weight_scale_inv.data if hasattr(
                    layer, "weight_scale_inv") else layer.weight_scale.data,
                block_sz,
            )

433
434
435
436
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
437

438
        if self.use_marlin:
439
440
441
442
            return apply_fp8_marlin_linear(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale,
443
444
445
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
446
                bias=bias)
447

448
449
        if self.block_quant:
            assert self.quant_config.weight_block_size is not None
450

451
            return torch.ops.vllm.apply_w8a8_block_fp8_linear(
452
453
454
455
456
457
                input=x,
                weight=layer.weight,
                block_size=self.quant_config.weight_block_size,
                weight_scale=layer.weight_scale_inv,
                input_scale=layer.input_scale,
                bias=bias,
458
                cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
459
                use_aiter_and_is_supported=self.use_aiter_and_is_supported,
460
461
            )

462
463
464
        return self.fp8_linear.apply(input=x,
                                     weight=layer.weight,
                                     weight_scale=layer.weight_scale,
465
                                     out_dtype=self.out_dtype,
466
467
                                     input_scale=layer.input_scale,
                                     bias=bias)
468
469


470
471
472
473
474
475
476
477
478
479
480
481
482
483
class Fp8MoEMethod(FusedMoEMethodBase):
    """MoE method for FP8.
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

    def __init__(self, quant_config: Fp8Config):
484

485
        from vllm.model_executor.layers.fused_moe import fused_experts
486
        self.quant_config = quant_config
487
        self.block_quant = self.quant_config.weight_block_size is not None
488

489
490
491
492
493
        self.flashinfer_moe_enabled = False
        if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
            logger.info_once(
                "Using FlashInfer MoE FP8 kernels for Fp8MoEMethod.")
            self.flashinfer_moe_enabled = True
494
495
496
497
498
499
500
501
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
        self.use_marlin = (not current_platform.has_device_capability(89)
                           or envs.VLLM_TEST_FORCE_FP8_MARLIN)
        # Disable marlin for rocm
        if current_platform.is_rocm():
            self.use_marlin = False

502
503
504
        # Check for DeepGemm support.
        self.allow_deep_gemm = False
        if envs.VLLM_USE_DEEP_GEMM:
505
            if not has_deep_gemm():
506
                logger.warning_once("Failed to import DeepGemm kernels.")
507
508
            elif not self.block_quant:
                logger.warning_once("Model is not block quantized. Not using "
509
                                    "DeepGemm kernels")
510
            elif (current_platform.is_cuda()
511
                  and current_platform.is_device_capability(90)):
512
513
                logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.")
                self.allow_deep_gemm = True
514
515
516
517
518
            elif (current_platform.is_cuda()
                  and is_blackwell_deep_gemm_used()):
                logger.info_once("Using DeepGemm SM100 kernels for "
                                 "Fp8MoEMethod.")
                self.allow_deep_gemm = True
519
520
521
522
            else:
                logger.warning_once(
                    "DeepGemm not supported on the current platform.")

523
524
525
        # Check for CutlassBlockScaledGroupedGemm support.
        self.allow_cutlass_block_scaled_grouped_gemm = False
        if not self.block_quant:
526
527
            logger.debug_once("Model is not block quantized. Not using "
                              "CutlassBlockScaledGroupedGemm kernels")
528
        elif (current_platform.is_cuda()
529
              and current_platform.is_device_capability(100)):
530
531
532
533
534
535
536
537
538
            logger.info_once(
                "Using CutlassBlockScaledGroupedGemm kernels for Fp8MoEMethod."
            )
            self.allow_cutlass_block_scaled_grouped_gemm = True
        else:
            logger.warning_once(
                "CutlassBlockScaledGroupedGemm not supported on the current "
                "platform.")

539
        self.topk_indices_dtype = None
540
        self.fused_experts = functools.partial(  # type: ignore
541
            fused_experts,
542
            use_fp8_w8a8=True,
543
            block_shape=self.quant_config.weight_block_size,
544
545
546
            allow_deep_gemm=self.allow_deep_gemm,
            allow_cutlass_block_scaled_grouped_gemm=(
                self.allow_cutlass_block_scaled_grouped_gemm))
547

548
    def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
549
550
                       intermediate_size_per_partition: int,
                       params_dtype: torch.dtype, **extra_weight_attrs):
551

552
553
554
555
556
557
        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

558
559
        if self.quant_config.is_checkpoint_fp8_serialized:
            params_dtype = torch.float8_e4m3fn
560
561
        if self.block_quant:
            assert self.quant_config.weight_block_size is not None
562
            layer.weight_block_size = self.quant_config.weight_block_size
563
564
565
566
567
568
569
570
571
            tp_size = get_tensor_model_parallel_world_size()
            block_n, block_k = (
                self.quant_config.weight_block_size[0],
                self.quant_config.weight_block_size[1],
            )
            # NOTE: To ensure proper alignment of the block-wise quantization
            # scales, the output_size of the weights for both the gate and up
            # layers must be divisible by block_n.
            # Required by column parallel or enabling merged weights
572
            if intermediate_size_per_partition % block_n != 0:
573
574
                raise ValueError(
                    f"The output_size of gate's and up's weight = "
575
                    f"{intermediate_size_per_partition} is not divisible by "
576
                    f"weight quantization block_n = {block_n}.")
577
578
            if (tp_size > 1
                    and intermediate_size_per_partition % block_k != 0):
579
                # Required by row parallel
580
581
582
583
                raise ValueError(
                    f"The input_size of down's weight = "
                    f"{intermediate_size_per_partition} is not divisible by "
                    f"weight quantization block_k = {block_k}.")
584
585

        # WEIGHTS
586
587
588
589
590
        w13_weight = torch.nn.Parameter(torch.empty(
            num_experts,
            2 * intermediate_size_per_partition,
            hidden_size,
            dtype=params_dtype),
591
592
593
594
                                        requires_grad=False)
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

595
596
597
598
599
        w2_weight = torch.nn.Parameter(torch.empty(
            num_experts,
            hidden_size,
            intermediate_size_per_partition,
            dtype=params_dtype),
600
601
602
603
604
                                       requires_grad=False)
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
        if not self.block_quant:
            # Allocate 2 scales for w1 and w3 respectively.
            # They will be combined to a single scale after weight loading.
            w13_weight_scale = torch.nn.Parameter(torch.ones(
                num_experts, 2, dtype=torch.float32),
                                                  requires_grad=False)
            w2_weight_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                 requires_grad=False)
            layer.register_parameter("w13_weight_scale", w13_weight_scale)
            layer.register_parameter("w2_weight_scale", w2_weight_scale)
        else:
            w13_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
620
621
                    2 * ((intermediate_size_per_partition + block_n - 1) //
                         block_n),
622
623
624
625
626
627
628
629
630
                    (hidden_size + block_k - 1) // block_k,
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            w2_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
                    (hidden_size + block_n - 1) // block_n,
631
                    (intermediate_size_per_partition + block_k - 1) // block_k,
632
633
634
635
636
637
638
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
            layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
            assert self.quant_config.activation_scheme == "dynamic"
639

640
641
642
        # Add the quantization method used (per tensor/grouped/channel)
        # to ensure the weight scales are loaded in properly
        extra_weight_attrs.update(
643
644
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.
             value} if self.block_quant else
645
            {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
646
647
648
649
        # If loading fp8 checkpoint, pass the weight loaders.
        # If loading an fp16 checkpoint, do not (we will quantize in
        #   process_weights_after_loading()
        if self.quant_config.is_checkpoint_fp8_serialized:
650
651
            set_weight_attrs(w13_weight_scale, extra_weight_attrs)
            set_weight_attrs(w2_weight_scale, extra_weight_attrs)
652
653
654
655
656
657
658
659

        # INPUT_SCALES
        if self.quant_config.activation_scheme == "static":
            if not self.quant_config.is_checkpoint_fp8_serialized:
                raise ValueError(
                    "Found static activation scheme for checkpoint that "
                    "was not serialized fp8.")

660
661
662
663
            w13_input_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                 requires_grad=False)
            layer.register_parameter("w13_input_scale", w13_input_scale)
664
            set_weight_attrs(w13_input_scale, extra_weight_attrs)
665
666
667
668
669

            w2_input_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                requires_grad=False)
            layer.register_parameter("w2_input_scale", w2_input_scale)
670
671
            set_weight_attrs(w2_input_scale, extra_weight_attrs)

672
        else:
673
674
            layer.w13_input_scale = None
            layer.w2_input_scale = None
675
676

    def process_weights_after_loading(self, layer: Module) -> None:
677
678
        # Lazy import to avoid importing triton too early.
        from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
679
            is_rocm_aiter_moe_enabled, shuffle_weights)
680

681
682
        self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()

683
        # TODO (rob): refactor block quant into separate class.
684
        if self.block_quant:
685
            assert self.quant_config.activation_scheme == "dynamic"
686
            if current_platform.is_fp8_fnuz():
687
688
689
690
691
692
693
694
                w13_weight, w13_weight_scale_inv, w13_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w13_weight, layer.w13_weight_scale_inv,
                        layer.w13_input_scale)
                w2_weight, w2_weight_scale_inv, w2_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w2_weight, layer.w2_weight_scale_inv,
                        layer.w2_input_scale)
695
696
697
698
699
700
701
702
            elif self.flashinfer_moe_enabled:
                # NOTE: weights have to be swapped since the activation is
                # applied on different half for flashinfer vs vllm
                w13_weight = _swap_w13_to_w31(layer.w13_weight.data)
                w13_weight_scale_inv = _swap_w13_to_w31(
                    layer.w13_weight_scale_inv.data)
                w2_weight = layer.w2_weight.data
                w2_weight_scale_inv = layer.w2_weight_scale_inv.data
703
704
705
706
707
708
709
710
711
712
713
714
715
            else:
                w13_weight = layer.w13_weight.data
                w13_weight_scale_inv = layer.w13_weight_scale_inv.data
                w2_weight = layer.w2_weight
                w2_weight_scale_inv = layer.w2_weight_scale_inv

            # torch.compile() cannot use Parameter subclasses.
            layer.w13_weight = Parameter(w13_weight, requires_grad=False)
            layer.w13_weight_scale_inv = Parameter(w13_weight_scale_inv,
                                                   requires_grad=False)
            layer.w2_weight = Parameter(w2_weight, requires_grad=False)
            layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv,
                                                  requires_grad=False)
716
            if self.rocm_aiter_moe_enabled:
717
718
                # reshaping weights is required for aiter moe kernel.
                shuffled_w13, shuffled_w2 = shuffle_weights(
719
                    layer.w13_weight.data, layer.w2_weight.data)
720
721
722
723
724

                layer.w13_weight = torch.nn.Parameter(shuffled_w13,
                                                      requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2,
                                                     requires_grad=False)
725
726
727

            # DeepGemm scales need to be transposed and aligned.  We try to do
            # it ahead of time for performance reasons.
728
            if self.allow_deep_gemm and not is_blackwell_deep_gemm_used():
729
730
731
                # Lazy import to avoid CUDA initialization problems.
                if _is_col_major(layer.w13_weight_scale_inv):
                    layer.w13_weight_scale_inv = \
732
                        get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv).contiguous()
733
734
                if _is_col_major(layer.w2_weight_scale_inv):
                    layer.w2_weight_scale_inv = \
735
                        get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv).contiguous()
736

737
        # If checkpoint is fp16, quantize in place.
738
        elif not self.quant_config.is_checkpoint_fp8_serialized:
739
            fp8_dtype = current_platform.fp8_dtype()
740
            w13_weight = torch.empty_like(layer.w13_weight.data,
741
742
                                          dtype=fp8_dtype)
            w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
743
744
745

            # Re-initialize w13_scale because we directly quantize
            # merged w13 weights and generate a single scaling factor.
746
            layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
747
                layer.local_num_experts,
748
749
                dtype=torch.float32,
                device=w13_weight.device),
750
                                                        requires_grad=False)
751
            for expert in range(layer.local_num_experts):
752
                w13_weight[expert, :, :], layer.w13_weight_scale[
753
754
                    expert] = ops.scaled_fp8_quant(
                        layer.w13_weight.data[expert, :, :])
755
                w2_weight[expert, :, :], layer.w2_weight_scale[
756
757
758
759
760
761
                    expert] = ops.scaled_fp8_quant(
                        layer.w2_weight.data[expert, :, :])
            layer.w13_weight = torch.nn.Parameter(w13_weight,
                                                  requires_grad=False)
            layer.w2_weight = torch.nn.Parameter(w2_weight,
                                                 requires_grad=False)
762
            if self.rocm_aiter_moe_enabled:
763
                # reshaping weights is required for aiter moe kernel.
764
765
                shuffled_w13, shuffled_w2 = shuffle_weights(
                    layer.w13_weight, layer.w2_weight)
766
767
768
769
770

                layer.w13_weight = torch.nn.Parameter(shuffled_w13,
                                                      requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2,
                                                     requires_grad=False)
771
772
773
774
775
776
777
        # If checkpoint is fp8, we need to handle that the
        # MoE kernels require single activation scale and single weight
        # scale for w13 per expert.
        else:
            # Fp8 moe kernels require a single activation scale.
            # We take the max of all the scales in case they differ.
            if self.quant_config.activation_scheme == "static":
778
779
                if (layer.w13_input_scale is None
                        or layer.w2_input_scale is None):
780
781
782
                    raise ValueError(
                        "QuantConfig has static quantization, but found "
                        "activation scales are None.")
783
784
                if (not all_close_1d(layer.w13_input_scale)
                        or not all_close_1d(layer.w2_input_scale)):
785
                    logger.warning_once(
786
787
                        "Found input_scales that are not equal for "
                        "fp8 MoE layer. Using the maximum across experts "
788
                        "for each layer.")
789
790
791
792
                layer.w13_input_scale = torch.nn.Parameter(
                    layer.w13_input_scale.max(), requires_grad=False)
                layer.w2_input_scale = torch.nn.Parameter(
                    layer.w2_input_scale.max(), requires_grad=False)
793
            if current_platform.is_fp8_fnuz():
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
                # Normalize the weights and scales
                w13_weight, w13_weight_scale, w13_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w13_weight, layer.w13_weight_scale,
                        layer.w13_input_scale)
                w2_weight, w2_weight_scale, w2_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w2_weight, layer.w2_weight_scale,
                        layer.w2_input_scale)
                # Reset the parameter
                layer.w13_weight = torch.nn.Parameter(w13_weight,
                                                      requires_grad=False)
                layer.w13_weight_scale = torch.nn.Parameter(
                    w13_weight_scale, requires_grad=False)
                if w13_input_scale is not None:
                    layer.w13_input_scale = torch.nn.Parameter(
                        w13_input_scale, requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(w2_weight,
                                                     requires_grad=False)
                layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale,
                                                           requires_grad=False)
                if w2_input_scale is not None:
                    layer.w2_input_scale = torch.nn.Parameter(
                        w2_input_scale, requires_grad=False)
818
819
820

            # Fp8 moe kernel needs single weight scale for w13 per expert.
            # We take the max then dequant and requant each expert.
821
            assert layer.w13_weight_scale is not None
822
            shard_size = layer.intermediate_size_per_partition
823
            max_w13_scales = layer.w13_weight_scale.max(dim=1).values
824
            for expert_id in range(layer.local_num_experts):
825
826
827
828
829
                start = 0
                for shard_id in range(2):
                    dq_weight = per_tensor_dequantize(
                        layer.w13_weight[expert_id][start:start +
                                                    shard_size, :],
830
                        layer.w13_weight_scale[expert_id][shard_id])
831
                    layer.w13_weight[expert_id][
832
                        start:start + shard_size, :], _ = ops.scaled_fp8_quant(
833
834
835
                            dq_weight, max_w13_scales[expert_id])
                    start += shard_size

836
            if self.rocm_aiter_moe_enabled:
837
838
                shuffled_w13, shuffled_w2 = shuffle_weights(
                    layer.w13_weight, layer.w2_weight)
839
840
841
842
843
844

                layer.w13_weight = torch.nn.Parameter(shuffled_w13,
                                                      requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2,
                                                     requires_grad=False)

845
846
            layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
                                                        requires_grad=False)
847
848
849
850
851
852

        if self.use_marlin:
            prepare_moe_fp8_layer_for_marlin(layer, False)
            # Activations not quantized for marlin.
            del layer.w13_input_scale
            del layer.w2_input_scale
853

854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
        if is_blackwell_deep_gemm_used():
            assert layer.weight_block_size is not None
            # Re-quantise the expert weights so their scales are UE8M0.
            block_sz = tuple(layer.weight_block_size)
            requant_weight_ue8m0_inplace(
                layer.w13_weight.data,
                layer.w13_weight_scale_inv.data,
                block_sz,
            )
            requant_weight_ue8m0_inplace(
                layer.w2_weight.data,
                layer.w2_weight_scale_inv.data,
                block_sz,
            )

            # Ensure column-major TMA alignment expected by DeepGEMM.
            if _is_col_major(layer.w13_weight_scale_inv):
                layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(
                    layer.w13_weight_scale_inv).contiguous()
            if _is_col_major(layer.w2_weight_scale_inv):
                layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(
                    layer.w2_weight_scale_inv).contiguous()

bnellnm's avatar
bnellnm committed
877
878
879
880
881
    def select_gemm_impl(
        self,
        prepare_finalize: FusedMoEPrepareAndFinalize,
        moe: FusedMoEConfig,
    ) -> FusedMoEPermuteExpertsUnpermute:
882
883
884
        from vllm.model_executor.layers.fused_moe import (
            BatchedTritonOrDeepGemmExperts, TritonOrDeepGemmExperts)

885
886
        assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
            "Marlin and ROCm AITER are not supported with all2all yet.")
887

bnellnm's avatar
bnellnm committed
888
889
890
891
892
893
894
895
896
897
898
        if (prepare_finalize.activation_format ==
                FusedMoEActivationFormat.BatchedExperts):
            max_num_tokens_per_rank = (
                prepare_finalize.max_num_tokens_per_rank())
            assert max_num_tokens_per_rank is not None
            logger.debug(
                "BatchedTritonOrDeepGemmExperts(%s): "
                "max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
                self.__class__.__name__, max_num_tokens_per_rank,
                self.quant_config.weight_block_size, False)
            return BatchedTritonOrDeepGemmExperts(
899
                max_num_tokens=max_num_tokens_per_rank,
900
                num_dispatchers=prepare_finalize.num_dispatchers(),
901
                use_fp8_w8a8=True,
902
                block_shape=self.quant_config.weight_block_size,
bnellnm's avatar
bnellnm committed
903
                per_act_token_quant=False,
904
                allow_deep_gemm=self.allow_deep_gemm,
905
906
            )
        else:
bnellnm's avatar
bnellnm committed
907
908
909
910
911
            logger.debug(
                "TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s",
                self.__class__.__name__, self.quant_config.weight_block_size,
                False)
            return TritonOrDeepGemmExperts(
912
913
914
915
916
                use_fp8_w8a8=True,
                block_shape=self.quant_config.weight_block_size,
                allow_deep_gemm=self.allow_deep_gemm,
            )

917
918
919
920
921
922
923
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
924
        use_grouped_topk: bool = False,
925
926
        topk_group: Optional[int] = None,
        num_expert_group: Optional[int] = None,
927
928
        global_num_experts: int = -1,
        expert_map: Optional[torch.Tensor] = None,
929
        custom_routing_function: Optional[Callable] = None,
Simon Mo's avatar
Simon Mo committed
930
931
        scoring_func: str = "softmax",
        e_score_correction_bias: Optional[torch.Tensor] = None,
932
        apply_router_weight_on_input: bool = False,
Michael Goin's avatar
Michael Goin committed
933
        activation: str = "silu",
934
935
936
937
        enable_eplb: bool = False,
        expert_load_view: Optional[torch.Tensor] = None,
        logical_to_physical_map: Optional[torch.Tensor] = None,
        logical_replica_count: Optional[torch.Tensor] = None,
938
    ) -> torch.Tensor:
939
940
941
942
943
        if enable_eplb:
            assert expert_load_view is not None
            assert logical_to_physical_map is not None
            assert logical_replica_count is not None
            assert isinstance(layer, FusedMoE)
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
        if not self.flashinfer_moe_enabled:
            topk_weights, topk_ids = FusedMoE.select_experts(
                hidden_states=x,
                router_logits=router_logits,
                use_grouped_topk=use_grouped_topk,
                top_k=top_k,
                renormalize=renormalize,
                topk_group=topk_group,
                num_expert_group=num_expert_group,
                custom_routing_function=custom_routing_function,
                scoring_func=scoring_func,
                e_score_correction_bias=e_score_correction_bias,
                indices_type=self.topk_indices_dtype,
                enable_eplb=enable_eplb,
                expert_map=expert_map,
                expert_load_view=expert_load_view,
                logical_to_physical_map=logical_to_physical_map,
                logical_replica_count=logical_replica_count,
            )
963

964
        if self.rocm_aiter_moe_enabled:
965
966
            from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (  # noqa: E501
                rocm_aiter_fused_experts)
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
            return rocm_aiter_fused_experts(
                x,
                layer.w13_weight,
                layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                activation=activation,
                use_fp8_w8a8=True,
                apply_router_weight_on_input=apply_router_weight_on_input,
                w1_scale=(layer.w13_weight_scale_inv
                          if self.block_quant else layer.w13_weight_scale),
                w2_scale=(layer.w2_weight_scale_inv
                          if self.block_quant else layer.w2_weight_scale),
                a1_scale=layer.w13_input_scale,
                a2_scale=layer.w2_input_scale,
982
983
                block_shape=self.quant_config.weight_block_size,
                expert_map=expert_map)
984
        elif self.use_marlin:
985
986
            assert activation == "silu", (
                f"{activation} not supported for Marlin MoE.")
987
988
989
990
991
992
993
994
995
996
            return torch.ops.vllm.fused_marlin_moe(
                x,
                layer.w13_weight,
                layer.w2_weight,
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                router_logits,
                topk_weights,
                topk_ids,
                quant_type_id=scalar_types.float8_e4m3fn.id,
997
                apply_router_weight_on_input=apply_router_weight_on_input,
998
999
                global_num_experts=global_num_experts,
                expert_map=expert_map)
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
        elif self.flashinfer_moe_enabled:
            # Currently only work with DS models
            assert self.block_quant
            assert (renormalize and use_grouped_topk
                    and scoring_func == 'sigmoid'
                    and custom_routing_function is None)
            assert activation == "silu"
            return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
                routing_logits=router_logits.to(torch.float32),
                routing_bias=e_score_correction_bias,
                x=x,
                w13_weight=layer.w13_weight,
                w13_weight_scale_inv=layer.w13_weight_scale_inv,
                w2_weight=layer.w2_weight,
                w2_weight_scale_inv=layer.w2_weight_scale_inv,
                global_num_experts=global_num_experts,
                top_k=top_k,
                num_expert_group=num_expert_group,
                topk_group=topk_group,
                intermediate_size=layer.intermediate_size_per_partition,
                expert_offset=layer.ep_rank * layer.local_num_experts,
                local_num_experts=layer.local_num_experts,
                block_shape=self.quant_config.weight_block_size,
                routed_scaling=1.0,
            )
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
        else:
            return self.fused_experts(
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                inplace=True,
                activation=activation,
                global_num_experts=global_num_experts,
                apply_router_weight_on_input=apply_router_weight_on_input,
                expert_map=expert_map,
                w1_scale=(layer.w13_weight_scale_inv
                          if self.block_quant else layer.w13_weight_scale),
                w2_scale=(layer.w2_weight_scale_inv
                          if self.block_quant else layer.w2_weight_scale),
                a1_scale=layer.w13_input_scale,
                a2_scale=layer.w2_input_scale,
            )
1044
1045


1046
1047
1048
class Fp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
1049
1050
1051
    """

    def __init__(self, quant_config: Fp8Config):
1052
        super().__init__(quant_config)