fp8.py 14.8 KB
Newer Older
1
from dataclasses import dataclass
2
3
4
5
import os
from typing import Optional, Tuple, Type, Union, List

import torch
6
from loguru import logger
7
8

from text_generation_server.utils.import_utils import SYSTEM
9
10
11
12
13
14
from text_generation_server.utils.weights import (
    Weight,
    WeightsLoader,
    UnquantizedWeight,
    Weights,
)
15
from text_generation_server.utils.log import log_once
16

17
18
19
20
21
22
try:
    import marlin_kernels
except ImportError:
    marlin_kernels = None


23
24
25
26
27
if SYSTEM == "cuda" and marlin_kernels is not None:
    major, minor = torch.cuda.get_device_capability()
    CUTLASS_FP8_AVAILABLE = marlin_kernels.cutlass_scaled_mm_supports_fp8(
        major * 10 + minor
    )
28
else:
29
    CUTLASS_FP8_AVAILABLE = False
30
31


32
def get_fp8_linear() -> Type[torch.nn.Module]:
33
34
35
36
37
    """
    Return an FP8 linear `Module` that is compatible with the current system.
    """

    if SYSTEM == "cuda":
38

39
        major, _ = torch.cuda.get_device_capability()
40
41
42
        if major == 8 and os.getenv("USE_CUTLASS_W8A8", "0") != "1":
            # NOTE: Capability 8.9 is supported by cutlass kernels, but FP8-Marlin
            #       gives better decoding throughput on L4 and L40.
43
44
45
46
47
48
            from text_generation_server.layers.marlin import GPTQMarlinFP8Linear

            return GPTQMarlinFP8Linear

    # On other systems let Torch decide if the hardware supports FP8.
    return Fp8Linear
Nicolas Patry's avatar
Nicolas Patry committed
49
50


51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def normalize_e4m3fn_to_e4m3fnuz(
    weight: torch.Tensor,
    weight_scale: torch.Tensor,
    input_scale: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
    assert weight.dtype == torch.float8_e4m3fn
    # The bits pattern 10000000(-128) represents zero in e4m3fn
    # but NaN in e4m3fnuz. So here we set it to 0.
    # https://onnx.ai/onnx/technical/float8.html
    weight_as_int8 = weight.view(torch.int8)
    ROCM_FP8_NAN_AS_INT = -128
    weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0
    weight = weight_as_int8.view(torch.float8_e4m3fnuz)

    # For the same bits representation, e4m3fnuz value is half of
    # the e4m3fn value, so we should double the scaling factor to
    # get the same dequantized value.
    # https://onnx.ai/onnx/technical/float8.html
    weight_scale = weight_scale * 2.0
    if input_scale is not None:
        input_scale = input_scale * 2.0
    return weight, weight_scale, input_scale


75
def fp8_quantize(
76
77
78
79
80
    weight: torch.Tensor,
    scale: Optional[torch.Tensor] = None,
    scale_upper_bound: Optional[torch.Tensor] = None,
    qdtype: torch.dtype = torch.float8_e4m3fn,
    scalar: bool = False,
81
):
82
83
84
85
86
87
    """
    This function returns a reciprocal of the scale, so that a tensor can be unscaled
    by multiplying it with the returned scale. If a scale is given through the `scale`
    argument, it must also be a reciprocal (so that scales from an FP8 checkpoint can
    be used without modification).
    """
88
89
90
91
92
93
94
    if marlin_kernels is not None:
        shape = weight.shape
        qweight, scale = marlin_kernels.scaled_fp8_quant(
            weight.reshape(-1, shape[-1]),
            dtype=qdtype,
            scale=scale,
            scale_ub=scale_upper_bound,
95
96
            # TODO: don't do this when we have to use the Torch kernel.
            use_per_token_if_dynamic=not scalar,
97
98
99
100
        )

        return qweight.reshape(shape), scale

Nicolas Patry's avatar
Nicolas Patry committed
101
    finfo = torch.finfo(qdtype)
102
103
104
105

    if scale is None:
        # Calculate the scale as dtype max divided by absmax
        scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound)
