linear.py 11.2 KB
Newer Older
1
2
import torch
import torch.nn as nn
3

4
import awq_ext  # with CUDA kernels
5
6
7
8
9


def make_divisible(c, divisor):
    return (c + divisor - 1) // divisor

10

11
12
13
14
15
16
17
18
19
def calculate_zeros_width(in_features, group_size=128, pack_num=8):
    if group_size >= 128:
        size_multiplier = 1
    elif group_size == 64:
        size_multiplier = 2
    elif group_size == 32:
        size_multiplier = 4
    else:
        raise NotImplementedError
20

21
22
23
24
    base_width = make_divisible(in_features // group_size, pack_num)
    base_width = make_divisible(base_width, size_multiplier) * size_multiplier
    return base_width

25

26
27
28
class WQLinear_GEMM(nn.Module):
    def __init__(self, w_bit, group_size, in_features, out_features, bias, dev):
        super().__init__()
29

30
31
        if w_bit not in [4]:
            raise NotImplementedError("Only 4-bit are supported for now.")
32

33
34
35
36
        self.in_features = in_features
        self.out_features = out_features
        self.w_bit = w_bit
        self.group_size = group_size if group_size != -1 else in_features
37

38
39
40
41
        # quick sanity check (make sure aligment)
        assert self.in_features % self.group_size == 0
        assert out_features % (32 // self.w_bit) == 0

42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
        self.register_buffer(
            "qweight",
            torch.zeros(
                (in_features, out_features // (32 // self.w_bit)),
                dtype=torch.int32,
                device=dev,
            ),
        )
        self.register_buffer(
            "qzeros",
            torch.zeros(
                (in_features // self.group_size, out_features // (32 // self.w_bit)),
                dtype=torch.int32,
                device=dev,
            ),
        )
        self.register_buffer(
            "scales",
            torch.zeros(
                (in_features // self.group_size, out_features),
                dtype=torch.float16,
                device=dev,
            ),
        )
66
        if bias:
67
68
69
70
71
72
73
74
            self.register_buffer(
                "bias",
                torch.zeros(
                    (out_features),
                    dtype=torch.float16,
                    device=dev,
                ),
            )
75
76
77
78
        else:
            self.bias = None

    @classmethod
79
80
81
82
83
84
85
86
87
88
89
    def from_linear(
        cls, linear, w_bit, group_size, init_only=False, scales=None, zeros=None
    ):
        awq_linear = cls(
            w_bit,
            group_size,
            linear.in_features,
            linear.out_features,
            linear.bias is not None,
            linear.weight.device,
        )
90
91
        if init_only:  # just prepare for loading sd
            return awq_linear
92

93
        # need scales and zeros info for real quantization
94
        assert scales is not None and zeros is not None
95
        scale_zeros = zeros * scales
96

97
98
99
100
101
        awq_linear.scales = scales.clone().half()
        if linear.bias is not None:
            awq_linear.bias = linear.bias.clone().half()

        pack_num = 32 // awq_linear.w_bit
102

103
104
        intweight = []
        for idx in range(awq_linear.in_features):
105
106
107
108
109
110
            intweight.append(
                torch.round(
                    (linear.weight.data[:, idx] + scale_zeros[idx // group_size])
                    / awq_linear.scales[idx // group_size]
                ).to(torch.int)[:, None]
            )
111
112
113
        intweight = torch.cat(intweight, dim=1)
        intweight = intweight.t().contiguous()
        intweight = intweight.to(dtype=torch.int32)
114
115
116
117
118
119
        qweight = torch.zeros(
            (intweight.shape[0], intweight.shape[1] // 32 * awq_linear.w_bit),
            dtype=torch.int32,
            device=intweight.device,
        )

120
121
122
123
124
125
126
127
128
129
130
        for col in range(intweight.shape[1] // pack_num):
            if awq_linear.w_bit == 4:
                order_map = [0, 2, 4, 6, 1, 3, 5, 7]
            else:
                raise NotImplementedError("Only 4-bit are supported for now.")
            for i in range(pack_num):
                qweight_col = intweight[:, col * pack_num + order_map[i]]
                qweight[:, col] |= qweight_col << (i * awq_linear.w_bit)
        awq_linear.qweight = qweight

        zeros = zeros.to(dtype=torch.int32)
131
132
133
134
135
136
        qzeros = torch.zeros(
            (zeros.shape[0], zeros.shape[1] // 32 * awq_linear.w_bit),
            dtype=torch.int32,
            device=zeros.device,
        )

137
138
139
140
141
142
143
144
145
        for col in range(zeros.shape[1] // pack_num):
            if awq_linear.w_bit == 4:
                order_map = [0, 2, 4, 6, 1, 3, 5, 7]
            else:
                raise NotImplementedError("Only 4-bit are supported for now.")
            for i in range(pack_num):
                qzero_col = zeros[:, col * pack_num + order_map[i]]
                qzeros[:, col] |= qzero_col << (i * awq_linear.w_bit)
        awq_linear.qzeros = qzeros
146

147
148
149
150
        return awq_linear

    @torch.no_grad()
    def forward(self, x):
151
        out_shape = x.shape[:-1] + (self.out_features,)
152
153
154
155

        input_dtype = x.dtype
        if input_dtype != torch.float16:
            x = x.half()
156
157
158
159
160

        out = awq_ext.gemm_forward_cuda(
            x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8
        )

161
162
        if input_dtype != torch.float16:
            out = out.to(dtype=input_dtype)
163

164
165
        out = out + self.bias if self.bias is not None else out
        return out.reshape(out_shape)
166

167
    def extra_repr(self) -> str:
168
169
170
171
172
173
174
175
        return (
            "in_features={}, out_features={}, bias={}, w_bit={}, group_size={}".format(
                self.in_features,
                self.out_features,
                self.bias is not None,
                self.w_bit,
                self.group_size,
            )
176
177
178
179
180
181
        )


class WQLinear_GEMV(nn.Module):
    def __init__(self, w_bit, group_size, in_features, out_features, bias, dev):
        super().__init__()
182

183
184
        if w_bit not in [4]:
            raise NotImplementedError("Only 4-bit are supported for now.")
185

186
187
188
189
190
191
192
193
194
        self.in_features = in_features
        self.out_features = out_features
        self.w_bit = w_bit
        self.group_size = group_size if group_size != -1 else in_features
        self.split_k_iters = 8

        # quick sanity check (make sure aligment)
        assert self.in_features % self.group_size == 0
        assert out_features % (32 // self.w_bit) == 0
195
        pack_num = 32 // self.w_bit
196

197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
        self.register_buffer(
            "qweight",
            torch.zeros(
                (out_features, in_features // pack_num), dtype=torch.int32, device=dev
            ),
        )
        self.register_buffer(
            "qzeros",
            torch.zeros(
                (out_features, calculate_zeros_width(in_features, self.group_size)),
                dtype=torch.int32,
                device=dev,
            ),
        )
        self.register_buffer(
            "scales",
            torch.zeros(
                (
                    out_features,
                    calculate_zeros_width(in_features, self.group_size) * pack_num,
                ),
                dtype=torch.float16,
                device=dev,
            ),
        )
222
        if bias:
223
224
225
            self.register_buffer(
                "bias", torch.zeros((out_features), dtype=torch.float16, device=dev)
            )
226
227
228
229
        else:
            self.bias = None

    @classmethod
230
231
232
233
234
235
236
237
238
239
240
    def from_linear(
        cls, linear, w_bit, group_size, init_only=False, scales=None, zeros=None
    ):
        awq_linear = cls(
            w_bit,
            group_size,
            linear.in_features,
            linear.out_features,
            linear.bias is not None,
            linear.weight.device,
        )
241
242
        if init_only:  # just prepare for loading sd
            return awq_linear
243

244
        # need scales and zeros info for real quantization
245
        assert scales is not None and zeros is not None
246
247
248
249
        scale_zeros = zeros * scales

        pack_num = 32 // awq_linear.w_bit
        qscales = torch.zeros(
250
251
252
253
            (
                scales.shape[0],
                calculate_zeros_width(linear.in_features, group_size) * pack_num,
            ),
254
            dtype=torch.float16,
255
            device=scales.device,
256
        )
257
        qscales[:, : scales.shape[1]] = scales
258
259
260
        awq_linear.scales = qscales
        if linear.bias is not None:
            awq_linear.bias = linear.bias.clone().half()
261

262
263
        intweight = []
        for idx in range(awq_linear.in_features):
264
265
266
267
268
269
            intweight.append(
                torch.round(
                    (linear.weight.data[:, idx] + scale_zeros[:, idx // group_size])
                    / awq_linear.scales[:, idx // group_size]
                ).to(torch.int)[:, None]
            )
270
271
        intweight = torch.cat(intweight, dim=1)
        intweight = intweight.to(dtype=torch.int32)
272
273
274
275
276
277
        qweight = torch.zeros(
            (intweight.shape[0], intweight.shape[1] // 32 * awq_linear.w_bit),
            dtype=torch.int32,
            device=intweight.device,
        )

278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
        for col in range(intweight.shape[1] // pack_num):
            if awq_linear.w_bit == 4:
                order_map = [0, 1, 2, 3, 4, 5, 6, 7]
            else:
                raise NotImplementedError("Only 4-bit are supported for now.")
            for i in range(pack_num):
                qweight_col = intweight[:, col * pack_num + order_map[i]]
                qweight[:, col] |= qweight_col << (i * awq_linear.w_bit)
        awq_linear.qweight = qweight

        zeros = zeros.to(dtype=torch.int32)
        qzeros = torch.zeros(
            (zeros.shape[0], calculate_zeros_width(linear.in_features, group_size)),
            dtype=torch.int32,
            device=zeros.device,
        )
294

295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
        for col in range((zeros.shape[1] + pack_num - 1) // pack_num):
            if awq_linear.w_bit == 4:
                order_map = [0, 1, 2, 3, 4, 5, 6, 7]
            else:
                raise NotImplementedError("Only 4-bit are supported for now.")
            for i in range(pack_num):
                if col * pack_num + order_map[i] >= zeros.shape[1]:
                    continue
                qzero_col = zeros[:, col * pack_num + order_map[i]]
                qzeros[:, col] |= qzero_col << (i * awq_linear.w_bit)
        awq_linear.qzeros = qzeros
        return awq_linear

    @torch.no_grad()
    def forward(self, x):
310
        out_shape = x.shape[:-1] + (self.out_features,)
311
        inputs = x.reshape(-1, x.shape[-1])
312
313
314
315

        input_dtype = inputs.dtype
        if input_dtype != torch.float16:
            inputs = inputs.half()
316

317
        if inputs.shape[0] > 8:
318
319
320
321
322
323
324
325
            out = awq_ext.gemmv2_forward_cuda(
                inputs,
                self.qweight,
                self.scales,
                self.qzeros,
                self.group_size,
                self.split_k_iters,
            )
326
        else:
327
328
329
            out = awq_ext.gemv_forward_cuda(
                inputs, self.qweight, self.scales, self.qzeros, self.group_size
            )
330
331
332

        if input_dtype != torch.float16:
            out = out.to(dtype=input_dtype)
333

334
335
        out = out + self.bias if self.bias is not None else out
        return out.reshape(out_shape)
336

337
    def extra_repr(self) -> str:
338
339
340
341
342
343
344
345
        return (
            "in_features={}, out_features={}, bias={}, w_bit={}, group_size={}".format(
                self.in_features,
                self.out_features,
                self.bias is not None,
                self.w_bit,
                self.group_size,
            )
346
        )