test_modules.py 22.5 KB
Newer Older
Aarni Koskela's avatar
Aarni Koskela committed
1
import math
2

Aarni Koskela's avatar
Aarni Koskela committed
3
import einops
4
5
import pytest
import torch
Tim Dettmers's avatar
Tim Dettmers committed
6
7
from torch import nn

8
import bitsandbytes as bnb
Aarni Koskela's avatar
Aarni Koskela committed
9
from tests.helpers import id_formatter
10

11

12
class MockArgs:
Tim Dettmers's avatar
Tim Dettmers committed
13
14
15
16
    def __init__(self, initial_data):
        for key in initial_data:
            setattr(self, key, initial_data[key])

17

Tim Dettmers's avatar
Tim Dettmers committed
18
class MLP8bit(torch.nn.Module):
justheuristic's avatar
justheuristic committed
19
    def __init__(self, dim1, dim2, has_fp16_weights=True, memory_efficient_backward=False, threshold=0.0):
20
        super().__init__()
21
        self.fc1 = bnb.nn.Linear8bitLt(
justheuristic's avatar
justheuristic committed
22
23
            dim1, dim2, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward,
            threshold=threshold
24
25
        )
        self.fc2 = bnb.nn.Linear8bitLt(
justheuristic's avatar
justheuristic committed
26
27
            dim2, dim1, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward,
            threshold=threshold
28
        )
Tim Dettmers's avatar
Tim Dettmers committed
29
30
31
32
33
34
35
36
37

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x


def get_args():
    args = MockArgs([])
38
39
    args.quant_type = "vector"
    args.use_8bit_training = "full"
Tim Dettmers's avatar
Tim Dettmers committed
40
41
42
    args.clip_freq = 9999
    return args

43

Tim Dettmers's avatar
Tim Dettmers committed
44
45
def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10):
    idx = torch.isclose(a, b, rtol, atol)
46
    sumval = (idx == 0).sum().item()
Tim Dettmers's avatar
Tim Dettmers committed
47
    if sumval > count:
48
        print(f"Too many values not close: assert {sumval} < {count}")
49
        torch.testing.assert_close(a, b, rtol, atol)
Tim Dettmers's avatar
Tim Dettmers committed
50
51


52
class LinearFunction(torch.autograd.Function):
Tim Dettmers's avatar
Tim Dettmers committed
53
54
    @staticmethod
    def get_8bit_linear_trimmed(x, stochastic=False, trim_value=3.0):
55
56
57
        round_func = (
            LinearFunction.round_stoachastic if stochastic else torch.round
        )
58
59
        norm = math.sqrt(math.pi) / math.sqrt(2.0)
        # std = torch.abs(x).mean()*norm
Tim Dettmers's avatar
Tim Dettmers committed
60
        std = torch.std(x)
61
62
        max1 = std * trim_value
        x = x / max1 * 127
Tim Dettmers's avatar
Tim Dettmers committed
63
64
65
        x = round_func(x)
        x[x > 127] = 127
        x[x < -127] = -127
66
        x = x / 127 * max1
Tim Dettmers's avatar
Tim Dettmers committed
67
68
69
70

        return x

    def quant(x, quant_type, dim=1):
71
        if quant_type == "linear":
Tim Dettmers's avatar
Tim Dettmers committed
72
            max1 = torch.abs(x).max().float()
73
            xq = torch.round(x / max1 * 127).to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
74
            return xq, max1
75
        elif quant_type == "vector":
Tim Dettmers's avatar
Tim Dettmers committed
76
            max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
77
            xq = torch.round(x / max1 * 127).to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
78
            return xq, max1
79
        elif quant_type == "min-max":
Tim Dettmers's avatar
Tim Dettmers committed
80
81
            maxA = torch.amax(x, dim=dim, keepdim=True).float()
            minA = torch.amin(x, dim=dim, keepdim=True).float()
82
83
            scale = (maxA - minA) / 2.0
            xq = torch.round(127 * (x - minA - scale) / scale).to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
84
            return xq, (minA.float(), scale.float())
85
86
        else:
            return None
Tim Dettmers's avatar
Tim Dettmers committed
87
88

    def dequant(xq, S1, S2, dtype, quant_type):
89
90
        if quant_type == "linear":
            norm = S1 * S2 / (127 * 127)
Tim Dettmers's avatar
Tim Dettmers committed
91
            # double cast needed to prevent overflows
