fp8.py 9.83 KB
Newer Older
1
from typing import Any, Dict, List, Optional, Tuple, Union
2
3
4
5
6

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter

7
from vllm import _custom_ops as ops
8
from vllm.logger import init_logger
9
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
10
from vllm.model_executor.layers.quantization.base_config import (
11
    QuantizationConfig)
12
from vllm.model_executor.utils import set_weight_attrs
13

14
15
16
17
ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)

18

19
class Fp8Config(QuantizationConfig):
20
21
    """Config class for FP8."""

22
23
    def __init__(
        self,
24
        is_checkpoint_fp8_serialized: bool = False,
25
26
        activation_scheme: str = "dynamic",
    ) -> None:
27
28
29
30
31
32
33
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
        if is_checkpoint_fp8_serialized:
            logger.warning("Detected fp8 checkpoint. Please note that the "
                           "format is experimental and subject to change.")
        if activation_scheme not in ACTIVATION_SCHEMES:
            raise ValueError(
                f"Unsupported activation scheme {activation_scheme}")
34
35
        self.activation_scheme = activation_scheme

36
37
38
39
40
41
42
43
44
45
    @classmethod
    def get_name(cls) -> str:
        return "fp8"

    @classmethod
    def get_supported_act_dtypes(cls) -> List[torch.dtype]:
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
46
        return 89
47
48
49
50
51
52

    @classmethod
    def get_config_filenames(cls) -> List[str]:
        return []

    @classmethod
53
    def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
54
55
        quant_method = cls.get_from_keys(config, ["quant_method"])
        is_checkpoint_fp8_serialized = ("fp8" in quant_method)
56
        activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
57
58
        return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
                   activation_scheme=activation_scheme)
59

60
    def get_quant_method(
61
            self, layer: torch.nn.Module) -> Optional["Fp8LinearMethod"]:
62
63
64
        if isinstance(layer, LinearBase):
            return Fp8LinearMethod(self)
        return None
65
66
67
68
69
70
71

    def get_scaled_act_names(self) -> List[str]:
        return []


class Fp8LinearMethod(LinearMethodBase):
    """Linear method for FP8.
72
73
74
75
76
77
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.
78
79
80
81
82
83
84
85
86
87

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
    2. Only support float8_e4m3fn data type due to the limitation of
       torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
       
    Args:
        quant_config: The quantization config.
    """

88
    def __init__(self, quant_config: Fp8Config):
89
90
        self.quant_config = quant_config

91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
    def _create_scale_param(
        self,
        scale_name: str,
        layer: torch.nn.Module,
        output_partition_sizes: List[int],
        **extra_weight_attrs,
    ) -> None:
        scale = Parameter(torch.empty(len(output_partition_sizes),
                                      dtype=torch.float32),
                          requires_grad=False)
        layer.register_parameter(scale_name, scale)
        set_weight_attrs(
            scale, {
                **extra_weight_attrs,
                "fp8_scales_shard_indexer":
                self.scales_shard_indexer,
            })

109
110
111
112
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
113
        output_partition_sizes: List[int],
114
115
116
117
118
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
119
        del input_size, output_size
120
        output_size_per_partition = sum(output_partition_sizes)
121
122
123
124
125
126
127
128

        layer.process_after_load = True
        layer.logical_widths = output_partition_sizes

        # WEIGHT
        weight_dtype = (torch.float8_e4m3fn
                        if self.quant_config.is_checkpoint_fp8_serialized else
                        params_dtype)
129
130
        weight = Parameter(torch.empty(output_size_per_partition,
                                       input_size_per_partition,
131
                                       dtype=weight_dtype),
132
133
                           requires_grad=False)
        layer.register_parameter("weight", weight)
134
135
136
137
138
        set_weight_attrs(weight, {
            **extra_weight_attrs,
            "input_dim": 1,
            "output_dim": 0,
        })
139

