test_modules.py 17.9 KB
Newer Older
1
2
from itertools import product

3
4
import pytest
import torch
Tim Dettmers's avatar
Tim Dettmers committed
5
6
from torch import nn

7
8
import bitsandbytes as bnb

9

10
class MockArgs:
Tim Dettmers's avatar
Tim Dettmers committed
11
12
13
14
    def __init__(self, initial_data):
        for key in initial_data:
            setattr(self, key, initial_data[key])

15

Tim Dettmers's avatar
Tim Dettmers committed
16
class MLP8bit(torch.nn.Module):
justheuristic's avatar
justheuristic committed
17
    def __init__(self, dim1, dim2, has_fp16_weights=True, memory_efficient_backward=False, threshold=0.0):
18
        super().__init__()
19
        self.fc1 = bnb.nn.Linear8bitLt(
justheuristic's avatar
justheuristic committed
20
21
            dim1, dim2, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward,
            threshold=threshold
22
23
        )
        self.fc2 = bnb.nn.Linear8bitLt(
justheuristic's avatar
justheuristic committed
24
25
            dim2, dim1, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward,
            threshold=threshold
26
        )
Tim Dettmers's avatar
Tim Dettmers committed
27
28
29
30
31
32
33
34
35

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x


def get_args():
    args = MockArgs([])
36
37
    args.quant_type = "vector"
    args.use_8bit_training = "full"
Tim Dettmers's avatar
Tim Dettmers committed
38
39
40
    args.clip_freq = 9999
    return args

41

Tim Dettmers's avatar
Tim Dettmers committed
42
43
def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10):
    idx = torch.isclose(a, b, rtol, atol)
44
    sumval = (idx == 0).sum().item()
Tim Dettmers's avatar
Tim Dettmers committed
45
    if sumval > count:
46
        print(f"Too many values not close: assert {sumval} < {count}")
Tim Dettmers's avatar
Tim Dettmers committed
47
48
49
        torch.testing.assert_allclose(a, b, rtol, atol)


50
class LinearFunction(torch.autograd.Function):
Tim Dettmers's avatar
Tim Dettmers committed
51
52
    @staticmethod
    def get_8bit_linear_trimmed(x, stochastic=False, trim_value=3.0):
53
54
55
        round_func = (
            LinearFunction.round_stoachastic if stochastic else torch.round
        )
56
57
        norm = math.sqrt(math.pi) / math.sqrt(2.0)
        # std = torch.abs(x).mean()*norm
Tim Dettmers's avatar
Tim Dettmers committed
58
        std = torch.std(x)
59
60
        max1 = std * trim_value
        x = x / max1 * 127
Tim Dettmers's avatar
Tim Dettmers committed
61
62
63
        x = round_func(x)
        x[x > 127] = 127
        x[x < -127] = -127
64
        x = x / 127 * max1
Tim Dettmers's avatar
Tim Dettmers committed
65
66
67
68

        return x

    def quant(x, quant_type, dim=1):
69
        if quant_type == "linear":
Tim Dettmers's avatar
Tim Dettmers committed
70
            max1 = torch.abs(x).max().float()
71
            xq = torch.round(x / max1 * 127).to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
72
            return xq, max1
73
        elif quant_type == "vector":
Tim Dettmers's avatar
Tim Dettmers committed
74
            max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
75
            xq = torch.round(x / max1 * 127).to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
76
            return xq, max1
77
        elif quant_type == "min-max":
Tim Dettmers's avatar
Tim Dettmers committed
78
79
            maxA = torch.amax(x, dim=dim, keepdim=True).float()
            minA = torch.amin(x, dim=dim, keepdim=True).float()
80
81
            scale = (maxA - minA) / 2.0
            xq = torch.round(127 * (x - minA - scale) / scale).to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
82
            return xq, (minA.float(), scale.float())
83
84
        else:
            return None
Tim Dettmers's avatar
Tim Dettmers committed
85
86

    def dequant(xq, S1, S2, dtype, quant_type):
87
88
        if quant_type == "linear":
            norm = S1 * S2 / (127 * 127)
Tim Dettmers's avatar
Tim Dettmers committed
89
            # double cast needed to prevent overflows