92
93
            return (xq.float() * norm).to(dtype)
        elif quant_type == "vector":
Tim Dettmers's avatar
Tim Dettmers committed
94
            x = xq.float()
95
96
97
98
99
            if len(xq.shape) == 2 and len(S1.shape) == 3:
                S1 = S1.squeeze(0)
            if len(xq.shape) == 2 and len(S2.shape) == 3:
                S2 = S2.squeeze(0)
            # print(x.shape, S1.shape, S2.shape)
Tim Dettmers's avatar
Tim Dettmers committed
100
            if len(S1.shape) == 2:
101
                x *= S1.t() / 127
Tim Dettmers's avatar
Tim Dettmers committed
102
            else:
103
104
                x *= S1 / 127
            x *= S2 / 127
Tim Dettmers's avatar
Tim Dettmers committed
105
            return x.to(dtype)
106
107
        else:
            return None
Tim Dettmers's avatar
Tim Dettmers committed
108
109

    def dequant_min_max(xq, A, B, SA, SB, dtype):
110
        offset = B.float().t().sum(0) * (SA[0] + SA[1])
Tim Dettmers's avatar
Tim Dettmers committed
111
        x = xq.float()
112
113
114
115
        if len(xq.shape) == 2 and len(SB.shape) == 3:
            SB = SB.squeeze(0)
        if len(xq.shape) == 2 and len(SA.shape) == 3:
            SA = SA.squeeze(0)
Tim Dettmers's avatar
Tim Dettmers committed
116
        if len(SB.shape) == 2:
117
            x *= SB.t() / 127
Tim Dettmers's avatar
Tim Dettmers committed
118
        else:
119
120
121
            x *= SB / 127
        x *= SA[1] / 127
        x += offset
Tim Dettmers's avatar
Tim Dettmers committed
122
123
124
        return x.to(dtype)

    def get_8bit_linear(x, stochastic=False):
125
126
127
        round_func = (
            LinearFunction.round_stoachastic if stochastic else torch.round
        )
Tim Dettmers's avatar
Tim Dettmers committed
128
        max1 = torch.abs(x).max()
129
130
131
        x = x / max1 * 127
        x = round_func(x) / 127 * max1
        # x = torch.round(x)/128*max1
Tim Dettmers's avatar
Tim Dettmers committed
132
133
134
135
        return x

    @staticmethod
    def get_8bit_vector_wise(x, dim, stochastic=False):
136
137
138
        round_func = (
            LinearFunction.round_stoachastic if stochastic else torch.round
        )
Tim Dettmers's avatar
Tim Dettmers committed
139
        max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
140
141
142
        max1[max1 == 0] = 1.0
        x = (x * 127) / max1
        x = round_func(x) / 127 * max1
Tim Dettmers's avatar
Tim Dettmers committed
143
144
145
146
147
148
        return x

    @staticmethod
    def round_stoachastic(x):
        sign = torch.sign(x)
        absx = torch.abs(x)
149
        decimal = absx - torch.floor(absx)
Tim Dettmers's avatar
Tim Dettmers committed
150
        rdm = torch.rand_like(decimal)
151
        return sign * (torch.floor(absx) + (rdm < decimal).to(x.dtype))
Tim Dettmers's avatar
Tim Dettmers committed
152
153
154
155
156
157
158
159
160
161
162
163
164

    @staticmethod
    def fake_8bit_storage(w, exponent_bits):
        code = bnb.functional.create_dynamic_map(n=exponent_bits).to(w.device)
        absmax, C = bnb.functional.quantize_blockwise(w.data, code=code)
        out = bnb.functional.dequantize_blockwise(absmax, C, code)
        out = out.half()
        w.copy_(out)
        return out

    @staticmethod
    def fake_8bit_storage_quantile(w, args):
        code = bnb.functional.estimate_quantiles(w.data, offset=args.offset)
165
166
167
168
        # C = bnb.functional.quantize_no_absmax(code, w)
        # out = bnb.functional.dequantize_no_absmax(code, C, out=w.data)
        # print(out)
        # out = out.half()
