test_modules.py 22.6 KB
Newer Older
Aarni Koskela's avatar
Aarni Koskela committed
1
import math
2

Aarni Koskela's avatar
Aarni Koskela committed
3
import einops
4
5
import pytest
import torch
Tim Dettmers's avatar
Tim Dettmers committed
6
7
from torch import nn

8
9
import bitsandbytes as bnb

10

11
class MockArgs:
Tim Dettmers's avatar
Tim Dettmers committed
12
13
14
15
    def __init__(self, initial_data):
        for key in initial_data:
            setattr(self, key, initial_data[key])

16

Tim Dettmers's avatar
Tim Dettmers committed
17
class MLP8bit(torch.nn.Module):
justheuristic's avatar
justheuristic committed
18
    def __init__(self, dim1, dim2, has_fp16_weights=True, memory_efficient_backward=False, threshold=0.0):
19
        super().__init__()
20
        self.fc1 = bnb.nn.Linear8bitLt(
justheuristic's avatar
justheuristic committed
21
22
            dim1, dim2, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward,
            threshold=threshold
23
24
        )
        self.fc2 = bnb.nn.Linear8bitLt(
justheuristic's avatar
justheuristic committed
25
26
            dim2, dim1, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward,
            threshold=threshold
27
        )
Tim Dettmers's avatar
Tim Dettmers committed
28
29
30
31
32
33
34
35
36

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x


def get_args():
    args = MockArgs([])
37
38
    args.quant_type = "vector"
    args.use_8bit_training = "full"
Tim Dettmers's avatar
Tim Dettmers committed
39
40
41
    args.clip_freq = 9999
    return args

42

Tim Dettmers's avatar
Tim Dettmers committed
43
44
def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10):
    idx = torch.isclose(a, b, rtol, atol)
45
    sumval = (idx == 0).sum().item()
Tim Dettmers's avatar
Tim Dettmers committed
46
    if sumval > count:
47
        print(f"Too many values not close: assert {sumval} < {count}")
48
        torch.testing.assert_close(a, b, rtol, atol)
Tim Dettmers's avatar
Tim Dettmers committed
49
50


51
class LinearFunction(torch.autograd.Function):
Tim Dettmers's avatar
Tim Dettmers committed
52
53
    @staticmethod
    def get_8bit_linear_trimmed(x, stochastic=False, trim_value=3.0):
54
55
56
        round_func = (
            LinearFunction.round_stoachastic if stochastic else torch.round
        )
57
58
        norm = math.sqrt(math.pi) / math.sqrt(2.0)
        # std = torch.abs(x).mean()*norm
Tim Dettmers's avatar
Tim Dettmers committed
59
        std = torch.std(x)
60
61
        max1 = std * trim_value
        x = x / max1 * 127
Tim Dettmers's avatar
Tim Dettmers committed
62
63
64
        x = round_func(x)
        x[x > 127] = 127
        x[x < -127] = -127
65
        x = x / 127 * max1
Tim Dettmers's avatar
Tim Dettmers committed
66
67
68
69

        return x

    def quant(x, quant_type, dim=1):
70
        if quant_type == "linear":
Tim Dettmers's avatar
Tim Dettmers committed
71
            max1 = torch.abs(x).max().float()
72
            xq = torch.round(x / max1 * 127).to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
73
            return xq, max1
74
        elif quant_type == "vector":
Tim Dettmers's avatar
Tim Dettmers committed
75
            max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
76
            xq = torch.round(x / max1 * 127).to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
77
            return xq, max1
78
        elif quant_type == "min-max":
Tim Dettmers's avatar
Tim Dettmers committed
79
80
            maxA = torch.amax(x, dim=dim, keepdim=True).float()
            minA = torch.amin(x, dim=dim, keepdim=True).float()
81
82
            scale = (maxA - minA) / 2.0
            xq = torch.round(127 * (x - minA - scale) / scale).to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
83
            return xq, (minA.float(), scale.float())
84
85
        else:
            return None
Tim Dettmers's avatar
Tim Dettmers committed
86
87

    def dequant(xq, S1, S2, dtype, quant_type):
88
89
        if quant_type == "linear":
            norm = S1 * S2 / (127 * 127)
Tim Dettmers's avatar
Tim Dettmers committed
90
            # double cast needed to prevent overflows
91
92
            return (xq.float() * norm).to(dtype)
        elif quant_type == "vector":
