fp8.py 7.36 KB
Newer Older
1
2
import torch

3
from dataclasses import dataclass
4
5
from typing import Optional, Union, List
from loguru import logger
6
7

from text_generation_server.utils.import_utils import SYSTEM
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from text_generation_server.utils.weights import (
    Weight,
    WeightsLoader,
    UnquantizedWeight,
    Weights,
)
from text_generation_server.utils.log import log_master, log_once

FBGEMM_MM_AVAILABLE = False
FBGEMM_DYN_AVAILABLE = False
try:
    import fbgemm_gpu.experimental.gen_ai

    if SYSTEM == "cuda":
        major, _ = torch.cuda.get_device_capability()
        FBGEMM_MM_AVAILABLE = major == 9
        FBGEMM_DYN_AVAILABLE = major >= 8
except (ImportError, ModuleNotFoundError):
    log_master(logger.warning, "FBGEMM fp8 kernels are not installed.")
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42


def get_fp8_linear() -> torch.nn.Module:
    """
    Return an FP8 linear `Module` that is compatible with the current system.
    """

    if SYSTEM == "cuda":
        major, minor = torch.cuda.get_device_capability()
        if major == 8 and minor < 9:
            from text_generation_server.layers.marlin import GPTQMarlinFP8Linear

            return GPTQMarlinFP8Linear

    # On other systems let Torch decide if the hardware supports FP8.
    return Fp8Linear
Nicolas Patry's avatar
Nicolas Patry committed
43
44


45
46
47
48
49
50
51
def fp8_quantize(weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn):
    if FBGEMM_DYN_AVAILABLE:
        qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row(
            weight, bs=None, scale_ub=scale_upper_bound, output_dtype=qdtype
        )
        return qweight, scale

Nicolas Patry's avatar
Nicolas Patry committed
52
53
54
    # weight, scale = quant_weights(weight, torch.int8, False)
    finfo = torch.finfo(qdtype)
    # Calculate the scale as dtype max divided by absmax
55
    scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound)
Nicolas Patry's avatar
Nicolas Patry committed
56
57
58
59
60
61
62
63
64
65
66
    # scale and clamp the tensor to bring it to
    # the representative range of float8 data type
    # (as default cast is unsaturated)
    qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max)
    # Return both float8 data and the inverse scale (as float),
    # as both required as inputs to torch._scaled_mm
    qweight = qweight.to(qdtype)
    scale = scale.float().reciprocal()
    return qweight, scale


67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
class HybridFP8UnquantLoader(WeightsLoader):
    """Weight loader that loads FP8 and unquantized Torch tensors."""

    def __init__(self, activation_scale_ub: Optional[float], to_fp8: bool):
        self.activation_scale_ub = activation_scale_ub
        self.to_fp8 = to_fp8

    def get_weights(self, weights: "Weights", prefix: str):
        w = weights.get_tensor(f"{prefix}.weight")

        if w.dtype == torch.float8_e4m3fn:
            # FP8 branch
            scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
            return Fp8Weight(
                weight=w,
                weight_scale=scale,
                activation_scale_ub=self.activation_scale_ub,
                dtype=weights.dtype,
            )
        if self.to_fp8:
            return Fp8Weight(weight=w, dtype=weights.dtype)

        return UnquantizedWeight(w)

    def get_weights_col_packed(
        self,
        weights: Weights,
        prefix: str,
        block_sizes: Union[int, List[int]],
    ):
        w = weights.get_packed_sharded(
            f"{prefix}.weight", dim=0, block_sizes=block_sizes
        )

        if w.dtype == torch.float8_e4m3fn:
            # FP8 branch
            scale = weights.get_packed_sharded(
                f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes, to_dtype=False
            )
            return Fp8Weight(
                weight=w,
                weight_scale=scale,
                activation_scale_ub=self.activation_scale_ub,
                dtype=weights.dtype,
            )
        if self.to_fp8:
            return Fp8Weight(weight=w, dtype=weights.dtype)

        return UnquantizedWeight(w)

    def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
        w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
        w = torch.cat(w, dim=dim)

        # FP8 branch
        if w.dtype == torch.float8_e4m3fn:
            scale = [
                weights.get_sharded(f"{p}.weight_scale", dim=0, to_dtype=False)
                for p in prefixes
            ]
            scale = torch.cat(scale, dim=0)
            return Fp8Weight(
                weight=w,
                weight_scale=scale,
                activation_scale_ub=self.activation_scale_ub,
                dtype=weights.dtype,
            )
        if self.to_fp8:
            return Fp8Weight(weight=w, dtype=weights.dtype)

        return UnquantizedWeight(w)

    def get_weights_row(self, weights: "Weights", prefix: str):
        w = weights.get_sharded(f"{prefix}.weight", dim=1)
        # FP8 branch
        if w.dtype == torch.float8_e4m3fn:
            scale = weights.get_sharded(f"{prefix}.weight_scale", dim=0, to_dtype=False)
            return Fp8Weight(
                weight=w,
                weight_scale=scale,
                activation_scale_ub=self.activation_scale_ub,
                dtype=weights.dtype,
            )
        if self.to_fp8:
            return Fp8Weight(weight=w, dtype=weights.dtype)

        return UnquantizedWeight(w)


