test_modules.py 22.4 KB
Newer Older
Aarni Koskela's avatar
Aarni Koskela committed
1
import math
2

Aarni Koskela's avatar
Aarni Koskela committed
3
import einops
4
5
import pytest
import torch
Tim Dettmers's avatar
Tim Dettmers committed
6
7
from torch import nn

8
import bitsandbytes as bnb
Aarni Koskela's avatar
Aarni Koskela committed
9
from tests.helpers import id_formatter
10

11

12
class MockArgs:
Tim Dettmers's avatar
Tim Dettmers committed
13
14
15
16
    def __init__(self, initial_data):
        for key in initial_data:
            setattr(self, key, initial_data[key])

17

Tim Dettmers's avatar
Tim Dettmers committed
18
class MLP8bit(torch.nn.Module):
justheuristic's avatar
justheuristic committed
19
    def __init__(self, dim1, dim2, has_fp16_weights=True, memory_efficient_backward=False, threshold=0.0):
20
        super().__init__()
21
        self.fc1 = bnb.nn.Linear8bitLt(
Ruff's avatar
Ruff committed
22
23
24
25
26
            dim1,
            dim2,
            has_fp16_weights=has_fp16_weights,
            memory_efficient_backward=memory_efficient_backward,
            threshold=threshold,
27
28
        )
        self.fc2 = bnb.nn.Linear8bitLt(
Ruff's avatar
Ruff committed
29
30
31
32
33
            dim2,
            dim1,
            has_fp16_weights=has_fp16_weights,
            memory_efficient_backward=memory_efficient_backward,
            threshold=threshold,
34
        )
Tim Dettmers's avatar
Tim Dettmers committed
35
36
37
38
39
40
41
42
43

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x


def get_args():
    args = MockArgs([])
44
45
    args.quant_type = "vector"
    args.use_8bit_training = "full"
Tim Dettmers's avatar
Tim Dettmers committed
46
47
48
    args.clip_freq = 9999
    return args

49

Tim Dettmers's avatar
Tim Dettmers committed
50
def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10):
51
    idx = torch.isclose(a, b, rtol=rtol, atol=atol)
52
    sumval = (idx == 0).sum().item()
Tim Dettmers's avatar
Tim Dettmers committed
53
    if sumval > count:
54
        print(f"Too many values not close: assert {sumval} < {count}")
55
        torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
Tim Dettmers's avatar
Tim Dettmers committed
56
57


58
class LinearFunction(torch.autograd.Function):
Tim Dettmers's avatar
Tim Dettmers committed
59
60
    @staticmethod
    def get_8bit_linear_trimmed(x, stochastic=False, trim_value=3.0):
Ruff's avatar
Ruff committed
61
        round_func = LinearFunction.round_stoachastic if stochastic else torch.round
62
63
        norm = math.sqrt(math.pi) / math.sqrt(2.0)
        # std = torch.abs(x).mean()*norm
Tim Dettmers's avatar
Tim Dettmers committed
64
        std = torch.std(x)
65
66
        max1 = std * trim_value
        x = x / max1 * 127
Tim Dettmers's avatar
Tim Dettmers committed
67
68
69
        x = round_func(x)
        x[x > 127] = 127
        x[x < -127] = -127
70
        x = x / 127 * max1
Tim Dettmers's avatar
Tim Dettmers committed
71
72
73
74

        return x

    def quant(x, quant_type, dim=1):
75
        if quant_type == "linear":
Tim Dettmers's avatar
Tim Dettmers committed
76
            max1 = torch.abs(x).max().float()
77
            xq = torch.round(x / max1 * 127).to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
78
            return xq, max1
79
        elif quant_type == "vector":
Tim Dettmers's avatar
Tim Dettmers committed
80
            max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
81
            xq = torch.round(x / max1 * 127).to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
82
            return xq, max1
83
        elif quant_type == "min-max":
Tim Dettmers's avatar
Tim Dettmers committed
84
85
            maxA = torch.amax(x, dim=dim, keepdim=True).float()
            minA = torch.amin(x, dim=dim, keepdim=True).float()
