triton_based_modules.py 9.61 KB
Newer Older
Mitchell Wortsman's avatar
Mitchell Wortsman committed
1
2
3
import torch
import torch.nn as nn
import time
Mitchell Wortsman's avatar
Mitchell Wortsman committed
4
from functools import partial
Mitchell Wortsman's avatar
Mitchell Wortsman committed
5

6
7
from bitsandbytes.triton.triton_utils import is_triton_available

8
9
10
11
12
from bitsandbytes.triton.dequantize_rowwise import dequantize_rowwise
from bitsandbytes.triton.quantize_rowwise import quantize_rowwise
from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose
from bitsandbytes.triton.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize
from bitsandbytes.triton.quantize_global import quantize_global, quantize_global_transpose
Aarni Koskela's avatar
Aarni Koskela committed
13
from bitsandbytes.triton.int8_matmul_mixed_dequantize import int8_matmul_mixed_dequantize
Mitchell Wortsman's avatar
Mitchell Wortsman committed
14

Mitchell Wortsman's avatar
Mitchell Wortsman committed
15
16

class _switchback_global(torch.autograd.Function):
Mitchell Wortsman's avatar
Mitchell Wortsman committed
17
18
19

    @staticmethod
    def forward(ctx, X_3D, W, bias):
Mitchell Wortsman's avatar
Mitchell Wortsman committed
20
        # reshape input to [N * L, D]
Mitchell Wortsman's avatar
Mitchell Wortsman committed
21
22
        X = X_3D.view(-1, X_3D.size(-1))

Mitchell Wortsman's avatar
Mitchell Wortsman committed
23
24
25
26
27
        # rowwise quantize for X, global quantize for W
        X_int8, state_X = quantize_rowwise(X)
        W_int8, state_W = quantize_global(W)

        # save for backward.
Mitchell Wortsman's avatar
Mitchell Wortsman committed
28
        ctx.save_for_backward = X, W
Mitchell Wortsman's avatar
Mitchell Wortsman committed
29
30
31

        # matmult, fused dequant and add bias
        # call "mixed" because we are mixing rowwise quantized and global quantized
Aarni Koskela's avatar
Aarni Koskela committed
32
        return int8_matmul_mixed_dequantize(
Mitchell Wortsman's avatar
Mitchell Wortsman committed
33
34
            X_int8, W_int8.t(), state_X, state_W, bias
        ).view(*X_3D.size()[:-1], -1)
Mitchell Wortsman's avatar
Mitchell Wortsman committed
35

Mitchell Wortsman's avatar
Mitchell Wortsman committed
36
37
    @staticmethod
    def backward(ctx, G_3D):
Mitchell Wortsman's avatar
Mitchell Wortsman committed
38
        # reshape input to [N_out * L, D]
Mitchell Wortsman's avatar
Mitchell Wortsman committed
39
40
41
42
        G = G_3D.reshape(-1, G_3D.size(-1))

        grad_X = grad_W = grad_bias = None

Mitchell Wortsman's avatar
Mitchell Wortsman committed
43
        X, W = ctx.save_for_backward
Mitchell Wortsman's avatar
Mitchell Wortsman committed
44
        if ctx.needs_input_grad[0]:
Mitchell Wortsman's avatar
Mitchell Wortsman committed
45
46
47
48
49
            # rowwise quantize for G, global quantize for W
            # for W, we also fuse the transpose operation because only A @ B^T is supported
            # so we transpose once then call .t() in the matmul
            G_int8, state_G = quantize_rowwise(G)
            W_int8, state_W = quantize_global_transpose(W)
Aarni Koskela's avatar
Aarni Koskela committed
50
            grad_X = int8_matmul_mixed_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view(
Mitchell Wortsman's avatar
Mitchell Wortsman committed
51
52
53
                *G_3D.size()[:-1], -1
            )
        if ctx.needs_input_grad[1]:
Mitchell Wortsman's avatar
Mitchell Wortsman committed
54
            # backward pass uses standard weight grad
