fp8.py 42.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import functools
bnellnm's avatar
bnellnm committed
5
from typing import TYPE_CHECKING, Any, Callable, Optional
6
7

import torch
8
import torch.nn.functional as F
9
10
11
from torch.nn import Module
from torch.nn.parameter import Parameter

12
import vllm.envs as envs
13
from vllm import _custom_ops as ops
14
from vllm.distributed import get_tensor_model_parallel_world_size
15
from vllm.logger import init_logger
bnellnm's avatar
bnellnm committed
16
17
18
19
20
from vllm.model_executor.layers.fused_moe import (
    BatchedTritonOrDeepGemmExperts, FusedMoE, FusedMoEActivationFormat,
    FusedMoEConfig, FusedMoEMethodBase, FusedMoEPermuteExpertsUnpermute,
    FusedMoEPrepareAndFinalize, FusedMoeWeightScaleSupported,
    TritonOrDeepGemmExperts)
21
22
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
                                               UnquantizedLinearMethod)
23
from vllm.model_executor.layers.quantization import QuantizationMethods
24
from vllm.model_executor.layers.quantization.base_config import (
25
    QuantizationConfig, QuantizeMethodBase)
26
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
27
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
28
29
    apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin,
    prepare_moe_fp8_layer_for_marlin)
30
31
from vllm.model_executor.layers.quantization.utils.quant_utils import (
    is_layer_skipped)
32
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
33
34
35
36
    Fp8LinearOp, all_close_1d, cutlass_block_fp8_supported,
    cutlass_fp8_supported, maybe_create_device_identity,
    normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
    requantize_with_max_scale)
37
38
from vllm.model_executor.parameter import (BlockQuantScaleParameter,
                                           ModelWeightParameter,
39
                                           PerTensorScaleParameter)
40
from vllm.model_executor.utils import set_weight_attrs
41
from vllm.platforms import current_platform
42
from vllm.scalar_type import scalar_types
43
from vllm.utils import has_deep_gemm
44

45
46
47
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

48
49
50
51
ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)

52
53
54
55
56
57

def _is_col_major(x: torch.Tensor) -> bool:
    assert x.dim() == 3
    b, m, n = x.shape
    return x.stride(0) == m * n and x.stride(1) == 1 and x.stride(2) == m

58

59
class Fp8Config(QuantizationConfig):
60
61
    """Config class for FP8."""

62
63
    def __init__(
        self,
64
        is_checkpoint_fp8_serialized: bool = False,
65
        activation_scheme: str = "dynamic",
66
67
        ignored_layers: Optional[list[str]] = None,
        weight_block_size: Optional[list[int]] = None,
68
    ) -> None:
69
        super().__init__()
70

71
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
72

73
74
75
        if activation_scheme not in ACTIVATION_SCHEMES:
            raise ValueError(
                f"Unsupported activation scheme {activation_scheme}")
76
        self.activation_scheme = activation_scheme
77
        self.ignored_layers = ignored_layers or []
78
79
80
81
82
83
84
85
86
87
88
89
90
91
        if weight_block_size is not None:
            if not is_checkpoint_fp8_serialized:
                raise ValueError(
                    "The block-wise quantization only supports fp8-serialized "
                    "checkpoint for now.")
            if len(weight_block_size) != 2:
                raise ValueError(
                    "The quantization block size of weight must have 2 "
                    f"dimensions, but got {len(weight_block_size)} dimensions")
            if activation_scheme != "dynamic":
                raise ValueError("The block-wise quantization only supports "
                                 "dynamic activation scheme for now, but got "
                                 f"{activation_scheme} activation scheme.")
        self.weight_block_size = weight_block_size
92

93
    @classmethod
94
    def get_name(cls) -> QuantizationMethods:
95
96
97
        return "fp8"

    @classmethod
98
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
99
100
101
102
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
103
        return 80
104
105

    @classmethod
106
    def get_config_filenames(cls) -> list[str]:
107
108
        return []

109
110
111
112
113
    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if self.ignored_layers is not None:
            self.ignored_layers = hf_to_vllm_mapper.apply_list(
                self.ignored_layers)

114
    @classmethod
115
    def from_config(cls, config: dict[str, Any]) -> "Fp8Config":
116
117
        quant_method = cls.get_from_keys(config, ["quant_method"])
        is_checkpoint_fp8_serialized = ("fp8" in quant_method)
118
        activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
119
        ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
120
121
        weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"],
                                                 None)
122
        return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