Tim Dettmers's avatar
Tim Dettmers committed
93
            x = xq.float()
94
95
96
97
98
            if len(xq.shape) == 2 and len(S1.shape) == 3:
                S1 = S1.squeeze(0)
            if len(xq.shape) == 2 and len(S2.shape) == 3:
                S2 = S2.squeeze(0)
            # print(x.shape, S1.shape, S2.shape)
Tim Dettmers's avatar
Tim Dettmers committed
99
            if len(S1.shape) == 2:
100
                x *= S1.t() / 127
Tim Dettmers's avatar
Tim Dettmers committed
101
            else:
102
103
                x *= S1 / 127
            x *= S2 / 127
Tim Dettmers's avatar
Tim Dettmers committed
104
            return x.to(dtype)
105
106
        else:
            return None
Tim Dettmers's avatar
Tim Dettmers committed
107
108

    def dequant_min_max(xq, A, B, SA, SB, dtype):
109
        offset = B.float().t().sum(0) * (SA[0] + SA[1])
Tim Dettmers's avatar
Tim Dettmers committed
110
        x = xq.float()
111
112
113
114
        if len(xq.shape) == 2 and len(SB.shape) == 3:
            SB = SB.squeeze(0)
        if len(xq.shape) == 2 and len(SA.shape) == 3:
            SA = SA.squeeze(0)
Tim Dettmers's avatar
Tim Dettmers committed
115
        if len(SB.shape) == 2:
116
            x *= SB.t() / 127
Tim Dettmers's avatar
Tim Dettmers committed
117
        else:
118
119
120
            x *= SB / 127
        x *= SA[1] / 127
        x += offset
Tim Dettmers's avatar
Tim Dettmers committed
121
122
123
        return x.to(dtype)

    def get_8bit_linear(x, stochastic=False):
124
125
126
        round_func = (
            LinearFunction.round_stoachastic if stochastic else torch.round
        )
Tim Dettmers's avatar
Tim Dettmers committed
127
        max1 = torch.abs(x).max()
128
129
130
        x = x / max1 * 127
        x = round_func(x) / 127 * max1
        # x = torch.round(x)/128*max1
Tim Dettmers's avatar
Tim Dettmers committed
131
132
133
134
        return x

    @staticmethod
    def get_8bit_vector_wise(x, dim, stochastic=False):
135
136
137
        round_func = (
            LinearFunction.round_stoachastic if stochastic else torch.round
        )
Tim Dettmers's avatar
Tim Dettmers committed
138
        max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
139
140
141
        max1[max1 == 0] = 1.0
        x = (x * 127) / max1
        x = round_func(x) / 127 * max1
Tim Dettmers's avatar
Tim Dettmers committed
142
143
144
145
146
147
        return x

    @staticmethod
    def round_stoachastic(x):
        sign = torch.sign(x)
        absx = torch.abs(x)
148
        decimal = absx - torch.floor(absx)
Tim Dettmers's avatar
Tim Dettmers committed
149
        rdm = torch.rand_like(decimal)
150
        return sign * (torch.floor(absx) + (rdm < decimal).to(x.dtype))
Tim Dettmers's avatar
Tim Dettmers committed
151
152
153
154
155
156
157
158
159
160
161
162
163

    @staticmethod
    def fake_8bit_storage(w, exponent_bits):
        code = bnb.functional.create_dynamic_map(n=exponent_bits).to(w.device)
        absmax, C = bnb.functional.quantize_blockwise(w.data, code=code)
        out = bnb.functional.dequantize_blockwise(absmax, C, code)
        out = out.half()
        w.copy_(out)
        return out

    @staticmethod
    def fake_8bit_storage_quantile(w, args):
        code = bnb.functional.estimate_quantiles(w.data, offset=args.offset)
164
165
166
167
        # C = bnb.functional.quantize_no_absmax(code, w)
        # out = bnb.functional.dequantize_no_absmax(code, C, out=w.data)
        # print(out)
        # out = out.half()
Tim Dettmers's avatar
Tim Dettmers committed
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
        code /= torch.max(torch.abs(code))
        absmax, C = bnb.functional.quantize_blockwise(w.data, code=code)
        out = bnb.functional.dequantize_blockwise(absmax, C, code)
        out = out.half()
        w.copy_(out)
        return out

    @staticmethod
    def fake_8bit_storage_stoachstic(w):
        rand = torch.rand(1024, device=w.device)
        absmax, C = bnb.functional.quantize_blockwise(w.data, rand=rand)
        out = bnb.functional.dequantize_blockwise(absmax, C)
        out = out.half()
        w.copy_(out)
        return out

    @staticmethod
    def fake_8bit_storage_with_max(w, topk=8):
