"docs/vscode:/vscode.git/clone" did not exist on "d394787e5268903a705850413e494ebf2ddcefb5"
fp8.py 29.3 KB
Newer Older
1
from typing import Any, Callable, Dict, List, Optional
2
3
4
5
6

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter

7
import vllm.envs as envs
8
from vllm import _custom_ops as ops
9
from vllm.distributed import get_tensor_model_parallel_world_size
10
from vllm.logger import init_logger
11
12
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
                                                  FusedMoeWeightScaleSupported)
13
14
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
                                               UnquantizedLinearMethod)
15
from vllm.model_executor.layers.quantization.base_config import (
16
    QuantizationConfig, QuantizeMethodBase)
17
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
18
19
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
    apply_w8a8_block_fp8_linear)
20
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
21
    apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
22
23
from vllm.model_executor.layers.quantization.utils.quant_utils import (
    is_layer_skipped)
24
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
25
    all_close_1d, apply_fp8_linear, convert_to_channelwise,
26
    cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
27
    requantize_with_max_scale)
28
29
from vllm.model_executor.parameter import (BlockQuantScaleParameter,
                                           ModelWeightParameter,
30
                                           PerTensorScaleParameter)
31
from vllm.model_executor.utils import set_weight_attrs
32
from vllm.platforms import current_platform
33
from vllm.utils import print_warning_once
34

35
36
37
38
ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)

39

40
class Fp8Config(QuantizationConfig):
41
42
    """Config class for FP8."""

43
44
    def __init__(
        self,
45
        is_checkpoint_fp8_serialized: bool = False,
46
        activation_scheme: str = "dynamic",
47
        ignored_layers: Optional[List[str]] = None,
48
        weight_block_size: Optional[List[int]] = None,
49
    ) -> None:
50
51
52
53
54
55
56
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
        if is_checkpoint_fp8_serialized:
            logger.warning("Detected fp8 checkpoint. Please note that the "
                           "format is experimental and subject to change.")
        if activation_scheme not in ACTIVATION_SCHEMES:
            raise ValueError(
                f"Unsupported activation scheme {activation_scheme}")
57
        self.activation_scheme = activation_scheme
58
        self.ignored_layers = ignored_layers or []
59
60
61
62
63
64
65
66
67
68
69
70
71
72
        if weight_block_size is not None:
            if not is_checkpoint_fp8_serialized:
                raise ValueError(
                    "The block-wise quantization only supports fp8-serialized "
                    "checkpoint for now.")
            if len(weight_block_size) != 2:
                raise ValueError(
                    "The quantization block size of weight must have 2 "
                    f"dimensions, but got {len(weight_block_size)} dimensions")
            if activation_scheme != "dynamic":
                raise ValueError("The block-wise quantization only supports "
                                 "dynamic activation scheme for now, but got "
                                 f"{activation_scheme} activation scheme.")
        self.weight_block_size = weight_block_size
73

74
75
76
77
78
79
80
81
82
83
    @classmethod
    def get_name(cls) -> str:
        return "fp8"

    @classmethod
    def get_supported_act_dtypes(cls) -> List[torch.dtype]:
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
84
        return 80
85
86
87
88
89
90

    @classmethod
    def get_config_filenames(cls) -> List[str]:
        return []

    @classmethod
91
    def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
92
93
        quant_method = cls.get_from_keys(config, ["quant_method"])
        is_checkpoint_fp8_serialized = ("fp8" in quant_method)
94
        activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
95
        ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
96
97
        weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"],
                                                 None)
98
        return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
99
                   activation_scheme=activation_scheme,
100
101
                   ignored_layers=ignored_layers,
                   weight_block_size=weight_block_size)
102

103
104
    def get_quant_method(self, layer: torch.nn.Module,
                         prefix: str) -> Optional["QuantizeMethodBase"]:
105
106
        from vllm.attention.layer import Attention  # Avoid circular import

107
        if isinstance(layer, LinearBase):
108
109
            if is_layer_skipped(prefix, self.ignored_layers):
                return UnquantizedLinearMethod()
110
            return Fp8LinearMethod(self)
111
112
113
        elif isinstance(layer, FusedMoE):
            return Fp8MoEMethod(self)
        elif isinstance(layer, Attention):
114
            return Fp8KVCacheMethod(self)
115
        return None
116
117
118
119


class Fp8LinearMethod(LinearMethodBase):
    """Linear method for FP8.
120
121
122
123
124
125
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.
126
127
128
129
130

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
    2. Only support float8_e4m3fn data type due to the limitation of
       torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
131

132
133
134
135
    Args:
        quant_config: The quantization config.
    """

