fp8.py 29.3 KB
Newer Older
1
from typing import Any, Callable, Dict, List, Optional
2
3
4
5
6

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter

7
import vllm.envs as envs
8
from vllm import _custom_ops as ops
9
from vllm.distributed import get_tensor_model_parallel_world_size
10
from vllm.logger import init_logger
11
12
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
                                                  FusedMoeWeightScaleSupported)
13
14
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
                                               UnquantizedLinearMethod)
15
from vllm.model_executor.layers.quantization.base_config import (
16
    QuantizationConfig, QuantizeMethodBase)
17
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
18
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
19
    apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
20
21
from vllm.model_executor.layers.quantization.utils.quant_utils import (
    is_layer_skipped)
22
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
23
    all_close_1d, apply_fp8_linear, convert_to_channelwise,
24
    cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
25
    requantize_with_max_scale)
26
27
from vllm.model_executor.parameter import (BlockQuantScaleParameter,
                                           ModelWeightParameter,
28
                                           PerTensorScaleParameter)
29
from vllm.model_executor.utils import set_weight_attrs
30
from vllm.platforms import current_platform
31

32
33
34
35
ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)

36

37
class Fp8Config(QuantizationConfig):
38
39
    """Config class for FP8."""

40
41
    def __init__(
        self,
42
        is_checkpoint_fp8_serialized: bool = False,
43
        activation_scheme: str = "dynamic",
44
        ignored_layers: Optional[List[str]] = None,
45
        weight_block_size: Optional[List[int]] = None,
46
    ) -> None:
47
48
49
50
51
52
53
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
        if is_checkpoint_fp8_serialized:
            logger.warning("Detected fp8 checkpoint. Please note that the "
                           "format is experimental and subject to change.")
        if activation_scheme not in ACTIVATION_SCHEMES:
            raise ValueError(
                f"Unsupported activation scheme {activation_scheme}")
54
        self.activation_scheme = activation_scheme
55
        self.ignored_layers = ignored_layers or []
56
57
58
59
60
61
62
63
64
65
66
67
68
69
        if weight_block_size is not None:
            if not is_checkpoint_fp8_serialized:
                raise ValueError(
                    "The block-wise quantization only supports fp8-serialized "
                    "checkpoint for now.")
            if len(weight_block_size) != 2:
                raise ValueError(
                    "The quantization block size of weight must have 2 "
                    f"dimensions, but got {len(weight_block_size)} dimensions")
            if activation_scheme != "dynamic":
                raise ValueError("The block-wise quantization only supports "
                                 "dynamic activation scheme for now, but got "
                                 f"{activation_scheme} activation scheme.")
        self.weight_block_size = weight_block_size
70

71
72
73
74
75
76
77
78
79
80
    @classmethod
    def get_name(cls) -> str:
        return "fp8"

    @classmethod
    def get_supported_act_dtypes(cls) -> List[torch.dtype]:
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
81
        return 80
82
83
84
85
86
87

    @classmethod
    def get_config_filenames(cls) -> List[str]:
        return []

    @classmethod
88
    def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
89
90
        quant_method = cls.get_from_keys(config, ["quant_method"])
        is_checkpoint_fp8_serialized = ("fp8" in quant_method)
91
        activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
92
        ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
93
94
        weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"],
                                                 None)
95
        return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
96
                   activation_scheme=activation_scheme,
97
98
                   ignored_layers=ignored_layers,
                   weight_block_size=weight_block_size)
99

100
101
    def get_quant_method(self, layer: torch.nn.Module,
                         prefix: str) -> Optional["QuantizeMethodBase"]:
102
103
        from vllm.attention.layer import Attention  # Avoid circular import

104
        if isinstance(layer, LinearBase):
105
106
            if is_layer_skipped(prefix, self.ignored_layers):
                return UnquantizedLinearMethod()
107
            return Fp8LinearMethod(self)
108
109
110
        elif isinstance(layer, FusedMoE):
            return Fp8MoEMethod(self)
        elif isinstance(layer, Attention):
111
            return Fp8KVCacheMethod(self)
112
        return None
113
114
115
116


class Fp8LinearMethod(LinearMethodBase):
    """Linear method for FP8.
117
118
119
120
121
122
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.
123
124
125
126
127

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
    2. Only support float8_e4m3fn data type due to the limitation of
       torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
128

129
130
131
132
    Args:
        quant_config: The quantization config.
    """

133
    def __init__(self, quant_config: Fp8Config):
134
        self.quant_config = quant_config
