_functions.py 14 KB
Newer Older
1
import operator
2
3
import warnings

Tim Dettmers's avatar
Tim Dettmers committed
4
5
6
import torch
import bitsandbytes.functional as F

7
8
9
from dataclasses import dataclass
from functools import reduce  # Required in Python 3

10
# math.prod not compatible with python < 3.8
11
12
13
def prod(iterable):
    return reduce(operator.mul, iterable, 1)

Tim Dettmers's avatar
Tim Dettmers committed
14
15
tensor = torch.Tensor

16
"""
Tim Dettmers's avatar
Tim Dettmers committed
17
18
19
    This class pools outlier dimensions across layers.
    This is particularly important for small models where outlier features 
    are less systematic and occur with low frequency.
20
"""
Tim Dettmers's avatar
Tim Dettmers committed
21
22
23
24
class GlobalOutlierPooler(object):
    _instance = None

    def __init__(self):
25
        raise RuntimeError("Call get_instance() instead")
Tim Dettmers's avatar
Tim Dettmers committed
26
27
28
29
30
31
32
33
34
35
36
37
38

    def initialize(self):
        self.outliers = set()
        self.model_dim = None

    @classmethod
    def get_instance(cls):
        if cls._instance is None:
            cls._instance = cls.__new__(cls)
            cls._instance.initialize()
        return cls._instance

    def add_outliers(self, outlier_idx, feature_dim):
39
40
41
42
        if self.model_dim is None:
            self.model_dim = feature_dim
        if feature_dim != self.model_dim:
            return  # we do not encode outliers for the 2nd FFN layer
Tim Dettmers's avatar
Tim Dettmers committed
43
44
45
46
47
48
49

        self.outliers.update(outlier_idx.tolist())

    def get_current_outlier_idx(self):
        return torch.Tensor(list(self.outliers)).to(torch.int64)


50
class MatMul8bit(torch.autograd.Function):
Tim Dettmers's avatar
Tim Dettmers committed
51
    @staticmethod
52
    def forward(ctx, A, B, out=None, quant_type="vector", precision=[8, 8, 8]):
Tim Dettmers's avatar
Tim Dettmers committed
53
54
55
56
57

        if precision[0] != 8:
            with torch.no_grad():
                output = torch.matmul(A, B)
        else:
58
59
60
61
            if len(B.shape) == 2:
                dim = 0
            else:
                dim = 1
Tim Dettmers's avatar
Tim Dettmers committed
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
            qA, SA = F.vectorwise_quant(A, dim=-1, quant_type=quant_type)
            qB, SB = F.vectorwise_quant(B, dim=dim, quant_type=quant_type)
            iout = F.igemm(qA, qB)
            output = F.vectorwise_mm_dequant(iout, SA, SB, A.dtype, quant_type)

        if A.requires_grad or B.requires_grad:
            ctx.save_for_backward(A, B)

        ctx.quant_type = quant_type
        ctx.precision = precision

        return output

    @staticmethod
    def backward(ctx, grad_output):
        A, B = ctx.saved_tensors
        quant_type = ctx.quant_type
        precision = ctx.precision
        grad_A = grad_B = None

        if B.requires_grad:
            if len(A.shape) == 3:
                dims = [0, 1]
                # bsi -> ibs
                permute_dim = [0, 2, 1]
            else:
                dims = [0]
                # bs -> sb
                permute_dim = [1, 0]

            if precision[1] != 8:
                with torch.no_grad():
                    grad_B = torch.matmul(A.permute(permute_dim), grad_output)
            else:
                if len(B.shape) == 2 and len(A.shape) == 3:
                    grad_output = grad_output.contiguous()
98
99
100
101
102
103
104
105
106
107
108
109
                    if not grad_output.is_contiguous():
                        grad_output.contiguous()
                    qgrad_output, S1 = F.vectorwise_quant(
                        grad_output.view(-1, grad_output.shape[2]),
                        dim=0,
                        quant_type=quant_type,
                    )
                    if not A.is_contiguous():
                        A = A.contiguous()
                    qA, S2 = F.vectorwise_quant(
                        A.view(-1, A.shape[2]), dim=0, quant_type=quant_type
                    )
Tim Dettmers's avatar
Tim Dettmers committed
110
                    igrad_B = F.igemm(qA.t(), qgrad_output)
111
112
113
                    grad_B = F.vectorwise_mm_dequant(
                        igrad_B, S2.t(), S1, grad_output.dtype, quant_type
                    )
Tim Dettmers's avatar
Tim Dettmers committed
114
                else:
