fp8.py 31.3 KB
Newer Older
1
from typing import Any, Callable, Dict, List, Optional
2
3
4
5
6

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter

7
import vllm.envs as envs
8
from vllm import _custom_ops as ops
9
from vllm.distributed import get_tensor_model_parallel_world_size
10
from vllm.logger import init_logger
11
12
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
                                                  FusedMoeWeightScaleSupported)
13
14
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
                                               UnquantizedLinearMethod)
15
from vllm.model_executor.layers.quantization.base_config import (
16
    QuantizationConfig, QuantizeMethodBase)
17
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
18
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
19
    apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
20
21
from vllm.model_executor.layers.quantization.utils.quant_utils import (
    is_layer_skipped)
22
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
23
    all_close_1d, apply_fp8_linear, convert_to_channelwise,
24
    cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
25
    requantize_with_max_scale)
26
27
from vllm.model_executor.parameter import (BlockQuantScaleParameter,
                                           ModelWeightParameter,
28
                                           PerTensorScaleParameter)
29
from vllm.model_executor.utils import set_weight_attrs
30
from vllm.platforms import current_platform
31

32
33
34
35
ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)

36

37
class Fp8Config(QuantizationConfig):
38
39
    """Config class for FP8."""

40
41
    def __init__(
        self,
42
        is_checkpoint_fp8_serialized: bool = False,
43
        activation_scheme: str = "dynamic",
44
        ignored_layers: Optional[List[str]] = None,
45
        weight_block_size: Optional[List[int]] = None,
46
    ) -> None:
47
48
49
50
51
52
53
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
        if is_checkpoint_fp8_serialized:
            logger.warning("Detected fp8 checkpoint. Please note that the "
                           "format is experimental and subject to change.")
        if activation_scheme not in ACTIVATION_SCHEMES:
            raise ValueError(
                f"Unsupported activation scheme {activation_scheme}")
54
        self.activation_scheme = activation_scheme
55
        self.ignored_layers = ignored_layers or []
56
57
58
59
60
61
62
63
64
65
66
67
68
69
        if weight_block_size is not None:
            if not is_checkpoint_fp8_serialized:
                raise ValueError(
                    "The block-wise quantization only supports fp8-serialized "
                    "checkpoint for now.")
            if len(weight_block_size) != 2:
                raise ValueError(
                    "The quantization block size of weight must have 2 "
                    f"dimensions, but got {len(weight_block_size)} dimensions")
            if activation_scheme != "dynamic":
                raise ValueError("The block-wise quantization only supports "
                                 "dynamic activation scheme for now, but got "
                                 f"{activation_scheme} activation scheme.")
        self.weight_block_size = weight_block_size
70

71
72
73
74
75
76
77
78
79
80
    @classmethod
    def get_name(cls) -> str:
        return "fp8"

    @classmethod
    def get_supported_act_dtypes(cls) -> List[torch.dtype]:
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
81
        return 80
82
83
84
85
86
87

    @classmethod
    def get_config_filenames(cls) -> List[str]:
        return []

    @classmethod
88
    def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
89
90
        quant_method = cls.get_from_keys(config, ["quant_method"])
        is_checkpoint_fp8_serialized = ("fp8" in quant_method)
91
        activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
92
        ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
93
94
        weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"],
                                                 None)
95
        return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
96
                   activation_scheme=activation_scheme,
97
98
                   ignored_layers=ignored_layers,
                   weight_block_size=weight_block_size)
99

100
101
    def get_quant_method(self, layer: torch.nn.Module,
                         prefix: str) -> Optional["QuantizeMethodBase"]:
102
103
        from vllm.attention.layer import Attention  # Avoid circular import

104
        if isinstance(layer, LinearBase):
105
106
            if is_layer_skipped(prefix, self.ignored_layers):
                return UnquantizedLinearMethod()
107
            return Fp8LinearMethod(self)
108
109
110
        elif isinstance(layer, FusedMoE):
            return Fp8MoEMethod(self)
        elif isinstance(layer, Attention):
