test_modules.py 21.2 KB
Newer Older
1
2
from itertools import product

3
4
import pytest
import torch
Tim Dettmers's avatar
Tim Dettmers committed
5
6
from torch import nn

7
8
import bitsandbytes as bnb

9

10
class MockArgs:
Tim Dettmers's avatar
Tim Dettmers committed
11
12
13
14
    def __init__(self, initial_data):
        for key in initial_data:
            setattr(self, key, initial_data[key])

15

Tim Dettmers's avatar
Tim Dettmers committed
16
class MLP8bit(torch.nn.Module):
justheuristic's avatar
justheuristic committed
17
    def __init__(self, dim1, dim2, has_fp16_weights=True, memory_efficient_backward=False, threshold=0.0):
18
        super().__init__()
19
        self.fc1 = bnb.nn.Linear8bitLt(
justheuristic's avatar
justheuristic committed
20
21
            dim1, dim2, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward,
            threshold=threshold
22
23
        )
        self.fc2 = bnb.nn.Linear8bitLt(
justheuristic's avatar
justheuristic committed
24
25
            dim2, dim1, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward,
            threshold=threshold
26
        )
Tim Dettmers's avatar
Tim Dettmers committed
27
28
29
30
31
32
33
34
35

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x


def get_args():
    args = MockArgs([])
36
37
    args.quant_type = "vector"
    args.use_8bit_training = "full"
Tim Dettmers's avatar
Tim Dettmers committed
38
39
40
    args.clip_freq = 9999
    return args

41

Tim Dettmers's avatar
Tim Dettmers committed
42
43
def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10):
    idx = torch.isclose(a, b, rtol, atol)
44
    sumval = (idx == 0).sum().item()
Tim Dettmers's avatar
Tim Dettmers committed
45
    if sumval > count:
46
        print(f"Too many values not close: assert {sumval} < {count}")
47
        torch.testing.assert_close(a, b, rtol, atol)
Tim Dettmers's avatar
Tim Dettmers committed
48
49


50
class LinearFunction(torch.autograd.Function):
Tim Dettmers's avatar
Tim Dettmers committed
51
52
    @staticmethod
    def get_8bit_linear_trimmed(x, stochastic=False, trim_value=3.0):
53
54
55
        round_func = (
            LinearFunction.round_stoachastic if stochastic else torch.round
        )
56
57
        norm = math.sqrt(math.pi) / math.sqrt(2.0)
        # std = torch.abs(x).mean()*norm
Tim Dettmers's avatar
Tim Dettmers committed
58
        std = torch.std(x)
59
60
        max1 = std * trim_value
        x = x / max1 * 127
Tim Dettmers's avatar
Tim Dettmers committed
61
62
63
        x = round_func(x)
        x[x > 127] = 127
        x[x < -127] = -127
64
        x = x / 127 * max1
Tim Dettmers's avatar
Tim Dettmers committed
65
66
67
68

        return x

    def quant(x, quant_type, dim=1):
69
        if quant_type == "linear":
Tim Dettmers's avatar
Tim Dettmers committed
70
            max1 = torch.abs(x).max().float()
71
            xq = torch.round(x / max1 * 127).to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
72
            return xq, max1
73
        elif quant_type == "vector":
Tim Dettmers's avatar
Tim Dettmers committed
74
            max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
75
            xq = torch.round(x / max1 * 127).to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
76
            return xq, max1
77
        elif quant_type == "min-max":
Tim Dettmers's avatar
Tim Dettmers committed
78
79
            maxA = torch.amax(x, dim=dim, keepdim=True).float()
            minA = torch.amin(x, dim=dim, keepdim=True).float()
80
81
            scale = (maxA - minA) / 2.0
            xq = torch.round(127 * (x - minA - scale) / scale).to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
82
            return xq, (minA.float(), scale.float())
83
84
        else:
            return None
Tim Dettmers's avatar
Tim Dettmers committed
85
86

    def dequant(xq, S1, S2, dtype, quant_type):
87
88
        if quant_type == "linear":
            norm = S1 * S2 / (127 * 127)
Tim Dettmers's avatar
Tim Dettmers committed
89
            # double cast needed to prevent overflows
