_functions.py 13.7 KB
Newer Older
1
2
from dataclasses import dataclass

Tim Dettmers's avatar
Tim Dettmers committed
3
import torch
4
import math
Tim Dettmers's avatar
Tim Dettmers committed
5
6
7
8
9
import bitsandbytes as bnb
import bitsandbytes.functional as F

tensor = torch.Tensor

10
"""
Tim Dettmers's avatar
Tim Dettmers committed
11
12
13
    This class pools outlier dimensions across layers.
    This is particularly important for small models where outlier features 
    are less systematic and occur with low frequency.
14
15
16
"""


Tim Dettmers's avatar
Tim Dettmers committed
17
18
19
20
class GlobalOutlierPooler(object):
    _instance = None

    def __init__(self):
21
        raise RuntimeError("Call get_instance() instead")
Tim Dettmers's avatar
Tim Dettmers committed
22
23
24
25
26
27
28
29
30
31
32
33
34

    def initialize(self):
        self.outliers = set()
        self.model_dim = None

    @classmethod
    def get_instance(cls):
        if cls._instance is None:
            cls._instance = cls.__new__(cls)
            cls._instance.initialize()
        return cls._instance

    def add_outliers(self, outlier_idx, feature_dim):
35
36
37
38
        if self.model_dim is None:
            self.model_dim = feature_dim
        if feature_dim != self.model_dim:
            return  # we do not encode outliers for the 2nd FFN layer
Tim Dettmers's avatar
Tim Dettmers committed
39
40
41
42
43
44
45

        self.outliers.update(outlier_idx.tolist())

    def get_current_outlier_idx(self):
        return torch.Tensor(list(self.outliers)).to(torch.int64)


46
class MatMul8bit(torch.autograd.Function):
Tim Dettmers's avatar
Tim Dettmers committed
47
    @staticmethod
48
    def forward(ctx, A, B, out=None, quant_type="vector", precision=[8, 8, 8]):
Tim Dettmers's avatar
Tim Dettmers committed
49
50
51
52
53

        if precision[0] != 8:
            with torch.no_grad():
                output = torch.matmul(A, B)
        else:
54
55
56
57
            if len(B.shape) == 2:
                dim = 0
            else:
                dim = 1
Tim Dettmers's avatar
Tim Dettmers committed
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
            qA, SA = F.vectorwise_quant(A, dim=-1, quant_type=quant_type)
            qB, SB = F.vectorwise_quant(B, dim=dim, quant_type=quant_type)
            iout = F.igemm(qA, qB)
            output = F.vectorwise_mm_dequant(iout, SA, SB, A.dtype, quant_type)

        if A.requires_grad or B.requires_grad:
            ctx.save_for_backward(A, B)

        ctx.quant_type = quant_type
        ctx.precision = precision

        return output

    @staticmethod
    def backward(ctx, grad_output):
        A, B = ctx.saved_tensors
        quant_type = ctx.quant_type
        precision = ctx.precision
        grad_A = grad_B = None

        if B.requires_grad:
            if len(A.shape) == 3:
                dims = [0, 1]
                # bsi -> ibs
                permute_dim = [0, 2, 1]
            else:
                dims = [0]
                # bs -> sb
                permute_dim = [1, 0]

            if precision[1] != 8:
                with torch.no_grad():
                    grad_B = torch.matmul(A.permute(permute_dim), grad_output)
            else:
                if len(B.shape) == 2 and len(A.shape) == 3:
                    grad_output = grad_output.contiguous()
94
95
96
97
98
99
100
101
102
103
104
105
                    if not grad_output.is_contiguous():
                        grad_output.contiguous()
                    qgrad_output, S1 = F.vectorwise_quant(
                        grad_output.view(-1, grad_output.shape[2]),
                        dim=0,
                        quant_type=quant_type,
                    )
                    if not A.is_contiguous():
                        A = A.contiguous()
                    qA, S2 = F.vectorwise_quant(
                        A.view(-1, A.shape[2]), dim=0, quant_type=quant_type
                    )
