_functions.py 13.6 KB
Newer Older
1
import operator
Tim Dettmers's avatar
Tim Dettmers committed
2
3
4
import torch
import bitsandbytes.functional as F

5
6
7
from dataclasses import dataclass
from functools import reduce  # Required in Python 3

8
# math.prod not compatible with python < 3.8
9
10
11
def prod(iterable):
    return reduce(operator.mul, iterable, 1)

Tim Dettmers's avatar
Tim Dettmers committed
12
13
tensor = torch.Tensor

14
"""
Tim Dettmers's avatar
Tim Dettmers committed
15
16
17
    This class pools outlier dimensions across layers.
    This is particularly important for small models where outlier features 
    are less systematic and occur with low frequency.
18
"""
Tim Dettmers's avatar
Tim Dettmers committed
19
20
21
22
class GlobalOutlierPooler(object):
    _instance = None

    def __init__(self):
23
        raise RuntimeError("Call get_instance() instead")
Tim Dettmers's avatar
Tim Dettmers committed
24
25
26
27
28
29
30
31
32
33
34
35
36

    def initialize(self):
        self.outliers = set()
        self.model_dim = None

    @classmethod
    def get_instance(cls):
        if cls._instance is None:
            cls._instance = cls.__new__(cls)
            cls._instance.initialize()
        return cls._instance

    def add_outliers(self, outlier_idx, feature_dim):
37
38
39
40
        if self.model_dim is None:
            self.model_dim = feature_dim
        if feature_dim != self.model_dim:
            return  # we do not encode outliers for the 2nd FFN layer
Tim Dettmers's avatar
Tim Dettmers committed
41
42
43
44
45
46
47

        self.outliers.update(outlier_idx.tolist())

    def get_current_outlier_idx(self):
        return torch.Tensor(list(self.outliers)).to(torch.int64)


48
class MatMul8bit(torch.autograd.Function):
Tim Dettmers's avatar
Tim Dettmers committed
49
    @staticmethod
50
    def forward(ctx, A, B, out=None, quant_type="vector", precision=[8, 8, 8]):
Tim Dettmers's avatar
Tim Dettmers committed
51
52
53
54
55

        if precision[0] != 8:
            with torch.no_grad():
                output = torch.matmul(A, B)
        else:
56
57
58
59
            if len(B.shape) == 2:
                dim = 0
            else:
                dim = 1
Tim Dettmers's avatar
Tim Dettmers committed
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
            qA, SA = F.vectorwise_quant(A, dim=-1, quant_type=quant_type)
            qB, SB = F.vectorwise_quant(B, dim=dim, quant_type=quant_type)
            iout = F.igemm(qA, qB)
            output = F.vectorwise_mm_dequant(iout, SA, SB, A.dtype, quant_type)

        if A.requires_grad or B.requires_grad:
            ctx.save_for_backward(A, B)

        ctx.quant_type = quant_type
        ctx.precision = precision

        return output

    @staticmethod
    def backward(ctx, grad_output):
        A, B = ctx.saved_tensors
        quant_type = ctx.quant_type
        precision = ctx.precision
        grad_A = grad_B = None

        if B.requires_grad:
            if len(A.shape) == 3:
                dims = [0, 1]
                # bsi -> ibs
                permute_dim = [0, 2, 1]
            else:
                dims = [0]
                # bs -> sb
                permute_dim = [1, 0]

            if precision[1] != 8:
                with torch.no_grad():
                    grad_B = torch.matmul(A.permute(permute_dim), grad_output)
            else:
                if len(B.shape) == 2 and len(A.shape) == 3:
                    grad_output = grad_output.contiguous()
96
97
98
99
100
101
102
103
104
105
106
107
                    if not grad_output.is_contiguous():
                        grad_output.contiguous()
                    qgrad_output, S1 = F.vectorwise_quant(
                        grad_output.view(-1, grad_output.shape[2]),
                        dim=0,
                        quant_type=quant_type,
                    )
                    if not A.is_contiguous():
                        A = A.contiguous()
                    qA, S2 = F.vectorwise_quant(
                        A.view(-1, A.shape[2]), dim=0, quant_type=quant_type
                    )
Tim Dettmers's avatar
Tim Dettmers committed
108
                    igrad_B = F.igemm(qA.t(), qgrad_output)
109
110
111
                    grad_B = F.vectorwise_mm_dequant(
                        igrad_B, S2.t(), S1, grad_output.dtype, quant_type
                    )