186
        blocked_w = einops.rearrange(w.flatten(), "(h b) -> h b", b=256)
Tim Dettmers's avatar
Tim Dettmers committed
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
        max_val, idx = torch.sort(torch.abs(blocked_w), dim=1, descending=True)
        idx = idx[:, :topk]
        max_val = max_val[:, :topk]

        mask = torch.zeros_like(blocked_w)
        mask.scatter_(dim=1, index=idx, src=torch.ones_like(max_val))
        mask = mask.bool()

        # 1. zero out max values
        # 2. quantize + dequantize
        # 3. write back max values
        # 4. copy matrix back to weight

        values = blocked_w[mask]
        blocked_w[mask] = 0

        code = bnb.functional.create_dynamic_map()
        code = code.to(w.device)
        absmax, C = bnb.functional.quantize_blockwise(blocked_w.data)
        bnb.functional.dequantize_blockwise(absmax, C, out=blocked_w)

        blocked_w[mask] = values

        unblocked_w = blocked_w.flatten().view(w.shape)

        w.copy_(unblocked_w)
        return unblocked_w

    @staticmethod
    def forward(ctx, x, weight, bias=None, args=None):
217
        if args.use_8bit_training != "off":
Tim Dettmers's avatar
Tim Dettmers committed
218
219
220
            weight8, S1 = LinearFunction.quant(weight, args.quant_type, dim=1)
            x8, S2 = LinearFunction.quant(x, args.quant_type, dim=2)
            outputq = bnb.functional.igemm(x8, weight8.t())
221
222
223
            output = LinearFunction.dequant(
                outputq, S1, S2, x.dtype, args.quant_type
            )
224
225
226
227
228
            # if torch.rand(1) < 0.01:
            # output32 = torch.matmul(x, weight.t())
            # err = torch.abs(output-output32).float()
            # relerr = err/(torch.abs(output32).float()+1e-8)
            # print(f'{err.mean().item():.4f}, {relerr.mean().item():.4f}', args.quant_type, 'forward', proxy)
Tim Dettmers's avatar
Tim Dettmers committed
229
        else:
230
231
            # output = torch.matmul(x, weight.t())
            output = torch.einsum("bsi,oi->bso", x, weight)
Tim Dettmers's avatar
Tim Dettmers committed
232
233
234
235
236
237
238
239
240
241
242
243
244
245

        ctx.save_for_backward(x, weight, bias)
        ctx.args = args

        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        x, weight, bias = ctx.saved_tensors
        args = ctx.args
        stochastic = False
        grad_input = grad_weight = grad_bias = None
246
247
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0)
Tim Dettmers's avatar
Tim Dettmers committed
248
249
250

        # weight and x are already 8bit
        # -> transform grad_output to 8-bit
251
252
253
254
        if args.use_8bit_training == "forward+wgrad":
            grad_output8, S1 = LinearFunction.quant(
                grad_output, args.quant_type, dim=[0, 1]
            )
Tim Dettmers's avatar
Tim Dettmers committed
255
256
            x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1])
            grad_weight8 = bnb.functional.igemm(grad_output8, x8)
257
258
259
            grad_weight = LinearFunction.dequant(
                grad_weight8, S1, S2, grad_output.dtype, args.quant_type
            )
Tim Dettmers's avatar
Tim Dettmers committed
260

261
            # grad_weight32 = torch.einsum('bso,bsi->oi', grad_output, x)
Tim Dettmers's avatar
Tim Dettmers committed
262
263

            grad_input = grad_output.matmul(weight)
264
265
266
267
        elif args.use_8bit_training == "full":
            grad_output8, S1 = LinearFunction.quant(
                grad_output, args.quant_type, dim=[0, 1]
            )
Tim Dettmers's avatar
Tim Dettmers committed
268
269
270
            x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1])
            grad_weight8 = torch.zeros_like(weight, dtype=torch.int32)
            bnb.functional.igemm(grad_output8, x8, out=grad_weight8)
271
272
273
            grad_weight = LinearFunction.dequant(
                grad_weight8, S1, S2, grad_output.dtype, args.quant_type
            )
