fp8.py 32.5 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from typing import Any, Callable, Dict, List, Optional
4
5

import torch
6
import torch.nn.functional as F
7
8
9
from torch.nn import Module
from torch.nn.parameter import Parameter

10
import vllm.envs as envs
11
from vllm import _custom_ops as ops
12
from vllm.distributed import get_tensor_model_parallel_world_size
13
from vllm.logger import init_logger
14
15
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
                                                  FusedMoeWeightScaleSupported)
16
17
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
                                               UnquantizedLinearMethod)
18
from vllm.model_executor.layers.quantization.base_config import (
19
    QuantizationConfig, QuantizeMethodBase)
20
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
21
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
22
    apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
23
24
from vllm.model_executor.layers.quantization.utils.quant_utils import (
    is_layer_skipped)
25
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
26
    all_close_1d, apply_fp8_linear, convert_to_channelwise,
27
    cutlass_block_fp8_supported, cutlass_fp8_supported,
28
29
    maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz,
    per_tensor_dequantize, requantize_with_max_scale)
30
31
from vllm.model_executor.parameter import (BlockQuantScaleParameter,
                                           ModelWeightParameter,
32
                                           PerTensorScaleParameter)
33
from vllm.model_executor.utils import set_weight_attrs
34
from vllm.platforms import current_platform
35

36
37
38
39
ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)

40

41
class Fp8Config(QuantizationConfig):
42
43
    """Config class for FP8."""

44
45
    def __init__(
        self,
46
        is_checkpoint_fp8_serialized: bool = False,
47
        activation_scheme: str = "dynamic",
48
        ignored_layers: Optional[List[str]] = None,
49
        weight_block_size: Optional[List[int]] = None,
50
    ) -> None:
51
        super().__init__()
52
53
54
55
56
57
58
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
        if is_checkpoint_fp8_serialized:
            logger.warning("Detected fp8 checkpoint. Please note that the "
                           "format is experimental and subject to change.")
        if activation_scheme not in ACTIVATION_SCHEMES:
            raise ValueError(
                f"Unsupported activation scheme {activation_scheme}")
59
        self.activation_scheme = activation_scheme
60
        self.ignored_layers = ignored_layers or []
61
62
63
64
65
66
67
68
69
70
71
72
73
74
        if weight_block_size is not None:
            if not is_checkpoint_fp8_serialized:
                raise ValueError(
                    "The block-wise quantization only supports fp8-serialized "
                    "checkpoint for now.")
            if len(weight_block_size) != 2:
                raise ValueError(
                    "The quantization block size of weight must have 2 "
                    f"dimensions, but got {len(weight_block_size)} dimensions")
            if activation_scheme != "dynamic":
                raise ValueError("The block-wise quantization only supports "
                                 "dynamic activation scheme for now, but got "
                                 f"{activation_scheme} activation scheme.")
        self.weight_block_size = weight_block_size
75

76
77
78
79
80
81
82
83
84
85
    @classmethod
    def get_name(cls) -> str:
        return "fp8"

    @classmethod
    def get_supported_act_dtypes(cls) -> List[torch.dtype]:
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
86
        return 80
87
88
89
90
91
92

    @classmethod
    def get_config_filenames(cls) -> List[str]:
        return []

    @classmethod
93
    def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
94
95
        quant_method = cls.get_from_keys(config, ["quant_method"])
        is_checkpoint_fp8_serialized = ("fp8" in quant_method)
96
        activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
97
        ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
98
99
        weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"],
                                                 None)
100
        return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
101
                   activation_scheme=activation_scheme,
102
103
                   ignored_layers=ignored_layers,
                   weight_block_size=weight_block_size)
104

105
106
    def get_quant_method(self, layer: torch.nn.Module,
                         prefix: str) -> Optional["QuantizeMethodBase"]:
107
108
        from vllm.attention.layer import Attention  # Avoid circular import

109
        if isinstance(layer, LinearBase):
110
111
            if is_layer_skipped(prefix, self.ignored_layers):
                return UnquantizedLinearMethod()
112
            return Fp8LinearMethod(self)
113
114
115
        elif isinstance(layer, FusedMoE):
            return Fp8MoEMethod(self)
        elif isinstance(layer, Attention):
116
            return Fp8KVCacheMethod(self)
117
        return None
118
119
120
121


