fp8.py 13.3 KB
Newer Older
1
from typing import Any, Dict, List, Optional, Tuple, Union
2
3
4
5
6

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter

7
from vllm import _custom_ops as ops
8
from vllm.logger import init_logger
9
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
10
from vllm.model_executor.layers.quantization.base_config import (
11
    QuantizationConfig, QuantizeMethodBase)
12
from vllm.model_executor.utils import set_weight_attrs
13
from vllm.utils import print_warning_once
14

15
16
17
18
ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)

19

20
21
22
def cutlass_fp8_supported() -> bool:
    capability = torch.cuda.get_device_capability()
    capability = capability[0] * 10 + capability[1]
23
24
    major, minor = torch.version.cuda.split(".")
    version = int(major) * 10 + int(minor)
25
26
27
28
29

    # CUTLASS FP8 kernels need at least
    #   CUDA 12.0 on SM90 systems (Hopper)
    #   CUDA 12.4 on SM89 systems (Lovelace)
    gpu_is_supported = False
30
    if capability >= 90:
31
        gpu_is_supported = version > 120
32
    elif capability >= 89:
33
34
35
36
37
        gpu_is_supported = version > 124

    return gpu_is_supported


38
class Fp8Config(QuantizationConfig):
39
40
    """Config class for FP8."""

41
42
    def __init__(
        self,
43
        is_checkpoint_fp8_serialized: bool = False,
44
45
        activation_scheme: str = "dynamic",
    ) -> None:
46
47
48
49
50
51
52
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
        if is_checkpoint_fp8_serialized:
            logger.warning("Detected fp8 checkpoint. Please note that the "
                           "format is experimental and subject to change.")
        if activation_scheme not in ACTIVATION_SCHEMES:
            raise ValueError(
                f"Unsupported activation scheme {activation_scheme}")
53
54
        self.activation_scheme = activation_scheme

55
56
57
58
59
60
61
62
63
64
    @classmethod
    def get_name(cls) -> str:
        return "fp8"

    @classmethod
    def get_supported_act_dtypes(cls) -> List[torch.dtype]:
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
65
        return 89
66
67
68
69
70
71

    @classmethod
    def get_config_filenames(cls) -> List[str]:
        return []

    @classmethod
72
    def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
73
74
        quant_method = cls.get_from_keys(config, ["quant_method"])
        is_checkpoint_fp8_serialized = ("fp8" in quant_method)
75
        activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
76
77
        return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
                   activation_scheme=activation_scheme)
78

79
    def get_quant_method(
80
81
82
            self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]:
        from vllm.attention.layer import Attention  # Avoid circular import

83
84
        if isinstance(layer, LinearBase):
            return Fp8LinearMethod(self)
85
86
        if isinstance(layer, Attention):
            return Fp8KVCacheMethod(self)
87
        return None
88
89
90
91
92
93
94

    def get_scaled_act_names(self) -> List[str]:
        return []


class Fp8LinearMethod(LinearMethodBase):
    """Linear method for FP8.
95
96
97
98
99
100
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.
101
102
103
104
105

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
    2. Only support float8_e4m3fn data type due to the limitation of
       torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
106

107
108
109
110
    Args:
        quant_config: The quantization config.
    """

111
    def __init__(self, quant_config: Fp8Config):
112
        self.quant_config = quant_config
113
        self.cutlass_fp8_supported = cutlass_fp8_supported()
114

115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
    def _create_scale_param(
        self,
        scale_name: str,
        layer: torch.nn.Module,
        output_partition_sizes: List[int],
        **extra_weight_attrs,
    ) -> None:
        scale = Parameter(torch.empty(len(output_partition_sizes),
                                      dtype=torch.float32),
                          requires_grad=False)
        layer.register_parameter(scale_name, scale)
        set_weight_attrs(
            scale, {
                **extra_weight_attrs,
                "fp8_scales_shard_indexer":
                self.scales_shard_indexer,
            })

133
134
135
136
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
137
        output_partition_sizes: List[int],
138
139
140
141
142
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
143
        del input_size, output_size
