linear.py 7.46 KB
Newer Older
1
from typing import Optional
Nicolas Patry's avatar
Nicolas Patry committed
2
3
4
5
import torch
from torch.nn import functional as F
from text_generation_server.utils.import_utils import SYSTEM

fxmarty's avatar
fxmarty committed
6
7
8
9
10
11
if SYSTEM == "rocm":
    try:
        from vllm import _custom_C
    except Exception as e:
        raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")

Nicolas Patry's avatar
Nicolas Patry committed
12
13
14
15
16
17
18
19

class FastLinear(torch.nn.Module):
    def __init__(
        self,
        weight,
        bias,
    ) -> None:
        super().__init__()
drbh's avatar
drbh committed
20
        self.weight = torch.nn.Parameter(weight, requires_grad=False)
Nicolas Patry's avatar
Nicolas Patry committed
21
        if bias is not None:
drbh's avatar
drbh committed
22
            self.bias = torch.nn.Parameter(bias, requires_grad=False)
Nicolas Patry's avatar
Nicolas Patry committed
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
        else:
            self.bias = None

    @classmethod
    def load(cls, config, prefix: str, weights, bias: bool):
        weight = weights.get_tensor(f"{prefix}.weight")
        if bias:
            bias = weights.get_tensor(f"{prefix}.bias")
        else:
            bias = None
        return cls(weight, bias)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return F.linear(input, self.weight, self.bias)


fxmarty's avatar
fxmarty committed
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
class FastLinearROCm(torch.nn.Module):
    def __init__(
        self,
        weight,
        bias,
    ) -> None:
        super().__init__()
        self.weight = torch.nn.Parameter(weight)
        if bias is not None:
            self.bias = torch.nn.Parameter(bias)
        else:
            self.bias = None

    @classmethod
    def load(cls, config, prefix: str, weights, bias: bool):
        weight = weights.get_tensor(f"{prefix}.weight")
        if bias:
            bias = weights.get_tensor(f"{prefix}.bias")
        else:
            bias = None
        return cls(weight, bias)

    def forward(self, inp: torch.Tensor) -> torch.Tensor:
        weight = self.weight
        bias = self.bias

        if SYSTEM == "rocm" and inp.numel() // inp.shape[-1] == 1:
            batched = False
            inp_shape = inp.shape

            if inp.dim() == 3:
                inp = inp.view(-1, inp_shape[-1])
                batched = True

            m, k = weight.shape[0], inp_shape[1]
            out = torch.empty(
                inp_shape[0], weight.shape[0], dtype=inp.dtype, device="cuda"
            )
            if (k == 8192 and (m == 1280 or m == 7168)) or (k == 3584 and m == 8192):
                _custom_C.LLMM1(weight, inp, out, 8)
            elif k <= 8192 and k % 8 == 0 and m % 4 == 0:
                _custom_C.LLMM1(weight, inp, out, 4)
            else:
                out = F.linear(inp, weight)

            if batched:
                out.view(*inp_shape[:-1], out.shape[-1])

            if bias is not None:
                out = out + bias
            return out
        return F.linear(inp, self.weight, self.bias)


Nicolas Patry's avatar
Nicolas Patry committed
93
94
def get_linear(weight, bias, quantize):
    if quantize is None:
fxmarty's avatar
fxmarty committed
95
96
97
98
        if SYSTEM == "rocm":
            linear = FastLinearROCm(weight, bias)
        else:
            linear = FastLinear(weight, bias)