136
    def __init__(self, quant_config: Fp8Config):
137
        self.quant_config = quant_config
138
        self.cutlass_fp8_supported = cutlass_fp8_supported()
139

140
141
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
142
143
        self.use_marlin = (not current_platform.has_device_capability(89)
                           or envs.VLLM_TEST_FORCE_FP8_MARLIN)
144
        # Disable marlin for rocm
145
        if current_platform.is_rocm():
146
            self.use_marlin = False
147

148
149
150
151
152
        self.block_quant = self.quant_config.weight_block_size is not None
        if self.block_quant:
            # Marlin doesn't support block-wise fp8
            self.use_marlin = False

153
154
155
156
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
157
        output_partition_sizes: List[int],
158
159
160
161
162
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
163
        output_size_per_partition = sum(output_partition_sizes)
164
        weight_loader = extra_weight_attrs.get("weight_loader")
165

166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
        if self.block_quant:
            tp_size = get_tensor_model_parallel_world_size()
            assert self.quant_config.weight_block_size is not None
            block_n, block_k = (
                self.quant_config.weight_block_size[0],
                self.quant_config.weight_block_size[1],
            )
            # Required by row parallel
            if (tp_size > 1
                    and input_size // input_size_per_partition == tp_size
                    and input_size_per_partition % block_k != 0):
                raise ValueError(
                    f"Weight input_size_per_partition = "
                    f"{input_size_per_partition} is not divisible by "
                    f"weight quantization block_k = {block_k}.")
            # Required by column parallel or enabling merged weights
            if (tp_size > 1 and output_size // output_size_per_partition
                    == tp_size) or len(output_partition_sizes) > 1:
                for output_partition_size in output_partition_sizes:
                    if output_partition_size % block_n != 0:
                        raise ValueError(
                            f"Weight output_partition_size = "
                            f"{output_partition_size} is not divisible by "
                            f"weight quantization block_n = {block_n}.")

191
192
        layer.logical_widths = output_partition_sizes

193
194
195
196
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype

197
198
199
200
        # WEIGHT
        weight_dtype = (torch.float8_e4m3fn
                        if self.quant_config.is_checkpoint_fp8_serialized else
                        params_dtype)
201
202
203
204
205
206
207
208

        weight = ModelWeightParameter(data=torch.empty(
            output_size_per_partition,
            input_size_per_partition,
            dtype=weight_dtype),
                                      input_dim=1,
                                      output_dim=0,
                                      weight_loader=weight_loader)
209
210
        layer.register_parameter("weight", weight)

211
212
213
214
        # If checkpoint is serialized fp8, load them.
        # Otherwise, wait until process_weights_after_loading.
        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
            if not self.block_quant:
                scale = PerTensorScaleParameter(
                    data=torch.empty(len(output_partition_sizes),
                                     dtype=torch.float32),
                    weight_loader=weight_loader,
                )
                scale[:] = torch.finfo(torch.float32).min
                layer.register_parameter("weight_scale", scale)
            else:
                assert self.quant_config.activation_scheme == "dynamic"
                scale = BlockQuantScaleParameter(
                    data=torch.empty(
                        (output_size_per_partition + block_n - 1) // block_n,
                        (input_size_per_partition + block_k - 1) // block_k,
                        dtype=torch.float32,
                    ),
                    input_dim=1,
                    output_dim=0,
                    weight_loader=weight_loader,
                )
                scale[:] = torch.finfo(torch.float32).min
                # The weight_scale_inv name is intentional for deepseekv3
                layer.register_parameter("weight_scale_inv", scale)
238

239
            # INPUT ACTIVATION SCALE
240
            if self.quant_config.activation_scheme == "static":
241
242
243
244
245
                scale = PerTensorScaleParameter(data=torch.empty(
                    len(output_partition_sizes), dtype=torch.float32),
                                                weight_loader=weight_loader)

                scale[:] = torch.finfo(torch.float32).min
246
                layer.register_parameter("input_scale", scale)
247
248
            else:
                layer.register_parameter("input_scale", None)
249

250
    def process_weights_after_loading(self, layer: Module) -> None:
251
252
253
        # Block quant doesn't need to process weights after loading
        if self.block_quant:
            return
254
255
        layer.weight = torch.nn.Parameter(layer.weight.data,
                                          requires_grad=False)
256
        # If checkpoint not serialized fp8, quantize the weights.
257
258
259
        if not self.quant_config.is_checkpoint_fp8_serialized:
            qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
                                                         scale=None)
