fp8.py 15.3 KB
Newer Older
1
from dataclasses import dataclass
2
3
4
5
import os
from typing import Optional, Tuple, Type, Union, List

import torch
6
from loguru import logger
7
8

from text_generation_server.utils.import_utils import SYSTEM
9
10
11
12
13
14
from text_generation_server.utils.weights import (
    Weight,
    WeightsLoader,
    UnquantizedWeight,
    Weights,
)
15
from text_generation_server.utils.log import log_once
16

17
18
19
20
21
22
try:
    import marlin_kernels
except ImportError:
    marlin_kernels = None


23
24
25
26
27
if SYSTEM == "cuda" and marlin_kernels is not None:
    major, minor = torch.cuda.get_device_capability()
    CUTLASS_FP8_AVAILABLE = marlin_kernels.cutlass_scaled_mm_supports_fp8(
        major * 10 + minor
    )
28
else:
29
    CUTLASS_FP8_AVAILABLE = False
30
31


32
def get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]:
33
34
35
36
37
    """
    Return an FP8 linear `Module` that is compatible with the current system.
    """

    if SYSTEM == "cuda":
38

39
        major, _ = torch.cuda.get_device_capability()
40
41
42
43
44
45
46
47
        # Marlin is W8A16, use it when:
        #
        # - On capability 8.x where x < 8: W8A8 FP8 GEMM is not supported.
        # - On capability 8.9: W8A8 FP8 GEMM is supported, but Marlin-FP8 is faster.
        # - On capability 9.x when force_w8a16: cutlass kernels do not support W8A16.
        if (major == 8 or (major == 9 and force_w8a16)) and os.getenv(
            "USE_CUTLASS_W8A8", "0"
        ) != "1":
48
49
            # NOTE: Capability 8.9 is supported by cutlass kernels, but FP8-Marlin
            #       gives better decoding throughput on L4 and L40.
50
51
52
53
54
55
            from text_generation_server.layers.marlin import GPTQMarlinFP8Linear

            return GPTQMarlinFP8Linear

    # On other systems let Torch decide if the hardware supports FP8.
    return Fp8Linear
Nicolas Patry's avatar
Nicolas Patry committed
56
57


58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def normalize_e4m3fn_to_e4m3fnuz(
    weight: torch.Tensor,
    weight_scale: torch.Tensor,
    input_scale: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
    assert weight.dtype == torch.float8_e4m3fn
    # The bits pattern 10000000(-128) represents zero in e4m3fn
    # but NaN in e4m3fnuz. So here we set it to 0.
    # https://onnx.ai/onnx/technical/float8.html
    weight_as_int8 = weight.view(torch.int8)
    ROCM_FP8_NAN_AS_INT = -128
    weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0
    weight = weight_as_int8.view(torch.float8_e4m3fnuz)

    # For the same bits representation, e4m3fnuz value is half of
    # the e4m3fn value, so we should double the scaling factor to
    # get the same dequantized value.
    # https://onnx.ai/onnx/technical/float8.html
    weight_scale = weight_scale * 2.0
    if input_scale is not None:
        input_scale = input_scale * 2.0
    return weight, weight_scale, input_scale


82
def fp8_quantize(
83
84
85
86
87
    weight: torch.Tensor,
    scale: Optional[torch.Tensor] = None,
    scale_upper_bound: Optional[torch.Tensor] = None,
    qdtype: torch.dtype = torch.float8_e4m3fn,
    scalar: bool = False,
88
):
89
90
91
92
93
94
    """
    This function returns a reciprocal of the scale, so that a tensor can be unscaled
    by multiplying it with the returned scale. If a scale is given through the `scale`
    argument, it must also be a reciprocal (so that scales from an FP8 checkpoint can
    be used without modification).
    """
95
96
97
98
99
100
101
    if marlin_kernels is not None:
        shape = weight.shape
        qweight, scale = marlin_kernels.scaled_fp8_quant(
            weight.reshape(-1, shape[-1]),
            dtype=qdtype,
            scale=scale,
            scale_ub=scale_upper_bound,
102
103
            # TODO: don't do this when we have to use the Torch kernel.
            use_per_token_if_dynamic=not scalar,
104
105
106
107
        )

        return qweight.reshape(shape), scale

Nicolas Patry's avatar
Nicolas Patry committed
108
    finfo = torch.finfo(qdtype)
109
110
111
112

    if scale is None:
        # Calculate the scale as dtype max divided by absmax
        scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound)
