test_modules.py 26.1 KB
Newer Older
1
import inspect
Aarni Koskela's avatar
Aarni Koskela committed
2
import math
3

Aarni Koskela's avatar
Aarni Koskela committed
4
import einops
5
6
import pytest
import torch
Tim Dettmers's avatar
Tim Dettmers committed
7
8
from torch import nn

9
import bitsandbytes as bnb
Aarni Koskela's avatar
Aarni Koskela committed
10
from tests.helpers import id_formatter
11

12

13
class MockArgs:
Tim Dettmers's avatar
Tim Dettmers committed
14
15
16
17
    def __init__(self, initial_data):
        for key in initial_data:
            setattr(self, key, initial_data[key])

18

Tim Dettmers's avatar
Tim Dettmers committed
19
class MLP8bit(torch.nn.Module):
20
    def __init__(self, dim1, dim2, has_fp16_weights=True, threshold=0.0):
21
        super().__init__()
22
        self.fc1 = bnb.nn.Linear8bitLt(
Ruff's avatar
Ruff committed
23
24
25
26
            dim1,
            dim2,
            has_fp16_weights=has_fp16_weights,
            threshold=threshold,
27
28
        )
        self.fc2 = bnb.nn.Linear8bitLt(
Ruff's avatar
Ruff committed
29
30
31
32
            dim2,
            dim1,
            has_fp16_weights=has_fp16_weights,
            threshold=threshold,
33
        )
Tim Dettmers's avatar
Tim Dettmers committed
34
35
36
37
38
39
40
41
42

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x


def get_args():
    args = MockArgs([])
43
44
    args.quant_type = "vector"
    args.use_8bit_training = "full"
Tim Dettmers's avatar
Tim Dettmers committed
45
46
47
    args.clip_freq = 9999
    return args

48

Tim Dettmers's avatar
Tim Dettmers committed
49
def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10):
50
    idx = torch.isclose(a, b, rtol=rtol, atol=atol)
51
    sumval = (idx == 0).sum().item()
Tim Dettmers's avatar
Tim Dettmers committed
52
    if sumval > count:
53
        print(f"Too many values not close: assert {sumval} < {count}")
54
        torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
Tim Dettmers's avatar
Tim Dettmers committed
55
56


57
class LinearFunction(torch.autograd.Function):
Tim Dettmers's avatar
Tim Dettmers committed
58
59
    @staticmethod
    def get_8bit_linear_trimmed(x, stochastic=False, trim_value=3.0):
Ruff's avatar
Ruff committed
60
        round_func = LinearFunction.round_stoachastic if stochastic else torch.round
61
62
        norm = math.sqrt(math.pi) / math.sqrt(2.0)
        # std = torch.abs(x).mean()*norm
Tim Dettmers's avatar
Tim Dettmers committed
63
        std = torch.std(x)
64
65
        max1 = std * trim_value
        x = x / max1 * 127
Tim Dettmers's avatar
Tim Dettmers committed
66
67
68
        x = round_func(x)
        x[x > 127] = 127
        x[x < -127] = -127
69
        x = x / 127 * max1
Tim Dettmers's avatar
Tim Dettmers committed
70
71
72
73

        return x

    def quant(x, quant_type, dim=1):
74
        if quant_type == "linear":
Tim Dettmers's avatar
Tim Dettmers committed
75
            max1 = torch.abs(x).max().float()
76
            xq = torch.round(x / max1 * 127).to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
77
            return xq, max1
78
        elif quant_type == "vector":
Tim Dettmers's avatar
Tim Dettmers committed
79
            max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
80
            xq = torch.round(x / max1 * 127).to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
81
            return xq, max1
82
        elif quant_type == "min-max":
Tim Dettmers's avatar
Tim Dettmers committed
83
84
            maxA = torch.amax(x, dim=dim, keepdim=True).float()
            minA = torch.amin(x, dim=dim, keepdim=True).float()
85
86
            scale = (maxA - minA) / 2.0
            xq = torch.round(127 * (x - minA - scale) / scale).to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
87
            return xq, (minA.float(), scale.float())
88
89
        else:
            return None
Tim Dettmers's avatar
Tim Dettmers committed
90
91

    def dequant(xq, S1, S2, dtype, quant_type):
92
93
        if quant_type == "linear":
            norm = S1 * S2 / (127 * 127)
Tim Dettmers's avatar
Tim Dettmers committed
94
            # double cast needed to prevent overflows
95
96
            return (xq.float() * norm).to(dtype)
        elif quant_type == "vector":
