fp8.py 8.95 KB
Newer Older
1
2
import torch

3
from dataclasses import dataclass
4
5
from typing import Optional, Union, List
from loguru import logger
6
7

from text_generation_server.utils.import_utils import SYSTEM
8
9
10
11
12
13
14
from text_generation_server.utils.weights import (
    Weight,
    WeightsLoader,
    UnquantizedWeight,
    Weights,
)
from text_generation_server.utils.log import log_master, log_once
15
16
import importlib.util

17
18
19
20

FBGEMM_MM_AVAILABLE = False
FBGEMM_DYN_AVAILABLE = False

21
22
23
24
25
26
27
28
29

def is_fbgemm_gpu_available():
    try:
        return importlib.util.find_spec("fbgemm_gpu.experimental.gen_ai") is not None
    except ModuleNotFoundError:
        return False


if is_fbgemm_gpu_available():
30
31
32
33
    if SYSTEM == "cuda":
        major, _ = torch.cuda.get_device_capability()
        FBGEMM_MM_AVAILABLE = major == 9
        FBGEMM_DYN_AVAILABLE = major >= 8
34
else:
35
    log_master(logger.warning, "FBGEMM fp8 kernels are not installed.")
36
37
38
39
40
41
42
43


def get_fp8_linear() -> torch.nn.Module:
    """
    Return an FP8 linear `Module` that is compatible with the current system.
    """

    if SYSTEM == "cuda":
44
45
        major, _ = torch.cuda.get_device_capability()
        if major == 8:
46
47
48
49
50
51
            from text_generation_server.layers.marlin import GPTQMarlinFP8Linear

            return GPTQMarlinFP8Linear

    # On other systems let Torch decide if the hardware supports FP8.
    return Fp8Linear
Nicolas Patry's avatar
Nicolas Patry committed
52
53


54
55
56
57
def fp8_quantize(
    weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn, scalar=False
):
    if FBGEMM_DYN_AVAILABLE and not scalar:
58
59
60
61
62
        qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row(
            weight, bs=None, scale_ub=scale_upper_bound, output_dtype=qdtype
        )
        return qweight, scale

Nicolas Patry's avatar
Nicolas Patry committed
63
64
65
    # weight, scale = quant_weights(weight, torch.int8, False)
    finfo = torch.finfo(qdtype)
    # Calculate the scale as dtype max divided by absmax
66
    scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound)
Nicolas Patry's avatar
Nicolas Patry committed
67
68
69
70
71
72
73
74
75
76
77
    # scale and clamp the tensor to bring it to
    # the representative range of float8 data type
    # (as default cast is unsaturated)
    qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max)
    # Return both float8 data and the inverse scale (as float),
    # as both required as inputs to torch._scaled_mm
    qweight = qweight.to(qdtype)
    scale = scale.float().reciprocal()
    return qweight, scale


78
79
80
81
82
83
84
85
86
87
88
89
class HybridFP8UnquantLoader(WeightsLoader):
    """Weight loader that loads FP8 and unquantized Torch tensors."""

    def __init__(self, activation_scale_ub: Optional[float], to_fp8: bool):
        self.activation_scale_ub = activation_scale_ub
        self.to_fp8 = to_fp8

    def get_weights(self, weights: "Weights", prefix: str):
        w = weights.get_tensor(f"{prefix}.weight")

        if w.dtype == torch.float8_e4m3fn:
            # FP8 branch
90
91
92
93
94
            scale = (
                weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
                .reshape(-1)
                .expand(w.shape[0])
            )
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
            return Fp8Weight(
                weight=w,
                weight_scale=scale,
                activation_scale_ub=self.activation_scale_ub,
                dtype=weights.dtype,
            )
        if self.to_fp8:
            return Fp8Weight(weight=w, dtype=weights.dtype)

        return UnquantizedWeight(w)

    def get_weights_col_packed(
        self,
        weights: Weights,
        prefix: str,
        block_sizes: Union[int, List[int]],
    ):
        w = weights.get_packed_sharded(
            f"{prefix}.weight", dim=0, block_sizes=block_sizes
        )

        if w.dtype == torch.float8_e4m3fn:
            # FP8 branch
118
119
120
121
122
123
124
125
126
127
            scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
            if scale.numel() > 1:
                scale = weights.get_packed_sharded(
                    f"{prefix}.weight_scale",
                    dim=0,
                    block_sizes=block_sizes,
                    to_dtype=False,
                )
            scale = scale.reshape(-1).expand(w.shape[0])

128
129
130
131
132
133
134
135
136
137
138
139
            return Fp8Weight(
                weight=w,
                weight_scale=scale,
                activation_scale_ub=self.activation_scale_ub,
                dtype=weights.dtype,
            )
        if self.to_fp8:
            return Fp8Weight(weight=w, dtype=weights.dtype)

        return UnquantizedWeight(w)

    def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
140
141
142
143
        # FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet
        w = [
            weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes
        ]
144
145
        shapes = [x.shape for x in w]

146
147
        # Concat then send to the device
        w = torch.cat(w, dim=dim).to(weights.device)