106
107
108
109
110
111
112
113
        # scale and clamp the tensor to bring it to
        # the representative range of float8 data type
        # (as default cast is unsaturated)
        qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max)
        scale = scale.float().reciprocal()
    else:
        # Use reciprocal to avoid more expensive division.
        qweight = (weight * scale.reciprocal()).clamp(min=finfo.min, max=finfo.max)
114

Nicolas Patry's avatar
Nicolas Patry committed
115
116
117
    # Return both float8 data and the inverse scale (as float),
    # as both required as inputs to torch._scaled_mm
    qweight = qweight.to(qdtype)
118
119
120
121

    if SYSTEM == "rocm":
        qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(qweight, scale)

Nicolas Patry's avatar
Nicolas Patry committed
122
123
124
    return qweight, scale


125
126
127
128
129
130
131
132
133
134
135
136
class HybridFP8UnquantLoader(WeightsLoader):
    """Weight loader that loads FP8 and unquantized Torch tensors."""

    def __init__(self, activation_scale_ub: Optional[float], to_fp8: bool):
        self.activation_scale_ub = activation_scale_ub
        self.to_fp8 = to_fp8

    def get_weights(self, weights: "Weights", prefix: str):
        w = weights.get_tensor(f"{prefix}.weight")

        if w.dtype == torch.float8_e4m3fn:
            # FP8 branch
137
138
139
140
141
            scale = (
                weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
                .reshape(-1)
                .expand(w.shape[0])
            )
142
143
144
145
146
147
148

            input_scale = None
            if weights.has_tensor(f"{prefix}.input_scale"):
                input_scale = weights.get_tensor(
                    f"{prefix}.input_scale", to_dtype=False
                ).reshape(-1)

149
150
151
            return Fp8Weight(
                weight=w,
                weight_scale=scale,
152
                input_scale=input_scale,
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
                activation_scale_ub=self.activation_scale_ub,
                dtype=weights.dtype,
            )
        if self.to_fp8:
            return Fp8Weight(weight=w, dtype=weights.dtype)

        return UnquantizedWeight(w)

    def get_weights_col_packed(
        self,
        weights: Weights,
        prefix: str,
        block_sizes: Union[int, List[int]],
    ):
        w = weights.get_packed_sharded(
            f"{prefix}.weight", dim=0, block_sizes=block_sizes
        )

        if w.dtype == torch.float8_e4m3fn:
            # FP8 branch
173
174
175
176
177
178
179
180
181
182
            scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
            if scale.numel() > 1:
                scale = weights.get_packed_sharded(
                    f"{prefix}.weight_scale",
                    dim=0,
                    block_sizes=block_sizes,
                    to_dtype=False,
                )
            scale = scale.reshape(-1).expand(w.shape[0])

183
184
185
186
187
188
189
190
191
192
193
194
195
196
            input_scale = None
            if weights.has_tensor(f"{prefix}.input_scale"):
                input_scale = weights.get_tensor(
                    f"{prefix}.input_scale", to_dtype=False
                )
                if input_scale.numel() > 1:
                    input_scale = weights.get_packed_sharded(
                        f"{prefix}.input_scale",
                        dim=0,
                        block_sizes=block_sizes,
                        to_dtype=False,
                    )
                input_scale = input_scale.reshape(-1).max()

197
198
199
            return Fp8Weight(
                weight=w,
                weight_scale=scale,
200
                input_scale=input_scale,
201
202
203
204
205
206
207
208
209
                activation_scale_ub=self.activation_scale_ub,
                dtype=weights.dtype,
            )
        if self.to_fp8:
            return Fp8Weight(weight=w, dtype=weights.dtype)

        return UnquantizedWeight(w)

    def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
210
211
212
213
        # FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet
        w = [
            weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes
        ]
214
215
        shapes = [x.shape for x in w]

216
217
        # Concat then send to the device
        w = torch.cat(w, dim=dim).to(weights.device)
218
219
220
221

        # FP8 branch
        if w.dtype == torch.float8_e4m3fn:
            scale = [
222
223
                _load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
                for p, shape in zip(prefixes, shapes)
224
            ]
