fp8.py 12.5 KB
Newer Older
1
from typing import Any, Dict, List, Optional, Union
2
3
4
5
6

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter

7
from vllm import _custom_ops as ops
8
from vllm.logger import init_logger
9
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
10
from vllm.model_executor.layers.quantization.base_config import (
11
    QuantizationConfig, QuantizeMethodBase)
12
from vllm.model_executor.utils import set_weight_attrs
13
from vllm.utils import get_device_capability_stateless, print_warning_once
14

15
16
17
18
ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)

19

20
def cutlass_fp8_supported() -> bool:
21
    capability = get_device_capability_stateless()
22
    capability = capability[0] * 10 + capability[1]
23
24

    return ops.cutlass_scaled_mm_supports_fp8(capability)
25
26


27
class Fp8Config(QuantizationConfig):
28
29
    """Config class for FP8."""

30
31
    def __init__(
        self,
32
        is_checkpoint_fp8_serialized: bool = False,
33
34
        activation_scheme: str = "dynamic",
    ) -> None:
35
36
37
38
39
40
41
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
        if is_checkpoint_fp8_serialized:
            logger.warning("Detected fp8 checkpoint. Please note that the "
                           "format is experimental and subject to change.")
        if activation_scheme not in ACTIVATION_SCHEMES:
            raise ValueError(
                f"Unsupported activation scheme {activation_scheme}")
42
43
        self.activation_scheme = activation_scheme

44
45
46
47
48
49
50
51
52
53
    @classmethod
    def get_name(cls) -> str:
        return "fp8"

    @classmethod
    def get_supported_act_dtypes(cls) -> List[torch.dtype]:
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
54
        return 89
55
56
57
58
59
60

    @classmethod
    def get_config_filenames(cls) -> List[str]:
        return []

    @classmethod
61
    def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
62
63
        quant_method = cls.get_from_keys(config, ["quant_method"])
        is_checkpoint_fp8_serialized = ("fp8" in quant_method)
64
        activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
65
66
        return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
                   activation_scheme=activation_scheme)
67

68
    def get_quant_method(
69
70
71
            self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]:
        from vllm.attention.layer import Attention  # Avoid circular import

72
73
        if isinstance(layer, LinearBase):
            return Fp8LinearMethod(self)
74
75
        if isinstance(layer, Attention):
            return Fp8KVCacheMethod(self)
76
        return None
77
78
79
80
81
82
83

    def get_scaled_act_names(self) -> List[str]:
        return []


class Fp8LinearMethod(LinearMethodBase):
    """Linear method for FP8.
84
85
86
87
88
89
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.
90
91
92
93
94

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
    2. Only support float8_e4m3fn data type due to the limitation of
       torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
95

96
97
98
99
    Args:
        quant_config: The quantization config.
    """

100
    def __init__(self, quant_config: Fp8Config):
101
        self.quant_config = quant_config
102
        self.cutlass_fp8_supported = cutlass_fp8_supported()
103

104
105
106
107
108
109
110
111
112
113
    def _create_scale_param(
        self,
        scale_name: str,
        layer: torch.nn.Module,
        output_partition_sizes: List[int],
        **extra_weight_attrs,
    ) -> None:
        scale = Parameter(torch.empty(len(output_partition_sizes),
                                      dtype=torch.float32),
                          requires_grad=False)
114
        scale[:] = torch.finfo(torch.float8_e4m3fn).min
115
        layer.register_parameter(scale_name, scale)
116
117
118
119
        set_weight_attrs(scale, {
            **extra_weight_attrs,
            "needs_scalar_to_array": True,
        })
120

121
122
123
124
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
125
        output_partition_sizes: List[int],
126
127
128
129
130
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
131
        del input_size, output_size
132
        output_size_per_partition = sum(output_partition_sizes)
133
134
135
136
137
138
139
140

        layer.process_after_load = True
        layer.logical_widths = output_partition_sizes

        # WEIGHT
        weight_dtype = (torch.float8_e4m3fn
                        if self.quant_config.is_checkpoint_fp8_serialized else
                        params_dtype)
141
142
        weight = Parameter(torch.empty(output_size_per_partition,
                                       input_size_per_partition,
143
                                       dtype=weight_dtype),
144
145
                           requires_grad=False)
        layer.register_parameter("weight", weight)
146
147
148
149
150
        set_weight_attrs(weight, {
            **extra_weight_attrs,
            "input_dim": 1,
            "output_dim": 0,
        })
151

152
153
154
155
156
157
158
159
160
161
        # If checkpoint is serialized fp8, load them.
        # Otherwise, wait until process_weights_after_loading.
        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
            self._create_scale_param(
                scale_name="weight_scale",
                layer=layer,
                output_partition_sizes=output_partition_sizes,
                **extra_weight_attrs)

162
            # INPUT ACTIVATION SCALE
163
164
            if self.quant_config.activation_scheme == "static":
                self._create_scale_param(
165
                    scale_name="input_scale",
166
167
168
169
                    layer=layer,
                    output_partition_sizes=output_partition_sizes,
                    **extra_weight_attrs)

170
    def process_weights_after_loading(self, layer: Module) -> None:
171
172
173
174
175
176
177
178
179
180
181
        if (not hasattr(layer, "process_after_load")
                or not layer.process_after_load):
            return

        # If checkpoint is fp/bf16 (not serialized fp8), quantize the weights.
        if not self.quant_config.is_checkpoint_fp8_serialized:
            qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
                                                         scale=None)
            layer.weight = Parameter(qweight.t(), requires_grad=False)
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
            layer.logical_widths = None
182
            layer.input_scale = None
183
184
            return