Tim Dettmers's avatar
Tim Dettmers committed
274

275
276
277
            grad_output8, S1 = LinearFunction.quant(
                grad_output, args.quant_type, dim=2
            )
Tim Dettmers's avatar
Tim Dettmers committed
278
279
            weight8, S3 = LinearFunction.quant(weight, args.quant_type, dim=0)
            grad_input8 = bnb.functional.igemm(grad_output8, weight8)
280
281
282
            grad_input = LinearFunction.dequant(
                grad_input8, S1, S3, grad_output.dtype, args.quant_type
            )
Tim Dettmers's avatar
Tim Dettmers committed
283
284
285

        else:
            grad_input = grad_output.matmul(weight)
286
            grad_weight = torch.einsum("bsi,bso->oi", x, grad_output)
287

Tim Dettmers's avatar
Tim Dettmers committed
288
        return grad_input, grad_weight, grad_bias, None
289

290

Tim Dettmers's avatar
Tim Dettmers committed
291
292
class Linear8bit(nn.Module):
    def __init__(self, input_features, output_features, bias=True, args=None):
293
        super().__init__()
Tim Dettmers's avatar
Tim Dettmers committed
294
295
296
        self.input_features = input_features
        self.output_features = output_features
        self.args = args
297

Tim Dettmers's avatar
Tim Dettmers committed
298
299
300
301
        self.weight = nn.Parameter(torch.empty(output_features, input_features))
        if bias:
            self.bias = nn.Parameter(torch.empty(output_features))
        else:
302
            self.register_parameter("bias", None)
303

Tim Dettmers's avatar
Tim Dettmers committed
304
305
306
307
308
309
310
311
312
313
314
315
        torch.nn.init.xavier_uniform_(self.weight)
        if self.bias is not None:
            torch.nn.init.zeros_(self.bias)

    def forward(self, x):
        self.args.training = self.training

        return LinearFunction.apply(x, self.weight, self.bias, self.args)


threshold = [0.0, 3.0]
values = threshold
316
names = [f"threshold_{vals}" for vals in values]
317
318


Tim Dettmers's avatar
Tim Dettmers committed
319
320
@pytest.mark.parametrize("threshold", values, ids=names)
def test_linear8bitlt_inference(threshold):
321
322
    l1 = bnb.nn.Linear8bitLt(32, 64, threshold=threshold).cuda().half()
    assert l1.weight.device.type == "cuda"
Tim Dettmers's avatar
Tim Dettmers committed
323
324
325
    assert l1.weight.dtype == torch.float16

    l1.eval()
326
    for i in range(100):
327
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
328
329
330
331
        o1 = l1(b1)
        if i == 1:
            assert l1.state.CxB is not None

332

Tim Dettmers's avatar
Tim Dettmers committed
333
def test_linear8bitlt_accumulated_gradient():
334
335
    l1 = torch.nn.Sequential(*[bnb.nn.Linear8bitLt(32, 32).cuda().half() for i in range(2)])
    l2 = torch.nn.Sequential(*[torch.nn.Linear(32, 32).cuda().half() for i in range(2)])
336
337
338
339
340
341
342
    l1[0].weight.data.copy_(l2[0].weight.data)
    l1[1].weight.data.copy_(l2[1].weight.data)
    l1[0].bias.data.copy_(l2[0].bias.data)
    l1[1].bias.data.copy_(l2[1].bias.data)

    opt1 = bnb.optim.Adam32bit(l1.parameters(), lr=0.001)
    opt2 = bnb.optim.Adam32bit(l2.parameters(), lr=0.001)
Tim Dettmers's avatar
Tim Dettmers committed
343
344
345
346

    acc_steps = 10

    for i in range(10):
347
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
348
349
350
351
352
353
354
355
356
        o1 = l1(b1)
        o2 = l2(b1)
        loss1 = o1.mean()
        loss2 = o2.mean()
        loss1.backward()
        loss2.backward()
        if i == 2:
            assert l1[0].state.CxB is not None
            assert l1[1].state.CxB is not None
357

Tim Dettmers's avatar
Tim Dettmers committed
358
359
360
361
362
        if i > 0 and i % acc_steps == 0:
            opt1.step()
            opt1.zero_grad(True)
            opt2.step()
            opt2.zero_grad(True)
