test_autograd.py 22.6 KB
Newer Older
Tom Aarsen's avatar
Tom Aarsen committed
1
from itertools import permutations, product
Tim Dettmers's avatar
Tim Dettmers committed
2

3
import pytest
Tim Dettmers's avatar
Tim Dettmers committed
4
5
import torch

6
import bitsandbytes as bnb
Tim Dettmers's avatar
Tim Dettmers committed
7
8
9

n = 1
k = 25
10
11
12
13
dim1 = torch.randint(16, 64, size=(n,)).tolist()
dim2 = torch.randint(32, 96, size=(n,)).tolist()
dim3 = torch.randint(32, 96, size=(n,)).tolist()
dim4 = torch.randint(32, 96, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
14
funcs = [(torch.bmm, bnb.bmm_cublas), (torch.matmul, bnb.matmul_cublas)]
15
str_funcs = ["bmm", "matmul"]
Tim Dettmers's avatar
Tim Dettmers committed
16
req_grad = [(False, False), (True, False), (True, True), (False, True)]
17
req_grad_str = ["FF", "TF", "TT", "FT"]
Tim Dettmers's avatar
Tim Dettmers committed
18
transpose = [(False, False), (False, True), (True, True), (True, False)]
19
str_transpose = ["FF", "FT", "TT", "TF"]
Tim Dettmers's avatar
Tim Dettmers committed
20
dtype = [torch.float32, torch.float16]
21
22
23
values = list(
    product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose)
)
24
str_values = list(
25
26
27
    product(
        dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose
    )
28
29
)
names = [
30
    "dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}".format(
31
32
33
34
35
36
37
        *vals
    )
    for vals in str_values
]


@pytest.mark.parametrize(
38
39
40
    "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose",
    values,
    ids=names,
41
)
Tim Dettmers's avatar
Tim Dettmers committed
42
def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
43
44
    if dim2 > 0:
        dim2 = dim2 - (dim2 % 16)
Tim Dettmers's avatar
Tim Dettmers committed
45
46
47
48
49
50
51
52
    dim3 = dim3 - (dim3 % 16)
    dim4 = dim4 - (dim4 % 16)
    for i in range(k):

        # normal multiply
        if funcs[0] in [torch.mm, torch.matmul]:
            dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
            dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
53
54
55
56
57
            A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0])
            B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1])
            target = torch.randn(
                size=(dim2, dim4), device="cuda", requires_grad=req_grad[1]
            )
Tim Dettmers's avatar
Tim Dettmers committed
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
            torch.nn.init.xavier_uniform_(B)

            if not transpose[0] and not transpose[1]:
                out_torch = funcs[0](A, B)
                out_bnb = funcs[1](A, B)
            elif not transpose[0] and transpose[1]:
                out_torch = funcs[0](A, B.t())
                out_bnb = funcs[1](A, B.t())
            elif transpose[0] and not transpose[1]:
                out_torch = funcs[0](A.t(), B)
                out_bnb = funcs[1](A.t(), B)
            elif transpose[0] and transpose[1]:
                out_torch = funcs[0](A.t(), B.t())
                out_bnb = funcs[1](A.t(), B.t())

            n = out_bnb.numel()
            idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
75
            assert (idx == 0).sum().item() < n * 0.0175
Tim Dettmers's avatar
Tim Dettmers committed
76
            idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
77
            assert (idx == 0).sum().item() < n * 0.001
Tim Dettmers's avatar
Tim Dettmers committed
78
79
80
81
82
83
84
85
86
87
88

            if any(req_grad):
                out_bnb.data.copy_(out_torch)
                torch.cuda.synchronize()
                loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
                loss_bnb.backward()
                gradA1 = A.grad
                gradB1 = B.grad
                A.grad = None
                B.grad = None

89
90
91
                loss_torch = torch.nn.functional.mse_loss(
                    out_torch, target
                ).mean()
