test_modules.py 18.8 KB
Newer Older
1
import inspect
2

3
4
import pytest
import torch
Tim Dettmers's avatar
Tim Dettmers committed
5
6
from torch import nn

7
import bitsandbytes as bnb
8
from tests.helpers import get_available_devices, id_formatter
9

10

11
class MockArgs:
Tim Dettmers's avatar
Tim Dettmers committed
12
13
14
15
    def __init__(self, initial_data):
        for key in initial_data:
            setattr(self, key, initial_data[key])

16

Tim Dettmers's avatar
Tim Dettmers committed
17
class MLP8bit(torch.nn.Module):
18
    def __init__(self, dim1, dim2, has_fp16_weights=True, threshold=0.0):
19
        super().__init__()
20
        self.fc1 = bnb.nn.Linear8bitLt(
Ruff's avatar
Ruff committed
21
22
23
24
            dim1,
            dim2,
            has_fp16_weights=has_fp16_weights,
            threshold=threshold,
25
26
        )
        self.fc2 = bnb.nn.Linear8bitLt(
Ruff's avatar
Ruff committed
27
28
29
30
            dim2,
            dim1,
            has_fp16_weights=has_fp16_weights,
            threshold=threshold,
31
        )
Tim Dettmers's avatar
Tim Dettmers committed
32
33
34
35
36
37
38
39
40

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x


def get_args():
    args = MockArgs([])
41
42
    args.quant_type = "vector"
    args.use_8bit_training = "full"
Tim Dettmers's avatar
Tim Dettmers committed
43
44
45
    args.clip_freq = 9999
    return args

46

Tim Dettmers's avatar
Tim Dettmers committed
47
def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10):
48
    idx = torch.isclose(a, b, rtol=rtol, atol=atol)
49
    sumval = (idx == 0).sum().item()
Tim Dettmers's avatar
Tim Dettmers committed
50
    if sumval > count:
51
        print(f"Too many values not close: assert {sumval} < {count}")
52
        torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
Tim Dettmers's avatar
Tim Dettmers committed
53
54


55
@pytest.mark.parametrize("device", get_available_devices())
Aarni Koskela's avatar
Aarni Koskela committed
56
@pytest.mark.parametrize("threshold", [0.0, 3.0], ids=id_formatter("threshold"))
57
58
59
60
61
62
63
def test_linear8bitlt_inference(device, threshold):
    if device == "cpu":
        pytest.xfail("Not yet implemented on CPU")

    l1 = bnb.nn.Linear8bitLt(32, 64, threshold=threshold, has_fp16_weights=False).to(device).half()
    assert l1.weight.device.type == device
    assert l1.weight.dtype == torch.int8
Tim Dettmers's avatar
Tim Dettmers committed
64
65

    l1.eval()
66
    for i in range(100):
67
        b1 = torch.randn(16, 8, 32, device=device).half()
Tim Dettmers's avatar
Tim Dettmers committed
68
69
        o1 = l1(b1)
        if i == 1:
70
            assert l1.state.CB is not None
Tim Dettmers's avatar
Tim Dettmers committed
71

72

73
74
75
76
77
78
79
80
# TODO: Remove support for training int8 weights
@pytest.mark.parametrize("device", get_available_devices())
def test_linear8bitlt_accumulated_gradient(device):
    if device != "cuda":
        pytest.skip("Only supported on CUDA")

    l1 = torch.nn.Sequential(*[bnb.nn.Linear8bitLt(32, 32).to(device).half() for i in range(2)])
    l2 = torch.nn.Sequential(*[torch.nn.Linear(32, 32).to(device).half() for i in range(2)])
81
82
83
84
85
86
87
    l1[0].weight.data.copy_(l2[0].weight.data)
    l1[1].weight.data.copy_(l2[1].weight.data)
    l1[0].bias.data.copy_(l2[0].bias.data)
    l1[1].bias.data.copy_(l2[1].bias.data)

    opt1 = bnb.optim.Adam32bit(l1.parameters(), lr=0.001)
    opt2 = bnb.optim.Adam32bit(l2.parameters(), lr=0.001)
Tim Dettmers's avatar
Tim Dettmers committed
88
89
90

    acc_steps = 10

91
    for i in range(15):
92
        b1 = torch.randn(16, 8, 32, device=device).half()
