test_modules.py 26.7 KB
Newer Older
1
import inspect
Aarni Koskela's avatar
Aarni Koskela committed
2
import math
3

Aarni Koskela's avatar
Aarni Koskela committed
4
import einops
5
6
import pytest
import torch
Tim Dettmers's avatar
Tim Dettmers committed
7
8
from torch import nn

9
import bitsandbytes as bnb
Aarni Koskela's avatar
Aarni Koskela committed
10
from tests.helpers import id_formatter
11

12

13
class MockArgs:
Tim Dettmers's avatar
Tim Dettmers committed
14
15
16
17
    def __init__(self, initial_data):
        for key in initial_data:
            setattr(self, key, initial_data[key])

18

Tim Dettmers's avatar
Tim Dettmers committed
19
class MLP8bit(torch.nn.Module):
justheuristic's avatar
justheuristic committed
20
    def __init__(self, dim1, dim2, has_fp16_weights=True, memory_efficient_backward=False, threshold=0.0):
21
        super().__init__()
22
        self.fc1 = bnb.nn.Linear8bitLt(
Ruff's avatar
Ruff committed
23
24
25
26
27
            dim1,
            dim2,
            has_fp16_weights=has_fp16_weights,
            memory_efficient_backward=memory_efficient_backward,
            threshold=threshold,
28
29
        )
        self.fc2 = bnb.nn.Linear8bitLt(
Ruff's avatar
Ruff committed
30
31
32
33
34
            dim2,
            dim1,
            has_fp16_weights=has_fp16_weights,
            memory_efficient_backward=memory_efficient_backward,
            threshold=threshold,
35
        )
Tim Dettmers's avatar
Tim Dettmers committed
36
37
38
39
40
41
42
43
44

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x


def get_args():
    args = MockArgs([])
45
46
    args.quant_type = "vector"
    args.use_8bit_training = "full"
Tim Dettmers's avatar
Tim Dettmers committed
47
48
49
    args.clip_freq = 9999
    return args

50

Tim Dettmers's avatar
Tim Dettmers committed
51
def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10):
52
    idx = torch.isclose(a, b, rtol=rtol, atol=atol)
53
    sumval = (idx == 0).sum().item()
Tim Dettmers's avatar
Tim Dettmers committed
54
    if sumval > count:
55
        print(f"Too many values not close: assert {sumval} < {count}")
56
        torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
Tim Dettmers's avatar
Tim Dettmers committed
57
58


59
class LinearFunction(torch.autograd.Function):
Tim Dettmers's avatar
Tim Dettmers committed
60
61
    @staticmethod
    def get_8bit_linear_trimmed(x, stochastic=False, trim_value=3.0):
Ruff's avatar
Ruff committed
62
        round_func = LinearFunction.round_stoachastic if stochastic else torch.round
63
64
        norm = math.sqrt(math.pi) / math.sqrt(2.0)
        # std = torch.abs(x).mean()*norm
Tim Dettmers's avatar
Tim Dettmers committed
65
        std = torch.std(x)
66
67
        max1 = std * trim_value
        x = x / max1 * 127
Tim Dettmers's avatar
Tim Dettmers committed
68
69
70
        x = round_func(x)
        x[x > 127] = 127
        x[x < -127] = -127
71
        x = x / 127 * max1
Tim Dettmers's avatar
Tim Dettmers committed
72
73
74
75

        return x

    def quant(x, quant_type, dim=1):
76
        if quant_type == "linear":
Tim Dettmers's avatar
Tim Dettmers committed
77
            max1 = torch.abs(x).max().float()
78
            xq = torch.round(x / max1 * 127).to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
79
            return xq, max1
80
        elif quant_type == "vector":
Tim Dettmers's avatar
Tim Dettmers committed
81
            max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
82
            xq = torch.round(x / max1 * 127).to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
83
            return xq, max1
84
        elif quant_type == "min-max":
Tim Dettmers's avatar
Tim Dettmers committed
85
86
            maxA = torch.amax(x, dim=dim, keepdim=True).float()
            minA = torch.amin(x, dim=dim, keepdim=True).float()
87
88
            scale = (maxA - minA) / 2.0
            xq = torch.round(127 * (x - minA - scale) / scale).to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
89
            return xq, (minA.float(), scale.float())
90
91
        else:
            return None
Tim Dettmers's avatar
Tim Dettmers committed
92
93

    def dequant(xq, S1, S2, dtype, quant_type):
94
95
        if quant_type == "linear":
            norm = S1 * S2 / (127 * 127)
Tim Dettmers's avatar
Tim Dettmers committed
96
            # double cast needed to prevent overflows
