"docs/vscode:/vscode.git/clone" did not exist on "3097ce3a329bfe19ce3add0dbf5aaa46d7b99e88"
fp8.py 48.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import functools
bnellnm's avatar
bnellnm committed
5
from typing import TYPE_CHECKING, Any, Callable, Optional
6
7

import torch
8
import torch.nn.functional as F
9
10
11
from torch.nn import Module
from torch.nn.parameter import Parameter

12
import vllm.envs as envs
13
from vllm import _custom_ops as ops
14
from vllm.distributed import get_tensor_model_parallel_world_size
15
from vllm.logger import init_logger
bnellnm's avatar
bnellnm committed
16
from vllm.model_executor.layers.fused_moe import (
17
18
19
    FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase,
    FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize,
    FusedMoeWeightScaleSupported)
20
21
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
                                               UnquantizedLinearMethod)
22
from vllm.model_executor.layers.quantization import QuantizationMethods
23
from vllm.model_executor.layers.quantization.base_config import (
24
    QuantizationConfig, QuantizeMethodBase)
25
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
26
27
28
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
    apply_flashinfer_per_tensor_scale_fp8, rotate_flashinfer_fp8_moe_weights,
    swap_w13_to_w31)
29
30
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
    get_col_major_tma_aligned_tensor, requant_weight_ue8m0_inplace)
31
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
32
33
    apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin,
    prepare_moe_fp8_layer_for_marlin)
34
from vllm.model_executor.layers.quantization.utils.quant_utils import (
35
    GroupShape, is_layer_skipped)
36
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
37
38
39
40
    Fp8LinearOp, all_close_1d, cutlass_block_fp8_supported,
    cutlass_fp8_supported, maybe_create_device_identity,
    normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
    requantize_with_max_scale)
41
42
from vllm.model_executor.parameter import (BlockQuantScaleParameter,
                                           ModelWeightParameter,
43
                                           PerTensorScaleParameter)
44
from vllm.model_executor.utils import set_weight_attrs
45
from vllm.platforms import current_platform
46
from vllm.scalar_type import scalar_types
47
from vllm.utils import has_deep_gemm
48
49
from vllm.utils.deep_gemm import (is_blackwell_deep_gemm_e8m0_used,
                                  is_deep_gemm_supported)
50
from vllm.utils.flashinfer import has_flashinfer_moe
51

52
53
54
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

55
56
57
58
ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)

59
60
61
62
63
64

def _is_col_major(x: torch.Tensor) -> bool:
    assert x.dim() == 3
    b, m, n = x.shape
    return x.stride(0) == m * n and x.stride(1) == 1 and x.stride(2) == m

65

66
class Fp8Config(QuantizationConfig):
67
68
    """Config class for FP8."""

69
70
    def __init__(
        self,
71
        is_checkpoint_fp8_serialized: bool = False,
72
        activation_scheme: str = "dynamic",
73
74
        ignored_layers: Optional[list[str]] = None,
        weight_block_size: Optional[list[int]] = None,
75
    ) -> None:
76
        super().__init__()
77

78
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
79

80
81
82
        if activation_scheme not in ACTIVATION_SCHEMES:
            raise ValueError(
                f"Unsupported activation scheme {activation_scheme}")
83
        self.activation_scheme = activation_scheme
84
        self.ignored_layers = ignored_layers or []
85
86
87
88
89
90
91
92
93
94
95
96
97
98
        if weight_block_size is not None:
            if not is_checkpoint_fp8_serialized:
                raise ValueError(
                    "The block-wise quantization only supports fp8-serialized "
                    "checkpoint for now.")
            if len(weight_block_size) != 2:
                raise ValueError(
                    "The quantization block size of weight must have 2 "
                    f"dimensions, but got {len(weight_block_size)} dimensions")
            if activation_scheme != "dynamic":
                raise ValueError("The block-wise quantization only supports "
                                 "dynamic activation scheme for now, but got "
                                 f"{activation_scheme} activation scheme.")
        self.weight_block_size = weight_block_size
99

100
    @classmethod
101
    def get_name(cls) -> QuantizationMethods:
102
103
104
        return "fp8"

    @classmethod
105
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
106
107
108
109
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
110
        return 80
111
112

    @classmethod
113
    def get_config_filenames(cls) -> list[str]:
114
115
        return []

116
117
118
119
120
    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if self.ignored_layers is not None:
            self.ignored_layers = hf_to_vllm_mapper.apply_list(
                self.ignored_layers)

121
    @classmethod
122
    def from_config(cls, config: dict[str, Any]) -> "Fp8Config":
123
124
        quant_method = cls.get_from_keys(config, ["quant_method"])
        is_checkpoint_fp8_serialized = ("fp8" in quant_method)
125
        activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
126
        ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
127
128
        weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"],
                                                 None)
129
        return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
130
                   activation_scheme=activation_scheme,
131
132
                   ignored_layers=ignored_layers,
                   weight_block_size=weight_block_size)
133

134
135
    def get_quant_method(self, layer: torch.nn.Module,
                         prefix: str) -> Optional["QuantizeMethodBase"]:
136
137
        from vllm.attention.layer import Attention  # Avoid circular import

