fp8.py 21.3 KB
Newer Older
1
from typing import Any, Dict, List, Optional, Union
2
3
4
5
6

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter

7
from vllm import _custom_ops as ops
8
from vllm.logger import init_logger
9
10
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
                                                  fused_moe)
11
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
12
from vllm.model_executor.layers.quantization.base_config import (
13
    QuantizationConfig, QuantizeMethodBase)
14
from vllm.model_executor.utils import set_weight_attrs
15
16
from vllm.platforms import current_platform
from vllm.utils import print_warning_once
17

18
19
20
21
ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)

22

23
def cutlass_fp8_supported() -> bool:
24
    capability = current_platform.get_device_capability()
25
    capability = capability[0] * 10 + capability[1]
26
27

    return ops.cutlass_scaled_mm_supports_fp8(capability)
28
29


30
class Fp8Config(QuantizationConfig):
31
32
    """Config class for FP8."""

33
34
    def __init__(
        self,
35
        is_checkpoint_fp8_serialized: bool = False,
36
37
        activation_scheme: str = "dynamic",
    ) -> None:
38
39
40
41
42
43
44
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
        if is_checkpoint_fp8_serialized:
            logger.warning("Detected fp8 checkpoint. Please note that the "
                           "format is experimental and subject to change.")
        if activation_scheme not in ACTIVATION_SCHEMES:
            raise ValueError(
                f"Unsupported activation scheme {activation_scheme}")
45
46
        self.activation_scheme = activation_scheme

47
48
49
50
51
52
53
54
55
56
    @classmethod
    def get_name(cls) -> str:
        return "fp8"

    @classmethod
    def get_supported_act_dtypes(cls) -> List[torch.dtype]:
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
57
        return 89
58
59
60
61
62
63

    @classmethod
    def get_config_filenames(cls) -> List[str]:
        return []

    @classmethod
64
    def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
65
66
        quant_method = cls.get_from_keys(config, ["quant_method"])
        is_checkpoint_fp8_serialized = ("fp8" in quant_method)
67
        activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
68
69
        return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
                   activation_scheme=activation_scheme)
70

71
    def get_quant_method(
72
73
74
            self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]:
        from vllm.attention.layer import Attention  # Avoid circular import

75
76
        if isinstance(layer, LinearBase):
            return Fp8LinearMethod(self)
77
78
79
        elif isinstance(layer, FusedMoE):
            return Fp8MoEMethod(self)
        elif isinstance(layer, Attention):
80
            return Fp8KVCacheMethod(self)
81
        return None
82
83
84
85
86
87
88

    def get_scaled_act_names(self) -> List[str]:
        return []


class Fp8LinearMethod(LinearMethodBase):
    """Linear method for FP8.
89
90
91
92
93
94
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.
95
96
97
98
99

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
    2. Only support float8_e4m3fn data type due to the limitation of
       torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
100

101
102
103
104
    Args:
        quant_config: The quantization config.
    """

105
    def __init__(self, quant_config: Fp8Config):
106
        self.quant_config = quant_config
107
        self.cutlass_fp8_supported = cutlass_fp8_supported()
108

109
110
111
112
113
114
115
116
117
118
    def _create_scale_param(
        self,
        scale_name: str,
        layer: torch.nn.Module,
        output_partition_sizes: List[int],
        **extra_weight_attrs,
    ) -> None:
        scale = Parameter(torch.empty(len(output_partition_sizes),
                                      dtype=torch.float32),
                          requires_grad=False)
119
        scale[:] = torch.finfo(torch.float8_e4m3fn).min
120
        layer.register_parameter(scale_name, scale)
121
122
123
124
        set_weight_attrs(scale, {
            **extra_weight_attrs,
            "needs_scalar_to_array": True,
        })
125

126
127
128
129
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
130
        output_partition_sizes: List[int],
131
132
133
134
135
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
136
        del input_size, output_size
137
        output_size_per_partition = sum(output_partition_sizes)
138
139
140
141
142
143
144
145

        layer.process_after_load = True
        layer.logical_widths = output_partition_sizes

        # WEIGHT
        weight_dtype = (torch.float8_e4m3fn
                        if self.quant_config.is_checkpoint_fp8_serialized else
                        params_dtype)
146
147
        weight = Parameter(torch.empty(output_size_per_partition,
                                       input_size_per_partition,
148
                                       dtype=weight_dtype),
149
150
                           requires_grad=False)
        layer.register_parameter("weight", weight)