97
98
            return (xq.float() * norm).to(dtype)
        elif quant_type == "vector":
Tim Dettmers's avatar
Tim Dettmers committed
99
            x = xq.float()
100
101
102
103
104
            if len(xq.shape) == 2 and len(S1.shape) == 3:
                S1 = S1.squeeze(0)
            if len(xq.shape) == 2 and len(S2.shape) == 3:
                S2 = S2.squeeze(0)
            # print(x.shape, S1.shape, S2.shape)
Tim Dettmers's avatar
Tim Dettmers committed
105
            if len(S1.shape) == 2:
106
                x *= S1.t() / 127
Tim Dettmers's avatar
Tim Dettmers committed
107
            else:
108
109
                x *= S1 / 127
            x *= S2 / 127
Tim Dettmers's avatar
Tim Dettmers committed
110
            return x.to(dtype)
111
112
        else:
            return None
Tim Dettmers's avatar
Tim Dettmers committed
113
114

    def dequant_min_max(xq, A, B, SA, SB, dtype):
115
        offset = B.float().t().sum(0) * (SA[0] + SA[1])
Tim Dettmers's avatar
Tim Dettmers committed
116
        x = xq.float()
117
118
119
120
        if len(xq.shape) == 2 and len(SB.shape) == 3:
            SB = SB.squeeze(0)
        if len(xq.shape) == 2 and len(SA.shape) == 3:
            SA = SA.squeeze(0)
Tim Dettmers's avatar
Tim Dettmers committed
121
        if len(SB.shape) == 2:
122
            x *= SB.t() / 127
Tim Dettmers's avatar
Tim Dettmers committed
123
        else:
124
125
126
            x *= SB / 127
        x *= SA[1] / 127
        x += offset
Tim Dettmers's avatar
Tim Dettmers committed
127
128
129
        return x.to(dtype)

    def get_8bit_linear(x, stochastic=False):
Ruff's avatar
Ruff committed
130
        round_func = LinearFunction.round_stoachastic if stochastic else torch.round
Tim Dettmers's avatar
Tim Dettmers committed
131
        max1 = torch.abs(x).max()
132
133
134
        x = x / max1 * 127
        x = round_func(x) / 127 * max1
        # x = torch.round(x)/128*max1
Tim Dettmers's avatar
Tim Dettmers committed
135
136
137
138
        return x

    @staticmethod
    def get_8bit_vector_wise(x, dim, stochastic=False):
Ruff's avatar
Ruff committed
139
        round_func = LinearFunction.round_stoachastic if stochastic else torch.round
Tim Dettmers's avatar
Tim Dettmers committed
140
        max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
141
142
143
        max1[max1 == 0] = 1.0
        x = (x * 127) / max1
        x = round_func(x) / 127 * max1
Tim Dettmers's avatar
Tim Dettmers committed
144
145
146
147
148
149
        return x

    @staticmethod
    def round_stoachastic(x):
        sign = torch.sign(x)
        absx = torch.abs(x)
150
        decimal = absx - torch.floor(absx)
Tim Dettmers's avatar
Tim Dettmers committed
151
        rdm = torch.rand_like(decimal)
152
        return sign * (torch.floor(absx) + (rdm < decimal).to(x.dtype))
Tim Dettmers's avatar
Tim Dettmers committed
153
154
155
156
157
158
159
160
161
162
163
164
165

    @staticmethod
    def fake_8bit_storage(w, exponent_bits):
        code = bnb.functional.create_dynamic_map(n=exponent_bits).to(w.device)
        absmax, C = bnb.functional.quantize_blockwise(w.data, code=code)
        out = bnb.functional.dequantize_blockwise(absmax, C, code)
        out = out.half()
        w.copy_(out)
        return out

    @staticmethod
    def fake_8bit_storage_quantile(w, args):
        code = bnb.functional.estimate_quantiles(w.data, offset=args.offset)
166
167
168
169
        # C = bnb.functional.quantize_no_absmax(code, w)
        # out = bnb.functional.dequantize_no_absmax(code, C, out=w.data)
        # print(out)
        # out = out.half()
Tim Dettmers's avatar
Tim Dettmers committed
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
        code /= torch.max(torch.abs(code))
        absmax, C = bnb.functional.quantize_blockwise(w.data, code=code)
        out = bnb.functional.dequantize_blockwise(absmax, C, code)
        out = out.half()
        w.copy_(out)
        return out

    @staticmethod
    def fake_8bit_storage_stoachstic(w):
        rand = torch.rand(1024, device=w.device)
        absmax, C = bnb.functional.quantize_blockwise(w.data, rand=rand)
        out = bnb.functional.dequantize_blockwise(absmax, C)
        out = out.half()
        w.copy_(out)
        return out

    @staticmethod
    def fake_8bit_storage_with_max(w, topk=8):