Tim Dettmers's avatar
Tim Dettmers committed
92
93
94
95
96
97
98
                loss_torch.backward()
                gradA2 = A.grad
                gradB2 = B.grad
                A.grad = None
                B.grad = None

            if req_grad[0]:
99
                torch.testing.assert_close(
100
101
                    gradA1, gradA2, atol=0.015, rtol=0.1
                )
Tim Dettmers's avatar
Tim Dettmers committed
102
103
104
            if req_grad[1]:
                n = gradB1.numel()
                idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
105
                assert (idx == 0).sum().item() < n * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
106
                idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
107
                assert (idx == 0).sum().item() < n * 0.02
108
                torch.testing.assert_close(
109
110
                    gradB1, gradB2, atol=0.18, rtol=0.3
                )
Tim Dettmers's avatar
Tim Dettmers committed
111
112
113

        # batched matrix multiply
        if funcs[0] in [torch.bmm, torch.matmul]:
114
            A = torch.randn(
115
116
117
                size=(dim1, dim2, dim3),
                device="cuda",
                requires_grad=req_grad[0],
118
119
            )
            B = torch.randn(
120
121
122
                size=(dim1, dim3, dim4),
                device="cuda",
                requires_grad=req_grad[1],
123
124
            )
            target = torch.randn(
125
126
127
                size=(dim1, dim2, dim4),
                device="cuda",
                requires_grad=req_grad[1],
128
            )
Tim Dettmers's avatar
Tim Dettmers committed
129
130
131
132
133
134
135
            torch.nn.init.xavier_uniform_(B)

            out_torch = funcs[0](A, B)
            out_bnb = funcs[1](A, B)

            n = out_bnb.numel()
            idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
136
            assert (idx == 0).sum().item() < n * 0.01
137
            torch.testing.assert_close(
138
139
                out_bnb, out_torch, atol=0.027, rtol=0.2
            )
Tim Dettmers's avatar
Tim Dettmers committed
140
141
142
143
144
145
146
147
148
149
150

            if any(req_grad):
                out_bnb.data.copy_(out_torch)
                torch.cuda.synchronize()
                loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
                loss_bnb.backward()
                gradA1 = A.grad
                gradB1 = B.grad
                A.grad = None
                B.grad = None

151
152
153
                loss_torch = torch.nn.functional.mse_loss(
                    out_torch, target
                ).mean()
Tim Dettmers's avatar
Tim Dettmers committed
154
155
156
157
158
159
160
                loss_torch.backward()
                gradA2 = A.grad
                gradB2 = B.grad
                A.grad = None
                B.grad = None

            if req_grad[0]:
161
                torch.testing.assert_close(
162
163
                    gradA1, gradA2, atol=0.015, rtol=0.1
                )
Tim Dettmers's avatar
Tim Dettmers committed
164
165
166
            if req_grad[1]:
                n = gradB1.numel()
                idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
167
                assert (idx == 0).sum().item() < n * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
168
                idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
169
                assert (idx == 0).sum().item() < n * 0.02
Tim Dettmers's avatar
Tim Dettmers committed
170
171
172

        if funcs[0] in [torch.matmul]:
            dim1 = dim1 - (dim1 % 16)
173
            A = torch.randn(
174
175
176
                size=(dim1, dim2, dim3),
                device="cuda",
                requires_grad=req_grad[0],
177
            )
Tim Dettmers's avatar
Tim Dettmers committed
178
            dimB = (dim4, dim3) if transpose[1] else (dim3, dim4)
179
180
            B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1])
            target = torch.randn(
181
182
183
                size=(dim1, dim2, dim4),
                device="cuda",
                requires_grad=req_grad[1],
184
            )
Tim Dettmers's avatar
Tim Dettmers committed
185
186
187
188
189
190
191
192
193
194
195
            torch.nn.init.xavier_uniform_(B)

            if transpose[1]:
                out_torch = funcs[0](A, B.t())
                out_bnb = funcs[1](A, B.t())
            else:
                out_torch = funcs[0](A, B)
                out_bnb = funcs[1](A, B)

            n = out_bnb.numel()
            idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
