fp8.py 14.4 KB
Newer Older
1
2
import torch

3
from dataclasses import dataclass
4
from typing import Optional, Tuple, Union, List
5
from loguru import logger
6
7

from text_generation_server.utils.import_utils import SYSTEM
8
9
10
11
12
13
14
from text_generation_server.utils.weights import (
    Weight,
    WeightsLoader,
    UnquantizedWeight,
    Weights,
)
from text_generation_server.utils.log import log_master, log_once
15
16
import importlib.util

17
18
19
20

FBGEMM_MM_AVAILABLE = False
FBGEMM_DYN_AVAILABLE = False

21
22
23
24
25
26
27
28
29

def is_fbgemm_gpu_available():
    try:
        return importlib.util.find_spec("fbgemm_gpu.experimental.gen_ai") is not None
    except ModuleNotFoundError:
        return False


if is_fbgemm_gpu_available():
30
31
32
33
    if SYSTEM == "cuda":
        major, _ = torch.cuda.get_device_capability()
        FBGEMM_MM_AVAILABLE = major == 9
        FBGEMM_DYN_AVAILABLE = major >= 8
34
else:
35
    log_master(logger.warning, "FBGEMM fp8 kernels are not installed.")
36
37
38
39
40
41
42
43


def get_fp8_linear() -> torch.nn.Module:
    """
    Return an FP8 linear `Module` that is compatible with the current system.
    """

    if SYSTEM == "cuda":
44
45
        major, _ = torch.cuda.get_device_capability()
        if major == 8:
46
47
48
49
50
51
            from text_generation_server.layers.marlin import GPTQMarlinFP8Linear

            return GPTQMarlinFP8Linear

    # On other systems let Torch decide if the hardware supports FP8.
    return Fp8Linear
Nicolas Patry's avatar
Nicolas Patry committed
52
53


54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def normalize_e4m3fn_to_e4m3fnuz(
    weight: torch.Tensor,
    weight_scale: torch.Tensor,
    input_scale: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
    assert weight.dtype == torch.float8_e4m3fn
    # The bits pattern 10000000(-128) represents zero in e4m3fn
    # but NaN in e4m3fnuz. So here we set it to 0.
    # https://onnx.ai/onnx/technical/float8.html
    weight_as_int8 = weight.view(torch.int8)
    ROCM_FP8_NAN_AS_INT = -128
    weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0
    weight = weight_as_int8.view(torch.float8_e4m3fnuz)

    # For the same bits representation, e4m3fnuz value is half of
    # the e4m3fn value, so we should double the scaling factor to
    # get the same dequantized value.
    # https://onnx.ai/onnx/technical/float8.html
    weight_scale = weight_scale * 2.0
    if input_scale is not None:
        input_scale = input_scale * 2.0
    return weight, weight_scale, input_scale


78
def fp8_quantize(
79
    weight, scale=None, scale_upper_bound=None, qdtype=torch.float8_e4m3fn, scalar=False
80
81
):
    if FBGEMM_DYN_AVAILABLE and not scalar:
82
83
84
85
86
        qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row(
            weight, bs=None, scale_ub=scale_upper_bound, output_dtype=qdtype
        )
        return qweight, scale

Nicolas Patry's avatar
Nicolas Patry committed
87
88
    # weight, scale = quant_weights(weight, torch.int8, False)
    finfo = torch.finfo(qdtype)
89
90
91
92
93

    if scale is None:
        # Calculate the scale as dtype max divided by absmax
        scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound)

Nicolas Patry's avatar
Nicolas Patry committed
94
95
96
97
98
99
100
101
    # scale and clamp the tensor to bring it to
    # the representative range of float8 data type
    # (as default cast is unsaturated)
    qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max)
    # Return both float8 data and the inverse scale (as float),
    # as both required as inputs to torch._scaled_mm
    qweight = qweight.to(qdtype)
    scale = scale.float().reciprocal()