Tim Dettmers's avatar
Tim Dettmers committed
106
                    igrad_B = F.igemm(qA.t(), qgrad_output)
107
108
109
                    grad_B = F.vectorwise_mm_dequant(
                        igrad_B, S2.t(), S1, grad_output.dtype, quant_type
                    )
Tim Dettmers's avatar
Tim Dettmers committed
110
                else:
111
112
113
                    qgrad_output, S1 = F.vectorwise_quant(
                        grad_output, dim=dims, quant_type=quant_type
                    )
114
115
116
                    qA, S2 = F.vectorwise_quant(
                        A, dim=dims, quant_type=quant_type
                    )
Tim Dettmers's avatar
Tim Dettmers committed
117
                    igrad_B = F.igemm(qA.permute(permute_dim), qgrad_output)
118
119
120
121
122
123
124
                    grad_B = F.vectorwise_mm_dequant(
                        igrad_B,
                        S2.permute(permute_dim),
                        S1,
                        grad_output.dtype,
                        quant_type,
                    )
Tim Dettmers's avatar
Tim Dettmers committed
125
126

        if A.requires_grad:
127
128
129
130
            if len(grad_output.shape) == 3:
                dims = [2]
            else:
                dims = [1]
Tim Dettmers's avatar
Tim Dettmers committed
131
132
133
134
135
136
137
138
139
140
141
142
143
144

            if len(B.shape) == 3:
                # bio -> boi
                permute_dim = [0, 2, 1]
                dim_B = dims
            else:
                # io -> oi
                permute_dim = [1, 0]
                dim_B = [1]

            if precision[2] != 8:
                with torch.no_grad():
                    grad_A = torch.matmul(grad_output, B.permute(permute_dim))
            else:
145
146
147
                qgrad_output, S1 = F.vectorwise_quant(
                    grad_output, dim=dims, quant_type=quant_type
                )
Tim Dettmers's avatar
Tim Dettmers committed
148
149
                qB, S3 = F.vectorwise_quant(B, dim=dim_B, quant_type=quant_type)
                igrad_A = F.igemm(qgrad_output, qB.permute(permute_dim))
150
                grad_A = F.vectorwise_mm_dequant(
151
152
153
154
155
                    igrad_A,
                    S1,
                    S3.permute(permute_dim),
                    grad_output.dtype,
                    quant_type,
156
                )
Tim Dettmers's avatar
Tim Dettmers committed
157
158
159
160
161
162
163
164

        return grad_A, grad_B, None, None, None


mm_cublas = MatMul8bit.apply
bmm_cublas = MatMul8bit.apply
matmul_cublas = MatMul8bit.apply

165

Tim Dettmers's avatar
Tim Dettmers committed
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
@dataclass
class MatmulLtState:
    CB = None
    CxB = None
    SB = None
    SCB = None

    CxBt = None
    SBt = None
    CBt = None

    subB = None

    outlier_pool = None
    has_accumulated_gradients = False
    threshold = 0.0
    idx = None
    is_training = True
    has_fp16_weights = True
    use_pool = False
    formatB = F.get_special_format_str()

    def reset_grads(self):
        self.CB = None
        self.CxB = None
        self.SB = None
        self.SCB = None

        self.CxBt = None
        self.SBt = None
        self.CBt = None


class MatMul8bitLt(torch.autograd.Function):
    @staticmethod
    def forward(ctx, A, B, out=None, state=MatmulLtState()):
202
203
204
205
206
207
208
209
210
211
212
        # default to pytorch behavior if inputs are empty
        ctx.is_empty = False
        if math.prod(A.shape) == 0:
            ctx.is_empty = True
            ctx.A = A
            ctx.B = B
            if A.shape[-1] == B.shape[0]:
                return torch.empty(A.shape[:-1]+B.shape[1:], dtype=torch.float16, device=A.device)
            else:
                return torch.empty(A.shape[:-1]+B.shape[:1], dtype=torch.float16, device=A.device)