86
87
            scale = (maxA - minA) / 2.0
            xq = torch.round(127 * (x - minA - scale) / scale).to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
88
            return xq, (minA.float(), scale.float())
89
90
        else:
            return None
Tim Dettmers's avatar
Tim Dettmers committed
91
92

    def dequant(xq, S1, S2, dtype, quant_type):
93
94
        if quant_type == "linear":
            norm = S1 * S2 / (127 * 127)
Tim Dettmers's avatar
Tim Dettmers committed
95
            # double cast needed to prevent overflows
96
97
            return (xq.float() * norm).to(dtype)
        elif quant_type == "vector":
Tim Dettmers's avatar
Tim Dettmers committed
98
            x = xq.float()
99
100
101
102
103
            if len(xq.shape) == 2 and len(S1.shape) == 3:
                S1 = S1.squeeze(0)
            if len(xq.shape) == 2 and len(S2.shape) == 3:
                S2 = S2.squeeze(0)
            # print(x.shape, S1.shape, S2.shape)
Tim Dettmers's avatar
Tim Dettmers committed
104
            if len(S1.shape) == 2:
105
                x *= S1.t() / 127
Tim Dettmers's avatar
Tim Dettmers committed
106
            else:
107
108
                x *= S1 / 127
            x *= S2 / 127
Tim Dettmers's avatar
Tim Dettmers committed
109
            return x.to(dtype)
110
111
        else:
            return None
Tim Dettmers's avatar
Tim Dettmers committed
112
113

    def dequant_min_max(xq, A, B, SA, SB, dtype):
114
        offset = B.float().t().sum(0) * (SA[0] + SA[1])
Tim Dettmers's avatar
Tim Dettmers committed
115
        x = xq.float()
116
117
118
119
        if len(xq.shape) == 2 and len(SB.shape) == 3:
            SB = SB.squeeze(0)
        if len(xq.shape) == 2 and len(SA.shape) == 3:
            SA = SA.squeeze(0)
Tim Dettmers's avatar
Tim Dettmers committed
120
        if len(SB.shape) == 2:
121
            x *= SB.t() / 127
Tim Dettmers's avatar
Tim Dettmers committed
122
        else:
123
124
125
            x *= SB / 127
        x *= SA[1] / 127
        x += offset
Tim Dettmers's avatar
Tim Dettmers committed
126
127
128
        return x.to(dtype)

    def get_8bit_linear(x, stochastic=False):
Ruff's avatar
Ruff committed
129
        round_func = LinearFunction.round_stoachastic if stochastic else torch.round
Tim Dettmers's avatar
Tim Dettmers committed
130
        max1 = torch.abs(x).max()
131
132
133
        x = x / max1 * 127
        x = round_func(x) / 127 * max1
        # x = torch.round(x)/128*max1
Tim Dettmers's avatar
Tim Dettmers committed
134
135
136
137
        return x

    @staticmethod
    def get_8bit_vector_wise(x, dim, stochastic=False):
Ruff's avatar
Ruff committed
138
        round_func = LinearFunction.round_stoachastic if stochastic else torch.round
Tim Dettmers's avatar
Tim Dettmers committed
139
        max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
140
141
142
        max1[max1 == 0] = 1.0
        x = (x * 127) / max1
        x = round_func(x) / 127 * max1
Tim Dettmers's avatar
Tim Dettmers committed
143
144
145
146
147
148
        return x

    @staticmethod
    def round_stoachastic(x):
        sign = torch.sign(x)
        absx = torch.abs(x)
149
        decimal = absx - torch.floor(absx)
Tim Dettmers's avatar
Tim Dettmers committed
150
        rdm = torch.rand_like(decimal)
151
        return sign * (torch.floor(absx) + (rdm < decimal).to(x.dtype))
