fp8.py 14.9 KB
Newer Older
1
2
import torch

3
from dataclasses import dataclass
4
from typing import Optional, Tuple, Union, List
5
from loguru import logger
6
7

from text_generation_server.utils.import_utils import SYSTEM
8
9
10
11
12
13
14
from text_generation_server.utils.weights import (
    Weight,
    WeightsLoader,
    UnquantizedWeight,
    Weights,
)
from text_generation_server.utils.log import log_master, log_once
15
16
import importlib.util

17
18
19
20

FBGEMM_MM_AVAILABLE = False
FBGEMM_DYN_AVAILABLE = False

21
22
23
24
25
26
27
28
29

def is_fbgemm_gpu_available():
    try:
        return importlib.util.find_spec("fbgemm_gpu.experimental.gen_ai") is not None
    except ModuleNotFoundError:
        return False


if is_fbgemm_gpu_available():
30
31
32
33
    if SYSTEM == "cuda":
        major, _ = torch.cuda.get_device_capability()
        FBGEMM_MM_AVAILABLE = major == 9
        FBGEMM_DYN_AVAILABLE = major >= 8
34
else:
35
    log_master(logger.warning, "FBGEMM fp8 kernels are not installed.")
36
37
38
39
40
41
42
43


def get_fp8_linear() -> torch.nn.Module:
    """
    Return an FP8 linear `Module` that is compatible with the current system.
    """

    if SYSTEM == "cuda":
44
45
        major, _ = torch.cuda.get_device_capability()
        if major == 8:
46
47
48
49
50
51
            from text_generation_server.layers.marlin import GPTQMarlinFP8Linear

            return GPTQMarlinFP8Linear

    # On other systems let Torch decide if the hardware supports FP8.
    return Fp8Linear
Nicolas Patry's avatar
Nicolas Patry committed
52
53


54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def normalize_e4m3fn_to_e4m3fnuz(
    weight: torch.Tensor,
    weight_scale: torch.Tensor,
    input_scale: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
    assert weight.dtype == torch.float8_e4m3fn
    # The bits pattern 10000000(-128) represents zero in e4m3fn
    # but NaN in e4m3fnuz. So here we set it to 0.
    # https://onnx.ai/onnx/technical/float8.html
    weight_as_int8 = weight.view(torch.int8)
    ROCM_FP8_NAN_AS_INT = -128
    weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0
    weight = weight_as_int8.view(torch.float8_e4m3fnuz)

    # For the same bits representation, e4m3fnuz value is half of
    # the e4m3fn value, so we should double the scaling factor to
    # get the same dequantized value.
    # https://onnx.ai/onnx/technical/float8.html
    weight_scale = weight_scale * 2.0
    if input_scale is not None:
        input_scale = input_scale * 2.0
    return weight, weight_scale, input_scale


78
def fp8_quantize(
79
80
81
82
83
    weight: torch.Tensor,
    scale: Optional[torch.Tensor] = None,
    scale_upper_bound: Optional[torch.Tensor] = None,
    qdtype: torch.dtype = torch.float8_e4m3fn,
    scalar: bool = False,
84
):
85
86
87
88
89
90
91
    """
    This function returns a reciprocal of the scale, so that a tensor can be unscaled
    by multiplying it with the returned scale. If a scale is given through the `scale`
    argument, it must also be a reciprocal (so that scales from an FP8 checkpoint can
    be used without modification).
    """
    if FBGEMM_DYN_AVAILABLE and not scalar and not scale:
92
93
94
95
96
        qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row(
            weight, bs=None, scale_ub=scale_upper_bound, output_dtype=qdtype
        )
        return qweight, scale

Nicolas Patry's avatar
Nicolas Patry committed
97
98
    # weight, scale = quant_weights(weight, torch.int8, False)
    finfo = torch.finfo(qdtype)
99
100
101
102

    if scale is None:
        # Calculate the scale as dtype max divided by absmax
        scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound)