185
186
187
188
189
190
        # If checkpoint is fp8, requantize the separately quantized logical
        # weights into a single fp8 weight with a single weight scale.
        else:
            # WEIGHT_SCALE / WEIGHT
            #   Loop over logical weights, requantizing with single scale.
            max_w_scale = layer.weight_scale.max()
191

192
193
194
195
196
197
198
199
200
201
202
            # QKV / MLP is fused in the on disk checkpoint if any of the
            # weight scales are still set to the default since we initialize
            # N weight scales for N shards but we only load 1 weight scale
            # from disk in this case. As a result, we skip dequant -> requant
            # since we already have quantized QKV together.
            # Sample Model with fused checkpoint:
            #   * nm-testing/Phi-3-mini-128k-instruct-FP8
            unfused_module_in_checkpoint = (
                layer.weight_scale[-1] > torch.finfo(torch.float8_e4m3fn).min)

            if unfused_module_in_checkpoint:
203
204
205
206
207
208
209
210
211
                start = 0
                for idx, logical_width in enumerate(layer.logical_widths):
                    end = start + logical_width
                    weight_dq = per_tensor_dequantize(
                        layer.weight[start:end, :], layer.weight_scale[idx])

                    layer.weight[start:end, :] = per_tensor_quantize(
                        weight_dq, layer.weight_scale.max())
                    start = end
212
213
214
215
216
217
218
            layer.weight_scale = Parameter(max_w_scale, requires_grad=False)

            # WEIGHT
            #   Transpose weight for passing to torch._scaled_mm
            weight = layer.weight
            layer.weight = Parameter(weight.t(), requires_grad=False)

219
            # INPUT ACTIVATION SCALE
220
            #   Dynamic: set to None (required input to ops.scaled_fp8_quant).
221
            #   Static:  set to max of the input_scales (since they are equal).
222
            if self.quant_config.activation_scheme == "dynamic":
223
                layer.input_scale = None
224
            elif self.quant_config.activation_scheme == "static":
225
226
                layer.input_scale = Parameter(layer.input_scale.max(),
                                              requires_grad=False)
227
228
229
            else:
                raise ValueError(
                    f"Unknown scheme {self.quant_config.activation_scheme}")
230

231
232
233
234
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
235

236
        # ops.scaled_fp8_quant supports both dynamic and static quant.
237
238
        #   If dynamic, layer.input_scale is None and x_scale computed from x.
        #   If static, layer.input_scale is scalar and x_scale is input_scale.
239

240
        if bias is None and self.cutlass_fp8_supported:
241
            qinput, x_scale = ops.scaled_fp8_quant(x, layer.input_scale)
242
243

            # Fused GEMM_DQ
244
            output = ops.cutlass_scaled_mm(
245
246
247
248
249
250
251
252
253
                qinput,
                layer.weight,
                out_dtype=x.dtype,
                scale_a=x_scale,
                scale_b=layer.weight_scale,
            )

        else:
            qinput, x_scale = ops.scaled_fp8_quant(x,
254
                                                   layer.input_scale,
255
256
257
258
259
260
261
262
263
264
265
266
267
268
                                                   batch_dim_padding=17)

            # Fused GEMM_DQ -- note we padded the input above because
            # torch._scaled_mm is more performant for matrices with
            # batch dimension > 16. Note that this could change
            # in the future.
            output, _ = torch._scaled_mm(
                qinput,
                layer.weight,
                out_dtype=x.dtype,
                scale_a=x_scale,
                scale_b=layer.weight_scale,
                bias=bias,
            )
269

270
        return torch.narrow(output, 0, 0, x.shape[0])
271
272


273
274
275
276
277
278
279
280
class Fp8KVCacheMethod(QuantizeMethodBase):
    """Supports loading kv-cache scaling factors from FP8 checkpoints.
    """

    def __init__(self, quant_config: Fp8Config):
        self.quant_config = quant_config

    def create_weights(self, layer: torch.nn.Module):
281
282
        """Create "weight" (aka kv_scale) for an attention layer.

283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
        Args:
            layer: The layer that is using the QuantizeMethodBase factory.
        """
        # Initialize the KV cache scale to 1.0 as the default value.
        # If the kv_scale appears in the checkpoint, it will be
        # overwritten when loading weights.
        layer.kv_scale = Parameter(torch.tensor(1.0), requires_grad=False)

    def apply(self, layer: torch.nn.Module) -> torch.Tensor:
        raise RuntimeError("Fp8KVCacheMethod.apply should not be called.")

    def process_weights_after_loading(self, layer: Module) -> None:
        # If the kv-cache dtype is auto, we enforce the kv-scale to be 1.0
        # regardless whether the kv-scale is available in the checkpoint.
        if layer.kv_cache_dtype != "auto":
            kv_scale = layer.kv_scale.to("cpu").tolist()
            if not isinstance(kv_scale, float):
                raise ValueError("Only support per-tensor scaling factor "
                                 "for fp8 KV cache")
            layer._kv_scale = kv_scale
            if layer._kv_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype:
                print_warning_once(
                    "Using KV cache scaling factor 1.0 for fp8_e4m3. This may "
                    "cause accuracy issues. Please make sure kv-cache scaling "
                    "factor is available in the fp8 checkpoint.")
        del layer.kv_scale


311
def per_tensor_quantize(tensor: torch.Tensor,
312
                        inv_scale: Union[float, torch.Tensor]) -> torch.Tensor:
313
314
315
316
317
    finfo = torch.finfo(torch.float8_e4m3fn)
    qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
    return qweight.to(torch.float8_e4m3fn)


318
319
320
def per_tensor_dequantize(
        tensor: torch.Tensor, inv_scale: Union[float,
                                               torch.Tensor]) -> torch.Tensor:
321
322
323
    fake_qweight = tensor.to(torch.float16)
    dq_weight = fake_qweight * inv_scale
    return dq_weight