90
91
            return (xq.float() * norm).to(dtype)
        elif quant_type == "vector":
Tim Dettmers's avatar
Tim Dettmers committed
92
            x = xq.float()
93
94
95
96
97
            if len(xq.shape) == 2 and len(S1.shape) == 3:
                S1 = S1.squeeze(0)
            if len(xq.shape) == 2 and len(S2.shape) == 3:
                S2 = S2.squeeze(0)
            # print(x.shape, S1.shape, S2.shape)
Tim Dettmers's avatar
Tim Dettmers committed
98
            if len(S1.shape) == 2:
99
                x *= S1.t() / 127
Tim Dettmers's avatar
Tim Dettmers committed
100
            else:
101
102
                x *= S1 / 127
            x *= S2 / 127
Tim Dettmers's avatar
Tim Dettmers committed
103
            return x.to(dtype)
104
105
        else:
            return None
Tim Dettmers's avatar
Tim Dettmers committed
106
107

    def dequant_min_max(xq, A, B, SA, SB, dtype):
108
        offset = B.float().t().sum(0) * (SA[0] + SA[1])
Tim Dettmers's avatar
Tim Dettmers committed
109
        x = xq.float()
110
111
112
113
        if len(xq.shape) == 2 and len(SB.shape) == 3:
            SB = SB.squeeze(0)
        if len(xq.shape) == 2 and len(SA.shape) == 3:
            SA = SA.squeeze(0)
Tim Dettmers's avatar
Tim Dettmers committed
114
        if len(SB.shape) == 2:
115
            x *= SB.t() / 127
Tim Dettmers's avatar
Tim Dettmers committed
116
        else:
117
118
119
            x *= SB / 127
        x *= SA[1] / 127
        x += offset
Tim Dettmers's avatar
Tim Dettmers committed
120
121
122
        return x.to(dtype)

    def get_8bit_linear(x, stochastic=False):
123
124
125
        round_func = (
            LinearFunction.round_stoachastic if stochastic else torch.round
        )
Tim Dettmers's avatar
Tim Dettmers committed
126
        max1 = torch.abs(x).max()
127
128
129
        x = x / max1 * 127
        x = round_func(x) / 127 * max1
        # x = torch.round(x)/128*max1
Tim Dettmers's avatar
Tim Dettmers committed
130
131
132
133
        return x

    @staticmethod
    def get_8bit_vector_wise(x, dim, stochastic=False):
134
135
136
        round_func = (
            LinearFunction.round_stoachastic if stochastic else torch.round
        )
Tim Dettmers's avatar
Tim Dettmers committed
137
        max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
138
139
140
        max1[max1 == 0] = 1.0
        x = (x * 127) / max1
        x = round_func(x) / 127 * max1
Tim Dettmers's avatar
Tim Dettmers committed
141
142
143
144
145
146
        return x

    @staticmethod
    def round_stoachastic(x):
        sign = torch.sign(x)
        absx = torch.abs(x)
147
        decimal = absx - torch.floor(absx)
Tim Dettmers's avatar
Tim Dettmers committed
148
        rdm = torch.rand_like(decimal)
149
        return sign * (torch.floor(absx) + (rdm < decimal).to(x.dtype))
Tim Dettmers's avatar
Tim Dettmers committed
150
151
152
153
154
155
156
157
158
159
160
161
162

    @staticmethod
    def fake_8bit_storage(w, exponent_bits):
        code = bnb.functional.create_dynamic_map(n=exponent_bits).to(w.device)
        absmax, C = bnb.functional.quantize_blockwise(w.data, code=code)
        out = bnb.functional.dequantize_blockwise(absmax, C, code)
        out = out.half()
        w.copy_(out)
        return out

    @staticmethod
    def fake_8bit_storage_quantile(w, args):
        code = bnb.functional.estimate_quantiles(w.data, offset=args.offset)
163
164
165
166
        # C = bnb.functional.quantize_no_absmax(code, w)
        # out = bnb.functional.dequantize_no_absmax(code, C, out=w.data)
        # print(out)
        # out = out.half()