123
                   activation_scheme=activation_scheme,
124
125
                   ignored_layers=ignored_layers,
                   weight_block_size=weight_block_size)
126

127
128
    def get_quant_method(self, layer: torch.nn.Module,
                         prefix: str) -> Optional["QuantizeMethodBase"]:
129
130
        from vllm.attention.layer import Attention  # Avoid circular import

131
        if isinstance(layer, LinearBase):
132
133
134
            if is_layer_skipped(prefix=prefix,
                                ignored_layers=self.ignored_layers,
                                fused_mapping=self.packed_modules_mapping):
135
                return UnquantizedLinearMethod()
136
            return Fp8LinearMethod(self)
137
138
139
        elif isinstance(layer, FusedMoE):
            return Fp8MoEMethod(self)
        elif isinstance(layer, Attention):
140
            return Fp8KVCacheMethod(self)
141
        return None
142

143
144
145
146
147
148
149
150
151
152
153
154
155
    def get_cache_scale(self, name: str) -> Optional[str]:
        """
        Check whether the param name matches the format for k/v cache scales
        in compressed-tensors. If this is the case, return its equivalent
        param name expected by vLLM

        :param name: param name
        :return: matching param name for KV cache scale in vLLM
        """
        if name.endswith(".output_scale") and ".k_proj" in name:
            return name.replace(".k_proj.output_scale", ".attn.k_scale")
        if name.endswith(".output_scale") and ".v_proj" in name:
            return name.replace(".v_proj.output_scale", ".attn.v_scale")
156
157
158
159
160
        if name.endswith(".output_scale") and ".q_proj" in name:
            return name.replace(".q_proj.output_scale", ".attn.q_scale")
        if name.endswith("self_attn.prob_output_scale"):
            return name.replace(".prob_output_scale", ".attn.prob_scale")
        # If no matches, return None
161
162
        return None

163
164
165

class Fp8LinearMethod(LinearMethodBase):
    """Linear method for FP8.
166
167
168
169
170
171
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.
172
173
174
175
176

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
    2. Only support float8_e4m3fn data type due to the limitation of
       torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
177

178
179
180
181
    Args:
        quant_config: The quantization config.
    """

182
    def __init__(self, quant_config: Fp8Config):
183
        self.quant_config = quant_config
184
        self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
185
        self.out_dtype = torch.get_default_dtype()
186

187
188
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
189
190
        self.use_marlin = (not current_platform.has_device_capability(89)
                           or envs.VLLM_TEST_FORCE_FP8_MARLIN)
191
        # Disable marlin for rocm
192
        if current_platform.is_rocm():
193
            self.use_marlin = False
194

195
196
197
198
199
200
201
        # AITER is only supported on ROCm and only for FP8_FNUZ
        # and at the moment are MI300 series
        self.use_aiter_and_is_supported = (current_platform.is_rocm()
                                           and envs.VLLM_ROCM_USE_AITER
                                           and envs.VLLM_ROCM_USE_AITER_LINEAR
                                           and current_platform.is_fp8_fnuz())

202
        self.block_quant = self.quant_config.weight_block_size is not None
203
204
205
206
        self.fp8_linear = Fp8LinearOp(
            # Default to using per_token quantization if cutlass is supported
            use_per_token_if_dynamic=cutlass_fp8_supported())

207
208
209
210
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
211
        output_partition_sizes: list[int],
212
213
214
215
216
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
217
218
        maybe_create_device_identity()

219
        output_size_per_partition = sum(output_partition_sizes)
220
        weight_loader = extra_weight_attrs.get("weight_loader")
221
222
223
224
225
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None
226

227
228
229
        if self.block_quant:
            tp_size = get_tensor_model_parallel_world_size()
            assert self.quant_config.weight_block_size is not None
230
            layer.weight_block_size = self.quant_config.weight_block_size
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
            block_n, block_k = (
                self.quant_config.weight_block_size[0],
                self.quant_config.weight_block_size[1],
            )
            # Required by row parallel
            if (tp_size > 1
                    and input_size // input_size_per_partition == tp_size
                    and input_size_per_partition % block_k != 0):
                raise ValueError(
                    f"Weight input_size_per_partition = "
                    f"{input_size_per_partition} is not divisible by "
                    f"weight quantization block_k = {block_k}.")
            # Required by column parallel or enabling merged weights
            if (tp_size > 1 and output_size // output_size_per_partition
                    == tp_size) or len(output_partition_sizes) > 1:
                for output_partition_size in output_partition_sizes:
                    if output_partition_size % block_n != 0:
                        raise ValueError(
                            f"Weight output_partition_size = "
                            f"{output_partition_size} is not divisible by "
                            f"weight quantization block_n = {block_n}.")