Mitchell Wortsman's avatar
Mitchell Wortsman committed
55
56
57
58
59
            grad_W = torch.matmul(G.t(), X.to(G.dtype))
        if ctx.needs_input_grad[2]:
            grad_bias = G.sum(dim=0)

        return grad_X, grad_W, grad_bias
60

Mitchell Wortsman's avatar
Mitchell Wortsman committed
61
class _switchback_vectorrize(torch.autograd.Function):
Mitchell Wortsman's avatar
Mitchell Wortsman committed
62
63
64

    @staticmethod
    def forward(ctx, X_3D, W, bias):
Mitchell Wortsman's avatar
Mitchell Wortsman committed
65
        # reshape input to [N * L, D]
Mitchell Wortsman's avatar
Mitchell Wortsman committed
66
67
68
        X = X_3D.view(-1, X_3D.size(-1))

        ctx.save_for_backward = X, W
Mitchell Wortsman's avatar
Mitchell Wortsman committed
69
70
71
72
73
74
75
76
        # rowwise quantize for X
        # columnwise quantize for W (first rowwise, transpose later)
        X_int8, state_X = quantize_rowwise(X)
        W_int8, state_W = quantize_rowwise(W)

        # matmult, fused dequant and add bias
        # call kernel which expects rowwise quantized X and W
        return int8_matmul_rowwise_dequantize(
Mitchell Wortsman's avatar
Mitchell Wortsman committed
77
78
            X_int8, W_int8.t(), state_X, state_W, bias
        ).view(*X_3D.size()[:-1], -1)
79

Mitchell Wortsman's avatar
Mitchell Wortsman committed
80
81
    @staticmethod
    def backward(ctx, G_3D):
Mitchell Wortsman's avatar
Mitchell Wortsman committed
82
        X, W = ctx.save_for_backward
Mitchell Wortsman's avatar
Mitchell Wortsman committed
83
84
85
86
87
88

        G = G_3D.reshape(-1, G_3D.size(-1))

        grad_X = grad_W = grad_bias = None

        if ctx.needs_input_grad[0]:
Mitchell Wortsman's avatar
Mitchell Wortsman committed
89
90
91
92
93
            # rowwise quantize for G, columnwise quantize for W and fused transpose
            # we call .t() for weight later because only A @ B^T is supported
            G_int8, state_G = quantize_rowwise(G)
            W_int8, state_W = quantize_columnwise_and_transpose(W)
            grad_X = int8_matmul_rowwise_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view(
Mitchell Wortsman's avatar
Mitchell Wortsman committed
94
95
96
                *G_3D.size()[:-1], -1
            )
        if ctx.needs_input_grad[1]:
Mitchell Wortsman's avatar
Mitchell Wortsman committed
97
            # backward pass uses standard weight grad
Mitchell Wortsman's avatar
Mitchell Wortsman committed
98
99
100
101
102
            grad_W = torch.matmul(G.t(), X.to(G.dtype))
        if ctx.needs_input_grad[2]:
            grad_bias = G.sum(dim=0)

        return grad_X, grad_W, grad_bias
103

Mitchell Wortsman's avatar
Mitchell Wortsman committed
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
class _switchback_global_mem_efficient(torch.autograd.Function):

    @staticmethod
    def forward(ctx, X_3D, W, bias):
        # reshape input to [N * L, D]
        X = X_3D.view(-1, X_3D.size(-1))
        X_3D_sz = X_3D.size()

        # rowwise quantize for X, global quantize for W
        X_int8, state_X = quantize_rowwise(X)
        del X
        W_int8, state_W = quantize_global(W)

        # save for backward.
        ctx.save_for_backward = X_int8, state_X, W_int8, state_W

        # matmult, fused dequant and add bias
        # call "mixed" because we are mixing rowwise quantized and global quantized