113
114
115
116
117
118
119
120
        # scale and clamp the tensor to bring it to
        # the representative range of float8 data type
        # (as default cast is unsaturated)
        qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max)
        scale = scale.float().reciprocal()
    else:
        # Use reciprocal to avoid more expensive division.
        qweight = (weight * scale.reciprocal()).clamp(min=finfo.min, max=finfo.max)
121

Nicolas Patry's avatar
Nicolas Patry committed
122
123
124
    # Return both float8 data and the inverse scale (as float),
    # as both required as inputs to torch._scaled_mm
    qweight = qweight.to(qdtype)
125
126
127
128

    if SYSTEM == "rocm":
        qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(qweight, scale)

Nicolas Patry's avatar
Nicolas Patry committed
129
130
131
    return qweight, scale


132
133
134
135
136
137
138
139
140
141
142
143
class HybridFP8UnquantLoader(WeightsLoader):
    """Weight loader that loads FP8 and unquantized Torch tensors."""

    def __init__(self, activation_scale_ub: Optional[float], to_fp8: bool):
        self.activation_scale_ub = activation_scale_ub
        self.to_fp8 = to_fp8

    def get_weights(self, weights: "Weights", prefix: str):
        w = weights.get_tensor(f"{prefix}.weight")

        if w.dtype == torch.float8_e4m3fn:
            # FP8 branch
144
145
146
147
148
            scale = (
                weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
                .reshape(-1)
                .expand(w.shape[0])
            )
149
150
151
152
153
154
155

            input_scale = None
            if weights.has_tensor(f"{prefix}.input_scale"):
                input_scale = weights.get_tensor(
                    f"{prefix}.input_scale", to_dtype=False
                ).reshape(-1)

156
157
158
            return Fp8Weight(
                weight=w,
                weight_scale=scale,
159
                input_scale=input_scale,
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
                activation_scale_ub=self.activation_scale_ub,
                dtype=weights.dtype,
            )
        if self.to_fp8:
            return Fp8Weight(weight=w, dtype=weights.dtype)

        return UnquantizedWeight(w)

    def get_weights_col_packed(
        self,
        weights: Weights,
        prefix: str,
        block_sizes: Union[int, List[int]],
    ):
        w = weights.get_packed_sharded(
            f"{prefix}.weight", dim=0, block_sizes=block_sizes
        )

        if w.dtype == torch.float8_e4m3fn:
            # FP8 branch
180
181
182
183
184
185
186
187
188
189
            scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
            if scale.numel() > 1:
                scale = weights.get_packed_sharded(
                    f"{prefix}.weight_scale",
                    dim=0,
                    block_sizes=block_sizes,
                    to_dtype=False,
                )
            scale = scale.reshape(-1).expand(w.shape[0])

190
191
192
193
194
195
196
197
198
199
200
201
202
203
            input_scale = None
            if weights.has_tensor(f"{prefix}.input_scale"):
                input_scale = weights.get_tensor(
                    f"{prefix}.input_scale", to_dtype=False
                )
                if input_scale.numel() > 1:
                    input_scale = weights.get_packed_sharded(
                        f"{prefix}.input_scale",
                        dim=0,
                        block_sizes=block_sizes,
                        to_dtype=False,
                    )
                input_scale = input_scale.reshape(-1).max()

204
205
206
            return Fp8Weight(
                weight=w,
                weight_scale=scale,
207
                input_scale=input_scale,
208
209
210
211
212
213
214
215
216
                activation_scale_ub=self.activation_scale_ub,
                dtype=weights.dtype,
            )
        if self.to_fp8:
            return Fp8Weight(weight=w, dtype=weights.dtype)

        return UnquantizedWeight(w)

    def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
217
218
219
220
        # FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet
        w = [
            weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes
        ]
221
222
        shapes = [x.shape for x in w]

223
224
        # Concat then send to the device
        w = torch.cat(w, dim=dim).to(weights.device)