103
104
105
106
107
108
109
110
        # scale and clamp the tensor to bring it to
        # the representative range of float8 data type
        # (as default cast is unsaturated)
        qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max)
        scale = scale.float().reciprocal()
    else:
        # Use reciprocal to avoid more expensive division.
        qweight = (weight * scale.reciprocal()).clamp(min=finfo.min, max=finfo.max)
111

Nicolas Patry's avatar
Nicolas Patry committed
112
113
114
    # Return both float8 data and the inverse scale (as float),
    # as both required as inputs to torch._scaled_mm
    qweight = qweight.to(qdtype)
115
116
117
118

    if SYSTEM == "rocm":
        qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(qweight, scale)

Nicolas Patry's avatar
Nicolas Patry committed
119
120
121
    return qweight, scale


122
123
124
125
126
127
128
129
130
131
132
133
class HybridFP8UnquantLoader(WeightsLoader):
    """Weight loader that loads FP8 and unquantized Torch tensors."""

    def __init__(self, activation_scale_ub: Optional[float], to_fp8: bool):
        self.activation_scale_ub = activation_scale_ub
        self.to_fp8 = to_fp8

    def get_weights(self, weights: "Weights", prefix: str):
        w = weights.get_tensor(f"{prefix}.weight")

        if w.dtype == torch.float8_e4m3fn:
            # FP8 branch
134
135
136
137
138
            scale = (
                weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
                .reshape(-1)
                .expand(w.shape[0])
            )
139
140
141
142
143
144
145

            input_scale = None
            if weights.has_tensor(f"{prefix}.input_scale"):
                input_scale = weights.get_tensor(
                    f"{prefix}.input_scale", to_dtype=False
                ).reshape(-1)

146
147
148
            return Fp8Weight(
                weight=w,
                weight_scale=scale,
149
                input_scale=input_scale,
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
                activation_scale_ub=self.activation_scale_ub,
                dtype=weights.dtype,
            )
        if self.to_fp8:
            return Fp8Weight(weight=w, dtype=weights.dtype)

        return UnquantizedWeight(w)

    def get_weights_col_packed(
        self,
        weights: Weights,
        prefix: str,
        block_sizes: Union[int, List[int]],
    ):
        w = weights.get_packed_sharded(
            f"{prefix}.weight", dim=0, block_sizes=block_sizes
        )

        if w.dtype == torch.float8_e4m3fn:
            # FP8 branch
170
171
172
173
174
175
176
177
178
179
            scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
            if scale.numel() > 1:
                scale = weights.get_packed_sharded(
                    f"{prefix}.weight_scale",
                    dim=0,
                    block_sizes=block_sizes,
                    to_dtype=False,
                )
            scale = scale.reshape(-1).expand(w.shape[0])

180
181
182
183
184
185
186
187
188
189
190
191
192
193
            input_scale = None
            if weights.has_tensor(f"{prefix}.input_scale"):
                input_scale = weights.get_tensor(
                    f"{prefix}.input_scale", to_dtype=False
                )
                if input_scale.numel() > 1:
                    input_scale = weights.get_packed_sharded(
                        f"{prefix}.input_scale",
                        dim=0,
                        block_sizes=block_sizes,
                        to_dtype=False,
                    )
                input_scale = input_scale.reshape(-1).max()

194
195
196
            return Fp8Weight(
                weight=w,
                weight_scale=scale,
197
                input_scale=input_scale,
198
199
200
201
202
203
204
205
206
                activation_scale_ub=self.activation_scale_ub,
                dtype=weights.dtype,
            )
        if self.to_fp8:
            return Fp8Weight(weight=w, dtype=weights.dtype)

        return UnquantizedWeight(w)

    def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
207
208
209
210
        # FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet
        w = [
            weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes
        ]
211
212
        shapes = [x.shape for x in w]

213
214
        # Concat then send to the device
        w = torch.cat(w, dim=dim).to(weights.device)
215
216
217
218

        # FP8 branch
        if w.dtype == torch.float8_e4m3fn:
            scale = [
219
220
                _load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
                for p, shape in zip(prefixes, shapes)
221
            ]
222
            scale = torch.cat(scale, dim=0).reshape(-1)
223