class Fp8LinearMethod(LinearMethodBase):
    """Linear method for FP8.
122
123
124
125
126
127
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.
128
129
130
131
132

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
    2. Only support float8_e4m3fn data type due to the limitation of
       torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
133

134
135
136
137
    Args:
        quant_config: The quantization config.
    """

138
    def __init__(self, quant_config: Fp8Config):
139
        self.quant_config = quant_config
140
        self.cutlass_fp8_supported = cutlass_fp8_supported()
141
        self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
142

143
144
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
145
146
        self.use_marlin = (not current_platform.has_device_capability(89)
                           or envs.VLLM_TEST_FORCE_FP8_MARLIN)
147
        # Disable marlin for rocm
148
        if current_platform.is_rocm():
149
            self.use_marlin = False
150

151
152
153
154
155
        self.block_quant = self.quant_config.weight_block_size is not None
        if self.block_quant:
            # Marlin doesn't support block-wise fp8
            self.use_marlin = False

156
157
158
159
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
160
        output_partition_sizes: List[int],
161
162
163
164
165
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
166
167
        maybe_create_device_identity()

168
        output_size_per_partition = sum(output_partition_sizes)
169
        weight_loader = extra_weight_attrs.get("weight_loader")
170

171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
        if self.block_quant:
            tp_size = get_tensor_model_parallel_world_size()
            assert self.quant_config.weight_block_size is not None
            block_n, block_k = (
                self.quant_config.weight_block_size[0],
                self.quant_config.weight_block_size[1],
            )
            # Required by row parallel
            if (tp_size > 1
                    and input_size // input_size_per_partition == tp_size
                    and input_size_per_partition % block_k != 0):
                raise ValueError(
                    f"Weight input_size_per_partition = "
                    f"{input_size_per_partition} is not divisible by "
                    f"weight quantization block_k = {block_k}.")
            # Required by column parallel or enabling merged weights
            if (tp_size > 1 and output_size // output_size_per_partition
                    == tp_size) or len(output_partition_sizes) > 1:
                for output_partition_size in output_partition_sizes:
                    if output_partition_size % block_n != 0:
                        raise ValueError(
                            f"Weight output_partition_size = "
                            f"{output_partition_size} is not divisible by "
                            f"weight quantization block_n = {block_n}.")

196
197
        layer.logical_widths = output_partition_sizes

198
199
200
201
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype

202
203
204
205
        # WEIGHT
        weight_dtype = (torch.float8_e4m3fn
                        if self.quant_config.is_checkpoint_fp8_serialized else
                        params_dtype)
206
207
208
209
210
211
212
213

        weight = ModelWeightParameter(data=torch.empty(
            output_size_per_partition,
            input_size_per_partition,
            dtype=weight_dtype),
                                      input_dim=1,
                                      output_dim=0,
                                      weight_loader=weight_loader)
214
215
        layer.register_parameter("weight", weight)

216
217
218
219
        # If checkpoint is serialized fp8, load them.
        # Otherwise, wait until process_weights_after_loading.
        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
            if not self.block_quant:
                scale = PerTensorScaleParameter(
                    data=torch.empty(len(output_partition_sizes),
                                     dtype=torch.float32),
                    weight_loader=weight_loader,
                )
                scale[:] = torch.finfo(torch.float32).min
                layer.register_parameter("weight_scale", scale)
            else:
                assert self.quant_config.activation_scheme == "dynamic"
                scale = BlockQuantScaleParameter(
                    data=torch.empty(
                        (output_size_per_partition + block_n - 1) // block_n,
                        (input_size_per_partition + block_k - 1) // block_k,
                        dtype=torch.float32,
                    ),
                    input_dim=1,
                    output_dim=0,
                    weight_loader=weight_loader,
                )
                scale[:] = torch.finfo(torch.float32).min
                # The weight_scale_inv name is intentional for deepseekv3
                layer.register_parameter("weight_scale_inv", scale)
243

244
            # INPUT ACTIVATION SCALE
245
            if self.quant_config.activation_scheme == "static":
246
247
248
249
250
                scale = PerTensorScaleParameter(data=torch.empty(
                    len(output_partition_sizes), dtype=torch.float32),
                                                weight_loader=weight_loader)

                scale[:] = torch.finfo(torch.float32).min
251
                layer.register_parameter("input_scale", scale)
252
253
            else:
                layer.register_parameter("input_scale", None)
254