138
        if isinstance(layer, LinearBase):
139
140
141
            if is_layer_skipped(prefix=prefix,
                                ignored_layers=self.ignored_layers,
                                fused_mapping=self.packed_modules_mapping):
142
                return UnquantizedLinearMethod()
143
            return Fp8LinearMethod(self)
144
145
146
        elif isinstance(layer, FusedMoE):
            return Fp8MoEMethod(self)
        elif isinstance(layer, Attention):
147
            return Fp8KVCacheMethod(self)
148
        return None
149

150
151
152
153
154
155
156
157
158
159
160
161
162
    def get_cache_scale(self, name: str) -> Optional[str]:
        """
        Check whether the param name matches the format for k/v cache scales
        in compressed-tensors. If this is the case, return its equivalent
        param name expected by vLLM

        :param name: param name
        :return: matching param name for KV cache scale in vLLM
        """
        if name.endswith(".output_scale") and ".k_proj" in name:
            return name.replace(".k_proj.output_scale", ".attn.k_scale")
        if name.endswith(".output_scale") and ".v_proj" in name:
            return name.replace(".v_proj.output_scale", ".attn.v_scale")
163
164
165
166
167
        if name.endswith(".output_scale") and ".q_proj" in name:
            return name.replace(".q_proj.output_scale", ".attn.q_scale")
        if name.endswith("self_attn.prob_output_scale"):
            return name.replace(".prob_output_scale", ".attn.prob_scale")
        # If no matches, return None
168
169
        return None

170
171
172

class Fp8LinearMethod(LinearMethodBase):
    """Linear method for FP8.
173
174
175
176
177
178
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.
179
180
181
182
183

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
    2. Only support float8_e4m3fn data type due to the limitation of
       torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
184

185
186
187
188
    Args:
        quant_config: The quantization config.
    """

189
    def __init__(self, quant_config: Fp8Config):
190
        self.quant_config = quant_config
191
        self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
192
        self.out_dtype = torch.get_default_dtype()
193

194
195
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
196
197
        self.use_marlin = (not current_platform.has_device_capability(89)
                           or envs.VLLM_TEST_FORCE_FP8_MARLIN)
198
        # Disable marlin for rocm
199
        if current_platform.is_rocm():
200
            self.use_marlin = False
201

202
203
204
205
206
207
208
        # AITER is only supported on ROCm and only for FP8_FNUZ
        # and at the moment are MI300 series
        self.use_aiter_and_is_supported = (current_platform.is_rocm()
                                           and envs.VLLM_ROCM_USE_AITER
                                           and envs.VLLM_ROCM_USE_AITER_LINEAR
                                           and current_platform.is_fp8_fnuz())

209
        self.block_quant = self.quant_config.weight_block_size is not None
210
211
212
213
214
215
216
        self.act_q_static = self.quant_config.activation_scheme == "static"
        # Use per-token quantization for better perf if dynamic and cutlass
        if not self.act_q_static and cutlass_fp8_supported():
            self.act_q_group_shape = GroupShape.PER_TOKEN
        else:
            self.act_q_group_shape = GroupShape.PER_TENSOR

217
        self.fp8_linear = Fp8LinearOp(
218
219
220
            act_quant_static=self.act_q_static,
            act_quant_group_shape=self.act_q_group_shape,
            cutlass_fp8_supported=cutlass_fp8_supported())
221

222
223
224
225
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
226
        output_partition_sizes: list[int],
227
228
229
230
231
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
232
233
        maybe_create_device_identity()

234
        output_size_per_partition = sum(output_partition_sizes)
235
        weight_loader = extra_weight_attrs.get("weight_loader")
236
237
238
239
240
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None
241

242
243
244
        if self.block_quant:
            tp_size = get_tensor_model_parallel_world_size()
            assert self.quant_config.weight_block_size is not None
245
            layer.weight_block_size = self.quant_config.weight_block_size
246
247
248
249
250
251
252
253
254
255
256
257
258
            block_n, block_k = (
                self.quant_config.weight_block_size[0],
                self.quant_config.weight_block_size[1],
            )
            # Required by row parallel
            if (tp_size > 1
                    and input_size // input_size_per_partition == tp_size
                    and input_size_per_partition % block_k != 0):
                raise ValueError(
                    f"Weight input_size_per_partition = "
                    f"{input_size_per_partition} is not divisible by "
                    f"weight quantization block_k = {block_k}.")
            # Required by column parallel or enabling merged weights