90
91
            return (xq.float() * norm).to(dtype)
        elif quant_type == "vector":
Tim Dettmers's avatar
Tim Dettmers committed
92
            x = xq.float()
93
94
95
96
97
            if len(xq.shape) == 2 and len(S1.shape) == 3:
                S1 = S1.squeeze(0)
            if len(xq.shape) == 2 and len(S2.shape) == 3:
                S2 = S2.squeeze(0)
            # print(x.shape, S1.shape, S2.shape)
Tim Dettmers's avatar
Tim Dettmers committed
98
            if len(S1.shape) == 2:
99
                x *= S1.t() / 127
Tim Dettmers's avatar
Tim Dettmers committed
100
            else:
101
102
                x *= S1 / 127
            x *= S2 / 127
Tim Dettmers's avatar
Tim Dettmers committed
103
            return x.to(dtype)
104
105
        else:
            return None
Tim Dettmers's avatar
Tim Dettmers committed
106
107

    def dequant_min_max(xq, A, B, SA, SB, dtype):
108
        offset = B.float().t().sum(0) * (SA[0] + SA[1])
Tim Dettmers's avatar
Tim Dettmers committed
109
        x = xq.float()
110
111
112
113
        if len(xq.shape) == 2 and len(SB.shape) == 3:
            SB = SB.squeeze(0)
        if len(xq.shape) == 2 and len(SA.shape) == 3:
            SA = SA.squeeze(0)
Tim Dettmers's avatar
Tim Dettmers committed
114
        if len(SB.shape) == 2:
115
            x *= SB.t() / 127
Tim Dettmers's avatar
Tim Dettmers committed
116
        else:
117
118
119
            x *= SB / 127
        x *= SA[1] / 127
        x += offset
Tim Dettmers's avatar
Tim Dettmers committed
120
121
122
        return x.to(dtype)

    def get_8bit_linear(x, stochastic=False):
123
124
125
        round_func = (
            LinearFunction.round_stoachastic if stochastic else torch.round
        )
Tim Dettmers's avatar
Tim Dettmers committed
126
        max1 = torch.abs(x).max()
127
128
129
        x = x / max1 * 127
        x = round_func(x) / 127 * max1
        # x = torch.round(x)/128*max1
Tim Dettmers's avatar
Tim Dettmers committed
130
131
132
133
        return x

    @staticmethod
    def get_8bit_vector_wise(x, dim, stochastic=False):
134
135
136
        round_func = (
            LinearFunction.round_stoachastic if stochastic else torch.round
        )
Tim Dettmers's avatar
Tim Dettmers committed
137
        max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
138
139
140
        max1[max1 == 0] = 1.0
        x = (x * 127) / max1
        x = round_func(x) / 127 * max1
Tim Dettmers's avatar
Tim Dettmers committed
141
142
143
144
145
146
        return x

    @staticmethod
    def round_stoachastic(x):
        sign = torch.sign(x)
        absx = torch.abs(x)
147
        decimal = absx - torch.floor(absx)
Tim Dettmers's avatar
Tim Dettmers committed
148
        rdm = torch.rand_like(decimal)
149
        return sign * (torch.floor(absx) + (rdm < decimal).to(x.dtype))
Tim Dettmers's avatar
Tim Dettmers committed
150
151
152
153
154
155
156
157
158
159
160
161
162

    @staticmethod
    def fake_8bit_storage(w, exponent_bits):
        code = bnb.functional.create_dynamic_map(n=exponent_bits).to(w.device)
        absmax, C = bnb.functional.quantize_blockwise(w.data, code=code)
        out = bnb.functional.dequantize_blockwise(absmax, C, code)
        out = out.half()
        w.copy_(out)
        return out

    @staticmethod
    def fake_8bit_storage_quantile(w, args):
        code = bnb.functional.estimate_quantiles(w.data, offset=args.offset)
163
164
165
166
        # C = bnb.functional.quantize_no_absmax(code, w)
        # out = bnb.functional.dequantize_no_absmax(code, C, out=w.data)
        # print(out)
        # out = out.half()