260

261
262
263
264
265
266
267
268
            # If using marlin (w8a16), kernel uses channelwise weights,
            # so extend the weight scales to be channelwise.
            if self.use_marlin:
                assert weight_scale.numel() == 1
                weight_scale = convert_to_channelwise(
                    weight_scale.expand(len(layer.logical_widths)),
                    layer.logical_widths)

269
            # Update the layer with the new values.
270
271
            layer.weight = Parameter(qweight.t(), requires_grad=False)
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
272
            layer.input_scale = None
273

274
275
        # If checkpoint is fp8, handle that there are N scales for N
        # shards in a fused module
276
        else:
277
278
279
280
281
            layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
                                                    requires_grad=False)
            if self.quant_config.activation_scheme == "static":
                layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
                                                       requires_grad=False)
282
283
284
285
286
287
288
289
290
291
292
            # If using marlin (w8a16), kernel uses channelwise weights,
            # so extend the weight scales to be channelwise.
            if self.use_marlin:
                weight = layer.weight
                weight_scale = convert_to_channelwise(layer.weight_scale,
                                                      layer.logical_widths)

            # If using w8a8, torch._scaled_mm needs per tensor, so
            # requantize the logical shards as a single weight.
            else:
                # Dequant -> Quant with max scale so we can run per tensor.
293
294
295
296
                weight = layer.weight
                weight_scale = layer.weight_scale

                # If rocm, use float8_e4m3fnuz.
297
                if current_platform.is_rocm():
298
299
300
301
302
303
304
305
306
                    weight, weight_scale, input_scale = \
                        normalize_e4m3fn_to_e4m3fnuz(
                            weight=weight,
                            weight_scale=weight_scale,
                            input_scale=layer.input_scale)
                    if input_scale is not None:
                        layer.input_scale = Parameter(input_scale,
                                                      requires_grad=False)

307
                weight_scale, weight = requantize_with_max_scale(
308
309
                    weight=weight,
                    weight_scale=weight_scale,
310
311
                    logical_widths=layer.logical_widths,
                )
312

313
            # Update layer with new values.
314
            layer.weight = Parameter(weight.t(), requires_grad=False)
315
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
316
            if self.quant_config.activation_scheme == "static":
317
318
                layer.input_scale = Parameter(layer.input_scale.max(),
                                              requires_grad=False)
319

320
321
322
323
        if self.use_marlin:
            prepare_fp8_layer_for_marlin(layer)
            # Activations not quantized for marlin.
            del layer.input_scale
324

325
326
327
328
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
329

330
        if self.use_marlin:
331
332
333
334
            return apply_fp8_marlin_linear(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale,
335
336
337
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
338
                bias=bias)
339

340
341
342
343
344
345
346
347
348
349
350
        if self.block_quant:
            assert self.quant_config.weight_block_size is not None
            return apply_w8a8_block_fp8_linear(
                input=x,
                weight=layer.weight,
                block_size=self.quant_config.weight_block_size,
                weight_scale=layer.weight_scale_inv,
                input_scale=layer.input_scale,
                bias=bias,
            )

351
352
353
354
355
356
        return apply_fp8_linear(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            input_scale=layer.input_scale,
            bias=bias,
357
358
            cutlass_fp8_supported=self.cutlass_fp8_supported,
            use_per_token_if_dynamic=False)
359
360