259
260
261
262
263
264
265
266
267
268
            is_tp_split = (tp_size > 1 and
                           output_size // output_size_per_partition == tp_size)
            is_merged_gemm = len(output_partition_sizes) > 1
            if is_tp_split or is_merged_gemm:
                sizes_to_check = output_partition_sizes
                if not is_tp_split and is_merged_gemm:
                    # In case of merged matrices, we allow the last
                    # matrix to not be a multiple of block size
                    sizes_to_check = output_partition_sizes[:-1]
                for output_partition_size in sizes_to_check:
269
270
271
272
273
274
                    if output_partition_size % block_n != 0:
                        raise ValueError(
                            f"Weight output_partition_size = "
                            f"{output_partition_size} is not divisible by "
                            f"weight quantization block_n = {block_n}.")

275
276
277
278
        # WEIGHT
        weight_dtype = (torch.float8_e4m3fn
                        if self.quant_config.is_checkpoint_fp8_serialized else
                        params_dtype)
279
280
281
282
283
284
285
286

        weight = ModelWeightParameter(data=torch.empty(
            output_size_per_partition,
            input_size_per_partition,
            dtype=weight_dtype),
                                      input_dim=1,
                                      output_dim=0,
                                      weight_loader=weight_loader)
287
288
        layer.register_parameter("weight", weight)

289
290
291
292
        # If checkpoint is serialized fp8, load them.
        # Otherwise, wait until process_weights_after_loading.
        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
293
294
295
296
297
298
299
            if not self.block_quant:
                scale = PerTensorScaleParameter(
                    data=torch.empty(len(output_partition_sizes),
                                     dtype=torch.float32),
                    weight_loader=weight_loader,
                )
                scale[:] = torch.finfo(torch.float32).min
300
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
301
302
303
304
305
306
307
308
309
310
311
312
313
314
                layer.register_parameter("weight_scale", scale)
            else:
                assert self.quant_config.activation_scheme == "dynamic"
                scale = BlockQuantScaleParameter(
                    data=torch.empty(
                        (output_size_per_partition + block_n - 1) // block_n,
                        (input_size_per_partition + block_k - 1) // block_k,
                        dtype=torch.float32,
                    ),
                    input_dim=1,
                    output_dim=0,
                    weight_loader=weight_loader,
                )
                scale[:] = torch.finfo(torch.float32).min
315
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
316
317
                # The weight_scale_inv name is intentional for deepseekv3
                layer.register_parameter("weight_scale_inv", scale)
318

319
            # INPUT ACTIVATION SCALE
320
            if self.quant_config.activation_scheme == "static":
321
322
323
324
325
                scale = PerTensorScaleParameter(data=torch.empty(
                    len(output_partition_sizes), dtype=torch.float32),
                                                weight_loader=weight_loader)

                scale[:] = torch.finfo(torch.float32).min
326
                set_weight_attrs(scale, {"scale_type": "input_scale"})
327
                layer.register_parameter("input_scale", scale)
328
329
            else:
                layer.register_parameter("input_scale", None)
330

331
    def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor:
332
333
334
335
336
337
338
339
340
341
        # Pad the weight tensor. This is an optimization on ROCm platform, which
        # can benefit from tensors located far enough from one another in memory
        if (envs.VLLM_ROCM_FP8_PADDING and current_platform.is_rocm()
                and weight.stride(-1) == 1
                and (weight.stride(-2) * weight.element_size()) % 512 == 0):
            num_pad = 256 // weight.element_size()
            weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
            torch.cuda.empty_cache()
        return weight

342
    def process_weights_after_loading(self, layer: Module) -> None:
343
        size_k_first = True
344
        # TODO(rob): refactor block quant into separate class.
345
        if self.block_quant:
346
            assert self.quant_config.activation_scheme == "dynamic"
347
            size_k_first = False
348
            if current_platform.is_fp8_fnuz():
349
                weight, weight_scale_inv, _ = \
350
351
                    normalize_e4m3fn_to_e4m3fnuz(
                        weight=layer.weight,
352
353
354
355
356
                        weight_scale=layer.weight_scale_inv)
            else:
                weight = layer.weight.data
                weight_scale_inv = layer.weight_scale_inv.data

357
            weight = self._maybe_pad_weight(weight)
358

359
360
361
362
363
            # Torch.compile cannot use Parameter subclasses.
            layer.weight = Parameter(weight, requires_grad=False)
            layer.weight_scale_inv = Parameter(weight_scale_inv,
                                               requires_grad=False)

364
        # If checkpoint not serialized fp8, quantize the weights.
365
        elif not self.quant_config.is_checkpoint_fp8_serialized:
366
367
            qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
                                                         scale=None)
368
369

            # Update the layer with the new values.
370
371
            layer.weight = Parameter(qweight.t(), requires_grad=False)
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
372
            layer.input_scale = None
373

374
375
        # If checkpoint is fp8, handle that there are N scales for N
        # shards in a fused module
376
        else:
377
378
379
380
381
            layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
                                                    requires_grad=False)
            if self.quant_config.activation_scheme == "static":
                layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
                                                       requires_grad=False)
382
383
384

            weight = layer.weight
            weight_scale = layer.weight_scale
385
386
387

            # If using w8a8, torch._scaled_mm needs per tensor, so
            # requantize the logical shards as a single weight.
388
            if not self.use_marlin:
389
                # Dequant -> Quant with max scale so we can run per tensor.
390
                if current_platform.is_fp8_fnuz():
391
392
393
394
395
396
397
398
399
                    weight, weight_scale, input_scale = \
                        normalize_e4m3fn_to_e4m3fnuz(
                            weight=weight,
                            weight_scale=weight_scale,
                            input_scale=layer.input_scale)
                    if input_scale is not None:
                        layer.input_scale = Parameter(input_scale,
                                                      requires_grad=False)