Tim Dettmers's avatar
Tim Dettmers committed
112
                else:
113
114
115
                    qgrad_output, S1 = F.vectorwise_quant(
                        grad_output, dim=dims, quant_type=quant_type
                    )
116
117
118
                    qA, S2 = F.vectorwise_quant(
                        A, dim=dims, quant_type=quant_type
                    )
Tim Dettmers's avatar
Tim Dettmers committed
119
                    igrad_B = F.igemm(qA.permute(permute_dim), qgrad_output)
120
121
122
123
124
125
126
                    grad_B = F.vectorwise_mm_dequant(
                        igrad_B,
                        S2.permute(permute_dim),
                        S1,
                        grad_output.dtype,
                        quant_type,
                    )
Tim Dettmers's avatar
Tim Dettmers committed
127
128

        if A.requires_grad:
129
130
131
132
            if len(grad_output.shape) == 3:
                dims = [2]
            else:
                dims = [1]
Tim Dettmers's avatar
Tim Dettmers committed
133
134
135
136
137
138
139
140
141
142
143
144
145
146

            if len(B.shape) == 3:
                # bio -> boi
                permute_dim = [0, 2, 1]
                dim_B = dims
            else:
                # io -> oi
                permute_dim = [1, 0]
                dim_B = [1]

            if precision[2] != 8:
                with torch.no_grad():
                    grad_A = torch.matmul(grad_output, B.permute(permute_dim))
            else:
147
148
149
                qgrad_output, S1 = F.vectorwise_quant(
                    grad_output, dim=dims, quant_type=quant_type
                )
Tim Dettmers's avatar
Tim Dettmers committed
150
151
                qB, S3 = F.vectorwise_quant(B, dim=dim_B, quant_type=quant_type)
                igrad_A = F.igemm(qgrad_output, qB.permute(permute_dim))
152
                grad_A = F.vectorwise_mm_dequant(
153
154
155
156
157
                    igrad_A,
                    S1,
                    S3.permute(permute_dim),
                    grad_output.dtype,
                    quant_type,
158
                )
Tim Dettmers's avatar
Tim Dettmers committed
159
160
161
162
163
164
165
166

        return grad_A, grad_B, None, None, None


mm_cublas = MatMul8bit.apply
bmm_cublas = MatMul8bit.apply
matmul_cublas = MatMul8bit.apply

167

Tim Dettmers's avatar
Tim Dettmers committed
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
@dataclass
class MatmulLtState:
    CB = None
    CxB = None
    SB = None
    SCB = None

    CxBt = None
    SBt = None
    CBt = None

    subB = None

    outlier_pool = None
    has_accumulated_gradients = False
    threshold = 0.0
    idx = None
    is_training = True
    has_fp16_weights = True
dbaranchuk's avatar
dbaranchuk committed
187
    memory_efficient_backward = False
Tim Dettmers's avatar
Tim Dettmers committed
188
189
190
191
192
193
194
195
196
197
198
    use_pool = False
    formatB = F.get_special_format_str()

    def reset_grads(self):
        self.CB = None
        self.CxB = None
        self.SB = None
        self.SCB = None

        self.CxBt = None
        self.SBt = None
dbaranchuk's avatar
dbaranchuk committed
199
        self.CBt = None
Tim Dettmers's avatar
Tim Dettmers committed
200
201
202
203


class MatMul8bitLt(torch.autograd.Function):
    @staticmethod
Tim Dettmers's avatar
Tim Dettmers committed
204
    def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()):
205
206
        # default to pytorch behavior if inputs are empty
        ctx.is_empty = False
207
        if prod(A.shape) == 0:
208
209
210
            ctx.is_empty = True
            ctx.A = A
            ctx.B = B
Tim Dettmers's avatar
Tim Dettmers committed
211
            ctx.bias = bias
212
213
214
215
216
            if A.shape[-1] == B.shape[0]:
                return torch.empty(A.shape[:-1]+B.shape[1:], dtype=torch.float16, device=A.device)
            else:
                return torch.empty(A.shape[:-1]+B.shape[:1], dtype=torch.float16, device=A.device)

Tim Dettmers's avatar
Tim Dettmers committed
217
218
219
220
221
222
223
        # 1. Quantize A
        # 2. Quantize B
        # 3. Matmul
        # 4. Mixed-precision decomposition matmul
        # 5. Save state
        requires_gradA = A.requires_grad
        requires_gradB = B.requires_grad