151
152
153
154
155
        set_weight_attrs(weight, {
            **extra_weight_attrs,
            "input_dim": 1,
            "output_dim": 0,
        })
156

157
158
159
160
161
162
163
164
165
166
        # If checkpoint is serialized fp8, load them.
        # Otherwise, wait until process_weights_after_loading.
        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
            self._create_scale_param(
                scale_name="weight_scale",
                layer=layer,
                output_partition_sizes=output_partition_sizes,
                **extra_weight_attrs)

167
            # INPUT ACTIVATION SCALE
168
169
            if self.quant_config.activation_scheme == "static":
                self._create_scale_param(
170
                    scale_name="input_scale",
171
172
173
174
                    layer=layer,
                    output_partition_sizes=output_partition_sizes,
                    **extra_weight_attrs)

175
    def process_weights_after_loading(self, layer: Module) -> None:
176
177
178
179
180
181
182
183
184
185
186
        if (not hasattr(layer, "process_after_load")
                or not layer.process_after_load):
            return

        # If checkpoint is fp/bf16 (not serialized fp8), quantize the weights.
        if not self.quant_config.is_checkpoint_fp8_serialized:
            qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
                                                         scale=None)
            layer.weight = Parameter(qweight.t(), requires_grad=False)
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
            layer.logical_widths = None
187
            layer.input_scale = None
188
189
            return

190
191
192
193
194
195
        # If checkpoint is fp8, requantize the separately quantized logical
        # weights into a single fp8 weight with a single weight scale.
        else:
            # WEIGHT_SCALE / WEIGHT
            #   Loop over logical weights, requantizing with single scale.
            max_w_scale = layer.weight_scale.max()
196

197
198
199
200
201
202
203
204
205
206
207
            # QKV / MLP is fused in the on disk checkpoint if any of the
            # weight scales are still set to the default since we initialize
            # N weight scales for N shards but we only load 1 weight scale
            # from disk in this case. As a result, we skip dequant -> requant
            # since we already have quantized QKV together.
            # Sample Model with fused checkpoint:
            #   * nm-testing/Phi-3-mini-128k-instruct-FP8
            unfused_module_in_checkpoint = (
                layer.weight_scale[-1] > torch.finfo(torch.float8_e4m3fn).min)

            if unfused_module_in_checkpoint:
208
209
210
211
212
213
214
215
216
                start = 0
                for idx, logical_width in enumerate(layer.logical_widths):
                    end = start + logical_width
                    weight_dq = per_tensor_dequantize(
                        layer.weight[start:end, :], layer.weight_scale[idx])

                    layer.weight[start:end, :] = per_tensor_quantize(
                        weight_dq, layer.weight_scale.max())
                    start = end
217
218
219
220
221
222
223
            layer.weight_scale = Parameter(max_w_scale, requires_grad=False)

            # WEIGHT
            #   Transpose weight for passing to torch._scaled_mm
            weight = layer.weight
            layer.weight = Parameter(weight.t(), requires_grad=False)

224
            # INPUT ACTIVATION SCALE
225
            #   Dynamic: set to None (required input to ops.scaled_fp8_quant).
226
            #   Static:  set to max of the input_scales (since they are equal).
227
            if self.quant_config.activation_scheme == "dynamic":
228
                layer.input_scale = None
229
            elif self.quant_config.activation_scheme == "static":
230
231
                layer.input_scale = Parameter(layer.input_scale.max(),
                                              requires_grad=False)
232
233
234
            else:
                raise ValueError(
                    f"Unknown scheme {self.quant_config.activation_scheme}")
235

236
237
238
239
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
240

241
        # ops.scaled_fp8_quant supports both dynamic and static quant.
242
243
        #   If dynamic, layer.input_scale is None and x_scale computed from x.
        #   If static, layer.input_scale is scalar and x_scale is input_scale.
244

245
        if bias is None and self.cutlass_fp8_supported:
246
            qinput, x_scale = ops.scaled_fp8_quant(x, layer.input_scale)
247
248

            # Fused GEMM_DQ
249
            output = ops.cutlass_scaled_mm(
250
251
252
253
254
255
256
257
258
                qinput,
                layer.weight,
                out_dtype=x.dtype,
                scale_a=x_scale,
                scale_b=layer.weight_scale,
            )

        else:
            qinput, x_scale = ops.scaled_fp8_quant(x,
259
                                                   layer.input_scale,
260
261
262
263
264
265
266
267
268
269
270
271
272
273
                                                   batch_dim_padding=17)

            # Fused GEMM_DQ -- note we padded the input above because
            # torch._scaled_mm is more performant for matrices with
            # batch dimension > 16. Note that this could change
            # in the future.
            output, _ = torch._scaled_mm(
                qinput,
                layer.weight,
                out_dtype=x.dtype,
                scale_a=x_scale,
                scale_b=layer.weight_scale,
                bias=bias,
            )