363
364
365
366
367
368
            assert_all_approx_close(
                l1[0].weight, l2[0].weight, rtol=1.05, atol=0.01, count=2
            )
            assert_all_approx_close(
                l1[1].weight, l2[1].weight, rtol=1.05, atol=0.01, count=2
            )
Tim Dettmers's avatar
Tim Dettmers committed
369
370
371
            # we do this copy because otherwise we have small divergences over time that add up
            l1[0].weight.data.copy_(l2[0].weight.data)
            l1[1].weight.data.copy_(l2[1].weight.data)
372
373
            l1[0].bias.data.copy_(l2[0].bias.data)
            l1[1].bias.data.copy_(l2[1].bias.data)
Tim Dettmers's avatar
Tim Dettmers committed
374
        else:
375
376
            torch.testing.assert_close(l1[0].weight.grad, l2[0].weight.grad, atol=1e-3, rtol=1e-3)
            torch.testing.assert_close(l1[1].weight.grad, l2[1].weight.grad, atol=1e-3, rtol=1e-3)
377
378


379
@pytest.mark.parametrize("threshold", [0.0, 2.0])
380
@pytest.mark.parametrize("memory_efficient_backward", [False])
justheuristic's avatar
justheuristic committed
381
def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
382
    l1 = (bnb.nn.Linear8bitLt( 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward).cuda().half())
Tim Dettmers's avatar
Tim Dettmers committed
383
    assert l1.weight.dtype == torch.int8
384

Tim Dettmers's avatar
Tim Dettmers committed
385
386
    l1.eval()
    for i in range(100):
387
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
388
389
390
391
392
393
        o1 = l1(b1)
        assert o1.dtype == torch.float16

    mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda()
    assert mlp.fc1.weight.dtype == torch.int8
    assert mlp.fc2.weight.dtype == torch.int8
394

Tim Dettmers's avatar
Tim Dettmers committed
395
    for i in range(100):
396
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
397
398
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
399
400
401
402
        if threshold > 0:
            assert mlp.fc1.state.idx is not None
        if threshold > 0:
            assert mlp.fc2.state.idx is not None
403

404
405
406
407
408
    mlp = (
        MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False)
        .cuda()
        .half()
    )
Tim Dettmers's avatar
Tim Dettmers committed
409
410
411
412
    assert mlp.fc1.weight.dtype == torch.int8
    assert mlp.fc2.weight.dtype == torch.int8

    for i in range(100):
413
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
414
415
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
416
417
418
419
        if threshold > 0:
            assert mlp.fc1.state.idx is not None
        if threshold > 0:
            assert mlp.fc2.state.idx is not None
420

421
422
423
424
425
    mlp = (
        MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False)
        .half()
        .cuda()
    )
Tim Dettmers's avatar
Tim Dettmers committed
426
427

    for i in range(100):
428
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
429
430
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
431
432
433
434
        if threshold > 0:
            assert mlp.fc1.state.idx is not None
        if threshold > 0:
            assert mlp.fc2.state.idx is not None
Tim Dettmers's avatar
Tim Dettmers committed
435
436
437
    assert mlp.fc1.weight.dtype == torch.int8
    assert mlp.fc2.weight.dtype == torch.int8

438
    mlp = ( MLP8bit( 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward).half().to("cuda"))
Tim Dettmers's avatar
Tim Dettmers committed
439
440

    for i in range(100):
441
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
442
443
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
444
445
446
447
        if threshold > 0:
            assert mlp.fc1.state.idx is not None
        if threshold > 0:
            assert mlp.fc2.state.idx is not None
Tim Dettmers's avatar
Tim Dettmers committed
448
449
    assert mlp.fc1.weight.dtype == torch.int8
    assert mlp.fc2.weight.dtype == torch.int8
450
451
    assert mlp.fc1.weight.device.type == "cuda"
    assert mlp.fc2.weight.device.type == "cuda"
Tim Dettmers's avatar
Tim Dettmers committed
452

justheuristic's avatar
justheuristic committed
453
    mlp = MLP8bit(
justheuristic's avatar
justheuristic committed
454
455
            32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward
        )
justheuristic's avatar
justheuristic committed
456
    w1, w2 = mlp.fc1.weight.clone().cuda(), mlp.fc2.weight.clone().cuda()  # grab weights before quantization,
justheuristic's avatar
justheuristic committed
457
    mlp = mlp.cuda().half()  # and this line triggers quantization
Tim Dettmers's avatar
Tim Dettmers committed
458
459

    for i in range(100):
