test_modules.py 22.6 KB
Newer Older
1
2
from itertools import product

3
4
import pytest
import torch
Tim Dettmers's avatar
Tim Dettmers committed
5
6
from torch import nn

7
8
import bitsandbytes as bnb

9

10
class MockArgs:
Tim Dettmers's avatar
Tim Dettmers committed
11
12
13
14
    def __init__(self, initial_data):
        for key in initial_data:
            setattr(self, key, initial_data[key])

15

Tim Dettmers's avatar
Tim Dettmers committed
16
class MLP8bit(torch.nn.Module):
justheuristic's avatar
justheuristic committed
17
    def __init__(self, dim1, dim2, has_fp16_weights=True, memory_efficient_backward=False, threshold=0.0):
18
        super().__init__()
19
        self.fc1 = bnb.nn.Linear8bitLt(
justheuristic's avatar
justheuristic committed
20
21
            dim1, dim2, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward,
            threshold=threshold
22
23
        )
        self.fc2 = bnb.nn.Linear8bitLt(
justheuristic's avatar
justheuristic committed
24
25
            dim2, dim1, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward,
            threshold=threshold
26
        )
Tim Dettmers's avatar
Tim Dettmers committed
27
28
29
30
31
32
33
34
35

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x


def get_args():
    args = MockArgs([])
36
37
    args.quant_type = "vector"
    args.use_8bit_training = "full"
Tim Dettmers's avatar
Tim Dettmers committed
38
39
40
    args.clip_freq = 9999
    return args

41

Tim Dettmers's avatar
Tim Dettmers committed
42
43
def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10):
    idx = torch.isclose(a, b, rtol, atol)
44
    sumval = (idx == 0).sum().item()
Tim Dettmers's avatar
Tim Dettmers committed
45
    if sumval > count:
46
        print(f"Too many values not close: assert {sumval} < {count}")
47
        torch.testing.assert_close(a, b, rtol, atol)
Tim Dettmers's avatar
Tim Dettmers committed
48
49


50
class LinearFunction(torch.autograd.Function):
Tim Dettmers's avatar
Tim Dettmers committed
51
52
    @staticmethod
    def get_8bit_linear_trimmed(x, stochastic=False, trim_value=3.0):
53
54
55
        round_func = (
            LinearFunction.round_stoachastic if stochastic else torch.round
        )
56
57
        norm = math.sqrt(math.pi) / math.sqrt(2.0)
        # std = torch.abs(x).mean()*norm
Tim Dettmers's avatar
Tim Dettmers committed
58
        std = torch.std(x)
59
60
        max1 = std * trim_value
        x = x / max1 * 127
Tim Dettmers's avatar
Tim Dettmers committed
61
62
63
        x = round_func(x)
        x[x > 127] = 127
        x[x < -127] = -127
64
        x = x / 127 * max1
Tim Dettmers's avatar
Tim Dettmers committed
65
66
67
68

        return x

    def quant(x, quant_type, dim=1):
69
        if quant_type == "linear":
Tim Dettmers's avatar
Tim Dettmers committed
70
            max1 = torch.abs(x).max().float()
71
            xq = torch.round(x / max1 * 127).to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
72
            return xq, max1
73
        elif quant_type == "vector":
Tim Dettmers's avatar
Tim Dettmers committed
74
            max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
75
            xq = torch.round(x / max1 * 127).to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
76
            return xq, max1
77
        elif quant_type == "min-max":
Tim Dettmers's avatar
Tim Dettmers committed
78
79
            maxA = torch.amax(x, dim=dim, keepdim=True).float()
            minA = torch.amin(x, dim=dim, keepdim=True).float()
80
81
            scale = (maxA - minA) / 2.0
            xq = torch.round(127 * (x - minA - scale) / scale).to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
82
            return xq, (minA.float(), scale.float())
83
84
        else:
            return None
Tim Dettmers's avatar
Tim Dettmers committed
85
86

    def dequant(xq, S1, S2, dtype, quant_type):
87
88
        if quant_type == "linear":
            norm = S1 * S2 / (127 * 127)
Tim Dettmers's avatar
Tim Dettmers committed
89
            # double cast needed to prevent overflows
90
91
            return (xq.float() * norm).to(dtype)
        elif quant_type == "vector":
Tim Dettmers's avatar
Tim Dettmers committed
92
            x = xq.float()