144
        output_size_per_partition = sum(output_partition_sizes)
145
146
147
148
149
150
151
152

        layer.process_after_load = True
        layer.logical_widths = output_partition_sizes

        # WEIGHT
        weight_dtype = (torch.float8_e4m3fn
                        if self.quant_config.is_checkpoint_fp8_serialized else
                        params_dtype)
153
154
        weight = Parameter(torch.empty(output_size_per_partition,
                                       input_size_per_partition,
155
                                       dtype=weight_dtype),
156
157
                           requires_grad=False)
        layer.register_parameter("weight", weight)
158
159
160
161
162
        set_weight_attrs(weight, {
            **extra_weight_attrs,
            "input_dim": 1,
            "output_dim": 0,
        })
163

164
165
166
167
168
169
170
171
172
173
        # If checkpoint is serialized fp8, load them.
        # Otherwise, wait until process_weights_after_loading.
        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
            self._create_scale_param(
                scale_name="weight_scale",
                layer=layer,
                output_partition_sizes=output_partition_sizes,
                **extra_weight_attrs)

174
            # INPUT ACTIVATION SCALE
175
176
            if self.quant_config.activation_scheme == "static":
                self._create_scale_param(
177
                    scale_name="input_scale",
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
                    layer=layer,
                    output_partition_sizes=output_partition_sizes,
                    **extra_weight_attrs)

    def scales_shard_indexer(
            self, param: torch.Tensor, loaded_weight: torch.Tensor,
            shard_id: Union[str, int]) -> Tuple[torch.Tensor, torch.Tensor]:
        qkv_idxs = {"q": 0, "k": 1, "v": 2}

        if isinstance(shard_id, int):
            pass
        elif isinstance(shard_id, str):
            if shard_id not in qkv_idxs:
                raise ValueError(f"Unknown shard_id: {shard_id}")
            shard_id = qkv_idxs[shard_id]
        else:
            ValueError(f"Shard id must be int or str but got {type(shard_id)}")

        return param[shard_id], loaded_weight
197
198

    def process_weights_after_loading(self, layer: Module) -> None:
199
200
201
202
203
204
205
206
207
208
209
        if (not hasattr(layer, "process_after_load")
                or not layer.process_after_load):
            return

        # If checkpoint is fp/bf16 (not serialized fp8), quantize the weights.
        if not self.quant_config.is_checkpoint_fp8_serialized:
            qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
                                                         scale=None)
            layer.weight = Parameter(qweight.t(), requires_grad=False)
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
            layer.logical_widths = None
210
            layer.input_scale = None
211
212
            return

213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
        # If checkpoint is fp8, requantize the separately quantized logical
        # weights into a single fp8 weight with a single weight scale.
        else:
            # WEIGHT_SCALE / WEIGHT
            #   Loop over logical weights, requantizing with single scale.
            max_w_scale = layer.weight_scale.max()
            start = 0
            for idx, logical_width in enumerate(layer.logical_widths):
                end = start + logical_width
                weight_dq = per_tensor_dequantize(layer.weight[start:end, :],
                                                  layer.weight_scale[idx])

                layer.weight[start:end, :] = per_tensor_quantize(
                    weight_dq, layer.weight_scale.max())
                start = end
            layer.weight_scale = Parameter(max_w_scale, requires_grad=False)

            # WEIGHT
            #   Transpose weight for passing to torch._scaled_mm
            weight = layer.weight
            layer.weight = Parameter(weight.t(), requires_grad=False)

235
            # INPUT ACTIVATION SCALE
236
            #   Dynamic: set to None (required input to ops.scaled_fp8_quant).
237
            #   Static:  set to max of the input_scales (since they are equal).
238
            if self.quant_config.activation_scheme == "dynamic":
239
                layer.input_scale = None
240
            elif self.quant_config.activation_scheme == "static":
241
                if not all_close_1d(layer.input_scale):
242
                    raise ValueError(
243
244
245
246
                        "All the input_scales for the logical weights of a "
                        f"layer must be equal. But got {layer.input_scale}")
                layer.input_scale = Parameter(layer.input_scale.max(),
                                              requires_grad=False)