255
256
257
258
259
260
261
262
263
264
265
    def add_padding_to_weight(self, weight: torch.Tensor) -> torch.Tensor:
        # Pad the weight tensor. This is an optimization on ROCm platform, which
        # can benefit from tensors located far enough from one another in memory
        if (envs.VLLM_ROCM_FP8_PADDING and current_platform.is_rocm()
                and weight.stride(-1) == 1
                and (weight.stride(-2) * weight.element_size()) % 512 == 0):
            num_pad = 256 // weight.element_size()
            weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
            torch.cuda.empty_cache()
        return weight

266
    def process_weights_after_loading(self, layer: Module) -> None:
267
        # TODO(rob): refactor block quant into separate class.
268
        if self.block_quant:
269
            assert self.quant_config.activation_scheme == "dynamic"
270
            if current_platform.is_rocm():
271
                weight, weight_scale_inv, _ = \
272
273
                    normalize_e4m3fn_to_e4m3fnuz(
                        weight=layer.weight,
274
275
276
277
278
                        weight_scale=layer.weight_scale_inv)
            else:
                weight = layer.weight.data
                weight_scale_inv = layer.weight_scale_inv.data

279
280
            weight = self.add_padding_to_weight(weight)

281
282
283
284
            # Torch.compile cannot use Parameter subclasses.
            layer.weight = Parameter(weight, requires_grad=False)
            layer.weight_scale_inv = Parameter(weight_scale_inv,
                                               requires_grad=False)
285
            return
286

287
        # If checkpoint not serialized fp8, quantize the weights.
288
289
290
        if not self.quant_config.is_checkpoint_fp8_serialized:
            qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
                                                         scale=None)
291

292
293
294
295
296
297
298
299
            # If using marlin (w8a16), kernel uses channelwise weights,
            # so extend the weight scales to be channelwise.
            if self.use_marlin:
                assert weight_scale.numel() == 1
                weight_scale = convert_to_channelwise(
                    weight_scale.expand(len(layer.logical_widths)),
                    layer.logical_widths)

300
            # Update the layer with the new values.
301
302
            layer.weight = Parameter(qweight.t(), requires_grad=False)
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
303
            layer.input_scale = None
304

305
306
        # If checkpoint is fp8, handle that there are N scales for N
        # shards in a fused module
307
        else:
308
309
310
311
312
            layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
                                                    requires_grad=False)
            if self.quant_config.activation_scheme == "static":
                layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
                                                       requires_grad=False)
313
314
315
316
317
318
319
320
321
322
323
            # If using marlin (w8a16), kernel uses channelwise weights,
            # so extend the weight scales to be channelwise.
            if self.use_marlin:
                weight = layer.weight
                weight_scale = convert_to_channelwise(layer.weight_scale,
                                                      layer.logical_widths)

            # If using w8a8, torch._scaled_mm needs per tensor, so
            # requantize the logical shards as a single weight.
            else:
                # Dequant -> Quant with max scale so we can run per tensor.
324
325
326
327
                weight = layer.weight
                weight_scale = layer.weight_scale

                # If rocm, use float8_e4m3fnuz.
328
                if current_platform.is_rocm():
329
330
331
332
333
334
335
336
337
                    weight, weight_scale, input_scale = \
                        normalize_e4m3fn_to_e4m3fnuz(
                            weight=weight,
                            weight_scale=weight_scale,
                            input_scale=layer.input_scale)
                    if input_scale is not None:
                        layer.input_scale = Parameter(input_scale,
                                                      requires_grad=False)

338
                weight_scale, weight = requantize_with_max_scale(
339
340
                    weight=weight,
                    weight_scale=weight_scale,
341
342
                    logical_widths=layer.logical_widths,
                )
343

344
            weight = self.add_padding_to_weight(weight)
345
            # Update layer with new values.
346
            layer.weight = Parameter(weight.t(), requires_grad=False)
347
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
348
            if self.quant_config.activation_scheme == "static":
349
350
                layer.input_scale = Parameter(layer.input_scale.max(),
                                              requires_grad=False)
351

352
353
354
355
        if self.use_marlin:
            prepare_fp8_layer_for_marlin(layer)
            # Activations not quantized for marlin.
            del layer.input_scale
356

357
358
359
360
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
361

362
        if self.use_marlin:
363
364
365
366
            return apply_fp8_marlin_linear(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale,
367
368
369
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
370
                bias=bias)
371