Aarni Koskela's avatar
Aarni Koskela committed
122
        return int8_matmul_mixed_dequantize(
Mitchell Wortsman's avatar
Mitchell Wortsman committed
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
            X_int8, W_int8.t(), state_X, state_W, bias
        ).view(*X_3D_sz[:-1], -1)

    @staticmethod
    def backward(ctx, G_3D):
        # reshape input to [N_out * L, D]
        G = G_3D.reshape(-1, G_3D.size(-1))
        G_3D_sz = G_3D.size()

        grad_X = grad_W = grad_bias = None

        X_int8, state_X, W_int8, state_W = ctx.save_for_backward
        if ctx.needs_input_grad[1]:
            real_X = dequantize_rowwise(X_int8, state_X)
            del X_int8
            grad_W = torch.matmul(G.t(), real_X.to(G.dtype))
            del real_X
        if ctx.needs_input_grad[2]:
            grad_bias = G.sum(dim=0)
        if ctx.needs_input_grad[0]:
            G_int8, state_G = quantize_rowwise(G)
            del G
            W_int8 = W_int8.t().contiguous()
Aarni Koskela's avatar
Aarni Koskela committed
146
            grad_X = int8_matmul_mixed_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view(
Mitchell Wortsman's avatar
Mitchell Wortsman committed
147
148
149
150
                *G_3D_sz[:-1], -1
            )

        return grad_X, grad_W, grad_bias
Mitchell Wortsman's avatar
Mitchell Wortsman committed
151

Mitchell Wortsman's avatar
Mitchell Wortsman committed
152
153
class SwitchBackLinear(nn.Linear):
    def __init__(
154
155
156
            self,
            in_features: int,
            out_features: int,
Mitchell Wortsman's avatar
Mitchell Wortsman committed
157
            bias: bool = True,
158
            device=None,
Mitchell Wortsman's avatar
Mitchell Wortsman committed
159
            dtype=None,
160
            vector_wise_quantization: bool = False,
Mitchell Wortsman's avatar
Mitchell Wortsman committed
161
            mem_efficient : bool = False,
Mitchell Wortsman's avatar
Mitchell Wortsman committed
162
163
164
        ):
        super().__init__(in_features, out_features, bias, device, dtype)

165
        if not is_triton_available():
166
167
168
            raise ImportError('''Could not import triton. Please install triton to use SwitchBackLinear.
                               Alternatively, you can use bnb.nn.SwitchBackLinearBnb, but it will be slower''')

Mitchell Wortsman's avatar
Mitchell Wortsman committed
169
        # By default, we use the global quantization.
170
171
        self.vector_wise_quantization = vector_wise_quantization
        if self.vector_wise_quantization:
Mitchell Wortsman's avatar
Mitchell Wortsman committed
172
            self._fn = _switchback_vectorrize
Mitchell Wortsman's avatar
Mitchell Wortsman committed
173
            if mem_efficient:
174
                print('mem efficient is not supported for vector-wise quantization.')
Mitchell Wortsman's avatar
Mitchell Wortsman committed
175
                exit(1)
Mitchell Wortsman's avatar
Mitchell Wortsman committed
176
        else:
Mitchell Wortsman's avatar
Mitchell Wortsman committed
177
178
179
180
            if mem_efficient:
                self._fn = _switchback_global_mem_efficient
            else:
                self._fn = _switchback_global
Mitchell Wortsman's avatar
Mitchell Wortsman committed
181
182

    def prepare_for_eval(self):
Mitchell Wortsman's avatar
Mitchell Wortsman committed
183
184
185
186
187
188
189
190
        # If we just want to do eval, we can pre-quantize the weights instead of doing it on the forward pass.
        # Note this is experimental and not tested thoroughly.
        # Note this needs to be explicitly called with something like
        # def cond_prepare(m):
        #     if hasattr(m, "prepare_for_eval"):
        #         m.prepare_for_eval()
        # model.apply(cond_prepare)
        print('=> preparing for eval.')
191
        if self.vector_wise_quantization:
