"vscode:/vscode.git/clone" did not exist on "d24d3693329b274e99d541b86dc07662391ed38e"
test_autograd.py 18.5 KB
Newer Older
Tom Aarsen's avatar
Tom Aarsen committed
1
from itertools import permutations, product
Tim Dettmers's avatar
Tim Dettmers committed
2

3
import pytest
Tim Dettmers's avatar
Tim Dettmers committed
4
5
import torch

6
import bitsandbytes as bnb
Tim Dettmers's avatar
Tim Dettmers committed
7
8
9

n = 1
k = 25
10
11
12
13
dim1 = torch.randint(16, 64, size=(n,)).tolist()
dim2 = torch.randint(32, 96, size=(n,)).tolist()
dim3 = torch.randint(32, 96, size=(n,)).tolist()
dim4 = torch.randint(32, 96, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
14
funcs = [(torch.bmm, bnb.bmm_cublas), (torch.matmul, bnb.matmul_cublas)]
15
str_funcs = ["bmm", "matmul"]
Tim Dettmers's avatar
Tim Dettmers committed
16
req_grad = [(False, False), (True, False), (True, True), (False, True)]
17
req_grad_str = ["FF", "TF", "TT", "FT"]
Tim Dettmers's avatar
Tim Dettmers committed
18
transpose = [(False, False), (False, True), (True, True), (True, False)]
19
str_transpose = ["FF", "FT", "TT", "TF"]
Tim Dettmers's avatar
Tim Dettmers committed
20
dtype = [torch.float32, torch.float16]
21
22
23
values = list(
    product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose)
)
24
str_values = list(
25
26
27
    product(
        dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose
    )
28
29
)
names = [
30
    "dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}".format(
31
32
33
34
35
36
37
        *vals
    )
    for vals in str_values
]


@pytest.mark.parametrize(
38
39
40
    "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose",
    values,
    ids=names,
41
)
Tim Dettmers's avatar
Tim Dettmers committed
42
def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
43
    if not torch.cuda.is_available(): pytest.skip('No GPU found.')
44
45
    if dim2 > 0:
        dim2 = dim2 - (dim2 % 16)
Tim Dettmers's avatar
Tim Dettmers committed
46
47
48
49
50
51
52
53
    dim3 = dim3 - (dim3 % 16)
    dim4 = dim4 - (dim4 % 16)
    for i in range(k):

        # normal multiply
        if funcs[0] in [torch.mm, torch.matmul]:
            dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
            dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
54
55
56
57
58
            A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0])
            B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1])
            target = torch.randn(
                size=(dim2, dim4), device="cuda", requires_grad=req_grad[1]
            )
Tim Dettmers's avatar
Tim Dettmers committed
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
            torch.nn.init.xavier_uniform_(B)

            if not transpose[0] and not transpose[1]:
                out_torch = funcs[0](A, B)
                out_bnb = funcs[1](A, B)
            elif not transpose[0] and transpose[1]:
                out_torch = funcs[0](A, B.t())
                out_bnb = funcs[1](A, B.t())
            elif transpose[0] and not transpose[1]:
                out_torch = funcs[0](A.t(), B)
                out_bnb = funcs[1](A.t(), B)
            elif transpose[0] and transpose[1]:
                out_torch = funcs[0](A.t(), B.t())
                out_bnb = funcs[1](A.t(), B.t())

            n = out_bnb.numel()
            idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
76
            assert (idx == 0).sum().item() < n * 0.0175
Tim Dettmers's avatar
Tim Dettmers committed
77
            idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
78
            assert (idx == 0).sum().item() < n * 0.001
Tim Dettmers's avatar
Tim Dettmers committed
79
80
81
82
83
84
85
86
87
88
89

            if any(req_grad):
                out_bnb.data.copy_(out_torch)
                torch.cuda.synchronize()
                loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
                loss_bnb.backward()
                gradA1 = A.grad
                gradB1 = B.grad
                A.grad = None
                B.grad = None

90
91
92
                loss_torch = torch.nn.functional.mse_loss(
                    out_torch, target
                ).mean()
Tim Dettmers's avatar
Tim Dettmers committed
93
94
95
96
97
98
99
                loss_torch.backward()
                gradA2 = A.grad
                gradB2 = B.grad
                A.grad = None
                B.grad = None

            if req_grad[0]:
100
101
102
                torch.testing.assert_allclose(
                    gradA1, gradA2, atol=0.015, rtol=0.1
                )
Tim Dettmers's avatar
Tim Dettmers committed
103
104
105
            if req_grad[1]:
                n = gradB1.numel()
                idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