Tim Dettmers's avatar
Tim Dettmers committed
97
            x = xq.float()
98
99
100
101
102
            if len(xq.shape) == 2 and len(S1.shape) == 3:
                S1 = S1.squeeze(0)
            if len(xq.shape) == 2 and len(S2.shape) == 3:
                S2 = S2.squeeze(0)
            # print(x.shape, S1.shape, S2.shape)
Tim Dettmers's avatar
Tim Dettmers committed
103
            if len(S1.shape) == 2:
104
                x *= S1.t() / 127
Tim Dettmers's avatar
Tim Dettmers committed
105
            else:
106
107
                x *= S1 / 127
            x *= S2 / 127
Tim Dettmers's avatar
Tim Dettmers committed
108
            return x.to(dtype)
109
110
        else:
            return None
Tim Dettmers's avatar
Tim Dettmers committed
111
112

    def dequant_min_max(xq, A, B, SA, SB, dtype):
113
        offset = B.float().t().sum(0) * (SA[0] + SA[1])
Tim Dettmers's avatar
Tim Dettmers committed
114
        x = xq.float()
115
116
117
118
        if len(xq.shape) == 2 and len(SB.shape) == 3:
            SB = SB.squeeze(0)
        if len(xq.shape) == 2 and len(SA.shape) == 3:
            SA = SA.squeeze(0)
Tim Dettmers's avatar
Tim Dettmers committed
119
        if len(SB.shape) == 2:
120
            x *= SB.t() / 127
Tim Dettmers's avatar
Tim Dettmers committed
121
        else:
122
123
124
            x *= SB / 127
        x *= SA[1] / 127
        x += offset
Tim Dettmers's avatar
Tim Dettmers committed
125
126
127
        return x.to(dtype)

    def get_8bit_linear(x, stochastic=False):
Ruff's avatar
Ruff committed
128
        round_func = LinearFunction.round_stoachastic if stochastic else torch.round
Tim Dettmers's avatar
Tim Dettmers committed
129
        max1 = torch.abs(x).max()
130
131
132
        x = x / max1 * 127
        x = round_func(x) / 127 * max1
        # x = torch.round(x)/128*max1
Tim Dettmers's avatar
Tim Dettmers committed
133
134
135
136
        return x

    @staticmethod
    def get_8bit_vector_wise(x, dim, stochastic=False):
Ruff's avatar
Ruff committed
137
        round_func = LinearFunction.round_stoachastic if stochastic else torch.round
Tim Dettmers's avatar
Tim Dettmers committed
138
        max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
139
140
141
        max1[max1 == 0] = 1.0
        x = (x * 127) / max1
        x = round_func(x) / 127 * max1
Tim Dettmers's avatar
Tim Dettmers committed
142
143
144
145
146
147
        return x

    @staticmethod
    def round_stoachastic(x):
        sign = torch.sign(x)
        absx = torch.abs(x)
148
        decimal = absx - torch.floor(absx)
Tim Dettmers's avatar
Tim Dettmers committed
149
        rdm = torch.rand_like(decimal)
150
        return sign * (torch.floor(absx) + (rdm < decimal).to(x.dtype))
Tim Dettmers's avatar
Tim Dettmers committed
151
152
153
154
155
156
157
158
159
160
161
162
163

    @staticmethod
    def fake_8bit_storage(w, exponent_bits):
        code = bnb.functional.create_dynamic_map(n=exponent_bits).to(w.device)
        absmax, C = bnb.functional.quantize_blockwise(w.data, code=code)
        out = bnb.functional.dequantize_blockwise(absmax, C, code)
        out = out.half()
        w.copy_(out)
        return out

    @staticmethod
    def fake_8bit_storage_quantile(w, args):
        code = bnb.functional.estimate_quantiles(w.data, offset=args.offset)
164
165
166
167
        # C = bnb.functional.quantize_no_absmax(code, w)
        # out = bnb.functional.dequantize_no_absmax(code, C, out=w.data)
        # print(out)
        # out = out.half()
Tim Dettmers's avatar
Tim Dettmers committed
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
        code /= torch.max(torch.abs(code))
        absmax, C = bnb.functional.quantize_blockwise(w.data, code=code)
        out = bnb.functional.dequantize_blockwise(absmax, C, code)
        out = out.half()
        w.copy_(out)
        return out

    @staticmethod
    def fake_8bit_storage_stoachstic(w):
        rand = torch.rand(1024, device=w.device)
        absmax, C = bnb.functional.quantize_blockwise(w.data, rand=rand)
        out = bnb.functional.dequantize_blockwise(absmax, C)
        out = out.half()
        w.copy_(out)
        return out

    @staticmethod
    def fake_8bit_storage_with_max(w, topk=8):