Mitchell Wortsman's avatar
Mitchell Wortsman committed
192
193
194
            W_int8, state_W = quantize_rowwise(self.weight)
        else:
            W_int8, state_W = quantize_global(self.weight)
195

Mitchell Wortsman's avatar
Mitchell Wortsman committed
196
197
198
199
200
201
202
        self.register_buffer("W_int8", W_int8)
        self.register_buffer("state_W", state_W)

        del self.weight

    def forward(self, x):
        if self.training:
Mitchell Wortsman's avatar
Mitchell Wortsman committed
203
            return self._fn.apply(x, self.weight, self.bias)
Mitchell Wortsman's avatar
Mitchell Wortsman committed
204
        else:
Mitchell Wortsman's avatar
Mitchell Wortsman committed
205
206
207
            # If it hasn't been "prepared for eval", run the standard forward pass.
            if not hasattr(self, "W_int8"):
                return self._fn.apply(x, self.weight, self.bias)
208

Mitchell Wortsman's avatar
Mitchell Wortsman committed
209
            # Otherwise, use pre-computed weights.
Mitchell Wortsman's avatar
Mitchell Wortsman committed
210
            X = x.view(-1, x.size(-1))
Mitchell Wortsman's avatar
Mitchell Wortsman committed
211
            X_int8, state_X = quantize_rowwise(X)
Mitchell Wortsman's avatar
Mitchell Wortsman committed
212

213
            if self.vector_wise_quantization:
Mitchell Wortsman's avatar
Mitchell Wortsman committed
214
215
216
217
                return int8_matmul_rowwise_dequantize(
                    X_int8, self.W_int8.t(), state_X, self.state_W, self.bias
                ).view(*x.size()[:-1], -1)
            else:
Aarni Koskela's avatar
Aarni Koskela committed
218
                return int8_matmul_mixed_dequantize(
Mitchell Wortsman's avatar
Mitchell Wortsman committed
219
220
                    X_int8, self.W_int8.t(), state_X, self.state_W, self.bias
                ).view(*x.size()[:-1], -1)
Mitchell Wortsman's avatar
Mitchell Wortsman committed
221

222
223
224
SwitchBackLinearGlobal = partial(SwitchBackLinear, vector_wise_quantization=False)
SwitchBackLinearGlobalMemEfficient = partial(SwitchBackLinear, vector_wise_quantization=False, mem_efficient=True)
SwitchBackLinearVectorwise = partial(SwitchBackLinear, vector_wise_quantization=True)
Mitchell Wortsman's avatar
Mitchell Wortsman committed
225

Mitchell Wortsman's avatar
Mitchell Wortsman committed
226
# This is just the standard linear function.
Tim Dettmers's avatar
Tim Dettmers committed
227
class StandardLinearFunction(torch.autograd.Function):
Mitchell Wortsman's avatar
Mitchell Wortsman committed
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
    @staticmethod
    def forward(ctx, input, weight, bias=None):
        X = input.view(-1, input.size(-1))

        ctx.save_for_backward(X, weight, bias)
        output = input.matmul(weight.t())
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output.view(*input.size()[:-1], -1)

    @staticmethod
    def backward(ctx, grad_output_3D):
        input, weight, bias = ctx.saved_tensors

        grad_output = grad_output_3D.reshape(-1, grad_output_3D.size(-1))

        grad_input = grad_weight = grad_bias = None

        if ctx.needs_input_grad[0]:
            grad_input = grad_output.matmul(weight.to(grad_output.dtype)).view(*grad_output_3D.size()[:-1], -1)
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().matmul(input.to(grad_output.dtype))
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0)

        return grad_input, grad_weight, grad_bias

Tim Dettmers's avatar
Tim Dettmers committed
255
class StandardLinear(nn.Linear):
Mitchell Wortsman's avatar
Mitchell Wortsman committed
256
257

    def forward(self, x):
Tim Dettmers's avatar
Tim Dettmers committed
258
        return StandardLinearFunction.apply(x, self.weight, self.bias)