Tim Dettmers's avatar
Tim Dettmers committed
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
        code /= torch.max(torch.abs(code))
        absmax, C = bnb.functional.quantize_blockwise(w.data, code=code)
        out = bnb.functional.dequantize_blockwise(absmax, C, code)
        out = out.half()
        w.copy_(out)
        return out

    @staticmethod
    def fake_8bit_storage_stoachstic(w):
        rand = torch.rand(1024, device=w.device)
        absmax, C = bnb.functional.quantize_blockwise(w.data, rand=rand)
        out = bnb.functional.dequantize_blockwise(absmax, C)
        out = out.half()
        w.copy_(out)
        return out

    @staticmethod
    def fake_8bit_storage_with_max(w, topk=8):
185
        blocked_w = einops.rearrange(w.flatten(), "(h b) -> h b", b=256)
Tim Dettmers's avatar
Tim Dettmers committed
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
        max_val, idx = torch.sort(torch.abs(blocked_w), dim=1, descending=True)
        idx = idx[:, :topk]
        max_val = max_val[:, :topk]

        mask = torch.zeros_like(blocked_w)
        mask.scatter_(dim=1, index=idx, src=torch.ones_like(max_val))
        mask = mask.bool()

        # 1. zero out max values
        # 2. quantize + dequantize
        # 3. write back max values
        # 4. copy matrix back to weight

        values = blocked_w[mask]
        blocked_w[mask] = 0

        code = bnb.functional.create_dynamic_map()
        code = code.to(w.device)
        absmax, C = bnb.functional.quantize_blockwise(blocked_w.data)
        bnb.functional.dequantize_blockwise(absmax, C, out=blocked_w)

        blocked_w[mask] = values

        unblocked_w = blocked_w.flatten().view(w.shape)

        w.copy_(unblocked_w)
        return unblocked_w

    @staticmethod
    def forward(ctx, x, weight, bias=None, args=None):
216
        if args.use_8bit_training != "off":
Tim Dettmers's avatar
Tim Dettmers committed
217
218
219
            weight8, S1 = LinearFunction.quant(weight, args.quant_type, dim=1)
            x8, S2 = LinearFunction.quant(x, args.quant_type, dim=2)
            outputq = bnb.functional.igemm(x8, weight8.t())
220
221
222
            output = LinearFunction.dequant(
                outputq, S1, S2, x.dtype, args.quant_type
            )
223
224
225
226
227
            # if torch.rand(1) < 0.01:
            # output32 = torch.matmul(x, weight.t())
            # err = torch.abs(output-output32).float()
            # relerr = err/(torch.abs(output32).float()+1e-8)
            # print(f'{err.mean().item():.4f}, {relerr.mean().item():.4f}', args.quant_type, 'forward', proxy)
Tim Dettmers's avatar
Tim Dettmers committed
228
        else:
229
230
            # output = torch.matmul(x, weight.t())
            output = torch.einsum("bsi,oi->bso", x, weight)
Tim Dettmers's avatar
Tim Dettmers committed
231
232
233
234
235
236
237
238
239
240
241
242
243
244

        ctx.save_for_backward(x, weight, bias)
        ctx.args = args

        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        x, weight, bias = ctx.saved_tensors
        args = ctx.args
        stochastic = False
        grad_input = grad_weight = grad_bias = None
245
246
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0)
Tim Dettmers's avatar
Tim Dettmers committed
247
248
249

        # weight and x are already 8bit
        # -> transform grad_output to 8-bit
250
251
252
253
        if args.use_8bit_training == "forward+wgrad":
            grad_output8, S1 = LinearFunction.quant(
                grad_output, args.quant_type, dim=[0, 1]
            )
Tim Dettmers's avatar
Tim Dettmers committed
254
255
            x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1])
            grad_weight8 = bnb.functional.igemm(grad_output8, x8)
256
257
258
            grad_weight = LinearFunction.dequant(
                grad_weight8, S1, S2, grad_output.dtype, args.quant_type
            )
Tim Dettmers's avatar
Tim Dettmers committed
259

260
            # grad_weight32 = torch.einsum('bso,bsi->oi', grad_output, x)
Tim Dettmers's avatar
Tim Dettmers committed
261
262

            grad_input = grad_output.matmul(weight)
