test_modules.py 19.9 KB
Newer Older
1
2
from itertools import product

3
4
import pytest
import torch
Tim Dettmers's avatar
Tim Dettmers committed
5
6
from torch import nn

7
8
import bitsandbytes as bnb

9

Tim Dettmers's avatar
Tim Dettmers committed
10
11
12
13
14
class MockArgs(object):
    def __init__(self, initial_data):
        for key in initial_data:
            setattr(self, key, initial_data[key])

15

Tim Dettmers's avatar
Tim Dettmers committed
16
class MLP8bit(torch.nn.Module):
justheuristic's avatar
justheuristic committed
17
    def __init__(self, dim1, dim2, has_fp16_weights=True, memory_efficient_backward=False, threshold=0.0):
Tim Dettmers's avatar
Tim Dettmers committed
18
        super(MLP8bit, self).__init__()
19
        self.fc1 = bnb.nn.Linear8bitLt(
justheuristic's avatar
justheuristic committed
20
21
            dim1, dim2, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward,
            threshold=threshold
22
23
        )
        self.fc2 = bnb.nn.Linear8bitLt(
justheuristic's avatar
justheuristic committed
24
25
            dim2, dim1, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward,
            threshold=threshold
26
        )
Tim Dettmers's avatar
Tim Dettmers committed
27
28
29
30
31
32
33
34
35

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x


def get_args():
    args = MockArgs([])
36
37
    args.quant_type = "vector"
    args.use_8bit_training = "full"
Tim Dettmers's avatar
Tim Dettmers committed
38
39
40
    args.clip_freq = 9999
    return args

41

Tim Dettmers's avatar
Tim Dettmers committed
42
43
def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10):
    idx = torch.isclose(a, b, rtol, atol)
44
    sumval = (idx == 0).sum().item()
Tim Dettmers's avatar
Tim Dettmers committed
45
    if sumval > count:
46
        print(f"Too many values not close: assert {sumval} < {count}")
Tim Dettmers's avatar
Tim Dettmers committed
47
48
49
        torch.testing.assert_allclose(a, b, rtol, atol)


50
class LinearFunction(torch.autograd.Function):
Tim Dettmers's avatar
Tim Dettmers committed
51
52
    @staticmethod
    def get_8bit_linear_trimmed(x, stochastic=False, trim_value=3.0):
53
54
55
        round_func = (
            LinearFunction.round_stoachastic if stochastic else torch.round
        )
56
57
        norm = math.sqrt(math.pi) / math.sqrt(2.0)
        # std = torch.abs(x).mean()*norm
Tim Dettmers's avatar
Tim Dettmers committed
58
        std = torch.std(x)
59
60
        max1 = std * trim_value
        x = x / max1 * 127
Tim Dettmers's avatar
Tim Dettmers committed
61
62
63
        x = round_func(x)
        x[x > 127] = 127
        x[x < -127] = -127
64
        x = x / 127 * max1
Tim Dettmers's avatar
Tim Dettmers committed
65
66
67
68

        return x

    def quant(x, quant_type, dim=1):
69
        if quant_type == "linear":
Tim Dettmers's avatar
Tim Dettmers committed
70
            max1 = torch.abs(x).max().float()
71
            xq = torch.round(x / max1 * 127).to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
72
            return xq, max1
73
        elif quant_type == "vector":
Tim Dettmers's avatar
Tim Dettmers committed
74
            max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
75
            xq = torch.round(x / max1 * 127).to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
76
            return xq, max1
77
        elif quant_type == "min-max":
Tim Dettmers's avatar
Tim Dettmers committed
78
79
            maxA = torch.amax(x, dim=dim, keepdim=True).float()
            minA = torch.amin(x, dim=dim, keepdim=True).float()
80
81
            scale = (maxA - minA) / 2.0
            xq = torch.round(127 * (x - minA - scale) / scale).to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
82
            return xq, (minA.float(), scale.float())