93
94
95
96
97
            if len(xq.shape) == 2 and len(S1.shape) == 3:
                S1 = S1.squeeze(0)
            if len(xq.shape) == 2 and len(S2.shape) == 3:
                S2 = S2.squeeze(0)
            # print(x.shape, S1.shape, S2.shape)
Tim Dettmers's avatar
Tim Dettmers committed
98
            if len(S1.shape) == 2:
99
                x *= S1.t() / 127
Tim Dettmers's avatar
Tim Dettmers committed
100
            else:
101
102
                x *= S1 / 127
            x *= S2 / 127
Tim Dettmers's avatar
Tim Dettmers committed
103
            return x.to(dtype)
104
105
        else:
            return None
Tim Dettmers's avatar
Tim Dettmers committed
106
107

    def dequant_min_max(xq, A, B, SA, SB, dtype):
108
        offset = B.float().t().sum(0) * (SA[0] + SA[1])
Tim Dettmers's avatar
Tim Dettmers committed
109
        x = xq.float()
110
111
112
113
        if len(xq.shape) == 2 and len(SB.shape) == 3:
            SB = SB.squeeze(0)
        if len(xq.shape) == 2 and len(SA.shape) == 3:
            SA = SA.squeeze(0)
Tim Dettmers's avatar
Tim Dettmers committed
114
        if len(SB.shape) == 2:
115
            x *= SB.t() / 127
Tim Dettmers's avatar
Tim Dettmers committed
116
        else:
117
118
119
            x *= SB / 127
        x *= SA[1] / 127
        x += offset
Tim Dettmers's avatar
Tim Dettmers committed
120
121
122
        return x.to(dtype)

    def get_8bit_linear(x, stochastic=False):
123
124
125
        round_func = (
            LinearFunction.round_stoachastic if stochastic else torch.round
        )
Tim Dettmers's avatar
Tim Dettmers committed
126
        max1 = torch.abs(x).max()
127
128
129
        x = x / max1 * 127
        x = round_func(x) / 127 * max1
        # x = torch.round(x)/128*max1
Tim Dettmers's avatar
Tim Dettmers committed
130
131
132
133
        return x

    @staticmethod
    def get_8bit_vector_wise(x, dim, stochastic=False):
134
135
136
        round_func = (
            LinearFunction.round_stoachastic if stochastic else torch.round
        )
Tim Dettmers's avatar
Tim Dettmers committed
137
        max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
138
139
140
        max1[max1 == 0] = 1.0
        x = (x * 127) / max1
        x = round_func(x) / 127 * max1
Tim Dettmers's avatar
Tim Dettmers committed
141
142
143
144
145
146
        return x

    @staticmethod
    def round_stoachastic(x):
        sign = torch.sign(x)
        absx = torch.abs(x)
147
        decimal = absx - torch.floor(absx)
Tim Dettmers's avatar
Tim Dettmers committed
148
        rdm = torch.rand_like(decimal)
149
        return sign * (torch.floor(absx) + (rdm < decimal).to(x.dtype))
Tim Dettmers's avatar
Tim Dettmers committed
150
151
152
153
154
155
156
157
158
159
160
161
162

    @staticmethod
    def fake_8bit_storage(w, exponent_bits):
        code = bnb.functional.create_dynamic_map(n=exponent_bits).to(w.device)
        absmax, C = bnb.functional.quantize_blockwise(w.data, code=code)
        out = bnb.functional.dequantize_blockwise(absmax, C, code)
        out = out.half()
        w.copy_(out)
        return out

    @staticmethod
    def fake_8bit_storage_quantile(w, args):
        code = bnb.functional.estimate_quantiles(w.data, offset=args.offset)
163
164
165
166
        # C = bnb.functional.quantize_no_absmax(code, w)
        # out = bnb.functional.dequantize_no_absmax(code, C, out=w.data)
        # print(out)
        # out = out.half()
Tim Dettmers's avatar
Tim Dettmers committed
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
        code /= torch.max(torch.abs(code))
        absmax, C = bnb.functional.quantize_blockwise(w.data, code=code)
        out = bnb.functional.dequantize_blockwise(absmax, C, code)
        out = out.half()
        w.copy_(out)
        return out

    @staticmethod
    def fake_8bit_storage_stoachstic(w):
        rand = torch.rand(1024, device=w.device)
        absmax, C = bnb.functional.quantize_blockwise(w.data, rand=rand)
        out = bnb.functional.dequantize_blockwise(absmax, C)
        out = out.half()
        w.copy_(out)
        return out

    @staticmethod
    def fake_8bit_storage_with_max(w, topk=8):