196
            assert (idx == 0).sum().item() < n * 0.0175
Tim Dettmers's avatar
Tim Dettmers committed
197
            idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
198
            assert (idx == 0).sum().item() < n * 0.001
Tim Dettmers's avatar
Tim Dettmers committed
199
200
201
202
203
204
205
206
207
208
209

            if any(req_grad):
                out_bnb.data.copy_(out_torch)
                torch.cuda.synchronize()
                loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
                loss_bnb.backward()
                gradA1 = A.grad
                gradB1 = B.grad
                A.grad = None
                B.grad = None

210
211
212
                loss_torch = torch.nn.functional.mse_loss(
                    out_torch, target
                ).mean()
Tim Dettmers's avatar
Tim Dettmers committed
213
214
215
216
217
218
219
                loss_torch.backward()
                gradA2 = A.grad
                gradB2 = B.grad
                A.grad = None
                B.grad = None

            if req_grad[0]:
220
                torch.testing.assert_close(
221
222
                    gradA1, gradA2, atol=0.015, rtol=0.1
                )
Tim Dettmers's avatar
Tim Dettmers committed
223
224
225
            if req_grad[1]:
                n = gradB1.numel()
                idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
226
                assert (idx == 0).sum().item() < n * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
227
                idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
228
                assert (idx == 0).sum().item() < n * 0.02
Tim Dettmers's avatar
Tim Dettmers committed
229
230
231
232


n = 1
k = 3
233
234
235
236
dim1 = torch.randint(16, 64, size=(n,)).tolist()
dim2 = torch.randint(32, 96, size=(n,)).tolist()
dim3 = torch.randint(32, 96, size=(n,)).tolist()
dim4 = torch.randint(32, 96, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
237

238
dim2.append(0)
Tim Dettmers's avatar
Tim Dettmers committed
239
240

decomp = [0.0, 6.0]
241
242
funcs = [(torch.matmul, bnb.matmul), (torch.matmul, bnb.research.switchback_bnb)]
str_funcs = ["matmullt", 'switchback_bnb']
Tim Dettmers's avatar
Tim Dettmers committed
243
req_grad = [(False, False), (True, False), (True, True), (False, True)]
Tim Dettmers's avatar
Tim Dettmers committed
244
245
246
247
248
249
250
251
252
req_grad = list(product([True, False], repeat=3))
req_grad_str = []
for c in req_grad:
    strval = ''
    for v in c:
        if v == True: strval += 'T'
        else: strval += 'F'
    req_grad_str.append(strval)

Tim Dettmers's avatar
Tim Dettmers committed
253
transpose = [(False, True), (False, False)]
254
str_transpose = ["NT", "NN"]
justheuristic's avatar
justheuristic committed
255
dtype = [torch.float16, torch.bfloat16, torch.float32]
Tim Dettmers's avatar
Tim Dettmers committed
256
has_fp16_weights = [True, False]
Tim Dettmers's avatar
Tim Dettmers committed
257
has_bias = [True, False]
258
259
260
261
262
263
264
265
266
267
268
269
values = list(
    product(
        dim1,
        dim2,
        dim3,
        dim4,
        funcs,
        dtype,
        req_grad,
        transpose,
        decomp,
        has_fp16_weights,
Tim Dettmers's avatar
Tim Dettmers committed
270
        has_bias
271
272
273
274
275
276
277
278
279
280
281
282
283
284
    )
)
str_values = list(
    product(
        dim1,
        dim2,
        dim3,
        dim4,
        str_funcs,
        dtype,
        req_grad_str,
        str_transpose,
        decomp,
        has_fp16_weights,
Tim Dettmers's avatar
Tim Dettmers committed
285
        has_bias
286
287
    )
)
288
names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}_decomp_{}_has_fp16_weights_{}_has_bias_{}".format(*vals) for vals in str_values]
289
290
291


