"vscode:/vscode.git/clone" did not exist on "7be141b2c52366f0cc4e731c36819aed178d2258"
fp8.py 19.9 KB
Newer Older
1
from typing import Any, Dict, List, Optional
2
3
4
5
6

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter

7
from vllm import _custom_ops as ops
8
from vllm.logger import init_logger
9
10
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
                                                  fused_moe)
11
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
12
from vllm.model_executor.layers.quantization.base_config import (
13
    QuantizationConfig, QuantizeMethodBase)
14
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
15
16
17
18
    apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
    all_close_1d, apply_fp8_linear, create_per_tensor_scale_param,
    cutlass_fp8_supported, per_tensor_dequantize, requantize_with_max_scale)
19
from vllm.model_executor.utils import set_weight_attrs
20
21
from vllm.platforms import current_platform
from vllm.utils import print_warning_once
22

23
24
25
26
ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)

27

28
class Fp8Config(QuantizationConfig):
29
30
    """Config class for FP8."""

31
32
    def __init__(
        self,
33
        is_checkpoint_fp8_serialized: bool = False,
34
35
        activation_scheme: str = "dynamic",
    ) -> None:
36
37
38
39
40
41
42
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
        if is_checkpoint_fp8_serialized:
            logger.warning("Detected fp8 checkpoint. Please note that the "
                           "format is experimental and subject to change.")
        if activation_scheme not in ACTIVATION_SCHEMES:
            raise ValueError(
                f"Unsupported activation scheme {activation_scheme}")
43
44
        self.activation_scheme = activation_scheme

45
46
47
48
49
50
51
52
53
54
    @classmethod
    def get_name(cls) -> str:
        return "fp8"

    @classmethod
    def get_supported_act_dtypes(cls) -> List[torch.dtype]:
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
55
        return 80
56
57
58
59
60
61

    @classmethod
    def get_config_filenames(cls) -> List[str]:
        return []

    @classmethod
62
    def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
63
64
        quant_method = cls.get_from_keys(config, ["quant_method"])
        is_checkpoint_fp8_serialized = ("fp8" in quant_method)
65
        activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
66
67
        return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
                   activation_scheme=activation_scheme)
68

69
    def get_quant_method(
70
71
72
            self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]:
        from vllm.attention.layer import Attention  # Avoid circular import

73
74
        if isinstance(layer, LinearBase):
            return Fp8LinearMethod(self)
75
76
77
        elif isinstance(layer, FusedMoE):
            return Fp8MoEMethod(self)
        elif isinstance(layer, Attention):
78
            return Fp8KVCacheMethod(self)
79
        return None
80
81
82
83
84
85
86

    def get_scaled_act_names(self) -> List[str]:
        return []


class Fp8LinearMethod(LinearMethodBase):
    """Linear method for FP8.
87
88
89
90
91
92
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.
93
94
95
96
97

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
    2. Only support float8_e4m3fn data type due to the limitation of
       torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
98

99
100
101
102
    Args:
        quant_config: The quantization config.
    """

103
    def __init__(self, quant_config: Fp8Config):
104
        self.quant_config = quant_config
105
        self.cutlass_fp8_supported = cutlass_fp8_supported()
106

107
108
109
110
111
112
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
        capability = current_platform.get_device_capability()
        capability = capability[0] * 10 + capability[1]
        self.use_marlin = capability < 89

113
114
115
116
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
117
        output_partition_sizes: List[int],
118
119
120
121
122
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
123
        del input_size, output_size
124
        output_size_per_partition = sum(output_partition_sizes)
125
126
127

        layer.logical_widths = output_partition_sizes

128
129
130
131
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype

132
133
134
135
        # WEIGHT
        weight_dtype = (torch.float8_e4m3fn
                        if self.quant_config.is_checkpoint_fp8_serialized else
                        params_dtype)
136
137
        weight = Parameter(torch.empty(output_size_per_partition,
                                       input_size_per_partition,
138
                                       dtype=weight_dtype),
139
140
                           requires_grad=False)
        layer.register_parameter("weight", weight)