185
        blocked_w = einops.rearrange(w.flatten(), "(h b) -> h b", b=256)
Tim Dettmers's avatar
Tim Dettmers committed
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
        max_val, idx = torch.sort(torch.abs(blocked_w), dim=1, descending=True)
        idx = idx[:, :topk]
        max_val = max_val[:, :topk]

        mask = torch.zeros_like(blocked_w)
        mask.scatter_(dim=1, index=idx, src=torch.ones_like(max_val))
        mask = mask.bool()

        # 1. zero out max values
        # 2. quantize + dequantize
        # 3. write back max values
        # 4. copy matrix back to weight

        values = blocked_w[mask]
        blocked_w[mask] = 0

        code = bnb.functional.create_dynamic_map()
        code = code.to(w.device)
        absmax, C = bnb.functional.quantize_blockwise(blocked_w.data)
        bnb.functional.dequantize_blockwise(absmax, C, out=blocked_w)

        blocked_w[mask] = values

        unblocked_w = blocked_w.flatten().view(w.shape)

        w.copy_(unblocked_w)
        return unblocked_w

    @staticmethod
    def forward(ctx, x, weight, bias=None, args=None):
216
        if args.use_8bit_training != "off":
Tim Dettmers's avatar
Tim Dettmers committed
217
218
219
            weight8, S1 = LinearFunction.quant(weight, args.quant_type, dim=1)
            x8, S2 = LinearFunction.quant(x, args.quant_type, dim=2)
            outputq = bnb.functional.igemm(x8, weight8.t())
220
221
222
            output = LinearFunction.dequant(
                outputq, S1, S2, x.dtype, args.quant_type
            )
223
224
225
226
227
            # if torch.rand(1) < 0.01:
            # output32 = torch.matmul(x, weight.t())
            # err = torch.abs(output-output32).float()
            # relerr = err/(torch.abs(output32).float()+1e-8)
            # print(f'{err.mean().item():.4f}, {relerr.mean().item():.4f}', args.quant_type, 'forward', proxy)
Tim Dettmers's avatar
Tim Dettmers committed
228
        else:
229
230
            # output = torch.matmul(x, weight.t())
            output = torch.einsum("bsi,oi->bso", x, weight)
Tim Dettmers's avatar
Tim Dettmers committed
231
232
233
234
235
236
237
238
239
240
241
242
243
244

        ctx.save_for_backward(x, weight, bias)
        ctx.args = args

        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        x, weight, bias = ctx.saved_tensors
        args = ctx.args
        stochastic = False
        grad_input = grad_weight = grad_bias = None
245
246
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0)
Tim Dettmers's avatar
Tim Dettmers committed
247
248
249

        # weight and x are already 8bit
        # -> transform grad_output to 8-bit
250
251
252
253
        if args.use_8bit_training == "forward+wgrad":
            grad_output8, S1 = LinearFunction.quant(
                grad_output, args.quant_type, dim=[0, 1]
            )
Tim Dettmers's avatar
Tim Dettmers committed
254
255
            x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1])
            grad_weight8 = bnb.functional.igemm(grad_output8, x8)
256
257
258
            grad_weight = LinearFunction.dequant(
                grad_weight8, S1, S2, grad_output.dtype, args.quant_type
            )
Tim Dettmers's avatar
Tim Dettmers committed
259

260
            # grad_weight32 = torch.einsum('bso,bsi->oi', grad_output, x)
Tim Dettmers's avatar
Tim Dettmers committed
261
262

            grad_input = grad_output.matmul(weight)
263
264
265
266
        elif args.use_8bit_training == "full":
            grad_output8, S1 = LinearFunction.quant(
                grad_output, args.quant_type, dim=[0, 1]
            )
Tim Dettmers's avatar
Tim Dettmers committed
267
268
269
            x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1])
            grad_weight8 = torch.zeros_like(weight, dtype=torch.int32)
            bnb.functional.igemm(grad_output8, x8, out=grad_weight8)
270
271
272
            grad_weight = LinearFunction.dequant(
                grad_weight8, S1, S2, grad_output.dtype, args.quant_type
            )
Tim Dettmers's avatar
Tim Dettmers committed
273