148
149
150
151

        # FP8 branch
        if w.dtype == torch.float8_e4m3fn:
            scale = [
152
153
                _load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
                for p, shape in zip(prefixes, shapes)
154
            ]
155
            scale = torch.cat(scale, dim=0).reshape(-1)
156

157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
            return Fp8Weight(
                weight=w,
                weight_scale=scale,
                activation_scale_ub=self.activation_scale_ub,
                dtype=weights.dtype,
            )
        if self.to_fp8:
            return Fp8Weight(weight=w, dtype=weights.dtype)

        return UnquantizedWeight(w)

    def get_weights_row(self, weights: "Weights", prefix: str):
        w = weights.get_sharded(f"{prefix}.weight", dim=1)
        # FP8 branch
        if w.dtype == torch.float8_e4m3fn:
172
173
174
175
176
            scale = (
                weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
                .reshape(-1)
                .expand(w.shape[0])
            )
177
178
179
180
181
182
183
184
185
186
187
188
            return Fp8Weight(
                weight=w,
                weight_scale=scale,
                activation_scale_ub=self.activation_scale_ub,
                dtype=weights.dtype,
            )
        if self.to_fp8:
            return Fp8Weight(weight=w, dtype=weights.dtype)

        return UnquantizedWeight(w)


189
190
191
@dataclass
class Fp8Weight(Weight):
    weight: torch.Tensor
192
193
194
    dtype: torch.dtype
    weight_scale: Optional[torch.Tensor] = None
    activation_scale_ub: Optional[float] = None
195
196

    def get_linear(self, bias: torch.Tensor):
197
198
        if self.weight_scale is None:
            return get_fp8_linear().from_unquant(self.weight, bias, self.dtype)
199
200
201
        # This is not checked by the fbgemm kernels, but they require contiguous
        # memory. Can be non-contiguous when we e.g. expand from scalars.
        self.weight_scale = self.weight_scale.contiguous()
202
203
204
        return get_fp8_linear().from_fp8(
            self.weight, self.weight_scale, self.activation_scale_ub, bias, self.dtype
        )
205
206


Nicolas Patry's avatar
Nicolas Patry committed
207
208
209
class Fp8Linear(torch.nn.Module):
    def __init__(
        self,
210
211
212
        qweight,
        scale,
        scale_upper_bound,
Nicolas Patry's avatar
Nicolas Patry committed
213
        bias,
214
        dtype,
Nicolas Patry's avatar
Nicolas Patry committed
215
216
    ) -> None:
        super().__init__()
217
218
219
        if FBGEMM_MM_AVAILABLE:
            log_once(logger.info, "Using FBGEMM fp8 optimized kernels")

220
221
222
223
224
225
226
227
228
229
        self.dtype = dtype
        self.qweight = qweight
        self.scale = scale
        self.scale_upper_bound = (
            torch.tensor(
                [scale_upper_bound], dtype=torch.float32, device=qweight.device
            )
            if scale_upper_bound is not None
            else None
        )
Nicolas Patry's avatar
Nicolas Patry committed
230
231
232

        self.bias = bias if bias is not None else None

233
234
    @classmethod
    def from_unquant(cls, weight, bias, dtype):
235
        qweight, scale = fp8_quantize(weight, scalar=not FBGEMM_MM_AVAILABLE)
236
237
238
239
240
241
        return cls(
            qweight=qweight, scale=scale, scale_upper_bound=None, bias=bias, dtype=dtype
        )

    @classmethod
    def from_fp8(cls, weight, scale, input_scale, bias, dtype):
242
243
244
        if FBGEMM_DYN_AVAILABLE:
            # fbgemm needs float32 scales.
            scale = scale.float()
245
246
247
248
249
250
251
252
        return cls(
            qweight=weight,
            scale=scale,
            scale_upper_bound=input_scale,
            bias=bias,
            dtype=dtype,
        )

Nicolas Patry's avatar
Nicolas Patry committed
253
    def forward(self, input: torch.Tensor) -> torch.Tensor:
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
        if FBGEMM_MM_AVAILABLE:
            qinput, scale = fp8_quantize(
                input, scale_upper_bound=self.scale_upper_bound
            )

            y = torch.ops.fbgemm.f8f8bf16_rowwise(
                qinput,
                self.qweight,
                scale,
                self.scale,
                use_fast_accum=True,
                bias=self.bias,
            )
            return y.to(self.dtype)

269
        qinput, scale = fp8_quantize(input, scalar=True)
Nicolas Patry's avatar
Nicolas Patry committed
270
271
272
273
274
275
276
277
278
        output, _ = torch._scaled_mm(
            qinput,
            self.qweight.t(),
            out_dtype=self.dtype,
            scale_a=scale,
            scale_b=self.scale,
            bias=self.bias,
        )
        return output
279
280
281
282
283
284
285


def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size):
    scale = weights.get_tensor(prefix, to_dtype=False)
    if scale.numel() > 1:
        scale = weights.get_sharded(prefix, dim=0, to_dtype=False)
    return scale.reshape(-1).expand(shape[0])