Tim Dettmers's avatar
Tim Dettmers committed
93
94
95
96
97
98
99
        o1 = l1(b1)
        o2 = l2(b1)
        loss1 = o1.mean()
        loss2 = o2.mean()
        loss1.backward()
        loss2.backward()
        if i == 2:
100
101
            assert l1[0].state.CB is not None
            assert l1[1].state.CB is not None
102

Tim Dettmers's avatar
Tim Dettmers committed
103
104
105
106
107
        if i > 0 and i % acc_steps == 0:
            opt1.step()
            opt1.zero_grad(True)
            opt2.step()
            opt2.zero_grad(True)
Ruff's avatar
Ruff committed
108
109
            assert_all_approx_close(l1[0].weight, l2[0].weight, rtol=1.05, atol=0.01, count=2)
            assert_all_approx_close(l1[1].weight, l2[1].weight, rtol=1.05, atol=0.01, count=2)
Tim Dettmers's avatar
Tim Dettmers committed
110
111
112
            # we do this copy because otherwise we have small divergences over time that add up
            l1[0].weight.data.copy_(l2[0].weight.data)
            l1[1].weight.data.copy_(l2[1].weight.data)
113
114
            l1[0].bias.data.copy_(l2[0].bias.data)
            l1[1].bias.data.copy_(l2[1].bias.data)
Tim Dettmers's avatar
Tim Dettmers committed
115
        else:
116
117
            assert_all_approx_close(l1[0].weight.grad, l2[0].weight.grad, rtol=1.05, atol=0.04, count=1)
            assert_all_approx_close(l1[1].weight.grad, l2[1].weight.grad, rtol=1.05, atol=0.04, count=1)
118
119


120
@pytest.mark.parametrize("device", get_available_devices())
121
@pytest.mark.parametrize("threshold", [0.0, 2.0])
122
123
124
125
def test_linear8bitlt_no_fp16_weights(device, threshold):
    if device == "cpu":
        pytest.xfail("Not yet supported on CPU")

Ruff's avatar
Ruff committed
126
127
128
129
130
131
132
    l1 = (
        bnb.nn.Linear8bitLt(
            32,
            64,
            threshold=threshold,
            has_fp16_weights=False,
        )
133
        .to(device)
Ruff's avatar
Ruff committed
134
135
        .half()
    )
Tim Dettmers's avatar
Tim Dettmers committed
136
    assert l1.weight.dtype == torch.int8
137

Tim Dettmers's avatar
Tim Dettmers committed
138
139
    l1.eval()
    for i in range(100):
140
        b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16)
Tim Dettmers's avatar
Tim Dettmers committed
141
142
143
        o1 = l1(b1)
        assert o1.dtype == torch.float16

144
    mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).to(device)
Tim Dettmers's avatar
Tim Dettmers committed
145
146
    assert mlp.fc1.weight.dtype == torch.int8
    assert mlp.fc2.weight.dtype == torch.int8
147

Tim Dettmers's avatar
Tim Dettmers committed
148
    for i in range(100):
149
        b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16)
Tim Dettmers's avatar
Tim Dettmers committed
150
151
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
152
153
154
155
        if threshold > 0:
            assert mlp.fc1.state.idx is not None
        if threshold > 0:
            assert mlp.fc2.state.idx is not None
156

157
    mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).to(device).half()
Tim Dettmers's avatar
Tim Dettmers committed
158
159
160
161
    assert mlp.fc1.weight.dtype == torch.int8
    assert mlp.fc2.weight.dtype == torch.int8

    for i in range(100):
162
        b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16)
Tim Dettmers's avatar
Tim Dettmers committed
163
164
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
165
166
167
168
        if threshold > 0:
            assert mlp.fc1.state.idx is not None
        if threshold > 0:
            assert mlp.fc2.state.idx is not None
169

170
    mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().to(device)
Tim Dettmers's avatar
Tim Dettmers committed
171
172

    for i in range(100):
173
        b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16)
Tim Dettmers's avatar
Tim Dettmers committed
174
175
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
176
177
178
179
        if threshold > 0:
            assert mlp.fc1.state.idx is not None
        if threshold > 0:
            assert mlp.fc2.state.idx is not None
Tim Dettmers's avatar
Tim Dettmers committed
180
181
182
    assert mlp.fc1.weight.dtype == torch.int8
    assert mlp.fc2.weight.dtype == torch.int8