400
                weight_scale, weight = requantize_with_max_scale(
401
402
                    weight=weight,
                    weight_scale=weight_scale,
403
404
                    logical_widths=layer.logical_widths,
                )
405

406
            weight = self._maybe_pad_weight(weight)
407
            # Update layer with new values.
408
            layer.weight = Parameter(weight.t(), requires_grad=False)
409
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
410
            if self.quant_config.activation_scheme == "static":
411
412
                layer.input_scale = Parameter(layer.input_scale.max(),
                                              requires_grad=False)
413

414
        if self.use_marlin:
415
            prepare_fp8_layer_for_marlin(layer, size_k_first)
416
417
            # Activations not quantized for marlin.
            del layer.input_scale
418

419
        # On B200, if E8M0 for DeepGemm is used, we need to
420
421
        # requantize the weight and input to the specific scale
        # at the same time.
422
        if is_blackwell_deep_gemm_e8m0_used():
423
424
425
426
427
428
429
430
431
            assert layer.weight_block_size is not None
            block_sz = tuple(layer.weight_block_size)
            requant_weight_ue8m0_inplace(
                layer.weight.data,
                layer.weight_scale_inv.data if hasattr(
                    layer, "weight_scale_inv") else layer.weight_scale.data,
                block_sz,
            )

432
433
434
435
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
436

437
        if self.use_marlin:
438
439
440
441
            return apply_fp8_marlin_linear(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale,
442
443
444
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
445
                bias=bias)
446

447
448
        if self.block_quant:
            assert self.quant_config.weight_block_size is not None
449

450
            return torch.ops.vllm.apply_w8a8_block_fp8_linear(
451
452
453
454
455
456
                input=x,
                weight=layer.weight,
                block_size=self.quant_config.weight_block_size,
                weight_scale=layer.weight_scale_inv,
                input_scale=layer.input_scale,
                bias=bias,
457
                cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
458
                use_aiter_and_is_supported=self.use_aiter_and_is_supported,
459
460
            )

461
462
463
        return self.fp8_linear.apply(input=x,
                                     weight=layer.weight,
                                     weight_scale=layer.weight_scale,
464
                                     out_dtype=self.out_dtype,
465
466
                                     input_scale=layer.input_scale,
                                     bias=bias)
467
468