225
            scale = torch.cat(scale, dim=0).reshape(-1)
226

227
228
229
230
231
232
233
234
235
236
237
238
            input_scale = [
                _load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape)
                for p, shape in zip(prefixes, shapes)
                if weights.has_tensor(f"{p}.input_scale")
            ]
            assert len(input_scale) == 0 or len(input_scale) == len(prefixes)
            input_scale = (
                torch.cat(input_scale, dim=0).reshape(-1).max()
                if len(input_scale) != 0
                else None
            )

239
240
241
            return Fp8Weight(
                weight=w,
                weight_scale=scale,
242
                input_scale=input_scale,
243
244
245
246
247
248
249
250
251
252
253
254
                activation_scale_ub=self.activation_scale_ub,
                dtype=weights.dtype,
            )
        if self.to_fp8:
            return Fp8Weight(weight=w, dtype=weights.dtype)

        return UnquantizedWeight(w)

    def get_weights_row(self, weights: "Weights", prefix: str):
        w = weights.get_sharded(f"{prefix}.weight", dim=1)
        # FP8 branch
        if w.dtype == torch.float8_e4m3fn:
255
256
257
258
259
            scale = (
                weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
                .reshape(-1)
                .expand(w.shape[0])
            )
260
261
262
263
264
265
            input_scale = None
            if weights.has_tensor(f"{prefix}.input_scale"):
                input_scale = weights.get_tensor(
                    f"{prefix}.input_scale", to_dtype=False
                ).reshape(-1)

266
267
268
            return Fp8Weight(
                weight=w,
                weight_scale=scale,
269
                input_scale=input_scale,
270
271
272
273
274
275
276
277
278
                activation_scale_ub=self.activation_scale_ub,
                dtype=weights.dtype,
            )
        if self.to_fp8:
            return Fp8Weight(weight=w, dtype=weights.dtype)

        return UnquantizedWeight(w)


279
280
281
@dataclass
class Fp8Weight(Weight):
    weight: torch.Tensor
282
283
    dtype: torch.dtype
    weight_scale: Optional[torch.Tensor] = None
284
    input_scale: Optional[torch.Tensor] = None
285
    activation_scale_ub: Optional[float] = None
286
287

    def get_linear(self, bias: torch.Tensor):
288
289
        if self.weight_scale is None:
            return get_fp8_linear().from_unquant(self.weight, bias, self.dtype)
290
291
292
        # This is not checked by the fbgemm kernels, but they require contiguous
        # memory. Can be non-contiguous when we e.g. expand from scalars.
        self.weight_scale = self.weight_scale.contiguous()
293
        return get_fp8_linear().from_fp8(
294
295
296
297
298
299
            weight=self.weight,
            scale=self.weight_scale,
            dtype=self.dtype,
            bias=bias,
            input_scale=self.input_scale,
            scale_upper_bound=self.activation_scale_ub,
300
        )
301
302


Nicolas Patry's avatar
Nicolas Patry committed
303
class Fp8Linear(torch.nn.Module):
304
305
    _device_identity_cache = {}

Nicolas Patry's avatar
Nicolas Patry committed
306
307
    def __init__(
        self,
308
309
310
311
312
313
        qweight: torch.Tensor,
        scale: torch.Tensor,
        dtype: torch.dtype,
        bias: Optional[torch.Tensor] = None,
        input_scale: Optional[torch.Tensor] = None,
        scale_upper_bound: Optional[float] = None,
Nicolas Patry's avatar
Nicolas Patry committed
314
315
    ) -> None:
        super().__init__()
316
317
        if CUTLASS_FP8_AVAILABLE:
            log_once(logger.info, "Using cutlass w8a8 kernels")
318
319
320
321
        if SYSTEM == "rocm" and qweight.dtype == torch.float8_e4m3fn:
            qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(
                weight=qweight, weight_scale=scale
            )
322

323
324
        self.dtype = dtype
        self.qweight = qweight
325
        self.scale = scale.float()
326
        self.input_scale = input_scale.float() if input_scale is not None else None
