fp8.py 29.3 KB
Newer Older
1
from typing import Any, Callable, Dict, List, Optional
2
3
4
5
6

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter

7
import vllm.envs as envs
8
from vllm import _custom_ops as ops
9
from vllm.distributed import get_tensor_model_parallel_world_size
10
from vllm.logger import init_logger
11
12
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
                                                  FusedMoeWeightScaleSupported)
13
14
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
                                               UnquantizedLinearMethod)
15
from vllm.model_executor.layers.quantization.base_config import (
16
    QuantizationConfig, QuantizeMethodBase)
17
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
18
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
19
    apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
20
21
from vllm.model_executor.layers.quantization.utils.quant_utils import (
    is_layer_skipped)
22
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
23
    all_close_1d, apply_fp8_linear, convert_to_channelwise,
24
    cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
25
    requantize_with_max_scale)
26
27
from vllm.model_executor.parameter import (BlockQuantScaleParameter,
                                           ModelWeightParameter,
28
                                           PerTensorScaleParameter)
29
from vllm.model_executor.utils import set_weight_attrs
30
from vllm.platforms import current_platform
31
from vllm.utils import print_warning_once
32

33
34
35
36
ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)

37

38
class Fp8Config(QuantizationConfig):
39
40
    """Config class for FP8."""

41
42
    def __init__(
        self,
43
        is_checkpoint_fp8_serialized: bool = False,
44
        activation_scheme: str = "dynamic",
45
        ignored_layers: Optional[List[str]] = None,
46
        weight_block_size: Optional[List[int]] = None,
47
    ) -> None:
48
49
50
51
52
53
54
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
        if is_checkpoint_fp8_serialized:
            logger.warning("Detected fp8 checkpoint. Please note that the "
                           "format is experimental and subject to change.")
        if activation_scheme not in ACTIVATION_SCHEMES:
            raise ValueError(
                f"Unsupported activation scheme {activation_scheme}")
55
        self.activation_scheme = activation_scheme
56
        self.ignored_layers = ignored_layers or []
57
58
59
60
61
62
63
64
65
66
67
68
69
70
        if weight_block_size is not None:
            if not is_checkpoint_fp8_serialized:
                raise ValueError(
                    "The block-wise quantization only supports fp8-serialized "
                    "checkpoint for now.")
            if len(weight_block_size) != 2:
                raise ValueError(
                    "The quantization block size of weight must have 2 "
                    f"dimensions, but got {len(weight_block_size)} dimensions")
            if activation_scheme != "dynamic":
                raise ValueError("The block-wise quantization only supports "
                                 "dynamic activation scheme for now, but got "
                                 f"{activation_scheme} activation scheme.")
        self.weight_block_size = weight_block_size
71

72
73
74
75
76
77
78
79
80
81
    @classmethod
    def get_name(cls) -> str:
        return "fp8"

    @classmethod
    def get_supported_act_dtypes(cls) -> List[torch.dtype]:
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
82
        return 80
83
84
85
86
87
88

    @classmethod
    def get_config_filenames(cls) -> List[str]:
        return []

    @classmethod
89
    def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
90
91
        quant_method = cls.get_from_keys(config, ["quant_method"])
        is_checkpoint_fp8_serialized = ("fp8" in quant_method)
92
        activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
93
        ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
94
95
        weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"],
                                                 None)
96
        return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
97
                   activation_scheme=activation_scheme,
98
99
                   ignored_layers=ignored_layers,
                   weight_block_size=weight_block_size)
100

101
102
    def get_quant_method(self, layer: torch.nn.Module,
                         prefix: str) -> Optional["QuantizeMethodBase"]:
103
104
        from vllm.attention.layer import Attention  # Avoid circular import

105
        if isinstance(layer, LinearBase):
106
107
            if is_layer_skipped(prefix, self.ignored_layers):
                return UnquantizedLinearMethod()
108
            return Fp8LinearMethod(self)
109
110
111
        elif isinstance(layer, FusedMoE):
            return Fp8MoEMethod(self)
        elif isinstance(layer, Attention):
112
            return Fp8KVCacheMethod(self)
113
        return None
114
115
116
117


class Fp8LinearMethod(LinearMethodBase):
    """Linear method for FP8.
118
119
120
121
122
123
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.
124
125
126
127
128

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
    2. Only support float8_e4m3fn data type due to the limitation of
       torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
129

130
131
132
133
    Args:
        quant_config: The quantization config.
    """

134
    def __init__(self, quant_config: Fp8Config):
135
        self.quant_config = quant_config