135
        self.cutlass_fp8_supported = cutlass_fp8_supported()
136

137
138
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
139
140
        self.use_marlin = (not current_platform.has_device_capability(89)
                           or envs.VLLM_TEST_FORCE_FP8_MARLIN)
141
        # Disable marlin for rocm
142
        if current_platform.is_rocm():
143
            self.use_marlin = False
144

145
146
147
148
149
        self.block_quant = self.quant_config.weight_block_size is not None
        if self.block_quant:
            # Marlin doesn't support block-wise fp8
            self.use_marlin = False

150
151
152
153
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
154
        output_partition_sizes: List[int],
155
156
157
158
159
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
160
        output_size_per_partition = sum(output_partition_sizes)
161
        weight_loader = extra_weight_attrs.get("weight_loader")
162

163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
        if self.block_quant:
            tp_size = get_tensor_model_parallel_world_size()
            assert self.quant_config.weight_block_size is not None
            block_n, block_k = (
                self.quant_config.weight_block_size[0],
                self.quant_config.weight_block_size[1],
            )
            # Required by row parallel
            if (tp_size > 1
                    and input_size // input_size_per_partition == tp_size
                    and input_size_per_partition % block_k != 0):
                raise ValueError(
                    f"Weight input_size_per_partition = "
                    f"{input_size_per_partition} is not divisible by "
                    f"weight quantization block_k = {block_k}.")
            # Required by column parallel or enabling merged weights
            if (tp_size > 1 and output_size // output_size_per_partition
                    == tp_size) or len(output_partition_sizes) > 1:
                for output_partition_size in output_partition_sizes:
                    if output_partition_size % block_n != 0:
                        raise ValueError(
                            f"Weight output_partition_size = "
                            f"{output_partition_size} is not divisible by "
                            f"weight quantization block_n = {block_n}.")

188
189
        layer.logical_widths = output_partition_sizes

190
191
192
193
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype

194
195
196
197
        # WEIGHT
        weight_dtype = (torch.float8_e4m3fn
                        if self.quant_config.is_checkpoint_fp8_serialized else
                        params_dtype)
198
199
200
201
202
203
204
205

        weight = ModelWeightParameter(data=torch.empty(
            output_size_per_partition,
            input_size_per_partition,
            dtype=weight_dtype),
                                      input_dim=1,
                                      output_dim=0,
                                      weight_loader=weight_loader)
206
207
        layer.register_parameter("weight", weight)

208
209
210
211
        # If checkpoint is serialized fp8, load them.
        # Otherwise, wait until process_weights_after_loading.
        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
            if not self.block_quant:
                scale = PerTensorScaleParameter(
                    data=torch.empty(len(output_partition_sizes),
                                     dtype=torch.float32),
                    weight_loader=weight_loader,
                )
                scale[:] = torch.finfo(torch.float32).min
                layer.register_parameter("weight_scale", scale)
            else:
                assert self.quant_config.activation_scheme == "dynamic"
                scale = BlockQuantScaleParameter(
                    data=torch.empty(
                        (output_size_per_partition + block_n - 1) // block_n,
                        (input_size_per_partition + block_k - 1) // block_k,
                        dtype=torch.float32,
                    ),
                    input_dim=1,
                    output_dim=0,
                    weight_loader=weight_loader,
                )
                scale[:] = torch.finfo(torch.float32).min
                # The weight_scale_inv name is intentional for deepseekv3
                layer.register_parameter("weight_scale_inv", scale)
235

236
            # INPUT ACTIVATION SCALE
237
            if self.quant_config.activation_scheme == "static":
238
239
240
241
242
                scale = PerTensorScaleParameter(data=torch.empty(
                    len(output_partition_sizes), dtype=torch.float32),
                                                weight_loader=weight_loader)

                scale[:] = torch.finfo(torch.float32).min
243
                layer.register_parameter("input_scale", scale)
244
245
            else:
                layer.register_parameter("input_scale", None)
246

247
    def process_weights_after_loading(self, layer: Module) -> None:
248
249
250
        # Block quant doesn't need to process weights after loading
        if self.block_quant:
            return
251
252
        layer.weight = torch.nn.Parameter(layer.weight.data,
                                          requires_grad=False)
253
        # If checkpoint not serialized fp8, quantize the weights.
254
255
256
        if not self.quant_config.is_checkpoint_fp8_serialized:
            qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
                                                         scale=None)
257