186
        blocked_w = einops.rearrange(w.flatten(), "(h b) -> h b", b=256)
Tim Dettmers's avatar
Tim Dettmers committed
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
        max_val, idx = torch.sort(torch.abs(blocked_w), dim=1, descending=True)
        idx = idx[:, :topk]
        max_val = max_val[:, :topk]

        mask = torch.zeros_like(blocked_w)
        mask.scatter_(dim=1, index=idx, src=torch.ones_like(max_val))
        mask = mask.bool()

        # 1. zero out max values
        # 2. quantize + dequantize
        # 3. write back max values
        # 4. copy matrix back to weight

        values = blocked_w[mask]
        blocked_w[mask] = 0

        code = bnb.functional.create_dynamic_map()
        code = code.to(w.device)
        absmax, C = bnb.functional.quantize_blockwise(blocked_w.data)
        bnb.functional.dequantize_blockwise(absmax, C, out=blocked_w)

        blocked_w[mask] = values

        unblocked_w = blocked_w.flatten().view(w.shape)

        w.copy_(unblocked_w)
        return unblocked_w

    @staticmethod
    def forward(ctx, x, weight, bias=None, args=None):
217
        if args.use_8bit_training != "off":
Tim Dettmers's avatar
Tim Dettmers committed
218
219
220
            weight8, S1 = LinearFunction.quant(weight, args.quant_type, dim=1)
            x8, S2 = LinearFunction.quant(x, args.quant_type, dim=2)
            outputq = bnb.functional.igemm(x8, weight8.t())
Ruff's avatar
Ruff committed
221
            output = LinearFunction.dequant(outputq, S1, S2, x.dtype, args.quant_type)
222
223
224
225
226
            # if torch.rand(1) < 0.01:
            # output32 = torch.matmul(x, weight.t())
            # err = torch.abs(output-output32).float()
            # relerr = err/(torch.abs(output32).float()+1e-8)
            # print(f'{err.mean().item():.4f}, {relerr.mean().item():.4f}', args.quant_type, 'forward', proxy)
Tim Dettmers's avatar
Tim Dettmers committed
227
        else:
228
229
            # output = torch.matmul(x, weight.t())
            output = torch.einsum("bsi,oi->bso", x, weight)
Tim Dettmers's avatar
Tim Dettmers committed
230
231
232
233
234
235
236
237
238
239
240
241
242
243

        ctx.save_for_backward(x, weight, bias)
        ctx.args = args

        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        x, weight, bias = ctx.saved_tensors
        args = ctx.args
        stochastic = False
        grad_input = grad_weight = grad_bias = None
244
245
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0)
Tim Dettmers's avatar
Tim Dettmers committed
246
247
248

        # weight and x are already 8bit
        # -> transform grad_output to 8-bit
249
        if args.use_8bit_training == "forward+wgrad":
Ruff's avatar
Ruff committed
250
            grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=[0, 1])
Tim Dettmers's avatar
Tim Dettmers committed
251
252
            x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1])
            grad_weight8 = bnb.functional.igemm(grad_output8, x8)
Ruff's avatar
Ruff committed
253
            grad_weight = LinearFunction.dequant(grad_weight8, S1, S2, grad_output.dtype, args.quant_type)
Tim Dettmers's avatar
Tim Dettmers committed
254

255
            # grad_weight32 = torch.einsum('bso,bsi->oi', grad_output, x)
Tim Dettmers's avatar
Tim Dettmers committed
256
257

            grad_input = grad_output.matmul(weight)
258
        elif args.use_8bit_training == "full":
Ruff's avatar
Ruff committed
259
            grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=[0, 1])
Tim Dettmers's avatar
Tim Dettmers committed
260
261
262
            x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1])
            grad_weight8 = torch.zeros_like(weight, dtype=torch.int32)
            bnb.functional.igemm(grad_output8, x8, out=grad_weight8)
Ruff's avatar
Ruff committed
263
            grad_weight = LinearFunction.dequant(grad_weight8, S1, S2, grad_output.dtype, args.quant_type)
Tim Dettmers's avatar
Tim Dettmers committed
264

Ruff's avatar
Ruff committed
265
            grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=2)
Tim Dettmers's avatar
Tim Dettmers committed
266
267
            weight8, S3 = LinearFunction.quant(weight, args.quant_type, dim=0)
            grad_input8 = bnb.functional.igemm(grad_output8, weight8)