Tim Dettmers's avatar
Tim Dettmers committed
213
214
215
216
217
218
219
220
221
        # 1. Quantize A
        # 2. Quantize B
        # 3. Matmul
        # 4. Mixed-precision decomposition matmul
        # 5. Save state
        requires_gradA = A.requires_grad
        requires_gradB = B.requires_grad
        formatB = state.formatB
        input_shape = A.shape
222
223
224
225
226
        if state.outlier_pool is None:
            state.outlier_pool = GlobalOutlierPooler.get_instance()
        assert (
            A.dtype == torch.float16
        ), f"The input data type needs to be fp16 but {A.dtype} was found!"
Tim Dettmers's avatar
Tim Dettmers committed
227
228

        # 1. Quantize A
229
230
        if len(A.shape) == 3:
            A = A.view(-1, A.shape[-1]).contiguous()
231
232
233
        CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(
            A, threshold=state.threshold
        )
Tim Dettmers's avatar
Tim Dettmers committed
234
235
236
237
238
239
240
241
242
243
244
245
246

        if state.threshold > 0.0 and coo_tensorA is not None:
            if state.has_fp16_weights:
                idx = torch.unique(coo_tensorA.colidx).long()
                CA[:, idx] = 0
                CAt[:, idx] = 0
                subA = A[:, idx]
                state.subB = B[:, idx].t().contiguous()
                state.idx = idx
            else:
                if state.CxB is None:
                    # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
                    # we also need to convert it to the turing/ampere format
247
248
249
                    state.CxB, state.SB = F.transform(
                        state.CB, to_order=formatB
                    )
250
251
                    # state.B = (state.CB.float()*(state.SCB.view(-1, 1)/127)).half()
                # if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None:
252
253
254
255
256
257
258
259
260
261
                #    # generate outlier index and subB
                #    outlier_idx = torch.unique(coo_tensorA.colidx).long()
                #    state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
                #    if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
                #        # do not use pool for 2nd FFN layer
                #        state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
                #    else:
                #        state.idx = outlier_idx
                #    state.subB = (state.CB[:, state.idx].float().t().contiguous()*(state.SCB/127)).half()

262
                # if state.idx is not None:
263
264
265
266
                #    # extract outliers
                #    CA[:, state.idx] = 0
                #    CAt[:, state.idx] = 0
                #    subA = A[:, state.idx]
267
                # else:
268
                #    subA = None
Tim Dettmers's avatar
Tim Dettmers committed
269
270
271
272
273
274
275
        else:
            if not state.has_fp16_weights and state.CxB is None:
                state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
            subA = None

        # 2. Quantize B
        if state.has_fp16_weights:
276
            has_grad = True if (getattr(B, "grad", None) is not None) else False
Tim Dettmers's avatar
Tim Dettmers committed
277
            is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1)
278
279
            if is_transposed:
                B = B.contiguous()
Tim Dettmers's avatar
Tim Dettmers committed
280
281
282

            if (state.is_training and not has_grad) or state.CxB is None:
                state.reset_grads()
283
284
285
286
287
288
289
                (
                    CB,
                    state.CBt,
                    state.SCB,
                    state.SCBt,
                    coo_tensorB,
                ) = F.double_quant(B)
Tim Dettmers's avatar
Tim Dettmers committed
290
291
292
293
                state.CxB, state.SB = F.transform(CB, to_order=formatB)
        else:
            has_grad = False

294
295
296
297
        if coo_tensorA is not None and not state.has_fp16_weights:
            # extract outliers

            outlier_idx = torch.unique(coo_tensorA.colidx)
298
            state.idx = outlier_idx
299
300
            # state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
            # if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
301
302
            #    # do not use pool for 2nd FFN layer
            #    state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
303
            # else:
304
305
            #    state.idx = outlier_idx
            outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