274
275
276
            grad_output8, S1 = LinearFunction.quant(
                grad_output, args.quant_type, dim=2
            )
Tim Dettmers's avatar
Tim Dettmers committed
277
278
            weight8, S3 = LinearFunction.quant(weight, args.quant_type, dim=0)
            grad_input8 = bnb.functional.igemm(grad_output8, weight8)
279
280
281
            grad_input = LinearFunction.dequant(
                grad_input8, S1, S3, grad_output.dtype, args.quant_type
            )
Tim Dettmers's avatar
Tim Dettmers committed
282
283
284

        else:
            grad_input = grad_output.matmul(weight)
285
            grad_weight = torch.einsum("bsi,bso->oi", x, grad_output)
286

Tim Dettmers's avatar
Tim Dettmers committed
287
        return grad_input, grad_weight, grad_bias, None
288

289

Tim Dettmers's avatar
Tim Dettmers committed
290
291
class Linear8bit(nn.Module):
    def __init__(self, input_features, output_features, bias=True, args=None):
292
        super().__init__()
Tim Dettmers's avatar
Tim Dettmers committed
293
294
295
        self.input_features = input_features
        self.output_features = output_features
        self.args = args
296

Tim Dettmers's avatar
Tim Dettmers committed
297
298
299
300
        self.weight = nn.Parameter(torch.empty(output_features, input_features))
        if bias:
            self.bias = nn.Parameter(torch.empty(output_features))
        else:
301
            self.register_parameter("bias", None)
302

Tim Dettmers's avatar
Tim Dettmers committed
303
304
305
306
307
308
309
310
311
312
313
314
        torch.nn.init.xavier_uniform_(self.weight)
        if self.bias is not None:
            torch.nn.init.zeros_(self.bias)

    def forward(self, x):
        self.args.training = self.training

        return LinearFunction.apply(x, self.weight, self.bias, self.args)


threshold = [0.0, 3.0]
values = threshold
315
names = [f"threshold_{vals}" for vals in values]
316
317


Tim Dettmers's avatar
Tim Dettmers committed
318
319
@pytest.mark.parametrize("threshold", values, ids=names)
def test_linear8bitlt_inference(threshold):
320
321
    l1 = bnb.nn.Linear8bitLt(32, 64, threshold=threshold).cuda().half()
    assert l1.weight.device.type == "cuda"
Tim Dettmers's avatar
Tim Dettmers committed
322
323
324
    assert l1.weight.dtype == torch.float16

    l1.eval()
325
    for i in range(100):
326
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
327
328
329
330
        o1 = l1(b1)
        if i == 1:
            assert l1.state.CxB is not None

331

Tim Dettmers's avatar
Tim Dettmers committed
332
def test_linear8bitlt_accumulated_gradient():
333
334
    l1 = torch.nn.Sequential(*[bnb.nn.Linear8bitLt(32, 32).cuda().half() for i in range(2)])
    l2 = torch.nn.Sequential(*[torch.nn.Linear(32, 32).cuda().half() for i in range(2)])
335
336
337
338
339
340
341
    l1[0].weight.data.copy_(l2[0].weight.data)
    l1[1].weight.data.copy_(l2[1].weight.data)
    l1[0].bias.data.copy_(l2[0].bias.data)
    l1[1].bias.data.copy_(l2[1].bias.data)

    opt1 = bnb.optim.Adam32bit(l1.parameters(), lr=0.001)
    opt2 = bnb.optim.Adam32bit(l2.parameters(), lr=0.001)
Tim Dettmers's avatar
Tim Dettmers committed
342
343
344
345

    acc_steps = 10

    for i in range(10):
346
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
347
348
349
350
351
352
353
354
355
        o1 = l1(b1)
        o2 = l2(b1)
        loss1 = o1.mean()
        loss2 = o2.mean()
        loss1.backward()
        loss2.backward()
        if i == 2:
            assert l1[0].state.CxB is not None
            assert l1[1].state.CxB is not None
356

Tim Dettmers's avatar
Tim Dettmers committed
357
358
359
360
361
        if i > 0 and i % acc_steps == 0:
            opt1.step()
            opt1.zero_grad(True)
            opt2.step()
            opt2.zero_grad(True)
362
363
364
365
366
367
            assert_all_approx_close(
                l1[0].weight, l2[0].weight, rtol=1.05, atol=0.01, count=2
            )
            assert_all_approx_close(
                l1[1].weight, l2[1].weight, rtol=1.05, atol=0.01, count=2
            )