Ruff's avatar
Ruff committed
268
            grad_input = LinearFunction.dequant(grad_input8, S1, S3, grad_output.dtype, args.quant_type)
Tim Dettmers's avatar
Tim Dettmers committed
269
270
271

        else:
            grad_input = grad_output.matmul(weight)
272
            grad_weight = torch.einsum("bsi,bso->oi", x, grad_output)
273

Tim Dettmers's avatar
Tim Dettmers committed
274
        return grad_input, grad_weight, grad_bias, None
275

276

Tim Dettmers's avatar
Tim Dettmers committed
277
278
class Linear8bit(nn.Module):
    def __init__(self, input_features, output_features, bias=True, args=None):
279
        super().__init__()
Tim Dettmers's avatar
Tim Dettmers committed
280
281
282
        self.input_features = input_features
        self.output_features = output_features
        self.args = args
283

Tim Dettmers's avatar
Tim Dettmers committed
284
285
286
287
        self.weight = nn.Parameter(torch.empty(output_features, input_features))
        if bias:
            self.bias = nn.Parameter(torch.empty(output_features))
        else:
288
            self.register_parameter("bias", None)
289

Tim Dettmers's avatar
Tim Dettmers committed
290
291
292
293
294
295
296
297
298
299
        torch.nn.init.xavier_uniform_(self.weight)
        if self.bias is not None:
            torch.nn.init.zeros_(self.bias)

    def forward(self, x):
        self.args.training = self.training

        return LinearFunction.apply(x, self.weight, self.bias, self.args)


Aarni Koskela's avatar
Aarni Koskela committed
300
@pytest.mark.parametrize("threshold", [0.0, 3.0], ids=id_formatter("threshold"))
Tim Dettmers's avatar
Tim Dettmers committed
301
def test_linear8bitlt_inference(threshold):
302
303
    l1 = bnb.nn.Linear8bitLt(32, 64, threshold=threshold).cuda().half()
    assert l1.weight.device.type == "cuda"
Tim Dettmers's avatar
Tim Dettmers committed
304
305
306
    assert l1.weight.dtype == torch.float16

    l1.eval()
307
    for i in range(100):
308
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
309
310
        o1 = l1(b1)
        if i == 1:
311
            assert l1.state.CB is not None
Tim Dettmers's avatar
Tim Dettmers committed
312

313

Tim Dettmers's avatar
Tim Dettmers committed
314
def test_linear8bitlt_accumulated_gradient():
315
316
    l1 = torch.nn.Sequential(*[bnb.nn.Linear8bitLt(32, 32).cuda().half() for i in range(2)])
    l2 = torch.nn.Sequential(*[torch.nn.Linear(32, 32).cuda().half() for i in range(2)])
317
318
319
320
321
322
323
    l1[0].weight.data.copy_(l2[0].weight.data)
    l1[1].weight.data.copy_(l2[1].weight.data)
    l1[0].bias.data.copy_(l2[0].bias.data)
    l1[1].bias.data.copy_(l2[1].bias.data)

    opt1 = bnb.optim.Adam32bit(l1.parameters(), lr=0.001)
    opt2 = bnb.optim.Adam32bit(l2.parameters(), lr=0.001)
Tim Dettmers's avatar
Tim Dettmers committed
324
325
326

    acc_steps = 10

327
    for i in range(15):
328
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
329
330
331
332
333
334
335
        o1 = l1(b1)
        o2 = l2(b1)
        loss1 = o1.mean()
        loss2 = o2.mean()
        loss1.backward()
        loss2.backward()
        if i == 2:
336
337
            assert l1[0].state.CB is not None
            assert l1[1].state.CB is not None
338

Tim Dettmers's avatar
Tim Dettmers committed
339
340
341
342
343
        if i > 0 and i % acc_steps == 0:
            opt1.step()
            opt1.zero_grad(True)
            opt2.step()
            opt2.zero_grad(True)
Ruff's avatar
Ruff committed
344
345
            assert_all_approx_close(l1[0].weight, l2[0].weight, rtol=1.05, atol=0.01, count=2)
            assert_all_approx_close(l1[1].weight, l2[1].weight, rtol=1.05, atol=0.01, count=2)
Tim Dettmers's avatar
Tim Dettmers committed
346
347
348
            # we do this copy because otherwise we have small divergences over time that add up
            l1[0].weight.data.copy_(l2[0].weight.data)
            l1[1].weight.data.copy_(l2[1].weight.data)