83
84
        else:
            return None
Tim Dettmers's avatar
Tim Dettmers committed
85
86

    def dequant(xq, S1, S2, dtype, quant_type):
87
88
        if quant_type == "linear":
            norm = S1 * S2 / (127 * 127)
Tim Dettmers's avatar
Tim Dettmers committed
89
            # double cast needed to prevent overflows
90
91
            return (xq.float() * norm).to(dtype)
        elif quant_type == "vector":
Tim Dettmers's avatar
Tim Dettmers committed
92
            x = xq.float()
93
94
95
96
97
            if len(xq.shape) == 2 and len(S1.shape) == 3:
                S1 = S1.squeeze(0)
            if len(xq.shape) == 2 and len(S2.shape) == 3:
                S2 = S2.squeeze(0)
            # print(x.shape, S1.shape, S2.shape)
Tim Dettmers's avatar
Tim Dettmers committed
98
            if len(S1.shape) == 2:
99
                x *= S1.t() / 127
Tim Dettmers's avatar
Tim Dettmers committed
100
            else:
101
102
                x *= S1 / 127
            x *= S2 / 127
Tim Dettmers's avatar
Tim Dettmers committed
103
            return x.to(dtype)
104
105
        else:
            return None
Tim Dettmers's avatar
Tim Dettmers committed
106
107

    def dequant_min_max(xq, A, B, SA, SB, dtype):
108
        offset = B.float().t().sum(0) * (SA[0] + SA[1])
Tim Dettmers's avatar
Tim Dettmers committed
109
        x = xq.float()
110
111
112
113
        if len(xq.shape) == 2 and len(SB.shape) == 3:
            SB = SB.squeeze(0)
        if len(xq.shape) == 2 and len(SA.shape) == 3:
            SA = SA.squeeze(0)
Tim Dettmers's avatar
Tim Dettmers committed
114
        if len(SB.shape) == 2:
115
            x *= SB.t() / 127
Tim Dettmers's avatar
Tim Dettmers committed
116
        else:
117
118
119
            x *= SB / 127
        x *= SA[1] / 127
        x += offset
Tim Dettmers's avatar
Tim Dettmers committed
120
121
122
        return x.to(dtype)

    def get_8bit_linear(x, stochastic=False):
123
124
125
        round_func = (
            LinearFunction.round_stoachastic if stochastic else torch.round
        )
Tim Dettmers's avatar
Tim Dettmers committed
126
        max1 = torch.abs(x).max()
127
128
129
        x = x / max1 * 127
        x = round_func(x) / 127 * max1
        # x = torch.round(x)/128*max1
Tim Dettmers's avatar
Tim Dettmers committed
130
131
132
133
        return x

    @staticmethod
    def get_8bit_vector_wise(x, dim, stochastic=False):
134
135
136
        round_func = (
            LinearFunction.round_stoachastic if stochastic else torch.round
        )
Tim Dettmers's avatar
Tim Dettmers committed
137
        max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
138
139
140
        max1[max1 == 0] = 1.0
        x = (x * 127) / max1
        x = round_func(x) / 127 * max1
Tim Dettmers's avatar
Tim Dettmers committed
141
142
143
144
145
146
        return x

    @staticmethod
    def round_stoachastic(x):
        sign = torch.sign(x)
        absx = torch.abs(x)
147
        decimal = absx - torch.floor(absx)
Tim Dettmers's avatar
Tim Dettmers committed
148
        rdm = torch.rand_like(decimal)
149
        return sign * (torch.floor(absx) + (rdm < decimal).to(x.dtype))
Tim Dettmers's avatar
Tim Dettmers committed
150
151
152
153
154
155
156
157
158
159
160
161
162

    @staticmethod
    def fake_8bit_storage(w, exponent_bits):
        code = bnb.functional.create_dynamic_map(n=exponent_bits).to(w.device)
        absmax, C = bnb.functional.quantize_blockwise(w.data, code=code)
        out = bnb.functional.dequantize_blockwise(absmax, C, code)
        out = out.half()
        w.copy_(out)
        return out

    @staticmethod
    def fake_8bit_storage_quantile(w, args):
        code = bnb.functional.estimate_quantiles(w.data, offset=args.offset)