Ruff's avatar
Ruff committed
183
184
185
186
187
188
189
190
    mlp = (
        MLP8bit(
            32,
            64,
            threshold=threshold,
            has_fp16_weights=False,
        )
        .half()
191
        .to(device)
Ruff's avatar
Ruff committed
192
    )
Tim Dettmers's avatar
Tim Dettmers committed
193
194

    for i in range(100):
195
        b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16)
Tim Dettmers's avatar
Tim Dettmers committed
196
197
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
198
199
200
201
        if threshold > 0:
            assert mlp.fc1.state.idx is not None
        if threshold > 0:
            assert mlp.fc2.state.idx is not None
Tim Dettmers's avatar
Tim Dettmers committed
202
203
    assert mlp.fc1.weight.dtype == torch.int8
    assert mlp.fc2.weight.dtype == torch.int8
204
205
    assert mlp.fc1.weight.device.type == device
    assert mlp.fc2.weight.device.type == device
Tim Dettmers's avatar
Tim Dettmers committed
206

justheuristic's avatar
justheuristic committed
207
    mlp = MLP8bit(
Ruff's avatar
Ruff committed
208
209
210
211
212
        32,
        64,
        threshold=threshold,
        has_fp16_weights=False,
    )
213
    w1, w2 = mlp.fc1.weight.clone().to(device), mlp.fc2.weight.clone().to(device)  # grab weights before quantization,
justheuristic's avatar
justheuristic committed
214
    mlp = mlp.cuda().half()  # and this line triggers quantization
Tim Dettmers's avatar
Tim Dettmers committed
215
216

    for i in range(100):
217
        b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16)
Tim Dettmers's avatar
Tim Dettmers committed
218
219
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
220
221
222
223
        if threshold > 0:
            assert mlp.fc1.state.idx is not None
        if threshold > 0:
            assert mlp.fc2.state.idx is not None
justheuristic's avatar
justheuristic committed
224

Tim Dettmers's avatar
Tim Dettmers committed
225
226
    assert mlp.fc1.weight.dtype == torch.int8
    assert mlp.fc2.weight.dtype == torch.int8
227
228
    assert mlp.fc1.weight.device.type == device
    assert mlp.fc2.weight.device.type == device
229

230
    b1 = torch.randn(16, 8, 32, device=device, requires_grad=True, dtype=torch.half)
231
232
233
234
    o1 = mlp(b1)
    assert o1.dtype == torch.float16
    assert o1.requires_grad
    grad_proj = torch.randn_like(o1)
justheuristic's avatar
justheuristic committed
235

236
237
238
239
    mlp.zero_grad()
    (o1 * grad_proj).sum().backward()
    grad_ref = grad_proj.flatten(2) @ w2.half() @ w1.half()
    scale = grad_ref.abs().mean()
justheuristic's avatar
justheuristic committed
240

241
242
243
    torch.testing.assert_close(b1.grad, grad_ref, rtol=0, atol=0.05 * scale)
    idx = torch.isclose(b1.grad, grad_ref, atol=0.01 * scale, rtol=0.1)
    assert (idx == 0).sum().item() <= b1.numel() * 0.005
244

justheuristic's avatar
justheuristic committed
245

246
@pytest.mark.parametrize("device", get_available_devices())
247
248
249
250
@pytest.mark.parametrize(
    "module",
    [
        lambda n_in, n_out, bias=True: bnb.nn.Linear8bitLt(n_in, n_out, bias=bias, has_fp16_weights=False),
251
        bnb.nn.LinearNF4,
252
    ],
253
    ids=["Int8Lt", "NF4"],
254
)
255
256
257
258
def test_linear_kbit_fp32_bias(device, module):
    if device == "cpu":
        pytest.xfail("Not yet implemented on CPU")

259
    # casts model to fp16 -> int8 automatically
260
    l1 = module(32, 64).to(device)
261
    assert l1.weight.dtype in [torch.int8, torch.uint8]
262
263
264
    assert l1.bias.dtype == torch.float32

    for i in range(100):
265
        b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16)
266
267
268
269
270
        # casts bias to fp32
        o1 = l1(b1)
        assert l1.bias.dtype == torch.float16

    # casts model to fp16 -> int8 automatically
271
    l1 = module(32, 64, bias=False).to(device)
272
    assert l1.weight.dtype in [torch.int8, torch.uint8]
273
274
275
    assert l1.bias is None

    for i in range(100):