Tim Dettmers's avatar
Tim Dettmers committed
368
369
370
            # we do this copy because otherwise we have small divergences over time that add up
            l1[0].weight.data.copy_(l2[0].weight.data)
            l1[1].weight.data.copy_(l2[1].weight.data)
371
372
            l1[0].bias.data.copy_(l2[0].bias.data)
            l1[1].bias.data.copy_(l2[1].bias.data)
Tim Dettmers's avatar
Tim Dettmers committed
373
        else:
374
375
            torch.testing.assert_close(l1[0].weight.grad, l2[0].weight.grad, atol=1e-3, rtol=1e-3)
            torch.testing.assert_close(l1[1].weight.grad, l2[1].weight.grad, atol=1e-3, rtol=1e-3)
376
377


378
@pytest.mark.parametrize("threshold", [0.0, 2.0])
379
@pytest.mark.parametrize("memory_efficient_backward", [False])
justheuristic's avatar
justheuristic committed
380
def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
381
    l1 = (bnb.nn.Linear8bitLt( 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward).cuda().half())
Tim Dettmers's avatar
Tim Dettmers committed
382
    assert l1.weight.dtype == torch.int8
383

Tim Dettmers's avatar
Tim Dettmers committed
384
385
    l1.eval()
    for i in range(100):
386
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
387
388
389
390
391
392
        o1 = l1(b1)
        assert o1.dtype == torch.float16

    mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda()
    assert mlp.fc1.weight.dtype == torch.int8
    assert mlp.fc2.weight.dtype == torch.int8
393

Tim Dettmers's avatar
Tim Dettmers committed
394
    for i in range(100):
395
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
396
397
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
398
399
400
401
        if threshold > 0:
            assert mlp.fc1.state.idx is not None
        if threshold > 0:
            assert mlp.fc2.state.idx is not None
402

403
404
405
406
407
    mlp = (
        MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False)
        .cuda()
        .half()
    )
Tim Dettmers's avatar
Tim Dettmers committed
408
409
410
411
    assert mlp.fc1.weight.dtype == torch.int8
    assert mlp.fc2.weight.dtype == torch.int8

    for i in range(100):
412
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
413
414
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
415
416
417
418
        if threshold > 0:
            assert mlp.fc1.state.idx is not None
        if threshold > 0:
            assert mlp.fc2.state.idx is not None
419

420
421
422
423
424
    mlp = (
        MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False)
        .half()
        .cuda()
    )
Tim Dettmers's avatar
Tim Dettmers committed
425
426

    for i in range(100):
427
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
428
429
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
430
431
432
433
        if threshold > 0:
            assert mlp.fc1.state.idx is not None
        if threshold > 0:
            assert mlp.fc2.state.idx is not None
Tim Dettmers's avatar
Tim Dettmers committed
434
435
436
    assert mlp.fc1.weight.dtype == torch.int8
    assert mlp.fc2.weight.dtype == torch.int8

437
    mlp = ( MLP8bit( 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward).half().to("cuda"))
Tim Dettmers's avatar
Tim Dettmers committed
438
439

    for i in range(100):
440
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
441
442
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
443
444
445
446
        if threshold > 0:
            assert mlp.fc1.state.idx is not None
        if threshold > 0:
            assert mlp.fc2.state.idx is not None
Tim Dettmers's avatar
Tim Dettmers committed
447
448
    assert mlp.fc1.weight.dtype == torch.int8
    assert mlp.fc2.weight.dtype == torch.int8
449
450
    assert mlp.fc1.weight.device.type == "cuda"
    assert mlp.fc2.weight.device.type == "cuda"
Tim Dettmers's avatar
Tim Dettmers committed
451

justheuristic's avatar
justheuristic committed
452
    mlp = MLP8bit(
justheuristic's avatar
justheuristic committed
453
454
            32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward
        )
justheuristic's avatar
justheuristic committed
455
    w1, w2 = mlp.fc1.weight.clone().cuda(), mlp.fc2.weight.clone().cuda()  # grab weights before quantization,
justheuristic's avatar
justheuristic committed
456
    mlp = mlp.cuda().half()  # and this line triggers quantization
Tim Dettmers's avatar
Tim Dettmers committed
457
458

    for i in range(100):