102
103
104
105

    if SYSTEM == "rocm":
        qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(qweight, scale)

Nicolas Patry's avatar
Nicolas Patry committed
106
107
108
    return qweight, scale


109
110
111
112
113
114
115
116
117
118
119
120
class HybridFP8UnquantLoader(WeightsLoader):
    """Weight loader that loads FP8 and unquantized Torch tensors."""

    def __init__(self, activation_scale_ub: Optional[float], to_fp8: bool):
        self.activation_scale_ub = activation_scale_ub
        self.to_fp8 = to_fp8

    def get_weights(self, weights: "Weights", prefix: str):
        w = weights.get_tensor(f"{prefix}.weight")

        if w.dtype == torch.float8_e4m3fn:
            # FP8 branch
121
122
123
124
125
            scale = (
                weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
                .reshape(-1)
                .expand(w.shape[0])
            )
126
127
128
129
130
131
132

            input_scale = None
            if weights.has_tensor(f"{prefix}.input_scale"):
                input_scale = weights.get_tensor(
                    f"{prefix}.input_scale", to_dtype=False
                ).reshape(-1)

133
134
135
            return Fp8Weight(
                weight=w,
                weight_scale=scale,
136
                input_scale=input_scale,
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
                activation_scale_ub=self.activation_scale_ub,
                dtype=weights.dtype,
            )
        if self.to_fp8:
            return Fp8Weight(weight=w, dtype=weights.dtype)

        return UnquantizedWeight(w)

    def get_weights_col_packed(
        self,
        weights: Weights,
        prefix: str,
        block_sizes: Union[int, List[int]],
    ):
        w = weights.get_packed_sharded(
            f"{prefix}.weight", dim=0, block_sizes=block_sizes
        )

        if w.dtype == torch.float8_e4m3fn:
            # FP8 branch
157
158
159
160
161
162
163
164
165
166
            scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
            if scale.numel() > 1:
                scale = weights.get_packed_sharded(
                    f"{prefix}.weight_scale",
                    dim=0,
                    block_sizes=block_sizes,
                    to_dtype=False,
                )
            scale = scale.reshape(-1).expand(w.shape[0])

167
168
169
170
171
172
173
174
175
176
177
178
179
180
            input_scale = None
            if weights.has_tensor(f"{prefix}.input_scale"):
                input_scale = weights.get_tensor(
                    f"{prefix}.input_scale", to_dtype=False
                )
                if input_scale.numel() > 1:
                    input_scale = weights.get_packed_sharded(
                        f"{prefix}.input_scale",
                        dim=0,
                        block_sizes=block_sizes,
                        to_dtype=False,
                    )
                input_scale = input_scale.reshape(-1).max()

181
182
183
            return Fp8Weight(
                weight=w,
                weight_scale=scale,
184
                input_scale=input_scale,
185
186
187
188
189
190
191
192
193
                activation_scale_ub=self.activation_scale_ub,
                dtype=weights.dtype,
            )
        if self.to_fp8:
            return Fp8Weight(weight=w, dtype=weights.dtype)

        return UnquantizedWeight(w)

    def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
194
195
196
197
        # FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet
        w = [
            weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes
        ]
198
199
        shapes = [x.shape for x in w]

200
201
        # Concat then send to the device
        w = torch.cat(w, dim=dim).to(weights.device)
202
203
204
205

        # FP8 branch
        if w.dtype == torch.float8_e4m3fn:
            scale = [
206
207
                _load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
                for p, shape in zip(prefixes, shapes)
208
            ]
209
            scale = torch.cat(scale, dim=0).reshape(-1)
210

211
212
213
214
215
216
217
218
219
220
221
222
            input_scale = [
                _load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape)
                for p, shape in zip(prefixes, shapes)
                if weights.has_tensor(f"{p}.input_scale")
            ]
            assert len(input_scale) == 0 or len(input_scale) == len(prefixes)
            input_scale = (
                torch.cat(input_scale, dim=0).reshape(-1).max()
                if len(input_scale) != 0
                else None
            )