349
350
            l1[0].bias.data.copy_(l2[0].bias.data)
            l1[1].bias.data.copy_(l2[1].bias.data)
Tim Dettmers's avatar
Tim Dettmers committed
351
        else:
352
353
            assert_all_approx_close(l1[0].weight.grad, l2[0].weight.grad, rtol=1.05, atol=0.04, count=1)
            assert_all_approx_close(l1[1].weight.grad, l2[1].weight.grad, rtol=1.05, atol=0.04, count=1)
354
355


356
@pytest.mark.parametrize("threshold", [0.0, 2.0])
357
def test_linear8bitlt_no_fp16_weights(threshold):
Ruff's avatar
Ruff committed
358
359
360
361
362
363
364
365
366
367
    l1 = (
        bnb.nn.Linear8bitLt(
            32,
            64,
            threshold=threshold,
            has_fp16_weights=False,
        )
        .cuda()
        .half()
    )
Tim Dettmers's avatar
Tim Dettmers committed
368
    assert l1.weight.dtype == torch.int8
369

Tim Dettmers's avatar
Tim Dettmers committed
370
371
    l1.eval()
    for i in range(100):
372
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
373
374
375
376
377
378
        o1 = l1(b1)
        assert o1.dtype == torch.float16

    mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda()
    assert mlp.fc1.weight.dtype == torch.int8
    assert mlp.fc2.weight.dtype == torch.int8
379

Tim Dettmers's avatar
Tim Dettmers committed
380
    for i in range(100):
381
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
382
383
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
384
385
386
387
        if threshold > 0:
            assert mlp.fc1.state.idx is not None
        if threshold > 0:
            assert mlp.fc2.state.idx is not None
388

Ruff's avatar
Ruff committed
389
    mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda().half()
Tim Dettmers's avatar
Tim Dettmers committed
390
391
392
393
    assert mlp.fc1.weight.dtype == torch.int8
    assert mlp.fc2.weight.dtype == torch.int8

    for i in range(100):
394
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
395
396
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
397
398
399
400
        if threshold > 0:
            assert mlp.fc1.state.idx is not None
        if threshold > 0:
            assert mlp.fc2.state.idx is not None
401

Ruff's avatar
Ruff committed
402
    mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().cuda()
Tim Dettmers's avatar
Tim Dettmers committed
403
404

    for i in range(100):
405
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
406
407
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
408
409
410
411
        if threshold > 0:
            assert mlp.fc1.state.idx is not None
        if threshold > 0:
            assert mlp.fc2.state.idx is not None
Tim Dettmers's avatar
Tim Dettmers committed
412
413
414
    assert mlp.fc1.weight.dtype == torch.int8
    assert mlp.fc2.weight.dtype == torch.int8

Ruff's avatar
Ruff committed
415
416
417
418
419
420
421
422
423
424
    mlp = (
        MLP8bit(
            32,
            64,
            threshold=threshold,
            has_fp16_weights=False,
        )
        .half()
        .to("cuda")
    )
Tim Dettmers's avatar
Tim Dettmers committed
425
426

    for i in range(100):
427
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
428
429
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
430
431
432
433
        if threshold > 0:
            assert mlp.fc1.state.idx is not None
        if threshold > 0:
            assert mlp.fc2.state.idx is not None
Tim Dettmers's avatar
Tim Dettmers committed
434
435
    assert mlp.fc1.weight.dtype == torch.int8
    assert mlp.fc2.weight.dtype == torch.int8
436
437
    assert mlp.fc1.weight.device.type == "cuda"
    assert mlp.fc2.weight.device.type == "cuda"
Tim Dettmers's avatar
Tim Dettmers committed
438

justheuristic's avatar
justheuristic committed
439
    mlp = MLP8bit(
Ruff's avatar
Ruff committed
440
441
442
443
444
        32,
        64,
        threshold=threshold,
        has_fp16_weights=False,
    )
justheuristic's avatar
justheuristic committed
445
    w1, w2 = mlp.fc1.weight.clone().cuda(), mlp.fc2.weight.clone().cuda()  # grab weights before quantization,
justheuristic's avatar
justheuristic committed
446
    mlp = mlp.cuda().half()  # and this line triggers quantization
Tim Dettmers's avatar
Tim Dettmers committed
447
448

    for i in range(100):
449
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
450
451
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
452
453
454
455
        if threshold > 0:
            assert mlp.fc1.state.idx is not None
        if threshold > 0:
            assert mlp.fc2.state.idx is not None
justheuristic's avatar
justheuristic committed
456