@pytest.mark.parametrize(
Tim Dettmers's avatar
Tim Dettmers committed
292
    "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights, has_bias",
293
294
295
296
    values,
    ids=names,
)
def test_matmullt(
297
298
299
300
301
302
303
304
305
306
    dim1,
    dim2,
    dim3,
    dim4,
    funcs,
    dtype,
    req_grad,
    transpose,
    decomp,
    has_fp16_weights,
Tim Dettmers's avatar
Tim Dettmers committed
307
    has_bias
308
):
Tim Dettmers's avatar
Tim Dettmers committed
309
310
    dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
    dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
311
    outlier_dim = torch.randint(0, dimA[1], size=(dimA[1] // 8,), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
312
313
314
    if has_bias == False:
        req_grad = list(req_grad)
        req_grad[2] = False
Tim Dettmers's avatar
Tim Dettmers committed
315
316
317
318
319

    for i in range(k):

        # normal multiply
        if funcs[0] in [torch.mm, torch.matmul]:
320
321
322
            A = torch.randn(
                size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype
            )
Tim Dettmers's avatar
Tim Dettmers committed
323
324
325
            if decomp == 6.0:
                with torch.no_grad():
                    A[:, outlier_dim] = 6.0
326
327
328
329
            B = torch.randn(
                size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype
            )
            target = torch.randn(
330
331
332
333
                size=(dim2, dim4),
                device="cuda",
                requires_grad=req_grad[1],
                dtype=dtype,
334
            )
Tim Dettmers's avatar
Tim Dettmers committed
335
336
            bias = None
            bias2 = None
337
            if has_bias:
Tim Dettmers's avatar
Tim Dettmers committed
338
339
                bias = torch.randn(dim4, device='cuda', dtype=dtype, requires_grad=req_grad[2])
                bias2 = bias.clone()
Tim Dettmers's avatar
Tim Dettmers committed
340
341
342
343
344
345
346
            torch.nn.init.xavier_uniform_(B)
            B2 = B.clone()

            state = bnb.MatmulLtState()
            state.threshold = decomp
            state.has_fp16_weights = has_fp16_weights
            if not has_fp16_weights:
347
348
349
350
351
352
353
354
                if not transpose[0] and not transpose[1]:
                    B2 = B2.t().contiguous()
                (
                    state.CB,
                    CBt,
                    state.SCB,
                    SCBt,
                    coo_tensorB,
justheuristic's avatar
justheuristic committed
355
                ) = bnb.functional.double_quant(B2.to(torch.float16))
Tim Dettmers's avatar
Tim Dettmers committed
356
357
358
359
                B2 = state.CB

            if not transpose[0] and transpose[1]:
                out_torch = funcs[0](A, B.t())
Tim Dettmers's avatar
Tim Dettmers committed
360
                out_bnb = funcs[1](A, B2, state=state, bias=bias2)
Tim Dettmers's avatar
Tim Dettmers committed
361
362
            elif not transpose[0] and not transpose[1]:
                out_torch = funcs[0](A, B)
Tim Dettmers's avatar
Tim Dettmers committed
363
364
365
366
                out_bnb = funcs[1](A, B2.t(), state=state, bias=bias2)

            if has_bias:
                out_torch += bias
Tim Dettmers's avatar
Tim Dettmers committed
367

justheuristic's avatar
justheuristic committed
368
            assert out_bnb.dtype == A.dtype, f"bnb matmullt received {A.dtype} but returned {out_bnb.dtype}"
justheuristic's avatar
justheuristic committed
369

Tim Dettmers's avatar
Tim Dettmers committed
370
            n = out_bnb.numel()
371
372
            err = torch.abs(out_bnb - out_torch).mean().item()
            # print(f'abs error {err:.4f}')
justheuristic's avatar
justheuristic committed
373

Tim Dettmers's avatar
Tim Dettmers committed
374
            idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
justheuristic's avatar
justheuristic committed
375
            assert (idx == 0).sum().item() <= n * (0.0175 if dtype == torch.float16 else 0.021)
Tim Dettmers's avatar
Tim Dettmers committed
376
            idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
Tim Dettmers's avatar
Tim Dettmers committed
377
            assert (idx == 0).sum().item() <= n * 0.001
Tim Dettmers's avatar
Tim Dettmers committed
378
379
380
381
382

            if has_fp16_weights:
                if any(req_grad):
                    out_bnb.data.copy_(out_torch)
                    torch.cuda.synchronize()
383
384
385
                    loss_bnb = torch.nn.functional.mse_loss(
                        out_bnb, target
                    ).mean()
Tim Dettmers's avatar
Tim Dettmers committed
386
387
388
389
390
                    loss_bnb.backward()
                    gradA1 = A.grad
                    gradB1 = B.grad
                    A.grad = None
                    B.grad = None
Tim Dettmers's avatar
Tim Dettmers committed
391
392
393
                    if has_bias:
                        gradBias1 = bias.grad
                        bias.grad = None
Tim Dettmers's avatar
Tim Dettmers committed
394

395
396
397
                    loss_torch = torch.nn.functional.mse_loss(
                        out_torch, target
                    ).mean()
Tim Dettmers's avatar
Tim Dettmers committed
398
399
400
401
402
                    loss_torch.backward()
                    gradA2 = A.grad
                    gradB2 = B.grad
                    A.grad = None
                    B.grad = None
Tim Dettmers's avatar
Tim Dettmers committed
403
404
405
                    if has_bias:
                        gradBias2 = bias.grad
                        bias.grad = None
Tim Dettmers's avatar
Tim Dettmers committed
406
407

                if req_grad[0]:
408
                    torch.testing.assert_close(
409
410
                        gradA1, gradA2, atol=0.015, rtol=0.1
                    )
Tim Dettmers's avatar
Tim Dettmers committed
411
412
                if req_grad[1]:
                    n = gradB1.numel()
413
414
415
416
417
418
                    if dim2 > 0:
                        assert torch.abs(gradB1).sum() > 0.0
                        assert torch.abs(gradB2).sum() > 0.0
                    else:
                        assert torch.abs(gradB1).sum() == 0.0
                        assert torch.abs(gradB2).sum() == 0.0
Tim Dettmers's avatar
Tim Dettmers committed
419
420
                    idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)

Tim Dettmers's avatar
Tim Dettmers committed
421
                    assert (idx == 0).sum().item() <= n * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
422
                    idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
Tim Dettmers's avatar
Tim Dettmers committed
423
                    assert (idx == 0).sum().item() <= n * 0.02
424
                    torch.testing.assert_close(
425
426
                        gradB1, gradB2, atol=0.18, rtol=0.3
                    )
Tim Dettmers's avatar
Tim Dettmers committed
427
428

                if req_grad[2]:
429
                    torch.testing.assert_close(gradBias1, gradBias2)
Tim Dettmers's avatar
Tim Dettmers committed
430
431
432
433
434
435
436
437
438
439
440


n = 1
k = 3
dim1 = torch.randint(16, 64, size=(n,)).tolist()
dim2 = torch.randint(32, 96, size=(n,)).tolist()
dim3 = torch.randint(32, 96, size=(n,)).tolist()
dim4 = torch.randint(32, 96, size=(n,)).tolist()

dim2.append(0)

441
funcs = [(torch.matmul, bnb.matmul_4bit)]
Tim Dettmers's avatar
Tim Dettmers committed
442
443
444
445
446
447
448
449
450
451
452
453
454
str_funcs = ["matmul"]
req_grad = list(product([True, False], repeat=3))
req_grad_str = []
for c in req_grad:
    strval = ''
    for v in c:
        if v == True: strval += 'T'
        else: strval += 'F'
    req_grad_str.append(strval)

transpose = [(False, True), (False, False)]
str_transpose = ["NT", "NN"]
dtype = [torch.float16, torch.float32]
455
compress_statistics = [False, True]
Tim Dettmers's avatar
Tim Dettmers committed
456
457
has_fp16_weights = [True, False]
has_bias = [True, False]
458
459
460
461
462
463
quant_type = ['fp4', 'nf4']
values = list(product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics, quant_type))
str_values = list(product(dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose, has_bias, compress_statistics, quant_type))
names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}_has_bias_{}_compress_statistics_{}_quant_type_{}".format(*vals) for vals in str_values]
@pytest.mark.parametrize( "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics, quant_type", values, ids=names)
def test_matmul_4bit( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics, quant_type):
Tim Dettmers's avatar
Tim Dettmers committed
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
    dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
    dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
    if has_bias == False:
        req_grad = list(req_grad)
        req_grad[2] = False

    for i in range(k):
        # normal multiply
        if funcs[0] in [torch.mm, torch.matmul]:
            A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype)
            B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype)
            target = torch.randn(size=(dim2, dim4), device="cuda", requires_grad=req_grad[1], dtype=dtype)
            bias = None
            bias2 = None
            if has_bias:
                bias = torch.randn(dim4, device='cuda', dtype=dtype, requires_grad=req_grad[2])
                bias2 = bias.clone()
            torch.nn.init.xavier_uniform_(B)