276
        b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16)
277
278
        o1 = l1(b1)
        assert l1.bias is None
279

Aarni Koskela's avatar
Aarni Koskela committed
280
281
282
283
284
285
286
287
288
289
290
291
292
293

module_dict = {
    "Int8Lt": bnb.nn.Linear8bitLt,
    "4bit": bnb.nn.Linear4bit,
    "FP4": bnb.nn.LinearFP4,
    "NF4": bnb.nn.LinearNF4,
    "FP4+C": lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compress_statistics=True),
    "NF4+C": lambda d1, d2: bnb.nn.LinearNF4(d1, d2, compress_statistics=True),
    "NF4+fp32": lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.float32),
    "NF4+fp16": lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.float16),
    "NF4+bf16": lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.bfloat16),
}


294
@pytest.mark.parametrize("device", get_available_devices())
Aarni Koskela's avatar
Aarni Koskela committed
295
@pytest.mark.parametrize("module", module_dict.values(), ids=module_dict.keys())
296
297
298
299
def test_kbit_backprop(device, module):
    if device == "cpu":
        pytest.xfail("Not yet implemented on CPU")

300
301
302
303
304
305
306
307
    b = 16
    dim1 = 36
    dim2 = 84
    # dim1 = 37
    # dim2 = 83

    ref = nn.Sequential(*[torch.nn.Linear(dim1, dim2), torch.nn.Linear(dim2, 128)])
    # ref[1].weight.requires_grad = False
308
309
    torch.nn.init.kaiming_normal_(ref[0].weight)
    torch.nn.init.kaiming_normal_(ref[1].weight)
310
    kbit = nn.Sequential(*[torch.nn.Linear(dim1, dim2), module(dim2, 128)])
311
312
313
314
    kbit[0].weight.detach().copy_(ref[0].weight)
    kbit[1].weight.detach().copy_(ref[1].weight)
    kbit[0].bias.detach().copy_(ref[0].bias)
    kbit[1].bias.detach().copy_(ref[1].bias)
315
316
317
    ref = ref.half().to(device)
    kbit = kbit.half().to(device)
    kbit = kbit.half().to(device)
318

319
320
321
322
    errs1 = []
    errs2 = []
    relerrs1 = []
    relerrs2 = []
323
    for i in range(100):
324
        batch = torch.randn(b, dim1, device=device, dtype=torch.float16)
325
326
327
328
329
330
331
332
333
334
        out1 = ref(batch)
        out2 = kbit(batch)
        out1.mean().backward()
        out2.mean().backward()

        grad1 = ref[0].weight.grad
        grad2 = kbit[0].weight.grad
        bgrad1 = ref[0].bias.grad
        bgrad2 = kbit[0].bias.grad

Ruff's avatar
Ruff committed
335
336
337
338
        err1 = (out1 - out2).abs().float()
        err2 = (grad1 - grad2).abs().float()
        relerr1 = err1 / (out1.abs().float() + 1e-9)
        relerr2 = err2 / (grad1.abs().float() + 1e-9)
339
340
341
342
343
        errs1.append(err1.mean().item())
        errs2.append(err2.mean().item())
        relerrs1.append(relerr1.mean().item())
        relerrs2.append(relerr2.mean().item())

344
        if isinstance(module, bnb.nn.Linear8bitLt):
345
            assert_all_approx_close(grad1, grad2, atol=0.008, rtol=0.05, count=1)
346
            torch.testing.assert_close(bgrad1, bgrad2, atol=0.008, rtol=0.05)
347
        else:
348
            assert_all_approx_close(grad1, grad2, atol=0.015, rtol=0.05, count=1)
349
            torch.testing.assert_close(bgrad1, bgrad2, atol=0.02, rtol=0.05)
350
351
352
        ref.zero_grad()
        kbit.zero_grad()

353
354
        assert kbit[0].weight.grad is None or kbit[0].weight.grad.sum().item() == 0
        assert kbit[0].weight.grad is None or kbit[0].bias.grad.sum().item() == 0
355

356

357
@pytest.mark.deprecated
Ruff's avatar
Ruff committed
358
def test_fp8linear():
359
360
361
    b = 10
    h = 1024
    inp = torch.randn(b, h).cuda()