247
248
249
            else:
                raise ValueError(
                    f"Unknown scheme {self.quant_config.activation_scheme}")
250

251
252
253
254
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
255

256
        # ops.scaled_fp8_quant supports both dynamic and static quant.
257
258
        #   If dynamic, layer.input_scale is None and x_scale computed from x.
        #   If static, layer.input_scale is scalar and x_scale is input_scale.
259

260
261
262
        # Temporarily disable CUTLASS kernels due to an illegal memory access
        #if  bias is None and self.cutlass_fp8_supported:
        if False:
263
            qinput, x_scale = ops.scaled_fp8_quant(x, layer.input_scale)
264
265

            # Fused GEMM_DQ
266
            output = ops.cutlass_scaled_mm(
267
268
269
270
271
272
273
274
275
                qinput,
                layer.weight,
                out_dtype=x.dtype,
                scale_a=x_scale,
                scale_b=layer.weight_scale,
            )

        else:
            qinput, x_scale = ops.scaled_fp8_quant(x,
276
                                                   layer.input_scale,
277
278
279
280
281
282
283
284
285
286
287
288
289
290
                                                   batch_dim_padding=17)

            # Fused GEMM_DQ -- note we padded the input above because
            # torch._scaled_mm is more performant for matrices with
            # batch dimension > 16. Note that this could change
            # in the future.
            output, _ = torch._scaled_mm(
                qinput,
                layer.weight,
                out_dtype=x.dtype,
                scale_a=x_scale,
                scale_b=layer.weight_scale,
                bias=bias,
            )
291

292
        return torch.narrow(output, 0, 0, x.shape[0])
293
294


295
296
297
298
299
300
301
302
class Fp8KVCacheMethod(QuantizeMethodBase):
    """Supports loading kv-cache scaling factors from FP8 checkpoints.
    """

    def __init__(self, quant_config: Fp8Config):
        self.quant_config = quant_config

    def create_weights(self, layer: torch.nn.Module):
303
304
        """Create "weight" (aka kv_scale) for an attention layer.

305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
        Args:
            layer: The layer that is using the QuantizeMethodBase factory.
        """
        # Initialize the KV cache scale to 1.0 as the default value.
        # If the kv_scale appears in the checkpoint, it will be
        # overwritten when loading weights.
        layer.kv_scale = Parameter(torch.tensor(1.0), requires_grad=False)

    def apply(self, layer: torch.nn.Module) -> torch.Tensor:
        raise RuntimeError("Fp8KVCacheMethod.apply should not be called.")

    def process_weights_after_loading(self, layer: Module) -> None:
        # If the kv-cache dtype is auto, we enforce the kv-scale to be 1.0
        # regardless whether the kv-scale is available in the checkpoint.
        if layer.kv_cache_dtype != "auto":
            kv_scale = layer.kv_scale.to("cpu").tolist()
            if not isinstance(kv_scale, float):
                raise ValueError("Only support per-tensor scaling factor "
                                 "for fp8 KV cache")
            layer._kv_scale = kv_scale
            if layer._kv_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype:
                print_warning_once(
                    "Using KV cache scaling factor 1.0 for fp8_e4m3. This may "
                    "cause accuracy issues. Please make sure kv-cache scaling "
                    "factor is available in the fp8 checkpoint.")
        del layer.kv_scale


333
334
335
336
337
338
def all_close_1d(x: torch.Tensor) -> bool:
    assert len(x.shape) == 1
    return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))


def per_tensor_quantize(tensor: torch.Tensor,
339
                        inv_scale: Union[float, torch.Tensor]) -> torch.Tensor:
340
341
342
343
344
    finfo = torch.finfo(torch.float8_e4m3fn)
    qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
    return qweight.to(torch.float8_e4m3fn)


345
346
347
def per_tensor_dequantize(
        tensor: torch.Tensor, inv_scale: Union[float,
                                               torch.Tensor]) -> torch.Tensor:
348
349
350
    fake_qweight = tensor.to(torch.float16)
    dq_weight = fake_qweight * inv_scale
    return dq_weight