258
259
260
261
262
263
264
265
            # If using marlin (w8a16), kernel uses channelwise weights,
            # so extend the weight scales to be channelwise.
            if self.use_marlin:
                assert weight_scale.numel() == 1
                weight_scale = convert_to_channelwise(
                    weight_scale.expand(len(layer.logical_widths)),
                    layer.logical_widths)

266
            # Update the layer with the new values.
267
268
            layer.weight = Parameter(qweight.t(), requires_grad=False)
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
269
            layer.input_scale = None
270

271
272
        # If checkpoint is fp8, handle that there are N scales for N
        # shards in a fused module
273
        else:
274
275
276
277
278
            layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
                                                    requires_grad=False)
            if self.quant_config.activation_scheme == "static":
                layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
                                                       requires_grad=False)
279
280
281
282
283
284
285
286
287
288
289
            # If using marlin (w8a16), kernel uses channelwise weights,
            # so extend the weight scales to be channelwise.
            if self.use_marlin:
                weight = layer.weight
                weight_scale = convert_to_channelwise(layer.weight_scale,
                                                      layer.logical_widths)

            # If using w8a8, torch._scaled_mm needs per tensor, so
            # requantize the logical shards as a single weight.
            else:
                # Dequant -> Quant with max scale so we can run per tensor.
290
291
292
293
                weight = layer.weight
                weight_scale = layer.weight_scale

                # If rocm, use float8_e4m3fnuz.
294
                if current_platform.is_rocm():
295
296
297
298
299
300
301
302
303
                    weight, weight_scale, input_scale = \
                        normalize_e4m3fn_to_e4m3fnuz(
                            weight=weight,
                            weight_scale=weight_scale,
                            input_scale=layer.input_scale)
                    if input_scale is not None:
                        layer.input_scale = Parameter(input_scale,
                                                      requires_grad=False)

304
                weight_scale, weight = requantize_with_max_scale(
305
306
                    weight=weight,
                    weight_scale=weight_scale,
307
308
                    logical_widths=layer.logical_widths,
                )
309

310
            # Update layer with new values.
311
            layer.weight = Parameter(weight.t(), requires_grad=False)
312
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
313
            if self.quant_config.activation_scheme == "static":
314
315
                layer.input_scale = Parameter(layer.input_scale.max(),
                                              requires_grad=False)
316

317
318
319
320
        if self.use_marlin:
            prepare_fp8_layer_for_marlin(layer)
            # Activations not quantized for marlin.
            del layer.input_scale
321

322
323
324
325
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
326

327
        if self.use_marlin:
328
329
330
331
            return apply_fp8_marlin_linear(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale,
332
333
334
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
335
                bias=bias)
336

337
338
339
        # Note: lazy import to avoid triton import error.
        from vllm.model_executor.layers.quantization.utils.fp8_utils import (
            apply_w8a8_block_fp8_linear)
340
341
342
343
344
345
346
347
348
349
350
        if self.block_quant:
            assert self.quant_config.weight_block_size is not None
            return apply_w8a8_block_fp8_linear(
                input=x,
                weight=layer.weight,
                block_size=self.quant_config.weight_block_size,
                weight_scale=layer.weight_scale_inv,
                input_scale=layer.input_scale,
                bias=bias,
            )

351
352
353
354
355
356
        return apply_fp8_linear(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            input_scale=layer.input_scale,
            bias=bias,
357
358
            cutlass_fp8_supported=self.cutlass_fp8_supported,
            use_per_token_if_dynamic=False)
359
360