Ruff's avatar
Ruff committed
362
363
364
365
    fp32 = torch.nn.Linear(h, h * 2).cuda()
    fp8 = bnb.research.nn.LinearFP8Mixed(h, h * 2).cuda()
    fp32b = torch.nn.Linear(h * 2, h).cuda()
    fp8b = bnb.research.nn.LinearFP8Mixed(h * 2, h).cuda()
366
367
368
369
370
371
372
373
374

    fp8.weight.data.copy_(fp32.weight.data)
    fp8.bias.data.copy_(fp32.bias.data)
    fp8b.weight.data.copy_(fp32b.weight.data)
    fp8b.bias.data.copy_(fp32b.bias.data)

    a = fp32b(torch.nn.functional.gelu(fp32(inp)))
    b = fp8b(torch.nn.functional.gelu(fp8(inp)))

Ruff's avatar
Ruff committed
375
    err = (a - b).abs().mean()
376
377
378
379

    a.mean().backward()
    b.mean().backward()

Ruff's avatar
Ruff committed
380
381
    graderr = (fp8.weight.grad - fp32.weight.grad).abs().mean()
    bgraderr = (fp8.bias.grad - fp32.bias.grad).abs().mean()
382
383
384
385
386

    assert err < 0.05
    assert graderr < 0.00002
    assert bgraderr < 0.00002

Ruff's avatar
Ruff committed
387

388
@pytest.mark.parametrize("device", get_available_devices())
389
390
391
392
393
394
395
396
397
398
399
400
401
@pytest.mark.parametrize("embedding_dim", [64, 65])
@pytest.mark.parametrize("input_shape", [(10,), (10, 10), (10, 10, 10)], ids=str)
@pytest.mark.parametrize(
    "embedding_class,quant_storage",
    [
        (bnb.nn.Embedding8bit, None),
        (bnb.nn.EmbeddingFP4, torch.uint8),
        (bnb.nn.EmbeddingFP4, torch.float32),
        (bnb.nn.EmbeddingNF4, torch.uint8),
        (bnb.nn.EmbeddingNF4, torch.float32),
    ],
    ids=lambda x: x.__name__ if inspect.isclass(x) else str(x),
)
402
403
404
405
def test_embedding_lossless(device, embedding_class, input_shape, embedding_dim, quant_storage):
    if device == "cpu":
        pytest.xfail("Not yet supported on CPU")