163
164
165
166
        # C = bnb.functional.quantize_no_absmax(code, w)
        # out = bnb.functional.dequantize_no_absmax(code, C, out=w.data)
        # print(out)
        # out = out.half()
Tim Dettmers's avatar
Tim Dettmers committed
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
        code /= torch.max(torch.abs(code))
        absmax, C = bnb.functional.quantize_blockwise(w.data, code=code)
        out = bnb.functional.dequantize_blockwise(absmax, C, code)
        out = out.half()
        w.copy_(out)
        return out

    @staticmethod
    def fake_8bit_storage_stoachstic(w):
        rand = torch.rand(1024, device=w.device)
        absmax, C = bnb.functional.quantize_blockwise(w.data, rand=rand)
        out = bnb.functional.dequantize_blockwise(absmax, C)
        out = out.half()
        w.copy_(out)
        return out

    @staticmethod
    def fake_8bit_storage_with_max(w, topk=8):
185
        blocked_w = einops.rearrange(w.flatten(), "(h b) -> h b", b=256)
Tim Dettmers's avatar
Tim Dettmers committed
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
        max_val, idx = torch.sort(torch.abs(blocked_w), dim=1, descending=True)
        idx = idx[:, :topk]
        max_val = max_val[:, :topk]

        mask = torch.zeros_like(blocked_w)
        mask.scatter_(dim=1, index=idx, src=torch.ones_like(max_val))
        mask = mask.bool()

        # 1. zero out max values
        # 2. quantize + dequantize
        # 3. write back max values
        # 4. copy matrix back to weight

        values = blocked_w[mask]
        blocked_w[mask] = 0

        code = bnb.functional.create_dynamic_map()
        code = code.to(w.device)
        absmax, C = bnb.functional.quantize_blockwise(blocked_w.data)
        bnb.functional.dequantize_blockwise(absmax, C, out=blocked_w)

        blocked_w[mask] = values

        unblocked_w = blocked_w.flatten().view(w.shape)

        w.copy_(unblocked_w)
        return unblocked_w

    @staticmethod
    def forward(ctx, x, weight, bias=None, args=None):
216
        if args.use_8bit_training != "off":
Tim Dettmers's avatar
Tim Dettmers committed
217
218
219
            weight8, S1 = LinearFunction.quant(weight, args.quant_type, dim=1)
            x8, S2 = LinearFunction.quant(x, args.quant_type, dim=2)
            outputq = bnb.functional.igemm(x8, weight8.t())
220
221
222
            output = LinearFunction.dequant(
                outputq, S1, S2, x.dtype, args.quant_type
            )
223
224
225
226
227
            # if torch.rand(1) < 0.01:
            # output32 = torch.matmul(x, weight.t())
            # err = torch.abs(output-output32).float()
            # relerr = err/(torch.abs(output32).float()+1e-8)
            # print(f'{err.mean().item():.4f}, {relerr.mean().item():.4f}', args.quant_type, 'forward', proxy)
Tim Dettmers's avatar
Tim Dettmers committed
228
        else:
229
230
            # output = torch.matmul(x, weight.t())
            output = torch.einsum("bsi,oi->bso", x, weight)
Tim Dettmers's avatar
Tim Dettmers committed
231
232
233
234
235
236
237
238
239
240
241
242
243
244

        ctx.save_for_backward(x, weight, bias)
        ctx.args = args

        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        x, weight, bias = ctx.saved_tensors
        args = ctx.args
        stochastic = False
        grad_input = grad_weight = grad_bias = None
245
246
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0)
Tim Dettmers's avatar
Tim Dettmers committed
247
248
249

        # weight and x are already 8bit
        # -> transform grad_output to 8-bit