Tim Dettmers's avatar
Tim Dettmers committed
224
        requires_gradBias = bias is not None and bias.requires_grad
Tim Dettmers's avatar
Tim Dettmers committed
225
226
        formatB = state.formatB
        input_shape = A.shape
227
228
        if state.outlier_pool is None:
            state.outlier_pool = GlobalOutlierPooler.get_instance()
229
230
231
232
233

        # Cast A to fp16
        A_dtype = A.dtype
        A = A.to(torch.float16)

Tim Dettmers's avatar
Tim Dettmers committed
234
        # 1. Quantize A
235
236
        if len(A.shape) == 3:
            A = A.view(-1, A.shape[-1]).contiguous()
237
238
239
        CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(
            A, threshold=state.threshold
        )
Tim Dettmers's avatar
Tim Dettmers committed
240
241
242
243
244
245
246
247
248

        if state.threshold > 0.0 and coo_tensorA is not None:
            if state.has_fp16_weights:
                idx = torch.unique(coo_tensorA.colidx).long()
                CA[:, idx] = 0
                CAt[:, idx] = 0
                subA = A[:, idx]
                state.subB = B[:, idx].t().contiguous()
                state.idx = idx
dbaranchuk's avatar
dbaranchuk committed
249
250
251
252
253
            else:
                if state.CxB is None:
                    # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
                    # we also need to convert it to the turing/ampere format
                    state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
Tim Dettmers's avatar
Tim Dettmers committed
254
255
256
257
258
259
260
        else:
            if not state.has_fp16_weights and state.CxB is None:
                state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
            subA = None

        # 2. Quantize B
        if state.has_fp16_weights:
261
            has_grad = True if (getattr(B, "grad", None) is not None) else False
Tim Dettmers's avatar
Tim Dettmers committed
262
            is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1)
263
264
            if is_transposed:
                B = B.contiguous()
Tim Dettmers's avatar
Tim Dettmers committed
265
266
267

            if (state.is_training and not has_grad) or state.CxB is None:
                state.reset_grads()
268
269
270
271
272
273
274
                (
                    CB,
                    state.CBt,
                    state.SCB,
                    state.SCBt,
                    coo_tensorB,
                ) = F.double_quant(B)
Tim Dettmers's avatar
Tim Dettmers committed
275
276
277
278
                state.CxB, state.SB = F.transform(CB, to_order=formatB)
        else:
            has_grad = False

279
280
281
282
        if coo_tensorA is not None and not state.has_fp16_weights:
            # extract outliers

            outlier_idx = torch.unique(coo_tensorA.colidx)
283
            state.idx = outlier_idx
284
285
286
287
288
289
            # state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
            # if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
            #    # do not use pool for 2nd FFN layer
            #    state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
            # else:
            #    state.idx = outlier_idx
290
            outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
291
            state.subB = (
292
293
294
295
                (outliers * state.SCB.view(-1, 1) / 127.0)
                .t()
                .contiguous()
                .half()
296
            )
297
298
299
300
            CA[:, state.idx.long()] = 0
            CAt[:, state.idx.long()] = 0
            subA = A[:, state.idx.long()]

Tim Dettmers's avatar
Tim Dettmers committed
301
302
303
304
305
306
307
308
        shapeB = state.SB[0]

        if len(input_shape) == 3:
            output_shape = (input_shape[0], input_shape[1], shapeB[0])
        else:
            output_shape = (input_shape[0], shapeB[0])

        # 3. Matmul
309
        C32A, SA = F.transform(CA, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
310
        out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
Tim Dettmers's avatar
Tim Dettmers committed
311
312
        # we apply the fused bias here
        output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias)
Tim Dettmers's avatar
Tim Dettmers committed
313
314

        # 4. Mixed-precision decomposition matmul
315
        if coo_tensorA is not None and subA is not None:
Tim Dettmers's avatar
Tim Dettmers committed
316
317
318
319
320
321
322
            output += torch.matmul(subA, state.subB)

        # 5. Save state
        ctx.state = state

        ctx.formatB = formatB
        ctx.grad_shape = input_shape
Tim Dettmers's avatar
Tim Dettmers committed
323
        ctx.req_grads = [requires_gradA, requires_gradB, requires_gradBias]