361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
class Fp8MoEMethod(FusedMoEMethodBase):
    """MoE method for FP8.
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

    def __init__(self, quant_config: Fp8Config):
        self.quant_config = quant_config
376
        self.block_quant = self.quant_config.weight_block_size is not None
377
378
379
380
381
382
383

    def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
                       intermediate_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):

        if self.quant_config.is_checkpoint_fp8_serialized:
            params_dtype = torch.float8_e4m3fn
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
        if self.block_quant:
            assert self.quant_config.weight_block_size is not None
            tp_size = get_tensor_model_parallel_world_size()
            block_n, block_k = (
                self.quant_config.weight_block_size[0],
                self.quant_config.weight_block_size[1],
            )
            # NOTE: To ensure proper alignment of the block-wise quantization
            # scales, the output_size of the weights for both the gate and up
            # layers must be divisible by block_n.
            # Required by column parallel or enabling merged weights
            if intermediate_size % block_n != 0:
                raise ValueError(
                    f"The output_size of gate's and up's weight = "
                    f"{intermediate_size} is not divisible by "
                    f"weight quantization block_n = {block_n}.")
            if (tp_size > 1 and intermediate_size % block_k != 0):
                # Required by row parallel
                raise ValueError(f"The input_size of down's weight = "
                                 f"{intermediate_size} is not divisible by "
                                 f"weight quantization block_k = {block_k}.")
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423

        # WEIGHTS
        w13_weight = torch.nn.Parameter(torch.empty(num_experts,
                                                    2 * intermediate_size,
                                                    hidden_size,
                                                    dtype=params_dtype),
                                        requires_grad=False)
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

        w2_weight = torch.nn.Parameter(torch.empty(num_experts,
                                                   hidden_size,
                                                   intermediate_size,
                                                   dtype=params_dtype),
                                       requires_grad=False)
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
        if not self.block_quant:
            # Allocate 2 scales for w1 and w3 respectively.
            # They will be combined to a single scale after weight loading.
            w13_weight_scale = torch.nn.Parameter(torch.ones(
                num_experts, 2, dtype=torch.float32),
                                                  requires_grad=False)
            w2_weight_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                 requires_grad=False)
            layer.register_parameter("w13_weight_scale", w13_weight_scale)
            layer.register_parameter("w2_weight_scale", w2_weight_scale)
        else:
            w13_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
                    2 * ((intermediate_size + block_n - 1) // block_n),
                    (hidden_size + block_k - 1) // block_k,
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            w2_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
                    (hidden_size + block_n - 1) // block_n,
                    (intermediate_size + block_k - 1) // block_k,
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
            layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
            assert self.quant_config.activation_scheme == "dynamic"
457

458
459
460
        # Add the quantization method used (per tensor/grouped/channel)
        # to ensure the weight scales are loaded in properly
        extra_weight_attrs.update(
461
462
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.
             value} if self.block_quant else
463
            {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
464
465
466
467
        # If loading fp8 checkpoint, pass the weight loaders.
        # If loading an fp16 checkpoint, do not (we will quantize in
        #   process_weights_after_loading()
        if self.quant_config.is_checkpoint_fp8_serialized:
468
469
            set_weight_attrs(w13_weight_scale, extra_weight_attrs)
            set_weight_attrs(w2_weight_scale, extra_weight_attrs)
470
471
472
473
474
475
476
477

        # INPUT_SCALES
        if self.quant_config.activation_scheme == "static":
            if not self.quant_config.is_checkpoint_fp8_serialized:
                raise ValueError(
                    "Found static activation scheme for checkpoint that "
                    "was not serialized fp8.")

478
479
480
481
            w13_input_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                 requires_grad=False)
            layer.register_parameter("w13_input_scale", w13_input_scale)
482
            set_weight_attrs(w13_input_scale, extra_weight_attrs)
483
484
485
486
487

            w2_input_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                requires_grad=False)
            layer.register_parameter("w2_input_scale", w2_input_scale)
488
489
            set_weight_attrs(w2_input_scale, extra_weight_attrs)

490
        else:
491
492
            layer.w13_input_scale = None
            layer.w2_input_scale = None
493
494

    def process_weights_after_loading(self, layer: Module) -> None:
495
496
497
        # Block quant doesn't need to process weights after loading
        if self.block_quant:
            return
498
499
        # If checkpoint is fp16, quantize in place.
        if not self.quant_config.is_checkpoint_fp8_serialized:
500
501
            # If rocm, use float8_e4m3fnuz as dtype
            fp8_dtype = torch.float8_e4m3fnuz \
502
                        if current_platform.is_rocm() else torch.float8_e4m3fn
503
            w13_weight = torch.empty_like(layer.w13_weight.data,
504
505
                                          dtype=fp8_dtype)
            w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
506
507
508

            # Re-initialize w13_scale because we directly quantize
            # merged w13 weights and generate a single scaling factor.
509
            layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
510
511
512
                layer.num_experts,
                dtype=torch.float32,
                device=w13_weight.device),
513
                                                        requires_grad=False)
514
            for expert in range(layer.num_experts):
515
                w13_weight[expert, :, :], layer.w13_weight_scale[
516
517
                    expert] = ops.scaled_fp8_quant(
                        layer.w13_weight.data[expert, :, :])
518
                w2_weight[expert, :, :], layer.w2_weight_scale[
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
                    expert] = ops.scaled_fp8_quant(
                        layer.w2_weight.data[expert, :, :])
            layer.w13_weight = torch.nn.Parameter(w13_weight,
                                                  requires_grad=False)
            layer.w2_weight = torch.nn.Parameter(w2_weight,
                                                 requires_grad=False)
            return

        # If checkpoint is fp8, we need to handle that the
        # MoE kernels require single activation scale and single weight
        # scale for w13 per expert.
        else:
            # Fp8 moe kernels require a single activation scale.
            # We take the max of all the scales in case they differ.
            if self.quant_config.activation_scheme == "static":
534
535
                if (layer.w13_input_scale is None
                        or layer.w2_input_scale is None):
536
537
538
                    raise ValueError(
                        "QuantConfig has static quantization, but found "
                        "activation scales are None.")
539
540
                if (not all_close_1d(layer.w13_input_scale)
                        or not all_close_1d(layer.w2_input_scale)):
541
542
543
544
                    print_warning_once(
                        "Found input_scales that are not equal for "
                        "fp8 MoE layer. Using the maximum across experts "
                        "for each layer. ")
545
546
547
548
                layer.w13_input_scale = torch.nn.Parameter(
                    layer.w13_input_scale.max(), requires_grad=False)
                layer.w2_input_scale = torch.nn.Parameter(
                    layer.w2_input_scale.max(), requires_grad=False)
549
            # If rocm, normalize the weights and scales to e4m3fnuz
550
            if current_platform.is_rocm():
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
                # Normalize the weights and scales
                w13_weight, w13_weight_scale, w13_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w13_weight, layer.w13_weight_scale,
                        layer.w13_input_scale)
                w2_weight, w2_weight_scale, w2_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w2_weight, layer.w2_weight_scale,
                        layer.w2_input_scale)
                # Reset the parameter
                layer.w13_weight = torch.nn.Parameter(w13_weight,
                                                      requires_grad=False)
                layer.w13_weight_scale = torch.nn.Parameter(
                    w13_weight_scale, requires_grad=False)
                if w13_input_scale is not None:
                    layer.w13_input_scale = torch.nn.Parameter(
                        w13_input_scale, requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(w2_weight,
                                                     requires_grad=False)
                layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale,
                                                           requires_grad=False)
                if w2_input_scale is not None:
                    layer.w2_input_scale = torch.nn.Parameter(
                        w2_input_scale, requires_grad=False)
575
576
577

            # Fp8 moe kernel needs single weight scale for w13 per expert.
            # We take the max then dequant and requant each expert.
578
            assert layer.w13_weight_scale is not None
579
            shard_size = layer.intermediate_size_per_partition
580
            max_w13_scales = layer.w13_weight_scale.max(dim=1).values
581
582
583
584
585
586
            for expert_id in range(layer.num_experts):
                start = 0
                for shard_id in range(2):
                    dq_weight = per_tensor_dequantize(
                        layer.w13_weight[expert_id][start:start +
                                                    shard_size, :],
587
                        layer.w13_weight_scale[expert_id][shard_id])
588
                    layer.w13_weight[expert_id][
589
                        start:start + shard_size, :], _ = ops.scaled_fp8_quant(
590
591
592
                            dq_weight, max_w13_scales[expert_id])
                    start += shard_size

593
594
            layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
                                                        requires_grad=False)
595
596
            return

597
598
599
600
601
602
603
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
604
        use_grouped_topk: bool = False,
605
606
607
        topk_group: Optional[int] = None,
        num_expert_group: Optional[int] = None,
        custom_routing_function: Optional[Callable] = None,
Simon Mo's avatar
Simon Mo committed
608
609
        scoring_func: str = "softmax",
        e_score_correction_bias: Optional[torch.Tensor] = None,
610
    ) -> torch.Tensor:
611
612
613
614
615
616
617
618
619
        from vllm.model_executor.layers.fused_moe import fused_experts

        topk_weights, topk_ids = FusedMoE.select_experts(
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
620
            num_expert_group=num_expert_group,
Simon Mo's avatar
Simon Mo committed
621
622
623
624
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
            e_score_correction_bias=e_score_correction_bias,
        )
625

626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
        return fused_experts(
            x,
            layer.w13_weight,
            layer.w2_weight,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            inplace=True,
            use_fp8_w8a8=True,
            w1_scale=(layer.w13_weight_scale_inv
                      if self.block_quant else layer.w13_weight_scale),
            w2_scale=(layer.w2_weight_scale_inv
                      if self.block_quant else layer.w2_weight_scale),
            a1_scale=layer.w13_input_scale,
            a2_scale=layer.w2_input_scale,
            block_shape=self.quant_config.weight_block_size,
        )
642
643


644
645
646
class Fp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
647
648
649
    """

    def __init__(self, quant_config: Fp8Config):
650
        super().__init__(quant_config)