Tim Dettmers's avatar
Tim Dettmers committed
457
458
    assert mlp.fc1.weight.dtype == torch.int8
    assert mlp.fc2.weight.dtype == torch.int8
459
460
    assert mlp.fc1.weight.device.type == "cuda"
    assert mlp.fc2.weight.device.type == "cuda"
461

462
463
464
465
466
    b1 = torch.randn(16, 8, 32, device="cuda", requires_grad=True, dtype=torch.half)
    o1 = mlp(b1)
    assert o1.dtype == torch.float16
    assert o1.requires_grad
    grad_proj = torch.randn_like(o1)
justheuristic's avatar
justheuristic committed
467

468
469
470
471
    mlp.zero_grad()
    (o1 * grad_proj).sum().backward()
    grad_ref = grad_proj.flatten(2) @ w2.half() @ w1.half()
    scale = grad_ref.abs().mean()
justheuristic's avatar
justheuristic committed
472

473
474
475
    torch.testing.assert_close(b1.grad, grad_ref, rtol=0, atol=0.05 * scale)
    idx = torch.isclose(b1.grad, grad_ref, atol=0.01 * scale, rtol=0.1)
    assert (idx == 0).sum().item() <= b1.numel() * 0.005
476

justheuristic's avatar
justheuristic committed
477

478
479
480
481
482
483
@pytest.mark.parametrize(
    "module",
    [
        lambda n_in, n_out, bias=True: bnb.nn.Linear8bitLt(n_in, n_out, bias=bias, has_fp16_weights=False),
        bnb.nn.LinearFP4,
    ],
Ruff's avatar
Ruff committed
484
    ids=["Int8Lt", "FP4"],
485
)
486
def test_linear_kbit_fp32_bias(module):
487
    # casts model to fp16 -> int8 automatically
488
489
    l1 = module(32, 64).cuda()
    assert l1.weight.dtype in [torch.int8, torch.uint8]
490
491
492
493
494
495
496
497
498
    assert l1.bias.dtype == torch.float32

    for i in range(100):
        b1 = torch.randn(16, 8, 32, device="cuda").half()
        # casts bias to fp32
        o1 = l1(b1)
        assert l1.bias.dtype == torch.float16

    # casts model to fp16 -> int8 automatically
499
500
    l1 = module(32, 64, bias=False).cuda()
    assert l1.weight.dtype in [torch.int8, torch.uint8]
501
502
503
504
505
506
    assert l1.bias is None

    for i in range(100):
        b1 = torch.randn(16, 8, 32, device="cuda").half()
        o1 = l1(b1)
        assert l1.bias is None
507

Aarni Koskela's avatar
Aarni Koskela committed
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522

module_dict = {
    "Int8Lt": bnb.nn.Linear8bitLt,
    "4bit": bnb.nn.Linear4bit,
    "FP4": bnb.nn.LinearFP4,
    "NF4": bnb.nn.LinearNF4,
    "FP4+C": lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compress_statistics=True),
    "NF4+C": lambda d1, d2: bnb.nn.LinearNF4(d1, d2, compress_statistics=True),
    "NF4+fp32": lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.float32),
    "NF4+fp16": lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.float16),
    "NF4+bf16": lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.bfloat16),
}


@pytest.mark.parametrize("module", module_dict.values(), ids=module_dict.keys())
523
def test_kbit_backprop(module):
524
525
526
527
528
529
530
531
    b = 16
    dim1 = 36
    dim2 = 84
    # dim1 = 37
    # dim2 = 83

    ref = nn.Sequential(*[torch.nn.Linear(dim1, dim2), torch.nn.Linear(dim2, 128)])
    # ref[1].weight.requires_grad = False
532
533
    torch.nn.init.kaiming_normal_(ref[0].weight)
    torch.nn.init.kaiming_normal_(ref[1].weight)
534
    kbit = nn.Sequential(*[torch.nn.Linear(dim1, dim2), module(dim2, 128)])
535
536
537
538
539
540
    kbit[0].weight.detach().copy_(ref[0].weight)
    kbit[1].weight.detach().copy_(ref[1].weight)
    kbit[0].bias.detach().copy_(ref[0].bias)
    kbit[1].bias.detach().copy_(ref[1].bias)
    ref = ref.half().cuda()
    kbit = kbit.half().cuda()
Ruff's avatar
Ruff committed
541
    kbit = kbit.half().to("cuda")
542

543
544
545
546
    errs1 = []
    errs2 = []
    relerrs1 = []
    relerrs2 = []