Tim Dettmers's avatar
Tim Dettmers committed
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
        code /= torch.max(torch.abs(code))
        absmax, C = bnb.functional.quantize_blockwise(w.data, code=code)
        out = bnb.functional.dequantize_blockwise(absmax, C, code)
        out = out.half()
        w.copy_(out)
        return out

    @staticmethod
    def fake_8bit_storage_stoachstic(w):
        rand = torch.rand(1024, device=w.device)
        absmax, C = bnb.functional.quantize_blockwise(w.data, rand=rand)
        out = bnb.functional.dequantize_blockwise(absmax, C)
        out = out.half()
        w.copy_(out)
        return out

    @staticmethod
    def fake_8bit_storage_with_max(w, topk=8):
187
        blocked_w = einops.rearrange(w.flatten(), "(h b) -> h b", b=256)
Tim Dettmers's avatar
Tim Dettmers committed
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
        max_val, idx = torch.sort(torch.abs(blocked_w), dim=1, descending=True)
        idx = idx[:, :topk]
        max_val = max_val[:, :topk]

        mask = torch.zeros_like(blocked_w)
        mask.scatter_(dim=1, index=idx, src=torch.ones_like(max_val))
        mask = mask.bool()

        # 1. zero out max values
        # 2. quantize + dequantize
        # 3. write back max values
        # 4. copy matrix back to weight

        values = blocked_w[mask]
        blocked_w[mask] = 0

        code = bnb.functional.create_dynamic_map()
        code = code.to(w.device)
        absmax, C = bnb.functional.quantize_blockwise(blocked_w.data)
        bnb.functional.dequantize_blockwise(absmax, C, out=blocked_w)

        blocked_w[mask] = values

        unblocked_w = blocked_w.flatten().view(w.shape)

        w.copy_(unblocked_w)
        return unblocked_w

    @staticmethod
    def forward(ctx, x, weight, bias=None, args=None):
218
        if args.use_8bit_training != "off":
Tim Dettmers's avatar
Tim Dettmers committed
219
220
221
            weight8, S1 = LinearFunction.quant(weight, args.quant_type, dim=1)
            x8, S2 = LinearFunction.quant(x, args.quant_type, dim=2)
            outputq = bnb.functional.igemm(x8, weight8.t())
222
223
224
            output = LinearFunction.dequant(
                outputq, S1, S2, x.dtype, args.quant_type
            )
225
226
227
228
229
            # if torch.rand(1) < 0.01:
            # output32 = torch.matmul(x, weight.t())
            # err = torch.abs(output-output32).float()
            # relerr = err/(torch.abs(output32).float()+1e-8)
            # print(f'{err.mean().item():.4f}, {relerr.mean().item():.4f}', args.quant_type, 'forward', proxy)
Tim Dettmers's avatar
Tim Dettmers committed
230
        else:
231
232
            # output = torch.matmul(x, weight.t())
            output = torch.einsum("bsi,oi->bso", x, weight)
Tim Dettmers's avatar
Tim Dettmers committed
233
234
235
236
237
238
239
240
241
242
243
244
245
246

        ctx.save_for_backward(x, weight, bias)
        ctx.args = args

        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        x, weight, bias = ctx.saved_tensors
        args = ctx.args
        stochastic = False
        grad_input = grad_weight = grad_bias = None
247
248
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0)
Tim Dettmers's avatar
Tim Dettmers committed
249
250
251

        # weight and x are already 8bit
        # -> transform grad_output to 8-bit
252
253
254
255
        if args.use_8bit_training == "forward+wgrad":
            grad_output8, S1 = LinearFunction.quant(
                grad_output, args.quant_type, dim=[0, 1]
            )
Tim Dettmers's avatar
Tim Dettmers committed
256
257
            x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1])
            grad_weight8 = bnb.functional.igemm(grad_output8, x8)
258
259
260
            grad_weight = LinearFunction.dequant(
                grad_weight8, S1, S2, grad_output.dtype, args.quant_type
            )
Tim Dettmers's avatar
Tim Dettmers committed
261

262
            # grad_weight32 = torch.einsum('bso,bsi->oi', grad_output, x)
Tim Dettmers's avatar
Tim Dettmers committed
263
264

            grad_input = grad_output.matmul(weight)
265
266
267
268
        elif args.use_8bit_training == "full":
            grad_output8, S1 = LinearFunction.quant(
                grad_output, args.quant_type, dim=[0, 1]
            )
Tim Dettmers's avatar
Tim Dettmers committed
269
270
271
            x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1])
            grad_weight8 = torch.zeros_like(weight, dtype=torch.int32)
            bnb.functional.igemm(grad_output8, x8, out=grad_weight8)
272
273
274
            grad_weight = LinearFunction.dequant(
                grad_weight8, S1, S2, grad_output.dtype, args.quant_type
            )