372
373
374
        # Note: lazy import to avoid triton import error.
        from vllm.model_executor.layers.quantization.utils.fp8_utils import (
            apply_w8a8_block_fp8_linear)
375
376
377
378
379
380
381
382
383
        if self.block_quant:
            assert self.quant_config.weight_block_size is not None
            return apply_w8a8_block_fp8_linear(
                input=x,
                weight=layer.weight,
                block_size=self.quant_config.weight_block_size,
                weight_scale=layer.weight_scale_inv,
                input_scale=layer.input_scale,
                bias=bias,
384
                cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
385
386
            )

387
388
389
390
391
392
        return apply_fp8_linear(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            input_scale=layer.input_scale,
            bias=bias,
393
            cutlass_fp8_supported=self.cutlass_fp8_supported,
394
395
            # Default to using per_token quantization if cutlass is supported
            use_per_token_if_dynamic=self.cutlass_fp8_supported)
396
397


398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
class Fp8MoEMethod(FusedMoEMethodBase):
    """MoE method for FP8.
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

    def __init__(self, quant_config: Fp8Config):
        self.quant_config = quant_config
413
        self.block_quant = self.quant_config.weight_block_size is not None
414
415

    def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
416
417
                       intermediate_size_per_partition: int,
                       params_dtype: torch.dtype, **extra_weight_attrs):
418
419
420

        if self.quant_config.is_checkpoint_fp8_serialized:
            params_dtype = torch.float8_e4m3fn
421
422
423
424
425
426
427
428
429
430
431
        if self.block_quant:
            assert self.quant_config.weight_block_size is not None
            tp_size = get_tensor_model_parallel_world_size()
            block_n, block_k = (
                self.quant_config.weight_block_size[0],
                self.quant_config.weight_block_size[1],
            )
            # NOTE: To ensure proper alignment of the block-wise quantization
            # scales, the output_size of the weights for both the gate and up
            # layers must be divisible by block_n.
            # Required by column parallel or enabling merged weights
432
            if intermediate_size_per_partition % block_n != 0:
433
434
                raise ValueError(
                    f"The output_size of gate's and up's weight = "
435
                    f"{intermediate_size_per_partition} is not divisible by "
436
                    f"weight quantization block_n = {block_n}.")
437
438
            if (tp_size > 1
                    and intermediate_size_per_partition % block_k != 0):
439
                # Required by row parallel
440
441
442
443
                raise ValueError(
                    f"The input_size of down's weight = "
                    f"{intermediate_size_per_partition} is not divisible by "
                    f"weight quantization block_k = {block_k}.")
444
445

        # WEIGHTS
446
447
448
449
450
        w13_weight = torch.nn.Parameter(torch.empty(
            num_experts,
            2 * intermediate_size_per_partition,
            hidden_size,
            dtype=params_dtype),
451
452
453
454
                                        requires_grad=False)
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

455
456
457
458
459
        w2_weight = torch.nn.Parameter(torch.empty(
            num_experts,
            hidden_size,
            intermediate_size_per_partition,
            dtype=params_dtype),
460
461
462
463
464
                                       requires_grad=False)
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
        if not self.block_quant:
            # Allocate 2 scales for w1 and w3 respectively.
            # They will be combined to a single scale after weight loading.
            w13_weight_scale = torch.nn.Parameter(torch.ones(
                num_experts, 2, dtype=torch.float32),
                                                  requires_grad=False)
            w2_weight_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                 requires_grad=False)
            layer.register_parameter("w13_weight_scale", w13_weight_scale)
            layer.register_parameter("w2_weight_scale", w2_weight_scale)
        else:
            w13_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
480
481
                    2 * ((intermediate_size_per_partition + block_n - 1) //
                         block_n),
482
483
484
485
486
487
488
489
490
                    (hidden_size + block_k - 1) // block_k,
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            w2_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
                    (hidden_size + block_n - 1) // block_n,
491
                    (intermediate_size_per_partition + block_k - 1) // block_k,
492
493
494
495
496
497
498
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
            layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
            assert self.quant_config.activation_scheme == "dynamic"
499

500
501
502
        # Add the quantization method used (per tensor/grouped/channel)
        # to ensure the weight scales are loaded in properly
        extra_weight_attrs.update(
503
504
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.
             value} if self.block_quant else
505
            {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
506
507
508
509
        # If loading fp8 checkpoint, pass the weight loaders.
        # If loading an fp16 checkpoint, do not (we will quantize in
        #   process_weights_after_loading()
        if self.quant_config.is_checkpoint_fp8_serialized:
510
511
            set_weight_attrs(w13_weight_scale, extra_weight_attrs)
            set_weight_attrs(w2_weight_scale, extra_weight_attrs)
512
513
514
515
516
517
518
519

        # INPUT_SCALES
        if self.quant_config.activation_scheme == "static":
            if not self.quant_config.is_checkpoint_fp8_serialized:
                raise ValueError(
                    "Found static activation scheme for checkpoint that "
                    "was not serialized fp8.")

520
521
522
523
            w13_input_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                 requires_grad=False)
            layer.register_parameter("w13_input_scale", w13_input_scale)
524
            set_weight_attrs(w13_input_scale, extra_weight_attrs)
525
526
527
528
529

            w2_input_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                requires_grad=False)
            layer.register_parameter("w2_input_scale", w2_input_scale)
530
531
            set_weight_attrs(w2_input_scale, extra_weight_attrs)

532
        else:
533
534
            layer.w13_input_scale = None
            layer.w2_input_scale = None
535
536

    def process_weights_after_loading(self, layer: Module) -> None:
537
        # TODO (rob): refactor block quant into separate class.
538
        if self.block_quant:
539
            assert self.quant_config.activation_scheme == "dynamic"
540
541
542
543
544
545
546
547
548
            if current_platform.is_rocm():
                w13_weight, w13_weight_scale_inv, w13_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w13_weight, layer.w13_weight_scale_inv,
                        layer.w13_input_scale)
                w2_weight, w2_weight_scale_inv, w2_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w2_weight, layer.w2_weight_scale_inv,
                        layer.w2_input_scale)
549
550
551
552
553
554
555
556
557
558
559
560
561
            else:
                w13_weight = layer.w13_weight.data
                w13_weight_scale_inv = layer.w13_weight_scale_inv.data
                w2_weight = layer.w2_weight
                w2_weight_scale_inv = layer.w2_weight_scale_inv

            # torch.compile() cannot use Parameter subclasses.
            layer.w13_weight = Parameter(w13_weight, requires_grad=False)
            layer.w13_weight_scale_inv = Parameter(w13_weight_scale_inv,
                                                   requires_grad=False)
            layer.w2_weight = Parameter(w2_weight, requires_grad=False)
            layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv,
                                                  requires_grad=False)
562
            return
563

564
565
        # If checkpoint is fp16, quantize in place.
        if not self.quant_config.is_checkpoint_fp8_serialized:
566
567
            # If rocm, use float8_e4m3fnuz as dtype
            fp8_dtype = torch.float8_e4m3fnuz \
568
                        if current_platform.is_rocm() else torch.float8_e4m3fn
569
            w13_weight = torch.empty_like(layer.w13_weight.data,
570
571
                                          dtype=fp8_dtype)
            w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
572
573
574

            # Re-initialize w13_scale because we directly quantize
            # merged w13 weights and generate a single scaling factor.
575
            layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
576
                layer.local_num_experts,
577
578
                dtype=torch.float32,
                device=w13_weight.device),
579
                                                        requires_grad=False)
580
            for expert in range(layer.local_num_experts):
581
                w13_weight[expert, :, :], layer.w13_weight_scale[
582
583
                    expert] = ops.scaled_fp8_quant(
                        layer.w13_weight.data[expert, :, :])
584
                w2_weight[expert, :, :], layer.w2_weight_scale[
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
                    expert] = ops.scaled_fp8_quant(
                        layer.w2_weight.data[expert, :, :])
            layer.w13_weight = torch.nn.Parameter(w13_weight,
                                                  requires_grad=False)
            layer.w2_weight = torch.nn.Parameter(w2_weight,
                                                 requires_grad=False)
            return

        # If checkpoint is fp8, we need to handle that the
        # MoE kernels require single activation scale and single weight
        # scale for w13 per expert.
        else:
            # Fp8 moe kernels require a single activation scale.
            # We take the max of all the scales in case they differ.
            if self.quant_config.activation_scheme == "static":
600
601
                if (layer.w13_input_scale is None
                        or layer.w2_input_scale is None):
602
603
604
                    raise ValueError(
                        "QuantConfig has static quantization, but found "
                        "activation scales are None.")
605
606
                if (not all_close_1d(layer.w13_input_scale)
                        or not all_close_1d(layer.w2_input_scale)):
607
                    logger.warning_once(
608
609
                        "Found input_scales that are not equal for "
                        "fp8 MoE layer. Using the maximum across experts "
610
                        "for each layer.")
611
612
613
614
                layer.w13_input_scale = torch.nn.Parameter(
                    layer.w13_input_scale.max(), requires_grad=False)
                layer.w2_input_scale = torch.nn.Parameter(
                    layer.w2_input_scale.max(), requires_grad=False)
615
            # If rocm, normalize the weights and scales to e4m3fnuz
616
            if current_platform.is_rocm():
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
                # Normalize the weights and scales
                w13_weight, w13_weight_scale, w13_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w13_weight, layer.w13_weight_scale,
                        layer.w13_input_scale)
                w2_weight, w2_weight_scale, w2_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w2_weight, layer.w2_weight_scale,
                        layer.w2_input_scale)
                # Reset the parameter
                layer.w13_weight = torch.nn.Parameter(w13_weight,
                                                      requires_grad=False)
                layer.w13_weight_scale = torch.nn.Parameter(
                    w13_weight_scale, requires_grad=False)
                if w13_input_scale is not None:
                    layer.w13_input_scale = torch.nn.Parameter(
                        w13_input_scale, requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(w2_weight,
                                                     requires_grad=False)
                layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale,
                                                           requires_grad=False)
                if w2_input_scale is not None:
                    layer.w2_input_scale = torch.nn.Parameter(
                        w2_input_scale, requires_grad=False)
641
642
643

            # Fp8 moe kernel needs single weight scale for w13 per expert.
            # We take the max then dequant and requant each expert.
644
            assert layer.w13_weight_scale is not None
645
            shard_size = layer.intermediate_size_per_partition
646
            max_w13_scales = layer.w13_weight_scale.max(dim=1).values
647
            for expert_id in range(layer.local_num_experts):
648
649
650
651
652
                start = 0
                for shard_id in range(2):
                    dq_weight = per_tensor_dequantize(
                        layer.w13_weight[expert_id][start:start +
                                                    shard_size, :],
653
                        layer.w13_weight_scale[expert_id][shard_id])
654
                    layer.w13_weight[expert_id][
655
                        start:start + shard_size, :], _ = ops.scaled_fp8_quant(
656
657
658
                            dq_weight, max_w13_scales[expert_id])
                    start += shard_size

659
660
            layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
                                                        requires_grad=False)
661
662
            return

663
664
665
666
667
668
669
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
670
        use_grouped_topk: bool = False,
671
672
        topk_group: Optional[int] = None,
        num_expert_group: Optional[int] = None,
673
674
        global_num_experts: int = -1,
        expert_map: Optional[torch.Tensor] = None,
675
        custom_routing_function: Optional[Callable] = None,
Simon Mo's avatar
Simon Mo committed
676
677
        scoring_func: str = "softmax",
        e_score_correction_bias: Optional[torch.Tensor] = None,
Michael Goin's avatar
Michael Goin committed
678
        activation: str = "silu",
679
    ) -> torch.Tensor:
680
681
682
683
684
685
686
687
688
        from vllm.model_executor.layers.fused_moe import fused_experts

        topk_weights, topk_ids = FusedMoE.select_experts(
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
689
            num_expert_group=num_expert_group,
Simon Mo's avatar
Simon Mo committed
690
691
692
693
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
            e_score_correction_bias=e_score_correction_bias,
        )
694

695
696
697
698
699
700
701
        return fused_experts(
            x,
            layer.w13_weight,
            layer.w2_weight,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            inplace=True,
Michael Goin's avatar
Michael Goin committed
702
            activation=activation,
703
            use_fp8_w8a8=True,
704
705
            global_num_experts=global_num_experts,
            expert_map=expert_map,
706
707
708
709
710
711
712
713
            w1_scale=(layer.w13_weight_scale_inv
                      if self.block_quant else layer.w13_weight_scale),
            w2_scale=(layer.w2_weight_scale_inv
                      if self.block_quant else layer.w2_weight_scale),
            a1_scale=layer.w13_input_scale,
            a2_scale=layer.w2_input_scale,
            block_shape=self.quant_config.weight_block_size,
        )
714
715


716
717
718
class Fp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
719
720
721
    """

    def __init__(self, quant_config: Fp8Config):
722
        super().__init__(quant_config)