547
548
549
550
551
552
553
554
555
556
557
558
    for i in range(100):
        batch = torch.randn(b, dim1).half().cuda()
        out1 = ref(batch)
        out2 = kbit(batch)
        out1.mean().backward()
        out2.mean().backward()

        grad1 = ref[0].weight.grad
        grad2 = kbit[0].weight.grad
        bgrad1 = ref[0].bias.grad
        bgrad2 = kbit[0].bias.grad

Ruff's avatar
Ruff committed
559
560
561
562
        err1 = (out1 - out2).abs().float()
        err2 = (grad1 - grad2).abs().float()
        relerr1 = err1 / (out1.abs().float() + 1e-9)
        relerr2 = err2 / (grad1.abs().float() + 1e-9)
563
564
565
566
567
        errs1.append(err1.mean().item())
        errs2.append(err2.mean().item())
        relerrs1.append(relerr1.mean().item())
        relerrs2.append(relerr2.mean().item())

568
        if isinstance(module, bnb.nn.Linear8bitLt):
569
            assert_all_approx_close(grad1, grad2, atol=0.008, rtol=0.05, count=1)
570
            torch.testing.assert_close(bgrad1, bgrad2, atol=0.008, rtol=0.05)
571
        else:
572
            assert_all_approx_close(grad1, grad2, atol=0.015, rtol=0.05, count=1)
573
            torch.testing.assert_close(bgrad1, bgrad2, atol=0.02, rtol=0.05)
574
575
576
        ref.zero_grad()
        kbit.zero_grad()

577
578
        assert kbit[0].weight.grad is None or kbit[0].weight.grad.sum().item() == 0
        assert kbit[0].weight.grad is None or kbit[0].bias.grad.sum().item() == 0
579

580

Ruff's avatar
Ruff committed
581
def test_fp8linear():
582
583
584
    b = 10
    h = 1024
    inp = torch.randn(b, h).cuda()
Ruff's avatar
Ruff committed
585
586
587
588
    fp32 = torch.nn.Linear(h, h * 2).cuda()
    fp8 = bnb.research.nn.LinearFP8Mixed(h, h * 2).cuda()
    fp32b = torch.nn.Linear(h * 2, h).cuda()
    fp8b = bnb.research.nn.LinearFP8Mixed(h * 2, h).cuda()
589
590
591
592
593
594
595
596
597

    fp8.weight.data.copy_(fp32.weight.data)
    fp8.bias.data.copy_(fp32.bias.data)
    fp8b.weight.data.copy_(fp32b.weight.data)
    fp8b.bias.data.copy_(fp32b.bias.data)

    a = fp32b(torch.nn.functional.gelu(fp32(inp)))
    b = fp8b(torch.nn.functional.gelu(fp8(inp)))

Ruff's avatar
Ruff committed
598
    err = (a - b).abs().mean()
599
600
601
602

    a.mean().backward()
    b.mean().backward()

Ruff's avatar
Ruff committed
603
604
    graderr = (fp8.weight.grad - fp32.weight.grad).abs().mean()
    bgraderr = (fp8.bias.grad - fp32.bias.grad).abs().mean()
605
606
607
608
609

    assert err < 0.05
    assert graderr < 0.00002
    assert bgraderr < 0.00002

Ruff's avatar
Ruff committed
610