111
            return Fp8KVCacheMethod(self)
112
        return None
113
114
115
116


class Fp8LinearMethod(LinearMethodBase):
    """Linear method for FP8.
117
118
119
120
121
122
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.
123
124
125
126
127

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
    2. Only support float8_e4m3fn data type due to the limitation of
       torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
128

129
130
131
132
    Args:
        quant_config: The quantization config.
    """

133
    def __init__(self, quant_config: Fp8Config):
134
        self.quant_config = quant_config
135
        self.cutlass_fp8_supported = cutlass_fp8_supported()
136

137
138
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
139
140
        self.use_marlin = (not current_platform.has_device_capability(89)
                           or envs.VLLM_TEST_FORCE_FP8_MARLIN)
141
        # Disable marlin for rocm
142
        if current_platform.is_rocm():
143
            self.use_marlin = False
144

145
146
147
148
149
        self.block_quant = self.quant_config.weight_block_size is not None
        if self.block_quant:
            # Marlin doesn't support block-wise fp8
            self.use_marlin = False

150
151
152
153
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
154
        output_partition_sizes: List[int],
155
156
157
158
159
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
160
        output_size_per_partition = sum(output_partition_sizes)
161
        weight_loader = extra_weight_attrs.get("weight_loader")
162

163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
        if self.block_quant:
            tp_size = get_tensor_model_parallel_world_size()
            assert self.quant_config.weight_block_size is not None
            block_n, block_k = (
                self.quant_config.weight_block_size[0],
                self.quant_config.weight_block_size[1],
            )
            # Required by row parallel
            if (tp_size > 1
                    and input_size // input_size_per_partition == tp_size
                    and input_size_per_partition % block_k != 0):
                raise ValueError(
                    f"Weight input_size_per_partition = "
                    f"{input_size_per_partition} is not divisible by "
                    f"weight quantization block_k = {block_k}.")
            # Required by column parallel or enabling merged weights
            if (tp_size > 1 and output_size // output_size_per_partition
                    == tp_size) or len(output_partition_sizes) > 1:
                for output_partition_size in output_partition_sizes:
                    if output_partition_size % block_n != 0:
                        raise ValueError(
                            f"Weight output_partition_size = "
                            f"{output_partition_size} is not divisible by "
                            f"weight quantization block_n = {block_n}.")

188
189
        layer.logical_widths = output_partition_sizes

190
191
192
193
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype

194
195
196
197
        # WEIGHT
        weight_dtype = (torch.float8_e4m3fn
                        if self.quant_config.is_checkpoint_fp8_serialized else
                        params_dtype)
198
199
200
201
202
203
204
205

        weight = ModelWeightParameter(data=torch.empty(
            output_size_per_partition,
            input_size_per_partition,
            dtype=weight_dtype),
                                      input_dim=1,
                                      output_dim=0,
                                      weight_loader=weight_loader)
206
207
        layer.register_parameter("weight", weight)

208
209
210
211
        # If checkpoint is serialized fp8, load them.
        # Otherwise, wait until process_weights_after_loading.
        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
            if not self.block_quant:
                scale = PerTensorScaleParameter(
                    data=torch.empty(len(output_partition_sizes),
                                     dtype=torch.float32),
                    weight_loader=weight_loader,
                )
                scale[:] = torch.finfo(torch.float32).min
                layer.register_parameter("weight_scale", scale)
            else:
                assert self.quant_config.activation_scheme == "dynamic"
                scale = BlockQuantScaleParameter(
                    data=torch.empty(
                        (output_size_per_partition + block_n - 1) // block_n,
                        (input_size_per_partition + block_k - 1) // block_k,
                        dtype=torch.float32,
                    ),
                    input_dim=1,
                    output_dim=0,
                    weight_loader=weight_loader,
                )
                scale[:] = torch.finfo(torch.float32).min
                # The weight_scale_inv name is intentional for deepseekv3
                layer.register_parameter("weight_scale_inv", scale)
235

236
            # INPUT ACTIVATION SCALE
237
            if self.quant_config.activation_scheme == "static":