263
264
265
266
        elif args.use_8bit_training == "full":
            grad_output8, S1 = LinearFunction.quant(
                grad_output, args.quant_type, dim=[0, 1]
            )
Tim Dettmers's avatar
Tim Dettmers committed
267
268
269
            x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1])
            grad_weight8 = torch.zeros_like(weight, dtype=torch.int32)
            bnb.functional.igemm(grad_output8, x8, out=grad_weight8)
270
271
272
            grad_weight = LinearFunction.dequant(
                grad_weight8, S1, S2, grad_output.dtype, args.quant_type
            )
Tim Dettmers's avatar
Tim Dettmers committed
273

274
275
276
            grad_output8, S1 = LinearFunction.quant(
                grad_output, args.quant_type, dim=2
            )
Tim Dettmers's avatar
Tim Dettmers committed
277
278
            weight8, S3 = LinearFunction.quant(weight, args.quant_type, dim=0)
            grad_input8 = bnb.functional.igemm(grad_output8, weight8)
279
280
281
            grad_input = LinearFunction.dequant(
                grad_input8, S1, S3, grad_output.dtype, args.quant_type
            )
Tim Dettmers's avatar
Tim Dettmers committed
282
283
284

        else:
            grad_input = grad_output.matmul(weight)
285
            grad_weight = torch.einsum("bsi,bso->oi", x, grad_output)
286

Tim Dettmers's avatar
Tim Dettmers committed
287
        return grad_input, grad_weight, grad_bias, None
288

289

Tim Dettmers's avatar
Tim Dettmers committed
290
291
class Linear8bit(nn.Module):
    def __init__(self, input_features, output_features, bias=True, args=None):
292
        super().__init__()
Tim Dettmers's avatar
Tim Dettmers committed
293
294
295
        self.input_features = input_features
        self.output_features = output_features
        self.args = args
296

Tim Dettmers's avatar
Tim Dettmers committed
297
298
299
300
        self.weight = nn.Parameter(torch.empty(output_features, input_features))
        if bias:
            self.bias = nn.Parameter(torch.empty(output_features))
        else:
301
            self.register_parameter("bias", None)
302

Tim Dettmers's avatar
Tim Dettmers committed
303
304
305
306
307
308
309
310
311
312
313
314
        torch.nn.init.xavier_uniform_(self.weight)
        if self.bias is not None:
            torch.nn.init.zeros_(self.bias)

    def forward(self, x):
        self.args.training = self.training

        return LinearFunction.apply(x, self.weight, self.bias, self.args)


threshold = [0.0, 3.0]
values = threshold
315
names = [f"threshold_{vals}" for vals in values]
316
317


Tim Dettmers's avatar
Tim Dettmers committed
318
319
@pytest.mark.parametrize("threshold", values, ids=names)
def test_linear8bitlt_inference(threshold):
320
321
    l1 = bnb.nn.Linear8bitLt(32, 64, threshold=threshold).cuda().half()
    assert l1.weight.device.type == "cuda"
Tim Dettmers's avatar
Tim Dettmers committed
322
323
324
    assert l1.weight.dtype == torch.float16

    l1.eval()
325
    for i in range(100):
326
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
327
328
329
330
        o1 = l1(b1)
        if i == 1:
            assert l1.state.CxB is not None

331

Tim Dettmers's avatar
Tim Dettmers committed
332
def test_linear8bitlt_accumulated_gradient():
333
334
335
    l1 = torch.nn.Sequential(
        *[bnb.nn.Linear8bitLt(32, 32).cuda().half() for i in range(2)]
    )
336
337
338
    l2 = torch.nn.Sequential(
        *[torch.nn.Linear(32, 32).cuda().half() for i in range(2)]
    )
Tim Dettmers's avatar
Tim Dettmers committed
339
340
341
342
343
344
345
346
347
348
    l2[0].weight = torch.nn.Parameter(l1[0].weight.clone())
    l2[0].bias = torch.nn.Parameter(l1[0].bias.clone())
    l2[1].weight = torch.nn.Parameter(l1[1].weight.clone())
    l2[1].bias = torch.nn.Parameter(l1[1].bias.clone())
    opt1 = bnb.optim.Adam8bit(l1.parameters(), lr=0.001)
    opt2 = bnb.optim.Adam8bit(l2.parameters(), lr=0.001)

    acc_steps = 10

    for i in range(10):