140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
        # If checkpoint is serialized fp8, load them.
        # Otherwise, wait until process_weights_after_loading.
        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
            self._create_scale_param(
                scale_name="weight_scale",
                layer=layer,
                output_partition_sizes=output_partition_sizes,
                **extra_weight_attrs)

            # ACTIVATION SCALE
            if self.quant_config.activation_scheme == "static":
                self._create_scale_param(
                    scale_name="act_scale",
                    layer=layer,
                    output_partition_sizes=output_partition_sizes,
                    **extra_weight_attrs)

    def scales_shard_indexer(
            self, param: torch.Tensor, loaded_weight: torch.Tensor,
            shard_id: Union[str, int]) -> Tuple[torch.Tensor, torch.Tensor]:
        qkv_idxs = {"q": 0, "k": 1, "v": 2}

        if isinstance(shard_id, int):
            pass
        elif isinstance(shard_id, str):
            if shard_id not in qkv_idxs:
                raise ValueError(f"Unknown shard_id: {shard_id}")
            shard_id = qkv_idxs[shard_id]
        else:
            ValueError(f"Shard id must be int or str but got {type(shard_id)}")

        return param[shard_id], loaded_weight
173
174

    def process_weights_after_loading(self, layer: Module) -> None:
175
176
177
178
179
180
181
182
183
184
185
186
        if (not hasattr(layer, "process_after_load")
                or not layer.process_after_load):
            return

        # If checkpoint is fp/bf16 (not serialized fp8), quantize the weights.
        if not self.quant_config.is_checkpoint_fp8_serialized:
            qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
                                                         scale=None)
            layer.weight = Parameter(qweight.t(), requires_grad=False)
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
            layer.logical_widths = None
            layer.act_scale = None
187
188
            return

189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
        # If checkpoint is fp8, requantize the separately quantized logical
        # weights into a single fp8 weight with a single weight scale.
        else:
            # WEIGHT_SCALE / WEIGHT
            #   Loop over logical weights, requantizing with single scale.
            max_w_scale = layer.weight_scale.max()
            start = 0
            for idx, logical_width in enumerate(layer.logical_widths):
                end = start + logical_width
                weight_dq = per_tensor_dequantize(layer.weight[start:end, :],
                                                  layer.weight_scale[idx])

                layer.weight[start:end, :] = per_tensor_quantize(
                    weight_dq, layer.weight_scale.max())
                start = end
            layer.weight_scale = Parameter(max_w_scale, requires_grad=False)

            # WEIGHT
            #   Transpose weight for passing to torch._scaled_mm
            weight = layer.weight
            layer.weight = Parameter(weight.t(), requires_grad=False)

            # ACT_SCALE
            #   Dynamic: set to None (required input to ops.scaled_fp8_quant).
            #   Static:  set to max of the act_scales (since they are equal).
            if self.quant_config.activation_scheme == "dynamic":
                layer.act_scale = None
            elif self.quant_config.activation_scheme == "static":
                if not all_close_1d(layer.act_scale):
                    raise ValueError(
                        "All the act_scales for the logical weights of a layer "
                        f"must be equal. But got {layer.act_scale}")
                layer.act_scale = Parameter(layer.act_scale.max(),
                                            requires_grad=False)
            else:
                raise ValueError(
                    f"Unknown scheme {self.quant_config.activation_scheme}")
226

227
228
229
230
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
231
232
233
234
235
236
        # ops.scaled_fp8_quant supports both dynamic and static quant.
        #   If dynamic, layer.act_scale is None and x_scale computed from x.
        #   If static,  layer.act_scale is scalar and x_scale set to act_scale.
        qinput, x_scale = ops.scaled_fp8_quant(x, layer.act_scale)

        # Fused GEMM_DQ
237
238
239
240
241
        output, _ = torch._scaled_mm(
            qinput,
            layer.weight,
            out_dtype=x.dtype,
            scale_a=x_scale,
242
            scale_b=layer.weight_scale,
243
244
            bias=bias,
        )
245

246
        return output
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265


def all_close_1d(x: torch.Tensor) -> bool:
    assert len(x.shape) == 1
    return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))


def per_tensor_quantize(tensor: torch.Tensor,
                        inv_scale: float) -> torch.Tensor:
    finfo = torch.finfo(torch.float8_e4m3fn)
    qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
    return qweight.to(torch.float8_e4m3fn)


def per_tensor_dequantize(tensor: torch.Tensor,
                          inv_scale: float) -> torch.Tensor:
    fake_qweight = tensor.to(torch.float16)
    dq_weight = fake_qweight * inv_scale
    return dq_weight