250
251
252
253
        if args.use_8bit_training == "forward+wgrad":
            grad_output8, S1 = LinearFunction.quant(
                grad_output, args.quant_type, dim=[0, 1]
            )
Tim Dettmers's avatar
Tim Dettmers committed
254
255
            x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1])
            grad_weight8 = bnb.functional.igemm(grad_output8, x8)
256
257
258
            grad_weight = LinearFunction.dequant(
                grad_weight8, S1, S2, grad_output.dtype, args.quant_type
            )
Tim Dettmers's avatar
Tim Dettmers committed
259

260
            # grad_weight32 = torch.einsum('bso,bsi->oi', grad_output, x)
Tim Dettmers's avatar
Tim Dettmers committed
261
262

            grad_input = grad_output.matmul(weight)
263
264
265
266
        elif args.use_8bit_training == "full":
            grad_output8, S1 = LinearFunction.quant(
                grad_output, args.quant_type, dim=[0, 1]
            )
Tim Dettmers's avatar
Tim Dettmers committed
267
268
269
            x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1])
            grad_weight8 = torch.zeros_like(weight, dtype=torch.int32)
            bnb.functional.igemm(grad_output8, x8, out=grad_weight8)
270
271
272
            grad_weight = LinearFunction.dequant(
                grad_weight8, S1, S2, grad_output.dtype, args.quant_type
            )
Tim Dettmers's avatar
Tim Dettmers committed
273

274
275
276
            grad_output8, S1 = LinearFunction.quant(
                grad_output, args.quant_type, dim=2
            )
Tim Dettmers's avatar
Tim Dettmers committed
277
278
            weight8, S3 = LinearFunction.quant(weight, args.quant_type, dim=0)
            grad_input8 = bnb.functional.igemm(grad_output8, weight8)
279
280
281
            grad_input = LinearFunction.dequant(
                grad_input8, S1, S3, grad_output.dtype, args.quant_type
            )
Tim Dettmers's avatar
Tim Dettmers committed
282
283
284

        else:
            grad_input = grad_output.matmul(weight)
285
            grad_weight = torch.einsum("bsi,bso->oi", x, grad_output)
286

Tim Dettmers's avatar
Tim Dettmers committed
287
        return grad_input, grad_weight, grad_bias, None
288

289

Tim Dettmers's avatar
Tim Dettmers committed
290
291
292
293
294
295
class Linear8bit(nn.Module):
    def __init__(self, input_features, output_features, bias=True, args=None):
        super(Linear8bit, self).__init__()
        self.input_features = input_features
        self.output_features = output_features
        self.args = args
296

Tim Dettmers's avatar
Tim Dettmers committed
297
298
299
300
        self.weight = nn.Parameter(torch.empty(output_features, input_features))
        if bias:
            self.bias = nn.Parameter(torch.empty(output_features))
        else:
301
            self.register_parameter("bias", None)
302

Tim Dettmers's avatar
Tim Dettmers committed
303
304
305
306
307
308
309
310
311
312
313
314
        torch.nn.init.xavier_uniform_(self.weight)
        if self.bias is not None:
            torch.nn.init.zeros_(self.bias)

    def forward(self, x):
        self.args.training = self.training

        return LinearFunction.apply(x, self.weight, self.bias, self.args)


def test_linear8bit():
    l0 = torch.nn.Linear(32, 64).cuda().half()
315
    l1 = bnb.nn.Linear8bit(32, 64, args=get_args()).cuda().half()
Tim Dettmers's avatar
Tim Dettmers committed
316
    l2 = Linear8bit(32, 64, args=get_args()).cuda().half()
317
    l3 = bnb.nn.Linear8bitLt(32, 64).cuda().half()
