fp8.py 45.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import functools
bnellnm's avatar
bnellnm committed
5
from typing import TYPE_CHECKING, Any, Callable, Optional
6
7

import torch
8
import torch.nn.functional as F
9
10
11
from torch.nn import Module
from torch.nn.parameter import Parameter

12
import vllm.envs as envs
13
from vllm import _custom_ops as ops
14
from vllm.distributed import get_tensor_model_parallel_world_size
15
from vllm.logger import init_logger
bnellnm's avatar
bnellnm committed
16
from vllm.model_executor.layers.fused_moe import (
17
18
19
    FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase,
    FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize,
    FusedMoeWeightScaleSupported)
20
21
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
                                               UnquantizedLinearMethod)
22
from vllm.model_executor.layers.quantization import QuantizationMethods
23
from vllm.model_executor.layers.quantization.base_config import (
24
    QuantizationConfig, QuantizeMethodBase)
25
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
26
27
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
    get_col_major_tma_aligned_tensor, requant_weight_ue8m0_inplace)
28
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
29
30
    apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin,
    prepare_moe_fp8_layer_for_marlin)
31
from vllm.model_executor.layers.quantization.utils.quant_utils import (
32
    GroupShape, is_layer_skipped)
33
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
34
35
36
37
    Fp8LinearOp, all_close_1d, cutlass_block_fp8_supported,
    cutlass_fp8_supported, maybe_create_device_identity,
    normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
    requantize_with_max_scale)
38
39
from vllm.model_executor.parameter import (BlockQuantScaleParameter,
                                           ModelWeightParameter,
40
                                           PerTensorScaleParameter)
41
from vllm.model_executor.utils import set_weight_attrs
42
from vllm.platforms import current_platform
43
from vllm.scalar_type import scalar_types
44
from vllm.utils import has_deep_gemm
45
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
46

47
48
49
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

50
51
52
53
ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)

54
55
56
57
58
59

def _is_col_major(x: torch.Tensor) -> bool:
    assert x.dim() == 3
    b, m, n = x.shape
    return x.stride(0) == m * n and x.stride(1) == 1 and x.stride(2) == m

60

61
class Fp8Config(QuantizationConfig):
62
63
    """Config class for FP8."""

64
65
    def __init__(
        self,
66
        is_checkpoint_fp8_serialized: bool = False,
67
        activation_scheme: str = "dynamic",
68
69
        ignored_layers: Optional[list[str]] = None,
        weight_block_size: Optional[list[int]] = None,
70
    ) -> None:
71
        super().__init__()
72

73
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
74

75
76
77
        if activation_scheme not in ACTIVATION_SCHEMES:
            raise ValueError(
                f"Unsupported activation scheme {activation_scheme}")
78
        self.activation_scheme = activation_scheme
79
        self.ignored_layers = ignored_layers or []
80
81
82
83
84
85
86
87
88
89
90
91
92
93
        if weight_block_size is not None:
            if not is_checkpoint_fp8_serialized:
                raise ValueError(
                    "The block-wise quantization only supports fp8-serialized "
                    "checkpoint for now.")
            if len(weight_block_size) != 2:
                raise ValueError(
                    "The quantization block size of weight must have 2 "
                    f"dimensions, but got {len(weight_block_size)} dimensions")
            if activation_scheme != "dynamic":
                raise ValueError("The block-wise quantization only supports "
                                 "dynamic activation scheme for now, but got "
                                 f"{activation_scheme} activation scheme.")
        self.weight_block_size = weight_block_size
94

95
    @classmethod
96
    def get_name(cls) -> QuantizationMethods:
97
98
99
        return "fp8"

    @classmethod
100
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
101
102
103
104
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
105
        return 80
106
107

    @classmethod
108
    def get_config_filenames(cls) -> list[str]:
109
110
        return []

111
112
113
114
115
    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if self.ignored_layers is not None:
            self.ignored_layers = hf_to_vllm_mapper.apply_list(
                self.ignored_layers)

116
    @classmethod
117
    def from_config(cls, config: dict[str, Any]) -> "Fp8Config":
118
119
        quant_method = cls.get_from_keys(config, ["quant_method"])
        is_checkpoint_fp8_serialized = ("fp8" in quant_method)
120
        activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
121
        ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
122
123
        weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"],
                                                 None)
124
        return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
125
                   activation_scheme=activation_scheme,
126
127
                   ignored_layers=ignored_layers,
                   weight_block_size=weight_block_size)
128

129
130
    def get_quant_method(self, layer: torch.nn.Module,
                         prefix: str) -> Optional["QuantizeMethodBase"]:
131
132
        from vllm.attention.layer import Attention  # Avoid circular import

133
        if isinstance(layer, LinearBase):