115
116
117
                    qgrad_output, S1 = F.vectorwise_quant(
                        grad_output, dim=dims, quant_type=quant_type
                    )
118
119
120
                    qA, S2 = F.vectorwise_quant(
                        A, dim=dims, quant_type=quant_type
                    )
Tim Dettmers's avatar
Tim Dettmers committed
121
                    igrad_B = F.igemm(qA.permute(permute_dim), qgrad_output)
122
123
124
125
126
127
128
                    grad_B = F.vectorwise_mm_dequant(
                        igrad_B,
                        S2.permute(permute_dim),
                        S1,
                        grad_output.dtype,
                        quant_type,
                    )
Tim Dettmers's avatar
Tim Dettmers committed
129
130

        if A.requires_grad:
131
132
133
134
            if len(grad_output.shape) == 3:
                dims = [2]
            else:
                dims = [1]
Tim Dettmers's avatar
Tim Dettmers committed
135
136
137
138
139
140
141
142
143
144
145
146
147
148

            if len(B.shape) == 3:
                # bio -> boi
                permute_dim = [0, 2, 1]
                dim_B = dims
            else:
                # io -> oi
                permute_dim = [1, 0]
                dim_B = [1]

            if precision[2] != 8:
                with torch.no_grad():
                    grad_A = torch.matmul(grad_output, B.permute(permute_dim))
            else:
149
150
151
                qgrad_output, S1 = F.vectorwise_quant(
                    grad_output, dim=dims, quant_type=quant_type
                )
Tim Dettmers's avatar
Tim Dettmers committed
152
153
                qB, S3 = F.vectorwise_quant(B, dim=dim_B, quant_type=quant_type)
                igrad_A = F.igemm(qgrad_output, qB.permute(permute_dim))
154
                grad_A = F.vectorwise_mm_dequant(
155
156
157
158
159
                    igrad_A,
                    S1,
                    S3.permute(permute_dim),
                    grad_output.dtype,
                    quant_type,
160
                )
Tim Dettmers's avatar
Tim Dettmers committed
161
162
163
164
165
166
167
168

        return grad_A, grad_B, None, None, None


mm_cublas = MatMul8bit.apply
bmm_cublas = MatMul8bit.apply
matmul_cublas = MatMul8bit.apply

169

Tim Dettmers's avatar
Tim Dettmers committed
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
@dataclass
class MatmulLtState:
    CB = None
    CxB = None
    SB = None
    SCB = None

    CxBt = None
    SBt = None
    CBt = None

    subB = None

    outlier_pool = None
    has_accumulated_gradients = False
    threshold = 0.0
    idx = None
    is_training = True
    has_fp16_weights = True
dbaranchuk's avatar
dbaranchuk committed
189
    memory_efficient_backward = False
Tim Dettmers's avatar
Tim Dettmers committed
190
191
192
193
194
195
196
197
198
199
200
    use_pool = False
    formatB = F.get_special_format_str()

    def reset_grads(self):
        self.CB = None
        self.CxB = None
        self.SB = None
        self.SCB = None

        self.CxBt = None
        self.SBt = None
dbaranchuk's avatar
dbaranchuk committed
201
        self.CBt = None
Tim Dettmers's avatar
Tim Dettmers committed
202
203
204
205


class MatMul8bitLt(torch.autograd.Function):
    @staticmethod
Tim Dettmers's avatar
Tim Dettmers committed
206
    def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()):
207
208
        # default to pytorch behavior if inputs are empty
        ctx.is_empty = False
209
        if prod(A.shape) == 0:
210
211
212
            ctx.is_empty = True
            ctx.A = A
            ctx.B = B
Tim Dettmers's avatar
Tim Dettmers committed
213
            ctx.bias = bias
214
215
216
217
218
            if A.shape[-1] == B.shape[0]:
                return torch.empty(A.shape[:-1]+B.shape[1:], dtype=torch.float16, device=A.device)
            else:
                return torch.empty(A.shape[:-1]+B.shape[:1], dtype=torch.float16, device=A.device)

Tim Dettmers's avatar
Tim Dettmers committed
219
220
221
222
223
224
225
        # 1. Quantize A
        # 2. Quantize B
        # 3. Matmul
        # 4. Mixed-precision decomposition matmul
        # 5. Save state
        requires_gradA = A.requires_grad
        requires_gradB = B.requires_grad
Tim Dettmers's avatar
Tim Dettmers committed
226
        requires_gradBias = bias is not None and bias.requires_grad
Tim Dettmers's avatar
Tim Dettmers committed
227
228
        formatB = state.formatB
        input_shape = A.shape