224
225
226
227
228
229
230
231
232
233
234
235
            input_scale = [
                _load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape)
                for p, shape in zip(prefixes, shapes)
                if weights.has_tensor(f"{p}.input_scale")
            ]
            assert len(input_scale) == 0 or len(input_scale) == len(prefixes)
            input_scale = (
                torch.cat(input_scale, dim=0).reshape(-1).max()
                if len(input_scale) != 0
                else None
            )

236
237
238
            return Fp8Weight(
                weight=w,
                weight_scale=scale,
239
                input_scale=input_scale,
240
241
242
243
244
245
246
247
248
249
250
251
                activation_scale_ub=self.activation_scale_ub,
                dtype=weights.dtype,
            )
        if self.to_fp8:
            return Fp8Weight(weight=w, dtype=weights.dtype)

        return UnquantizedWeight(w)

    def get_weights_row(self, weights: "Weights", prefix: str):
        w = weights.get_sharded(f"{prefix}.weight", dim=1)
        # FP8 branch
        if w.dtype == torch.float8_e4m3fn:
252
253
254
255
256
            scale = (
                weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
                .reshape(-1)
                .expand(w.shape[0])
            )
257
258
259
260
261
262
            input_scale = None
            if weights.has_tensor(f"{prefix}.input_scale"):
                input_scale = weights.get_tensor(
                    f"{prefix}.input_scale", to_dtype=False
                ).reshape(-1)

263
264
265
            return Fp8Weight(
                weight=w,
                weight_scale=scale,
266
                input_scale=input_scale,
267
268
269
270
271
272
273
274
275
                activation_scale_ub=self.activation_scale_ub,
                dtype=weights.dtype,
            )
        if self.to_fp8:
            return Fp8Weight(weight=w, dtype=weights.dtype)

        return UnquantizedWeight(w)


276
277
278
@dataclass
class Fp8Weight(Weight):
    weight: torch.Tensor
279
280
    dtype: torch.dtype
    weight_scale: Optional[torch.Tensor] = None
281
    input_scale: Optional[torch.Tensor] = None
282
    activation_scale_ub: Optional[float] = None
283
284

    def get_linear(self, bias: torch.Tensor):
285
286
        if self.weight_scale is None:
            return get_fp8_linear().from_unquant(self.weight, bias, self.dtype)
287
288
289
        # This is not checked by the fbgemm kernels, but they require contiguous
        # memory. Can be non-contiguous when we e.g. expand from scalars.
        self.weight_scale = self.weight_scale.contiguous()
290
        return get_fp8_linear().from_fp8(
291
292
293
294
295
296
            weight=self.weight,
            scale=self.weight_scale,
            dtype=self.dtype,
            bias=bias,
            input_scale=self.input_scale,
            scale_upper_bound=self.activation_scale_ub,
297
        )
298
299


Nicolas Patry's avatar
Nicolas Patry committed
300
class Fp8Linear(torch.nn.Module):
301
302
    _device_identity_cache = {}

Nicolas Patry's avatar
Nicolas Patry committed
303
304
    def __init__(
        self,
305
306
307
308
309
310
        qweight: torch.Tensor,
        scale: torch.Tensor,
        dtype: torch.dtype,
        bias: Optional[torch.Tensor] = None,
        input_scale: Optional[torch.Tensor] = None,
        scale_upper_bound: Optional[float] = None,
Nicolas Patry's avatar
Nicolas Patry committed
311
312
    ) -> None:
        super().__init__()
313
314
        if FBGEMM_MM_AVAILABLE:
            log_once(logger.info, "Using FBGEMM fp8 optimized kernels")
315
316
317
318
        if SYSTEM == "rocm" and qweight.dtype == torch.float8_e4m3fn:
            qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(
                weight=qweight, weight_scale=scale
            )
319

320
321
        self.dtype = dtype
        self.qweight = qweight
322
        self.scale = scale.float()
323
        self.input_scale = input_scale.float() if input_scale is not None else None
Nicolas Patry's avatar
Nicolas Patry committed
324