141
142
143
144
145
        set_weight_attrs(weight, {
            **extra_weight_attrs,
            "input_dim": 1,
            "output_dim": 0,
        })
146

147
148
149
150
        # If checkpoint is serialized fp8, load them.
        # Otherwise, wait until process_weights_after_loading.
        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
151
152
153
            scale = create_per_tensor_scale_param(output_partition_sizes,
                                                  **extra_weight_attrs)
            layer.register_parameter("weight_scale", scale)
154

155
            # INPUT ACTIVATION SCALE
156
            if self.quant_config.activation_scheme == "static":
157
158
159
                scale = create_per_tensor_scale_param(output_partition_sizes,
                                                      **extra_weight_attrs)
                layer.register_parameter("input_scale", scale)
160

161
    def process_weights_after_loading(self, layer: Module) -> None:
162
        # If checkpoint not serialized fp8, quantize the weights.
163
164
165
        if not self.quant_config.is_checkpoint_fp8_serialized:
            qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
                                                         scale=None)
166
167

            # Update the layer with the new values.
168
169
            layer.weight = Parameter(qweight.t(), requires_grad=False)
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
170
            layer.input_scale = None
171

172
173
174
        # If checkpoint is fp8, requantize the separately quantized logical
        # weights into a single fp8 weight with a single weight scale.
        else:
175
176
177
178
179
180
            # Dequant -> Quant with max scale.
            max_w_scale, weight = requantize_with_max_scale(
                weight=layer.weight,
                weight_scale=layer.weight_scale,
                logical_widths=layer.logical_widths,
            )
181

182
            # Update layer with new values.
183
            layer.weight = Parameter(weight.t(), requires_grad=False)
184
185
            layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
            if self.quant_config.activation_scheme == "static":
186
187
                layer.input_scale = Parameter(layer.input_scale.max(),
                                              requires_grad=False)
188
            else:
189
                layer.input_scale = None
190

191
192
193
194
        if self.use_marlin:
            prepare_fp8_layer_for_marlin(layer)
            # Activations not quantized for marlin.
            del layer.input_scale
195

196
197
198
199
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
200

201
        if self.use_marlin:
202
203
204
205
            return apply_fp8_marlin_linear(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale,
206
207
208
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
209
                bias=bias)
210

211
212
213
214
215
216
217
        return apply_fp8_linear(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            input_scale=layer.input_scale,
            bias=bias,
            cutlass_fp8_supported=self.cutlass_fp8_supported)
218
219