225
226
227
228

        # FP8 branch
        if w.dtype == torch.float8_e4m3fn:
            scale = [
229
230
                _load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
                for p, shape in zip(prefixes, shapes)
231
            ]
232
            scale = torch.cat(scale, dim=0).reshape(-1)
233

234
235
236
237
238
239
240
241
242
243
244
245
            input_scale = [
                _load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape)
                for p, shape in zip(prefixes, shapes)
                if weights.has_tensor(f"{p}.input_scale")
            ]
            assert len(input_scale) == 0 or len(input_scale) == len(prefixes)
            input_scale = (
                torch.cat(input_scale, dim=0).reshape(-1).max()
                if len(input_scale) != 0
                else None
            )

246
247
248
            return Fp8Weight(
                weight=w,
                weight_scale=scale,
249
                input_scale=input_scale,
250
251
252
253
254
255
256
257
258
259
260
261
                activation_scale_ub=self.activation_scale_ub,
                dtype=weights.dtype,
            )
        if self.to_fp8:
            return Fp8Weight(weight=w, dtype=weights.dtype)

        return UnquantizedWeight(w)

    def get_weights_row(self, weights: "Weights", prefix: str):
        w = weights.get_sharded(f"{prefix}.weight", dim=1)
        # FP8 branch
        if w.dtype == torch.float8_e4m3fn:
262
263
264
265
266
            scale = (
                weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
                .reshape(-1)
                .expand(w.shape[0])
            )
267
268
269
270
271
272
            input_scale = None
            if weights.has_tensor(f"{prefix}.input_scale"):
                input_scale = weights.get_tensor(
                    f"{prefix}.input_scale", to_dtype=False
                ).reshape(-1)

273
274
275
            return Fp8Weight(
                weight=w,
                weight_scale=scale,
276
                input_scale=input_scale,
277
278
279
280
281
282
283
284
285
                activation_scale_ub=self.activation_scale_ub,
                dtype=weights.dtype,
            )
        if self.to_fp8:
            return Fp8Weight(weight=w, dtype=weights.dtype)

        return UnquantizedWeight(w)


286
287
288
@dataclass
class Fp8Weight(Weight):
    weight: torch.Tensor
289
290
    dtype: torch.dtype
    weight_scale: Optional[torch.Tensor] = None
291
    input_scale: Optional[torch.Tensor] = None
292
    activation_scale_ub: Optional[float] = None
293
    force_w8a16: bool = False
294
295

    def get_linear(self, bias: torch.Tensor):
296
        if self.weight_scale is None:
297
298
299
            return get_fp8_linear(force_w8a16=self.force_w8a16).from_unquant(
                self.weight, bias, self.dtype
            )
300
301
302
        # This is not checked by the fbgemm kernels, but they require contiguous
        # memory. Can be non-contiguous when we e.g. expand from scalars.
        self.weight_scale = self.weight_scale.contiguous()
303
        return get_fp8_linear(force_w8a16=self.force_w8a16).from_fp8(
304
305
306
307
308
309
            weight=self.weight,
            scale=self.weight_scale,
            dtype=self.dtype,
            bias=bias,
            input_scale=self.input_scale,
            scale_upper_bound=self.activation_scale_ub,
310
        )
311
312


Nicolas Patry's avatar
Nicolas Patry committed
313
class Fp8Linear(torch.nn.Module):
314
315
    _device_identity_cache = {}

Nicolas Patry's avatar
Nicolas Patry committed
316
317
    def __init__(
        self,
318
319
320
321
322
323
        qweight: torch.Tensor,
        scale: torch.Tensor,
        dtype: torch.dtype,
        bias: Optional[torch.Tensor] = None,
        input_scale: Optional[torch.Tensor] = None,
        scale_upper_bound: Optional[float] = None,
Nicolas Patry's avatar
Nicolas Patry committed
324
325
    ) -> None:
        super().__init__()
326
327
        if CUTLASS_FP8_AVAILABLE:
            log_once(logger.info, "Using cutlass w8a8 kernels")
328
329
330
331
        if SYSTEM == "rocm" and qweight.dtype == torch.float8_e4m3fn:
            qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(
                weight=qweight, weight_scale=scale
            )
332