274

275
        return torch.narrow(output, 0, 0, x.shape[0])
276
277


278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
class Fp8MoEMethod(FusedMoEMethodBase):
    """MoE method for FP8.
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

    def __init__(self, quant_config: Fp8Config):
        self.quant_config = quant_config

    def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
                       intermediate_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):

        layer.process_after_load = True

        if self.quant_config.is_checkpoint_fp8_serialized:
            params_dtype = torch.float8_e4m3fn

        # WEIGHTS
        w13_weight = torch.nn.Parameter(torch.empty(num_experts,
                                                    2 * intermediate_size,
                                                    hidden_size,
                                                    dtype=params_dtype),
                                        requires_grad=False)
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

        w2_weight = torch.nn.Parameter(torch.empty(num_experts,
                                                   hidden_size,
                                                   intermediate_size,
                                                   dtype=params_dtype),
                                       requires_grad=False)
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
        # Allocate 2 scales for w1 and w3 respectively.
        # They will be combined to a single scale after weight loading.
        w13_scale = torch.nn.Parameter(torch.ones(num_experts,
                                                  2,
                                                  dtype=torch.float32),
                                       requires_grad=False)
        layer.register_parameter("w13_scale", w13_scale)

        w2_scale = torch.nn.Parameter(torch.ones(num_experts,
                                                 dtype=torch.float32),
                                      requires_grad=False)
        layer.register_parameter("w2_scale", w2_scale)

        # If loading fp8 checkpoint, pass the weight loaders.
        # If loading an fp16 checkpoint, do not (we will quantize in
        #   process_weights_after_loading()
        if self.quant_config.is_checkpoint_fp8_serialized:
            set_weight_attrs(w13_scale, extra_weight_attrs)
            set_weight_attrs(w2_scale, extra_weight_attrs)

        # INPUT_SCALES
        if self.quant_config.activation_scheme == "static":
            if not self.quant_config.is_checkpoint_fp8_serialized:
                raise ValueError(
                    "Found static activation scheme for checkpoint that "
                    "was not serialized fp8.")

            a13_scale = torch.nn.Parameter(torch.ones(num_experts,
                                                      dtype=torch.float32),
                                           requires_grad=False)
            layer.register_parameter("a13_scale", a13_scale)
            set_weight_attrs(a13_scale, extra_weight_attrs)

            a2_scale = torch.nn.Parameter(torch.ones(num_experts,
                                                     dtype=torch.float32),
                                          requires_grad=False)
            layer.register_parameter("a2_scale", a2_scale)
            set_weight_attrs(a2_scale, extra_weight_attrs)
        else:
            layer.a13_scale = None
            layer.a2_scale = None

    def process_weights_after_loading(self, layer: Module) -> None:
        if (not hasattr(layer, "process_after_load")
                or not layer.process_after_load):
            return

        # If checkpoint is fp16, quantize in place.
        if not self.quant_config.is_checkpoint_fp8_serialized:
            w13_weight = torch.empty_like(layer.w13_weight.data,
                                          dtype=torch.float8_e4m3fn)
            w2_weight = torch.empty_like(layer.w2_weight.data,
                                         dtype=torch.float8_e4m3fn)

            # Re-initialize w13_scale because we directly quantize
            # merged w13 weights and generate a single scaling factor.
            layer.w13_scale = torch.nn.Parameter(torch.ones(
                layer.num_experts,
                dtype=torch.float32,
                device=w13_weight.device),
                                                 requires_grad=False)
            for expert in range(layer.num_experts):
                w13_weight[expert, :, :], layer.w13_scale[
                    expert] = ops.scaled_fp8_quant(
                        layer.w13_weight.data[expert, :, :])
                w2_weight[expert, :, :], layer.w2_scale[
                    expert] = ops.scaled_fp8_quant(
                        layer.w2_weight.data[expert, :, :])
            layer.w13_weight = torch.nn.Parameter(w13_weight,
                                                  requires_grad=False)
            layer.w2_weight = torch.nn.Parameter(w2_weight,
                                                 requires_grad=False)
            return

        # If checkpoint is fp8, we need to handle that the
        # MoE kernels require single activation scale and single weight
        # scale for w13 per expert.
        else:
            # Fp8 moe kernels require a single activation scale.
            # We take the max of all the scales in case they differ.
            if self.quant_config.activation_scheme == "static":
                if layer.a13_scale is None or layer.a2_scale is None:
                    raise ValueError(
                        "QuantConfig has static quantization, but found "
                        "activation scales are None.")
                if (not all_close_1d(layer.a13_scale)
                        or not all_close_1d(layer.a2_scale)):
                    print_warning_once(
                        "Found input_scales that are not equal for "
                        "fp8 MoE layer. Using the maximum across experts "
                        "for each layer. ")
                layer.a13_scale = torch.nn.Parameter(layer.a13_scale.max(),
                                                     requires_grad=False)
                layer.a2_scale = torch.nn.Parameter(layer.a2_scale.max(),
                                                    requires_grad=False)

            # Fp8 moe kernel needs single weight scale for w13 per expert.
            # We take the max then dequant and requant each expert.
            assert layer.w13_scale is not None
            shard_size = layer.intermediate_size_per_partition
            max_w13_scales = layer.w13_scale.max(dim=1).values
            for expert_id in range(layer.num_experts):
                start = 0
                for shard_id in range(2):
                    dq_weight = per_tensor_dequantize(
                        layer.w13_weight[expert_id][start:start +
                                                    shard_size, :],
                        layer.w13_scale[expert_id][shard_id])
                    layer.w13_weight[expert_id][
                        start:start + shard_size, :] = per_tensor_quantize(
                            dq_weight, max_w13_scales[expert_id])
                    start += shard_size

            layer.w13_scale = torch.nn.Parameter(max_w13_scales,
                                                 requires_grad=False)
            return

    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              router_logits: torch.Tensor,
              top_k: int,
              renormalize: bool = True) -> torch.Tensor:

        return fused_moe(x,
                         layer.w13_weight,
                         layer.w2_weight,
                         router_logits,
                         top_k,
                         renormalize=renormalize,
                         inplace=True,
                         use_fp8=True,
                         w1_scale=layer.w13_scale,
                         w2_scale=layer.w2_scale,
                         a1_scale=layer.a13_scale,
                         a2_scale=layer.a2_scale)


459
460
461
462
463
464
465
466
class Fp8KVCacheMethod(QuantizeMethodBase):
    """Supports loading kv-cache scaling factors from FP8 checkpoints.
    """

    def __init__(self, quant_config: Fp8Config):
        self.quant_config = quant_config

    def create_weights(self, layer: torch.nn.Module):
467
468
        """Create "weight" (aka kv_scale) for an attention layer.