Nicolas Patry's avatar
Nicolas Patry committed
327

328
329
330
        if CUTLASS_FP8_AVAILABLE and scale_upper_bound is not None:
            self.scale_upper_bound = torch.tensor(
                scale_upper_bound, dtype=torch.float32, device=qweight.device
331
332
333
334
            )
        else:
            self.scale_upper_bound = scale_upper_bound

Nicolas Patry's avatar
Nicolas Patry committed
335
336
        self.bias = bias if bias is not None else None

337
338
    @classmethod
    def from_unquant(cls, weight, bias, dtype):
339
        qweight, scale = fp8_quantize(weight, scalar=not CUTLASS_FP8_AVAILABLE)
340
        return cls(
341
342
343
344
345
346
            qweight=qweight,
            scale=scale,
            dtype=dtype,
            bias=bias,
            input_scale=None,
            scale_upper_bound=None,
347
348
349
        )

    @classmethod
350
351
352
353
354
355
356
357
358
359
360
    def from_fp8(
        cls,
        weight: torch.Tensor,
        scale: torch.Tensor,
        dtype: torch.dtype,
        bias: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> "Fp8Linear":
        input_scale = kwargs.get("input_scale", None)
        scale_upper_bound = kwargs.get("scale_upper_bound", None)

361
362
363
        return cls(
            qweight=weight,
            scale=scale,
364
365
            input_scale=input_scale,
            scale_upper_bound=scale_upper_bound,
366
367
368
369
            bias=bias,
            dtype=dtype,
        )

370
371
372
373
374
375
376
377
    @classmethod
    def get_shared_device_identity(cls, device):
        # Input scaling factors are no longer optional in _scaled_mm starting
        # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
        if device not in cls._device_identity_cache:
            cls._device_identity_cache[device] = torch.ones(1, device=device)
        return cls._device_identity_cache[device]

Nicolas Patry's avatar
Nicolas Patry committed
378
    def forward(self, input: torch.Tensor) -> torch.Tensor:
379
380
        if CUTLASS_FP8_AVAILABLE:
            # cutlass FP8 supports per-token scales, so get non-scalar scales.
381
            qinput, scale = fp8_quantize(
382
                input, scale_upper_bound=self.scale_upper_bound, scalar=False
383
            )
384
385
            return marlin_kernels.cutlass_scaled_mm(
                qinput, self.qweight.t(), scale, self.scale, input.dtype, self.bias
386
387
            )

388
389
390
391
392
        qinput, scale = fp8_quantize(
            input,
            self.input_scale,
            scale_upper_bound=self.scale_upper_bound,
            scalar=True,
Nicolas Patry's avatar
Nicolas Patry committed
393
        )
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430

        per_tensor_weights = self.scale.numel() == 1
        per_tensor_activations = scale.numel() == 1

        if SYSTEM != "rocm" or (per_tensor_weights and per_tensor_activations):
            output = torch._scaled_mm(
                qinput,
                self.qweight.t(),
                out_dtype=self.dtype,
                scale_a=scale,
                scale_b=self.scale,
                bias=self.bias,
            )

            if isinstance(output, tuple) and len(output) == 2:
                output = output[0]
        else:
            device_identity = None
            if SYSTEM == "rocm":
                device_identity = self.get_shared_device_identity(self.qweight.device)

            output = torch._scaled_mm(
                qinput,
                self.qweight.t(),
                scale_a=device_identity,
                scale_b=device_identity,
                out_dtype=torch.float32,
            )
            if isinstance(output, tuple) and len(output) == 2:
                output = output[0]

            output = output * scale * self.scale.t()
            if self.bias is not None:
                output = output + self.bias

            output = output.to(dtype=self.dtype)

Nicolas Patry's avatar
Nicolas Patry committed
431
        return output
432
433
434
435
436
437
438


def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size):
    scale = weights.get_tensor(prefix, to_dtype=False)
    if scale.numel() > 1:
        scale = weights.get_sharded(prefix, dim=0, to_dtype=False)
    return scale.reshape(-1).expand(shape[0])