333
334
        self.dtype = dtype
        self.qweight = qweight
335
        self.scale = scale.float()
336
        self.input_scale = input_scale.float() if input_scale is not None else None
Nicolas Patry's avatar
Nicolas Patry committed
337

338
339
340
        if CUTLASS_FP8_AVAILABLE and scale_upper_bound is not None:
            self.scale_upper_bound = torch.tensor(
                scale_upper_bound, dtype=torch.float32, device=qweight.device
341
342
343
344
            )
        else:
            self.scale_upper_bound = scale_upper_bound

Nicolas Patry's avatar
Nicolas Patry committed
345
346
        self.bias = bias if bias is not None else None

347
348
    @classmethod
    def from_unquant(cls, weight, bias, dtype):
349
        qweight, scale = fp8_quantize(weight, scalar=not CUTLASS_FP8_AVAILABLE)
350
        return cls(
351
352
353
354
355
356
            qweight=qweight,
            scale=scale,
            dtype=dtype,
            bias=bias,
            input_scale=None,
            scale_upper_bound=None,
357
358
359
        )

    @classmethod
360
361
362
363
364
365
366
367
368
369
370
    def from_fp8(
        cls,
        weight: torch.Tensor,
        scale: torch.Tensor,
        dtype: torch.dtype,
        bias: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> "Fp8Linear":
        input_scale = kwargs.get("input_scale", None)
        scale_upper_bound = kwargs.get("scale_upper_bound", None)

371
372
373
        return cls(
            qweight=weight,
            scale=scale,
374
375
            input_scale=input_scale,
            scale_upper_bound=scale_upper_bound,
376
377
378
379
            bias=bias,
            dtype=dtype,
        )

380
381
382
383
384
385
386
387
    @classmethod
    def get_shared_device_identity(cls, device):
        # Input scaling factors are no longer optional in _scaled_mm starting
        # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
        if device not in cls._device_identity_cache:
            cls._device_identity_cache[device] = torch.ones(1, device=device)
        return cls._device_identity_cache[device]

Nicolas Patry's avatar
Nicolas Patry committed
388
    def forward(self, input: torch.Tensor) -> torch.Tensor:
389
390
        if CUTLASS_FP8_AVAILABLE:
            # cutlass FP8 supports per-token scales, so get non-scalar scales.
391
            qinput, scale = fp8_quantize(
392
                input, scale_upper_bound=self.scale_upper_bound, scalar=False
393
            )
394
395
            return marlin_kernels.cutlass_scaled_mm(
                qinput, self.qweight.t(), scale, self.scale, input.dtype, self.bias
396
397
            )

398
399
400
401
402
        qinput, scale = fp8_quantize(
            input,
            self.input_scale,
            scale_upper_bound=self.scale_upper_bound,
            scalar=True,
Nicolas Patry's avatar
Nicolas Patry committed
403
        )
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440

        per_tensor_weights = self.scale.numel() == 1
        per_tensor_activations = scale.numel() == 1

        if SYSTEM != "rocm" or (per_tensor_weights and per_tensor_activations):
            output = torch._scaled_mm(
                qinput,
                self.qweight.t(),
                out_dtype=self.dtype,
                scale_a=scale,
                scale_b=self.scale,
                bias=self.bias,
            )

            if isinstance(output, tuple) and len(output) == 2:
                output = output[0]
        else:
            device_identity = None
            if SYSTEM == "rocm":
                device_identity = self.get_shared_device_identity(self.qweight.device)

            output = torch._scaled_mm(
                qinput,
                self.qweight.t(),
                scale_a=device_identity,
                scale_b=device_identity,
                out_dtype=torch.float32,
            )
            if isinstance(output, tuple) and len(output) == 2:
                output = output[0]

            output = output * scale * self.scale.t()
            if self.bias is not None:
                output = output + self.bias

            output = output.to(dtype=self.dtype)

Nicolas Patry's avatar
Nicolas Patry committed
441
        return output
442
443
444
445
446
447
448


def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size):
    scale = weights.get_tensor(prefix, to_dtype=False)
    if scale.numel() > 1:
        scale = weights.get_sharded(prefix, dim=0, to_dtype=False)
    return scale.reshape(-1).expand(shape[0])