134
135
136
            if is_layer_skipped(prefix=prefix,
                                ignored_layers=self.ignored_layers,
                                fused_mapping=self.packed_modules_mapping):
137
                return UnquantizedLinearMethod()
138
            return Fp8LinearMethod(self)
139
140
141
        elif isinstance(layer, FusedMoE):
            return Fp8MoEMethod(self)
        elif isinstance(layer, Attention):
142
            return Fp8KVCacheMethod(self)
143
        return None
144

145
146
147
148
149
150
151
152
153
154
155
156
157
    def get_cache_scale(self, name: str) -> Optional[str]:
        """
        Check whether the param name matches the format for k/v cache scales
        in compressed-tensors. If this is the case, return its equivalent
        param name expected by vLLM

        :param name: param name
        :return: matching param name for KV cache scale in vLLM
        """
        if name.endswith(".output_scale") and ".k_proj" in name:
            return name.replace(".k_proj.output_scale", ".attn.k_scale")
        if name.endswith(".output_scale") and ".v_proj" in name:
            return name.replace(".v_proj.output_scale", ".attn.v_scale")
158
159
160
161
162
        if name.endswith(".output_scale") and ".q_proj" in name:
            return name.replace(".q_proj.output_scale", ".attn.q_scale")
        if name.endswith("self_attn.prob_output_scale"):
            return name.replace(".prob_output_scale", ".attn.prob_scale")
        # If no matches, return None
163
164
        return None

165
166
167

class Fp8LinearMethod(LinearMethodBase):
    """Linear method for FP8.
168
169
170
171
172
173
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.
174
175
176
177
178

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
    2. Only support float8_e4m3fn data type due to the limitation of
       torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
179

180
181
182
183
    Args:
        quant_config: The quantization config.
    """

184
    def __init__(self, quant_config: Fp8Config):
185
        self.quant_config = quant_config
186
        self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
187
        self.out_dtype = torch.get_default_dtype()
188

189
190
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
191
192
        self.use_marlin = (not current_platform.has_device_capability(89)
                           or envs.VLLM_TEST_FORCE_FP8_MARLIN)
193
        # Disable marlin for rocm
194
        if current_platform.is_rocm():
195
            self.use_marlin = False
196

197
198
199
200
201
202
203
        # AITER is only supported on ROCm and only for FP8_FNUZ
        # and at the moment are MI300 series
        self.use_aiter_and_is_supported = (current_platform.is_rocm()
                                           and envs.VLLM_ROCM_USE_AITER
                                           and envs.VLLM_ROCM_USE_AITER_LINEAR
                                           and current_platform.is_fp8_fnuz())

204
        self.block_quant = self.quant_config.weight_block_size is not None
205
206
207
208
209
210
211
        self.act_q_static = self.quant_config.activation_scheme == "static"
        # Use per-token quantization for better perf if dynamic and cutlass
        if not self.act_q_static and cutlass_fp8_supported():
            self.act_q_group_shape = GroupShape.PER_TOKEN
        else:
            self.act_q_group_shape = GroupShape.PER_TENSOR

212
        self.fp8_linear = Fp8LinearOp(
213
214
215
            act_quant_static=self.act_q_static,
            act_quant_group_shape=self.act_q_group_shape,
            cutlass_fp8_supported=cutlass_fp8_supported())
216

217
218
219
220
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
221
        output_partition_sizes: list[int],
222
223
224
225
226
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
227
228
        maybe_create_device_identity()

229
        output_size_per_partition = sum(output_partition_sizes)
230
        weight_loader = extra_weight_attrs.get("weight_loader")
231
232
233
234
235
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None
236

237
238
239
        if self.block_quant:
            tp_size = get_tensor_model_parallel_world_size()
            assert self.quant_config.weight_block_size is not None
240
            layer.weight_block_size = self.quant_config.weight_block_size
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
            block_n, block_k = (
                self.quant_config.weight_block_size[0],
                self.quant_config.weight_block_size[1],
            )
            # Required by row parallel
            if (tp_size > 1
                    and input_size // input_size_per_partition == tp_size
                    and input_size_per_partition % block_k != 0):
                raise ValueError(
                    f"Weight input_size_per_partition = "
                    f"{input_size_per_partition} is not divisible by "
                    f"weight quantization block_k = {block_k}.")
            # Required by column parallel or enabling merged weights
            if (tp_size > 1 and output_size // output_size_per_partition
                    == tp_size) or len(output_partition_sizes) > 1:
                for output_partition_size in output_partition_sizes:
                    if output_partition_size % block_n != 0:
                        raise ValueError(
                            f"Weight output_partition_size = "
                            f"{output_partition_size} is not divisible by "
                            f"weight quantization block_n = {block_n}.")

263
264
265
266
        # WEIGHT
        weight_dtype = (torch.float8_e4m3fn
                        if self.quant_config.is_checkpoint_fp8_serialized else
                        params_dtype)