106
                assert (idx == 0).sum().item() < n * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
107
                idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
108
                assert (idx == 0).sum().item() < n * 0.02
109
110
111
                torch.testing.assert_allclose(
                    gradB1, gradB2, atol=0.18, rtol=0.3
                )
Tim Dettmers's avatar
Tim Dettmers committed
112
113
114

        # batched matrix multiply
        if funcs[0] in [torch.bmm, torch.matmul]:
115
            A = torch.randn(
116
117
118
                size=(dim1, dim2, dim3),
                device="cuda",
                requires_grad=req_grad[0],
119
120
            )
            B = torch.randn(
121
122
123
                size=(dim1, dim3, dim4),
                device="cuda",
                requires_grad=req_grad[1],
124
125
            )
            target = torch.randn(
126
127
128
                size=(dim1, dim2, dim4),
                device="cuda",
                requires_grad=req_grad[1],
129
            )
Tim Dettmers's avatar
Tim Dettmers committed
130
131
132
133
134
135
136
            torch.nn.init.xavier_uniform_(B)

            out_torch = funcs[0](A, B)
            out_bnb = funcs[1](A, B)

            n = out_bnb.numel()
            idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
137
            assert (idx == 0).sum().item() < n * 0.01
138
139
140
            torch.testing.assert_allclose(
                out_bnb, out_torch, atol=0.027, rtol=0.2
            )
Tim Dettmers's avatar
Tim Dettmers committed
141
142
143
144
145
146
147
148
149
150
151

            if any(req_grad):
                out_bnb.data.copy_(out_torch)
                torch.cuda.synchronize()
                loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
                loss_bnb.backward()
                gradA1 = A.grad
                gradB1 = B.grad
                A.grad = None
                B.grad = None

152
153
154
                loss_torch = torch.nn.functional.mse_loss(
                    out_torch, target
                ).mean()
Tim Dettmers's avatar
Tim Dettmers committed
155
156
157
158
159
160
161
                loss_torch.backward()
                gradA2 = A.grad
                gradB2 = B.grad
                A.grad = None
                B.grad = None

            if req_grad[0]:
162
163
164
                torch.testing.assert_allclose(
                    gradA1, gradA2, atol=0.015, rtol=0.1
                )
Tim Dettmers's avatar
Tim Dettmers committed
165
166
167
            if req_grad[1]:
                n = gradB1.numel()
                idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
168
                assert (idx == 0).sum().item() < n * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
169
                idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
170
                assert (idx == 0).sum().item() < n * 0.02
Tim Dettmers's avatar
Tim Dettmers committed
171
172
173

        if funcs[0] in [torch.matmul]:
            dim1 = dim1 - (dim1 % 16)
174
            A = torch.randn(
175
176
177
                size=(dim1, dim2, dim3),
                device="cuda",
                requires_grad=req_grad[0],
178
            )
Tim Dettmers's avatar
Tim Dettmers committed
179
            dimB = (dim4, dim3) if transpose[1] else (dim3, dim4)
180
181
            B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1])
            target = torch.randn(
182
183
184
                size=(dim1, dim2, dim4),
                device="cuda",
                requires_grad=req_grad[1],
185
            )
Tim Dettmers's avatar
Tim Dettmers committed
186
187
188
189
190
191
192
193
194
195
196
            torch.nn.init.xavier_uniform_(B)

            if transpose[1]:
                out_torch = funcs[0](A, B.t())
                out_bnb = funcs[1](A, B.t())
            else:
                out_torch = funcs[0](A, B)
                out_bnb = funcs[1](A, B)

            n = out_bnb.numel()
            idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
197
            assert (idx == 0).sum().item() < n * 0.0175
Tim Dettmers's avatar
Tim Dettmers committed
198
            idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
199
            assert (idx == 0).sum().item() < n * 0.001
Tim Dettmers's avatar
Tim Dettmers committed
200
201
202
203
204
205
206
207
208
209
210

            if any(req_grad):
                out_bnb.data.copy_(out_torch)
                torch.cuda.synchronize()
                loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
                loss_bnb.backward()
                gradA1 = A.grad
                gradB1 = B.grad
                A.grad = None
                B.grad = None

211
212
213
                loss_torch = torch.nn.functional.mse_loss(
                    out_torch, target
                ).mean()
Tim Dettmers's avatar
Tim Dettmers committed
214
215
216
217
218
219
220
                loss_torch.backward()
                gradA2 = A.grad
                gradB2 = B.grad
                A.grad = None
                B.grad = None

            if req_grad[0]:
221
222
223
                torch.testing.assert_allclose(
                    gradA1, gradA2, atol=0.015, rtol=0.1
                )
Tim Dettmers's avatar
Tim Dettmers committed
224
225
226
            if req_grad[1]:
                n = gradB1.numel()
                idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
227
                assert (idx == 0).sum().item() < n * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
228
                idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
229
                assert (idx == 0).sum().item() < n * 0.02
Tim Dettmers's avatar
Tim Dettmers committed
230
231
232
233


n = 1
k = 3
234
235
236
237
dim1 = torch.randint(16, 64, size=(n,)).tolist()
dim2 = torch.randint(32, 96, size=(n,)).tolist()
dim3 = torch.randint(32, 96, size=(n,)).tolist()
dim4 = torch.randint(32, 96, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
238

239
dim2.append(0)
Tim Dettmers's avatar
Tim Dettmers committed
240
241
242

decomp = [0.0, 6.0]
funcs = [(torch.matmul, bnb.matmul)]
243
str_funcs = ["matmul"]
Tim Dettmers's avatar
Tim Dettmers committed
244
req_grad = [(False, False), (True, False), (True, True), (False, True)]
Tim Dettmers's avatar
Tim Dettmers committed
245
246
247
248
249
250
251
252
253
req_grad = list(product([True, False], repeat=3))
req_grad_str = []
for c in req_grad:
    strval = ''
    for v in c:
        if v == True: strval += 'T'
        else: strval += 'F'
    req_grad_str.append(strval)

Tim Dettmers's avatar
Tim Dettmers committed
254
transpose = [(False, True), (False, False)]
255
str_transpose = ["NT", "NN"]
justheuristic's avatar
justheuristic committed
256
dtype = [torch.float16, torch.bfloat16, torch.float32]
Tim Dettmers's avatar
Tim Dettmers committed
257
has_fp16_weights = [True, False]
Tim Dettmers's avatar
Tim Dettmers committed
258
has_bias = [True, False]
259
260
261
262
263
264
265
266
267
268
269
270
values = list(
    product(
        dim1,
        dim2,
        dim3,
        dim4,
        funcs,
        dtype,
        req_grad,
        transpose,
        decomp,
        has_fp16_weights,
Tim Dettmers's avatar
Tim Dettmers committed
271
        has_bias
272
273
274
275
276
277
278
279
280
281
282
283
284
285
    )
)
str_values = list(
    product(
        dim1,
        dim2,
        dim3,
        dim4,
        str_funcs,
        dtype,
        req_grad_str,
        str_transpose,
        decomp,
        has_fp16_weights,
Tim Dettmers's avatar
Tim Dettmers committed
286
        has_bias
287
288
    )
)
289
names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}_decomp_{}_has_fp16_weights_{}_has_bias_{}".format(*vals) for vals in str_values]
290
291
292


@pytest.mark.parametrize(
Tim Dettmers's avatar
Tim Dettmers committed
293
    "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights, has_bias",
294
295
296
297
    values,
    ids=names,
)
def test_matmullt(
298
299
300
301
302
303
304
305
306
307
    dim1,
    dim2,
    dim3,
    dim4,
    funcs,
    dtype,
    req_grad,
    transpose,
    decomp,
    has_fp16_weights,
Tim Dettmers's avatar
Tim Dettmers committed
308
    has_bias
309
):
310
    if not torch.cuda.is_available(): pytest.skip('No GPU found.')
Tim Dettmers's avatar
Tim Dettmers committed
311
312
    dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
    dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