483
            B2, quant_state = bnb.functional.quantize_4bit(B, compress_statistics=compress_statistics, quant_type=quant_type)
Tim Dettmers's avatar
Tim Dettmers committed
484
485
486

            if not transpose[0] and transpose[1]:
                out_torch = funcs[0](A, B.t())
Tim Dettmers's avatar
Tim Dettmers committed
487
                out_bnb = funcs[1](A, B2.t(), quant_state, bias=bias2)
Tim Dettmers's avatar
Tim Dettmers committed
488
489
            elif not transpose[0] and not transpose[1]:
                out_torch = funcs[0](A, B)
Tim Dettmers's avatar
Tim Dettmers committed
490
                out_bnb = funcs[1](A, B2, quant_state, bias=bias2)
Tim Dettmers's avatar
Tim Dettmers committed
491
492
493
494
495
496
497
498
499

            if has_bias:
                out_torch += bias

            assert out_bnb.dtype == A.dtype, f"bnb matmullt received {A.dtype} but returned {out_bnb.dtype}"

            n = out_bnb.numel()
            err = torch.abs(out_bnb - out_torch).float().mean().item()
            if n > 0:
500
                assert err < 0.115
Tim Dettmers's avatar
Tim Dettmers committed
501

502
                #assert err < 0.20