267
268
269
270
271
272
273
274

        weight = ModelWeightParameter(data=torch.empty(
            output_size_per_partition,
            input_size_per_partition,
            dtype=weight_dtype),
                                      input_dim=1,
                                      output_dim=0,
                                      weight_loader=weight_loader)
275
276
        layer.register_parameter("weight", weight)

277
278
279
280
        # If checkpoint is serialized fp8, load them.
        # Otherwise, wait until process_weights_after_loading.
        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
281
282
283
284
285
286
287
            if not self.block_quant:
                scale = PerTensorScaleParameter(
                    data=torch.empty(len(output_partition_sizes),
                                     dtype=torch.float32),
                    weight_loader=weight_loader,
                )
                scale[:] = torch.finfo(torch.float32).min
288
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
289
290
291
292
293
294
295
296
297
298
299
300
301
302
                layer.register_parameter("weight_scale", scale)
            else:
                assert self.quant_config.activation_scheme == "dynamic"
                scale = BlockQuantScaleParameter(
                    data=torch.empty(
                        (output_size_per_partition + block_n - 1) // block_n,
                        (input_size_per_partition + block_k - 1) // block_k,
                        dtype=torch.float32,
                    ),
                    input_dim=1,
                    output_dim=0,
                    weight_loader=weight_loader,
                )
                scale[:] = torch.finfo(torch.float32).min
303
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
304
305
                # The weight_scale_inv name is intentional for deepseekv3
                layer.register_parameter("weight_scale_inv", scale)
306

307
            # INPUT ACTIVATION SCALE
308
            if self.quant_config.activation_scheme == "static":
309
310
311
312
313
                scale = PerTensorScaleParameter(data=torch.empty(
                    len(output_partition_sizes), dtype=torch.float32),
                                                weight_loader=weight_loader)

                scale[:] = torch.finfo(torch.float32).min
314
                set_weight_attrs(scale, {"scale_type": "input_scale"})
315
                layer.register_parameter("input_scale", scale)
316
317
            else:
                layer.register_parameter("input_scale", None)
318

319
    def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor:
320
321
322
323
324
325
326
327
328
329
        # Pad the weight tensor. This is an optimization on ROCm platform, which
        # can benefit from tensors located far enough from one another in memory
        if (envs.VLLM_ROCM_FP8_PADDING and current_platform.is_rocm()
                and weight.stride(-1) == 1
                and (weight.stride(-2) * weight.element_size()) % 512 == 0):
            num_pad = 256 // weight.element_size()
            weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
            torch.cuda.empty_cache()
        return weight

330
    def process_weights_after_loading(self, layer: Module) -> None:
331
        size_k_first = True
332
        # TODO(rob): refactor block quant into separate class.
333
        if self.block_quant:
334
            assert self.quant_config.activation_scheme == "dynamic"
335
            size_k_first = False
336
            if current_platform.is_fp8_fnuz():
337
                weight, weight_scale_inv, _ = \
338
339
                    normalize_e4m3fn_to_e4m3fnuz(
                        weight=layer.weight,
340
341
342
343
344
                        weight_scale=layer.weight_scale_inv)
            else:
                weight = layer.weight.data
                weight_scale_inv = layer.weight_scale_inv.data

345
            weight = self._maybe_pad_weight(weight)
346

347
348
349
350
351
            # Torch.compile cannot use Parameter subclasses.
            layer.weight = Parameter(weight, requires_grad=False)
            layer.weight_scale_inv = Parameter(weight_scale_inv,
                                               requires_grad=False)

352
        # If checkpoint not serialized fp8, quantize the weights.
353
        elif not self.quant_config.is_checkpoint_fp8_serialized:
354
355
            qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
                                                         scale=None)
356
357

            # Update the layer with the new values.
358
359
            layer.weight = Parameter(qweight.t(), requires_grad=False)
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
360
            layer.input_scale = None
361

362
363
        # If checkpoint is fp8, handle that there are N scales for N
        # shards in a fused module
364
        else:
365
366
367
368
369
            layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
                                                    requires_grad=False)
            if self.quant_config.activation_scheme == "static":
                layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
                                                       requires_grad=False)
370
371
372

            weight = layer.weight
            weight_scale = layer.weight_scale
373
374
375

            # If using w8a8, torch._scaled_mm needs per tensor, so
            # requantize the logical shards as a single weight.
376
            if not self.use_marlin:
377
                # Dequant -> Quant with max scale so we can run per tensor.
378
                if current_platform.is_fp8_fnuz():
379
380
381
382
383
384
385
386
387
                    weight, weight_scale, input_scale = \
                        normalize_e4m3fn_to_e4m3fnuz(
                            weight=weight,
                            weight_scale=weight_scale,
                            input_scale=layer.input_scale)
                    if input_scale is not None:
                        layer.input_scale = Parameter(input_scale,
                                                      requires_grad=False)