Tim Dettmers's avatar
Tim Dettmers committed
324
325
326
327
328
329
330
331
332

        if requires_gradA or requires_gradB:
            ctx.tensors = (CAt, subA)
            ctx.tensor_states = (SCAt, state.idx)
        else:
            ctx.tensors = [None, None]
            ctx.tensor_states = (None, None)
            ctx.save_for_backward(None, None)

dbaranchuk's avatar
dbaranchuk committed
333
334
335
        # Cast fp16 output back to A.dtype
        output = output.to(A_dtype)

Tim Dettmers's avatar
Tim Dettmers committed
336
        clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
Tim Dettmers's avatar
Tim Dettmers committed
337
338
        return clone_func(output.view(output_shape))

339
    @staticmethod
Tim Dettmers's avatar
Tim Dettmers committed
340
    def backward(ctx, grad_output):
341
        if ctx.is_empty:
Tim Dettmers's avatar
Tim Dettmers committed
342
343
344
            bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias))
            return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
        req_gradA, req_gradB, req_gradBias = ctx.req_grads
345
346
347
        CAt, subA = ctx.tensors
        SCAt, idx = ctx.tensor_states
        formatB = ctx.formatB
Tim Dettmers's avatar
Tim Dettmers committed
348
349
        state = ctx.state

350
351
352
353
        # Cast grad_output to fp16
        grad_output_dtype = grad_output.dtype
        grad_output = grad_output.to(torch.float16)

Tim Dettmers's avatar
Tim Dettmers committed
354
        if len(grad_output.shape) == 3:
dbaranchuk's avatar
dbaranchuk committed
355
            grad_output = grad_output.reshape(
356
357
                -1, grad_output.shape[-1]
            ).contiguous()
Tim Dettmers's avatar
Tim Dettmers committed
358

Tim Dettmers's avatar
Tim Dettmers committed
359
        grad_A = grad_B = grad_bias = None
Tim Dettmers's avatar
Tim Dettmers committed
360

361
362
363
364
365
366
367
368
369
        Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output)
        if req_gradB:
            CxAt, SAt = F.transform(CAt, formatB, transpose=True)
            C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True)
            gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt)
            grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt)
            if state.threshold > 0.0 and subA is not None:
                grad_B[:, idx] += torch.matmul(grad_output.t(), subA)

Tim Dettmers's avatar
Tim Dettmers committed
370
        if req_gradA:
dbaranchuk's avatar
dbaranchuk committed
371
            if state.CBt is not None:
372
373
374
375
376
377
378
                C32grad, Sgrad = F.transform(Cgrad, "col32")
                if state.CxBt is None:
                    state.CxBt, state.SBt = F.transform(
                        state.CBt, to_order=formatB, transpose=True
                    )
                gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
                grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape)
dbaranchuk's avatar
dbaranchuk committed
379
            elif state.CB is not None:
380
381
382
383
384
                CB = state.CB.half()
                SCB = (state.SCB.unsqueeze(1) / 127.0).half()
                CB *= SCB
                grad_A = torch.mm(grad_output, CB).view(ctx.grad_shape)
            else:
dbaranchuk's avatar
dbaranchuk committed
385
                raise Exception('State must contain either CBt or CB matrix for backward')
Tim Dettmers's avatar
Tim Dettmers committed
386
387
388

        if req_gradBias:
            grad_bias = grad_output.sum(0)
Tim Dettmers's avatar
Tim Dettmers committed
389

dbaranchuk's avatar
dbaranchuk committed
390
        # Cast grad_A back to grad_output_dtype
dbaranchuk's avatar
bug fix  
dbaranchuk committed
391
        grad_output = grad_output.to(grad_output_dtype)
dbaranchuk's avatar
dbaranchuk committed
392

Tim Dettmers's avatar
Tim Dettmers committed
393
        return grad_A, grad_B, None, grad_bias, None
Tim Dettmers's avatar
Tim Dettmers committed
394
395


396
def matmul(
397
398
399
400
401
    A: tensor,
    B: tensor,
    out: tensor = None,
    state: MatmulLtState = None,
    threshold=0.0,
Tim Dettmers's avatar
Tim Dettmers committed
402
    bias=None
403
):
Tim Dettmers's avatar
Tim Dettmers committed
404
405
406
    state = state or MatmulLtState()
    if threshold > 0.0:
        state.threshold = threshold
Tim Dettmers's avatar
Tim Dettmers committed
407
    return MatMul8bitLt.apply(A, B, out, bias, state)