406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
    num_embeddings = 128

    src_weight = (torch.randn((num_embeddings, embedding_dim), dtype=torch.float32) > 0).to(
        torch.float32
    ) * 2 - 1  # Embeddings filled with {-1, 1} values. It should compress losslessly

    emb_base = nn.Embedding(
        num_embeddings=num_embeddings,
        embedding_dim=embedding_dim,
        _freeze=True,
        _weight=src_weight,
    )
    if embedding_class is bnb.nn.Embedding8bit:
        e = embedding_class(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
    else:
        e = embedding_class(num_embeddings=num_embeddings, embedding_dim=embedding_dim, quant_storage=quant_storage)

    e.load_state_dict(emb_base.state_dict())

425
426
    emb_base.to(device)
    e.to(device)
427

428
    input_tokens = torch.randint(low=0, high=num_embeddings, size=input_shape, device=device)
429
430
431
432
433
434
435

    torch.testing.assert_close(
        actual=e(input_tokens),
        expected=emb_base(input_tokens),
    )


436
@pytest.mark.parametrize("device", get_available_devices())
437
438
439
440
441
442
443
444
445
446
447
448
449
@pytest.mark.parametrize("embedding_dim", [64, 65])
@pytest.mark.parametrize("input_shape", [(10,), (10, 10), (10, 10, 10)], ids=str)
@pytest.mark.parametrize(
    "embedding_class,quant_storage",
    [
        (bnb.nn.Embedding8bit, None),
        (bnb.nn.EmbeddingFP4, torch.uint8),
        (bnb.nn.EmbeddingFP4, torch.float32),
        (bnb.nn.EmbeddingNF4, torch.uint8),
        (bnb.nn.EmbeddingNF4, torch.float32),
    ],
    ids=lambda x: x.__name__ if inspect.isclass(x) else str(x),
)
450
451
452
453
def test_embedding_error(device, embedding_class, input_shape, embedding_dim, quant_storage):
    if device == "cpu":
        pytest.xfail("Not yet supported on CPU")

454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
    is_8bit = embedding_class is bnb.nn.Embedding8bit

    num_embeddings = 128

    src_weight = torch.rand((num_embeddings, embedding_dim), dtype=torch.float32)

    emb_base = nn.Embedding(
        num_embeddings=num_embeddings,
        embedding_dim=embedding_dim,
        _freeze=True,
        _weight=src_weight,
    )
    if is_8bit:
        e = embedding_class(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
    else:
        e = embedding_class(num_embeddings=num_embeddings, embedding_dim=embedding_dim, quant_storage=quant_storage)

    e.load_state_dict(emb_base.state_dict())

473
474
    emb_base.to(device)
    e.to(device)
475

476
    input_tokens = torch.randint(low=0, high=num_embeddings, size=input_shape, device=device)
477
478
479
480
481
482
483
484
485

    torch.testing.assert_close(
        actual=e(input_tokens),
        expected=emb_base(input_tokens),
        atol=0.05 if is_8bit else 0.20,
        rtol=0.0,
    )


486
487
488
489
490
@pytest.mark.parametrize("device", get_available_devices())
def test_4bit_linear_warnings(device):
    if device == "cpu":
        pytest.xfail("Not yet implemented on CPU")

491
492
    dim1 = 64

Ruff's avatar
Ruff committed
493
    with pytest.warns(UserWarning, match=r"inference or training"):
494
495
496
497
498
        net = nn.Sequential(
            *[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4", compute_dtype=torch.float32) for i in range(10)]
        )
        net = net.to(device)
        inp = torch.rand(10, dim1, device=device, dtype=torch.float16)
499
        net(inp)
Ruff's avatar
Ruff committed
500
    with pytest.warns(UserWarning, match=r"inference."):
501
502
503
504
505
        net = nn.Sequential(
            *[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4", compute_dtype=torch.float32) for i in range(10)]
        )
        net = net.to(device)
        inp = torch.rand(1, dim1, device=device, dtype=torch.float16)
506
507
508
        net(inp)

    with pytest.warns(UserWarning) as record:
509
510
511
512
513
        net = nn.Sequential(
            *[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4", compute_dtype=torch.float32) for i in range(10)]
        )
        net = net.to(device)
        inp = torch.rand(10, dim1, device=device, dtype=torch.float16)
514
515
        net(inp)

516
517
518
519
520
        net = nn.Sequential(
            *[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4", compute_dtype=torch.float32) for i in range(10)]
        )
        net = net.to(device)
        inp = torch.rand(1, dim1, device=device, dtype=torch.float16)
521
522
523
        net(inp)

    assert len(record) == 2
524
525


526
527
528
529
530
@pytest.mark.parametrize("device", get_available_devices())
def test_4bit_embedding_warnings(device):
    if device == "cpu":
        pytest.xfail("Not yet implemented on CPU")

531
532
533
534
    num_embeddings = 128
    default_block_size = 64

    with pytest.warns(UserWarning, match=r"inference."):
535
536
537
538
539
        net = bnb.nn.Embedding4bit(
            num_embeddings=num_embeddings, embedding_dim=default_block_size + 1, quant_type="nf4"
        )
        net.to(device)
        inp = torch.randint(low=0, high=num_embeddings, size=(1,), device=device)
540
541
542
        net(inp)


543
def test_4bit_embedding_weight_fsdp_fix(requires_cuda):
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
    num_embeddings = 64
    embedding_dim = 32

    module = bnb.nn.Embedding4bit(num_embeddings=num_embeddings, embedding_dim=embedding_dim)

    module.cuda()

    module.weight.quant_state = None

    input_tokens = torch.randint(low=0, high=num_embeddings, size=(1,), device="cuda")

    module(input_tokens)

    assert module.weight.quant_state is not None


560
def test_4bit_linear_weight_fsdp_fix(requires_cuda):
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
    inp_size = 64
    out_size = 32

    module = bnb.nn.Linear4bit(inp_size, out_size)

    module.cuda()

    module.weight.quant_state = None

    input_tensor = torch.randn((1, inp_size), device="cuda")

    module(input_tensor)

    assert module.weight.quant_state is not None


def test_embedding_not_implemented_error():
    with pytest.raises(NotImplementedError):
        emb = bnb.nn.Embedding4bit(32, 32)
        emb.state_dict()

    with pytest.raises(NotImplementedError):
        emb = bnb.nn.Embedding8bit(32, 32)
        emb.state_dict()