Tim Dettmers's avatar
Tim Dettmers committed
152
153
154
155
156
157
158
159
160
161
162
163
164

    @staticmethod
    def fake_8bit_storage(w, exponent_bits):
        code = bnb.functional.create_dynamic_map(n=exponent_bits).to(w.device)
        absmax, C = bnb.functional.quantize_blockwise(w.data, code=code)
        out = bnb.functional.dequantize_blockwise(absmax, C, code)
        out = out.half()
        w.copy_(out)
        return out

    @staticmethod
    def fake_8bit_storage_quantile(w, args):
        code = bnb.functional.estimate_quantiles(w.data, offset=args.offset)
165
166
167
168
        # C = bnb.functional.quantize_no_absmax(code, w)
        # out = bnb.functional.dequantize_no_absmax(code, C, out=w.data)
        # print(out)
        # out = out.half()
Tim Dettmers's avatar
Tim Dettmers committed
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
        code /= torch.max(torch.abs(code))
        absmax, C = bnb.functional.quantize_blockwise(w.data, code=code)
        out = bnb.functional.dequantize_blockwise(absmax, C, code)
        out = out.half()
        w.copy_(out)
        return out

    @staticmethod
    def fake_8bit_storage_stoachstic(w):
        rand = torch.rand(1024, device=w.device)
        absmax, C = bnb.functional.quantize_blockwise(w.data, rand=rand)
        out = bnb.functional.dequantize_blockwise(absmax, C)
        out = out.half()
        w.copy_(out)
        return out

    @staticmethod
    def fake_8bit_storage_with_max(w, topk=8):
187
        blocked_w = einops.rearrange(w.flatten(), "(h b) -> h b", b=256)
Tim Dettmers's avatar
Tim Dettmers committed
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
        max_val, idx = torch.sort(torch.abs(blocked_w), dim=1, descending=True)
        idx = idx[:, :topk]
        max_val = max_val[:, :topk]

        mask = torch.zeros_like(blocked_w)
        mask.scatter_(dim=1, index=idx, src=torch.ones_like(max_val))
        mask = mask.bool()

        # 1. zero out max values
        # 2. quantize + dequantize
        # 3. write back max values
        # 4. copy matrix back to weight

        values = blocked_w[mask]
        blocked_w[mask] = 0

        code = bnb.functional.create_dynamic_map()
        code = code.to(w.device)
        absmax, C = bnb.functional.quantize_blockwise(blocked_w.data)
        bnb.functional.dequantize_blockwise(absmax, C, out=blocked_w)

        blocked_w[mask] = values

        unblocked_w = blocked_w.flatten().view(w.shape)

        w.copy_(unblocked_w)
        return unblocked_w

    @staticmethod
    def forward(ctx, x, weight, bias=None, args=None):
218
        if args.use_8bit_training != "off":
Tim Dettmers's avatar
Tim Dettmers committed
219
220
221
            weight8, S1 = LinearFunction.quant(weight, args.quant_type, dim=1)
            x8, S2 = LinearFunction.quant(x, args.quant_type, dim=2)
            outputq = bnb.functional.igemm(x8, weight8.t())
Ruff's avatar
Ruff committed
222
            output = LinearFunction.dequant(outputq, S1, S2, x.dtype, args.quant_type)
223
224
225
226
227
            # if torch.rand(1) < 0.01:
            # output32 = torch.matmul(x, weight.t())
            # err = torch.abs(output-output32).float()
            # relerr = err/(torch.abs(output32).float()+1e-8)
            # print(f'{err.mean().item():.4f}, {relerr.mean().item():.4f}', args.quant_type, 'forward', proxy)
Tim Dettmers's avatar
Tim Dettmers committed
228
        else:
229
230
            # output = torch.matmul(x, weight.t())
            output = torch.einsum("bsi,oi->bso", x, weight)
Tim Dettmers's avatar
Tim Dettmers committed
231
232
233
234
235
236
237
238
239
240
241
242
243
244

        ctx.save_for_backward(x, weight, bias)
        ctx.args = args

        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        x, weight, bias = ctx.saved_tensors
        args = ctx.args
        stochastic = False
        grad_input = grad_weight = grad_bias = None