188
        blocked_w = einops.rearrange(w.flatten(), "(h b) -> h b", b=256)
Tim Dettmers's avatar
Tim Dettmers committed
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
        max_val, idx = torch.sort(torch.abs(blocked_w), dim=1, descending=True)
        idx = idx[:, :topk]
        max_val = max_val[:, :topk]

        mask = torch.zeros_like(blocked_w)
        mask.scatter_(dim=1, index=idx, src=torch.ones_like(max_val))
        mask = mask.bool()

        # 1. zero out max values
        # 2. quantize + dequantize
        # 3. write back max values
        # 4. copy matrix back to weight

        values = blocked_w[mask]
        blocked_w[mask] = 0

        code = bnb.functional.create_dynamic_map()
        code = code.to(w.device)
        absmax, C = bnb.functional.quantize_blockwise(blocked_w.data)
        bnb.functional.dequantize_blockwise(absmax, C, out=blocked_w)

        blocked_w[mask] = values

        unblocked_w = blocked_w.flatten().view(w.shape)

        w.copy_(unblocked_w)
        return unblocked_w

    @staticmethod
    def forward(ctx, x, weight, bias=None, args=None):
219
        if args.use_8bit_training != "off":
Tim Dettmers's avatar
Tim Dettmers committed
220
221
222
            weight8, S1 = LinearFunction.quant(weight, args.quant_type, dim=1)
            x8, S2 = LinearFunction.quant(x, args.quant_type, dim=2)
            outputq = bnb.functional.igemm(x8, weight8.t())
Ruff's avatar
Ruff committed
223
            output = LinearFunction.dequant(outputq, S1, S2, x.dtype, args.quant_type)
224
225
226
227
228
            # if torch.rand(1) < 0.01:
            # output32 = torch.matmul(x, weight.t())
            # err = torch.abs(output-output32).float()
            # relerr = err/(torch.abs(output32).float()+1e-8)
            # print(f'{err.mean().item():.4f}, {relerr.mean().item():.4f}', args.quant_type, 'forward', proxy)
Tim Dettmers's avatar
Tim Dettmers committed
229
        else:
230
231
            # output = torch.matmul(x, weight.t())
            output = torch.einsum("bsi,oi->bso", x, weight)
Tim Dettmers's avatar
Tim Dettmers committed
232
233
234
235
236
237
238
239
240
241
242
243
244
245

        ctx.save_for_backward(x, weight, bias)
        ctx.args = args

        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        x, weight, bias = ctx.saved_tensors
        args = ctx.args
        stochastic = False
        grad_input = grad_weight = grad_bias = None
246
247
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0)
Tim Dettmers's avatar
Tim Dettmers committed
248
249
250

        # weight and x are already 8bit
        # -> transform grad_output to 8-bit
251
        if args.use_8bit_training == "forward+wgrad":
Ruff's avatar
Ruff committed
252
            grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=[0, 1])
Tim Dettmers's avatar
Tim Dettmers committed
253
254
            x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1])
            grad_weight8 = bnb.functional.igemm(grad_output8, x8)
Ruff's avatar
Ruff committed
255
            grad_weight = LinearFunction.dequant(grad_weight8, S1, S2, grad_output.dtype, args.quant_type)
Tim Dettmers's avatar
Tim Dettmers committed
256

257
            # grad_weight32 = torch.einsum('bso,bsi->oi', grad_output, x)
Tim Dettmers's avatar
Tim Dettmers committed
258
259

            grad_input = grad_output.matmul(weight)
260
        elif args.use_8bit_training == "full":
Ruff's avatar
Ruff committed
261
            grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=[0, 1])
Tim Dettmers's avatar
Tim Dettmers committed
262
263
264
            x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1])
            grad_weight8 = torch.zeros_like(weight, dtype=torch.int32)
            bnb.functional.igemm(grad_output8, x8, out=grad_weight8)
Ruff's avatar
Ruff committed
265
            grad_weight = LinearFunction.dequant(grad_weight8, S1, S2, grad_output.dtype, args.quant_type)
Tim Dettmers's avatar
Tim Dettmers committed
266

Ruff's avatar
Ruff committed
267
            grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=2)
Tim Dettmers's avatar
Tim Dettmers committed
268
269
            weight8, S3 = LinearFunction.quant(weight, args.quant_type, dim=0)
            grad_input8 = bnb.functional.igemm(grad_output8, weight8)
Ruff's avatar
Ruff committed
270
            grad_input = LinearFunction.dequant(grad_input8, S1, S3, grad_output.dtype, args.quant_type)