229
230
        if state.outlier_pool is None:
            state.outlier_pool = GlobalOutlierPooler.get_instance()
231
232

        # Cast A to fp16
justheuristic's avatar
justheuristic committed
233
234
        if A.dtype != torch.float16:
            warnings.warn(f"MatMul8bitLt: input matrix will be cast from {A.dtype} to float16")
235

Tim Dettmers's avatar
Tim Dettmers committed
236
        # 1. Quantize A
237
238
        if len(A.shape) == 3:
            A = A.view(-1, A.shape[-1]).contiguous()
239
        CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(
justheuristic's avatar
justheuristic committed
240
            A.to(torch.float16), threshold=state.threshold
241
        )
Tim Dettmers's avatar
Tim Dettmers committed
242
243
244
245
246
247
248
249
250

        if state.threshold > 0.0 and coo_tensorA is not None:
            if state.has_fp16_weights:
                idx = torch.unique(coo_tensorA.colidx).long()
                CA[:, idx] = 0
                CAt[:, idx] = 0
                subA = A[:, idx]
                state.subB = B[:, idx].t().contiguous()
                state.idx = idx
dbaranchuk's avatar
dbaranchuk committed
251
252
253
254
255
            else:
                if state.CxB is None:
                    # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
                    # we also need to convert it to the turing/ampere format
                    state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
Tim Dettmers's avatar
Tim Dettmers committed
256
257
258
259
260
261
262
        else:
            if not state.has_fp16_weights and state.CxB is None:
                state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
            subA = None

        # 2. Quantize B
        if state.has_fp16_weights:
263
            has_grad = True if (getattr(B, "grad", None) is not None) else False
Tim Dettmers's avatar
Tim Dettmers committed
264
            is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1)
265
266
            if is_transposed:
                B = B.contiguous()
Tim Dettmers's avatar
Tim Dettmers committed
267
268
269

            if (state.is_training and not has_grad) or state.CxB is None:
                state.reset_grads()
270
271
272
273
274
275
                (
                    CB,
                    state.CBt,
                    state.SCB,
                    state.SCBt,
                    coo_tensorB,
justheuristic's avatar
justheuristic committed
276
                ) = F.double_quant(B.to(torch.float16))
Tim Dettmers's avatar
Tim Dettmers committed
277
278
279
280
                state.CxB, state.SB = F.transform(CB, to_order=formatB)
        else:
            has_grad = False

281
282
283
284
        if coo_tensorA is not None and not state.has_fp16_weights:
            # extract outliers

            outlier_idx = torch.unique(coo_tensorA.colidx)
285
            state.idx = outlier_idx
286
287
288
289
290
291
            # state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
            # if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
            #    # do not use pool for 2nd FFN layer
            #    state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
            # else:
            #    state.idx = outlier_idx
292
            outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
293
            state.subB = (
294
295
296
297
                (outliers * state.SCB.view(-1, 1) / 127.0)
                .t()
                .contiguous()
                .half()
298
            )
299
300
301
302
            CA[:, state.idx.long()] = 0
            CAt[:, state.idx.long()] = 0
            subA = A[:, state.idx.long()]

Tim Dettmers's avatar
Tim Dettmers committed
303
304
305
306
307
308
309
310
        shapeB = state.SB[0]

        if len(input_shape) == 3:
            output_shape = (input_shape[0], input_shape[1], shapeB[0])
        else:
            output_shape = (input_shape[0], shapeB[0])

        # 3. Matmul
311
        C32A, SA = F.transform(CA, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
312
        out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
Tim Dettmers's avatar
Tim Dettmers committed
313
        # we apply the fused bias here
justheuristic's avatar
justheuristic committed
314

justheuristic's avatar
justheuristic committed
315
316
        if bias is None or bias.dtype == torch.float16:
            output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias)
justheuristic's avatar
justheuristic committed
317
            output = output.to(A.dtype)
justheuristic's avatar
justheuristic committed
318
319
        else:  # apply bias separately
            output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None)
justheuristic's avatar
justheuristic committed
320
            output = output.to(A.dtype).add_(bias)
Tim Dettmers's avatar
Tim Dettmers committed
321
322

        # 4. Mixed-precision decomposition matmul
323
        if coo_tensorA is not None and subA is not None:
justheuristic's avatar
justheuristic committed
324
            output.addmm_(subA, state.subB)
Tim Dettmers's avatar
Tim Dettmers committed
325
326
327
328
329
330

        # 5. Save state
        ctx.state = state

        ctx.formatB = formatB
        ctx.grad_shape = input_shape