223
224
225
            return Fp8Weight(
                weight=w,
                weight_scale=scale,
226
                input_scale=input_scale,
227
228
229
230
231
232
233
234
235
236
237
238
                activation_scale_ub=self.activation_scale_ub,
                dtype=weights.dtype,
            )
        if self.to_fp8:
            return Fp8Weight(weight=w, dtype=weights.dtype)

        return UnquantizedWeight(w)

    def get_weights_row(self, weights: "Weights", prefix: str):
        w = weights.get_sharded(f"{prefix}.weight", dim=1)
        # FP8 branch
        if w.dtype == torch.float8_e4m3fn:
239
240
241
242
243
            scale = (
                weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
                .reshape(-1)
                .expand(w.shape[0])
            )
244
245
246
247
248
249
            input_scale = None
            if weights.has_tensor(f"{prefix}.input_scale"):
                input_scale = weights.get_tensor(
                    f"{prefix}.input_scale", to_dtype=False
                ).reshape(-1)

250
251
252
            return Fp8Weight(
                weight=w,
                weight_scale=scale,
253
                input_scale=input_scale,
254
255
256
257
258
259
260
261
262
                activation_scale_ub=self.activation_scale_ub,
                dtype=weights.dtype,
            )
        if self.to_fp8:
            return Fp8Weight(weight=w, dtype=weights.dtype)

        return UnquantizedWeight(w)


263
264
265
@dataclass
class Fp8Weight(Weight):
    weight: torch.Tensor
266
267
    dtype: torch.dtype
    weight_scale: Optional[torch.Tensor] = None
268
    input_scale: Optional[torch.Tensor] = None
269
    activation_scale_ub: Optional[float] = None
270
271

    def get_linear(self, bias: torch.Tensor):
272
273
        if self.weight_scale is None:
            return get_fp8_linear().from_unquant(self.weight, bias, self.dtype)
274
275
276
        # This is not checked by the fbgemm kernels, but they require contiguous
        # memory. Can be non-contiguous when we e.g. expand from scalars.
        self.weight_scale = self.weight_scale.contiguous()
277
        return get_fp8_linear().from_fp8(
278
279
280
281
282
283
            weight=self.weight,
            scale=self.weight_scale,
            dtype=self.dtype,
            bias=bias,
            input_scale=self.input_scale,
            scale_upper_bound=self.activation_scale_ub,
284
        )
285
286


Nicolas Patry's avatar
Nicolas Patry committed
287
class Fp8Linear(torch.nn.Module):
288
289
    _device_identity_cache = {}

Nicolas Patry's avatar
Nicolas Patry committed
290
291
    def __init__(
        self,
292
293
294
295
296
297
        qweight: torch.Tensor,
        scale: torch.Tensor,
        dtype: torch.dtype,
        bias: Optional[torch.Tensor] = None,
        input_scale: Optional[torch.Tensor] = None,
        scale_upper_bound: Optional[float] = None,
Nicolas Patry's avatar
Nicolas Patry committed
298
299
    ) -> None:
        super().__init__()
300
301
        if FBGEMM_MM_AVAILABLE:
            log_once(logger.info, "Using FBGEMM fp8 optimized kernels")
302
303
304
305
        if SYSTEM == "rocm" and qweight.dtype == torch.float8_e4m3fn:
            qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(
                weight=qweight, weight_scale=scale
            )
306

307
308
        self.dtype = dtype
        self.qweight = qweight
309
310
311
        self.scale = scale.float()
        self.input_scale = (
            input_scale.float().reciprocal() if input_scale is not None else None
312
        )
Nicolas Patry's avatar
Nicolas Patry committed
313

314
315
316
317
318
319
320
321
322
323
324
        if FBGEMM_MM_AVAILABLE:
            self.scale_upper_bound = (
                torch.tensor(
                    [scale_upper_bound], dtype=torch.float32, device=qweight.device
                )
                if scale_upper_bound is not None
                else None
            )
        else:
            self.scale_upper_bound = scale_upper_bound