Tim Dettmers's avatar
Tim Dettmers committed
271
272
273

        else:
            grad_input = grad_output.matmul(weight)
274
            grad_weight = torch.einsum("bsi,bso->oi", x, grad_output)
275

Tim Dettmers's avatar
Tim Dettmers committed
276
        return grad_input, grad_weight, grad_bias, None
277

278

Tim Dettmers's avatar
Tim Dettmers committed
279
280
class Linear8bit(nn.Module):
    def __init__(self, input_features, output_features, bias=True, args=None):
281
        super().__init__()
Tim Dettmers's avatar
Tim Dettmers committed
282
283
284
        self.input_features = input_features
        self.output_features = output_features
        self.args = args
285

Tim Dettmers's avatar
Tim Dettmers committed
286
287
288
289
        self.weight = nn.Parameter(torch.empty(output_features, input_features))
        if bias:
            self.bias = nn.Parameter(torch.empty(output_features))
        else:
290
            self.register_parameter("bias", None)
291

Tim Dettmers's avatar
Tim Dettmers committed
292
293
294
295
296
297
298
299
300
301
        torch.nn.init.xavier_uniform_(self.weight)
        if self.bias is not None:
            torch.nn.init.zeros_(self.bias)

    def forward(self, x):
        self.args.training = self.training

        return LinearFunction.apply(x, self.weight, self.bias, self.args)


Aarni Koskela's avatar
Aarni Koskela committed
302
@pytest.mark.parametrize("threshold", [0.0, 3.0], ids=id_formatter("threshold"))
Tim Dettmers's avatar
Tim Dettmers committed
303
def test_linear8bitlt_inference(threshold):
304
305
    l1 = bnb.nn.Linear8bitLt(32, 64, threshold=threshold).cuda().half()
    assert l1.weight.device.type == "cuda"
Tim Dettmers's avatar
Tim Dettmers committed
306
307
308
    assert l1.weight.dtype == torch.float16

    l1.eval()
309
    for i in range(100):
310
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
311
312
313
314
        o1 = l1(b1)
        if i == 1:
            assert l1.state.CxB is not None

315

Tim Dettmers's avatar
Tim Dettmers committed
316
def test_linear8bitlt_accumulated_gradient():
317
318
    l1 = torch.nn.Sequential(*[bnb.nn.Linear8bitLt(32, 32).cuda().half() for i in range(2)])
    l2 = torch.nn.Sequential(*[torch.nn.Linear(32, 32).cuda().half() for i in range(2)])
319
320
321
322
323
324
325
    l1[0].weight.data.copy_(l2[0].weight.data)
    l1[1].weight.data.copy_(l2[1].weight.data)
    l1[0].bias.data.copy_(l2[0].bias.data)
    l1[1].bias.data.copy_(l2[1].bias.data)

    opt1 = bnb.optim.Adam32bit(l1.parameters(), lr=0.001)
    opt2 = bnb.optim.Adam32bit(l2.parameters(), lr=0.001)
Tim Dettmers's avatar
Tim Dettmers committed
326
327
328
329

    acc_steps = 10

    for i in range(10):
330
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
331
332
333
334
335
336
337
338
339
        o1 = l1(b1)
        o2 = l2(b1)
        loss1 = o1.mean()
        loss2 = o2.mean()
        loss1.backward()
        loss2.backward()
        if i == 2:
            assert l1[0].state.CxB is not None
            assert l1[1].state.CxB is not None
340

Tim Dettmers's avatar
Tim Dettmers committed
341
342
343
344
345
        if i > 0 and i % acc_steps == 0:
            opt1.step()
            opt1.zero_grad(True)
            opt2.step()
            opt2.zero_grad(True)
Ruff's avatar
Ruff committed
346
347
            assert_all_approx_close(l1[0].weight, l2[0].weight, rtol=1.05, atol=0.01, count=2)
            assert_all_approx_close(l1[1].weight, l2[1].weight, rtol=1.05, atol=0.01, count=2)
Tim Dettmers's avatar
Tim Dettmers committed
348
349
350
            # we do this copy because otherwise we have small divergences over time that add up
            l1[0].weight.data.copy_(l2[0].weight.data)
            l1[1].weight.data.copy_(l2[1].weight.data)
351
352
            l1[0].bias.data.copy_(l2[0].bias.data)
            l1[1].bias.data.copy_(l2[1].bias.data)
Tim Dettmers's avatar
Tim Dettmers committed
353
        else:
354
355
            torch.testing.assert_close(l1[0].weight.grad, l2[0].weight.grad, atol=1e-3, rtol=1e-3)
            torch.testing.assert_close(l1[1].weight.grad, l2[1].weight.grad, atol=1e-3, rtol=1e-3)
356
357


358
@pytest.mark.parametrize("threshold", [0.0, 2.0])
359
@pytest.mark.parametrize("memory_efficient_backward", [False])
justheuristic's avatar
justheuristic committed
360
def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
Ruff's avatar
Ruff committed
361
362
363
364
365
366
367
368
369
370
371
    l1 = (
        bnb.nn.Linear8bitLt(
            32,
            64,
            threshold=threshold,
            has_fp16_weights=False,
            memory_efficient_backward=memory_efficient_backward,
        )
        .cuda()
        .half()
    )
Tim Dettmers's avatar
Tim Dettmers committed
372
    assert l1.weight.dtype == torch.int8
373

Tim Dettmers's avatar
Tim Dettmers committed
374
375
    l1.eval()
    for i in range(100):
376
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
377
378
379
380
381
382
        o1 = l1(b1)
        assert o1.dtype == torch.float16

    mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda()
    assert mlp.fc1.weight.dtype == torch.int8
    assert mlp.fc2.weight.dtype == torch.int8
383

Tim Dettmers's avatar
Tim Dettmers committed
384
    for i in range(100):
385
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
386
387
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
388
389
390
391
        if threshold > 0:
            assert mlp.fc1.state.idx is not None
        if threshold > 0:
            assert mlp.fc2.state.idx is not None
392

Ruff's avatar
Ruff committed
393
    mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda().half()
Tim Dettmers's avatar
Tim Dettmers committed
394
395
396
397
    assert mlp.fc1.weight.dtype == torch.int8
    assert mlp.fc2.weight.dtype == torch.int8

    for i in range(100):
398
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
399
400
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
401
402
403
404
        if threshold > 0:
            assert mlp.fc1.state.idx is not None
        if threshold > 0:
            assert mlp.fc2.state.idx is not None
405

Ruff's avatar
Ruff committed
406
    mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().cuda()
Tim Dettmers's avatar
Tim Dettmers committed
407
408

    for i in range(100):
409
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
410
411
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
412
413
414
415
        if threshold > 0:
            assert mlp.fc1.state.idx is not None
        if threshold > 0:
            assert mlp.fc2.state.idx is not None
Tim Dettmers's avatar
Tim Dettmers committed
416
417
418
    assert mlp.fc1.weight.dtype == torch.int8
    assert mlp.fc2.weight.dtype == torch.int8

Ruff's avatar
Ruff committed
419
420
421
422
423
424
425
426
427
428
429
    mlp = (
        MLP8bit(
            32,
            64,
            threshold=threshold,
            has_fp16_weights=False,
            memory_efficient_backward=memory_efficient_backward,
        )
        .half()
        .to("cuda")
    )
Tim Dettmers's avatar
Tim Dettmers committed
430
431

    for i in range(100):
432
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
433
434
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
435
436
437
438
        if threshold > 0:
            assert mlp.fc1.state.idx is not None
        if threshold > 0:
            assert mlp.fc2.state.idx is not None
Tim Dettmers's avatar
Tim Dettmers committed
439
440
    assert mlp.fc1.weight.dtype == torch.int8
    assert mlp.fc2.weight.dtype == torch.int8
441
442
    assert mlp.fc1.weight.device.type == "cuda"
    assert mlp.fc2.weight.device.type == "cuda"
Tim Dettmers's avatar
Tim Dettmers committed
443

justheuristic's avatar
justheuristic committed
444
    mlp = MLP8bit(
Ruff's avatar
Ruff committed
445
446
447
448
449
450
        32,
        64,
        threshold=threshold,
        has_fp16_weights=False,
        memory_efficient_backward=memory_efficient_backward,
    )
justheuristic's avatar
justheuristic committed
451
    w1, w2 = mlp.fc1.weight.clone().cuda(), mlp.fc2.weight.clone().cuda()  # grab weights before quantization,
justheuristic's avatar
justheuristic committed
452
    mlp = mlp.cuda().half()  # and this line triggers quantization
Tim Dettmers's avatar
Tim Dettmers committed
453
454

    for i in range(100):
455
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
456
457
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
458
459
460
461
        if threshold > 0:
            assert mlp.fc1.state.idx is not None
        if threshold > 0:
            assert mlp.fc2.state.idx is not None
justheuristic's avatar
justheuristic committed
462

Tim Dettmers's avatar
Tim Dettmers committed
463
464
    assert mlp.fc1.weight.dtype == torch.int8
    assert mlp.fc2.weight.dtype == torch.int8