238
239
240
241
242
                scale = PerTensorScaleParameter(data=torch.empty(
                    len(output_partition_sizes), dtype=torch.float32),
                                                weight_loader=weight_loader)

                scale[:] = torch.finfo(torch.float32).min
243
                layer.register_parameter("input_scale", scale)
244
245
            else:
                layer.register_parameter("input_scale", None)
246

247
    def process_weights_after_loading(self, layer: Module) -> None:
248
249
        # Block quant doesn't need to process weights after loading
        if self.block_quant:
250
251
252
253
254
255
256
257
258
            if current_platform.is_rocm():
                weight, weight_scale, _ = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        weight=layer.weight,
                        weight_scale=layer.weight_scale_inv,
                        input_scale=layer.input_scale)
                layer.weight = Parameter(weight, requires_grad=False)
                layer.weight_scale_inv = Parameter(weight_scale,
                                                   requires_grad=False)
259
            return
260
261
        layer.weight = torch.nn.Parameter(layer.weight.data,
                                          requires_grad=False)
262
        # If checkpoint not serialized fp8, quantize the weights.
263
264
265
        if not self.quant_config.is_checkpoint_fp8_serialized:
            qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
                                                         scale=None)
266

267
268
269
270
271
272
273
274
            # If using marlin (w8a16), kernel uses channelwise weights,
            # so extend the weight scales to be channelwise.
            if self.use_marlin:
                assert weight_scale.numel() == 1
                weight_scale = convert_to_channelwise(
                    weight_scale.expand(len(layer.logical_widths)),
                    layer.logical_widths)

275
            # Update the layer with the new values.
276
277
            layer.weight = Parameter(qweight.t(), requires_grad=False)
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
278
            layer.input_scale = None
279

280
281
        # If checkpoint is fp8, handle that there are N scales for N
        # shards in a fused module
282
        else:
283
284
285
286
287
            layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
                                                    requires_grad=False)
            if self.quant_config.activation_scheme == "static":
                layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
                                                       requires_grad=False)
288
289
290
291
292
293
294
295
296
297
298
            # If using marlin (w8a16), kernel uses channelwise weights,
            # so extend the weight scales to be channelwise.
            if self.use_marlin:
                weight = layer.weight
                weight_scale = convert_to_channelwise(layer.weight_scale,
                                                      layer.logical_widths)

            # If using w8a8, torch._scaled_mm needs per tensor, so
            # requantize the logical shards as a single weight.
            else:
                # Dequant -> Quant with max scale so we can run per tensor.
299
300
301
302
                weight = layer.weight
                weight_scale = layer.weight_scale

                # If rocm, use float8_e4m3fnuz.
303
                if current_platform.is_rocm():
304
305
306
307
308
309
310
311
312
                    weight, weight_scale, input_scale = \
                        normalize_e4m3fn_to_e4m3fnuz(
                            weight=weight,
                            weight_scale=weight_scale,
                            input_scale=layer.input_scale)
                    if input_scale is not None:
                        layer.input_scale = Parameter(input_scale,
                                                      requires_grad=False)

313
                weight_scale, weight = requantize_with_max_scale(
314
315
                    weight=weight,
                    weight_scale=weight_scale,
316
317
                    logical_widths=layer.logical_widths,
                )
318

319
            # Update layer with new values.
320
            layer.weight = Parameter(weight.t(), requires_grad=False)
321
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
322
            if self.quant_config.activation_scheme == "static":
323
324
                layer.input_scale = Parameter(layer.input_scale.max(),
                                              requires_grad=False)
325

326
327
328
329
        if self.use_marlin:
            prepare_fp8_layer_for_marlin(layer)
            # Activations not quantized for marlin.
            del layer.input_scale
330

331
332
333
334
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
335

336
        if self.use_marlin:
337
338
339
340
            return apply_fp8_marlin_linear(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale,
341
342
343
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
344
                bias=bias)
345

346
347
348
        # Note: lazy import to avoid triton import error.
        from vllm.model_executor.layers.quantization.utils.fp8_utils import (
            apply_w8a8_block_fp8_linear)
