_functions.py 14 KB
Newer Older
1
import operator
2
3
import warnings

Tim Dettmers's avatar
Tim Dettmers committed
4
5
6
import torch
import bitsandbytes.functional as F

7
8
9
from dataclasses import dataclass
from functools import reduce  # Required in Python 3

10
# math.prod not compatible with python < 3.8
11
12
13
def prod(iterable):
    return reduce(operator.mul, iterable, 1)

Tim Dettmers's avatar
Tim Dettmers committed
14
15
tensor = torch.Tensor

16
"""
Tim Dettmers's avatar
Tim Dettmers committed
17
18
19
    This class pools outlier dimensions across layers.
    This is particularly important for small models where outlier features 
    are less systematic and occur with low frequency.
20
"""
Tim Dettmers's avatar
Tim Dettmers committed
21
22
23
24
class GlobalOutlierPooler(object):
    _instance = None

    def __init__(self):
25
        raise RuntimeError("Call get_instance() instead")
Tim Dettmers's avatar
Tim Dettmers committed
26
27
28
29
30
31
32
33
34
35
36
37
38

    def initialize(self):
        self.outliers = set()
        self.model_dim = None

    @classmethod
    def get_instance(cls):
        if cls._instance is None:
            cls._instance = cls.__new__(cls)
            cls._instance.initialize()
        return cls._instance

    def add_outliers(self, outlier_idx, feature_dim):
39
40
41
42
        if self.model_dim is None:
            self.model_dim = feature_dim
        if feature_dim != self.model_dim:
            return  # we do not encode outliers for the 2nd FFN layer
Tim Dettmers's avatar
Tim Dettmers committed
43
44
45
46
47
48
49

        self.outliers.update(outlier_idx.tolist())

    def get_current_outlier_idx(self):
        return torch.Tensor(list(self.outliers)).to(torch.int64)


50
class MatMul8bit(torch.autograd.Function):
Tim Dettmers's avatar
Tim Dettmers committed
51
    @staticmethod
52
    def forward(ctx, A, B, out=None, quant_type="vector", precision=[8, 8, 8]):
Tim Dettmers's avatar
Tim Dettmers committed
53
54
55
56
57

        if precision[0] != 8:
            with torch.no_grad():
                output = torch.matmul(A, B)
        else:
58
59
60
61
            if len(B.shape) == 2:
                dim = 0
            else:
                dim = 1
Tim Dettmers's avatar
Tim Dettmers committed
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
            qA, SA = F.vectorwise_quant(A, dim=-1, quant_type=quant_type)
            qB, SB = F.vectorwise_quant(B, dim=dim, quant_type=quant_type)
            iout = F.igemm(qA, qB)
            output = F.vectorwise_mm_dequant(iout, SA, SB, A.dtype, quant_type)

        if A.requires_grad or B.requires_grad:
            ctx.save_for_backward(A, B)

        ctx.quant_type = quant_type
        ctx.precision = precision

        return output

    @staticmethod
    def backward(ctx, grad_output):
        A, B = ctx.saved_tensors
        quant_type = ctx.quant_type
        precision = ctx.precision
        grad_A = grad_B = None

        if B.requires_grad:
            if len(A.shape) == 3:
                dims = [0, 1]
                # bsi -> ibs
                permute_dim = [0, 2, 1]
            else:
                dims = [0]
                # bs -> sb
                permute_dim = [1, 0]

            if precision[1] != 8:
                with torch.no_grad():
                    grad_B = torch.matmul(A.permute(permute_dim), grad_output)
            else:
                if len(B.shape) == 2 and len(A.shape) == 3:
                    grad_output = grad_output.contiguous()
98
99
100
101
102
103
104
105
106
107
108
109
                    if not grad_output.is_contiguous():
                        grad_output.contiguous()
                    qgrad_output, S1 = F.vectorwise_quant(
                        grad_output.view(-1, grad_output.shape[2]),
                        dim=0,
                        quant_type=quant_type,
                    )
                    if not A.is_contiguous():
                        A = A.contiguous()
                    qA, S2 = F.vectorwise_quant(
                        A.view(-1, A.shape[2]), dim=0, quant_type=quant_type
                    )
Tim Dettmers's avatar
Tim Dettmers committed
110
                    igrad_B = F.igemm(qA.t(), qgrad_output)
111
112
113
                    grad_B = F.vectorwise_mm_dequant(
                        igrad_B, S2.t(), S1, grad_output.dtype, quant_type
                    )
Tim Dettmers's avatar
Tim Dettmers committed
114
                else:
115
116
117
                    qgrad_output, S1 = F.vectorwise_quant(
                        grad_output, dim=dims, quant_type=quant_type
                    )
118
119
120
                    qA, S2 = F.vectorwise_quant(
                        A, dim=dims, quant_type=quant_type
                    )
Tim Dettmers's avatar
Tim Dettmers committed
121
                    igrad_B = F.igemm(qA.permute(permute_dim), qgrad_output)
122
123
124
125
126
127
128
                    grad_B = F.vectorwise_mm_dequant(
                        igrad_B,
                        S2.permute(permute_dim),
                        S1,
                        grad_output.dtype,
                        quant_type,
                    )
Tim Dettmers's avatar
Tim Dettmers committed
129
130

        if A.requires_grad:
131
132
133
134
            if len(grad_output.shape) == 3:
                dims = [2]
            else:
                dims = [1]
Tim Dettmers's avatar
Tim Dettmers committed
135
136
137
138
139
140
141
142
143
144
145
146
147
148

            if len(B.shape) == 3:
                # bio -> boi
                permute_dim = [0, 2, 1]
                dim_B = dims
            else:
                # io -> oi
                permute_dim = [1, 0]
                dim_B = [1]

            if precision[2] != 8:
                with torch.no_grad():
                    grad_A = torch.matmul(grad_output, B.permute(permute_dim))
            else:
149
150
151
                qgrad_output, S1 = F.vectorwise_quant(
                    grad_output, dim=dims, quant_type=quant_type
                )
Tim Dettmers's avatar
Tim Dettmers committed
152
153
                qB, S3 = F.vectorwise_quant(B, dim=dim_B, quant_type=quant_type)
                igrad_A = F.igemm(qgrad_output, qB.permute(permute_dim))
154
                grad_A = F.vectorwise_mm_dequant(
155
156
157
158
159
                    igrad_A,
                    S1,
                    S3.permute(permute_dim),
                    grad_output.dtype,
                    quant_type,
160
                )
Tim Dettmers's avatar
Tim Dettmers committed
161
162
163
164
165
166
167
168

        return grad_A, grad_B, None, None, None


mm_cublas = MatMul8bit.apply
bmm_cublas = MatMul8bit.apply
matmul_cublas = MatMul8bit.apply

169

Tim Dettmers's avatar
Tim Dettmers committed
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
@dataclass
class MatmulLtState:
    CB = None
    CxB = None
    SB = None
    SCB = None

    CxBt = None
    SBt = None
    CBt = None

    subB = None

    outlier_pool = None
    has_accumulated_gradients = False
    threshold = 0.0
    idx = None
    is_training = True
    has_fp16_weights = True
dbaranchuk's avatar
dbaranchuk committed
189
    memory_efficient_backward = False
Tim Dettmers's avatar
Tim Dettmers committed
190
191
192
193
194
195
196
197
198
199
200
    use_pool = False
    formatB = F.get_special_format_str()

    def reset_grads(self):
        self.CB = None
        self.CxB = None
        self.SB = None
        self.SCB = None

        self.CxBt = None
        self.SBt = None
dbaranchuk's avatar
dbaranchuk committed
201
        self.CBt = None
Tim Dettmers's avatar
Tim Dettmers committed
202
203
204
205


class MatMul8bitLt(torch.autograd.Function):
    @staticmethod
Tim Dettmers's avatar
Tim Dettmers committed
206
    def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()):
207
208
        # default to pytorch behavior if inputs are empty
        ctx.is_empty = False
209
        if prod(A.shape) == 0:
210
211
212
            ctx.is_empty = True
            ctx.A = A
            ctx.B = B
Tim Dettmers's avatar
Tim Dettmers committed
213
            ctx.bias = bias
214
215
216
217
218
            if A.shape[-1] == B.shape[0]:
                return torch.empty(A.shape[:-1]+B.shape[1:], dtype=torch.float16, device=A.device)
            else:
                return torch.empty(A.shape[:-1]+B.shape[:1], dtype=torch.float16, device=A.device)

Tim Dettmers's avatar
Tim Dettmers committed
219
220
221
222
223
224
225
        # 1. Quantize A
        # 2. Quantize B
        # 3. Matmul
        # 4. Mixed-precision decomposition matmul
        # 5. Save state
        requires_gradA = A.requires_grad
        requires_gradB = B.requires_grad
Tim Dettmers's avatar
Tim Dettmers committed
226
        requires_gradBias = bias is not None and bias.requires_grad
Tim Dettmers's avatar
Tim Dettmers committed
227
228
        formatB = state.formatB
        input_shape = A.shape
229
230
        if state.outlier_pool is None:
            state.outlier_pool = GlobalOutlierPooler.get_instance()