465
466
    assert mlp.fc1.weight.device.type == "cuda"
    assert mlp.fc2.weight.device.type == "cuda"
467

justheuristic's avatar
justheuristic committed
468
469
470
471
472
473
474
    if memory_efficient_backward:
        b1 = torch.randn(16, 8, 32, device="cuda", requires_grad=True, dtype=torch.half)
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
        assert o1.requires_grad
        grad_proj = torch.randn_like(o1)

justheuristic's avatar
debug  
justheuristic committed
475
        mlp.zero_grad()
justheuristic's avatar
justheuristic committed
476
        (o1 * grad_proj).sum().backward()
justheuristic's avatar
justheuristic committed
477
        grad_ref = grad_proj.flatten(2) @ w2.half() @ w1.half()
justheuristic's avatar
justheuristic committed
478
        scale = grad_ref.abs().mean()
justheuristic's avatar
justheuristic committed
479

480
        torch.testing.assert_close(b1.grad, grad_ref, rtol=0, atol=0.05 * scale)
justheuristic's avatar
review  
justheuristic committed
481
        idx = torch.isclose(b1.grad, grad_ref, atol=0.01 * scale, rtol=0.1)
Tim Dettmers's avatar
Tim Dettmers committed
482
        assert (idx == 0).sum().item() <= b1.numel() * 0.005
483

justheuristic's avatar
justheuristic committed
484

485
486
487
488
489
490
@pytest.mark.parametrize(
    "module",
    [
        lambda n_in, n_out, bias=True: bnb.nn.Linear8bitLt(n_in, n_out, bias=bias, has_fp16_weights=False),
        bnb.nn.LinearFP4,
    ],
Ruff's avatar
Ruff committed
491
    ids=["Int8Lt", "FP4"],
492
)
493
def test_linear_kbit_fp32_bias(module):
494
    # casts model to fp16 -> int8 automatically
495
496
    l1 = module(32, 64).cuda()
    assert l1.weight.dtype in [torch.int8, torch.uint8]
497
498
499
500
501
502
503
504
505
    assert l1.bias.dtype == torch.float32

    for i in range(100):
        b1 = torch.randn(16, 8, 32, device="cuda").half()
        # casts bias to fp32
        o1 = l1(b1)
        assert l1.bias.dtype == torch.float16

    # casts model to fp16 -> int8 automatically
506
507
    l1 = module(32, 64, bias=False).cuda()
    assert l1.weight.dtype in [torch.int8, torch.uint8]
508
509
510
511
512
513
    assert l1.bias is None

    for i in range(100):
        b1 = torch.randn(16, 8, 32, device="cuda").half()
        o1 = l1(b1)
        assert l1.bias is None
514

Aarni Koskela's avatar
Aarni Koskela committed
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529

module_dict = {
    "Int8Lt": bnb.nn.Linear8bitLt,
    "4bit": bnb.nn.Linear4bit,
    "FP4": bnb.nn.LinearFP4,
    "NF4": bnb.nn.LinearNF4,
    "FP4+C": lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compress_statistics=True),
    "NF4+C": lambda d1, d2: bnb.nn.LinearNF4(d1, d2, compress_statistics=True),
    "NF4+fp32": lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.float32),
    "NF4+fp16": lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.float16),
    "NF4+bf16": lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.bfloat16),
}


@pytest.mark.parametrize("module", module_dict.values(), ids=module_dict.keys())
530
531
532
533
534
535
536
def test_kbit_backprop(module):
    b = 17
    dim1 = 37
    dim2 = 83

    ref = nn.Sequential(*[torch.nn.Linear(dim1, dim2), torch.nn.Linear(dim2, 10)])
    ref[1].weight.requires_grad = False
537
538
    torch.nn.init.kaiming_normal_(ref[0].weight)
    torch.nn.init.kaiming_normal_(ref[1].weight)
539
540
541
542
543
544
545
    kbit = nn.Sequential(*[torch.nn.Linear(dim1, dim2), module(dim2, 10)])
    kbit[0].weight.detach().copy_(ref[0].weight)
    kbit[1].weight.detach().copy_(ref[1].weight)
    kbit[0].bias.detach().copy_(ref[0].bias)
    kbit[1].bias.detach().copy_(ref[1].bias)
    ref = ref.half().cuda()
    kbit = kbit.half().cuda()
Ruff's avatar
Ruff committed
546
    kbit = kbit.half().to("cuda")
547

548
549
550
551
    errs1 = []
    errs2 = []
    relerrs1 = []
    relerrs2 = []