136
        self.cutlass_fp8_supported = cutlass_fp8_supported()
137

138
139
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
140
141
        self.use_marlin = (not current_platform.has_device_capability(89)
                           or envs.VLLM_TEST_FORCE_FP8_MARLIN)
142
        # Disable marlin for rocm
143
        if current_platform.is_rocm():
144
            self.use_marlin = False
145

146
147
148
149
150
        self.block_quant = self.quant_config.weight_block_size is not None
        if self.block_quant:
            # Marlin doesn't support block-wise fp8
            self.use_marlin = False

151
152
153
154
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
155
        output_partition_sizes: List[int],
156
157
158
159
160
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
161
        output_size_per_partition = sum(output_partition_sizes)
162
        weight_loader = extra_weight_attrs.get("weight_loader")
163

164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
        if self.block_quant:
            tp_size = get_tensor_model_parallel_world_size()
            assert self.quant_config.weight_block_size is not None
            block_n, block_k = (
                self.quant_config.weight_block_size[0],
                self.quant_config.weight_block_size[1],
            )
            # Required by row parallel
            if (tp_size > 1
                    and input_size // input_size_per_partition == tp_size
                    and input_size_per_partition % block_k != 0):
                raise ValueError(
                    f"Weight input_size_per_partition = "
                    f"{input_size_per_partition} is not divisible by "
                    f"weight quantization block_k = {block_k}.")
            # Required by column parallel or enabling merged weights
            if (tp_size > 1 and output_size // output_size_per_partition
                    == tp_size) or len(output_partition_sizes) > 1:
                for output_partition_size in output_partition_sizes:
                    if output_partition_size % block_n != 0:
                        raise ValueError(
                            f"Weight output_partition_size = "
                            f"{output_partition_size} is not divisible by "
                            f"weight quantization block_n = {block_n}.")

189
190
        layer.logical_widths = output_partition_sizes

191
192
193
194
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype

195
196
197
198
        # WEIGHT
        weight_dtype = (torch.float8_e4m3fn
                        if self.quant_config.is_checkpoint_fp8_serialized else
                        params_dtype)
199
200
201
202
203
204
205
206

        weight = ModelWeightParameter(data=torch.empty(
            output_size_per_partition,
            input_size_per_partition,
            dtype=weight_dtype),
                                      input_dim=1,
                                      output_dim=0,
                                      weight_loader=weight_loader)
207
208
        layer.register_parameter("weight", weight)

209
210
211
212
        # If checkpoint is serialized fp8, load them.
        # Otherwise, wait until process_weights_after_loading.
        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
            if not self.block_quant:
                scale = PerTensorScaleParameter(
                    data=torch.empty(len(output_partition_sizes),
                                     dtype=torch.float32),
                    weight_loader=weight_loader,
                )
                scale[:] = torch.finfo(torch.float32).min
                layer.register_parameter("weight_scale", scale)
            else:
                assert self.quant_config.activation_scheme == "dynamic"
                scale = BlockQuantScaleParameter(
                    data=torch.empty(
                        (output_size_per_partition + block_n - 1) // block_n,
                        (input_size_per_partition + block_k - 1) // block_k,
                        dtype=torch.float32,
                    ),
                    input_dim=1,
                    output_dim=0,
                    weight_loader=weight_loader,
                )
                scale[:] = torch.finfo(torch.float32).min
                # The weight_scale_inv name is intentional for deepseekv3
                layer.register_parameter("weight_scale_inv", scale)
236

237
            # INPUT ACTIVATION SCALE
238
            if self.quant_config.activation_scheme == "static":
239
240
241
242
243
                scale = PerTensorScaleParameter(data=torch.empty(
                    len(output_partition_sizes), dtype=torch.float32),
                                                weight_loader=weight_loader)

                scale[:] = torch.finfo(torch.float32).min
244
                layer.register_parameter("input_scale", scale)
245
246
            else:
                layer.register_parameter("input_scale", None)
247

248
    def process_weights_after_loading(self, layer: Module) -> None:
249
250
251
        # Block quant doesn't need to process weights after loading
        if self.block_quant:
            return
252
253
        layer.weight = torch.nn.Parameter(layer.weight.data,
                                          requires_grad=False)
254
        # If checkpoint not serialized fp8, quantize the weights.
255
256
257
        if not self.quant_config.is_checkpoint_fp8_serialized:
            qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
                                                         scale=None)
258