349
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
350
351
352
353
354
355
356
357
358
        o1 = l1(b1)
        o2 = l2(b1)
        loss1 = o1.mean()
        loss2 = o2.mean()
        loss1.backward()
        loss2.backward()
        if i == 2:
            assert l1[0].state.CxB is not None
            assert l1[1].state.CxB is not None
359

Tim Dettmers's avatar
Tim Dettmers committed
360
361
362
363
364
        if i > 0 and i % acc_steps == 0:
            opt1.step()
            opt1.zero_grad(True)
            opt2.step()
            opt2.zero_grad(True)
365
366
367
368
369
370
            assert_all_approx_close(
                l1[0].weight, l2[0].weight, rtol=1.05, atol=0.01, count=2
            )
            assert_all_approx_close(
                l1[1].weight, l2[1].weight, rtol=1.05, atol=0.01, count=2
            )
Tim Dettmers's avatar
Tim Dettmers committed
371
372
373
374
375
376
            # we do this copy because otherwise we have small divergences over time that add up
            l1[0].weight.data.copy_(l2[0].weight.data)
            l1[1].weight.data.copy_(l2[1].weight.data)
        else:
            torch.testing.assert_allclose(l1[0].weight.grad, l2[0].weight.grad)
            torch.testing.assert_allclose(l1[1].weight.grad, l2[1].weight.grad)
377
378


Tim Dettmers's avatar
Tim Dettmers committed
379
380
threshold = [0.0, 2.0]
values = threshold
381
names = [f"threshold_{vals}" for vals in values]
382
383


Tim Dettmers's avatar
Tim Dettmers committed
384
@pytest.mark.parametrize("threshold", values, ids=names)
justheuristic's avatar
justheuristic committed
385
386
@pytest.mark.parametrize("memory_efficient_backward", [True, False])
def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
387
    l1 = (
justheuristic's avatar
justheuristic committed
388
389
390
        bnb.nn.Linear8bitLt(
            32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward
        )
391
392
393
        .cuda()
        .half()
    )
Tim Dettmers's avatar
Tim Dettmers committed
394
    assert l1.weight.dtype == torch.int8
395

Tim Dettmers's avatar
Tim Dettmers committed
396
397
    l1.eval()
    for i in range(100):
398
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
399
400
401
402
403
404
        o1 = l1(b1)
        assert o1.dtype == torch.float16

    mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda()
    assert mlp.fc1.weight.dtype == torch.int8
    assert mlp.fc2.weight.dtype == torch.int8
405

Tim Dettmers's avatar
Tim Dettmers committed
406
    for i in range(100):
407
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
408
409
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
410
411
412
413
        if threshold > 0:
            assert mlp.fc1.state.idx is not None
        if threshold > 0:
            assert mlp.fc2.state.idx is not None
414

415
416
417
418
419
    mlp = (
        MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False)
        .cuda()
        .half()
    )
Tim Dettmers's avatar
Tim Dettmers committed
420
421
422
423
    assert mlp.fc1.weight.dtype == torch.int8
    assert mlp.fc2.weight.dtype == torch.int8

    for i in range(100):
424
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
425
426
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
427
428
429
430
        if threshold > 0:
            assert mlp.fc1.state.idx is not None
        if threshold > 0:
            assert mlp.fc2.state.idx is not None
431

432
433
434
435
436
    mlp = (
        MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False)
        .half()
        .cuda()
    )
Tim Dettmers's avatar
Tim Dettmers committed
437
438

    for i in range(100):
439
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
440
441
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
442
443
444
445
        if threshold > 0:
            assert mlp.fc1.state.idx is not None
        if threshold > 0:
            assert mlp.fc2.state.idx is not None
Tim Dettmers's avatar
Tim Dettmers committed
446
447
448
    assert mlp.fc1.weight.dtype == torch.int8
    assert mlp.fc2.weight.dtype == torch.int8