459
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
460
461
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
462
463
464
465
        if threshold > 0:
            assert mlp.fc1.state.idx is not None
        if threshold > 0:
            assert mlp.fc2.state.idx is not None
justheuristic's avatar
justheuristic committed
466

Tim Dettmers's avatar
Tim Dettmers committed
467
468
    assert mlp.fc1.weight.dtype == torch.int8
    assert mlp.fc2.weight.dtype == torch.int8
469
470
    assert mlp.fc1.weight.device.type == "cuda"
    assert mlp.fc2.weight.device.type == "cuda"
471

justheuristic's avatar
justheuristic committed
472
473
474
475
476
477
478
    if memory_efficient_backward:
        b1 = torch.randn(16, 8, 32, device="cuda", requires_grad=True, dtype=torch.half)
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
        assert o1.requires_grad
        grad_proj = torch.randn_like(o1)

justheuristic's avatar
debug  
justheuristic committed
479
        mlp.zero_grad()
justheuristic's avatar
justheuristic committed
480
        (o1 * grad_proj).sum().backward()
justheuristic's avatar
justheuristic committed
481
        grad_ref = grad_proj.flatten(2) @ w2.half() @ w1.half()
justheuristic's avatar
justheuristic committed
482
        scale = grad_ref.abs().mean()
justheuristic's avatar
justheuristic committed
483

484
        torch.testing.assert_close(b1.grad, grad_ref, rtol=0, atol=0.05 * scale)
justheuristic's avatar
review  
justheuristic committed
485
        idx = torch.isclose(b1.grad, grad_ref, atol=0.01 * scale, rtol=0.1)
Tim Dettmers's avatar
Tim Dettmers committed
486
        assert (idx == 0).sum().item() <= b1.numel() * 0.005
487

justheuristic's avatar
justheuristic committed
488

489
490
@pytest.mark.parametrize("module", [lambda nin, nout, bias=True: bnb.nn.Linear8bitLt(nin, nout, bias=bias, has_fp16_weights=False), bnb.nn.LinearFP4], ids=['Int8Lt', 'FP4'])
def test_linear_kbit_fp32_bias(module):
491
    # casts model to fp16 -> int8 automatically
492
493
    l1 = module(32, 64).cuda()
    assert l1.weight.dtype in [torch.int8, torch.uint8]
494
495
496
497
498
499
500
501
502
    assert l1.bias.dtype == torch.float32

    for i in range(100):
        b1 = torch.randn(16, 8, 32, device="cuda").half()
        # casts bias to fp32
        o1 = l1(b1)
        assert l1.bias.dtype == torch.float16

    # casts model to fp16 -> int8 automatically
503
504
    l1 = module(32, 64, bias=False).cuda()
    assert l1.weight.dtype in [torch.int8, torch.uint8]
505
506
507
508
509
510
    assert l1.bias is None

    for i in range(100):
        b1 = torch.randn(16, 8, 32, device="cuda").half()
        o1 = l1(b1)
        assert l1.bias is None
511

512
513
514
515
516
517
518
modules = []
modules.append(bnb.nn.Linear8bitLt)
modules.append(bnb.nn.Linear4bit)
modules.append(bnb.nn.LinearFP4)
modules.append(bnb.nn.LinearNF4)
modules.append(lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compress_statistics=True))
modules.append(lambda d1, d2: bnb.nn.LinearNF4(d1, d2, compress_statistics=True))
519
520
521
522
modules.append(lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.float32))
modules.append(lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.float16))
modules.append(lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.bfloat16))
names = ['Int8Lt', '4bit', 'FP4', 'NF4', 'FP4+C', 'NF4+C', 'NF4+fp32', 'NF4+fp16', 'NF4+bf16']
523
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
524
@pytest.mark.parametrize("module", modules, ids=names)
525
526
527
528
529
530
531
def test_kbit_backprop(module):
    b = 17
    dim1 = 37
    dim2 = 83

    ref = nn.Sequential(*[torch.nn.Linear(dim1, dim2), torch.nn.Linear(dim2, 10)])
    ref[1].weight.requires_grad = False
532
533
    torch.nn.init.kaiming_normal_(ref[0].weight)
    torch.nn.init.kaiming_normal_(ref[1].weight)