259
260
261
262
263
264
265
266
            # If using marlin (w8a16), kernel uses channelwise weights,
            # so extend the weight scales to be channelwise.
            if self.use_marlin:
                assert weight_scale.numel() == 1
                weight_scale = convert_to_channelwise(
                    weight_scale.expand(len(layer.logical_widths)),
                    layer.logical_widths)

267
            # Update the layer with the new values.
268
269
            layer.weight = Parameter(qweight.t(), requires_grad=False)
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
270
            layer.input_scale = None
271

272
273
        # If checkpoint is fp8, handle that there are N scales for N
        # shards in a fused module
274
        else:
275
276
277
278
279
            layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
                                                    requires_grad=False)
            if self.quant_config.activation_scheme == "static":
                layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
                                                       requires_grad=False)
280
281
282
283
284
285
286
287
288
289
290
            # If using marlin (w8a16), kernel uses channelwise weights,
            # so extend the weight scales to be channelwise.
            if self.use_marlin:
                weight = layer.weight
                weight_scale = convert_to_channelwise(layer.weight_scale,
                                                      layer.logical_widths)

            # If using w8a8, torch._scaled_mm needs per tensor, so
            # requantize the logical shards as a single weight.
            else:
                # Dequant -> Quant with max scale so we can run per tensor.
291
292
293
294
                weight = layer.weight
                weight_scale = layer.weight_scale

                # If rocm, use float8_e4m3fnuz.
295
                if current_platform.is_rocm():
296
297
298
299
300
301
302
303
304
                    weight, weight_scale, input_scale = \
                        normalize_e4m3fn_to_e4m3fnuz(
                            weight=weight,
                            weight_scale=weight_scale,
                            input_scale=layer.input_scale)
                    if input_scale is not None:
                        layer.input_scale = Parameter(input_scale,
                                                      requires_grad=False)

305
                weight_scale, weight = requantize_with_max_scale(
306
307
                    weight=weight,
                    weight_scale=weight_scale,
308
309
                    logical_widths=layer.logical_widths,
                )
310

311
            # Update layer with new values.
312
            layer.weight = Parameter(weight.t(), requires_grad=False)
313
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
314
            if self.quant_config.activation_scheme == "static":
315
316
                layer.input_scale = Parameter(layer.input_scale.max(),
                                              requires_grad=False)
317

318
319
320
321
        if self.use_marlin:
            prepare_fp8_layer_for_marlin(layer)
            # Activations not quantized for marlin.
            del layer.input_scale
322

323
324
325
326
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
327

328
        if self.use_marlin:
329
330
331
332
            return apply_fp8_marlin_linear(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale,
333
334
335
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
336
                bias=bias)
337

338
339
340
        # Note: lazy import to avoid triton import error.
        from vllm.model_executor.layers.quantization.utils.fp8_utils import (
            apply_w8a8_block_fp8_linear)
341
342
343
344
345
346
347
348
349
350
351
        if self.block_quant:
            assert self.quant_config.weight_block_size is not None
            return apply_w8a8_block_fp8_linear(
                input=x,
                weight=layer.weight,
                block_size=self.quant_config.weight_block_size,
                weight_scale=layer.weight_scale_inv,
                input_scale=layer.input_scale,
                bias=bias,
            )

352
353
354
355
356
357
        return apply_fp8_linear(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            input_scale=layer.input_scale,
            bias=bias,
358
359
            cutlass_fp8_supported=self.cutlass_fp8_supported,
            use_per_token_if_dynamic=False)
360
361