253
254
255
256
        # WEIGHT
        weight_dtype = (torch.float8_e4m3fn
                        if self.quant_config.is_checkpoint_fp8_serialized else
                        params_dtype)
257
258
259
260
261
262
263
264

        weight = ModelWeightParameter(data=torch.empty(
            output_size_per_partition,
            input_size_per_partition,
            dtype=weight_dtype),
                                      input_dim=1,
                                      output_dim=0,
                                      weight_loader=weight_loader)
265
266
        layer.register_parameter("weight", weight)

267
268
269
270
        # If checkpoint is serialized fp8, load them.
        # Otherwise, wait until process_weights_after_loading.
        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
271
272
273
274
275
276
277
            if not self.block_quant:
                scale = PerTensorScaleParameter(
                    data=torch.empty(len(output_partition_sizes),
                                     dtype=torch.float32),
                    weight_loader=weight_loader,
                )
                scale[:] = torch.finfo(torch.float32).min
278
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
279
280
281
282
283
284
285
286
287
288
289
290
291
292
                layer.register_parameter("weight_scale", scale)
            else:
                assert self.quant_config.activation_scheme == "dynamic"
                scale = BlockQuantScaleParameter(
                    data=torch.empty(
                        (output_size_per_partition + block_n - 1) // block_n,
                        (input_size_per_partition + block_k - 1) // block_k,
                        dtype=torch.float32,
                    ),
                    input_dim=1,
                    output_dim=0,
                    weight_loader=weight_loader,
                )
                scale[:] = torch.finfo(torch.float32).min
293
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
294
295
                # The weight_scale_inv name is intentional for deepseekv3
                layer.register_parameter("weight_scale_inv", scale)
296

297
            # INPUT ACTIVATION SCALE
298
            if self.quant_config.activation_scheme == "static":
299
300
301
302
303
                scale = PerTensorScaleParameter(data=torch.empty(
                    len(output_partition_sizes), dtype=torch.float32),
                                                weight_loader=weight_loader)

                scale[:] = torch.finfo(torch.float32).min
304
                set_weight_attrs(scale, {"scale_type": "input_scale"})
305
                layer.register_parameter("input_scale", scale)
306
307
            else:
                layer.register_parameter("input_scale", None)
308

309
    def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor:
310
311
312
313
314
315
316
317
318
319
        # Pad the weight tensor. This is an optimization on ROCm platform, which
        # can benefit from tensors located far enough from one another in memory
        if (envs.VLLM_ROCM_FP8_PADDING and current_platform.is_rocm()
                and weight.stride(-1) == 1
                and (weight.stride(-2) * weight.element_size()) % 512 == 0):
            num_pad = 256 // weight.element_size()
            weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
            torch.cuda.empty_cache()
        return weight

320
    def process_weights_after_loading(self, layer: Module) -> None:
321
        size_k_first = True
322
        # TODO(rob): refactor block quant into separate class.
323
        if self.block_quant:
324
            assert self.quant_config.activation_scheme == "dynamic"
325
            size_k_first = False
326
            if current_platform.is_fp8_fnuz():
327
                weight, weight_scale_inv, _ = \
328
329
                    normalize_e4m3fn_to_e4m3fnuz(
                        weight=layer.weight,
330
331
332
333
334
                        weight_scale=layer.weight_scale_inv)
            else:
                weight = layer.weight.data
                weight_scale_inv = layer.weight_scale_inv.data

335
            weight = self._maybe_pad_weight(weight)
336

337
338
339
340
341
            # Torch.compile cannot use Parameter subclasses.
            layer.weight = Parameter(weight, requires_grad=False)
            layer.weight_scale_inv = Parameter(weight_scale_inv,
                                               requires_grad=False)

342
        # If checkpoint not serialized fp8, quantize the weights.
343
        elif not self.quant_config.is_checkpoint_fp8_serialized:
344
345
            qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
                                                         scale=None)
346
347

            # Update the layer with the new values.
348
349
            layer.weight = Parameter(qweight.t(), requires_grad=False)
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
350
            layer.input_scale = None
351

352
353
        # If checkpoint is fp8, handle that there are N scales for N
        # shards in a fused module
354
        else:
355
356
357
358
359
            layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
                                                    requires_grad=False)
            if self.quant_config.activation_scheme == "static":
                layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
                                                       requires_grad=False)
360
361
362

            weight = layer.weight
            weight_scale = layer.weight_scale
363
364
365

            # If using w8a8, torch._scaled_mm needs per tensor, so
            # requantize the logical shards as a single weight.
366
            if not self.use_marlin:
367
                # Dequant -> Quant with max scale so we can run per tensor.
368
                if current_platform.is_fp8_fnuz():
369
370
371
372
373
374
375
376
377
                    weight, weight_scale, input_scale = \
                        normalize_e4m3fn_to_e4m3fnuz(
                            weight=weight,
                            weight_scale=weight_scale,
                            input_scale=layer.input_scale)
                    if input_scale is not None:
                        layer.input_scale = Parameter(input_scale,
                                                      requires_grad=False)

378
                weight_scale, weight = requantize_with_max_scale(
379
380
                    weight=weight,
                    weight_scale=weight_scale,
381
382
                    logical_widths=layer.logical_widths,
                )
383

384
            weight = self._maybe_pad_weight(weight)
385
            # Update layer with new values.
386
            layer.weight = Parameter(weight.t(), requires_grad=False)
387
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
388
            if self.quant_config.activation_scheme == "static":
389
390
                layer.input_scale = Parameter(layer.input_scale.max(),
                                              requires_grad=False)
391

392
        if self.use_marlin:
393
            prepare_fp8_layer_for_marlin(layer, size_k_first)
394
395
            # Activations not quantized for marlin.
            del layer.input_scale
396

397
398
399
400
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
401

402
        if self.use_marlin:
403
404
405
406
            return apply_fp8_marlin_linear(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale,
407
408
409
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
410
                bias=bias)
411

412
413
        if self.block_quant:
            assert self.quant_config.weight_block_size is not None
414

415
            return torch.ops.vllm.apply_w8a8_block_fp8_linear(
416
417
418
419
420
421
                input=x,
                weight=layer.weight,
                block_size=self.quant_config.weight_block_size,
                weight_scale=layer.weight_scale_inv,
                input_scale=layer.input_scale,
                bias=bias,
422
                cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
423
                use_aiter_and_is_supported=self.use_aiter_and_is_supported,
424
425
            )

426
427
428
        return self.fp8_linear.apply(input=x,
                                     weight=layer.weight,
                                     weight_scale=layer.weight_scale,
429
                                     out_dtype=self.out_dtype,
430
431
                                     input_scale=layer.input_scale,
                                     bias=bias)
432
433