Tim Dettmers's avatar
Tim Dettmers committed
331
        ctx.req_grads = [requires_gradA, requires_gradB, requires_gradBias]
justheuristic's avatar
justheuristic committed
332
        ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype
Tim Dettmers's avatar
Tim Dettmers committed
333
334
335
336
337
338
339
340
341

        if requires_gradA or requires_gradB:
            ctx.tensors = (CAt, subA)
            ctx.tensor_states = (SCAt, state.idx)
        else:
            ctx.tensors = [None, None]
            ctx.tensor_states = (None, None)
            ctx.save_for_backward(None, None)

dbaranchuk's avatar
dbaranchuk committed
342

Tim Dettmers's avatar
Tim Dettmers committed
343
        clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
Tim Dettmers's avatar
Tim Dettmers committed
344
345
        return clone_func(output.view(output_shape))

346
    @staticmethod
Tim Dettmers's avatar
Tim Dettmers committed
347
    def backward(ctx, grad_output):
348
        if ctx.is_empty:
Tim Dettmers's avatar
Tim Dettmers committed
349
350
            bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias))
            return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
justheuristic's avatar
justheuristic committed
351
        req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad
352
353
354
        CAt, subA = ctx.tensors
        SCAt, idx = ctx.tensor_states
        formatB = ctx.formatB
Tim Dettmers's avatar
Tim Dettmers committed
355
        state = ctx.state
justheuristic's avatar
justheuristic committed
356
357
358
359
        grad_A = grad_B = grad_bias = None

        if req_gradBias:
            # compute grad_bias first before changing grad_output dtype
justheuristic's avatar
justheuristic committed
360
            grad_bias = grad_output.sum(0).to(ctx.bias_dtype)
Tim Dettmers's avatar
Tim Dettmers committed
361

362
        # Cast grad_output to fp16
Tim Dettmers's avatar
Tim Dettmers committed
363
        if len(grad_output.shape) == 3:
dbaranchuk's avatar
dbaranchuk committed
364
            grad_output = grad_output.reshape(
365
366
                -1, grad_output.shape[-1]
            ).contiguous()
Tim Dettmers's avatar
Tim Dettmers committed
367

justheuristic's avatar
justheuristic committed
368
        Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16))
369
370
371
372
        if req_gradB:
            CxAt, SAt = F.transform(CAt, formatB, transpose=True)
            C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True)
            gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt)
justheuristic's avatar
justheuristic committed
373
            grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt).to(ctx.B_dtype)
374
            if state.threshold > 0.0 and subA is not None:
justheuristic's avatar
justheuristic committed
375
                grad_B[:, idx].addmm_(grad_output.t(), subA)
376

Tim Dettmers's avatar
Tim Dettmers committed
377
        if req_gradA:
dbaranchuk's avatar
dbaranchuk committed
378
            if state.CBt is not None:
379
380
381
382
383
384
                C32grad, Sgrad = F.transform(Cgrad, "col32")
                if state.CxBt is None:
                    state.CxBt, state.SBt = F.transform(
                        state.CBt, to_order=formatB, transpose=True
                    )
                gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
justheuristic's avatar
justheuristic committed
385
386
                grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.A_dtype)

dbaranchuk's avatar
dbaranchuk committed
387
            elif state.CB is not None:
justheuristic's avatar
justheuristic committed
388
                CB = state.CB.to(ctx.B_dtype)
389
390
                SCB = (state.SCB.unsqueeze(1) / 127.0).half()
                CB *= SCB
justheuristic's avatar
justheuristic committed
391
                grad_A = torch.mm(grad_output, CB).view(ctx.grad_shape).to(ctx.A_dtype)
392
            else:
dbaranchuk's avatar
dbaranchuk committed
393
                raise Exception('State must contain either CBt or CB matrix for backward')
Tim Dettmers's avatar
Tim Dettmers committed
394
395

        return grad_A, grad_B, None, grad_bias, None
Tim Dettmers's avatar
Tim Dettmers committed
396
397


398
def matmul(
399
400
401
402
403
    A: tensor,
    B: tensor,
    out: tensor = None,
    state: MatmulLtState = None,
    threshold=0.0,
Tim Dettmers's avatar
Tim Dettmers committed
404
    bias=None
405
):
Tim Dettmers's avatar
Tim Dettmers committed
406
407
408
    state = state or MatmulLtState()
    if threshold > 0.0:
        state.threshold = threshold
Tim Dettmers's avatar
Tim Dettmers committed
409
    return MatMul8bitLt.apply(A, B, out, bias, state)