Tim Dettmers's avatar
Tim Dettmers committed
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
        code /= torch.max(torch.abs(code))
        absmax, C = bnb.functional.quantize_blockwise(w.data, code=code)
        out = bnb.functional.dequantize_blockwise(absmax, C, code)
        out = out.half()
        w.copy_(out)
        return out

    @staticmethod
    def fake_8bit_storage_stoachstic(w):
        rand = torch.rand(1024, device=w.device)
        absmax, C = bnb.functional.quantize_blockwise(w.data, rand=rand)
        out = bnb.functional.dequantize_blockwise(absmax, C)
        out = out.half()
        w.copy_(out)
        return out

    @staticmethod
    def fake_8bit_storage_with_max(w, topk=8):
185
        blocked_w = einops.rearrange(w.flatten(), "(h b) -> h b", b=256)
Tim Dettmers's avatar
Tim Dettmers committed
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
        max_val, idx = torch.sort(torch.abs(blocked_w), dim=1, descending=True)
        idx = idx[:, :topk]
        max_val = max_val[:, :topk]

        mask = torch.zeros_like(blocked_w)
        mask.scatter_(dim=1, index=idx, src=torch.ones_like(max_val))
        mask = mask.bool()

        # 1. zero out max values
        # 2. quantize + dequantize
        # 3. write back max values
        # 4. copy matrix back to weight

        values = blocked_w[mask]
        blocked_w[mask] = 0

        code = bnb.functional.create_dynamic_map()
        code = code.to(w.device)
        absmax, C = bnb.functional.quantize_blockwise(blocked_w.data)
        bnb.functional.dequantize_blockwise(absmax, C, out=blocked_w)

        blocked_w[mask] = values

        unblocked_w = blocked_w.flatten().view(w.shape)

        w.copy_(unblocked_w)
        return unblocked_w

    @staticmethod
    def forward(ctx, x, weight, bias=None, args=None):
216
        if args.use_8bit_training != "off":
Tim Dettmers's avatar
Tim Dettmers committed
217
218
219
            weight8, S1 = LinearFunction.quant(weight, args.quant_type, dim=1)
            x8, S2 = LinearFunction.quant(x, args.quant_type, dim=2)
            outputq = bnb.functional.igemm(x8, weight8.t())
220
221
222
            output = LinearFunction.dequant(
                outputq, S1, S2, x.dtype, args.quant_type
            )
223
224
225
226
227
            # if torch.rand(1) < 0.01:
            # output32 = torch.matmul(x, weight.t())
            # err = torch.abs(output-output32).float()
            # relerr = err/(torch.abs(output32).float()+1e-8)
            # print(f'{err.mean().item():.4f}, {relerr.mean().item():.4f}', args.quant_type, 'forward', proxy)
Tim Dettmers's avatar
Tim Dettmers committed
228
        else:
229
230
            # output = torch.matmul(x, weight.t())
            output = torch.einsum("bsi,oi->bso", x, weight)
Tim Dettmers's avatar
Tim Dettmers committed
231
232
233
234
235
236
237
238
239
240
241
242
243
244

        ctx.save_for_backward(x, weight, bias)
        ctx.args = args

        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        x, weight, bias = ctx.saved_tensors
        args = ctx.args
        stochastic = False
        grad_input = grad_weight = grad_bias = None
245
246
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0)
Tim Dettmers's avatar
Tim Dettmers committed
247
248
249

        # weight and x are already 8bit
        # -> transform grad_output to 8-bit
250
251
252
253
        if args.use_8bit_training == "forward+wgrad":
            grad_output8, S1 = LinearFunction.quant(
                grad_output, args.quant_type, dim=[0, 1]
            )
Tim Dettmers's avatar
Tim Dettmers committed
254
255
            x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1])
            grad_weight8 = bnb.functional.igemm(grad_output8, x8)
256
257
258
            grad_weight = LinearFunction.dequant(
                grad_weight8, S1, S2, grad_output.dtype, args.quant_type
            )
Tim Dettmers's avatar
Tim Dettmers committed
259

260
            # grad_weight32 = torch.einsum('bso,bsi->oi', grad_output, x)
Tim Dettmers's avatar
Tim Dettmers committed
261
262

            grad_input = grad_output.matmul(weight)
263
264
265
266
        elif args.use_8bit_training == "full":
            grad_output8, S1 = LinearFunction.quant(
                grad_output, args.quant_type, dim=[0, 1]
            )