Tim Dettmers's avatar
Tim Dettmers committed
318
319
320
321
322
323
324
325
326
327
328

    l0.weight.data = l2.weight.data.clone()
    l0.bias.data = l2.bias.data.clone()

    l1.weight.data = l2.weight.data.clone()
    l1.bias.data = l2.bias.data.clone()

    l3.weight.data = l2.weight.data.clone()
    l3.bias.data = l2.bias.data.clone()

    for i in range(100):
329
330
        b1 = torch.randn(16, 8, 32, device="cuda").half()
        t = torch.randn(16, 8, 64, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
        b2 = b1.clone()
        b3 = b1.clone()
        b0 = b1.clone()

        o0 = l0(b0)
        o1 = l1(b1)
        o2 = l2(b2)
        o3 = l3(b3)

        assert_all_approx_close(o1, o2, atol=0.013, rtol=0.05, count=1)
        assert_all_approx_close(o3, o2, atol=0.013, rtol=0.05, count=1)

        loss0 = torch.nn.functional.mse_loss(o0, t)
        loss1 = torch.nn.functional.mse_loss(o1, t)
        loss2 = torch.nn.functional.mse_loss(o2, t)
        loss3 = torch.nn.functional.mse_loss(o3, t)

        loss0.backward()
        loss1.backward()
        loss2.backward()
        loss3.backward()

353
354
355
356
357
358
        assert_all_approx_close(
            l1.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2
        )
        assert_all_approx_close(
            l3.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2
        )
359
360
361
362
363
364
        assert_all_approx_close(
            l1.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2
        )
        assert_all_approx_close(
            l3.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2
        )
Tim Dettmers's avatar
Tim Dettmers committed
365

366
367
368
        err1 = torch.abs(l0.weight.grad - l1.weight.grad).mean().item()
        err2 = torch.abs(l0.weight.grad - l2.weight.grad).mean().item()
        err3 = torch.abs(l0.weight.grad - l3.weight.grad).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
369

370
371
372
        assert err1 * 0.8 < err2
        assert err2 * 0.8 < err3
        assert err3 * 0.8 < err1
Tim Dettmers's avatar
Tim Dettmers committed
373
374
375
376
377
378
379
380
381
382
383
384
385

        l0.weight.grad = None
        l1.weight.grad = None
        l2.weight.grad = None
        l3.weight.grad = None
        l0.bias.grad = None
        l1.bias.grad = None
        l2.bias.grad = None
        l3.bias.grad = None


threshold = [0.0, 3.0]
values = threshold
386
387
388
names = ["threshold_{0}".format(vals) for vals in values]


Tim Dettmers's avatar
Tim Dettmers committed
389
390
@pytest.mark.parametrize("threshold", values, ids=names)
def test_linear8bitlt_inference(threshold):
391
392
    l1 = bnb.nn.Linear8bitLt(32, 64, threshold=threshold).cuda().half()
    assert l1.weight.device.type == "cuda"
Tim Dettmers's avatar
Tim Dettmers committed
393
394
395
    assert l1.weight.dtype == torch.float16

    l1.eval()
396
    for i in range(100):
397
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
398
399
400
401
        o1 = l1(b1)
        if i == 1:
            assert l1.state.CxB is not None

402

Tim Dettmers's avatar
Tim Dettmers committed
403
def test_linear8bitlt_accumulated_gradient():
404
405
406
    l1 = torch.nn.Sequential(
        *[bnb.nn.Linear8bitLt(32, 32).cuda().half() for i in range(2)]
    )
407
408
409
    l2 = torch.nn.Sequential(
        *[torch.nn.Linear(32, 32).cuda().half() for i in range(2)]
    )
Tim Dettmers's avatar
Tim Dettmers committed
410
411
412
413
414
415
416
417
418
419
    l2[0].weight = torch.nn.Parameter(l1[0].weight.clone())
    l2[0].bias = torch.nn.Parameter(l1[0].bias.clone())
    l2[1].weight = torch.nn.Parameter(l1[1].weight.clone())
    l2[1].bias = torch.nn.Parameter(l1[1].bias.clone())
    opt1 = bnb.optim.Adam8bit(l1.parameters(), lr=0.001)
    opt2 = bnb.optim.Adam8bit(l2.parameters(), lr=0.001)

    acc_steps = 10

    for i in range(10):
420
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
421
422
423
424
425
426
427
428
429
        o1 = l1(b1)
        o2 = l2(b1)
        loss1 = o1.mean()
        loss2 = o2.mean()
        loss1.backward()
        loss2.backward()
        if i == 2:
            assert l1[0].state.CxB is not None
            assert l1[1].state.CxB is not None
430

Tim Dettmers's avatar
Tim Dettmers committed
431
432
433
434
435
        if i > 0 and i % acc_steps == 0:
            opt1.step()
            opt1.zero_grad(True)
            opt2.step()
            opt2.zero_grad(True)
436
437
438
439
440
441
            assert_all_approx_close(
                l1[0].weight, l2[0].weight, rtol=1.05, atol=0.01, count=2
            )
            assert_all_approx_close(
                l1[1].weight, l2[1].weight, rtol=1.05, atol=0.01, count=2
            )
Tim Dettmers's avatar
Tim Dettmers committed
442
443
444
445
446
447
            # we do this copy because otherwise we have small divergences over time that add up
            l1[0].weight.data.copy_(l2[0].weight.data)
            l1[1].weight.data.copy_(l2[1].weight.data)
        else:
            torch.testing.assert_allclose(l1[0].weight.grad, l2[0].weight.grad)
            torch.testing.assert_allclose(l1[1].weight.grad, l2[1].weight.grad)
448
449


Tim Dettmers's avatar
Tim Dettmers committed
450
451
threshold = [0.0, 2.0]
values = threshold
452
453
454
names = ["threshold_{0}".format(vals) for vals in values]


Tim Dettmers's avatar
Tim Dettmers committed
455
@pytest.mark.parametrize("threshold", values, ids=names)
justheuristic's avatar
justheuristic committed
456
457
@pytest.mark.parametrize("memory_efficient_backward", [True, False])
def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
458
    l1 = (
justheuristic's avatar
justheuristic committed
459
460
461
        bnb.nn.Linear8bitLt(
            32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward
        )
462
463
464
        .cuda()
        .half()
    )
Tim Dettmers's avatar
Tim Dettmers committed
465
    assert l1.weight.dtype == torch.int8
466

Tim Dettmers's avatar
Tim Dettmers committed
467
468
    l1.eval()
    for i in range(100):
469
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
470
471
472
473
474
475
        o1 = l1(b1)
        assert o1.dtype == torch.float16

    mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda()
    assert mlp.fc1.weight.dtype == torch.int8
    assert mlp.fc2.weight.dtype == torch.int8
476

Tim Dettmers's avatar
Tim Dettmers committed
477
    for i in range(100):
478
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
479
480
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
481
482
483
484
        if threshold > 0:
            assert mlp.fc1.state.idx is not None
        if threshold > 0:
            assert mlp.fc2.state.idx is not None
485

486
487
488
489
490
    mlp = (
        MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False)
        .cuda()
        .half()
    )