325
326
327
328
329
330
331
332
333
334
335
        if FBGEMM_MM_AVAILABLE:
            self.scale_upper_bound = (
                torch.tensor(
                    [scale_upper_bound], dtype=torch.float32, device=qweight.device
                )
                if scale_upper_bound is not None
                else None
            )
        else:
            self.scale_upper_bound = scale_upper_bound

Nicolas Patry's avatar
Nicolas Patry committed
336
337
        self.bias = bias if bias is not None else None

338
339
    @classmethod
    def from_unquant(cls, weight, bias, dtype):
340
        qweight, scale = fp8_quantize(weight, scalar=not FBGEMM_MM_AVAILABLE)
341
        return cls(
342
343
344
345
346
347
            qweight=qweight,
            scale=scale,
            dtype=dtype,
            bias=bias,
            input_scale=None,
            scale_upper_bound=None,
348
349
350
        )

    @classmethod
351
352
353
354
355
356
357
358
359
360
361
    def from_fp8(
        cls,
        weight: torch.Tensor,
        scale: torch.Tensor,
        dtype: torch.dtype,
        bias: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> "Fp8Linear":
        input_scale = kwargs.get("input_scale", None)
        scale_upper_bound = kwargs.get("scale_upper_bound", None)

362
363
364
        if FBGEMM_DYN_AVAILABLE:
            # fbgemm needs float32 scales.
            scale = scale.float()
365
366
367
        return cls(
            qweight=weight,
            scale=scale,
368
369
            input_scale=input_scale,
            scale_upper_bound=scale_upper_bound,
370
371
372
373
            bias=bias,
            dtype=dtype,
        )

374
375
376
377
378
379
380
381
    @classmethod
    def get_shared_device_identity(cls, device):
        # Input scaling factors are no longer optional in _scaled_mm starting
        # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
        if device not in cls._device_identity_cache:
            cls._device_identity_cache[device] = torch.ones(1, device=device)
        return cls._device_identity_cache[device]

Nicolas Patry's avatar
Nicolas Patry committed
382
    def forward(self, input: torch.Tensor) -> torch.Tensor:
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
        if FBGEMM_MM_AVAILABLE:
            qinput, scale = fp8_quantize(
                input, scale_upper_bound=self.scale_upper_bound
            )

            y = torch.ops.fbgemm.f8f8bf16_rowwise(
                qinput,
                self.qweight,
                scale,
                self.scale,
                use_fast_accum=True,
                bias=self.bias,
            )
            return y.to(self.dtype)

398
399
400
401
402
        qinput, scale = fp8_quantize(
            input,
            self.input_scale,
            scale_upper_bound=self.scale_upper_bound,
            scalar=True,
Nicolas Patry's avatar
Nicolas Patry committed
403
        )
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440

        per_tensor_weights = self.scale.numel() == 1
        per_tensor_activations = scale.numel() == 1

        if SYSTEM != "rocm" or (per_tensor_weights and per_tensor_activations):
            output = torch._scaled_mm(
                qinput,
                self.qweight.t(),
                out_dtype=self.dtype,
                scale_a=scale,
                scale_b=self.scale,
                bias=self.bias,
            )

            if isinstance(output, tuple) and len(output) == 2:
                output = output[0]
        else:
            device_identity = None
            if SYSTEM == "rocm":
                device_identity = self.get_shared_device_identity(self.qweight.device)

            output = torch._scaled_mm(
                qinput,
                self.qweight.t(),
                scale_a=device_identity,
                scale_b=device_identity,
                out_dtype=torch.float32,
            )
            if isinstance(output, tuple) and len(output) == 2:
                output = output[0]

            output = output * scale * self.scale.t()
            if self.bias is not None:
                output = output + self.bias

            output = output.to(dtype=self.dtype)

Nicolas Patry's avatar
Nicolas Patry committed
441
        return output
442
443
444
445
446
447
448


def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size):
    scale = weights.get_tensor(prefix, to_dtype=False)
    if scale.numel() > 1:
        scale = weights.get_sharded(prefix, dim=0, to_dtype=False)
    return scale.reshape(-1).expand(shape[0])