231
232
233

        # Cast A to fp16
        A_dtype = A.dtype
234
        if A_dtype != torch.float16:
justheuristic's avatar
justheuristic committed
235
236
            warnings.warn(f"MatMul8bitLt: input matrix will be converted from {A_dtype} to float16")
            A = A.to(torch.float16)
237

Tim Dettmers's avatar
Tim Dettmers committed
238
        # 1. Quantize A
239
240
        if len(A.shape) == 3:
            A = A.view(-1, A.shape[-1]).contiguous()
241
242
243
        CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(
            A, threshold=state.threshold
        )
Tim Dettmers's avatar
Tim Dettmers committed
244
245
246
247
248
249
250
251
252

        if state.threshold > 0.0 and coo_tensorA is not None:
            if state.has_fp16_weights:
                idx = torch.unique(coo_tensorA.colidx).long()
                CA[:, idx] = 0
                CAt[:, idx] = 0
                subA = A[:, idx]
                state.subB = B[:, idx].t().contiguous()
                state.idx = idx
dbaranchuk's avatar
dbaranchuk committed
253
254
255
256
257
            else:
                if state.CxB is None:
                    # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
                    # we also need to convert it to the turing/ampere format
                    state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
Tim Dettmers's avatar
Tim Dettmers committed
258
259
260
261
262
263
264
        else:
            if not state.has_fp16_weights and state.CxB is None:
                state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
            subA = None

        # 2. Quantize B
        if state.has_fp16_weights:
265
            has_grad = True if (getattr(B, "grad", None) is not None) else False
Tim Dettmers's avatar
Tim Dettmers committed
266
            is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1)
267
268
            if is_transposed:
                B = B.contiguous()
Tim Dettmers's avatar
Tim Dettmers committed
269
270
271

            if (state.is_training and not has_grad) or state.CxB is None:
                state.reset_grads()
272
273
274
275
276
277
                (
                    CB,
                    state.CBt,
                    state.SCB,
                    state.SCBt,
                    coo_tensorB,
justheuristic's avatar
justheuristic committed
278
                ) = F.double_quant(B.to(torch.float16))
Tim Dettmers's avatar
Tim Dettmers committed
279
280
281
282
                state.CxB, state.SB = F.transform(CB, to_order=formatB)
        else:
            has_grad = False

283
284
285
286
        if coo_tensorA is not None and not state.has_fp16_weights:
            # extract outliers

            outlier_idx = torch.unique(coo_tensorA.colidx)
287
            state.idx = outlier_idx
288
289
290
291
292
293
            # state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
            # if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
            #    # do not use pool for 2nd FFN layer
            #    state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
            # else:
            #    state.idx = outlier_idx
294
            outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
295
            state.subB = (
296
297
298
299
                (outliers * state.SCB.view(-1, 1) / 127.0)
                .t()
                .contiguous()
                .half()
300
            )
301
302
303
304
            CA[:, state.idx.long()] = 0
            CAt[:, state.idx.long()] = 0
            subA = A[:, state.idx.long()]

Tim Dettmers's avatar
Tim Dettmers committed
305
306
307
308
309
310
311
312
        shapeB = state.SB[0]

        if len(input_shape) == 3:
            output_shape = (input_shape[0], input_shape[1], shapeB[0])
        else:
            output_shape = (input_shape[0], shapeB[0])

        # 3. Matmul
313
        C32A, SA = F.transform(CA, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
314
        out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
Tim Dettmers's avatar
Tim Dettmers committed
315
        # we apply the fused bias here
justheuristic's avatar
justheuristic committed
316

justheuristic's avatar
justheuristic committed
317
318
        if bias is None or bias.dtype == torch.float16:
            output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias)
justheuristic's avatar
justheuristic committed
319
            delayed_bias = None
justheuristic's avatar
justheuristic committed
320
321
        else:  # apply bias separately
            output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None)
justheuristic's avatar
justheuristic committed
322
            delayed_bias = bias
Tim Dettmers's avatar
Tim Dettmers committed
323
324

        # 4. Mixed-precision decomposition matmul
325
        if coo_tensorA is not None and subA is not None:
justheuristic's avatar
justheuristic committed
326
            output.addmm_(subA, state.subB)
Tim Dettmers's avatar
Tim Dettmers committed
327
328
329
330
331
332

        # 5. Save state
        ctx.state = state

        ctx.formatB = formatB
        ctx.grad_shape = input_shape
Tim Dettmers's avatar
Tim Dettmers committed
333
        ctx.req_grads = [requires_gradA, requires_gradB, requires_gradBias]