Tim Dettmers's avatar
Tim Dettmers committed
491
492
493
494
    assert mlp.fc1.weight.dtype == torch.int8
    assert mlp.fc2.weight.dtype == torch.int8

    for i in range(100):
495
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
496
497
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
498
499
500
501
        if threshold > 0:
            assert mlp.fc1.state.idx is not None
        if threshold > 0:
            assert mlp.fc2.state.idx is not None
502

503
504
505
506
507
    mlp = (
        MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False)
        .half()
        .cuda()
    )
Tim Dettmers's avatar
Tim Dettmers committed
508
509

    for i in range(100):
510
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
511
512
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
513
514
515
516
        if threshold > 0:
            assert mlp.fc1.state.idx is not None
        if threshold > 0:
            assert mlp.fc2.state.idx is not None
Tim Dettmers's avatar
Tim Dettmers committed
517
518
519
    assert mlp.fc1.weight.dtype == torch.int8
    assert mlp.fc2.weight.dtype == torch.int8

520
    mlp = (
justheuristic's avatar
justheuristic committed
521
522
523
        MLP8bit(
            32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward
        )
524
525
526
        .half()
        .to("cuda")
    )
Tim Dettmers's avatar
Tim Dettmers committed
527
528

    for i in range(100):
529
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
530
531
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
532
533
534
535
        if threshold > 0:
            assert mlp.fc1.state.idx is not None
        if threshold > 0:
            assert mlp.fc2.state.idx is not None