469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
        Args:
            layer: The layer that is using the QuantizeMethodBase factory.
        """
        # Initialize the KV cache scale to 1.0 as the default value.
        # If the kv_scale appears in the checkpoint, it will be
        # overwritten when loading weights.
        layer.kv_scale = Parameter(torch.tensor(1.0), requires_grad=False)

    def apply(self, layer: torch.nn.Module) -> torch.Tensor:
        raise RuntimeError("Fp8KVCacheMethod.apply should not be called.")

    def process_weights_after_loading(self, layer: Module) -> None:
        # If the kv-cache dtype is auto, we enforce the kv-scale to be 1.0
        # regardless whether the kv-scale is available in the checkpoint.
        if layer.kv_cache_dtype != "auto":
            kv_scale = layer.kv_scale.to("cpu").tolist()
            if not isinstance(kv_scale, float):
                raise ValueError("Only support per-tensor scaling factor "
                                 "for fp8 KV cache")
            layer._kv_scale = kv_scale
            if layer._kv_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype:
                print_warning_once(
                    "Using KV cache scaling factor 1.0 for fp8_e4m3. This may "
                    "cause accuracy issues. Please make sure kv-cache scaling "
                    "factor is available in the fp8 checkpoint.")
        del layer.kv_scale


497
def per_tensor_quantize(tensor: torch.Tensor,
498
                        inv_scale: Union[float, torch.Tensor]) -> torch.Tensor:
499
500
501
502
503
    finfo = torch.finfo(torch.float8_e4m3fn)
    qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
    return qweight.to(torch.float8_e4m3fn)


504
505
506
def per_tensor_dequantize(
        tensor: torch.Tensor, inv_scale: Union[float,
                                               torch.Tensor]) -> torch.Tensor:
507
508
509
    fake_qweight = tensor.to(torch.float16)
    dq_weight = fake_qweight * inv_scale
    return dq_weight
510
511
512
513
514


def all_close_1d(x: torch.Tensor) -> bool:
    assert len(x.shape) == 1
    return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))