Tim Dettmers's avatar
Tim Dettmers committed
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
            if any(req_grad):
                out_bnb.data.copy_(out_torch)
                torch.cuda.synchronize()
                loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
                loss_bnb.backward()
                gradA1 = A.grad
                gradB1 = B.grad
                A.grad = None
                B.grad = None
                if has_bias:
                    gradBias1 = bias.grad
                    bias.grad = None

                loss_torch = torch.nn.functional.mse_loss( out_torch, target ).mean()
                loss_torch.backward()
                gradA2 = A.grad
                gradB2 = B.grad
                A.grad = None
                B.grad = None
                if has_bias:
                    gradBias2 = bias.grad
                    bias.grad = None

                if req_grad[0]:
527
                    torch.testing.assert_close( gradA1, gradA2, atol=0.015, rtol=0.1)
Tim Dettmers's avatar
Tim Dettmers committed
528
529

                if req_grad[2]:
530
531
532
                    torch.testing.assert_close(gradBias1, gradBias2)


533
534
funcs = [(torch.matmul, bnb.research.matmul_fp8_mixed), (torch.matmul, bnb.research.matmul_fp8_global)]
str_funcs = ["matmul_fp8_mixed", 'matmul_fp8_global']
Tim Dettmers's avatar
Tim Dettmers committed
535
536
537
538
539
540
541
542
543
544
545
546
547
req_grad = list(product([True, False], repeat=3))
req_grad_str = []
for c in req_grad:
    strval = ''
    for v in c:
        if v == True: strval += 'T'
        else: strval += 'F'
    req_grad_str.append(strval)