434
435
436
437
438
439
440
441
442
443
444
445
446
447
class Fp8MoEMethod(FusedMoEMethodBase):
    """MoE method for FP8.
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

    def __init__(self, quant_config: Fp8Config):
448

449
        from vllm.model_executor.layers.fused_moe import fused_experts
450
        self.quant_config = quant_config
451
        self.block_quant = self.quant_config.weight_block_size is not None
452

453
454
455
456
457
458
459
460
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
        self.use_marlin = (not current_platform.has_device_capability(89)
                           or envs.VLLM_TEST_FORCE_FP8_MARLIN)
        # Disable marlin for rocm
        if current_platform.is_rocm():
            self.use_marlin = False

461
462
463
        # Check for DeepGemm support.
        self.allow_deep_gemm = False
        if envs.VLLM_USE_DEEP_GEMM:
464
            if not has_deep_gemm():
465
                logger.warning_once("Failed to import DeepGemm kernels.")
466
467
468
            elif not self.block_quant:
                logger.warning_once("Model is not block quantized. Not using "
                                    " DeepGemm kernels")
469
470
471
472
473
474
475
476
            elif (current_platform.is_cuda()
                  and current_platform.has_device_capability(90)):
                logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.")
                self.allow_deep_gemm = True
            else:
                logger.warning_once(
                    "DeepGemm not supported on the current platform.")

477
        self.topk_indices_dtype = None
478
        self.fused_experts = functools.partial(  # type: ignore
479
            fused_experts,
480
            use_fp8_w8a8=True,
481
482
483
            block_shape=self.quant_config.weight_block_size,
            allow_deep_gemm=self.allow_deep_gemm)

484
    def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
485
486
                       intermediate_size_per_partition: int,
                       params_dtype: torch.dtype, **extra_weight_attrs):
487

488
489
490
491
492
493
        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

494
495
        if self.quant_config.is_checkpoint_fp8_serialized:
            params_dtype = torch.float8_e4m3fn
496
497
        if self.block_quant:
            assert self.quant_config.weight_block_size is not None
498
            layer.weight_block_size = self.quant_config.weight_block_size
499
500
501
502
503
504
505
506
507
            tp_size = get_tensor_model_parallel_world_size()
            block_n, block_k = (
                self.quant_config.weight_block_size[0],
                self.quant_config.weight_block_size[1],
            )
            # NOTE: To ensure proper alignment of the block-wise quantization
            # scales, the output_size of the weights for both the gate and up
            # layers must be divisible by block_n.
            # Required by column parallel or enabling merged weights
508
            if intermediate_size_per_partition % block_n != 0:
509
510
                raise ValueError(
                    f"The output_size of gate's and up's weight = "
511
                    f"{intermediate_size_per_partition} is not divisible by "
512
                    f"weight quantization block_n = {block_n}.")
513
514
            if (tp_size > 1
                    and intermediate_size_per_partition % block_k != 0):
515
                # Required by row parallel
516
517
518
519
                raise ValueError(
                    f"The input_size of down's weight = "
                    f"{intermediate_size_per_partition} is not divisible by "
                    f"weight quantization block_k = {block_k}.")
520
521

        # WEIGHTS
522
523
524
525
526
        w13_weight = torch.nn.Parameter(torch.empty(
            num_experts,
            2 * intermediate_size_per_partition,
            hidden_size,
            dtype=params_dtype),
527
528
529
530
                                        requires_grad=False)
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

531
532
533
534
535
        w2_weight = torch.nn.Parameter(torch.empty(
            num_experts,
            hidden_size,
            intermediate_size_per_partition,
            dtype=params_dtype),
536
537
538
539
540
                                       requires_grad=False)
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
        if not self.block_quant:
            # Allocate 2 scales for w1 and w3 respectively.
            # They will be combined to a single scale after weight loading.
            w13_weight_scale = torch.nn.Parameter(torch.ones(
                num_experts, 2, dtype=torch.float32),
                                                  requires_grad=False)
            w2_weight_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                 requires_grad=False)
            layer.register_parameter("w13_weight_scale", w13_weight_scale)
            layer.register_parameter("w2_weight_scale", w2_weight_scale)
        else:
            w13_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
556
557
                    2 * ((intermediate_size_per_partition + block_n - 1) //
                         block_n),
558
559
560
561
562
563
564
565
566
                    (hidden_size + block_k - 1) // block_k,
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            w2_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
                    (hidden_size + block_n - 1) // block_n,
567
                    (intermediate_size_per_partition + block_k - 1) // block_k,
568
569
570
571
572
573
574
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
            layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
            assert self.quant_config.activation_scheme == "dynamic"
575

576
577
578
        # Add the quantization method used (per tensor/grouped/channel)
        # to ensure the weight scales are loaded in properly
        extra_weight_attrs.update(
579
580
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.
             value} if self.block_quant else
581
            {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
582
583
584
585
        # If loading fp8 checkpoint, pass the weight loaders.
        # If loading an fp16 checkpoint, do not (we will quantize in
        #   process_weights_after_loading()
        if self.quant_config.is_checkpoint_fp8_serialized:
586
587
            set_weight_attrs(w13_weight_scale, extra_weight_attrs)
            set_weight_attrs(w2_weight_scale, extra_weight_attrs)
588
589
590
591
592
593
594
595

        # INPUT_SCALES
        if self.quant_config.activation_scheme == "static":
            if not self.quant_config.is_checkpoint_fp8_serialized:
                raise ValueError(
                    "Found static activation scheme for checkpoint that "
                    "was not serialized fp8.")

596
597
598
599
            w13_input_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                 requires_grad=False)
            layer.register_parameter("w13_input_scale", w13_input_scale)
600
            set_weight_attrs(w13_input_scale, extra_weight_attrs)
601
602
603
604
605

            w2_input_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                requires_grad=False)
            layer.register_parameter("w2_input_scale", w2_input_scale)
606
607
            set_weight_attrs(w2_input_scale, extra_weight_attrs)

608
        else:
609
610
            layer.w13_input_scale = None
            layer.w2_input_scale = None
611
612

    def process_weights_after_loading(self, layer: Module) -> None:
613
614
        # Lazy import to avoid importing triton too early.
        from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
615
            is_rocm_aiter_moe_enabled, shuffle_weights)
616

617
618
        self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()

619
        # TODO (rob): refactor block quant into separate class.
620
        if self.block_quant:
621
            assert self.quant_config.activation_scheme == "dynamic"
622
            if current_platform.is_fp8_fnuz():
623
624
625
626
627
628
629
630
                w13_weight, w13_weight_scale_inv, w13_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w13_weight, layer.w13_weight_scale_inv,
                        layer.w13_input_scale)
                w2_weight, w2_weight_scale_inv, w2_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w2_weight, layer.w2_weight_scale_inv,
                        layer.w2_input_scale)
631
632
633
634
635
636
637
638
639
640
641
642
643
            else:
                w13_weight = layer.w13_weight.data
                w13_weight_scale_inv = layer.w13_weight_scale_inv.data
                w2_weight = layer.w2_weight
                w2_weight_scale_inv = layer.w2_weight_scale_inv

            # torch.compile() cannot use Parameter subclasses.
            layer.w13_weight = Parameter(w13_weight, requires_grad=False)
            layer.w13_weight_scale_inv = Parameter(w13_weight_scale_inv,
                                                   requires_grad=False)
            layer.w2_weight = Parameter(w2_weight, requires_grad=False)
            layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv,
                                                  requires_grad=False)
644
            if self.rocm_aiter_moe_enabled:
645
646
                # reshaping weights is required for aiter moe kernel.
                shuffled_w13, shuffled_w2 = shuffle_weights(
647
                    layer.w13_weight.data, layer.w2_weight.data)
648
649
650
651
652

                layer.w13_weight = torch.nn.Parameter(shuffled_w13,
                                                      requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2,
                                                     requires_grad=False)
653
654
655
656
657
658
659
660
661
662
663
664
665

            # DeepGemm scales need to be transposed and aligned.  We try to do
            # it ahead of time for performance reasons.
            if self.allow_deep_gemm:
                # Lazy import to avoid CUDA initialization problems.
                import deep_gemm as dg
                if _is_col_major(layer.w13_weight_scale_inv):
                    layer.w13_weight_scale_inv = \
                        dg.get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv).contiguous()
                if _is_col_major(layer.w2_weight_scale_inv):
                    layer.w2_weight_scale_inv = \
                        dg.get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv).contiguous()

666
        # If checkpoint is fp16, quantize in place.
667
        elif not self.quant_config.is_checkpoint_fp8_serialized:
668
            fp8_dtype = current_platform.fp8_dtype()
669
            w13_weight = torch.empty_like(layer.w13_weight.data,
670
671
                                          dtype=fp8_dtype)
            w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
672
673
674

            # Re-initialize w13_scale because we directly quantize
            # merged w13 weights and generate a single scaling factor.
675
            layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
676
                layer.local_num_experts,
677
678
                dtype=torch.float32,
                device=w13_weight.device),
679
                                                        requires_grad=False)
680
            for expert in range(layer.local_num_experts):
681
                w13_weight[expert, :, :], layer.w13_weight_scale[
682
683
                    expert] = ops.scaled_fp8_quant(
                        layer.w13_weight.data[expert, :, :])
684
                w2_weight[expert, :, :], layer.w2_weight_scale[
685
686
687
688
689
690
                    expert] = ops.scaled_fp8_quant(
                        layer.w2_weight.data[expert, :, :])
            layer.w13_weight = torch.nn.Parameter(w13_weight,
                                                  requires_grad=False)
            layer.w2_weight = torch.nn.Parameter(w2_weight,
                                                 requires_grad=False)
691
            if self.rocm_aiter_moe_enabled:
692
                # reshaping weights is required for aiter moe kernel.
693
694
                shuffled_w13, shuffled_w2 = shuffle_weights(
                    layer.w13_weight, layer.w2_weight)
695
696
697
698
699

                layer.w13_weight = torch.nn.Parameter(shuffled_w13,
                                                      requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2,
                                                     requires_grad=False)
700
701
702
703
704
705
706
        # If checkpoint is fp8, we need to handle that the
        # MoE kernels require single activation scale and single weight
        # scale for w13 per expert.
        else:
            # Fp8 moe kernels require a single activation scale.
            # We take the max of all the scales in case they differ.
            if self.quant_config.activation_scheme == "static":
707
708
                if (layer.w13_input_scale is None
                        or layer.w2_input_scale is None):
709
710
711
                    raise ValueError(
                        "QuantConfig has static quantization, but found "
                        "activation scales are None.")
712
713
                if (not all_close_1d(layer.w13_input_scale)
                        or not all_close_1d(layer.w2_input_scale)):
714
                    logger.warning_once(
715
716
                        "Found input_scales that are not equal for "
                        "fp8 MoE layer. Using the maximum across experts "
717
                        "for each layer.")
718
719
720
721
                layer.w13_input_scale = torch.nn.Parameter(
                    layer.w13_input_scale.max(), requires_grad=False)
                layer.w2_input_scale = torch.nn.Parameter(
                    layer.w2_input_scale.max(), requires_grad=False)
722
            if current_platform.is_fp8_fnuz():
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
                # Normalize the weights and scales
                w13_weight, w13_weight_scale, w13_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w13_weight, layer.w13_weight_scale,
                        layer.w13_input_scale)
                w2_weight, w2_weight_scale, w2_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w2_weight, layer.w2_weight_scale,
                        layer.w2_input_scale)
                # Reset the parameter
                layer.w13_weight = torch.nn.Parameter(w13_weight,
                                                      requires_grad=False)
                layer.w13_weight_scale = torch.nn.Parameter(
                    w13_weight_scale, requires_grad=False)
                if w13_input_scale is not None:
                    layer.w13_input_scale = torch.nn.Parameter(
                        w13_input_scale, requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(w2_weight,
                                                     requires_grad=False)
                layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale,
                                                           requires_grad=False)
                if w2_input_scale is not None:
                    layer.w2_input_scale = torch.nn.Parameter(
                        w2_input_scale, requires_grad=False)
747
748
749

            # Fp8 moe kernel needs single weight scale for w13 per expert.
            # We take the max then dequant and requant each expert.
750
            assert layer.w13_weight_scale is not None
751
            shard_size = layer.intermediate_size_per_partition
752
            max_w13_scales = layer.w13_weight_scale.max(dim=1).values
753
            for expert_id in range(layer.local_num_experts):
754
755
756
757
758
                start = 0
                for shard_id in range(2):
                    dq_weight = per_tensor_dequantize(
                        layer.w13_weight[expert_id][start:start +
                                                    shard_size, :],
759
                        layer.w13_weight_scale[expert_id][shard_id])
760
                    layer.w13_weight[expert_id][
761
                        start:start + shard_size, :], _ = ops.scaled_fp8_quant(
762
763
764
                            dq_weight, max_w13_scales[expert_id])
                    start += shard_size

765
            if self.rocm_aiter_moe_enabled:
766
767
                shuffled_w13, shuffled_w2 = shuffle_weights(
                    layer.w13_weight, layer.w2_weight)
768
769
770
771
772
773

                layer.w13_weight = torch.nn.Parameter(shuffled_w13,
                                                      requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2,
                                                     requires_grad=False)

774
775
            layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
                                                        requires_grad=False)
776
777
778
779
780
781

        if self.use_marlin:
            prepare_moe_fp8_layer_for_marlin(layer, False)
            # Activations not quantized for marlin.
            del layer.w13_input_scale
            del layer.w2_input_scale
782

bnellnm's avatar
bnellnm committed
783
784
785
786
787
    def select_gemm_impl(
        self,
        prepare_finalize: FusedMoEPrepareAndFinalize,
        moe: FusedMoEConfig,
    ) -> FusedMoEPermuteExpertsUnpermute:
788
789
        assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
            "Marlin and ROCm AITER are not supported with all2all yet.")
790

bnellnm's avatar
bnellnm committed
791
792
793
794
795
796
797
798
799
800
801
        if (prepare_finalize.activation_format ==
                FusedMoEActivationFormat.BatchedExperts):
            max_num_tokens_per_rank = (
                prepare_finalize.max_num_tokens_per_rank())
            assert max_num_tokens_per_rank is not None
            logger.debug(
                "BatchedTritonOrDeepGemmExperts(%s): "
                "max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
                self.__class__.__name__, max_num_tokens_per_rank,
                self.quant_config.weight_block_size, False)
            return BatchedTritonOrDeepGemmExperts(
802
                max_num_tokens=max_num_tokens_per_rank,
bnellnm's avatar
bnellnm committed
803
804
805
806
                world_size=prepare_finalize.
                world_size,  # type: ignore [attr-defined]
                dp_size=prepare_finalize.
                dp_size,  # type: ignore [attr-defined]
807
                use_fp8_w8a8=True,
808
                block_shape=self.quant_config.weight_block_size,
bnellnm's avatar
bnellnm committed
809
                per_act_token_quant=False,
810
                allow_deep_gemm=self.allow_deep_gemm,
811
812
            )
        else:
bnellnm's avatar
bnellnm committed
813
814
815
816
817
            logger.debug(
                "TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s",
                self.__class__.__name__, self.quant_config.weight_block_size,
                False)
            return TritonOrDeepGemmExperts(
818
819
820
821
822
                use_fp8_w8a8=True,
                block_shape=self.quant_config.weight_block_size,
                allow_deep_gemm=self.allow_deep_gemm,
            )

823
824
825
826
827
828
829
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
830
        use_grouped_topk: bool = False,
831
832
        topk_group: Optional[int] = None,
        num_expert_group: Optional[int] = None,
833
834
        global_num_experts: int = -1,
        expert_map: Optional[torch.Tensor] = None,
835
        custom_routing_function: Optional[Callable] = None,
Simon Mo's avatar
Simon Mo committed
836
837
        scoring_func: str = "softmax",
        e_score_correction_bias: Optional[torch.Tensor] = None,
838
        apply_router_weight_on_input: bool = False,
Michael Goin's avatar
Michael Goin committed
839
        activation: str = "silu",
840
841
842
843
        enable_eplb: bool = False,
        expert_load_view: Optional[torch.Tensor] = None,
        logical_to_physical_map: Optional[torch.Tensor] = None,
        logical_replica_count: Optional[torch.Tensor] = None,
844
    ) -> torch.Tensor:
845
846
847
848
849
        if enable_eplb:
            assert expert_load_view is not None
            assert logical_to_physical_map is not None
            assert logical_replica_count is not None
            assert isinstance(layer, FusedMoE)
850

851
852
853
854
855
856
857
        topk_weights, topk_ids = FusedMoE.select_experts(
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
858
            num_expert_group=num_expert_group,
Simon Mo's avatar
Simon Mo committed
859
860
861
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
            e_score_correction_bias=e_score_correction_bias,
862
            indices_type=self.topk_indices_dtype,
863
864
865
866
867
            enable_eplb=enable_eplb,
            expert_map=expert_map,
            expert_load_view=expert_load_view,
            logical_to_physical_map=logical_to_physical_map,
            logical_replica_count=logical_replica_count,
Simon Mo's avatar
Simon Mo committed
868
        )
869

870
        if self.rocm_aiter_moe_enabled:
871
872
            from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (  # noqa: E501
                rocm_aiter_fused_experts)
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
            return rocm_aiter_fused_experts(
                x,
                layer.w13_weight,
                layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                activation=activation,
                use_fp8_w8a8=True,
                apply_router_weight_on_input=apply_router_weight_on_input,
                w1_scale=(layer.w13_weight_scale_inv
                          if self.block_quant else layer.w13_weight_scale),
                w2_scale=(layer.w2_weight_scale_inv
                          if self.block_quant else layer.w2_weight_scale),
                a1_scale=layer.w13_input_scale,
                a2_scale=layer.w2_input_scale,
888
889
                block_shape=self.quant_config.weight_block_size,
                expert_map=expert_map)
890
        elif self.use_marlin:
891
892
893
894
            assert activation == "silu", (
                f"{activation} not supported for Marlin MoE.")
            assert not apply_router_weight_on_input, (
                "Apply router weight on input not supported for Marlin MoE.")
895
896
897
898
899
900
901
902
903
904
905
906
            return torch.ops.vllm.fused_marlin_moe(
                x,
                layer.w13_weight,
                layer.w2_weight,
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                router_logits,
                topk_weights,
                topk_ids,
                quant_type_id=scalar_types.float8_e4m3fn.id,
                global_num_experts=global_num_experts,
                expert_map=expert_map)
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
        else:
            return self.fused_experts(
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                inplace=True,
                activation=activation,
                global_num_experts=global_num_experts,
                apply_router_weight_on_input=apply_router_weight_on_input,
                expert_map=expert_map,
                w1_scale=(layer.w13_weight_scale_inv
                          if self.block_quant else layer.w13_weight_scale),
                w2_scale=(layer.w2_weight_scale_inv
                          if self.block_quant else layer.w2_weight_scale),
                a1_scale=layer.w13_input_scale,
                a2_scale=layer.w2_input_scale,
            )
926
927


928
929
930
class Fp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
931
932
933
    """

    def __init__(self, quant_config: Fp8Config):
934
        super().__init__(quant_config)