245
246
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0)
Tim Dettmers's avatar
Tim Dettmers committed
247
248
249

        # weight and x are already 8bit
        # -> transform grad_output to 8-bit
250
        if args.use_8bit_training == "forward+wgrad":
Ruff's avatar
Ruff committed
251
            grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=[0, 1])
Tim Dettmers's avatar
Tim Dettmers committed
252
253
            x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1])
            grad_weight8 = bnb.functional.igemm(grad_output8, x8)
Ruff's avatar
Ruff committed
254
            grad_weight = LinearFunction.dequant(grad_weight8, S1, S2, grad_output.dtype, args.quant_type)
Tim Dettmers's avatar
Tim Dettmers committed
255

256
            # grad_weight32 = torch.einsum('bso,bsi->oi', grad_output, x)
Tim Dettmers's avatar
Tim Dettmers committed
257
258

            grad_input = grad_output.matmul(weight)
259
        elif args.use_8bit_training == "full":
Ruff's avatar
Ruff committed
260
            grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=[0, 1])
Tim Dettmers's avatar
Tim Dettmers committed
261
262
263
            x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1])
            grad_weight8 = torch.zeros_like(weight, dtype=torch.int32)
            bnb.functional.igemm(grad_output8, x8, out=grad_weight8)
Ruff's avatar
Ruff committed
264
            grad_weight = LinearFunction.dequant(grad_weight8, S1, S2, grad_output.dtype, args.quant_type)
Tim Dettmers's avatar
Tim Dettmers committed
265

Ruff's avatar
Ruff committed
266
            grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=2)
Tim Dettmers's avatar
Tim Dettmers committed
267
268
            weight8, S3 = LinearFunction.quant(weight, args.quant_type, dim=0)
            grad_input8 = bnb.functional.igemm(grad_output8, weight8)
Ruff's avatar
Ruff committed
269
            grad_input = LinearFunction.dequant(grad_input8, S1, S3, grad_output.dtype, args.quant_type)
Tim Dettmers's avatar
Tim Dettmers committed
270
271
272

        else:
            grad_input = grad_output.matmul(weight)
273
            grad_weight = torch.einsum("bsi,bso->oi", x, grad_output)
274

Tim Dettmers's avatar
Tim Dettmers committed
275
        return grad_input, grad_weight, grad_bias, None
276

277

Tim Dettmers's avatar
Tim Dettmers committed
278
279
class Linear8bit(nn.Module):
    def __init__(self, input_features, output_features, bias=True, args=None):
280
        super().__init__()
Tim Dettmers's avatar
Tim Dettmers committed
281
282
283
        self.input_features = input_features
        self.output_features = output_features
        self.args = args
284

Tim Dettmers's avatar
Tim Dettmers committed
285
286
287
288
        self.weight = nn.Parameter(torch.empty(output_features, input_features))
        if bias:
            self.bias = nn.Parameter(torch.empty(output_features))
        else:
289
            self.register_parameter("bias", None)
290

Tim Dettmers's avatar
Tim Dettmers committed
291
292
293
294
295
296
297
298
299
300
        torch.nn.init.xavier_uniform_(self.weight)
        if self.bias is not None:
            torch.nn.init.zeros_(self.bias)

    def forward(self, x):
        self.args.training = self.training

        return LinearFunction.apply(x, self.weight, self.bias, self.args)


Aarni Koskela's avatar
Aarni Koskela committed
301
@pytest.mark.parametrize("threshold", [0.0, 3.0], ids=id_formatter("threshold"))
Tim Dettmers's avatar
Tim Dettmers committed
302
def test_linear8bitlt_inference(threshold):
303
304
    l1 = bnb.nn.Linear8bitLt(32, 64, threshold=threshold).cuda().half()
    assert l1.weight.device.type == "cuda"
Tim Dettmers's avatar
Tim Dettmers committed
305
306
307
    assert l1.weight.dtype == torch.float16

    l1.eval()
308
    for i in range(100):
309
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
310
311
312
313
        o1 = l1(b1)
        if i == 1:
            assert l1.state.CxB is not None