156
157
158
@dataclass
class Fp8Weight(Weight):
    weight: torch.Tensor
159
160
161
    dtype: torch.dtype
    weight_scale: Optional[torch.Tensor] = None
    activation_scale_ub: Optional[float] = None
162
163

    def get_linear(self, bias: torch.Tensor):
164
165
166
167
168
        if self.weight_scale is None:
            return get_fp8_linear().from_unquant(self.weight, bias, self.dtype)
        return get_fp8_linear().from_fp8(
            self.weight, self.weight_scale, self.activation_scale_ub, bias, self.dtype
        )
169
170


Nicolas Patry's avatar
Nicolas Patry committed
171
172
173
class Fp8Linear(torch.nn.Module):
    def __init__(
        self,
174
175
176
        qweight,
        scale,
        scale_upper_bound,
Nicolas Patry's avatar
Nicolas Patry committed
177
        bias,
178
        dtype,
Nicolas Patry's avatar
Nicolas Patry committed
179
180
    ) -> None:
        super().__init__()
181
182
183
184
185
186
187
188
189
190
        self.dtype = dtype
        self.qweight = qweight
        self.scale = scale
        self.scale_upper_bound = (
            torch.tensor(
                [scale_upper_bound], dtype=torch.float32, device=qweight.device
            )
            if scale_upper_bound is not None
            else None
        )
Nicolas Patry's avatar
Nicolas Patry committed
191
192
193

        self.bias = bias if bias is not None else None

194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
    @classmethod
    def from_unquant(cls, weight, bias, dtype):
        qweight, scale = fp8_quantize(weight)
        return cls(
            qweight=qweight, scale=scale, scale_upper_bound=None, bias=bias, dtype=dtype
        )

    @classmethod
    def from_fp8(cls, weight, scale, input_scale, bias, dtype):
        return cls(
            qweight=weight,
            scale=scale,
            scale_upper_bound=input_scale,
            bias=bias,
            dtype=dtype,
        )

Nicolas Patry's avatar
Nicolas Patry committed
211
    def forward(self, input: torch.Tensor) -> torch.Tensor:
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
        if FBGEMM_MM_AVAILABLE:
            qinput, scale = fp8_quantize(
                input, scale_upper_bound=self.scale_upper_bound
            )

            y = torch.ops.fbgemm.f8f8bf16_rowwise(
                qinput,
                self.qweight,
                scale,
                self.scale,
                use_fast_accum=True,
                bias=self.bias,
            )
            return y.to(self.dtype)

Nicolas Patry's avatar
Nicolas Patry committed
227
228
229
230
231
232
233
234
235
236
        qinput, scale = fp8_quantize(input)
        output, _ = torch._scaled_mm(
            qinput,
            self.qweight.t(),
            out_dtype=self.dtype,
            scale_a=scale,
            scale_b=self.scale,
            bias=self.bias,
        )
        return output