534
535
536
537
538
539
540
    kbit = nn.Sequential(*[torch.nn.Linear(dim1, dim2), module(dim2, 10)])
    kbit[0].weight.detach().copy_(ref[0].weight)
    kbit[1].weight.detach().copy_(ref[1].weight)
    kbit[0].bias.detach().copy_(ref[0].bias)
    kbit[1].bias.detach().copy_(ref[1].bias)
    ref = ref.half().cuda()
    kbit = kbit.half().cuda()
541
    kbit = kbit.half().to('cuda')
542

543
544
545
546
    errs1 = []
    errs2 = []
    relerrs1 = []
    relerrs2 = []
547
548
549
550
551
552
553
554
555
556
557
558
    for i in range(100):
        batch = torch.randn(b, dim1).half().cuda()
        out1 = ref(batch)
        out2 = kbit(batch)
        out1.mean().backward()
        out2.mean().backward()

        grad1 = ref[0].weight.grad
        grad2 = kbit[0].weight.grad
        bgrad1 = ref[0].bias.grad
        bgrad2 = kbit[0].bias.grad

559
560
561
562
563
564
565
566
567
        err1 = (out1-out2).abs().float()
        err2 = (grad1-grad2).abs().float()
        relerr1 = (err1/(out1.abs().float()+1e-9))
        relerr2 = (err2/(grad1.abs().float()+1e-9))
        errs1.append(err1.mean().item())
        errs2.append(err2.mean().item())
        relerrs1.append(relerr1.mean().item())
        relerrs2.append(relerr2.mean().item())

568
        if isinstance(module, bnb.nn.Linear8bitLt):
569
            assert_all_approx_close(grad1, grad2, atol=0.008, rtol=0.05, count=1)
570
            torch.testing.assert_close(bgrad1, bgrad2, atol=0.008, rtol=0.05)
571
        else:
572
            assert_all_approx_close(grad1, grad2, atol=0.015, rtol=0.05, count=1)
573
            torch.testing.assert_close(bgrad1, bgrad2, atol=0.02, rtol=0.05)
574
575
576
        ref.zero_grad()
        kbit.zero_grad()

577
578
        assert kbit[0].weight.grad is None or kbit[0].weight.grad.sum().item() == 0
        assert kbit[0].weight.grad is None or kbit[0].bias.grad.sum().item() == 0
579
580
581
582
    print('out', sum(errs1)/len(errs1))
    print('grad', sum(errs2)/len(errs2))
    print('rel out', sum(relerrs1)/len(relerrs1))
    print('rel grad', sum(relerrs2)/len(relerrs2))
583

584
585
586
587
588
589
def test_fp8linear():

    b = 10
    h = 1024
    inp = torch.randn(b, h).cuda()
    fp32 = torch.nn.Linear(h, h*2).cuda()
590
    fp8 = bnb.research.nn.LinearFP8Mixed(h, h*2).cuda()
591
    fp32b = torch.nn.Linear(h*2, h).cuda()
592
    fp8b = bnb.research.nn.LinearFP8Mixed(h*2, h).cuda()
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613

    fp8.weight.data.copy_(fp32.weight.data)
    fp8.bias.data.copy_(fp32.bias.data)
    fp8b.weight.data.copy_(fp32b.weight.data)
    fp8b.bias.data.copy_(fp32b.bias.data)

    a = fp32b(torch.nn.functional.gelu(fp32(inp)))
    b = fp8b(torch.nn.functional.gelu(fp8(inp)))

    err = (a-b).abs().mean()

    a.mean().backward()
    b.mean().backward()

    graderr = (fp8.weight.grad-fp32.weight.grad).abs().mean()
    bgraderr = (fp8.bias.grad-fp32.bias.grad).abs().mean()

    assert err < 0.05
    assert graderr < 0.00002
    assert bgraderr < 0.00002

614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
def test_4bit_warnings():
    dim1 = 64

    with pytest.warns(UserWarning, match=r'inference or training'):
        net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)])
        net = net.cuda()
        inp = torch.rand(10, dim1).cuda().half()
        net(inp)
    with pytest.warns(UserWarning, match=r'inference.'):
        net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)])
        net = net.cuda()
        inp = torch.rand(1, dim1).cuda().half()
        net(inp)

    with pytest.warns(UserWarning) as record:

        net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)])
        net = net.cuda()
        inp = torch.rand(10, dim1).cuda().half()
        net(inp)

        net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)])
        net = net.cuda()
        inp = torch.rand(1, dim1).cuda().half()
        net(inp)

    assert len(record) == 2
641
642


643