460
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
461
462
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
463
464
465
466
        if threshold > 0:
            assert mlp.fc1.state.idx is not None
        if threshold > 0:
            assert mlp.fc2.state.idx is not None
justheuristic's avatar
justheuristic committed
467

Tim Dettmers's avatar
Tim Dettmers committed
468
469
    assert mlp.fc1.weight.dtype == torch.int8
    assert mlp.fc2.weight.dtype == torch.int8
470
471
    assert mlp.fc1.weight.device.type == "cuda"
    assert mlp.fc2.weight.device.type == "cuda"
472

justheuristic's avatar
justheuristic committed
473
474
475
476
477
478
479
    if memory_efficient_backward:
        b1 = torch.randn(16, 8, 32, device="cuda", requires_grad=True, dtype=torch.half)
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
        assert o1.requires_grad
        grad_proj = torch.randn_like(o1)

justheuristic's avatar
debug  
justheuristic committed
480
        mlp.zero_grad()
justheuristic's avatar
justheuristic committed
481
        (o1 * grad_proj).sum().backward()
justheuristic's avatar
justheuristic committed
482
        grad_ref = grad_proj.flatten(2) @ w2.half() @ w1.half()
justheuristic's avatar
justheuristic committed
483
        scale = grad_ref.abs().mean()
justheuristic's avatar
justheuristic committed
484

485
        torch.testing.assert_close(b1.grad, grad_ref, rtol=0, atol=0.05 * scale)
justheuristic's avatar
review  
justheuristic committed
486
        idx = torch.isclose(b1.grad, grad_ref, atol=0.01 * scale, rtol=0.1)
Tim Dettmers's avatar
Tim Dettmers committed
487
        assert (idx == 0).sum().item() <= b1.numel() * 0.005
488

justheuristic's avatar
justheuristic committed
489

490
491
@pytest.mark.parametrize("module", [lambda nin, nout, bias=True: bnb.nn.Linear8bitLt(nin, nout, bias=bias, has_fp16_weights=False), bnb.nn.LinearFP4], ids=['Int8Lt', 'FP4'])
def test_linear_kbit_fp32_bias(module):
492
    # casts model to fp16 -> int8 automatically
493
494
    l1 = module(32, 64).cuda()
    assert l1.weight.dtype in [torch.int8, torch.uint8]
495
496
497
498
499
500
501
502
503
    assert l1.bias.dtype == torch.float32

    for i in range(100):
        b1 = torch.randn(16, 8, 32, device="cuda").half()
        # casts bias to fp32
        o1 = l1(b1)
        assert l1.bias.dtype == torch.float16

    # casts model to fp16 -> int8 automatically
504
505
    l1 = module(32, 64, bias=False).cuda()
    assert l1.weight.dtype in [torch.int8, torch.uint8]
506
507
508
509
510
511
    assert l1.bias is None

    for i in range(100):
        b1 = torch.randn(16, 8, 32, device="cuda").half()
        o1 = l1(b1)
        assert l1.bias is None
512

513
514
515
516
517
518
519
modules = []
modules.append(bnb.nn.Linear8bitLt)
modules.append(bnb.nn.Linear4bit)
modules.append(bnb.nn.LinearFP4)
modules.append(bnb.nn.LinearNF4)
modules.append(lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compress_statistics=True))
modules.append(lambda d1, d2: bnb.nn.LinearNF4(d1, d2, compress_statistics=True))
520
521
522
523
modules.append(lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.float32))
modules.append(lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.float16))
modules.append(lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.bfloat16))
names = ['Int8Lt', '4bit', 'FP4', 'NF4', 'FP4+C', 'NF4+C', 'NF4+fp32', 'NF4+fp16', 'NF4+bf16']
524
@pytest.mark.parametrize("module", modules, ids=names)
525
526
527
528
529
530
531
def test_kbit_backprop(module):
    b = 17
    dim1 = 37
    dim2 = 83

    ref = nn.Sequential(*[torch.nn.Linear(dim1, dim2), torch.nn.Linear(dim2, 10)])
    ref[1].weight.requires_grad = False
532
533
    torch.nn.init.kaiming_normal_(ref[0].weight)
    torch.nn.init.kaiming_normal_(ref[1].weight)