306
            state.subB = (
307
308
309
310
                (outliers * state.SCB.view(-1, 1) / 127.0)
                .t()
                .contiguous()
                .half()
311
            )
312
313
314
315
            CA[:, state.idx.long()] = 0
            CAt[:, state.idx.long()] = 0
            subA = A[:, state.idx.long()]

Tim Dettmers's avatar
Tim Dettmers committed
316
317
318
319
320
321
322
323
        shapeB = state.SB[0]

        if len(input_shape) == 3:
            output_shape = (input_shape[0], input_shape[1], shapeB[0])
        else:
            output_shape = (input_shape[0], shapeB[0])

        # 3. Matmul
324
        C32A, SA = F.transform(CA, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
325
326
327
328
        out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
        output = F.mm_dequant(out32, Sout32, SCA, state.SCB)

        # 4. Mixed-precision decomposition matmul
329
        if coo_tensorA is not None and subA is not None:
Tim Dettmers's avatar
Tim Dettmers committed
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
            output += torch.matmul(subA, state.subB)

        # 5. Save state
        ctx.state = state

        ctx.formatB = formatB
        ctx.grad_shape = input_shape
        ctx.req_grads = [requires_gradA, requires_gradB]

        if requires_gradA or requires_gradB:
            ctx.tensors = (CAt, subA)
            ctx.tensor_states = (SCAt, state.idx)
        else:
            ctx.tensors = [None, None]
            ctx.tensor_states = (None, None)
            ctx.save_for_backward(None, None)

347
        # clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
Tim Dettmers's avatar
Tim Dettmers committed
348
349
350
351
352
        clone_func = torch.clone
        return clone_func(output.view(output_shape))

    @staticmethod
    def backward(ctx, grad_output):
353
354
        if ctx.is_empty:
            return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None
Tim Dettmers's avatar
Tim Dettmers committed
355
356
357
358
359
        req_gradA, req_gradB = ctx.req_grads
        CAt, subA = ctx.tensors
        SCAt, idx = ctx.tensor_states
        formatB = ctx.formatB
        state = ctx.state
360
361
362
        assert (
            state.has_fp16_weights
        ), "Backprop only supported for fp16 weights."
Tim Dettmers's avatar
Tim Dettmers committed
363
364

        if len(grad_output.shape) == 3:
365
366
367
            grad_output = grad_output.view(
                -1, grad_output.shape[-1]
            ).contiguous()
Tim Dettmers's avatar
Tim Dettmers committed
368
369
370
371
372
373

        grad_A = grad_B = None

        Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output)
        if req_gradB:
            CxAt, SAt = F.transform(CAt, formatB, transpose=True)
374
            C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True)
Tim Dettmers's avatar
Tim Dettmers committed
375
376
377
378
379
380
            gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt)
            grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt)
            if state.threshold > 0.0 and subA is not None:
                grad_B[:, idx] += torch.matmul(grad_output.t(), subA)

        if req_gradA:
381
            C32grad, Sgrad = F.transform(Cgrad, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
382
            if state.CxBt is None:
383
384
385
                state.CxBt, state.SBt = F.transform(
                    state.CBt, to_order=formatB, transpose=True
                )
Tim Dettmers's avatar
Tim Dettmers committed
386
            gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
387
388
389
            grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(
                ctx.grad_shape
            )
Tim Dettmers's avatar
Tim Dettmers committed
390

391
        return grad_A, grad_B, None, None
Tim Dettmers's avatar
Tim Dettmers committed
392
393
394
395
396


matmul = MatMul8bitLt.apply


397
def matmul(
398
399
400
401
402
    A: tensor,
    B: tensor,
    out: tensor = None,
    state: MatmulLtState = None,
    threshold=0.0,
403
):
Tim Dettmers's avatar
Tim Dettmers committed
404
405
406
407
    state = state or MatmulLtState()
    if threshold > 0.0:
        state.threshold = threshold
    return MatMul8bitLt.apply(A, B, out, state)