388
                weight_scale, weight = requantize_with_max_scale(
389
390
                    weight=weight,
                    weight_scale=weight_scale,
391
392
                    logical_widths=layer.logical_widths,
                )
393

394
            weight = self._maybe_pad_weight(weight)
395
            # Update layer with new values.
396
            layer.weight = Parameter(weight.t(), requires_grad=False)
397
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
398
            if self.quant_config.activation_scheme == "static":
399
400
                layer.input_scale = Parameter(layer.input_scale.max(),
                                              requires_grad=False)
401

402
        if self.use_marlin:
403
            prepare_fp8_layer_for_marlin(layer, size_k_first)
404
405
            # Activations not quantized for marlin.
            del layer.input_scale
406

407
408
409
410
411
412
413
414
415
416
417
418
419
        # On B200, DeepGemm only support E8M0 scale, which means we need to
        # requantize the weight and input to the specific scale
        # at the same time.
        if is_blackwell_deep_gemm_used():
            assert layer.weight_block_size is not None
            block_sz = tuple(layer.weight_block_size)
            requant_weight_ue8m0_inplace(
                layer.weight.data,
                layer.weight_scale_inv.data if hasattr(
                    layer, "weight_scale_inv") else layer.weight_scale.data,
                block_sz,
            )

420
421
422
423
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
424

425
        if self.use_marlin:
426
427
428
429
            return apply_fp8_marlin_linear(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale,
430
431
432
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
433
                bias=bias)
434

435
436
        if self.block_quant:
            assert self.quant_config.weight_block_size is not None
437

438
            return torch.ops.vllm.apply_w8a8_block_fp8_linear(
439
440
441
442
443
444
                input=x,
                weight=layer.weight,
                block_size=self.quant_config.weight_block_size,
                weight_scale=layer.weight_scale_inv,
                input_scale=layer.input_scale,
                bias=bias,
445
                cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
446
                use_aiter_and_is_supported=self.use_aiter_and_is_supported,
447
448
            )

449
450
451
        return self.fp8_linear.apply(input=x,
                                     weight=layer.weight,
                                     weight_scale=layer.weight_scale,
452
                                     out_dtype=self.out_dtype,
453
454
                                     input_scale=layer.input_scale,
                                     bias=bias)
455
456