Tim Dettmers's avatar
Tim Dettmers committed
275

276
277
278
            grad_output8, S1 = LinearFunction.quant(
                grad_output, args.quant_type, dim=2
            )
Tim Dettmers's avatar
Tim Dettmers committed
279
280
            weight8, S3 = LinearFunction.quant(weight, args.quant_type, dim=0)
            grad_input8 = bnb.functional.igemm(grad_output8, weight8)
281
282
283
            grad_input = LinearFunction.dequant(
                grad_input8, S1, S3, grad_output.dtype, args.quant_type
            )
Tim Dettmers's avatar
Tim Dettmers committed
284
285
286

        else:
            grad_input = grad_output.matmul(weight)
287
            grad_weight = torch.einsum("bsi,bso->oi", x, grad_output)
288

Tim Dettmers's avatar
Tim Dettmers committed
289
        return grad_input, grad_weight, grad_bias, None
290

291

Tim Dettmers's avatar
Tim Dettmers committed
292
293
class Linear8bit(nn.Module):
    def __init__(self, input_features, output_features, bias=True, args=None):
294
        super().__init__()
Tim Dettmers's avatar
Tim Dettmers committed
295
296
297
        self.input_features = input_features
        self.output_features = output_features
        self.args = args
298

Tim Dettmers's avatar
Tim Dettmers committed
299
300
301
302
        self.weight = nn.Parameter(torch.empty(output_features, input_features))
        if bias:
            self.bias = nn.Parameter(torch.empty(output_features))
        else:
303
            self.register_parameter("bias", None)
304

Tim Dettmers's avatar
Tim Dettmers committed
305
306
307
308
309
310
311
312
313
314
        torch.nn.init.xavier_uniform_(self.weight)
        if self.bias is not None:
            torch.nn.init.zeros_(self.bias)

    def forward(self, x):
        self.args.training = self.training

        return LinearFunction.apply(x, self.weight, self.bias, self.args)


Aarni Koskela's avatar
Aarni Koskela committed
315
@pytest.mark.parametrize("threshold", [0.0, 3.0], ids=id_formatter("threshold"))
Tim Dettmers's avatar
Tim Dettmers committed
316
def test_linear8bitlt_inference(threshold):
317
318
    l1 = bnb.nn.Linear8bitLt(32, 64, threshold=threshold).cuda().half()
    assert l1.weight.device.type == "cuda"
Tim Dettmers's avatar
Tim Dettmers committed
319
320
321
    assert l1.weight.dtype == torch.float16

    l1.eval()
322
    for i in range(100):
323
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
324
325
326
327
        o1 = l1(b1)
        if i == 1:
            assert l1.state.CxB is not None

328

Tim Dettmers's avatar
Tim Dettmers committed
329
def test_linear8bitlt_accumulated_gradient():
330
331
    l1 = torch.nn.Sequential(*[bnb.nn.Linear8bitLt(32, 32).cuda().half() for i in range(2)])
    l2 = torch.nn.Sequential(*[torch.nn.Linear(32, 32).cuda().half() for i in range(2)])
332
333
334
335
336
337
338
    l1[0].weight.data.copy_(l2[0].weight.data)
    l1[1].weight.data.copy_(l2[1].weight.data)
    l1[0].bias.data.copy_(l2[0].bias.data)
    l1[1].bias.data.copy_(l2[1].bias.data)

    opt1 = bnb.optim.Adam32bit(l1.parameters(), lr=0.001)
    opt2 = bnb.optim.Adam32bit(l2.parameters(), lr=0.001)
Tim Dettmers's avatar
Tim Dettmers committed
339
340
341
342

    acc_steps = 10

    for i in range(10):
343
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
344
345
346
347
348
349
350
351
352
        o1 = l1(b1)
        o2 = l2(b1)
        loss1 = o1.mean()
        loss2 = o2.mean()
        loss1.backward()
        loss2.backward()
        if i == 2:
            assert l1[0].state.CxB is not None
            assert l1[1].state.CxB is not None
353

Tim Dettmers's avatar
Tim Dettmers committed
354
355
356
357
358
        if i > 0 and i % acc_steps == 0:
            opt1.step()
            opt1.zero_grad(True)
            opt2.step()
            opt2.zero_grad(True)