Tim Dettmers's avatar
Tim Dettmers committed
267
268
269
            x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1])
            grad_weight8 = torch.zeros_like(weight, dtype=torch.int32)
            bnb.functional.igemm(grad_output8, x8, out=grad_weight8)
270
271
272
            grad_weight = LinearFunction.dequant(
                grad_weight8, S1, S2, grad_output.dtype, args.quant_type
            )
Tim Dettmers's avatar
Tim Dettmers committed
273

274
275
276
            grad_output8, S1 = LinearFunction.quant(
                grad_output, args.quant_type, dim=2
            )
Tim Dettmers's avatar
Tim Dettmers committed
277
278
            weight8, S3 = LinearFunction.quant(weight, args.quant_type, dim=0)
            grad_input8 = bnb.functional.igemm(grad_output8, weight8)
279
280
281
            grad_input = LinearFunction.dequant(
                grad_input8, S1, S3, grad_output.dtype, args.quant_type
            )
Tim Dettmers's avatar
Tim Dettmers committed
282
283
284

        else:
            grad_input = grad_output.matmul(weight)
285
            grad_weight = torch.einsum("bsi,bso->oi", x, grad_output)
286

Tim Dettmers's avatar
Tim Dettmers committed
287
        return grad_input, grad_weight, grad_bias, None
288

289

Tim Dettmers's avatar
Tim Dettmers committed
290
291
class Linear8bit(nn.Module):
    def __init__(self, input_features, output_features, bias=True, args=None):
292
        super().__init__()
Tim Dettmers's avatar
Tim Dettmers committed
293
294
295
        self.input_features = input_features
        self.output_features = output_features
        self.args = args
296

Tim Dettmers's avatar
Tim Dettmers committed
297
298
299
300
        self.weight = nn.Parameter(torch.empty(output_features, input_features))
        if bias:
            self.bias = nn.Parameter(torch.empty(output_features))
        else:
301
            self.register_parameter("bias", None)
302

Tim Dettmers's avatar
Tim Dettmers committed
303
304
305
306
307
308
309
310
311
312
313
314
        torch.nn.init.xavier_uniform_(self.weight)
        if self.bias is not None:
            torch.nn.init.zeros_(self.bias)

    def forward(self, x):
        self.args.training = self.training

        return LinearFunction.apply(x, self.weight, self.bias, self.args)


threshold = [0.0, 3.0]
values = threshold
315
names = [f"threshold_{vals}" for vals in values]
316
317


Tim Dettmers's avatar
Tim Dettmers committed
318
319
@pytest.mark.parametrize("threshold", values, ids=names)
def test_linear8bitlt_inference(threshold):
320
321
    l1 = bnb.nn.Linear8bitLt(32, 64, threshold=threshold).cuda().half()
    assert l1.weight.device.type == "cuda"
Tim Dettmers's avatar
Tim Dettmers committed
322
323
324
    assert l1.weight.dtype == torch.float16

    l1.eval()
325
    for i in range(100):
326
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
327
328
329
330
        o1 = l1(b1)
        if i == 1:
            assert l1.state.CxB is not None

331

Tim Dettmers's avatar
Tim Dettmers committed
332
def test_linear8bitlt_accumulated_gradient():
333
334
    l1 = torch.nn.Sequential(*[bnb.nn.Linear8bitLt(32, 32).cuda().half() for i in range(2)])
    l2 = torch.nn.Sequential(*[torch.nn.Linear(32, 32).cuda().half() for i in range(2)])
Tim Dettmers's avatar
Tim Dettmers committed
335
336
337
338
339
340
341
342
343
344
    l2[0].weight = torch.nn.Parameter(l1[0].weight.clone())
    l2[0].bias = torch.nn.Parameter(l1[0].bias.clone())
    l2[1].weight = torch.nn.Parameter(l1[1].weight.clone())
    l2[1].bias = torch.nn.Parameter(l1[1].bias.clone())
    opt1 = bnb.optim.Adam8bit(l1.parameters(), lr=0.001)
    opt2 = bnb.optim.Adam8bit(l2.parameters(), lr=0.001)

    acc_steps = 10

    for i in range(10):