220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
class Fp8MoEMethod(FusedMoEMethodBase):
    """MoE method for FP8.
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

    def __init__(self, quant_config: Fp8Config):
        self.quant_config = quant_config

    def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
                       intermediate_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):

        if self.quant_config.is_checkpoint_fp8_serialized:
            params_dtype = torch.float8_e4m3fn

        # WEIGHTS
        w13_weight = torch.nn.Parameter(torch.empty(num_experts,
                                                    2 * intermediate_size,
                                                    hidden_size,
                                                    dtype=params_dtype),
                                        requires_grad=False)
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

        w2_weight = torch.nn.Parameter(torch.empty(num_experts,
                                                   hidden_size,
                                                   intermediate_size,
                                                   dtype=params_dtype),
                                       requires_grad=False)
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
        # Allocate 2 scales for w1 and w3 respectively.
        # They will be combined to a single scale after weight loading.
        w13_scale = torch.nn.Parameter(torch.ones(num_experts,
                                                  2,
                                                  dtype=torch.float32),
                                       requires_grad=False)
        layer.register_parameter("w13_scale", w13_scale)

        w2_scale = torch.nn.Parameter(torch.ones(num_experts,
                                                 dtype=torch.float32),
                                      requires_grad=False)
        layer.register_parameter("w2_scale", w2_scale)

        # If loading fp8 checkpoint, pass the weight loaders.
        # If loading an fp16 checkpoint, do not (we will quantize in
        #   process_weights_after_loading()
        if self.quant_config.is_checkpoint_fp8_serialized:
            set_weight_attrs(w13_scale, extra_weight_attrs)
            set_weight_attrs(w2_scale, extra_weight_attrs)

        # INPUT_SCALES
        if self.quant_config.activation_scheme == "static":
            if not self.quant_config.is_checkpoint_fp8_serialized:
                raise ValueError(
                    "Found static activation scheme for checkpoint that "
                    "was not serialized fp8.")

            a13_scale = torch.nn.Parameter(torch.ones(num_experts,
                                                      dtype=torch.float32),
                                           requires_grad=False)
            layer.register_parameter("a13_scale", a13_scale)
            set_weight_attrs(a13_scale, extra_weight_attrs)

            a2_scale = torch.nn.Parameter(torch.ones(num_experts,
                                                     dtype=torch.float32),
                                          requires_grad=False)
            layer.register_parameter("a2_scale", a2_scale)
            set_weight_attrs(a2_scale, extra_weight_attrs)
        else:
            layer.a13_scale = None
            layer.a2_scale = None

    def process_weights_after_loading(self, layer: Module) -> None:

        # If checkpoint is fp16, quantize in place.
        if not self.quant_config.is_checkpoint_fp8_serialized:
            w13_weight = torch.empty_like(layer.w13_weight.data,
                                          dtype=torch.float8_e4m3fn)
            w2_weight = torch.empty_like(layer.w2_weight.data,
                                         dtype=torch.float8_e4m3fn)

            # Re-initialize w13_scale because we directly quantize
            # merged w13 weights and generate a single scaling factor.
            layer.w13_scale = torch.nn.Parameter(torch.ones(
                layer.num_experts,
                dtype=torch.float32,
                device=w13_weight.device),
                                                 requires_grad=False)
            for expert in range(layer.num_experts):
                w13_weight[expert, :, :], layer.w13_scale[
                    expert] = ops.scaled_fp8_quant(
                        layer.w13_weight.data[expert, :, :])
                w2_weight[expert, :, :], layer.w2_scale[
                    expert] = ops.scaled_fp8_quant(
                        layer.w2_weight.data[expert, :, :])
            layer.w13_weight = torch.nn.Parameter(w13_weight,
                                                  requires_grad=False)
            layer.w2_weight = torch.nn.Parameter(w2_weight,
                                                 requires_grad=False)
            return

        # If checkpoint is fp8, we need to handle that the
        # MoE kernels require single activation scale and single weight
        # scale for w13 per expert.
        else:
            # Fp8 moe kernels require a single activation scale.
            # We take the max of all the scales in case they differ.
            if self.quant_config.activation_scheme == "static":
                if layer.a13_scale is None or layer.a2_scale is None:
                    raise ValueError(
                        "QuantConfig has static quantization, but found "
                        "activation scales are None.")
                if (not all_close_1d(layer.a13_scale)
                        or not all_close_1d(layer.a2_scale)):
                    print_warning_once(
                        "Found input_scales that are not equal for "
                        "fp8 MoE layer. Using the maximum across experts "
                        "for each layer. ")
                layer.a13_scale = torch.nn.Parameter(layer.a13_scale.max(),
                                                     requires_grad=False)
                layer.a2_scale = torch.nn.Parameter(layer.a2_scale.max(),
                                                    requires_grad=False)

            # Fp8 moe kernel needs single weight scale for w13 per expert.
            # We take the max then dequant and requant each expert.
            assert layer.w13_scale is not None
            shard_size = layer.intermediate_size_per_partition
            max_w13_scales = layer.w13_scale.max(dim=1).values
            for expert_id in range(layer.num_experts):
                start = 0
                for shard_id in range(2):
                    dq_weight = per_tensor_dequantize(
                        layer.w13_weight[expert_id][start:start +
                                                    shard_size, :],
                        layer.w13_scale[expert_id][shard_id])
                    layer.w13_weight[expert_id][
367
                        start:start + shard_size, :], _ = ops.scaled_fp8_quant(
368
369
370
371
372
373
374
375
376
377
378
379
                            dq_weight, max_w13_scales[expert_id])
                    start += shard_size

            layer.w13_scale = torch.nn.Parameter(max_w13_scales,
                                                 requires_grad=False)
            return

    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              router_logits: torch.Tensor,
              top_k: int,
380
381
382
383
              renormalize: bool = True,
              use_grouped_topk: bool = False,
              num_expert_group: Optional[int] = None,
              topk_group: Optional[int] = None) -> torch.Tensor:
384
385
386
387
388
389
390
391
392
393
394
395

        return fused_moe(x,
                         layer.w13_weight,
                         layer.w2_weight,
                         router_logits,
                         top_k,
                         renormalize=renormalize,
                         inplace=True,
                         use_fp8=True,
                         w1_scale=layer.w13_scale,
                         w2_scale=layer.w2_scale,
                         a1_scale=layer.a13_scale,
396
397
398
399
                         a2_scale=layer.a2_scale,
                         use_grouped_topk=use_grouped_topk,
                         num_expert_group=num_expert_group,
                         topk_group=topk_group)
400
401


402
403
404
405
406
407
408
409
class Fp8KVCacheMethod(QuantizeMethodBase):
    """Supports loading kv-cache scaling factors from FP8 checkpoints.
    """

    def __init__(self, quant_config: Fp8Config):
        self.quant_config = quant_config

    def create_weights(self, layer: torch.nn.Module):
410
        """Create "weight" (aka k_scale and v_scale) for an attention layer.