362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
class Fp8MoEMethod(FusedMoEMethodBase):
    """MoE method for FP8.
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

    def __init__(self, quant_config: Fp8Config):
        self.quant_config = quant_config
377
        self.block_quant = self.quant_config.weight_block_size is not None
378
379
380
381
382
383
384

    def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
                       intermediate_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):

        if self.quant_config.is_checkpoint_fp8_serialized:
            params_dtype = torch.float8_e4m3fn
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
        if self.block_quant:
            assert self.quant_config.weight_block_size is not None
            tp_size = get_tensor_model_parallel_world_size()
            block_n, block_k = (
                self.quant_config.weight_block_size[0],
                self.quant_config.weight_block_size[1],
            )
            # NOTE: To ensure proper alignment of the block-wise quantization
            # scales, the output_size of the weights for both the gate and up
            # layers must be divisible by block_n.
            # Required by column parallel or enabling merged weights
            if intermediate_size % block_n != 0:
                raise ValueError(
                    f"The output_size of gate's and up's weight = "
                    f"{intermediate_size} is not divisible by "
                    f"weight quantization block_n = {block_n}.")
            if (tp_size > 1 and intermediate_size % block_k != 0):
                # Required by row parallel
                raise ValueError(f"The input_size of down's weight = "
                                 f"{intermediate_size} is not divisible by "
                                 f"weight quantization block_k = {block_k}.")
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424

        # WEIGHTS
        w13_weight = torch.nn.Parameter(torch.empty(num_experts,
                                                    2 * intermediate_size,
                                                    hidden_size,
                                                    dtype=params_dtype),
                                        requires_grad=False)
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

        w2_weight = torch.nn.Parameter(torch.empty(num_experts,
                                                   hidden_size,
                                                   intermediate_size,
                                                   dtype=params_dtype),
                                       requires_grad=False)
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
        if not self.block_quant:
            # Allocate 2 scales for w1 and w3 respectively.
            # They will be combined to a single scale after weight loading.
            w13_weight_scale = torch.nn.Parameter(torch.ones(
                num_experts, 2, dtype=torch.float32),
                                                  requires_grad=False)
            w2_weight_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                 requires_grad=False)
            layer.register_parameter("w13_weight_scale", w13_weight_scale)
            layer.register_parameter("w2_weight_scale", w2_weight_scale)
        else:
            w13_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
                    2 * ((intermediate_size + block_n - 1) // block_n),
                    (hidden_size + block_k - 1) // block_k,
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            w2_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
                    (hidden_size + block_n - 1) // block_n,
                    (intermediate_size + block_k - 1) // block_k,
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
            layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
            assert self.quant_config.activation_scheme == "dynamic"
458

459
460
461
        # Add the quantization method used (per tensor/grouped/channel)
        # to ensure the weight scales are loaded in properly
        extra_weight_attrs.update(
462
463
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.
             value} if self.block_quant else
464
            {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
465
466
467
468
        # If loading fp8 checkpoint, pass the weight loaders.
        # If loading an fp16 checkpoint, do not (we will quantize in
        #   process_weights_after_loading()
        if self.quant_config.is_checkpoint_fp8_serialized:
469
470
            set_weight_attrs(w13_weight_scale, extra_weight_attrs)
            set_weight_attrs(w2_weight_scale, extra_weight_attrs)
471
472
473
474
475
476
477
478

        # INPUT_SCALES
        if self.quant_config.activation_scheme == "static":
            if not self.quant_config.is_checkpoint_fp8_serialized:
                raise ValueError(
                    "Found static activation scheme for checkpoint that "
                    "was not serialized fp8.")

479
480
481
482
            w13_input_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                 requires_grad=False)
            layer.register_parameter("w13_input_scale", w13_input_scale)
483
            set_weight_attrs(w13_input_scale, extra_weight_attrs)
484
485
486
487
488

            w2_input_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                requires_grad=False)
            layer.register_parameter("w2_input_scale", w2_input_scale)
489
490
            set_weight_attrs(w2_input_scale, extra_weight_attrs)

491
        else:
492
493
            layer.w13_input_scale = None
            layer.w2_input_scale = None
494
495

    def process_weights_after_loading(self, layer: Module) -> None:
496
497
498
        # Block quant doesn't need to process weights after loading
        if self.block_quant:
            return
499
500
        # If checkpoint is fp16, quantize in place.
        if not self.quant_config.is_checkpoint_fp8_serialized:
501
502
            # If rocm, use float8_e4m3fnuz as dtype
            fp8_dtype = torch.float8_e4m3fnuz \
503
                        if current_platform.is_rocm() else torch.float8_e4m3fn
504
            w13_weight = torch.empty_like(layer.w13_weight.data,
505
506
                                          dtype=fp8_dtype)
            w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
507
508
509

            # Re-initialize w13_scale because we directly quantize
            # merged w13 weights and generate a single scaling factor.
510
            layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
511
512
513
                layer.num_experts,
                dtype=torch.float32,
                device=w13_weight.device),
514
                                                        requires_grad=False)
515
            for expert in range(layer.num_experts):
516
                w13_weight[expert, :, :], layer.w13_weight_scale[
517
518
                    expert] = ops.scaled_fp8_quant(
                        layer.w13_weight.data[expert, :, :])
519
                w2_weight[expert, :, :], layer.w2_weight_scale[
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
                    expert] = ops.scaled_fp8_quant(
                        layer.w2_weight.data[expert, :, :])
            layer.w13_weight = torch.nn.Parameter(w13_weight,
                                                  requires_grad=False)
            layer.w2_weight = torch.nn.Parameter(w2_weight,
                                                 requires_grad=False)
            return

        # If checkpoint is fp8, we need to handle that the
        # MoE kernels require single activation scale and single weight
        # scale for w13 per expert.
        else:
            # Fp8 moe kernels require a single activation scale.
            # We take the max of all the scales in case they differ.
            if self.quant_config.activation_scheme == "static":
535
536
                if (layer.w13_input_scale is None
                        or layer.w2_input_scale is None):
537
538
539
                    raise ValueError(
                        "QuantConfig has static quantization, but found "
                        "activation scales are None.")
540
541
                if (not all_close_1d(layer.w13_input_scale)
                        or not all_close_1d(layer.w2_input_scale)):
542
543
544
545
                    print_warning_once(
                        "Found input_scales that are not equal for "
                        "fp8 MoE layer. Using the maximum across experts "
                        "for each layer. ")
546
547
548
549
                layer.w13_input_scale = torch.nn.Parameter(
                    layer.w13_input_scale.max(), requires_grad=False)
                layer.w2_input_scale = torch.nn.Parameter(
                    layer.w2_input_scale.max(), requires_grad=False)
550
            # If rocm, normalize the weights and scales to e4m3fnuz
551
            if current_platform.is_rocm():
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
                # Normalize the weights and scales
                w13_weight, w13_weight_scale, w13_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w13_weight, layer.w13_weight_scale,
                        layer.w13_input_scale)
                w2_weight, w2_weight_scale, w2_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w2_weight, layer.w2_weight_scale,
                        layer.w2_input_scale)
                # Reset the parameter
                layer.w13_weight = torch.nn.Parameter(w13_weight,
                                                      requires_grad=False)
                layer.w13_weight_scale = torch.nn.Parameter(
                    w13_weight_scale, requires_grad=False)
                if w13_input_scale is not None:
                    layer.w13_input_scale = torch.nn.Parameter(
                        w13_input_scale, requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(w2_weight,
                                                     requires_grad=False)
                layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale,
                                                           requires_grad=False)
                if w2_input_scale is not None:
                    layer.w2_input_scale = torch.nn.Parameter(
                        w2_input_scale, requires_grad=False)
576
577
578

            # Fp8 moe kernel needs single weight scale for w13 per expert.
            # We take the max then dequant and requant each expert.
579
            assert layer.w13_weight_scale is not None
580
            shard_size = layer.intermediate_size_per_partition
581
            max_w13_scales = layer.w13_weight_scale.max(dim=1).values
582
583
584
585
586
587
            for expert_id in range(layer.num_experts):
                start = 0
                for shard_id in range(2):
                    dq_weight = per_tensor_dequantize(
                        layer.w13_weight[expert_id][start:start +
                                                    shard_size, :],
588
                        layer.w13_weight_scale[expert_id][shard_id])
589
                    layer.w13_weight[expert_id][
590
                        start:start + shard_size, :], _ = ops.scaled_fp8_quant(
591
592
593
                            dq_weight, max_w13_scales[expert_id])
                    start += shard_size

594
595
            layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
                                                        requires_grad=False)
596
597
            return

598
599
600
601
602
603
604
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
605
        use_grouped_topk: bool = False,
606
607
608
        topk_group: Optional[int] = None,
        num_expert_group: Optional[int] = None,
        custom_routing_function: Optional[Callable] = None,
Simon Mo's avatar
Simon Mo committed
609
610
        scoring_func: str = "softmax",
        e_score_correction_bias: Optional[torch.Tensor] = None,
611
    ) -> torch.Tensor:
612
613
614
615
616
617
618
619
620
        from vllm.model_executor.layers.fused_moe import fused_experts

        topk_weights, topk_ids = FusedMoE.select_experts(
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
621
            num_expert_group=num_expert_group,
Simon Mo's avatar
Simon Mo committed
622
623
624
625
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
            e_score_correction_bias=e_score_correction_bias,
        )
626

627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
        return fused_experts(
            x,
            layer.w13_weight,
            layer.w2_weight,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            inplace=True,
            use_fp8_w8a8=True,
            w1_scale=(layer.w13_weight_scale_inv
                      if self.block_quant else layer.w13_weight_scale),
            w2_scale=(layer.w2_weight_scale_inv
                      if self.block_quant else layer.w2_weight_scale),
            a1_scale=layer.w13_input_scale,
            a2_scale=layer.w2_input_scale,
            block_shape=self.quant_config.weight_block_size,
        )
643
644


645
646
647
class Fp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
648
649
650
    """

    def __init__(self, quant_config: Fp8Config):
651
        super().__init__(quant_config)