Tim Dettmers's avatar
Tim Dettmers committed
536
537
    assert mlp.fc1.weight.dtype == torch.int8
    assert mlp.fc2.weight.dtype == torch.int8
538
539
    assert mlp.fc1.weight.device.type == "cuda"
    assert mlp.fc2.weight.device.type == "cuda"
Tim Dettmers's avatar
Tim Dettmers committed
540

justheuristic's avatar
justheuristic committed
541
    mlp = MLP8bit(
justheuristic's avatar
justheuristic committed
542
543
            32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward
        )
justheuristic's avatar
justheuristic committed
544
    w1, w2 = mlp.fc1.weight.clone().cuda(), mlp.fc2.weight.clone().cuda()  # grab weights before quantization,
justheuristic's avatar
justheuristic committed
545
    mlp = mlp.cuda().half()  # and this line triggers quantization
Tim Dettmers's avatar
Tim Dettmers committed
546
547

    for i in range(100):
548
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
549
550
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
551
552
553
554
        if threshold > 0:
            assert mlp.fc1.state.idx is not None
        if threshold > 0:
            assert mlp.fc2.state.idx is not None
justheuristic's avatar
justheuristic committed
555

Tim Dettmers's avatar
Tim Dettmers committed
556
557
    assert mlp.fc1.weight.dtype == torch.int8
    assert mlp.fc2.weight.dtype == torch.int8
558
559
    assert mlp.fc1.weight.device.type == "cuda"
    assert mlp.fc2.weight.device.type == "cuda"
560

justheuristic's avatar
justheuristic committed
561
562
563
564
565
566
567
    if memory_efficient_backward:
        b1 = torch.randn(16, 8, 32, device="cuda", requires_grad=True, dtype=torch.half)
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
        assert o1.requires_grad
        grad_proj = torch.randn_like(o1)

justheuristic's avatar
debug  
justheuristic committed
568
        mlp.zero_grad()
justheuristic's avatar
justheuristic committed
569
        (o1 * grad_proj).sum().backward()
justheuristic's avatar
justheuristic committed
570
        grad_ref = grad_proj.flatten(2) @ w2.half() @ w1.half()
justheuristic's avatar
debug  
justheuristic committed
571
572
573
574
575
        assert torch.allclose(b1.grad, grad_ref)




justheuristic's avatar
justheuristic committed
576

577

justheuristic's avatar
justheuristic committed
578

579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
def test_linear8bitlt_fp32_bias():
    # casts model to fp16 -> int8 automatically
    l1 = bnb.nn.Linear8bitLt(32, 64, has_fp16_weights=False).cuda()
    assert l1.weight.dtype == torch.int8
    assert l1.bias.dtype == torch.float32

    for i in range(100):
        b1 = torch.randn(16, 8, 32, device="cuda").half()
        # casts bias to fp32
        o1 = l1(b1)
        assert l1.bias.dtype == torch.float16

    # casts model to fp16 -> int8 automatically
    l1 = bnb.nn.Linear8bitLt(32, 64, has_fp16_weights=False, bias=False).cuda()
    assert l1.weight.dtype == torch.int8
    assert l1.bias is None

    for i in range(100):
        b1 = torch.randn(16, 8, 32, device="cuda").half()
        o1 = l1(b1)
        assert l1.bias is None