345
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
346
347
348
349
350
351
352
353
354
        o1 = l1(b1)
        o2 = l2(b1)
        loss1 = o1.mean()
        loss2 = o2.mean()
        loss1.backward()
        loss2.backward()
        if i == 2:
            assert l1[0].state.CxB is not None
            assert l1[1].state.CxB is not None
355

356
        print(i)
Tim Dettmers's avatar
Tim Dettmers committed
357
358
359
360
361
        if i > 0 and i % acc_steps == 0:
            opt1.step()
            opt1.zero_grad(True)
            opt2.step()
            opt2.zero_grad(True)
362
363
364
365
366
367
            assert_all_approx_close(
                l1[0].weight, l2[0].weight, rtol=1.05, atol=0.01, count=2
            )
            assert_all_approx_close(
                l1[1].weight, l2[1].weight, rtol=1.05, atol=0.01, count=2
            )
Tim Dettmers's avatar
Tim Dettmers committed
368
369
370
371
            # we do this copy because otherwise we have small divergences over time that add up
            l1[0].weight.data.copy_(l2[0].weight.data)
            l1[1].weight.data.copy_(l2[1].weight.data)
        else:
372
373
            torch.testing.assert_close(l1[0].weight.grad, l2[0].weight.grad)
            torch.testing.assert_close(l1[1].weight.grad, l2[1].weight.grad)
374
375


376
@pytest.mark.parametrize("threshold", [0.0, 2.0])
377
@pytest.mark.parametrize("memory_efficient_backward", [False])
justheuristic's avatar
justheuristic committed
378
def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
379
    l1 = (bnb.nn.Linear8bitLt( 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward).cuda().half())
Tim Dettmers's avatar
Tim Dettmers committed
380
    assert l1.weight.dtype == torch.int8
381

Tim Dettmers's avatar
Tim Dettmers committed
382
383
    l1.eval()
    for i in range(100):
384
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
385
386
387
388
389
390
        o1 = l1(b1)
        assert o1.dtype == torch.float16

    mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda()
    assert mlp.fc1.weight.dtype == torch.int8
    assert mlp.fc2.weight.dtype == torch.int8
391

Tim Dettmers's avatar
Tim Dettmers committed
392
    for i in range(100):
393
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
394
395
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
396
397
398
399
        if threshold > 0:
            assert mlp.fc1.state.idx is not None
        if threshold > 0:
            assert mlp.fc2.state.idx is not None
400

401
402
403
404
405
    mlp = (
        MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False)
        .cuda()
        .half()
    )
Tim Dettmers's avatar
Tim Dettmers committed
406
407
408
409
    assert mlp.fc1.weight.dtype == torch.int8
    assert mlp.fc2.weight.dtype == torch.int8

    for i in range(100):
410
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
411
412
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
413
414
415
416
        if threshold > 0:
            assert mlp.fc1.state.idx is not None
        if threshold > 0:
            assert mlp.fc2.state.idx is not None
417

418
419
420
421
422
    mlp = (
        MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False)
        .half()
        .cuda()
    )
Tim Dettmers's avatar
Tim Dettmers committed
423
424

    for i in range(100):
425
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
426
427
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
428
429
430
431
        if threshold > 0:
            assert mlp.fc1.state.idx is not None
        if threshold > 0:
            assert mlp.fc2.state.idx is not None
Tim Dettmers's avatar
Tim Dettmers committed
432
433
434
    assert mlp.fc1.weight.dtype == torch.int8
    assert mlp.fc2.weight.dtype == torch.int8

435
    mlp = ( MLP8bit( 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward).half().to("cuda"))
Tim Dettmers's avatar
Tim Dettmers committed
436
437

    for i in range(100):
438
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
439
440
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
441
442
443
444
        if threshold > 0:
            assert mlp.fc1.state.idx is not None
        if threshold > 0:
            assert mlp.fc2.state.idx is not None
Tim Dettmers's avatar
Tim Dettmers committed
445
446
    assert mlp.fc1.weight.dtype == torch.int8
    assert mlp.fc2.weight.dtype == torch.int8
447
448
    assert mlp.fc1.weight.device.type == "cuda"
    assert mlp.fc2.weight.device.type == "cuda"