transpose = [(False, True), (False, False)]
str_transpose = ["NT", "NN"]
dtype = [torch.float16, torch.float32]
has_fp16_weights = [True, False]
548
549
550
551
552
values = list(product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose))
str_values = list(product(dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose))
names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}".format(*vals) for vals in str_values]
@pytest.mark.parametrize( "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose", values, ids=names)
def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
Tim Dettmers's avatar
Tim Dettmers committed
553
554
    dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
    dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
555
556
    req_grad = list(req_grad)
    req_grad[2] = False
Tim Dettmers's avatar
Tim Dettmers committed
557
558
559
560
561
562
563

    for i in range(k):
        # normal multiply
        if funcs[0] in [torch.mm, torch.matmul]:
            A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype)
            B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype)
            target = torch.randn(size=(dim2, dim4), device="cuda", requires_grad=req_grad[1], dtype=dtype)
564

Tim Dettmers's avatar
Tim Dettmers committed
565
566
            torch.nn.init.xavier_uniform_(B)

567
568
            fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(A.device)
            bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(A.device)
Tim Dettmers's avatar
Tim Dettmers committed
569
570
571

            if not transpose[0] and transpose[1]:
                out_torch = funcs[0](A, B.t())
572
                out_bnb = funcs[1](A, B.t(), fw_code, bw_code)
Tim Dettmers's avatar
Tim Dettmers committed
573
574
            elif not transpose[0] and not transpose[1]:
                out_torch = funcs[0](A, B)
575
                out_bnb = funcs[1](A, B, fw_code, bw_code)
Tim Dettmers's avatar
Tim Dettmers committed
576
577
578
579
580
581

            assert out_bnb.dtype == A.dtype, f"bnb matmullt received {A.dtype} but returned {out_bnb.dtype}"

            n = out_bnb.numel()
            err = torch.abs(out_bnb - out_torch).float().mean().item()
            if n > 0:
582
583
                assert err < 0.115
                #assert err < 0.20
Tim Dettmers's avatar
Tim Dettmers committed
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
            if any(req_grad):
                out_bnb.data.copy_(out_torch)
                torch.cuda.synchronize()
                loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
                loss_bnb.backward()
                gradA1 = A.grad
                gradB1 = B.grad
                A.grad = None
                B.grad = None

                loss_torch = torch.nn.functional.mse_loss( out_torch, target ).mean()
                loss_torch.backward()
                gradA2 = A.grad
                gradB2 = B.grad
                A.grad = None
                B.grad = None

                if req_grad[0]:
602
                    torch.testing.assert_close( gradA1, gradA2, atol=0.015, rtol=0.1)
Tim Dettmers's avatar
Tim Dettmers committed
603

604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
                if req_grad[1]:
                    n = gradB1.numel()
                    if dim2 > 0:
                        assert torch.abs(gradB1).sum() > 0.0
                        assert torch.abs(gradB2).sum() > 0.0
                    else:
                        assert torch.abs(gradB1).sum() == 0.0
                        assert torch.abs(gradB2).sum() == 0.0
                    idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)

                    assert (idx == 0).sum().item() <= n * 0.1
                    idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
                    assert (idx == 0).sum().item() <= n * 0.02
                    grad_err = (gradB1-gradB2).abs().mean()
                    assert grad_err.item() < 0.003
619
                    torch.testing.assert_close(
620
621
622
                        gradB1, gradB2, atol=0.18, rtol=0.3
                    )