313
    outlier_dim = torch.randint(0, dimA[1], size=(dimA[1] // 8,), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
314
315
316
    if has_bias == False:
        req_grad = list(req_grad)
        req_grad[2] = False
Tim Dettmers's avatar
Tim Dettmers committed
317
318
319
320
321

    for i in range(k):

        # normal multiply
        if funcs[0] in [torch.mm, torch.matmul]:
322
323
324
            A = torch.randn(
                size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype
            )
Tim Dettmers's avatar
Tim Dettmers committed
325
326
327
            if decomp == 6.0:
                with torch.no_grad():
                    A[:, outlier_dim] = 6.0
328
329
330
331
            B = torch.randn(
                size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype
            )
            target = torch.randn(
332
333
334
335
                size=(dim2, dim4),
                device="cuda",
                requires_grad=req_grad[1],
                dtype=dtype,
336
            )
Tim Dettmers's avatar
Tim Dettmers committed
337
338
            bias = None
            bias2 = None
339
            if has_bias:
Tim Dettmers's avatar
Tim Dettmers committed
340
341
                bias = torch.randn(dim4, device='cuda', dtype=dtype, requires_grad=req_grad[2])
                bias2 = bias.clone()
Tim Dettmers's avatar
Tim Dettmers committed
342
343
344
345
346
347
348
            torch.nn.init.xavier_uniform_(B)
            B2 = B.clone()

            state = bnb.MatmulLtState()
            state.threshold = decomp
            state.has_fp16_weights = has_fp16_weights
            if not has_fp16_weights:
349
350
351
352
353
354
355
356
                if not transpose[0] and not transpose[1]:
                    B2 = B2.t().contiguous()
                (
                    state.CB,
                    CBt,
                    state.SCB,
                    SCBt,
                    coo_tensorB,
justheuristic's avatar
justheuristic committed
357
                ) = bnb.functional.double_quant(B2.to(torch.float16))
Tim Dettmers's avatar
Tim Dettmers committed
358
359
360
361
                B2 = state.CB

            if not transpose[0] and transpose[1]:
                out_torch = funcs[0](A, B.t())
Tim Dettmers's avatar
Tim Dettmers committed
362
                out_bnb = funcs[1](A, B2, state=state, bias=bias2)
Tim Dettmers's avatar
Tim Dettmers committed
363
364
            elif not transpose[0] and not transpose[1]:
                out_torch = funcs[0](A, B)
Tim Dettmers's avatar
Tim Dettmers committed
365
366
367
368
                out_bnb = funcs[1](A, B2.t(), state=state, bias=bias2)

            if has_bias:
                out_torch += bias
Tim Dettmers's avatar
Tim Dettmers committed
369

justheuristic's avatar
justheuristic committed
370
            assert out_bnb.dtype == A.dtype, f"bnb matmullt received {A.dtype} but returned {out_bnb.dtype}"
justheuristic's avatar
justheuristic committed
371

Tim Dettmers's avatar
Tim Dettmers committed
372
            n = out_bnb.numel()
373
374
            err = torch.abs(out_bnb - out_torch).mean().item()
            # print(f'abs error {err:.4f}')
justheuristic's avatar
justheuristic committed
375

Tim Dettmers's avatar
Tim Dettmers committed
376
            idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
justheuristic's avatar
justheuristic committed
377
            assert (idx == 0).sum().item() <= n * (0.0175 if dtype == torch.float16 else 0.021)
Tim Dettmers's avatar
Tim Dettmers committed
378
            idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
Tim Dettmers's avatar
Tim Dettmers committed
379
            assert (idx == 0).sum().item() <= n * 0.001
Tim Dettmers's avatar
Tim Dettmers committed
380
381
382
383
384

            if has_fp16_weights:
                if any(req_grad):
                    out_bnb.data.copy_(out_torch)
                    torch.cuda.synchronize()
385
386
387
                    loss_bnb = torch.nn.functional.mse_loss(
                        out_bnb, target
                    ).mean()
Tim Dettmers's avatar
Tim Dettmers committed
388
389
390
391
392
                    loss_bnb.backward()
                    gradA1 = A.grad
                    gradB1 = B.grad
                    A.grad = None
                    B.grad = None
Tim Dettmers's avatar
Tim Dettmers committed
393
394
395
                    if has_bias:
                        gradBias1 = bias.grad
                        bias.grad = None
Tim Dettmers's avatar
Tim Dettmers committed
396

397
398
399
                    loss_torch = torch.nn.functional.mse_loss(
                        out_torch, target
                    ).mean()
Tim Dettmers's avatar
Tim Dettmers committed
400
401
402
403
404
                    loss_torch.backward()
                    gradA2 = A.grad
                    gradB2 = B.grad
                    A.grad = None
                    B.grad = None
Tim Dettmers's avatar
Tim Dettmers committed
405
406
407
                    if has_bias:
                        gradBias2 = bias.grad
                        bias.grad = None
Tim Dettmers's avatar
Tim Dettmers committed
408
409

                if req_grad[0]:
410
411
412
                    torch.testing.assert_allclose(
                        gradA1, gradA2, atol=0.015, rtol=0.1
                    )
Tim Dettmers's avatar
Tim Dettmers committed
413
414
                if req_grad[1]:
                    n = gradB1.numel()
415
416
417
418
419
420
                    if dim2 > 0:
                        assert torch.abs(gradB1).sum() > 0.0
                        assert torch.abs(gradB2).sum() > 0.0
                    else:
                        assert torch.abs(gradB1).sum() == 0.0
                        assert torch.abs(gradB2).sum() == 0.0
Tim Dettmers's avatar
Tim Dettmers committed
421
422
                    idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)

Tim Dettmers's avatar
Tim Dettmers committed
423
                    assert (idx == 0).sum().item() <= n * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
424
                    idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
Tim Dettmers's avatar
Tim Dettmers committed
425
                    assert (idx == 0).sum().item() <= n * 0.02
426
427
428
                    torch.testing.assert_allclose(
                        gradB1, gradB2, atol=0.18, rtol=0.3
                    )
Tim Dettmers's avatar
Tim Dettmers committed
429
430

                if req_grad[2]:
justheuristic's avatar
justheuristic committed
431
                    torch.testing.assert_allclose(gradBias1, gradBias2)
Tim Dettmers's avatar
Tim Dettmers committed
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487


n = 1
k = 3
dim1 = torch.randint(16, 64, size=(n,)).tolist()
dim2 = torch.randint(32, 96, size=(n,)).tolist()
dim3 = torch.randint(32, 96, size=(n,)).tolist()
dim4 = torch.randint(32, 96, size=(n,)).tolist()

dim2.append(0)

funcs = [(torch.matmul, bnb.matmul_fp4)]
str_funcs = ["matmul"]
req_grad = list(product([True, False], repeat=3))
req_grad_str = []
for c in req_grad:
    strval = ''
    for v in c:
        if v == True: strval += 'T'
        else: strval += 'F'
    req_grad_str.append(strval)

transpose = [(False, True), (False, False)]
str_transpose = ["NT", "NN"]
dtype = [torch.float16, torch.float32]
has_fp16_weights = [True, False]
has_bias = [True, False]
values = list(product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias))
str_values = list(product(dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose, has_bias))
names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}_has_bias_{}".format(*vals) for vals in str_values]
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
@pytest.mark.parametrize( "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias", values, ids=names)
def test_matmul_fp4( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias):
    dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
    dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
    if has_bias == False:
        req_grad = list(req_grad)
        req_grad[2] = False

    for i in range(k):
        # normal multiply
        if funcs[0] in [torch.mm, torch.matmul]:
            A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype)
            B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype)
            target = torch.randn(size=(dim2, dim4), device="cuda", requires_grad=req_grad[1], dtype=dtype)
            bias = None
            bias2 = None
            if has_bias:
                bias = torch.randn(dim4, device='cuda', dtype=dtype, requires_grad=req_grad[2])
                bias2 = bias.clone()
            torch.nn.init.xavier_uniform_(B)

            B2, quant_state = bnb.functional.quantize_fp4(B)

            if not transpose[0] and transpose[1]:
                out_torch = funcs[0](A, B.t())