457
458
459
460
461
462
463
464
465
466
467
468
469
470
class Fp8MoEMethod(FusedMoEMethodBase):
    """MoE method for FP8.
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

    def __init__(self, quant_config: Fp8Config):
471

472
        from vllm.model_executor.layers.fused_moe import fused_experts
473
        self.quant_config = quant_config
474
        self.block_quant = self.quant_config.weight_block_size is not None
475

476
477
478
479
480
481
482
483
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
        self.use_marlin = (not current_platform.has_device_capability(89)
                           or envs.VLLM_TEST_FORCE_FP8_MARLIN)
        # Disable marlin for rocm
        if current_platform.is_rocm():
            self.use_marlin = False

484
485
486
        # Check for DeepGemm support.
        self.allow_deep_gemm = False
        if envs.VLLM_USE_DEEP_GEMM:
487
            if not has_deep_gemm():
488
                logger.warning_once("Failed to import DeepGemm kernels.")
489
490
            elif not self.block_quant:
                logger.warning_once("Model is not block quantized. Not using "
491
                                    "DeepGemm kernels")
492
            elif (current_platform.is_cuda()
493
                  and current_platform.is_device_capability(90)):
494
495
                logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.")
                self.allow_deep_gemm = True
496
497
498
499
500
            elif (current_platform.is_cuda()
                  and is_blackwell_deep_gemm_used()):
                logger.info_once("Using DeepGemm SM100 kernels for "
                                 "Fp8MoEMethod.")
                self.allow_deep_gemm = True
501
502
503
504
            else:
                logger.warning_once(
                    "DeepGemm not supported on the current platform.")

505
506
507
        # Check for CutlassBlockScaledGroupedGemm support.
        self.allow_cutlass_block_scaled_grouped_gemm = False
        if not self.block_quant:
508
509
            logger.debug_once("Model is not block quantized. Not using "
                              "CutlassBlockScaledGroupedGemm kernels")
510
        elif (current_platform.is_cuda()
511
              and current_platform.is_device_capability(100)):
512
513
514
515
516
517
518
519
520
            logger.info_once(
                "Using CutlassBlockScaledGroupedGemm kernels for Fp8MoEMethod."
            )
            self.allow_cutlass_block_scaled_grouped_gemm = True
        else:
            logger.warning_once(
                "CutlassBlockScaledGroupedGemm not supported on the current "
                "platform.")

521
        self.topk_indices_dtype = None
522
        self.fused_experts = functools.partial(  # type: ignore
523
            fused_experts,
524
            use_fp8_w8a8=True,
525
            block_shape=self.quant_config.weight_block_size,
526
527
528
            allow_deep_gemm=self.allow_deep_gemm,
            allow_cutlass_block_scaled_grouped_gemm=(
                self.allow_cutlass_block_scaled_grouped_gemm))
529

530
    def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
531
532
                       intermediate_size_per_partition: int,
                       params_dtype: torch.dtype, **extra_weight_attrs):
533

534
535
536
537
538
539
        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

540
541
        if self.quant_config.is_checkpoint_fp8_serialized:
            params_dtype = torch.float8_e4m3fn
542
543
        if self.block_quant:
            assert self.quant_config.weight_block_size is not None
544
            layer.weight_block_size = self.quant_config.weight_block_size
545
546
547
548
549
550
551
552
553
            tp_size = get_tensor_model_parallel_world_size()
            block_n, block_k = (
                self.quant_config.weight_block_size[0],
                self.quant_config.weight_block_size[1],
            )
            # NOTE: To ensure proper alignment of the block-wise quantization
            # scales, the output_size of the weights for both the gate and up
            # layers must be divisible by block_n.
            # Required by column parallel or enabling merged weights
554
            if intermediate_size_per_partition % block_n != 0:
555
556
                raise ValueError(
                    f"The output_size of gate's and up's weight = "
557
                    f"{intermediate_size_per_partition} is not divisible by "
558
                    f"weight quantization block_n = {block_n}.")
559
560
            if (tp_size > 1
                    and intermediate_size_per_partition % block_k != 0):
561
                # Required by row parallel
562
563
564
565
                raise ValueError(
                    f"The input_size of down's weight = "
                    f"{intermediate_size_per_partition} is not divisible by "
                    f"weight quantization block_k = {block_k}.")
566
567

        # WEIGHTS
568
569
570
571
572
        w13_weight = torch.nn.Parameter(torch.empty(
            num_experts,
            2 * intermediate_size_per_partition,
            hidden_size,
            dtype=params_dtype),
573
574
575
576
                                        requires_grad=False)
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

577
578
579
580
581
        w2_weight = torch.nn.Parameter(torch.empty(
            num_experts,
            hidden_size,
            intermediate_size_per_partition,
            dtype=params_dtype),
582
583
584
585
586
                                       requires_grad=False)
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
        if not self.block_quant:
            # Allocate 2 scales for w1 and w3 respectively.
            # They will be combined to a single scale after weight loading.
            w13_weight_scale = torch.nn.Parameter(torch.ones(
                num_experts, 2, dtype=torch.float32),
                                                  requires_grad=False)
            w2_weight_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                 requires_grad=False)
            layer.register_parameter("w13_weight_scale", w13_weight_scale)
            layer.register_parameter("w2_weight_scale", w2_weight_scale)
        else:
            w13_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
602
603
                    2 * ((intermediate_size_per_partition + block_n - 1) //
                         block_n),
604
605
606
607
608
609
610
611
612
                    (hidden_size + block_k - 1) // block_k,
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            w2_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
                    (hidden_size + block_n - 1) // block_n,
613
                    (intermediate_size_per_partition + block_k - 1) // block_k,
614
615
616
617
618
619
620
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
            layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
            assert self.quant_config.activation_scheme == "dynamic"
621

622
623
624
        # Add the quantization method used (per tensor/grouped/channel)
        # to ensure the weight scales are loaded in properly
        extra_weight_attrs.update(
625
626
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.
             value} if self.block_quant else
627
            {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
628
629
630
631
        # If loading fp8 checkpoint, pass the weight loaders.
        # If loading an fp16 checkpoint, do not (we will quantize in
        #   process_weights_after_loading()
        if self.quant_config.is_checkpoint_fp8_serialized:
632
633
            set_weight_attrs(w13_weight_scale, extra_weight_attrs)
            set_weight_attrs(w2_weight_scale, extra_weight_attrs)
634
635
636
637
638
639
640
641

        # INPUT_SCALES
        if self.quant_config.activation_scheme == "static":
            if not self.quant_config.is_checkpoint_fp8_serialized:
                raise ValueError(
                    "Found static activation scheme for checkpoint that "
                    "was not serialized fp8.")

642
643
644
645
            w13_input_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                 requires_grad=False)
            layer.register_parameter("w13_input_scale", w13_input_scale)
646
            set_weight_attrs(w13_input_scale, extra_weight_attrs)
647
648
649
650
651

            w2_input_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                requires_grad=False)
            layer.register_parameter("w2_input_scale", w2_input_scale)
652
653
            set_weight_attrs(w2_input_scale, extra_weight_attrs)

654
        else:
655
656
            layer.w13_input_scale = None
            layer.w2_input_scale = None
657
658

    def process_weights_after_loading(self, layer: Module) -> None:
659
660
        # Lazy import to avoid importing triton too early.
        from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
661
            is_rocm_aiter_moe_enabled, shuffle_weights)
662

663
664
        self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()

665
        # TODO (rob): refactor block quant into separate class.
666
        if self.block_quant:
667
            assert self.quant_config.activation_scheme == "dynamic"
668
            if current_platform.is_fp8_fnuz():
669
670
671
672
673
674
675
676
                w13_weight, w13_weight_scale_inv, w13_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w13_weight, layer.w13_weight_scale_inv,
                        layer.w13_input_scale)
                w2_weight, w2_weight_scale_inv, w2_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w2_weight, layer.w2_weight_scale_inv,
                        layer.w2_input_scale)
677
678
679
680
681
682
683
684
685
686
687
688
689
            else:
                w13_weight = layer.w13_weight.data
                w13_weight_scale_inv = layer.w13_weight_scale_inv.data
                w2_weight = layer.w2_weight
                w2_weight_scale_inv = layer.w2_weight_scale_inv

            # torch.compile() cannot use Parameter subclasses.
            layer.w13_weight = Parameter(w13_weight, requires_grad=False)
            layer.w13_weight_scale_inv = Parameter(w13_weight_scale_inv,
                                                   requires_grad=False)
            layer.w2_weight = Parameter(w2_weight, requires_grad=False)
            layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv,
                                                  requires_grad=False)
690
            if self.rocm_aiter_moe_enabled:
691
692
                # reshaping weights is required for aiter moe kernel.
                shuffled_w13, shuffled_w2 = shuffle_weights(
693
                    layer.w13_weight.data, layer.w2_weight.data)
694
695
696
697
698

                layer.w13_weight = torch.nn.Parameter(shuffled_w13,
                                                      requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2,
                                                     requires_grad=False)
699
700
701

            # DeepGemm scales need to be transposed and aligned.  We try to do
            # it ahead of time for performance reasons.
702
            if self.allow_deep_gemm and not is_blackwell_deep_gemm_used():
703
704
705
                # Lazy import to avoid CUDA initialization problems.
                if _is_col_major(layer.w13_weight_scale_inv):
                    layer.w13_weight_scale_inv = \
706
                        get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv).contiguous()
707
708
                if _is_col_major(layer.w2_weight_scale_inv):
                    layer.w2_weight_scale_inv = \
709
                        get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv).contiguous()
710

711
        # If checkpoint is fp16, quantize in place.
712
        elif not self.quant_config.is_checkpoint_fp8_serialized:
713
            fp8_dtype = current_platform.fp8_dtype()
714
            w13_weight = torch.empty_like(layer.w13_weight.data,
715
716
                                          dtype=fp8_dtype)
            w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
717
718
719

            # Re-initialize w13_scale because we directly quantize
            # merged w13 weights and generate a single scaling factor.
720
            layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
721
                layer.local_num_experts,
722
723
                dtype=torch.float32,
                device=w13_weight.device),
724
                                                        requires_grad=False)
725
            for expert in range(layer.local_num_experts):
726
                w13_weight[expert, :, :], layer.w13_weight_scale[
727
728
                    expert] = ops.scaled_fp8_quant(
                        layer.w13_weight.data[expert, :, :])
729
                w2_weight[expert, :, :], layer.w2_weight_scale[
730
731
732
733
734
735
                    expert] = ops.scaled_fp8_quant(
                        layer.w2_weight.data[expert, :, :])
            layer.w13_weight = torch.nn.Parameter(w13_weight,
                                                  requires_grad=False)
            layer.w2_weight = torch.nn.Parameter(w2_weight,
                                                 requires_grad=False)
736
            if self.rocm_aiter_moe_enabled:
737
                # reshaping weights is required for aiter moe kernel.
738
739
                shuffled_w13, shuffled_w2 = shuffle_weights(
                    layer.w13_weight, layer.w2_weight)
740
741
742
743
744

                layer.w13_weight = torch.nn.Parameter(shuffled_w13,
                                                      requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2,
                                                     requires_grad=False)
745
746
747
748
749
750
751
        # If checkpoint is fp8, we need to handle that the
        # MoE kernels require single activation scale and single weight
        # scale for w13 per expert.
        else:
            # Fp8 moe kernels require a single activation scale.
            # We take the max of all the scales in case they differ.
            if self.quant_config.activation_scheme == "static":
752
753
                if (layer.w13_input_scale is None
                        or layer.w2_input_scale is None):
754
755
756
                    raise ValueError(
                        "QuantConfig has static quantization, but found "
                        "activation scales are None.")
757
758
                if (not all_close_1d(layer.w13_input_scale)
                        or not all_close_1d(layer.w2_input_scale)):
759
                    logger.warning_once(
760
761
                        "Found input_scales that are not equal for "
                        "fp8 MoE layer. Using the maximum across experts "
762
                        "for each layer.")
763
764
765
766
                layer.w13_input_scale = torch.nn.Parameter(
                    layer.w13_input_scale.max(), requires_grad=False)
                layer.w2_input_scale = torch.nn.Parameter(
                    layer.w2_input_scale.max(), requires_grad=False)
767
            if current_platform.is_fp8_fnuz():
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
                # Normalize the weights and scales
                w13_weight, w13_weight_scale, w13_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w13_weight, layer.w13_weight_scale,
                        layer.w13_input_scale)
                w2_weight, w2_weight_scale, w2_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w2_weight, layer.w2_weight_scale,
                        layer.w2_input_scale)
                # Reset the parameter
                layer.w13_weight = torch.nn.Parameter(w13_weight,
                                                      requires_grad=False)
                layer.w13_weight_scale = torch.nn.Parameter(
                    w13_weight_scale, requires_grad=False)
                if w13_input_scale is not None:
                    layer.w13_input_scale = torch.nn.Parameter(
                        w13_input_scale, requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(w2_weight,
                                                     requires_grad=False)
                layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale,
                                                           requires_grad=False)
                if w2_input_scale is not None:
                    layer.w2_input_scale = torch.nn.Parameter(
                        w2_input_scale, requires_grad=False)
792
793
794

            # Fp8 moe kernel needs single weight scale for w13 per expert.
            # We take the max then dequant and requant each expert.
795
            assert layer.w13_weight_scale is not None
796
            shard_size = layer.intermediate_size_per_partition
797
            max_w13_scales = layer.w13_weight_scale.max(dim=1).values
798
            for expert_id in range(layer.local_num_experts):
799
800
801
802
803
                start = 0
                for shard_id in range(2):
                    dq_weight = per_tensor_dequantize(
                        layer.w13_weight[expert_id][start:start +
                                                    shard_size, :],
804
                        layer.w13_weight_scale[expert_id][shard_id])
805
                    layer.w13_weight[expert_id][
806
                        start:start + shard_size, :], _ = ops.scaled_fp8_quant(
807
808
809
                            dq_weight, max_w13_scales[expert_id])
                    start += shard_size

810
            if self.rocm_aiter_moe_enabled:
811
812
                shuffled_w13, shuffled_w2 = shuffle_weights(
                    layer.w13_weight, layer.w2_weight)
813
814
815
816
817
818

                layer.w13_weight = torch.nn.Parameter(shuffled_w13,
                                                      requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2,
                                                     requires_grad=False)

819
820
            layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
                                                        requires_grad=False)
821
822
823
824
825
826

        if self.use_marlin:
            prepare_moe_fp8_layer_for_marlin(layer, False)
            # Activations not quantized for marlin.
            del layer.w13_input_scale
            del layer.w2_input_scale
827

828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
        if is_blackwell_deep_gemm_used():
            assert layer.weight_block_size is not None
            # Re-quantise the expert weights so their scales are UE8M0.
            block_sz = tuple(layer.weight_block_size)
            requant_weight_ue8m0_inplace(
                layer.w13_weight.data,
                layer.w13_weight_scale_inv.data,
                block_sz,
            )
            requant_weight_ue8m0_inplace(
                layer.w2_weight.data,
                layer.w2_weight_scale_inv.data,
                block_sz,
            )

            # Ensure column-major TMA alignment expected by DeepGEMM.
            if _is_col_major(layer.w13_weight_scale_inv):
                layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(
                    layer.w13_weight_scale_inv).contiguous()
            if _is_col_major(layer.w2_weight_scale_inv):
                layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(
                    layer.w2_weight_scale_inv).contiguous()

bnellnm's avatar
bnellnm committed
851
852
853
854
855
    def select_gemm_impl(
        self,
        prepare_finalize: FusedMoEPrepareAndFinalize,
        moe: FusedMoEConfig,
    ) -> FusedMoEPermuteExpertsUnpermute:
856
857
858
        from vllm.model_executor.layers.fused_moe import (
            BatchedTritonOrDeepGemmExperts, TritonOrDeepGemmExperts)

859
860
        assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
            "Marlin and ROCm AITER are not supported with all2all yet.")
861

bnellnm's avatar
bnellnm committed
862
863
864
865
866
867
868
869
870
871
872
        if (prepare_finalize.activation_format ==
                FusedMoEActivationFormat.BatchedExperts):
            max_num_tokens_per_rank = (
                prepare_finalize.max_num_tokens_per_rank())
            assert max_num_tokens_per_rank is not None
            logger.debug(
                "BatchedTritonOrDeepGemmExperts(%s): "
                "max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
                self.__class__.__name__, max_num_tokens_per_rank,
                self.quant_config.weight_block_size, False)
            return BatchedTritonOrDeepGemmExperts(
873
                max_num_tokens=max_num_tokens_per_rank,
874
                num_dispatchers=prepare_finalize.num_dispatchers(),
875
                use_fp8_w8a8=True,
876
                block_shape=self.quant_config.weight_block_size,
bnellnm's avatar
bnellnm committed
877
                per_act_token_quant=False,
878
                allow_deep_gemm=self.allow_deep_gemm,
879
880
            )
        else:
bnellnm's avatar
bnellnm committed
881
882
883
884
885
            logger.debug(
                "TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s",
                self.__class__.__name__, self.quant_config.weight_block_size,
                False)
            return TritonOrDeepGemmExperts(
886
887
888
889
890
                use_fp8_w8a8=True,
                block_shape=self.quant_config.weight_block_size,
                allow_deep_gemm=self.allow_deep_gemm,
            )

891
892
893
894
895
896
897
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
898
        use_grouped_topk: bool = False,
899
900
        topk_group: Optional[int] = None,
        num_expert_group: Optional[int] = None,
901
902
        global_num_experts: int = -1,
        expert_map: Optional[torch.Tensor] = None,
903
        custom_routing_function: Optional[Callable] = None,
Simon Mo's avatar
Simon Mo committed
904
905
        scoring_func: str = "softmax",
        e_score_correction_bias: Optional[torch.Tensor] = None,
906
        apply_router_weight_on_input: bool = False,
Michael Goin's avatar
Michael Goin committed
907
        activation: str = "silu",
908
909
910
911
        enable_eplb: bool = False,
        expert_load_view: Optional[torch.Tensor] = None,
        logical_to_physical_map: Optional[torch.Tensor] = None,
        logical_replica_count: Optional[torch.Tensor] = None,
912
    ) -> torch.Tensor:
913
914
915
916
917
        if enable_eplb:
            assert expert_load_view is not None
            assert logical_to_physical_map is not None
            assert logical_replica_count is not None
            assert isinstance(layer, FusedMoE)
918

919
920
921
922
923
924
925
        topk_weights, topk_ids = FusedMoE.select_experts(
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
926
            num_expert_group=num_expert_group,
Simon Mo's avatar
Simon Mo committed
927
928
929
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
            e_score_correction_bias=e_score_correction_bias,
930
            indices_type=self.topk_indices_dtype,
931
932
933
934
935
            enable_eplb=enable_eplb,
            expert_map=expert_map,
            expert_load_view=expert_load_view,
            logical_to_physical_map=logical_to_physical_map,
            logical_replica_count=logical_replica_count,
Simon Mo's avatar
Simon Mo committed
936
        )
937

938
        if self.rocm_aiter_moe_enabled:
939
940
            from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (  # noqa: E501
                rocm_aiter_fused_experts)
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
            return rocm_aiter_fused_experts(
                x,
                layer.w13_weight,
                layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                activation=activation,
                use_fp8_w8a8=True,
                apply_router_weight_on_input=apply_router_weight_on_input,
                w1_scale=(layer.w13_weight_scale_inv
                          if self.block_quant else layer.w13_weight_scale),
                w2_scale=(layer.w2_weight_scale_inv
                          if self.block_quant else layer.w2_weight_scale),
                a1_scale=layer.w13_input_scale,
                a2_scale=layer.w2_input_scale,
956
957
                block_shape=self.quant_config.weight_block_size,
                expert_map=expert_map)
958
        elif self.use_marlin:
959
960
            assert activation == "silu", (
                f"{activation} not supported for Marlin MoE.")
961
962
963
964
965
966
967
968
969
970
            return torch.ops.vllm.fused_marlin_moe(
                x,
                layer.w13_weight,
                layer.w2_weight,
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                router_logits,
                topk_weights,
                topk_ids,
                quant_type_id=scalar_types.float8_e4m3fn.id,
971
                apply_router_weight_on_input=apply_router_weight_on_input,
972
973
                global_num_experts=global_num_experts,
                expert_map=expert_map)
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
        else:
            return self.fused_experts(
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                inplace=True,
                activation=activation,
                global_num_experts=global_num_experts,
                apply_router_weight_on_input=apply_router_weight_on_input,
                expert_map=expert_map,
                w1_scale=(layer.w13_weight_scale_inv
                          if self.block_quant else layer.w13_weight_scale),
                w2_scale=(layer.w2_weight_scale_inv
                          if self.block_quant else layer.w2_weight_scale),
                a1_scale=layer.w13_input_scale,
                a2_scale=layer.w2_input_scale,
            )
993
994


995
996
997
class Fp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
998
999
1000
    """

    def __init__(self, quant_config: Fp8Config):
1001
        super().__init__(quant_config)