Tim Dettmers's avatar
Tim Dettmers committed
449

justheuristic's avatar
justheuristic committed
450
    mlp = MLP8bit(
justheuristic's avatar
justheuristic committed
451
452
            32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward
        )
justheuristic's avatar
justheuristic committed
453
    w1, w2 = mlp.fc1.weight.clone().cuda(), mlp.fc2.weight.clone().cuda()  # grab weights before quantization,
justheuristic's avatar
justheuristic committed
454
    mlp = mlp.cuda().half()  # and this line triggers quantization
Tim Dettmers's avatar
Tim Dettmers committed
455
456

    for i in range(100):
457
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
458
459
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
460
461
462
463
        if threshold > 0:
            assert mlp.fc1.state.idx is not None
        if threshold > 0:
            assert mlp.fc2.state.idx is not None
justheuristic's avatar
justheuristic committed
464

Tim Dettmers's avatar
Tim Dettmers committed
465
466
    assert mlp.fc1.weight.dtype == torch.int8
    assert mlp.fc2.weight.dtype == torch.int8
467
468
    assert mlp.fc1.weight.device.type == "cuda"
    assert mlp.fc2.weight.device.type == "cuda"
469

justheuristic's avatar
justheuristic committed
470
471
472
473
474
475
476
    if memory_efficient_backward:
        b1 = torch.randn(16, 8, 32, device="cuda", requires_grad=True, dtype=torch.half)
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
        assert o1.requires_grad
        grad_proj = torch.randn_like(o1)

justheuristic's avatar
debug  
justheuristic committed
477
        mlp.zero_grad()
justheuristic's avatar
justheuristic committed
478
        (o1 * grad_proj).sum().backward()
justheuristic's avatar
justheuristic committed
479
        grad_ref = grad_proj.flatten(2) @ w2.half() @ w1.half()
justheuristic's avatar
justheuristic committed
480
        scale = grad_ref.abs().mean()
justheuristic's avatar
justheuristic committed
481

482
        torch.testing.assert_close(b1.grad, grad_ref, rtol=0, atol=0.05 * scale)
justheuristic's avatar
review  
justheuristic committed
483
        idx = torch.isclose(b1.grad, grad_ref, atol=0.01 * scale, rtol=0.1)
Tim Dettmers's avatar
Tim Dettmers committed
484
        assert (idx == 0).sum().item() <= b1.numel() * 0.005
485

justheuristic's avatar
justheuristic committed
486

487
488
@pytest.mark.parametrize("module", [lambda nin, nout, bias=True: bnb.nn.Linear8bitLt(nin, nout, bias=bias, has_fp16_weights=False), bnb.nn.LinearFP4], ids=['Int8Lt', 'FP4'])
def test_linear_kbit_fp32_bias(module):
489
    # casts model to fp16 -> int8 automatically
490
491
    l1 = module(32, 64).cuda()
    assert l1.weight.dtype in [torch.int8, torch.uint8]
492
493
494
495
496
497
498
499
500
    assert l1.bias.dtype == torch.float32

    for i in range(100):
        b1 = torch.randn(16, 8, 32, device="cuda").half()
        # casts bias to fp32
        o1 = l1(b1)
        assert l1.bias.dtype == torch.float16

    # casts model to fp16 -> int8 automatically
501
502
    l1 = module(32, 64, bias=False).cuda()
    assert l1.weight.dtype in [torch.int8, torch.uint8]
503
504
505
506
507
508
    assert l1.bias is None

    for i in range(100):
        b1 = torch.randn(16, 8, 32, device="cuda").half()
        o1 = l1(b1)
        assert l1.bias is None
509

510
511
512
513
514
515
516
517
modules = []
modules.append(bnb.nn.Linear8bitLt)
modules.append(bnb.nn.Linear4bit)
modules.append(bnb.nn.LinearFP4)
modules.append(bnb.nn.LinearNF4)
modules.append(lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compress_statistics=True))
modules.append(lambda d1, d2: bnb.nn.LinearNF4(d1, d2, compress_statistics=True))
names = ['Int8Lt', '4bit', 'FP4', 'NF4', 'FP4+C', 'NF4+C']
518
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
519
@pytest.mark.parametrize("module", modules, ids=names)
520
521
522
523
524
525
526
def test_kbit_backprop(module):
    b = 17
    dim1 = 37
    dim2 = 83

    ref = nn.Sequential(*[torch.nn.Linear(dim1, dim2), torch.nn.Linear(dim2, 10)])
    ref[1].weight.requires_grad = False