469
470
471
472
473
474
475
476
477
478
479
480
481
482
class Fp8MoEMethod(FusedMoEMethodBase):
    """MoE method for FP8.
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

    def __init__(self, quant_config: Fp8Config):
483

484
        from vllm.model_executor.layers.fused_moe import fused_experts
485
        self.quant_config = quant_config
486
        self.block_quant = self.quant_config.weight_block_size is not None
487

488
489
490
491
492
        self.flashinfer_moe_enabled = False
        if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
            logger.info_once(
                "Using FlashInfer MoE FP8 kernels for Fp8MoEMethod.")
            self.flashinfer_moe_enabled = True
493
494
495
496
497
498
499
500
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
        self.use_marlin = (not current_platform.has_device_capability(89)
                           or envs.VLLM_TEST_FORCE_FP8_MARLIN)
        # Disable marlin for rocm
        if current_platform.is_rocm():
            self.use_marlin = False

501
502
503
        # Check for DeepGemm support.
        self.allow_deep_gemm = False
        if envs.VLLM_USE_DEEP_GEMM:
504
            if not has_deep_gemm():
505
                logger.warning_once("Failed to import DeepGemm kernels.")
506
507
            elif not self.block_quant:
                logger.warning_once("Model is not block quantized. Not using "
508
                                    "DeepGemm kernels")
509
            elif (is_deep_gemm_supported()):
510
511
512
513
514
515
                logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.")
                self.allow_deep_gemm = True
            else:
                logger.warning_once(
                    "DeepGemm not supported on the current platform.")

516
517
518
        # Check for CutlassBlockScaledGroupedGemm support.
        self.allow_cutlass_block_scaled_grouped_gemm = False
        if not self.block_quant:
519
520
            logger.debug_once("Model is not block quantized. Not using "
                              "CutlassBlockScaledGroupedGemm kernels")
521
        elif (current_platform.is_cuda()
522
              and current_platform.is_device_capability(100)):
523
524
525
526
527
528
529
530
531
            logger.info_once(
                "Using CutlassBlockScaledGroupedGemm kernels for Fp8MoEMethod."
            )
            self.allow_cutlass_block_scaled_grouped_gemm = True
        else:
            logger.warning_once(
                "CutlassBlockScaledGroupedGemm not supported on the current "
                "platform.")

532
        self.topk_indices_dtype = None
533
        self.fused_experts = functools.partial(  # type: ignore
534
            fused_experts,
535
            use_fp8_w8a8=True,
536
            block_shape=self.quant_config.weight_block_size,
537
538
539
            allow_deep_gemm=self.allow_deep_gemm,
            allow_cutlass_block_scaled_grouped_gemm=(
                self.allow_cutlass_block_scaled_grouped_gemm))
540

541
    def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
542
543
                       intermediate_size_per_partition: int,
                       params_dtype: torch.dtype, **extra_weight_attrs):
544

545
546
547
548
549
550
        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

551
552
        if self.quant_config.is_checkpoint_fp8_serialized:
            params_dtype = torch.float8_e4m3fn
553
554
        if self.block_quant:
            assert self.quant_config.weight_block_size is not None
555
            layer.weight_block_size = self.quant_config.weight_block_size
556
557
558
559
560
561
562
563
564
            tp_size = get_tensor_model_parallel_world_size()
            block_n, block_k = (
                self.quant_config.weight_block_size[0],
                self.quant_config.weight_block_size[1],
            )
            # NOTE: To ensure proper alignment of the block-wise quantization
            # scales, the output_size of the weights for both the gate and up
            # layers must be divisible by block_n.
            # Required by column parallel or enabling merged weights
565
            if intermediate_size_per_partition % block_n != 0:
566
567
                raise ValueError(
                    f"The output_size of gate's and up's weight = "
568
                    f"{intermediate_size_per_partition} is not divisible by "
569
                    f"weight quantization block_n = {block_n}.")
570
571
            if (tp_size > 1
                    and intermediate_size_per_partition % block_k != 0):
572
                # Required by row parallel
573
574
575
576
                raise ValueError(
                    f"The input_size of down's weight = "
                    f"{intermediate_size_per_partition} is not divisible by "
                    f"weight quantization block_k = {block_k}.")
577
578

        # WEIGHTS
579
580
581
582
583
        w13_weight = torch.nn.Parameter(torch.empty(
            num_experts,
            2 * intermediate_size_per_partition,
            hidden_size,
            dtype=params_dtype),
584
585
586
587
                                        requires_grad=False)
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

588
589
590
591
592
        w2_weight = torch.nn.Parameter(torch.empty(
            num_experts,
            hidden_size,
            intermediate_size_per_partition,
            dtype=params_dtype),
593
594
595
596
597
                                       requires_grad=False)
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
        if not self.block_quant:
            # Allocate 2 scales for w1 and w3 respectively.
            # They will be combined to a single scale after weight loading.
            w13_weight_scale = torch.nn.Parameter(torch.ones(
                num_experts, 2, dtype=torch.float32),
                                                  requires_grad=False)
            w2_weight_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                 requires_grad=False)
            layer.register_parameter("w13_weight_scale", w13_weight_scale)
            layer.register_parameter("w2_weight_scale", w2_weight_scale)
        else:
            w13_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
613
614
                    2 * ((intermediate_size_per_partition + block_n - 1) //
                         block_n),
615
616
617
618
619
620
621
622
623
                    (hidden_size + block_k - 1) // block_k,
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            w2_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
                    (hidden_size + block_n - 1) // block_n,
624
                    (intermediate_size_per_partition + block_k - 1) // block_k,
625
626
627
628
629
630
631
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
            layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
            assert self.quant_config.activation_scheme == "dynamic"
632

633
634
635
        # Add the quantization method used (per tensor/grouped/channel)
        # to ensure the weight scales are loaded in properly
        extra_weight_attrs.update(
636
637
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.
             value} if self.block_quant else
638
            {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
639
640
641
642
        # If loading fp8 checkpoint, pass the weight loaders.
        # If loading an fp16 checkpoint, do not (we will quantize in
        #   process_weights_after_loading()
        if self.quant_config.is_checkpoint_fp8_serialized:
643
644
            set_weight_attrs(w13_weight_scale, extra_weight_attrs)
            set_weight_attrs(w2_weight_scale, extra_weight_attrs)
645
646
647
648
649
650
651
652

        # INPUT_SCALES
        if self.quant_config.activation_scheme == "static":
            if not self.quant_config.is_checkpoint_fp8_serialized:
                raise ValueError(
                    "Found static activation scheme for checkpoint that "
                    "was not serialized fp8.")

653
654
655
656
            w13_input_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                 requires_grad=False)
            layer.register_parameter("w13_input_scale", w13_input_scale)
657
            set_weight_attrs(w13_input_scale, extra_weight_attrs)
658
659
660
661
662

            w2_input_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                requires_grad=False)
            layer.register_parameter("w2_input_scale", w2_input_scale)
663
664
            set_weight_attrs(w2_input_scale, extra_weight_attrs)

665
        else:
666
667
            layer.w13_input_scale = None
            layer.w2_input_scale = None
668
669

    def process_weights_after_loading(self, layer: Module) -> None:
670
671
        # Lazy import to avoid importing triton too early.
        from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
672
            is_rocm_aiter_moe_enabled, shuffle_weights)
673

674
675
        self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()

676
        # TODO (rob): refactor block quant into separate class.
677
        if self.block_quant:
678
            assert self.quant_config.activation_scheme == "dynamic"
679
            if current_platform.is_fp8_fnuz():
680
681
682
683
684
685
686
687
                w13_weight, w13_weight_scale_inv, w13_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w13_weight, layer.w13_weight_scale_inv,
                        layer.w13_input_scale)
                w2_weight, w2_weight_scale_inv, w2_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w2_weight, layer.w2_weight_scale_inv,
                        layer.w2_input_scale)
688
689
690
            elif self.flashinfer_moe_enabled:
                # NOTE: weights have to be swapped since the activation is
                # applied on different half for flashinfer vs vllm
691
692
                w13_weight = swap_w13_to_w31(layer.w13_weight.data)
                w13_weight_scale_inv = swap_w13_to_w31(
693
694
695
                    layer.w13_weight_scale_inv.data)
                w2_weight = layer.w2_weight.data
                w2_weight_scale_inv = layer.w2_weight_scale_inv.data
696
697
                if not self.block_quant:
                    rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
698
699
700
701
702
703
704
705
706
707
708
709
710
            else:
                w13_weight = layer.w13_weight.data
                w13_weight_scale_inv = layer.w13_weight_scale_inv.data
                w2_weight = layer.w2_weight
                w2_weight_scale_inv = layer.w2_weight_scale_inv

            # torch.compile() cannot use Parameter subclasses.
            layer.w13_weight = Parameter(w13_weight, requires_grad=False)
            layer.w13_weight_scale_inv = Parameter(w13_weight_scale_inv,
                                                   requires_grad=False)
            layer.w2_weight = Parameter(w2_weight, requires_grad=False)
            layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv,
                                                  requires_grad=False)
711
            if self.rocm_aiter_moe_enabled:
712
713
                # reshaping weights is required for aiter moe kernel.
                shuffled_w13, shuffled_w2 = shuffle_weights(
714
                    layer.w13_weight.data, layer.w2_weight.data)
715
716
717
718
719

                layer.w13_weight = torch.nn.Parameter(shuffled_w13,
                                                      requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2,
                                                     requires_grad=False)
720
721
722

            # DeepGemm scales need to be transposed and aligned.  We try to do
            # it ahead of time for performance reasons.
723
            if self.allow_deep_gemm and not is_blackwell_deep_gemm_e8m0_used():
724
725
726
                # Lazy import to avoid CUDA initialization problems.
                if _is_col_major(layer.w13_weight_scale_inv):
                    layer.w13_weight_scale_inv = \
727
                        get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv).contiguous()
728
729
                if _is_col_major(layer.w2_weight_scale_inv):
                    layer.w2_weight_scale_inv = \
730
                        get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv).contiguous()
731

732
        # If checkpoint is fp16, quantize in place.
733
        elif not self.quant_config.is_checkpoint_fp8_serialized:
734
            fp8_dtype = current_platform.fp8_dtype()
735
            w13_weight = torch.empty_like(layer.w13_weight.data,
736
737
                                          dtype=fp8_dtype)
            w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
738
739
740

            # Re-initialize w13_scale because we directly quantize
            # merged w13 weights and generate a single scaling factor.
741
            layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
742
                layer.local_num_experts,
743
744
                dtype=torch.float32,
                device=w13_weight.device),
745
                                                        requires_grad=False)
746
            for expert in range(layer.local_num_experts):
747
                w13_weight[expert, :, :], layer.w13_weight_scale[
748
749
                    expert] = ops.scaled_fp8_quant(
                        layer.w13_weight.data[expert, :, :])
750
                w2_weight[expert, :, :], layer.w2_weight_scale[
751
752
753
754
755
756
                    expert] = ops.scaled_fp8_quant(
                        layer.w2_weight.data[expert, :, :])
            layer.w13_weight = torch.nn.Parameter(w13_weight,
                                                  requires_grad=False)
            layer.w2_weight = torch.nn.Parameter(w2_weight,
                                                 requires_grad=False)
757
            if self.rocm_aiter_moe_enabled:
758
                # reshaping weights is required for aiter moe kernel.
759
760
                shuffled_w13, shuffled_w2 = shuffle_weights(
                    layer.w13_weight, layer.w2_weight)
761
762
763
764
765

                layer.w13_weight = torch.nn.Parameter(shuffled_w13,
                                                      requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2,
                                                     requires_grad=False)
766
767
768
769
770
771
772
        # If checkpoint is fp8, we need to handle that the
        # MoE kernels require single activation scale and single weight
        # scale for w13 per expert.
        else:
            # Fp8 moe kernels require a single activation scale.
            # We take the max of all the scales in case they differ.
            if self.quant_config.activation_scheme == "static":
773
774
                if (layer.w13_input_scale is None
                        or layer.w2_input_scale is None):
775
776
777
                    raise ValueError(
                        "QuantConfig has static quantization, but found "
                        "activation scales are None.")
778
779
                if (not all_close_1d(layer.w13_input_scale)
                        or not all_close_1d(layer.w2_input_scale)):
780
                    logger.warning_once(
781
782
                        "Found input_scales that are not equal for "
                        "fp8 MoE layer. Using the maximum across experts "
783
                        "for each layer.")
784
785
786
787
                layer.w13_input_scale = torch.nn.Parameter(
                    layer.w13_input_scale.max(), requires_grad=False)
                layer.w2_input_scale = torch.nn.Parameter(
                    layer.w2_input_scale.max(), requires_grad=False)
788
            if current_platform.is_fp8_fnuz():
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
                # Normalize the weights and scales
                w13_weight, w13_weight_scale, w13_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w13_weight, layer.w13_weight_scale,
                        layer.w13_input_scale)
                w2_weight, w2_weight_scale, w2_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w2_weight, layer.w2_weight_scale,
                        layer.w2_input_scale)
                # Reset the parameter
                layer.w13_weight = torch.nn.Parameter(w13_weight,
                                                      requires_grad=False)
                layer.w13_weight_scale = torch.nn.Parameter(
                    w13_weight_scale, requires_grad=False)
                if w13_input_scale is not None:
                    layer.w13_input_scale = torch.nn.Parameter(
                        w13_input_scale, requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(w2_weight,
                                                     requires_grad=False)
                layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale,
                                                           requires_grad=False)
                if w2_input_scale is not None:
                    layer.w2_input_scale = torch.nn.Parameter(
                        w2_input_scale, requires_grad=False)
813
814
815

            # Fp8 moe kernel needs single weight scale for w13 per expert.
            # We take the max then dequant and requant each expert.
816
            assert layer.w13_weight_scale is not None
817
            shard_size = layer.intermediate_size_per_partition
818
            max_w13_scales = layer.w13_weight_scale.max(dim=1).values
819
            for expert_id in range(layer.local_num_experts):
820
821
822
823
824
                start = 0
                for shard_id in range(2):
                    dq_weight = per_tensor_dequantize(
                        layer.w13_weight[expert_id][start:start +
                                                    shard_size, :],
825
                        layer.w13_weight_scale[expert_id][shard_id])
826
                    layer.w13_weight[expert_id][
827
                        start:start + shard_size, :], _ = ops.scaled_fp8_quant(
828
829
830
                            dq_weight, max_w13_scales[expert_id])
                    start += shard_size

831
            if self.rocm_aiter_moe_enabled:
832
833
                shuffled_w13, shuffled_w2 = shuffle_weights(
                    layer.w13_weight, layer.w2_weight)
834
835
836
837
838
839

                layer.w13_weight = torch.nn.Parameter(shuffled_w13,
                                                      requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2,
                                                     requires_grad=False)

840
841
            layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
                                                        requires_grad=False)
842
843
844
845
846
847

        if self.use_marlin:
            prepare_moe_fp8_layer_for_marlin(layer, False)
            # Activations not quantized for marlin.
            del layer.w13_input_scale
            del layer.w2_input_scale
848

849
        if is_blackwell_deep_gemm_e8m0_used():
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
            assert layer.weight_block_size is not None
            # Re-quantise the expert weights so their scales are UE8M0.
            block_sz = tuple(layer.weight_block_size)
            requant_weight_ue8m0_inplace(
                layer.w13_weight.data,
                layer.w13_weight_scale_inv.data,
                block_sz,
            )
            requant_weight_ue8m0_inplace(
                layer.w2_weight.data,
                layer.w2_weight_scale_inv.data,
                block_sz,
            )

            # Ensure column-major TMA alignment expected by DeepGEMM.
            if _is_col_major(layer.w13_weight_scale_inv):
                layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(
                    layer.w13_weight_scale_inv).contiguous()
            if _is_col_major(layer.w2_weight_scale_inv):
                layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(
                    layer.w2_weight_scale_inv).contiguous()

bnellnm's avatar
bnellnm committed
872
873
874
875
876
    def select_gemm_impl(
        self,
        prepare_finalize: FusedMoEPrepareAndFinalize,
        moe: FusedMoEConfig,
    ) -> FusedMoEPermuteExpertsUnpermute:
877
878
879
        from vllm.model_executor.layers.fused_moe import (
            BatchedTritonOrDeepGemmExperts, TritonOrDeepGemmExperts)

880
881
        assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
            "Marlin and ROCm AITER are not supported with all2all yet.")
882

bnellnm's avatar
bnellnm committed
883
884
885
886
887
888
889
890
891
892
893
        if (prepare_finalize.activation_format ==
                FusedMoEActivationFormat.BatchedExperts):
            max_num_tokens_per_rank = (
                prepare_finalize.max_num_tokens_per_rank())
            assert max_num_tokens_per_rank is not None
            logger.debug(
                "BatchedTritonOrDeepGemmExperts(%s): "
                "max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
                self.__class__.__name__, max_num_tokens_per_rank,
                self.quant_config.weight_block_size, False)
            return BatchedTritonOrDeepGemmExperts(
894
                max_num_tokens=max_num_tokens_per_rank,
895
                num_dispatchers=prepare_finalize.num_dispatchers(),
896
                use_fp8_w8a8=True,
897
                block_shape=self.quant_config.weight_block_size,
bnellnm's avatar
bnellnm committed
898
                per_act_token_quant=False,
899
                allow_deep_gemm=self.allow_deep_gemm,
900
901
            )
        else:
bnellnm's avatar
bnellnm committed
902
903
904
905
906
            logger.debug(
                "TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s",
                self.__class__.__name__, self.quant_config.weight_block_size,
                False)
            return TritonOrDeepGemmExperts(
907
908
909
910
911
                use_fp8_w8a8=True,
                block_shape=self.quant_config.weight_block_size,
                allow_deep_gemm=self.allow_deep_gemm,
            )

912
913
914
915
916
917
918
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
919
        use_grouped_topk: bool = False,
920
921
        topk_group: Optional[int] = None,
        num_expert_group: Optional[int] = None,
922
923
        global_num_experts: int = -1,
        expert_map: Optional[torch.Tensor] = None,
924
        custom_routing_function: Optional[Callable] = None,
Simon Mo's avatar
Simon Mo committed
925
926
        scoring_func: str = "softmax",
        e_score_correction_bias: Optional[torch.Tensor] = None,
927
        apply_router_weight_on_input: bool = False,
Michael Goin's avatar
Michael Goin committed
928
        activation: str = "silu",
929
930
931
932
        enable_eplb: bool = False,
        expert_load_view: Optional[torch.Tensor] = None,
        logical_to_physical_map: Optional[torch.Tensor] = None,
        logical_replica_count: Optional[torch.Tensor] = None,
933
    ) -> torch.Tensor:
934
935
936
937
938
        if enable_eplb:
            assert expert_load_view is not None
            assert logical_to_physical_map is not None
            assert logical_replica_count is not None
            assert isinstance(layer, FusedMoE)
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
        if not self.flashinfer_moe_enabled:
            topk_weights, topk_ids = FusedMoE.select_experts(
                hidden_states=x,
                router_logits=router_logits,
                use_grouped_topk=use_grouped_topk,
                top_k=top_k,
                renormalize=renormalize,
                topk_group=topk_group,
                num_expert_group=num_expert_group,
                custom_routing_function=custom_routing_function,
                scoring_func=scoring_func,
                e_score_correction_bias=e_score_correction_bias,
                indices_type=self.topk_indices_dtype,
                enable_eplb=enable_eplb,
                expert_map=expert_map,
                expert_load_view=expert_load_view,
                logical_to_physical_map=logical_to_physical_map,
                logical_replica_count=logical_replica_count,
            )
958

959
        if self.rocm_aiter_moe_enabled:
960
961
            from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (  # noqa: E501
                rocm_aiter_fused_experts)
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
            return rocm_aiter_fused_experts(
                x,
                layer.w13_weight,
                layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                activation=activation,
                use_fp8_w8a8=True,
                apply_router_weight_on_input=apply_router_weight_on_input,
                w1_scale=(layer.w13_weight_scale_inv
                          if self.block_quant else layer.w13_weight_scale),
                w2_scale=(layer.w2_weight_scale_inv
                          if self.block_quant else layer.w2_weight_scale),
                a1_scale=layer.w13_input_scale,
                a2_scale=layer.w2_input_scale,
977
978
                block_shape=self.quant_config.weight_block_size,
                expert_map=expert_map)
979
        elif self.use_marlin:
980
981
            assert activation == "silu", (
                f"{activation} not supported for Marlin MoE.")
982
983
984
985
986
987
988
989
990
991
            return torch.ops.vllm.fused_marlin_moe(
                x,
                layer.w13_weight,
                layer.w2_weight,
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                router_logits,
                topk_weights,
                topk_ids,
                quant_type_id=scalar_types.float8_e4m3fn.id,
992
                apply_router_weight_on_input=apply_router_weight_on_input,
993
994
                global_num_experts=global_num_experts,
                expert_map=expert_map)
995
        elif self.flashinfer_moe_enabled:
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
            assert activation == 'silu'
            assert scoring_func == 'sigmoid'
            if self.block_quant:
                assert (renormalize and use_grouped_topk
                        and custom_routing_function is None)

                return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
                    routing_logits=router_logits.to(torch.float32),
                    routing_bias=e_score_correction_bias,
                    x=x,
                    w13_weight=layer.w13_weight,
                    w13_weight_scale_inv=layer.w13_weight_scale_inv,
                    w2_weight=layer.w2_weight,
                    w2_weight_scale_inv=layer.w2_weight_scale_inv,
                    global_num_experts=global_num_experts,
                    top_k=top_k,
                    num_expert_group=num_expert_group,
                    topk_group=topk_group,
                    intermediate_size=layer.intermediate_size_per_partition,
                    expert_offset=layer.ep_rank * layer.local_num_experts,
                    local_num_experts=layer.local_num_experts,
                    block_shape=self.quant_config.weight_block_size,
                    routed_scaling=1.0,
                )
            else:
                assert (not renormalize
                        and custom_routing_function is not None)
                return apply_flashinfer_per_tensor_scale_fp8(
                    layer=layer,
                    hidden_states=x,
                    router_logits=router_logits,
                    routing_bias=e_score_correction_bias,
                    global_num_experts=global_num_experts,
                    top_k=top_k,
                    num_expert_group=num_expert_group,
                    topk_group=topk_group,
                    apply_router_weight_on_input=apply_router_weight_on_input)
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
        else:
            return self.fused_experts(
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                inplace=True,
                activation=activation,
                global_num_experts=global_num_experts,
                apply_router_weight_on_input=apply_router_weight_on_input,
                expert_map=expert_map,
                w1_scale=(layer.w13_weight_scale_inv
                          if self.block_quant else layer.w13_weight_scale),
                w2_scale=(layer.w2_weight_scale_inv
                          if self.block_quant else layer.w2_weight_scale),
                a1_scale=layer.w13_input_scale,
                a2_scale=layer.w2_input_scale,
            )
1052
1053


1054
1055
1056
class Fp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
1057
1058
1059
    """

    def __init__(self, quant_config: Fp8Config):
1060
        super().__init__(quant_config)