Tim Dettmers's avatar
Tim Dettmers committed
334
335
336
337
338
339
340
341
342

        if requires_gradA or requires_gradB:
            ctx.tensors = (CAt, subA)
            ctx.tensor_states = (SCAt, state.idx)
        else:
            ctx.tensors = [None, None]
            ctx.tensor_states = (None, None)
            ctx.save_for_backward(None, None)

justheuristic's avatar
justheuristic committed
343
344
345
        output = output.to(A_dtype)
        if delayed_bias is not None:
            output.add_(delayed_bias)
dbaranchuk's avatar
dbaranchuk committed
346

Tim Dettmers's avatar
Tim Dettmers committed
347
        clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
Tim Dettmers's avatar
Tim Dettmers committed
348
349
        return clone_func(output.view(output_shape))

350
    @staticmethod
Tim Dettmers's avatar
Tim Dettmers committed
351
    def backward(ctx, grad_output):
352
        if ctx.is_empty:
Tim Dettmers's avatar
Tim Dettmers committed
353
354
355
            bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias))
            return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
        req_gradA, req_gradB, req_gradBias = ctx.req_grads
356
357
358
        CAt, subA = ctx.tensors
        SCAt, idx = ctx.tensor_states
        formatB = ctx.formatB
Tim Dettmers's avatar
Tim Dettmers committed
359
360
        state = ctx.state

361
362
363
364
        # Cast grad_output to fp16
        grad_output_dtype = grad_output.dtype
        grad_output = grad_output.to(torch.float16)

Tim Dettmers's avatar
Tim Dettmers committed
365
        if len(grad_output.shape) == 3:
dbaranchuk's avatar
dbaranchuk committed
366
            grad_output = grad_output.reshape(
367
368
                -1, grad_output.shape[-1]
            ).contiguous()
Tim Dettmers's avatar
Tim Dettmers committed
369

Tim Dettmers's avatar
Tim Dettmers committed
370
        grad_A = grad_B = grad_bias = None
Tim Dettmers's avatar
Tim Dettmers committed
371

372
373
374
375
376
377
378
379
380
        Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output)
        if req_gradB:
            CxAt, SAt = F.transform(CAt, formatB, transpose=True)
            C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True)
            gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt)
            grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt)
            if state.threshold > 0.0 and subA is not None:
                grad_B[:, idx] += torch.matmul(grad_output.t(), subA)

Tim Dettmers's avatar
Tim Dettmers committed
381
        if req_gradA:
dbaranchuk's avatar
dbaranchuk committed
382
            if state.CBt is not None:
383
384
385
386
387
388
389
                C32grad, Sgrad = F.transform(Cgrad, "col32")
                if state.CxBt is None:
                    state.CxBt, state.SBt = F.transform(
                        state.CBt, to_order=formatB, transpose=True
                    )
                gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
                grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape)
dbaranchuk's avatar
dbaranchuk committed
390
            elif state.CB is not None:
391
392
393
394
395
                CB = state.CB.half()
                SCB = (state.SCB.unsqueeze(1) / 127.0).half()
                CB *= SCB
                grad_A = torch.mm(grad_output, CB).view(ctx.grad_shape)
            else:
dbaranchuk's avatar
dbaranchuk committed
396
                raise Exception('State must contain either CBt or CB matrix for backward')
Tim Dettmers's avatar
Tim Dettmers committed
397
398
399

        if req_gradBias:
            grad_bias = grad_output.sum(0)
Tim Dettmers's avatar
Tim Dettmers committed
400

dbaranchuk's avatar
dbaranchuk committed
401
        # Cast grad_A back to grad_output_dtype
dbaranchuk's avatar
bug fix  
dbaranchuk committed
402
        grad_output = grad_output.to(grad_output_dtype)
dbaranchuk's avatar
dbaranchuk committed
403

Tim Dettmers's avatar
Tim Dettmers committed
404
        return grad_A, grad_B, None, grad_bias, None
Tim Dettmers's avatar
Tim Dettmers committed
405
406


407
def matmul(
408
409
410
411
412
    A: tensor,
    B: tensor,
    out: tensor = None,
    state: MatmulLtState = None,
    threshold=0.0,
Tim Dettmers's avatar
Tim Dettmers committed
413
    bias=None
414
):
Tim Dettmers's avatar
Tim Dettmers committed
415
416
417
    state = state or MatmulLtState()
    if threshold > 0.0:
        state.threshold = threshold
Tim Dettmers's avatar
Tim Dettmers committed
418
    return MatMul8bitLt.apply(A, B, out, bias, state)