Nicolas Patry's avatar
Nicolas Patry committed
325
326
        self.bias = bias if bias is not None else None

327
328
    @classmethod
    def from_unquant(cls, weight, bias, dtype):
329
        qweight, scale = fp8_quantize(weight, scalar=not FBGEMM_MM_AVAILABLE)
330
        return cls(
331
332
333
334
335
336
            qweight=qweight,
            scale=scale,
            dtype=dtype,
            bias=bias,
            input_scale=None,
            scale_upper_bound=None,
337
338
339
        )

    @classmethod
340
341
342
343
344
345
346
347
348
349
350
    def from_fp8(
        cls,
        weight: torch.Tensor,
        scale: torch.Tensor,
        dtype: torch.dtype,
        bias: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> "Fp8Linear":
        input_scale = kwargs.get("input_scale", None)
        scale_upper_bound = kwargs.get("scale_upper_bound", None)

351
352
353
        if FBGEMM_DYN_AVAILABLE:
            # fbgemm needs float32 scales.
            scale = scale.float()
354
355
356
        return cls(
            qweight=weight,
            scale=scale,
357
358
            input_scale=input_scale,
            scale_upper_bound=scale_upper_bound,
359
360
361
362
            bias=bias,
            dtype=dtype,
        )

363
364
365
366
367
368
369
370
    @classmethod
    def get_shared_device_identity(cls, device):
        # Input scaling factors are no longer optional in _scaled_mm starting
        # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
        if device not in cls._device_identity_cache:
            cls._device_identity_cache[device] = torch.ones(1, device=device)
        return cls._device_identity_cache[device]

Nicolas Patry's avatar
Nicolas Patry committed
371
    def forward(self, input: torch.Tensor) -> torch.Tensor:
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
        if FBGEMM_MM_AVAILABLE:
            qinput, scale = fp8_quantize(
                input, scale_upper_bound=self.scale_upper_bound
            )

            y = torch.ops.fbgemm.f8f8bf16_rowwise(
                qinput,
                self.qweight,
                scale,
                self.scale,
                use_fast_accum=True,
                bias=self.bias,
            )
            return y.to(self.dtype)

387
388
389
390
391
        qinput, scale = fp8_quantize(
            input,
            self.input_scale,
            scale_upper_bound=self.scale_upper_bound,
            scalar=True,
Nicolas Patry's avatar
Nicolas Patry committed
392
        )
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429

        per_tensor_weights = self.scale.numel() == 1
        per_tensor_activations = scale.numel() == 1

        if SYSTEM != "rocm" or (per_tensor_weights and per_tensor_activations):
            output = torch._scaled_mm(
                qinput,
                self.qweight.t(),
                out_dtype=self.dtype,
                scale_a=scale,
                scale_b=self.scale,
                bias=self.bias,
            )

            if isinstance(output, tuple) and len(output) == 2:
                output = output[0]
        else:
            device_identity = None
            if SYSTEM == "rocm":
                device_identity = self.get_shared_device_identity(self.qweight.device)

            output = torch._scaled_mm(
                qinput,
                self.qweight.t(),
                scale_a=device_identity,
                scale_b=device_identity,
                out_dtype=torch.float32,
            )
            if isinstance(output, tuple) and len(output) == 2:
                output = output[0]

            output = output * scale * self.scale.t()
            if self.bias is not None:
                output = output + self.bias

            output = output.to(dtype=self.dtype)

Nicolas Patry's avatar
Nicolas Patry committed
430
        return output
431
432
433
434
435
436
437


def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size):
    scale = weights.get_tensor(prefix, to_dtype=False)
    if scale.numel() > 1:
        scale = weights.get_sharded(prefix, dim=0, to_dtype=False)
    return scale.reshape(-1).expand(shape[0])