449
    mlp = (
justheuristic's avatar
justheuristic committed
450
451
452
        MLP8bit(
            32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward
        )
453
454
455
        .half()
        .to("cuda")
    )
Tim Dettmers's avatar
Tim Dettmers committed
456
457

    for i in range(100):
458
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
459
460
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
461
462
463
464
        if threshold > 0:
            assert mlp.fc1.state.idx is not None
        if threshold > 0:
            assert mlp.fc2.state.idx is not None
Tim Dettmers's avatar
Tim Dettmers committed
465
466
    assert mlp.fc1.weight.dtype == torch.int8
    assert mlp.fc2.weight.dtype == torch.int8
467
468
    assert mlp.fc1.weight.device.type == "cuda"
    assert mlp.fc2.weight.device.type == "cuda"
Tim Dettmers's avatar
Tim Dettmers committed
469

justheuristic's avatar
justheuristic committed
470
    mlp = MLP8bit(
justheuristic's avatar
justheuristic committed
471
472
            32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward
        )
justheuristic's avatar
justheuristic committed
473
    w1, w2 = mlp.fc1.weight.clone().cuda(), mlp.fc2.weight.clone().cuda()  # grab weights before quantization,
justheuristic's avatar
justheuristic committed
474
    mlp = mlp.cuda().half()  # and this line triggers quantization
Tim Dettmers's avatar
Tim Dettmers committed
475
476

    for i in range(100):
477
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
478
479
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
480
481
482
483
        if threshold > 0:
            assert mlp.fc1.state.idx is not None
        if threshold > 0:
            assert mlp.fc2.state.idx is not None
justheuristic's avatar
justheuristic committed
484

Tim Dettmers's avatar
Tim Dettmers committed
485
486
    assert mlp.fc1.weight.dtype == torch.int8
    assert mlp.fc2.weight.dtype == torch.int8
487
488
    assert mlp.fc1.weight.device.type == "cuda"
    assert mlp.fc2.weight.device.type == "cuda"
489

justheuristic's avatar
justheuristic committed
490
491
492
493
494
495
496
    if memory_efficient_backward:
        b1 = torch.randn(16, 8, 32, device="cuda", requires_grad=True, dtype=torch.half)
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
        assert o1.requires_grad
        grad_proj = torch.randn_like(o1)

justheuristic's avatar
debug  
justheuristic committed
497
        mlp.zero_grad()
justheuristic's avatar
justheuristic committed
498
        (o1 * grad_proj).sum().backward()
justheuristic's avatar
justheuristic committed
499
        grad_ref = grad_proj.flatten(2) @ w2.half() @ w1.half()
justheuristic's avatar
justheuristic committed
500
        scale = grad_ref.abs().mean()
justheuristic's avatar
justheuristic committed
501

justheuristic's avatar
review  
justheuristic committed
502
503
        torch.testing.assert_allclose(b1.grad, grad_ref, rtol=0, atol=0.05 * scale)
        idx = torch.isclose(b1.grad, grad_ref, atol=0.01 * scale, rtol=0.1)
Tim Dettmers's avatar
Tim Dettmers committed
504
        assert (idx == 0).sum().item() <= b1.numel() * 0.005
505

justheuristic's avatar
justheuristic committed
506

507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
def test_linear8bitlt_fp32_bias():
    # casts model to fp16 -> int8 automatically
    l1 = bnb.nn.Linear8bitLt(32, 64, has_fp16_weights=False).cuda()
    assert l1.weight.dtype == torch.int8
    assert l1.bias.dtype == torch.float32

    for i in range(100):
        b1 = torch.randn(16, 8, 32, device="cuda").half()
        # casts bias to fp32
        o1 = l1(b1)
        assert l1.bias.dtype == torch.float16

    # casts model to fp16 -> int8 automatically
    l1 = bnb.nn.Linear8bitLt(32, 64, has_fp16_weights=False, bias=False).cuda()
    assert l1.weight.dtype == torch.int8
    assert l1.bias is None

    for i in range(100):
        b1 = torch.randn(16, 8, 32, device="cuda").half()
        o1 = l1(b1)
        assert l1.bias is None