314

Tim Dettmers's avatar
Tim Dettmers committed
315
def test_linear8bitlt_accumulated_gradient():
316
317
    l1 = torch.nn.Sequential(*[bnb.nn.Linear8bitLt(32, 32).cuda().half() for i in range(2)])
    l2 = torch.nn.Sequential(*[torch.nn.Linear(32, 32).cuda().half() for i in range(2)])
318
319
320
321
322
323
324
    l1[0].weight.data.copy_(l2[0].weight.data)
    l1[1].weight.data.copy_(l2[1].weight.data)
    l1[0].bias.data.copy_(l2[0].bias.data)
    l1[1].bias.data.copy_(l2[1].bias.data)

    opt1 = bnb.optim.Adam32bit(l1.parameters(), lr=0.001)
    opt2 = bnb.optim.Adam32bit(l2.parameters(), lr=0.001)
Tim Dettmers's avatar
Tim Dettmers committed
325
326
327
328

    acc_steps = 10

    for i in range(10):
329
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
330
331
332
333
334
335
336
337
338
        o1 = l1(b1)
        o2 = l2(b1)
        loss1 = o1.mean()
        loss2 = o2.mean()
        loss1.backward()
        loss2.backward()
        if i == 2:
            assert l1[0].state.CxB is not None
            assert l1[1].state.CxB is not None
339

Tim Dettmers's avatar
Tim Dettmers committed
340
341
342
343
344
        if i > 0 and i % acc_steps == 0:
            opt1.step()
            opt1.zero_grad(True)
            opt2.step()
            opt2.zero_grad(True)
Ruff's avatar
Ruff committed
345
346
            assert_all_approx_close(l1[0].weight, l2[0].weight, rtol=1.05, atol=0.01, count=2)
            assert_all_approx_close(l1[1].weight, l2[1].weight, rtol=1.05, atol=0.01, count=2)
Tim Dettmers's avatar
Tim Dettmers committed
347
348
349
            # we do this copy because otherwise we have small divergences over time that add up
            l1[0].weight.data.copy_(l2[0].weight.data)
            l1[1].weight.data.copy_(l2[1].weight.data)
350
351
            l1[0].bias.data.copy_(l2[0].bias.data)
            l1[1].bias.data.copy_(l2[1].bias.data)
Tim Dettmers's avatar
Tim Dettmers committed
352
        else:
353
354
            torch.testing.assert_close(l1[0].weight.grad, l2[0].weight.grad, atol=1e-3, rtol=1e-3)
            torch.testing.assert_close(l1[1].weight.grad, l2[1].weight.grad, atol=1e-3, rtol=1e-3)
355
356


357
@pytest.mark.parametrize("threshold", [0.0, 2.0])
358
@pytest.mark.parametrize("memory_efficient_backward", [False])
justheuristic's avatar
justheuristic committed
359
def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
Ruff's avatar
Ruff committed
360
361
362
363
364
365
366
367
368
369
370
    l1 = (
        bnb.nn.Linear8bitLt(
            32,
            64,
            threshold=threshold,
            has_fp16_weights=False,
            memory_efficient_backward=memory_efficient_backward,
        )
        .cuda()
        .half()
    )
Tim Dettmers's avatar
Tim Dettmers committed
371
    assert l1.weight.dtype == torch.int8
372

Tim Dettmers's avatar
Tim Dettmers committed
373
374
    l1.eval()
    for i in range(100):
375
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
376
377
378
379
380
381
        o1 = l1(b1)
        assert o1.dtype == torch.float16

    mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda()
    assert mlp.fc1.weight.dtype == torch.int8
    assert mlp.fc2.weight.dtype == torch.int8
382

Tim Dettmers's avatar
Tim Dettmers committed
383
    for i in range(100):
384
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
385
386
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
387
388
389
390
        if threshold > 0:
            assert mlp.fc1.state.idx is not None
        if threshold > 0:
            assert mlp.fc2.state.idx is not None
391

Ruff's avatar
Ruff committed
392
    mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda().half()