611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
@pytest.mark.parametrize("embedding_dim", [64, 65])
@pytest.mark.parametrize("input_shape", [(10,), (10, 10), (10, 10, 10)], ids=str)
@pytest.mark.parametrize(
    "embedding_class,quant_storage",
    [
        (bnb.nn.Embedding8bit, None),
        (bnb.nn.EmbeddingFP4, torch.uint8),
        (bnb.nn.EmbeddingFP4, torch.float32),
        (bnb.nn.EmbeddingNF4, torch.uint8),
        (bnb.nn.EmbeddingNF4, torch.float32),
    ],
    ids=lambda x: x.__name__ if inspect.isclass(x) else str(x),
)
def test_embedding_lossless(embedding_class, input_shape, embedding_dim, quant_storage):
    num_embeddings = 128

    src_weight = (torch.randn((num_embeddings, embedding_dim), dtype=torch.float32) > 0).to(
        torch.float32
    ) * 2 - 1  # Embeddings filled with {-1, 1} values. It should compress losslessly

    emb_base = nn.Embedding(
        num_embeddings=num_embeddings,
        embedding_dim=embedding_dim,
        _freeze=True,
        _weight=src_weight,
    )
    if embedding_class is bnb.nn.Embedding8bit:
        e = embedding_class(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
    else:
        e = embedding_class(num_embeddings=num_embeddings, embedding_dim=embedding_dim, quant_storage=quant_storage)

    e.load_state_dict(emb_base.state_dict())

    emb_base.cuda()
    e.cuda()

    input_tokens = torch.randint(low=0, high=num_embeddings, size=input_shape, device="cuda")

    torch.testing.assert_close(
        actual=e(input_tokens),
        expected=emb_base(input_tokens),
    )


@pytest.mark.parametrize("embedding_dim", [64, 65])
@pytest.mark.parametrize("input_shape", [(10,), (10, 10), (10, 10, 10)], ids=str)
@pytest.mark.parametrize(
    "embedding_class,quant_storage",
    [
        (bnb.nn.Embedding8bit, None),
        (bnb.nn.EmbeddingFP4, torch.uint8),
        (bnb.nn.EmbeddingFP4, torch.float32),
        (bnb.nn.EmbeddingNF4, torch.uint8),
        (bnb.nn.EmbeddingNF4, torch.float32),
    ],
    ids=lambda x: x.__name__ if inspect.isclass(x) else str(x),
)
def test_embedding_error(embedding_class, input_shape, embedding_dim, quant_storage):
    is_8bit = embedding_class is bnb.nn.Embedding8bit

    num_embeddings = 128

    src_weight = torch.rand((num_embeddings, embedding_dim), dtype=torch.float32)

    emb_base = nn.Embedding(
        num_embeddings=num_embeddings,
        embedding_dim=embedding_dim,
        _freeze=True,
        _weight=src_weight,
    )
    if is_8bit:
        e = embedding_class(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
    else:
        e = embedding_class(num_embeddings=num_embeddings, embedding_dim=embedding_dim, quant_storage=quant_storage)

    e.load_state_dict(emb_base.state_dict())

    emb_base.cuda()
    e.cuda()

    input_tokens = torch.randint(low=0, high=num_embeddings, size=input_shape, device="cuda")

    torch.testing.assert_close(
        actual=e(input_tokens),
        expected=emb_base(input_tokens),
        atol=0.05 if is_8bit else 0.20,
        rtol=0.0,
    )


def test_4bit_linear_warnings():
702
703
    dim1 = 64

Ruff's avatar
Ruff committed
704
    with pytest.warns(UserWarning, match=r"inference or training"):
705
706
707
708
        net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)])
        net = net.cuda()
        inp = torch.rand(10, dim1).cuda().half()
        net(inp)
Ruff's avatar
Ruff committed
709
    with pytest.warns(UserWarning, match=r"inference."):
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
        net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)])
        net = net.cuda()
        inp = torch.rand(1, dim1).cuda().half()
        net(inp)

    with pytest.warns(UserWarning) as record:
        net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)])
        net = net.cuda()
        inp = torch.rand(10, dim1).cuda().half()
        net(inp)

        net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)])
        net = net.cuda()
        inp = torch.rand(1, dim1).cuda().half()
        net(inp)

    assert len(record) == 2
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781


def test_4bit_embedding_warnings():
    num_embeddings = 128
    default_block_size = 64

    with pytest.warns(UserWarning, match=r"inference."):
        net = bnb.nn.Embedding4bit(num_embeddings=num_embeddings, embedding_dim=default_block_size + 1)
        net.cuda()
        inp = torch.randint(low=0, high=num_embeddings, size=(1,), device="cuda")
        net(inp)


def test_4bit_embedding_weight_fsdp_fix():
    num_embeddings = 64
    embedding_dim = 32

    module = bnb.nn.Embedding4bit(num_embeddings=num_embeddings, embedding_dim=embedding_dim)

    module.cuda()

    module.weight.quant_state = None

    input_tokens = torch.randint(low=0, high=num_embeddings, size=(1,), device="cuda")

    module(input_tokens)

    assert module.weight.quant_state is not None


def test_4bit_linear_weight_fsdp_fix():
    inp_size = 64
    out_size = 32

    module = bnb.nn.Linear4bit(inp_size, out_size)

    module.cuda()

    module.weight.quant_state = None

    input_tensor = torch.randn((1, inp_size), device="cuda")

    module(input_tensor)

    assert module.weight.quant_state is not None


def test_embedding_not_implemented_error():
    with pytest.raises(NotImplementedError):
        emb = bnb.nn.Embedding4bit(32, 32)
        emb.state_dict()

    with pytest.raises(NotImplementedError):
        emb = bnb.nn.Embedding8bit(32, 32)
        emb.state_dict()