349
350
351
352
353
354
355
356
357
358
359
        if self.block_quant:
            assert self.quant_config.weight_block_size is not None
            return apply_w8a8_block_fp8_linear(
                input=x,
                weight=layer.weight,
                block_size=self.quant_config.weight_block_size,
                weight_scale=layer.weight_scale_inv,
                input_scale=layer.input_scale,
                bias=bias,
            )

360
361
362
363
364
365
        return apply_fp8_linear(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            input_scale=layer.input_scale,
            bias=bias,
366
            cutlass_fp8_supported=self.cutlass_fp8_supported,
367
368
            # Default to using per_token quantization if cutlass is supported
            use_per_token_if_dynamic=self.cutlass_fp8_supported)
369
370


371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
class Fp8MoEMethod(FusedMoEMethodBase):
    """MoE method for FP8.
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

    def __init__(self, quant_config: Fp8Config):
        self.quant_config = quant_config
386
        self.block_quant = self.quant_config.weight_block_size is not None
387
388
389
390
391
392
393

    def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
                       intermediate_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):

        if self.quant_config.is_checkpoint_fp8_serialized:
            params_dtype = torch.float8_e4m3fn
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
        if self.block_quant:
            assert self.quant_config.weight_block_size is not None
            tp_size = get_tensor_model_parallel_world_size()
            block_n, block_k = (
                self.quant_config.weight_block_size[0],
                self.quant_config.weight_block_size[1],
            )
            # NOTE: To ensure proper alignment of the block-wise quantization
            # scales, the output_size of the weights for both the gate and up
            # layers must be divisible by block_n.
            # Required by column parallel or enabling merged weights
            if intermediate_size % block_n != 0:
                raise ValueError(
                    f"The output_size of gate's and up's weight = "
                    f"{intermediate_size} is not divisible by "
                    f"weight quantization block_n = {block_n}.")
            if (tp_size > 1 and intermediate_size % block_k != 0):
                # Required by row parallel
                raise ValueError(f"The input_size of down's weight = "
                                 f"{intermediate_size} is not divisible by "
                                 f"weight quantization block_k = {block_k}.")
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433

        # WEIGHTS
        w13_weight = torch.nn.Parameter(torch.empty(num_experts,
                                                    2 * intermediate_size,
                                                    hidden_size,
                                                    dtype=params_dtype),
                                        requires_grad=False)
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

        w2_weight = torch.nn.Parameter(torch.empty(num_experts,
                                                   hidden_size,
                                                   intermediate_size,
                                                   dtype=params_dtype),
                                       requires_grad=False)
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
        if not self.block_quant:
            # Allocate 2 scales for w1 and w3 respectively.
            # They will be combined to a single scale after weight loading.
            w13_weight_scale = torch.nn.Parameter(torch.ones(
                num_experts, 2, dtype=torch.float32),
                                                  requires_grad=False)
            w2_weight_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                 requires_grad=False)
            layer.register_parameter("w13_weight_scale", w13_weight_scale)
            layer.register_parameter("w2_weight_scale", w2_weight_scale)
        else:
            w13_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
                    2 * ((intermediate_size + block_n - 1) // block_n),
                    (hidden_size + block_k - 1) // block_k,
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            w2_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
                    (hidden_size + block_n - 1) // block_n,
                    (intermediate_size + block_k - 1) // block_k,
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
            layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
            assert self.quant_config.activation_scheme == "dynamic"
467

468
469
470
        # Add the quantization method used (per tensor/grouped/channel)
        # to ensure the weight scales are loaded in properly
        extra_weight_attrs.update(
471
472
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.
             value} if self.block_quant else
473
            {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
474
475
476
477
        # If loading fp8 checkpoint, pass the weight loaders.
        # If loading an fp16 checkpoint, do not (we will quantize in
        #   process_weights_after_loading()
        if self.quant_config.is_checkpoint_fp8_serialized:
478
479
            set_weight_attrs(w13_weight_scale, extra_weight_attrs)
            set_weight_attrs(w2_weight_scale, extra_weight_attrs)
480
481
482
483
484
485
486
487

        # INPUT_SCALES
        if self.quant_config.activation_scheme == "static":
            if not self.quant_config.is_checkpoint_fp8_serialized:
                raise ValueError(
                    "Found static activation scheme for checkpoint that "
                    "was not serialized fp8.")

488
489
490
491
            w13_input_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                 requires_grad=False)
            layer.register_parameter("w13_input_scale", w13_input_scale)
492
            set_weight_attrs(w13_input_scale, extra_weight_attrs)
493
494
495
496
497

            w2_input_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                requires_grad=False)
            layer.register_parameter("w2_input_scale", w2_input_scale)
498
499
            set_weight_attrs(w2_input_scale, extra_weight_attrs)

500
        else:
501
502
            layer.w13_input_scale = None
            layer.w2_input_scale = None
503
504

    def process_weights_after_loading(self, layer: Module) -> None:
505
506
        # Block quant doesn't need to process weights after loading
        if self.block_quant:
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
            if current_platform.is_rocm():
                w13_weight, w13_weight_scale_inv, w13_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w13_weight, layer.w13_weight_scale_inv,
                        layer.w13_input_scale)
                w2_weight, w2_weight_scale_inv, w2_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w2_weight, layer.w2_weight_scale_inv,
                        layer.w2_input_scale)
                # Reset the parameter
                layer.w13_weight = torch.nn.Parameter(w13_weight,
                                                      requires_grad=False)
                layer.w13_weight_scale_inv = torch.nn.Parameter(
                    w13_weight_scale_inv, requires_grad=False)
                if w13_input_scale is not None:
                    layer.w13_input_scale = torch.nn.Parameter(
                        w13_input_scale, requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(w2_weight,
                                                     requires_grad=False)
                layer.w2_weight_scale_inv = torch.nn.Parameter(
                    w2_weight_scale_inv, requires_grad=False)
                if w2_input_scale is not None:
                    layer.w2_input_scale = torch.nn.Parameter(
                        w2_input_scale, requires_grad=False)
531
            return
532
533
        # If checkpoint is fp16, quantize in place.
        if not self.quant_config.is_checkpoint_fp8_serialized:
534
535
            # If rocm, use float8_e4m3fnuz as dtype
            fp8_dtype = torch.float8_e4m3fnuz \
536
                        if current_platform.is_rocm() else torch.float8_e4m3fn
537
            w13_weight = torch.empty_like(layer.w13_weight.data,
538
539
                                          dtype=fp8_dtype)
            w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
540
541
542

            # Re-initialize w13_scale because we directly quantize
            # merged w13 weights and generate a single scaling factor.
543
            layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
544
545
546
                layer.num_experts,
                dtype=torch.float32,
                device=w13_weight.device),
547
                                                        requires_grad=False)
548
            for expert in range(layer.num_experts):
549
                w13_weight[expert, :, :], layer.w13_weight_scale[
550
551
                    expert] = ops.scaled_fp8_quant(
                        layer.w13_weight.data[expert, :, :])
552
                w2_weight[expert, :, :], layer.w2_weight_scale[
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
                    expert] = ops.scaled_fp8_quant(
                        layer.w2_weight.data[expert, :, :])
            layer.w13_weight = torch.nn.Parameter(w13_weight,
                                                  requires_grad=False)
            layer.w2_weight = torch.nn.Parameter(w2_weight,
                                                 requires_grad=False)
            return

        # If checkpoint is fp8, we need to handle that the
        # MoE kernels require single activation scale and single weight
        # scale for w13 per expert.
        else:
            # Fp8 moe kernels require a single activation scale.
            # We take the max of all the scales in case they differ.
            if self.quant_config.activation_scheme == "static":
568
569
                if (layer.w13_input_scale is None
                        or layer.w2_input_scale is None):
570
571
572
                    raise ValueError(
                        "QuantConfig has static quantization, but found "
                        "activation scales are None.")
573
574
                if (not all_close_1d(layer.w13_input_scale)
                        or not all_close_1d(layer.w2_input_scale)):
575
                    logger.warning_once(
576
577
                        "Found input_scales that are not equal for "
                        "fp8 MoE layer. Using the maximum across experts "
578
                        "for each layer.")
579
580
581
582
                layer.w13_input_scale = torch.nn.Parameter(
                    layer.w13_input_scale.max(), requires_grad=False)
                layer.w2_input_scale = torch.nn.Parameter(
                    layer.w2_input_scale.max(), requires_grad=False)
583
            # If rocm, normalize the weights and scales to e4m3fnuz
584
            if current_platform.is_rocm():
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
                # Normalize the weights and scales
                w13_weight, w13_weight_scale, w13_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w13_weight, layer.w13_weight_scale,
                        layer.w13_input_scale)
                w2_weight, w2_weight_scale, w2_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w2_weight, layer.w2_weight_scale,
                        layer.w2_input_scale)
                # Reset the parameter
                layer.w13_weight = torch.nn.Parameter(w13_weight,
                                                      requires_grad=False)
                layer.w13_weight_scale = torch.nn.Parameter(
                    w13_weight_scale, requires_grad=False)
                if w13_input_scale is not None:
                    layer.w13_input_scale = torch.nn.Parameter(
                        w13_input_scale, requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(w2_weight,
                                                     requires_grad=False)
                layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale,
                                                           requires_grad=False)
                if w2_input_scale is not None:
                    layer.w2_input_scale = torch.nn.Parameter(
                        w2_input_scale, requires_grad=False)
609
610
611

            # Fp8 moe kernel needs single weight scale for w13 per expert.
            # We take the max then dequant and requant each expert.
612
            assert layer.w13_weight_scale is not None
613
            shard_size = layer.intermediate_size_per_partition
614
            max_w13_scales = layer.w13_weight_scale.max(dim=1).values
615
616
617
618
619
620
            for expert_id in range(layer.num_experts):
                start = 0
                for shard_id in range(2):
                    dq_weight = per_tensor_dequantize(
                        layer.w13_weight[expert_id][start:start +
                                                    shard_size, :],
621
                        layer.w13_weight_scale[expert_id][shard_id])
622
                    layer.w13_weight[expert_id][
623
                        start:start + shard_size, :], _ = ops.scaled_fp8_quant(
624
625
626
                            dq_weight, max_w13_scales[expert_id])
                    start += shard_size

627
628
            layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
                                                        requires_grad=False)
629
630
            return

631
632
633
634
635
636
637
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
638
        use_grouped_topk: bool = False,
639
640
641
        topk_group: Optional[int] = None,
        num_expert_group: Optional[int] = None,
        custom_routing_function: Optional[Callable] = None,
Simon Mo's avatar
Simon Mo committed
642
643
        scoring_func: str = "softmax",
        e_score_correction_bias: Optional[torch.Tensor] = None,
644
    ) -> torch.Tensor:
645
646
647
648
649
650
651
652
653
        from vllm.model_executor.layers.fused_moe import fused_experts

        topk_weights, topk_ids = FusedMoE.select_experts(
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
654
            num_expert_group=num_expert_group,
Simon Mo's avatar
Simon Mo committed
655
656
657
658
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
            e_score_correction_bias=e_score_correction_bias,
        )
659

660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
        return fused_experts(
            x,
            layer.w13_weight,
            layer.w2_weight,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            inplace=True,
            use_fp8_w8a8=True,
            w1_scale=(layer.w13_weight_scale_inv
                      if self.block_quant else layer.w13_weight_scale),
            w2_scale=(layer.w2_weight_scale_inv
                      if self.block_quant else layer.w2_weight_scale),
            a1_scale=layer.w13_input_scale,
            a2_scale=layer.w2_input_scale,
            block_shape=self.quant_config.weight_block_size,
        )
676
677


678
679
680
class Fp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
681
682
683
    """

    def __init__(self, quant_config: Fp8Config):
684
        super().__init__(quant_config)