534
535
536
537
538
539
540
    kbit = nn.Sequential(*[torch.nn.Linear(dim1, dim2), module(dim2, 10)])
    kbit[0].weight.detach().copy_(ref[0].weight)
    kbit[1].weight.detach().copy_(ref[1].weight)
    kbit[0].bias.detach().copy_(ref[0].bias)
    kbit[1].bias.detach().copy_(ref[1].bias)
    ref = ref.half().cuda()
    kbit = kbit.half().cuda()
541
    kbit = kbit.half().to('cuda')
542

543
544
545
546
    errs1 = []
    errs2 = []
    relerrs1 = []
    relerrs2 = []
547
548
549
550
551
552
553
554
555
556
557
558
    for i in range(100):
        batch = torch.randn(b, dim1).half().cuda()
        out1 = ref(batch)
        out2 = kbit(batch)
        out1.mean().backward()
        out2.mean().backward()

        grad1 = ref[0].weight.grad
        grad2 = kbit[0].weight.grad
        bgrad1 = ref[0].bias.grad
        bgrad2 = kbit[0].bias.grad

559
560
561
562
563
564
565
566
567
        err1 = (out1-out2).abs().float()
        err2 = (grad1-grad2).abs().float()
        relerr1 = (err1/(out1.abs().float()+1e-9))
        relerr2 = (err2/(grad1.abs().float()+1e-9))
        errs1.append(err1.mean().item())
        errs2.append(err2.mean().item())
        relerrs1.append(relerr1.mean().item())
        relerrs2.append(relerr2.mean().item())

568
        if isinstance(module, bnb.nn.Linear8bitLt):
569
            assert_all_approx_close(grad1, grad2, atol=0.008, rtol=0.05, count=1)
570
            torch.testing.assert_close(bgrad1, bgrad2, atol=0.008, rtol=0.05)
571
        else:
572
            assert_all_approx_close(grad1, grad2, atol=0.015, rtol=0.05, count=1)
573
            torch.testing.assert_close(bgrad1, bgrad2, atol=0.02, rtol=0.05)
574
575
576
        ref.zero_grad()
        kbit.zero_grad()

577
578
        assert kbit[0].weight.grad is None or kbit[0].weight.grad.sum().item() == 0
        assert kbit[0].weight.grad is None or kbit[0].bias.grad.sum().item() == 0
579
580
581
582
    #print('out', sum(errs1)/len(errs1))
    #print('grad', sum(errs2)/len(errs2))
    #print('rel out', sum(relerrs1)/len(relerrs1))
    #print('rel grad', sum(relerrs2)/len(relerrs2))
583

584
585
586
587
588
589
def test_fp8linear():

    b = 10
    h = 1024
    inp = torch.randn(b, h).cuda()
    fp32 = torch.nn.Linear(h, h*2).cuda()
590
    fp8 = bnb.research.nn.LinearFP8Mixed(h, h*2).cuda()
591
    fp32b = torch.nn.Linear(h*2, h).cuda()
592
    fp8b = bnb.research.nn.LinearFP8Mixed(h*2, h).cuda()
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613

    fp8.weight.data.copy_(fp32.weight.data)
    fp8.bias.data.copy_(fp32.bias.data)
    fp8b.weight.data.copy_(fp32b.weight.data)
    fp8b.bias.data.copy_(fp32b.bias.data)

    a = fp32b(torch.nn.functional.gelu(fp32(inp)))
    b = fp8b(torch.nn.functional.gelu(fp8(inp)))

    err = (a-b).abs().mean()

    a.mean().backward()
    b.mean().backward()

    graderr = (fp8.weight.grad-fp32.weight.grad).abs().mean()
    bgraderr = (fp8.bias.grad-fp32.bias.grad).abs().mean()

    assert err < 0.05
    assert graderr < 0.00002
    assert bgraderr < 0.00002

614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
def test_4bit_warnings():
    dim1 = 64

    with pytest.warns(UserWarning, match=r'inference or training'):
        net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)])
        net = net.cuda()
        inp = torch.rand(10, dim1).cuda().half()
        net(inp)
    with pytest.warns(UserWarning, match=r'inference.'):
        net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)])
        net = net.cuda()
        inp = torch.rand(1, dim1).cuda().half()
        net(inp)

    with pytest.warns(UserWarning) as record:

        net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)])
        net = net.cuda()
        inp = torch.rand(10, dim1).cuda().half()
        net(inp)

        net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)])
        net = net.cuda()
        inp = torch.rand(1, dim1).cuda().half()
        net(inp)

    assert len(record) == 2
641
642


643