552
553
554
555
556
557
558
559
560
561
562
563
    for i in range(100):
        batch = torch.randn(b, dim1).half().cuda()
        out1 = ref(batch)
        out2 = kbit(batch)
        out1.mean().backward()
        out2.mean().backward()

        grad1 = ref[0].weight.grad
        grad2 = kbit[0].weight.grad
        bgrad1 = ref[0].bias.grad
        bgrad2 = kbit[0].bias.grad

Ruff's avatar
Ruff committed
564
565
566
567
        err1 = (out1 - out2).abs().float()
        err2 = (grad1 - grad2).abs().float()
        relerr1 = err1 / (out1.abs().float() + 1e-9)
        relerr2 = err2 / (grad1.abs().float() + 1e-9)
568
569
570
571
572
        errs1.append(err1.mean().item())
        errs2.append(err2.mean().item())
        relerrs1.append(relerr1.mean().item())
        relerrs2.append(relerr2.mean().item())

573
        if isinstance(module, bnb.nn.Linear8bitLt):
574
            assert_all_approx_close(grad1, grad2, atol=0.008, rtol=0.05, count=1)
575
            torch.testing.assert_close(bgrad1, bgrad2, atol=0.008, rtol=0.05)
576
        else:
577
            assert_all_approx_close(grad1, grad2, atol=0.015, rtol=0.05, count=1)
578
            torch.testing.assert_close(bgrad1, bgrad2, atol=0.02, rtol=0.05)
579
580
581
        ref.zero_grad()
        kbit.zero_grad()

582
583
        assert kbit[0].weight.grad is None or kbit[0].weight.grad.sum().item() == 0
        assert kbit[0].weight.grad is None or kbit[0].bias.grad.sum().item() == 0
Ruff's avatar
Ruff committed
584
585
586
587
    # print('out', sum(errs1)/len(errs1))
    # print('grad', sum(errs2)/len(errs2))
    # print('rel out', sum(relerrs1)/len(relerrs1))
    # print('rel grad', sum(relerrs2)/len(relerrs2))
588

589

Ruff's avatar
Ruff committed
590
def test_fp8linear():
591
592
593
    b = 10
    h = 1024
    inp = torch.randn(b, h).cuda()
Ruff's avatar
Ruff committed
594
595
596
597
    fp32 = torch.nn.Linear(h, h * 2).cuda()
    fp8 = bnb.research.nn.LinearFP8Mixed(h, h * 2).cuda()
    fp32b = torch.nn.Linear(h * 2, h).cuda()
    fp8b = bnb.research.nn.LinearFP8Mixed(h * 2, h).cuda()
598
599
600
601
602
603
604
605
606

    fp8.weight.data.copy_(fp32.weight.data)
    fp8.bias.data.copy_(fp32.bias.data)
    fp8b.weight.data.copy_(fp32b.weight.data)
    fp8b.bias.data.copy_(fp32b.bias.data)

    a = fp32b(torch.nn.functional.gelu(fp32(inp)))
    b = fp8b(torch.nn.functional.gelu(fp8(inp)))

Ruff's avatar
Ruff committed
607
    err = (a - b).abs().mean()
608
609
610
611

    a.mean().backward()
    b.mean().backward()

Ruff's avatar
Ruff committed
612
613
    graderr = (fp8.weight.grad - fp32.weight.grad).abs().mean()
    bgraderr = (fp8.bias.grad - fp32.bias.grad).abs().mean()
614
615
616
617
618

    assert err < 0.05
    assert graderr < 0.00002
    assert bgraderr < 0.00002

Ruff's avatar
Ruff committed
619