527
528
    torch.nn.init.kaiming_normal_(ref[0].weight)
    torch.nn.init.kaiming_normal_(ref[1].weight)
529
530
531
532
533
534
535
536
    kbit = nn.Sequential(*[torch.nn.Linear(dim1, dim2), module(dim2, 10)])
    kbit[0].weight.detach().copy_(ref[0].weight)
    kbit[1].weight.detach().copy_(ref[1].weight)
    kbit[0].bias.detach().copy_(ref[0].bias)
    kbit[1].bias.detach().copy_(ref[1].bias)
    ref = ref.half().cuda()
    kbit = kbit.half().cuda()

537
538
539
540
    errs1 = []
    errs2 = []
    relerrs1 = []
    relerrs2 = []
541
542
543
544
545
546
547
548
549
550
551
552
    for i in range(100):
        batch = torch.randn(b, dim1).half().cuda()
        out1 = ref(batch)
        out2 = kbit(batch)
        out1.mean().backward()
        out2.mean().backward()

        grad1 = ref[0].weight.grad
        grad2 = kbit[0].weight.grad
        bgrad1 = ref[0].bias.grad
        bgrad2 = kbit[0].bias.grad

553
554
555
556
557
558
559
560
561
        err1 = (out1-out2).abs().float()
        err2 = (grad1-grad2).abs().float()
        relerr1 = (err1/(out1.abs().float()+1e-9))
        relerr2 = (err2/(grad1.abs().float()+1e-9))
        errs1.append(err1.mean().item())
        errs2.append(err2.mean().item())
        relerrs1.append(relerr1.mean().item())
        relerrs2.append(relerr2.mean().item())

562
        if isinstance(module, bnb.nn.Linear8bitLt):
563
564
            torch.testing.assert_close(grad1, grad2, atol=0.008, rtol=0.05)
            torch.testing.assert_close(bgrad1, bgrad2, atol=0.008, rtol=0.05)
565
        else:
566
567
            torch.testing.assert_close(grad1, grad2, atol=0.015, rtol=0.05)
            torch.testing.assert_close(bgrad1, bgrad2, atol=0.02, rtol=0.05)
568
569
570
        ref.zero_grad()
        kbit.zero_grad()

571
572
        assert kbit[0].weight.grad is None or kbit[0].weight.grad.sum().item() == 0
        assert kbit[0].weight.grad is None or kbit[0].bias.grad.sum().item() == 0
573
574
575
576
    print('out', sum(errs1)/len(errs1))
    print('grad', sum(errs2)/len(errs2))
    print('rel out', sum(relerrs1)/len(relerrs1))
    print('rel grad', sum(relerrs2)/len(relerrs2))
577

578
579
580
581
582
583
def test_fp8linear():

    b = 10
    h = 1024
    inp = torch.randn(b, h).cuda()
    fp32 = torch.nn.Linear(h, h*2).cuda()
584
    fp8 = bnb.research.nn.LinearFP8Mixed(h, h*2).cuda()
585
    fp32b = torch.nn.Linear(h*2, h).cuda()
586
    fp8b = bnb.research.nn.LinearFP8Mixed(h*2, h).cuda()
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612

    fp8.weight.data.copy_(fp32.weight.data)
    fp8.bias.data.copy_(fp32.bias.data)
    fp8b.weight.data.copy_(fp32b.weight.data)
    fp8b.bias.data.copy_(fp32b.bias.data)

    a = fp32b(torch.nn.functional.gelu(fp32(inp)))
    b = fp8b(torch.nn.functional.gelu(fp8(inp)))

    err = (a-b).abs().mean()

    a.mean().backward()
    b.mean().backward()

    graderr = (fp8.weight.grad-fp32.weight.grad).abs().mean()
    bgraderr = (fp8.bias.grad-fp32.bias.grad).abs().mean()

    assert err < 0.05
    assert graderr < 0.00002
    assert bgraderr < 0.00002






613