Tim Dettmers's avatar
Tim Dettmers committed
488
                out_bnb = funcs[1](A, B2.t(), quant_state, bias=bias2)
Tim Dettmers's avatar
Tim Dettmers committed
489
490
            elif not transpose[0] and not transpose[1]:
                out_torch = funcs[0](A, B)
Tim Dettmers's avatar
Tim Dettmers committed
491
                out_bnb = funcs[1](A, B2, quant_state, bias=bias2)
Tim Dettmers's avatar
Tim Dettmers committed
492
493
494
495
496
497
498
499
500

            if has_bias:
                out_torch += bias

            assert out_bnb.dtype == A.dtype, f"bnb matmullt received {A.dtype} but returned {out_bnb.dtype}"

            n = out_bnb.numel()
            err = torch.abs(out_bnb - out_torch).float().mean().item()
            if n > 0:
501
                assert err < 0.115
Tim Dettmers's avatar
Tim Dettmers committed
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530

            if any(req_grad):
                out_bnb.data.copy_(out_torch)
                torch.cuda.synchronize()
                loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
                loss_bnb.backward()
                gradA1 = A.grad
                gradB1 = B.grad
                A.grad = None
                B.grad = None
                if has_bias:
                    gradBias1 = bias.grad
                    bias.grad = None

                loss_torch = torch.nn.functional.mse_loss( out_torch, target ).mean()
                loss_torch.backward()
                gradA2 = A.grad
                gradB2 = B.grad
                A.grad = None
                B.grad = None
                if has_bias:
                    gradBias2 = bias.grad
                    bias.grad = None

                if req_grad[0]:
                    torch.testing.assert_allclose( gradA1, gradA2, atol=0.015, rtol=0.1)

                if req_grad[2]:
                    torch.testing.assert_allclose(gradBias1, gradBias2)