620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
@pytest.mark.parametrize("embedding_dim", [64, 65])
@pytest.mark.parametrize("input_shape", [(10,), (10, 10), (10, 10, 10)], ids=str)
@pytest.mark.parametrize(
    "embedding_class,quant_storage",
    [
        (bnb.nn.Embedding8bit, None),
        (bnb.nn.EmbeddingFP4, torch.uint8),
        (bnb.nn.EmbeddingFP4, torch.float32),
        (bnb.nn.EmbeddingNF4, torch.uint8),
        (bnb.nn.EmbeddingNF4, torch.float32),
    ],
    ids=lambda x: x.__name__ if inspect.isclass(x) else str(x),
)
def test_embedding_lossless(embedding_class, input_shape, embedding_dim, quant_storage):
    num_embeddings = 128

    src_weight = (torch.randn((num_embeddings, embedding_dim), dtype=torch.float32) > 0).to(
        torch.float32
    ) * 2 - 1  # Embeddings filled with {-1, 1} values. It should compress losslessly

    emb_base = nn.Embedding(
        num_embeddings=num_embeddings,
        embedding_dim=embedding_dim,
        _freeze=True,
        _weight=src_weight,
    )
    if embedding_class is bnb.nn.Embedding8bit:
        e = embedding_class(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
    else:
        e = embedding_class(num_embeddings=num_embeddings, embedding_dim=embedding_dim, quant_storage=quant_storage)

    e.load_state_dict(emb_base.state_dict())

    emb_base.cuda()
    e.cuda()

    input_tokens = torch.randint(low=0, high=num_embeddings, size=input_shape, device="cuda")

    torch.testing.assert_close(
        actual=e(input_tokens),
        expected=emb_base(input_tokens),
    )


@pytest.mark.parametrize("embedding_dim", [64, 65])
@pytest.mark.parametrize("input_shape", [(10,), (10, 10), (10, 10, 10)], ids=str)
@pytest.mark.parametrize(
    "embedding_class,quant_storage",
    [
        (bnb.nn.Embedding8bit, None),
        (bnb.nn.EmbeddingFP4, torch.uint8),
        (bnb.nn.EmbeddingFP4, torch.float32),
        (bnb.nn.EmbeddingNF4, torch.uint8),
        (bnb.nn.EmbeddingNF4, torch.float32),
    ],
    ids=lambda x: x.__name__ if inspect.isclass(x) else str(x),
)
def test_embedding_error(embedding_class, input_shape, embedding_dim, quant_storage):
    is_8bit = embedding_class is bnb.nn.Embedding8bit

    num_embeddings = 128

    src_weight = torch.rand((num_embeddings, embedding_dim), dtype=torch.float32)

    emb_base = nn.Embedding(
        num_embeddings=num_embeddings,
        embedding_dim=embedding_dim,
        _freeze=True,
        _weight=src_weight,
    )
    if is_8bit:
        e = embedding_class(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
    else:
        e = embedding_class(num_embeddings=num_embeddings, embedding_dim=embedding_dim, quant_storage=quant_storage)

    e.load_state_dict(emb_base.state_dict())

    emb_base.cuda()
    e.cuda()

    input_tokens = torch.randint(low=0, high=num_embeddings, size=input_shape, device="cuda")

    torch.testing.assert_close(
        actual=e(input_tokens),
        expected=emb_base(input_tokens),
        atol=0.05 if is_8bit else 0.20,
        rtol=0.0,
    )


def test_4bit_linear_warnings():
711
712
    dim1 = 64

Ruff's avatar
Ruff committed
713
    with pytest.warns(UserWarning, match=r"inference or training"):
714
715
716
717
        net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)])
        net = net.cuda()
        inp = torch.rand(10, dim1).cuda().half()
        net(inp)
Ruff's avatar
Ruff committed
718
    with pytest.warns(UserWarning, match=r"inference."):
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
        net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)])
        net = net.cuda()
        inp = torch.rand(1, dim1).cuda().half()
        net(inp)

    with pytest.warns(UserWarning) as record:
        net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)])
        net = net.cuda()
        inp = torch.rand(10, dim1).cuda().half()
        net(inp)

        net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)])
        net = net.cuda()
        inp = torch.rand(1, dim1).cuda().half()
        net(inp)

    assert len(record) == 2
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790


def test_4bit_embedding_warnings():
    num_embeddings = 128
    default_block_size = 64

    with pytest.warns(UserWarning, match=r"inference."):
        net = bnb.nn.Embedding4bit(num_embeddings=num_embeddings, embedding_dim=default_block_size + 1)
        net.cuda()
        inp = torch.randint(low=0, high=num_embeddings, size=(1,), device="cuda")
        net(inp)


def test_4bit_embedding_weight_fsdp_fix():
    num_embeddings = 64
    embedding_dim = 32

    module = bnb.nn.Embedding4bit(num_embeddings=num_embeddings, embedding_dim=embedding_dim)

    module.cuda()

    module.weight.quant_state = None

    input_tokens = torch.randint(low=0, high=num_embeddings, size=(1,), device="cuda")

    module(input_tokens)

    assert module.weight.quant_state is not None


def test_4bit_linear_weight_fsdp_fix():
    inp_size = 64
    out_size = 32

    module = bnb.nn.Linear4bit(inp_size, out_size)

    module.cuda()

    module.weight.quant_state = None

    input_tensor = torch.randn((1, inp_size), device="cuda")

    module(input_tensor)

    assert module.weight.quant_state is not None


def test_embedding_not_implemented_error():
    with pytest.raises(NotImplementedError):
        emb = bnb.nn.Embedding4bit(32, 32)
        emb.state_dict()

    with pytest.raises(NotImplementedError):
        emb = bnb.nn.Embedding8bit(32, 32)
        emb.state_dict()