fp8.py 22.7 KB
Newer Older
1
from typing import Any, Callable, Dict, List, Optional
2
3
4
5
6

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter

7
import vllm.envs as envs
8
from vllm import _custom_ops as ops
9
from vllm.logger import init_logger
10
11
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
                                                  FusedMoeWeightScaleSupported)
12
13
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
                                               UnquantizedLinearMethod)
14
from vllm.model_executor.layers.quantization.base_config import (
15
    QuantizationConfig, QuantizeMethodBase)
16
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
17
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
18
    apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
19
20
from vllm.model_executor.layers.quantization.utils.quant_utils import (
    is_layer_skipped)
21
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
22
    all_close_1d, apply_fp8_linear, convert_to_channelwise,
23
    cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
24
    requantize_with_max_scale)
25
26
from vllm.model_executor.parameter import (ModelWeightParameter,
                                           PerTensorScaleParameter)
27
from vllm.model_executor.utils import set_weight_attrs
28
from vllm.platforms import current_platform
29
from vllm.utils import print_warning_once
30

31
32
33
34
ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)

35

36
class Fp8Config(QuantizationConfig):
37
38
    """Config class for FP8."""

39
40
    def __init__(
        self,
41
        is_checkpoint_fp8_serialized: bool = False,
42
        activation_scheme: str = "dynamic",
43
        ignored_layers: Optional[List[str]] = None,
44
    ) -> None:
45
46
47
48
49
50
51
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
        if is_checkpoint_fp8_serialized:
            logger.warning("Detected fp8 checkpoint. Please note that the "
                           "format is experimental and subject to change.")
        if activation_scheme not in ACTIVATION_SCHEMES:
            raise ValueError(
                f"Unsupported activation scheme {activation_scheme}")
52
        self.activation_scheme = activation_scheme
53
        self.ignored_layers = ignored_layers or []
54

55
56
57
58
59
60
61
62
63
64
    @classmethod
    def get_name(cls) -> str:
        return "fp8"

    @classmethod
    def get_supported_act_dtypes(cls) -> List[torch.dtype]:
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
65
        return 80
66
67
68
69
70
71

    @classmethod
    def get_config_filenames(cls) -> List[str]:
        return []

    @classmethod
72
    def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
73
74
        quant_method = cls.get_from_keys(config, ["quant_method"])
        is_checkpoint_fp8_serialized = ("fp8" in quant_method)
75
        activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
76
        ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
77
        return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
78
79
                   activation_scheme=activation_scheme,
                   ignored_layers=ignored_layers)
80

81
82
    def get_quant_method(self, layer: torch.nn.Module,
                         prefix: str) -> Optional["QuantizeMethodBase"]:
83
84
        from vllm.attention.layer import Attention  # Avoid circular import

85
        if isinstance(layer, LinearBase):
86
87
            if is_layer_skipped(prefix, self.ignored_layers):
                return UnquantizedLinearMethod()
88
            return Fp8LinearMethod(self)
89
90
91
        elif isinstance(layer, FusedMoE):
            return Fp8MoEMethod(self)
        elif isinstance(layer, Attention):
92
            return Fp8KVCacheMethod(self)
93
        return None
94
95
96
97
98
99
100

    def get_scaled_act_names(self) -> List[str]:
        return []


class Fp8LinearMethod(LinearMethodBase):
    """Linear method for FP8.
101
102
103
104
105
106
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.
107
108
109
110
111

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
    2. Only support float8_e4m3fn data type due to the limitation of
       torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
112

113
114
115
116
    Args:
        quant_config: The quantization config.
    """

117
    def __init__(self, quant_config: Fp8Config):
118
        self.quant_config = quant_config
119
        self.cutlass_fp8_supported = cutlass_fp8_supported()
120

121
122
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
123
124
        self.use_marlin = (not current_platform.has_device_capability(89)
                           or envs.VLLM_TEST_FORCE_FP8_MARLIN)
125
        # Disable marlin for rocm
126
        if current_platform.is_rocm():
127
            self.use_marlin = False
128

129
130
131
132
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
133
        output_partition_sizes: List[int],
134
135
136
137
138
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
139
        del input_size, output_size
140
        output_size_per_partition = sum(output_partition_sizes)
141
        weight_loader = extra_weight_attrs.get("weight_loader")
142
143
144

        layer.logical_widths = output_partition_sizes

145
146
147
148
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype

149
150
151
152
        # WEIGHT
        weight_dtype = (torch.float8_e4m3fn
                        if self.quant_config.is_checkpoint_fp8_serialized else
                        params_dtype)
153
154
155
156
157
158
159
160

        weight = ModelWeightParameter(data=torch.empty(
            output_size_per_partition,
            input_size_per_partition,
            dtype=weight_dtype),
                                      input_dim=1,
                                      output_dim=0,
                                      weight_loader=weight_loader)
161
162
        layer.register_parameter("weight", weight)

163
164
165
166
        # If checkpoint is serialized fp8, load them.
        # Otherwise, wait until process_weights_after_loading.
        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
167
168
169
170
171
            scale = PerTensorScaleParameter(data=torch.empty(
                len(output_partition_sizes), dtype=torch.float32),
                                            weight_loader=weight_loader)

            scale[:] = torch.finfo(torch.float32).min
172
            layer.register_parameter("weight_scale", scale)
173

174
            # INPUT ACTIVATION SCALE
175
            if self.quant_config.activation_scheme == "static":
176
177
178
179
180
                scale = PerTensorScaleParameter(data=torch.empty(
                    len(output_partition_sizes), dtype=torch.float32),
                                                weight_loader=weight_loader)

                scale[:] = torch.finfo(torch.float32).min
181
                layer.register_parameter("input_scale", scale)
182
183
            else:
                layer.register_parameter("input_scale", None)
184

185
    def process_weights_after_loading(self, layer: Module) -> None:
186
187
        layer.weight = torch.nn.Parameter(layer.weight.data,
                                          requires_grad=False)
188
        # If checkpoint not serialized fp8, quantize the weights.
189
190
191
        if not self.quant_config.is_checkpoint_fp8_serialized:
            qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
                                                         scale=None)
192

193
194
195
196
197
198
199
200
            # If using marlin (w8a16), kernel uses channelwise weights,
            # so extend the weight scales to be channelwise.
            if self.use_marlin:
                assert weight_scale.numel() == 1
                weight_scale = convert_to_channelwise(
                    weight_scale.expand(len(layer.logical_widths)),
                    layer.logical_widths)

201
            # Update the layer with the new values.
202
203
            layer.weight = Parameter(qweight.t(), requires_grad=False)
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
204
            layer.input_scale = None
205

206
207
        # If checkpoint is fp8, handle that there are N scales for N
        # shards in a fused module
208
        else:
209
210
211
212
213
            layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
                                                    requires_grad=False)
            if self.quant_config.activation_scheme == "static":
                layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
                                                       requires_grad=False)
214
215
216
217
218
219
220
221
222
223
224
            # If using marlin (w8a16), kernel uses channelwise weights,
            # so extend the weight scales to be channelwise.
            if self.use_marlin:
                weight = layer.weight
                weight_scale = convert_to_channelwise(layer.weight_scale,
                                                      layer.logical_widths)

            # If using w8a8, torch._scaled_mm needs per tensor, so
            # requantize the logical shards as a single weight.
            else:
                # Dequant -> Quant with max scale so we can run per tensor.
225
226
227
228
                weight = layer.weight
                weight_scale = layer.weight_scale

                # If rocm, use float8_e4m3fnuz.
229
                if current_platform.is_rocm():
230
231
232
233
234
235
236
237
238
                    weight, weight_scale, input_scale = \
                        normalize_e4m3fn_to_e4m3fnuz(
                            weight=weight,
                            weight_scale=weight_scale,
                            input_scale=layer.input_scale)
                    if input_scale is not None:
                        layer.input_scale = Parameter(input_scale,
                                                      requires_grad=False)

239
                weight_scale, weight = requantize_with_max_scale(
240
241
                    weight=weight,
                    weight_scale=weight_scale,
242
243
                    logical_widths=layer.logical_widths,
                )
244

245
            # Update layer with new values.
246
            layer.weight = Parameter(weight.t(), requires_grad=False)
247
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
248
            if self.quant_config.activation_scheme == "static":
249
250
                layer.input_scale = Parameter(layer.input_scale.max(),
                                              requires_grad=False)
251

252
253
254
255
        if self.use_marlin:
            prepare_fp8_layer_for_marlin(layer)
            # Activations not quantized for marlin.
            del layer.input_scale
256

257
258
259
260
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
261

262
        if self.use_marlin:
263
264
265
266
            return apply_fp8_marlin_linear(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale,
267
268
269
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
270
                bias=bias)
271

272
273
274
275
276
277
        return apply_fp8_linear(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            input_scale=layer.input_scale,
            bias=bias,
278
279
            cutlass_fp8_supported=self.cutlass_fp8_supported,
            use_per_token_if_dynamic=False)
280
281


282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
class Fp8MoEMethod(FusedMoEMethodBase):
    """MoE method for FP8.
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

    def __init__(self, quant_config: Fp8Config):
        self.quant_config = quant_config

    def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
                       intermediate_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):

        if self.quant_config.is_checkpoint_fp8_serialized:
            params_dtype = torch.float8_e4m3fn

        # WEIGHTS
        w13_weight = torch.nn.Parameter(torch.empty(num_experts,
                                                    2 * intermediate_size,
                                                    hidden_size,
                                                    dtype=params_dtype),
                                        requires_grad=False)
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

        w2_weight = torch.nn.Parameter(torch.empty(num_experts,
                                                   hidden_size,
                                                   intermediate_size,
                                                   dtype=params_dtype),
                                       requires_grad=False)
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
        # Allocate 2 scales for w1 and w3 respectively.
        # They will be combined to a single scale after weight loading.
325
326
327
328
329
        w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
                                                         2,
                                                         dtype=torch.float32),
                                              requires_grad=False)
        layer.register_parameter("w13_weight_scale", w13_weight_scale)
330

331
332
333
334
        w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
                                                        dtype=torch.float32),
                                             requires_grad=False)
        layer.register_parameter("w2_weight_scale", w2_weight_scale)
335
336
337
338
        # Add the quantization method used (per tensor/grouped/channel)
        # to ensure the weight scales are loaded in properly
        extra_weight_attrs.update(
            {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
339
340
341
342
        # If loading fp8 checkpoint, pass the weight loaders.
        # If loading an fp16 checkpoint, do not (we will quantize in
        #   process_weights_after_loading()
        if self.quant_config.is_checkpoint_fp8_serialized:
343
344
            set_weight_attrs(w13_weight_scale, extra_weight_attrs)
            set_weight_attrs(w2_weight_scale, extra_weight_attrs)
345
346
347
348
349
350
351
352

        # INPUT_SCALES
        if self.quant_config.activation_scheme == "static":
            if not self.quant_config.is_checkpoint_fp8_serialized:
                raise ValueError(
                    "Found static activation scheme for checkpoint that "
                    "was not serialized fp8.")

353
354
355
356
            w13_input_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                 requires_grad=False)
            layer.register_parameter("w13_input_scale", w13_input_scale)
357
            set_weight_attrs(w13_input_scale, extra_weight_attrs)
358
359
360
361
362

            w2_input_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                requires_grad=False)
            layer.register_parameter("w2_input_scale", w2_input_scale)
363
364
            set_weight_attrs(w2_input_scale, extra_weight_attrs)

365
        else:
366
367
            layer.w13_input_scale = None
            layer.w2_input_scale = None
368
369
370
371
372

    def process_weights_after_loading(self, layer: Module) -> None:

        # If checkpoint is fp16, quantize in place.
        if not self.quant_config.is_checkpoint_fp8_serialized:
373
374
            # If rocm, use float8_e4m3fnuz as dtype
            fp8_dtype = torch.float8_e4m3fnuz \
375
                        if current_platform.is_rocm() else torch.float8_e4m3fn
376
            w13_weight = torch.empty_like(layer.w13_weight.data,
377
378
                                          dtype=fp8_dtype)
            w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
379
380
381

            # Re-initialize w13_scale because we directly quantize
            # merged w13 weights and generate a single scaling factor.
382
            layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
383
384
385
                layer.num_experts,
                dtype=torch.float32,
                device=w13_weight.device),
386
                                                        requires_grad=False)
387
            for expert in range(layer.num_experts):
388
                w13_weight[expert, :, :], layer.w13_weight_scale[
389
390
                    expert] = ops.scaled_fp8_quant(
                        layer.w13_weight.data[expert, :, :])
391
                w2_weight[expert, :, :], layer.w2_weight_scale[
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
                    expert] = ops.scaled_fp8_quant(
                        layer.w2_weight.data[expert, :, :])
            layer.w13_weight = torch.nn.Parameter(w13_weight,
                                                  requires_grad=False)
            layer.w2_weight = torch.nn.Parameter(w2_weight,
                                                 requires_grad=False)
            return

        # If checkpoint is fp8, we need to handle that the
        # MoE kernels require single activation scale and single weight
        # scale for w13 per expert.
        else:
            # Fp8 moe kernels require a single activation scale.
            # We take the max of all the scales in case they differ.
            if self.quant_config.activation_scheme == "static":
407
408
                if (layer.w13_input_scale is None
                        or layer.w2_input_scale is None):
409
410
411
                    raise ValueError(
                        "QuantConfig has static quantization, but found "
                        "activation scales are None.")
412
413
                if (not all_close_1d(layer.w13_input_scale)
                        or not all_close_1d(layer.w2_input_scale)):
414
415
416
417
                    print_warning_once(
                        "Found input_scales that are not equal for "
                        "fp8 MoE layer. Using the maximum across experts "
                        "for each layer. ")
418
419
420
421
                layer.w13_input_scale = torch.nn.Parameter(
                    layer.w13_input_scale.max(), requires_grad=False)
                layer.w2_input_scale = torch.nn.Parameter(
                    layer.w2_input_scale.max(), requires_grad=False)
422
            # If rocm, normalize the weights and scales to e4m3fnuz
423
            if current_platform.is_rocm():
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
                # Normalize the weights and scales
                w13_weight, w13_weight_scale, w13_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w13_weight, layer.w13_weight_scale,
                        layer.w13_input_scale)
                w2_weight, w2_weight_scale, w2_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w2_weight, layer.w2_weight_scale,
                        layer.w2_input_scale)
                # Reset the parameter
                layer.w13_weight = torch.nn.Parameter(w13_weight,
                                                      requires_grad=False)
                layer.w13_weight_scale = torch.nn.Parameter(
                    w13_weight_scale, requires_grad=False)
                if w13_input_scale is not None:
                    layer.w13_input_scale = torch.nn.Parameter(
                        w13_input_scale, requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(w2_weight,
                                                     requires_grad=False)
                layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale,
                                                           requires_grad=False)
                if w2_input_scale is not None:
                    layer.w2_input_scale = torch.nn.Parameter(
                        w2_input_scale, requires_grad=False)
448
449
450

            # Fp8 moe kernel needs single weight scale for w13 per expert.
            # We take the max then dequant and requant each expert.
451
            assert layer.w13_weight_scale is not None
452
            shard_size = layer.intermediate_size_per_partition
453
            max_w13_scales = layer.w13_weight_scale.max(dim=1).values
454
455
456
457
458
459
            for expert_id in range(layer.num_experts):
                start = 0
                for shard_id in range(2):
                    dq_weight = per_tensor_dequantize(
                        layer.w13_weight[expert_id][start:start +
                                                    shard_size, :],
460
                        layer.w13_weight_scale[expert_id][shard_id])
461
                    layer.w13_weight[expert_id][
462
                        start:start + shard_size, :], _ = ops.scaled_fp8_quant(
463
464
465
                            dq_weight, max_w13_scales[expert_id])
                    start += shard_size

466
467
            layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
                                                        requires_grad=False)
468
469
            return

470
471
472
473
474
475
476
477
478
479
480
481
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
        use_grouped_topk: bool,
        topk_group: Optional[int] = None,
        num_expert_group: Optional[int] = None,
        custom_routing_function: Optional[Callable] = None,
    ) -> torch.Tensor:
482
483
484
485
486
487
488
489
490
491

        from vllm.model_executor.layers.fused_moe import fused_experts

        topk_weights, topk_ids = FusedMoE.select_experts(
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
492
493
            num_expert_group=num_expert_group,
            custom_routing_function=custom_routing_function)
494
495
496
497
498
499
500

        return fused_experts(x,
                             layer.w13_weight,
                             layer.w2_weight,
                             topk_weights=topk_weights,
                             topk_ids=topk_ids,
                             inplace=True,
501
                             use_fp8_w8a8=True,
502
503
504
505
                             w1_scale=layer.w13_weight_scale,
                             w2_scale=layer.w2_weight_scale,
                             a1_scale=layer.w13_input_scale,
                             a2_scale=layer.w2_input_scale)
506
507


508
509
510
class Fp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
511
512
513
    """

    def __init__(self, quant_config: Fp8Config):
514
        super().__init__(quant_config)