411

412
413
414
        Args:
            layer: The layer that is using the QuantizeMethodBase factory.
        """
415
416
        # Initialize the KV cache scales to -1.0, which is an invalid value.
        # If the k/v_scale appears in the checkpoint, it will be
417
        # overwritten when loading weights.
418
419
        layer.k_scale = Parameter(torch.tensor(-1.0), requires_grad=False)
        layer.v_scale = Parameter(torch.tensor(-1.0), requires_grad=False)
420
421
422
423
424

    def apply(self, layer: torch.nn.Module) -> torch.Tensor:
        raise RuntimeError("Fp8KVCacheMethod.apply should not be called.")

    def process_weights_after_loading(self, layer: Module) -> None:
425
        # If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0
426
427
        # regardless whether the kv-scale is available in the checkpoint.
        if layer.kv_cache_dtype != "auto":
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
            if layer.k_scale > 0.0 and layer.v_scale > 0.0:
                # We prefer to use separate k_scale and v_scale if present
                k_scale = layer.k_scale.to("cpu").tolist()
                v_scale = layer.v_scale.to("cpu").tolist()
            elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
                # If no scales were loaded (both scales are invalid negative
                # values), use the default value of 1.0
                k_scale = Parameter(torch.tensor(1.0), requires_grad=False)
                v_scale = Parameter(torch.tensor(1.0), requires_grad=False)
            else:
                # If we find a single kv_scale in the checkpoint, we remap
                # kv_scale to k_scale during weight loading, and duplicate
                # k_scale to v_scale here
                assert layer.k_scale > 0.0
                scale_to_duplicate = max(layer.k_scale, layer.v_scale)
                k_scale = scale_to_duplicate.to("cpu").tolist()
                v_scale = scale_to_duplicate.to("cpu").tolist()

            if not isinstance(k_scale, float) or not isinstance(
                    v_scale, float):
448
449
                raise ValueError("Only support per-tensor scaling factor "
                                 "for fp8 KV cache")
450
451
452
453
454
455

            # These are used in the final Attention.forward()
            layer._k_scale = k_scale
            layer._v_scale = v_scale
            if (layer._k_scale == 1.0 and layer._v_scale == 1.0
                    and "e5m2" not in layer.kv_cache_dtype):
456
                print_warning_once(
457
458
459
460
461
462
                    "Using KV cache scaling factor 1.0 for fp8_e4m3. This "
                    "may cause accuracy issues. Please make sure k/v_scale "
                    "scaling factors are available in the fp8 checkpoint.")

        del layer.k_scale
        del layer.v_scale