Tim Dettmers's avatar
Tim Dettmers committed
393
394
395
396
    assert mlp.fc1.weight.dtype == torch.int8
    assert mlp.fc2.weight.dtype == torch.int8

    for i in range(100):
397
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
398
399
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
400
401
402
403
        if threshold > 0:
            assert mlp.fc1.state.idx is not None
        if threshold > 0:
            assert mlp.fc2.state.idx is not None
404

Ruff's avatar
Ruff committed
405
    mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().cuda()
Tim Dettmers's avatar
Tim Dettmers committed
406
407

    for i in range(100):
408
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
409
410
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
411
412
413
414
        if threshold > 0:
            assert mlp.fc1.state.idx is not None
        if threshold > 0:
            assert mlp.fc2.state.idx is not None
Tim Dettmers's avatar
Tim Dettmers committed
415
416
417
    assert mlp.fc1.weight.dtype == torch.int8
    assert mlp.fc2.weight.dtype == torch.int8

Ruff's avatar
Ruff committed
418
419
420
421
422
423
424
425
426
427
428
    mlp = (
        MLP8bit(
            32,
            64,
            threshold=threshold,
            has_fp16_weights=False,
            memory_efficient_backward=memory_efficient_backward,
        )
        .half()
        .to("cuda")
    )
Tim Dettmers's avatar
Tim Dettmers committed
429
430

    for i in range(100):
431
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
432
433
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
434
435
436
437
        if threshold > 0:
            assert mlp.fc1.state.idx is not None
        if threshold > 0:
            assert mlp.fc2.state.idx is not None
Tim Dettmers's avatar
Tim Dettmers committed
438
439
    assert mlp.fc1.weight.dtype == torch.int8
    assert mlp.fc2.weight.dtype == torch.int8
440
441
    assert mlp.fc1.weight.device.type == "cuda"
    assert mlp.fc2.weight.device.type == "cuda"
Tim Dettmers's avatar
Tim Dettmers committed
442

justheuristic's avatar
justheuristic committed
443
    mlp = MLP8bit(
Ruff's avatar
Ruff committed
444
445
446
447
448
449
        32,
        64,
        threshold=threshold,
        has_fp16_weights=False,
        memory_efficient_backward=memory_efficient_backward,
    )
justheuristic's avatar
justheuristic committed
450
    w1, w2 = mlp.fc1.weight.clone().cuda(), mlp.fc2.weight.clone().cuda()  # grab weights before quantization,
justheuristic's avatar
justheuristic committed
451
    mlp = mlp.cuda().half()  # and this line triggers quantization
Tim Dettmers's avatar
Tim Dettmers committed
452
453

    for i in range(100):
454
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
455
456
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
457
458
459
460
        if threshold > 0:
            assert mlp.fc1.state.idx is not None
        if threshold > 0:
            assert mlp.fc2.state.idx is not None
justheuristic's avatar
justheuristic committed
461

Tim Dettmers's avatar
Tim Dettmers committed
462
463
    assert mlp.fc1.weight.dtype == torch.int8
    assert mlp.fc2.weight.dtype == torch.int8
464
465
    assert mlp.fc1.weight.device.type == "cuda"
    assert mlp.fc2.weight.device.type == "cuda"
466

justheuristic's avatar
justheuristic committed
467
468
469
470
471
472
473
    if memory_efficient_backward:
        b1 = torch.randn(16, 8, 32, device="cuda", requires_grad=True, dtype=torch.half)
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
        assert o1.requires_grad
        grad_proj = torch.randn_like(o1)

justheuristic's avatar
debug  
justheuristic committed
474
        mlp.zero_grad()
justheuristic's avatar
justheuristic committed
475
        (o1 * grad_proj).sum().backward()
justheuristic's avatar
justheuristic committed
476
        grad_ref = grad_proj.flatten(2) @ w2.half() @ w1.half()
justheuristic's avatar
justheuristic committed
477
        scale = grad_ref.abs().mean()
justheuristic's avatar
justheuristic committed
478