361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
class Fp8MoEMethod(FusedMoEMethodBase):
    """MoE method for FP8.
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

    def __init__(self, quant_config: Fp8Config):
        self.quant_config = quant_config
376
        self.block_quant = self.quant_config.weight_block_size is not None
377
378
379
380
381
382
383

    def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
                       intermediate_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):

        if self.quant_config.is_checkpoint_fp8_serialized:
            params_dtype = torch.float8_e4m3fn
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
        if self.block_quant:
            assert self.quant_config.weight_block_size is not None
            tp_size = get_tensor_model_parallel_world_size()
            block_n, block_k = (
                self.quant_config.weight_block_size[0],
                self.quant_config.weight_block_size[1],
            )
            # NOTE: To ensure proper alignment of the block-wise quantization
            # scales, the output_size of the weights for both the gate and up
            # layers must be divisible by block_n.
            # Required by column parallel or enabling merged weights
            if intermediate_size % block_n != 0:
                raise ValueError(
                    f"The output_size of gate's and up's weight = "
                    f"{intermediate_size} is not divisible by "
                    f"weight quantization block_n = {block_n}.")
            if (tp_size > 1 and intermediate_size % block_k != 0):
                # Required by row parallel
                raise ValueError(f"The input_size of down's weight = "
                                 f"{intermediate_size} is not divisible by "
                                 f"weight quantization block_k = {block_k}.")
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423

        # WEIGHTS
        w13_weight = torch.nn.Parameter(torch.empty(num_experts,
                                                    2 * intermediate_size,
                                                    hidden_size,
                                                    dtype=params_dtype),
                                        requires_grad=False)
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

        w2_weight = torch.nn.Parameter(torch.empty(num_experts,
                                                   hidden_size,
                                                   intermediate_size,
                                                   dtype=params_dtype),
                                       requires_grad=False)
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
        if not self.block_quant:
            # Allocate 2 scales for w1 and w3 respectively.
            # They will be combined to a single scale after weight loading.
            w13_weight_scale = torch.nn.Parameter(torch.ones(
                num_experts, 2, dtype=torch.float32),
                                                  requires_grad=False)
            w2_weight_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                 requires_grad=False)
            layer.register_parameter("w13_weight_scale", w13_weight_scale)
            layer.register_parameter("w2_weight_scale", w2_weight_scale)
        else:
            w13_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
                    2 * ((intermediate_size + block_n - 1) // block_n),
                    (hidden_size + block_k - 1) // block_k,
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            w2_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
                    (hidden_size + block_n - 1) // block_n,
                    (intermediate_size + block_k - 1) // block_k,
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
            layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
            assert self.quant_config.activation_scheme == "dynamic"
457

458
459
460
        # Add the quantization method used (per tensor/grouped/channel)
        # to ensure the weight scales are loaded in properly
        extra_weight_attrs.update(
461
462
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.
             value} if self.block_quant else
463
            {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
464
465
466
467
        # If loading fp8 checkpoint, pass the weight loaders.
        # If loading an fp16 checkpoint, do not (we will quantize in
        #   process_weights_after_loading()
        if self.quant_config.is_checkpoint_fp8_serialized:
468
469
            set_weight_attrs(w13_weight_scale, extra_weight_attrs)
            set_weight_attrs(w2_weight_scale, extra_weight_attrs)
470
471
472
473
474
475
476
477

        # INPUT_SCALES
        if self.quant_config.activation_scheme == "static":
            if not self.quant_config.is_checkpoint_fp8_serialized:
                raise ValueError(
                    "Found static activation scheme for checkpoint that "
                    "was not serialized fp8.")

478
479
480
481
            w13_input_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                 requires_grad=False)
            layer.register_parameter("w13_input_scale", w13_input_scale)
482
            set_weight_attrs(w13_input_scale, extra_weight_attrs)
483
484
485
486
487

            w2_input_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                requires_grad=False)
            layer.register_parameter("w2_input_scale", w2_input_scale)
488
489
            set_weight_attrs(w2_input_scale, extra_weight_attrs)

490
        else:
491
492
            layer.w13_input_scale = None
            layer.w2_input_scale = None
493
494

    def process_weights_after_loading(self, layer: Module) -> None:
495
496
497
        # Block quant doesn't need to process weights after loading
        if self.block_quant:
            return
498
499
        # If checkpoint is fp16, quantize in place.
        if not self.quant_config.is_checkpoint_fp8_serialized:
500
501
            # If rocm, use float8_e4m3fnuz as dtype
            fp8_dtype = torch.float8_e4m3fnuz \
502
                        if current_platform.is_rocm() else torch.float8_e4m3fn
503
            w13_weight = torch.empty_like(layer.w13_weight.data,
504
505
                                          dtype=fp8_dtype)
            w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
506
507
508

            # Re-initialize w13_scale because we directly quantize
            # merged w13 weights and generate a single scaling factor.
509
            layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
510
511
512
                layer.num_experts,
                dtype=torch.float32,
                device=w13_weight.device),
513
                                                        requires_grad=False)
514
            for expert in range(layer.num_experts):
515
                w13_weight[expert, :, :], layer.w13_weight_scale[
516
517
                    expert] = ops.scaled_fp8_quant(
                        layer.w13_weight.data[expert, :, :])
518
                w2_weight[expert, :, :], layer.w2_weight_scale[
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
                    expert] = ops.scaled_fp8_quant(
                        layer.w2_weight.data[expert, :, :])
            layer.w13_weight = torch.nn.Parameter(w13_weight,
                                                  requires_grad=False)
            layer.w2_weight = torch.nn.Parameter(w2_weight,
                                                 requires_grad=False)
            return

        # If checkpoint is fp8, we need to handle that the
        # MoE kernels require single activation scale and single weight
        # scale for w13 per expert.
        else:
            # Fp8 moe kernels require a single activation scale.
            # We take the max of all the scales in case they differ.
            if self.quant_config.activation_scheme == "static":
534
535
                if (layer.w13_input_scale is None
                        or layer.w2_input_scale is None):
536
537
538
                    raise ValueError(
                        "QuantConfig has static quantization, but found "
                        "activation scales are None.")
539
540
                if (not all_close_1d(layer.w13_input_scale)
                        or not all_close_1d(layer.w2_input_scale)):
541
                    logger.warning_once(
542
543
                        "Found input_scales that are not equal for "
                        "fp8 MoE layer. Using the maximum across experts "
544
                        "for each layer.")
545
546
547
548
                layer.w13_input_scale = torch.nn.Parameter(
                    layer.w13_input_scale.max(), requires_grad=False)
                layer.w2_input_scale = torch.nn.Parameter(
                    layer.w2_input_scale.max(), requires_grad=False)
549
            # If rocm, normalize the weights and scales to e4m3fnuz
550
            if current_platform.is_rocm():
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
                # Normalize the weights and scales
                w13_weight, w13_weight_scale, w13_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w13_weight, layer.w13_weight_scale,
                        layer.w13_input_scale)
                w2_weight, w2_weight_scale, w2_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w2_weight, layer.w2_weight_scale,
                        layer.w2_input_scale)
                # Reset the parameter
                layer.w13_weight = torch.nn.Parameter(w13_weight,
                                                      requires_grad=False)
                layer.w13_weight_scale = torch.nn.Parameter(
                    w13_weight_scale, requires_grad=False)
                if w13_input_scale is not None:
                    layer.w13_input_scale = torch.nn.Parameter(
                        w13_input_scale, requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(w2_weight,
                                                     requires_grad=False)
                layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale,
                                                           requires_grad=False)
                if w2_input_scale is not None:
                    layer.w2_input_scale = torch.nn.Parameter(
                        w2_input_scale, requires_grad=False)
575
576
577

            # Fp8 moe kernel needs single weight scale for w13 per expert.
            # We take the max then dequant and requant each expert.
578
            assert layer.w13_weight_scale is not None
579
            shard_size = layer.intermediate_size_per_partition
580
            max_w13_scales = layer.w13_weight_scale.max(dim=1).values
581
582
583
584
585
586
            for expert_id in range(layer.num_experts):
                start = 0
                for shard_id in range(2):
                    dq_weight = per_tensor_dequantize(
                        layer.w13_weight[expert_id][start:start +
                                                    shard_size, :],
587
                        layer.w13_weight_scale[expert_id][shard_id])
588
                    layer.w13_weight[expert_id][
589
                        start:start + shard_size, :], _ = ops.scaled_fp8_quant(
590
591
592
                            dq_weight, max_w13_scales[expert_id])
                    start += shard_size

593
594
            layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
                                                        requires_grad=False)
595
596
            return

597
598
599
600
601
602
603
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
604
        use_grouped_topk: bool = False,
605
606
607
        topk_group: Optional[int] = None,
        num_expert_group: Optional[int] = None,
        custom_routing_function: Optional[Callable] = None,
Simon Mo's avatar
Simon Mo committed
608
609
        scoring_func: str = "softmax",
        e_score_correction_bias: Optional[torch.Tensor] = None,
610
    ) -> torch.Tensor:
611
612
613
614
615
616
617
618
619
        from vllm.model_executor.layers.fused_moe import fused_experts

        topk_weights, topk_ids = FusedMoE.select_experts(
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
620
            num_expert_group=num_expert_group,
Simon Mo's avatar
Simon Mo committed
621
622
623
624
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
            e_score_correction_bias=e_score_correction_bias,
        )
625

626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
        return fused_experts(
            x,
            layer.w13_weight,
            layer.w2_weight,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            inplace=True,
            use_fp8_w8a8=True,
            w1_scale=(layer.w13_weight_scale_inv
                      if self.block_quant else layer.w13_weight_scale),
            w2_scale=(layer.w2_weight_scale_inv
                      if self.block_quant else layer.w2_weight_scale),
            a1_scale=layer.w13_input_scale,
            a2_scale=layer.w2_input_scale,
            block_shape=self.quant_config.weight_block_size,
        )
642
643


644
645
646
class Fp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
647
648
649
    """

    def __init__(self, quant_config: Fp8Config):
650
        super().__init__(quant_config)