359
360
361
362
363
364
            assert_all_approx_close(
                l1[0].weight, l2[0].weight, rtol=1.05, atol=0.01, count=2
            )
            assert_all_approx_close(
                l1[1].weight, l2[1].weight, rtol=1.05, atol=0.01, count=2
            )
Tim Dettmers's avatar
Tim Dettmers committed
365
366
367
            # we do this copy because otherwise we have small divergences over time that add up
            l1[0].weight.data.copy_(l2[0].weight.data)
            l1[1].weight.data.copy_(l2[1].weight.data)
368
369
            l1[0].bias.data.copy_(l2[0].bias.data)
            l1[1].bias.data.copy_(l2[1].bias.data)
Tim Dettmers's avatar
Tim Dettmers committed
370
        else:
371
372
            torch.testing.assert_close(l1[0].weight.grad, l2[0].weight.grad, atol=1e-3, rtol=1e-3)
            torch.testing.assert_close(l1[1].weight.grad, l2[1].weight.grad, atol=1e-3, rtol=1e-3)
373
374


375
@pytest.mark.parametrize("threshold", [0.0, 2.0])
376
@pytest.mark.parametrize("memory_efficient_backward", [False])
justheuristic's avatar
justheuristic committed
377
def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
378
    l1 = (bnb.nn.Linear8bitLt( 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward).cuda().half())
Tim Dettmers's avatar
Tim Dettmers committed
379
    assert l1.weight.dtype == torch.int8
380

Tim Dettmers's avatar
Tim Dettmers committed
381
382
    l1.eval()
    for i in range(100):
383
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
384
385
386
387
388
389
        o1 = l1(b1)
        assert o1.dtype == torch.float16

    mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda()
    assert mlp.fc1.weight.dtype == torch.int8
    assert mlp.fc2.weight.dtype == torch.int8
390

Tim Dettmers's avatar
Tim Dettmers committed
391
    for i in range(100):
392
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
393
394
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
395
396
397
398
        if threshold > 0:
            assert mlp.fc1.state.idx is not None
        if threshold > 0:
            assert mlp.fc2.state.idx is not None
399

400
401
402
403
404
    mlp = (
        MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False)
        .cuda()
        .half()
    )
Tim Dettmers's avatar
Tim Dettmers committed
405
406
407
408
    assert mlp.fc1.weight.dtype == torch.int8
    assert mlp.fc2.weight.dtype == torch.int8

    for i in range(100):
409
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
410
411
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
412
413
414
415
        if threshold > 0:
            assert mlp.fc1.state.idx is not None
        if threshold > 0:
            assert mlp.fc2.state.idx is not None
416

417
418
419
420
421
    mlp = (
        MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False)
        .half()
        .cuda()
    )
Tim Dettmers's avatar
Tim Dettmers committed
422
423

    for i in range(100):
424
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
425
426
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
427
428
429
430
        if threshold > 0:
            assert mlp.fc1.state.idx is not None
        if threshold > 0:
            assert mlp.fc2.state.idx is not None
Tim Dettmers's avatar
Tim Dettmers committed
431
432
433
    assert mlp.fc1.weight.dtype == torch.int8
    assert mlp.fc2.weight.dtype == torch.int8

434
    mlp = ( MLP8bit( 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward).half().to("cuda"))
Tim Dettmers's avatar
Tim Dettmers committed
435
436

    for i in range(100):
437
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
438
439
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
440
441
442
443
        if threshold > 0:
            assert mlp.fc1.state.idx is not None
        if threshold > 0:
            assert mlp.fc2.state.idx is not None
Tim Dettmers's avatar
Tim Dettmers committed
444
445
    assert mlp.fc1.weight.dtype == torch.int8
    assert mlp.fc2.weight.dtype == torch.int8
446
447
    assert mlp.fc1.weight.device.type == "cuda"
    assert mlp.fc2.weight.device.type == "cuda"
Tim Dettmers's avatar
Tim Dettmers committed
448

justheuristic's avatar
justheuristic committed
449
    mlp = MLP8bit(
justheuristic's avatar
justheuristic committed
450
451
            32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward
        )
justheuristic's avatar
justheuristic committed
452
    w1, w2 = mlp.fc1.weight.clone().cuda(), mlp.fc2.weight.clone().cuda()  # grab weights before quantization,
justheuristic's avatar
justheuristic committed
453
    mlp = mlp.cuda().half()  # and this line triggers quantization
Tim Dettmers's avatar
Tim Dettmers committed
454
455

    for i in range(100):