479
        torch.testing.assert_close(b1.grad, grad_ref, rtol=0, atol=0.05 * scale)
justheuristic's avatar
review  
justheuristic committed
480
        idx = torch.isclose(b1.grad, grad_ref, atol=0.01 * scale, rtol=0.1)
Tim Dettmers's avatar
Tim Dettmers committed
481
        assert (idx == 0).sum().item() <= b1.numel() * 0.005
482

justheuristic's avatar
justheuristic committed
483

484
485
486
487
488
489
@pytest.mark.parametrize(
    "module",
    [
        lambda n_in, n_out, bias=True: bnb.nn.Linear8bitLt(n_in, n_out, bias=bias, has_fp16_weights=False),
        bnb.nn.LinearFP4,
    ],
Ruff's avatar
Ruff committed
490
    ids=["Int8Lt", "FP4"],
491
)
492
def test_linear_kbit_fp32_bias(module):
493
    # casts model to fp16 -> int8 automatically
494
495
    l1 = module(32, 64).cuda()
    assert l1.weight.dtype in [torch.int8, torch.uint8]
496
497
498
499
500
501
502
503
504
    assert l1.bias.dtype == torch.float32

    for i in range(100):
        b1 = torch.randn(16, 8, 32, device="cuda").half()
        # casts bias to fp32
        o1 = l1(b1)
        assert l1.bias.dtype == torch.float16

    # casts model to fp16 -> int8 automatically
505
506
    l1 = module(32, 64, bias=False).cuda()
    assert l1.weight.dtype in [torch.int8, torch.uint8]
507
508
509
510
511
512
    assert l1.bias is None

    for i in range(100):
        b1 = torch.randn(16, 8, 32, device="cuda").half()
        o1 = l1(b1)
        assert l1.bias is None
513

Aarni Koskela's avatar
Aarni Koskela committed
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528

module_dict = {
    "Int8Lt": bnb.nn.Linear8bitLt,
    "4bit": bnb.nn.Linear4bit,
    "FP4": bnb.nn.LinearFP4,
    "NF4": bnb.nn.LinearNF4,
    "FP4+C": lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compress_statistics=True),
    "NF4+C": lambda d1, d2: bnb.nn.LinearNF4(d1, d2, compress_statistics=True),
    "NF4+fp32": lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.float32),
    "NF4+fp16": lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.float16),
    "NF4+bf16": lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.bfloat16),
}


@pytest.mark.parametrize("module", module_dict.values(), ids=module_dict.keys())
529
530
531
532
533
534
535
def test_kbit_backprop(module):
    b = 17
    dim1 = 37
    dim2 = 83

    ref = nn.Sequential(*[torch.nn.Linear(dim1, dim2), torch.nn.Linear(dim2, 10)])
    ref[1].weight.requires_grad = False
536
537
    torch.nn.init.kaiming_normal_(ref[0].weight)
    torch.nn.init.kaiming_normal_(ref[1].weight)
538
539
540
541
542
543
544
    kbit = nn.Sequential(*[torch.nn.Linear(dim1, dim2), module(dim2, 10)])
    kbit[0].weight.detach().copy_(ref[0].weight)
    kbit[1].weight.detach().copy_(ref[1].weight)
    kbit[0].bias.detach().copy_(ref[0].bias)
    kbit[1].bias.detach().copy_(ref[1].bias)
    ref = ref.half().cuda()
    kbit = kbit.half().cuda()
Ruff's avatar
Ruff committed
545
    kbit = kbit.half().to("cuda")
546

547
548
549
550
    errs1 = []
    errs2 = []
    relerrs1 = []
    relerrs2 = []
551
552
553
554
555
556
557
558
559
560
561
562
    for i in range(100):
        batch = torch.randn(b, dim1).half().cuda()
        out1 = ref(batch)
        out2 = kbit(batch)
        out1.mean().backward()
        out2.mean().backward()

        grad1 = ref[0].weight.grad
        grad2 = kbit[0].weight.grad
        bgrad1 = ref[0].bias.grad
        bgrad2 = kbit[0].bias.grad