Nicolas Patry's avatar
Nicolas Patry committed
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
    elif quantize == "eetq":
        try:
            from text_generation_server.layers.eetq import EETQLinear

            linear = EETQLinear(weight, bias)
        except ImportError:
            raise ImportError(
                "Please install EETQ from https://github.com/NetEase-FuXi/EETQ"
            )
    elif quantize == "fp8":
        from text_generation_server.layers.fp8 import Fp8Linear

        linear = Fp8Linear(weight, bias)
    elif quantize == "bitsandbytes":
        try:
            from text_generation_server.layers.bnb import (
                warn_deprecate_bnb,
                Linear8bitLt,
            )
        except ImportError:
            raise NotImplementedError(
                f"Bitsandbytes is missing install it with `pip install bitsandbytes`."
            )
        warn_deprecate_bnb()
        linear = Linear8bitLt(
            weight,
            bias,
            has_fp16_weights=False,
            threshold=6.0,
        )
        if bias is not None:
            linear.bias = nn.Parameter(bias)
    elif quantize == "bitsandbytes-fp4":
        try:
            from text_generation_server.layers.bnb import Linear4bit
        except ImportError:
            raise NotImplementedError(
                f"Bitsandbytes is missing install it with `pip install bitsandbytes`."
            )
        linear = Linear4bit(
            weight,
            bias,
            quant_type="fp4",
        )
    elif quantize == "bitsandbytes-nf4":
        try:
            from text_generation_server.layers.bnb import Linear4bit
        except ImportError:
            raise NotImplementedError(
                f"Bitsandbytes is missing install it with `pip install bitsandbytes`."
            )
        linear = Linear4bit(
            weight,
            bias,
            quant_type="nf4",
        )
155
    elif quantize == "exl2":
Nicolas Patry's avatar
Nicolas Patry committed
156
157
        from text_generation_server.layers.exl2 import Exl2Weight

158
159
160
161
162
163
164
165
166
        if not isinstance(weight, Exl2Weight):
            raise NotImplementedError(
                f"The passed weight is not `exl2` compatible, loader needs to be updated."
            )

        from text_generation_server.layers.gptq import ExllamaQuantLinear

        linear = ExllamaQuantLinear(weight, bias)

Nicolas Patry's avatar
Nicolas Patry committed
167
    elif quantize == "gptq":
Nicolas Patry's avatar
Nicolas Patry committed
168
169
        from text_generation_server.layers.gptq import GPTQWeight

170
        if not isinstance(weight, GPTQWeight):
Nicolas Patry's avatar
Nicolas Patry committed
171
172
173
174
            raise NotImplementedError(
                f"The passed weight is not `gptq` compatible, loader needs to be updated."
            )

175
        if weight.use_exllama:
Nicolas Patry's avatar
Nicolas Patry committed
176
177
178
179
180
181
182
183
184
            try:
                from text_generation_server.layers.gptq import (
                    ExllamaQuantLinear,
                )
            except ImportError:
                raise NotImplementedError(
                    f"Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`"
                )

185
            linear = ExllamaQuantLinear(weight, bias)
Nicolas Patry's avatar
Nicolas Patry committed
186
187
188
189
        else:
            from text_generation_server.layers.gptq.quant_linear import QuantLinear

            linear = QuantLinear(
190
191
192
193
                weight.qweight,
                weight.qzeros,
                weight.scales,
                weight.g_idx,
Nicolas Patry's avatar
Nicolas Patry committed
194
                bias,
195
196
                weight.bits,
                weight.groupsize,
Nicolas Patry's avatar
Nicolas Patry committed
197
198
            )
    elif quantize == "awq":
Nicolas Patry's avatar
Nicolas Patry committed
199
200
        from text_generation_server.layers.gptq import GPTQWeight

201
        if not isinstance(weight, GPTQWeight):
Nicolas Patry's avatar
Nicolas Patry committed
202
203
204
205
206
207
208
209
210
211
212
213
            raise NotImplementedError(
                f"The passed weight is not `awq` compatible, loader needs to be updated."
            )
        if SYSTEM == "rocm":
            raise NotImplementedError(
                "AWQ GEMM kernel can't be used on ROCm systems, please use `--quantize gptq` instead "
                "to use Exllama/GPTQ kernels for AWQ inference."
            )
        try:
            from text_generation_server.layers.awq.quantize.qmodule import WQLinear

            linear = WQLinear(
214
215
216
217
218
                w_bit=weight.bits,
                group_size=weight.groupsize,
                qweight=weight.qweight,
                qzeros=weight.qzeros,
                scales=weight.scales,
Nicolas Patry's avatar
Nicolas Patry committed
219
220
221
222
223
224
225
226
227
                bias=bias is not None,
            )
        except ImportError:
            raise NotImplementedError(
                "You do not seem to have awq installed, either install it (cd server &&  make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly"
            )
    else:
        raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
    return linear