456
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
457
458
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
459
460
461
462
        if threshold > 0:
            assert mlp.fc1.state.idx is not None
        if threshold > 0:
            assert mlp.fc2.state.idx is not None
justheuristic's avatar
justheuristic committed
463

Tim Dettmers's avatar
Tim Dettmers committed
464
465
    assert mlp.fc1.weight.dtype == torch.int8
    assert mlp.fc2.weight.dtype == torch.int8
466
467
    assert mlp.fc1.weight.device.type == "cuda"
    assert mlp.fc2.weight.device.type == "cuda"
468

justheuristic's avatar
justheuristic committed
469
470
471
472
473
474
475
    if memory_efficient_backward:
        b1 = torch.randn(16, 8, 32, device="cuda", requires_grad=True, dtype=torch.half)
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
        assert o1.requires_grad
        grad_proj = torch.randn_like(o1)

justheuristic's avatar
debug  
justheuristic committed
476
        mlp.zero_grad()
justheuristic's avatar
justheuristic committed
477
        (o1 * grad_proj).sum().backward()
justheuristic's avatar
justheuristic committed
478
        grad_ref = grad_proj.flatten(2) @ w2.half() @ w1.half()
justheuristic's avatar
justheuristic committed
479
        scale = grad_ref.abs().mean()
justheuristic's avatar
justheuristic committed
480

481
        torch.testing.assert_close(b1.grad, grad_ref, rtol=0, atol=0.05 * scale)
justheuristic's avatar
review  
justheuristic committed
482
        idx = torch.isclose(b1.grad, grad_ref, atol=0.01 * scale, rtol=0.1)
Tim Dettmers's avatar
Tim Dettmers committed
483
        assert (idx == 0).sum().item() <= b1.numel() * 0.005
484

justheuristic's avatar
justheuristic committed
485

486
487
488
489
490
491
492
493
@pytest.mark.parametrize(
    "module",
    [
        lambda n_in, n_out, bias=True: bnb.nn.Linear8bitLt(n_in, n_out, bias=bias, has_fp16_weights=False),
        bnb.nn.LinearFP4,
    ],
    ids=['Int8Lt', 'FP4'],
)
494
def test_linear_kbit_fp32_bias(module):
495
    # casts model to fp16 -> int8 automatically
496
497
    l1 = module(32, 64).cuda()
    assert l1.weight.dtype in [torch.int8, torch.uint8]
498
499
500
501
502
503
504
505
506
    assert l1.bias.dtype == torch.float32

    for i in range(100):
        b1 = torch.randn(16, 8, 32, device="cuda").half()
        # casts bias to fp32
        o1 = l1(b1)
        assert l1.bias.dtype == torch.float16

    # casts model to fp16 -> int8 automatically
507
508
    l1 = module(32, 64, bias=False).cuda()
    assert l1.weight.dtype in [torch.int8, torch.uint8]
509
510
511
512
513
514
    assert l1.bias is None

    for i in range(100):
        b1 = torch.randn(16, 8, 32, device="cuda").half()
        o1 = l1(b1)
        assert l1.bias is None
515

Aarni Koskela's avatar
Aarni Koskela committed
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530

module_dict = {
    "Int8Lt": bnb.nn.Linear8bitLt,
    "4bit": bnb.nn.Linear4bit,
    "FP4": bnb.nn.LinearFP4,
    "NF4": bnb.nn.LinearNF4,
    "FP4+C": lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compress_statistics=True),
    "NF4+C": lambda d1, d2: bnb.nn.LinearNF4(d1, d2, compress_statistics=True),
    "NF4+fp32": lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.float32),
    "NF4+fp16": lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.float16),
    "NF4+bf16": lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.bfloat16),
}


@pytest.mark.parametrize("module", module_dict.values(), ids=module_dict.keys())
531
532
533
534
535
536
537
def test_kbit_backprop(module):
    b = 17
    dim1 = 37
    dim2 = 83

    ref = nn.Sequential(*[torch.nn.Linear(dim1, dim2), torch.nn.Linear(dim2, 10)])
    ref[1].weight.requires_grad = False
538
539
    torch.nn.init.kaiming_normal_(ref[0].weight)
    torch.nn.init.kaiming_normal_(ref[1].weight)