Ruff's avatar
Ruff committed
563
564
565
566
        err1 = (out1 - out2).abs().float()
        err2 = (grad1 - grad2).abs().float()
        relerr1 = err1 / (out1.abs().float() + 1e-9)
        relerr2 = err2 / (grad1.abs().float() + 1e-9)
567
568
569
570
571
        errs1.append(err1.mean().item())
        errs2.append(err2.mean().item())
        relerrs1.append(relerr1.mean().item())
        relerrs2.append(relerr2.mean().item())

572
        if isinstance(module, bnb.nn.Linear8bitLt):
573
            assert_all_approx_close(grad1, grad2, atol=0.008, rtol=0.05, count=1)
574
            torch.testing.assert_close(bgrad1, bgrad2, atol=0.008, rtol=0.05)
575
        else:
576
            assert_all_approx_close(grad1, grad2, atol=0.015, rtol=0.05, count=1)
577
            torch.testing.assert_close(bgrad1, bgrad2, atol=0.02, rtol=0.05)
578
579
580
        ref.zero_grad()
        kbit.zero_grad()

581
582
        assert kbit[0].weight.grad is None or kbit[0].weight.grad.sum().item() == 0
        assert kbit[0].weight.grad is None or kbit[0].bias.grad.sum().item() == 0
Ruff's avatar
Ruff committed
583
584
585
586
    # print('out', sum(errs1)/len(errs1))
    # print('grad', sum(errs2)/len(errs2))
    # print('rel out', sum(relerrs1)/len(relerrs1))
    # print('rel grad', sum(relerrs2)/len(relerrs2))
587

588

Ruff's avatar
Ruff committed
589
def test_fp8linear():
590
591
592
    b = 10
    h = 1024
    inp = torch.randn(b, h).cuda()
Ruff's avatar
Ruff committed
593
594
595
596
    fp32 = torch.nn.Linear(h, h * 2).cuda()
    fp8 = bnb.research.nn.LinearFP8Mixed(h, h * 2).cuda()
    fp32b = torch.nn.Linear(h * 2, h).cuda()
    fp8b = bnb.research.nn.LinearFP8Mixed(h * 2, h).cuda()
597
598
599
600
601
602
603
604
605

    fp8.weight.data.copy_(fp32.weight.data)
    fp8.bias.data.copy_(fp32.bias.data)
    fp8b.weight.data.copy_(fp32b.weight.data)
    fp8b.bias.data.copy_(fp32b.bias.data)

    a = fp32b(torch.nn.functional.gelu(fp32(inp)))
    b = fp8b(torch.nn.functional.gelu(fp8(inp)))

Ruff's avatar
Ruff committed
606
    err = (a - b).abs().mean()
607
608
609
610

    a.mean().backward()
    b.mean().backward()

Ruff's avatar
Ruff committed
611
612
    graderr = (fp8.weight.grad - fp32.weight.grad).abs().mean()
    bgraderr = (fp8.bias.grad - fp32.bias.grad).abs().mean()
613
614
615
616
617

    assert err < 0.05
    assert graderr < 0.00002
    assert bgraderr < 0.00002

Ruff's avatar
Ruff committed
618

619
620
621
def test_4bit_warnings():
    dim1 = 64

Ruff's avatar
Ruff committed
622
    with pytest.warns(UserWarning, match=r"inference or training"):
623
624
625
626
        net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)])
        net = net.cuda()
        inp = torch.rand(10, dim1).cuda().half()
        net(inp)
Ruff's avatar
Ruff committed
627
    with pytest.warns(UserWarning, match=r"inference."):
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
        net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)])
        net = net.cuda()
        inp = torch.rand(1, dim1).cuda().half()
        net(inp)

    with pytest.warns(UserWarning) as record:
        net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)])
        net = net.cuda()
        inp = torch.rand(10, dim1).cuda().half()
        net(inp)

        net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)])
        net = net.cuda()
        inp = torch.rand(1, dim1).cuda().half()
        net(inp)

    assert len(record) == 2