540
541
542
543
544
545
546
    kbit = nn.Sequential(*[torch.nn.Linear(dim1, dim2), module(dim2, 10)])
    kbit[0].weight.detach().copy_(ref[0].weight)
    kbit[1].weight.detach().copy_(ref[1].weight)
    kbit[0].bias.detach().copy_(ref[0].bias)
    kbit[1].bias.detach().copy_(ref[1].bias)
    ref = ref.half().cuda()
    kbit = kbit.half().cuda()
547
    kbit = kbit.half().to('cuda')
548

549
550
551
552
    errs1 = []
    errs2 = []
    relerrs1 = []
    relerrs2 = []
553
554
555
556
557
558
559
560
561
562
563
564
    for i in range(100):
        batch = torch.randn(b, dim1).half().cuda()
        out1 = ref(batch)
        out2 = kbit(batch)
        out1.mean().backward()
        out2.mean().backward()

        grad1 = ref[0].weight.grad
        grad2 = kbit[0].weight.grad
        bgrad1 = ref[0].bias.grad
        bgrad2 = kbit[0].bias.grad

565
566
567
568
569
570
571
572
573
        err1 = (out1-out2).abs().float()
        err2 = (grad1-grad2).abs().float()
        relerr1 = (err1/(out1.abs().float()+1e-9))
        relerr2 = (err2/(grad1.abs().float()+1e-9))
        errs1.append(err1.mean().item())
        errs2.append(err2.mean().item())
        relerrs1.append(relerr1.mean().item())
        relerrs2.append(relerr2.mean().item())

574
        if isinstance(module, bnb.nn.Linear8bitLt):
575
            assert_all_approx_close(grad1, grad2, atol=0.008, rtol=0.05, count=1)
576
            torch.testing.assert_close(bgrad1, bgrad2, atol=0.008, rtol=0.05)
577
        else:
578
            assert_all_approx_close(grad1, grad2, atol=0.015, rtol=0.05, count=1)
579
            torch.testing.assert_close(bgrad1, bgrad2, atol=0.02, rtol=0.05)
580
581
582
        ref.zero_grad()
        kbit.zero_grad()

583
584
        assert kbit[0].weight.grad is None or kbit[0].weight.grad.sum().item() == 0
        assert kbit[0].weight.grad is None or kbit[0].bias.grad.sum().item() == 0
585
586
587
588
    #print('out', sum(errs1)/len(errs1))
    #print('grad', sum(errs2)/len(errs2))
    #print('rel out', sum(relerrs1)/len(relerrs1))
    #print('rel grad', sum(relerrs2)/len(relerrs2))
589

590
591
592
593
594
595
def test_fp8linear():

    b = 10
    h = 1024
    inp = torch.randn(b, h).cuda()
    fp32 = torch.nn.Linear(h, h*2).cuda()
596
    fp8 = bnb.research.nn.LinearFP8Mixed(h, h*2).cuda()
597
    fp32b = torch.nn.Linear(h*2, h).cuda()
598
    fp8b = bnb.research.nn.LinearFP8Mixed(h*2, h).cuda()
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619

    fp8.weight.data.copy_(fp32.weight.data)
    fp8.bias.data.copy_(fp32.bias.data)
    fp8b.weight.data.copy_(fp32b.weight.data)
    fp8b.bias.data.copy_(fp32b.bias.data)

    a = fp32b(torch.nn.functional.gelu(fp32(inp)))
    b = fp8b(torch.nn.functional.gelu(fp8(inp)))

    err = (a-b).abs().mean()

    a.mean().backward()
    b.mean().backward()

    graderr = (fp8.weight.grad-fp32.weight.grad).abs().mean()
    bgraderr = (fp8.bias.grad-fp32.bias.grad).abs().mean()

    assert err < 0.05
    assert graderr < 0.00002
    assert bgraderr < 0.00002

620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
def test_4bit_warnings():
    dim1 = 64

    with pytest.warns(UserWarning, match=r'inference or training'):
        net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)])
        net = net.cuda()
        inp = torch.rand(10, dim1).cuda().half()
        net(inp)
    with pytest.warns(UserWarning, match=r'inference.'):
        net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)])
        net = net.cuda()
        inp = torch.rand(1, dim1).cuda().half()
        net(inp)

    with pytest.warns(UserWarning) as record:

        net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)])
        net = net.cuda()
        inp = torch.rand(10, dim1).cuda().half()
        net(inp)

        net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)])
        net = net.cuda()
        inp = torch.rand(1, dim1).cuda().half()
        net(inp)

    assert len(record) == 2