"vscode:/vscode.git/clone" did not exist on "799f5b4e12c5350872b6fe5ebc28be423d2570c3"
test_autograd.py 14.6 KB
Newer Older
Tim Dettmers's avatar
Tim Dettmers committed
1
from itertools import product, permutations
Tim Dettmers's avatar
Tim Dettmers committed
2

3
import pytest
Tim Dettmers's avatar
Tim Dettmers committed
4
5
import torch

6
import bitsandbytes as bnb
Tim Dettmers's avatar
Tim Dettmers committed
7
8
9

n = 1
k = 25
10
11
12
13
dim1 = torch.randint(16, 64, size=(n,)).tolist()
dim2 = torch.randint(32, 96, size=(n,)).tolist()
dim3 = torch.randint(32, 96, size=(n,)).tolist()
dim4 = torch.randint(32, 96, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
14
funcs = [(torch.bmm, bnb.bmm_cublas), (torch.matmul, bnb.matmul_cublas)]
15
str_funcs = ["bmm", "matmul"]
Tim Dettmers's avatar
Tim Dettmers committed
16
req_grad = [(False, False), (True, False), (True, True), (False, True)]
17
req_grad_str = ["FF", "TF", "TT", "FT"]
Tim Dettmers's avatar
Tim Dettmers committed
18
transpose = [(False, False), (False, True), (True, True), (True, False)]
19
str_transpose = ["FF", "FT", "TT", "TF"]
Tim Dettmers's avatar
Tim Dettmers committed
20
dtype = [torch.float32, torch.float16]
21
22
23
values = list(
    product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose)
)
24
str_values = list(
25
26
27
    product(
        dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose
    )
28
29
30
31
32
33
34
35
36
37
)
names = [
    "dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}".format(
        *vals
    )
    for vals in str_values
]


@pytest.mark.parametrize(
38
39
40
    "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose",
    values,
    ids=names,
41
)
Tim Dettmers's avatar
Tim Dettmers committed
42
def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
43
    if not torch.cuda.is_available(): pytest.skip('No GPU found.')
44
45
    if dim2 > 0:
        dim2 = dim2 - (dim2 % 16)
Tim Dettmers's avatar
Tim Dettmers committed
46
47
48
49
50
51
52
53
    dim3 = dim3 - (dim3 % 16)
    dim4 = dim4 - (dim4 % 16)
    for i in range(k):

        # normal multiply
        if funcs[0] in [torch.mm, torch.matmul]:
            dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
            dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
54
55
56
57
58
            A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0])
            B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1])
            target = torch.randn(
                size=(dim2, dim4), device="cuda", requires_grad=req_grad[1]
            )
Tim Dettmers's avatar
Tim Dettmers committed
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
            torch.nn.init.xavier_uniform_(B)

            if not transpose[0] and not transpose[1]:
                out_torch = funcs[0](A, B)
                out_bnb = funcs[1](A, B)
            elif not transpose[0] and transpose[1]:
                out_torch = funcs[0](A, B.t())
                out_bnb = funcs[1](A, B.t())
            elif transpose[0] and not transpose[1]:
                out_torch = funcs[0](A.t(), B)
                out_bnb = funcs[1](A.t(), B)
            elif transpose[0] and transpose[1]:
                out_torch = funcs[0](A.t(), B.t())
                out_bnb = funcs[1](A.t(), B.t())

            n = out_bnb.numel()
            idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
76
            assert (idx == 0).sum().item() < n * 0.0175
Tim Dettmers's avatar
Tim Dettmers committed
77
            idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
78
            assert (idx == 0).sum().item() < n * 0.001
Tim Dettmers's avatar
Tim Dettmers committed
79
80
81
82
83
84
85
86
87
88
89

            if any(req_grad):
                out_bnb.data.copy_(out_torch)
                torch.cuda.synchronize()
                loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
                loss_bnb.backward()
                gradA1 = A.grad
                gradB1 = B.grad
                A.grad = None
                B.grad = None

90
91
92
                loss_torch = torch.nn.functional.mse_loss(
                    out_torch, target
                ).mean()
Tim Dettmers's avatar
Tim Dettmers committed
93
94
95
96
97
98
99
                loss_torch.backward()
                gradA2 = A.grad
                gradB2 = B.grad
                A.grad = None
                B.grad = None

            if req_grad[0]:
100
101
102
                torch.testing.assert_allclose(
                    gradA1, gradA2, atol=0.015, rtol=0.1
                )
Tim Dettmers's avatar
Tim Dettmers committed
103
104
105
            if req_grad[1]:
                n = gradB1.numel()
                idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
106
                assert (idx == 0).sum().item() < n * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
107
                idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
108
                assert (idx == 0).sum().item() < n * 0.02
109
110
111
                torch.testing.assert_allclose(
                    gradB1, gradB2, atol=0.18, rtol=0.3
                )
Tim Dettmers's avatar
Tim Dettmers committed
112
113
114

        # batched matrix multiply
        if funcs[0] in [torch.bmm, torch.matmul]:
115
            A = torch.randn(
116
117
118
                size=(dim1, dim2, dim3),
                device="cuda",
                requires_grad=req_grad[0],
119
120
            )
            B = torch.randn(
121
122
123
                size=(dim1, dim3, dim4),
                device="cuda",
                requires_grad=req_grad[1],
124
125
            )
            target = torch.randn(
126
127
128
                size=(dim1, dim2, dim4),
                device="cuda",
                requires_grad=req_grad[1],
129
            )
Tim Dettmers's avatar
Tim Dettmers committed
130
131
132
133
134
135
136
            torch.nn.init.xavier_uniform_(B)

            out_torch = funcs[0](A, B)
            out_bnb = funcs[1](A, B)

            n = out_bnb.numel()
            idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
137
            assert (idx == 0).sum().item() < n * 0.01
138
139
140
            torch.testing.assert_allclose(
                out_bnb, out_torch, atol=0.027, rtol=0.2
            )
Tim Dettmers's avatar
Tim Dettmers committed
141
142
143
144
145
146
147
148
149
150
151

            if any(req_grad):
                out_bnb.data.copy_(out_torch)
                torch.cuda.synchronize()
                loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
                loss_bnb.backward()
                gradA1 = A.grad
                gradB1 = B.grad
                A.grad = None
                B.grad = None

152
153
154
                loss_torch = torch.nn.functional.mse_loss(
                    out_torch, target
                ).mean()
Tim Dettmers's avatar
Tim Dettmers committed
155
156
157
158
159
160
161
                loss_torch.backward()
                gradA2 = A.grad
                gradB2 = B.grad
                A.grad = None
                B.grad = None

            if req_grad[0]:
162
163
164
                torch.testing.assert_allclose(
                    gradA1, gradA2, atol=0.015, rtol=0.1
                )
Tim Dettmers's avatar
Tim Dettmers committed
165
166
167
            if req_grad[1]:
                n = gradB1.numel()
                idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
168
                assert (idx == 0).sum().item() < n * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
169
                idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
170
                assert (idx == 0).sum().item() < n * 0.02
Tim Dettmers's avatar
Tim Dettmers committed
171
172
173

        if funcs[0] in [torch.matmul]:
            dim1 = dim1 - (dim1 % 16)
174
            A = torch.randn(
175
176
177
                size=(dim1, dim2, dim3),
                device="cuda",
                requires_grad=req_grad[0],
178
            )
Tim Dettmers's avatar
Tim Dettmers committed
179
            dimB = (dim4, dim3) if transpose[1] else (dim3, dim4)
180
181
            B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1])
            target = torch.randn(
182
183
184
                size=(dim1, dim2, dim4),
                device="cuda",
                requires_grad=req_grad[1],
185
            )
Tim Dettmers's avatar
Tim Dettmers committed
186
187
188
189
190
191
192
193
194
195
196
            torch.nn.init.xavier_uniform_(B)

            if transpose[1]:
                out_torch = funcs[0](A, B.t())
                out_bnb = funcs[1](A, B.t())
            else:
                out_torch = funcs[0](A, B)
                out_bnb = funcs[1](A, B)

            n = out_bnb.numel()
            idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
197
            assert (idx == 0).sum().item() < n * 0.0175
Tim Dettmers's avatar
Tim Dettmers committed
198
            idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
199
            assert (idx == 0).sum().item() < n * 0.001
Tim Dettmers's avatar
Tim Dettmers committed
200
201
202
203
204
205
206
207
208
209
210

            if any(req_grad):
                out_bnb.data.copy_(out_torch)
                torch.cuda.synchronize()
                loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
                loss_bnb.backward()
                gradA1 = A.grad
                gradB1 = B.grad
                A.grad = None
                B.grad = None

211
212
213
                loss_torch = torch.nn.functional.mse_loss(
                    out_torch, target
                ).mean()
Tim Dettmers's avatar
Tim Dettmers committed
214
215
216
217
218
219
220
                loss_torch.backward()
                gradA2 = A.grad
                gradB2 = B.grad
                A.grad = None
                B.grad = None

            if req_grad[0]:
221
222
223
                torch.testing.assert_allclose(
                    gradA1, gradA2, atol=0.015, rtol=0.1
                )
Tim Dettmers's avatar
Tim Dettmers committed
224
225
226
            if req_grad[1]:
                n = gradB1.numel()
                idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
227
                assert (idx == 0).sum().item() < n * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
228
                idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
229
                assert (idx == 0).sum().item() < n * 0.02
Tim Dettmers's avatar
Tim Dettmers committed
230
231
232
233


n = 1
k = 3
234
235
236
237
dim1 = torch.randint(16, 64, size=(n,)).tolist()
dim2 = torch.randint(32, 96, size=(n,)).tolist()
dim3 = torch.randint(32, 96, size=(n,)).tolist()
dim4 = torch.randint(32, 96, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
238

239
dim2.append(0)
Tim Dettmers's avatar
Tim Dettmers committed
240
241
242

decomp = [0.0, 6.0]
funcs = [(torch.matmul, bnb.matmul)]
243
str_funcs = ["matmul"]
Tim Dettmers's avatar
Tim Dettmers committed
244
req_grad = [(False, False), (True, False), (True, True), (False, True)]
Tim Dettmers's avatar
Tim Dettmers committed
245
246
247
248
249
250
251
252
253
req_grad = list(product([True, False], repeat=3))
req_grad_str = []
for c in req_grad:
    strval = ''
    for v in c:
        if v == True: strval += 'T'
        else: strval += 'F'
    req_grad_str.append(strval)

Tim Dettmers's avatar
Tim Dettmers committed
254
transpose = [(False, True), (False, False)]
255
str_transpose = ["NT", "NN"]
justheuristic's avatar
justheuristic committed
256
dtype = [torch.float16, torch.bfloat16]
Tim Dettmers's avatar
Tim Dettmers committed
257
has_fp16_weights = [True, False]
Tim Dettmers's avatar
Tim Dettmers committed
258
has_bias = [True, False]
259
260
261
262
263
264
265
266
267
268
269
270
values = list(
    product(
        dim1,
        dim2,
        dim3,
        dim4,
        funcs,
        dtype,
        req_grad,
        transpose,
        decomp,
        has_fp16_weights,
Tim Dettmers's avatar
Tim Dettmers committed
271
        has_bias
272
273
274
275
276
277
278
279
280
281
282
283
284
285
    )
)
str_values = list(
    product(
        dim1,
        dim2,
        dim3,
        dim4,
        str_funcs,
        dtype,
        req_grad_str,
        str_transpose,
        decomp,
        has_fp16_weights,
Tim Dettmers's avatar
Tim Dettmers committed
286
        has_bias
287
288
    )
)
Tim Dettmers's avatar
Tim Dettmers committed
289
names = ["dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}_decomp_{8}_has_fp16_weights_{9}_has_bias_{10}".format(*vals) for vals in str_values]
290
291
292


@pytest.mark.parametrize(
Tim Dettmers's avatar
Tim Dettmers committed
293
    "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights, has_bias",
294
295
296
297
    values,
    ids=names,
)
def test_matmullt(
298
299
300
301
302
303
304
305
306
307
    dim1,
    dim2,
    dim3,
    dim4,
    funcs,
    dtype,
    req_grad,
    transpose,
    decomp,
    has_fp16_weights,
Tim Dettmers's avatar
Tim Dettmers committed
308
    has_bias
309
):
310
    if not torch.cuda.is_available(): pytest.skip('No GPU found.')
Tim Dettmers's avatar
Tim Dettmers committed
311
312
    dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
    dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
313
    outlier_dim = torch.randint(0, dimA[1], size=(dimA[1] // 8,), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
314
315
316
    if has_bias == False:
        req_grad = list(req_grad)
        req_grad[2] = False
Tim Dettmers's avatar
Tim Dettmers committed
317
318
319
320
321

    for i in range(k):

        # normal multiply
        if funcs[0] in [torch.mm, torch.matmul]:
322
323
324
            A = torch.randn(
                size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype
            )
Tim Dettmers's avatar
Tim Dettmers committed
325
326
327
            if decomp == 6.0:
                with torch.no_grad():
                    A[:, outlier_dim] = 6.0
328
329
330
331
            B = torch.randn(
                size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype
            )
            target = torch.randn(
332
333
334
335
                size=(dim2, dim4),
                device="cuda",
                requires_grad=req_grad[1],
                dtype=dtype,
336
            )
Tim Dettmers's avatar
Tim Dettmers committed
337
338
339
340
341
            bias = None
            bias2 = None
            if has_bias: 
                bias = torch.randn(dim4, device='cuda', dtype=dtype, requires_grad=req_grad[2])
                bias2 = bias.clone()
Tim Dettmers's avatar
Tim Dettmers committed
342
343
344
345
346
347
348
            torch.nn.init.xavier_uniform_(B)
            B2 = B.clone()

            state = bnb.MatmulLtState()
            state.threshold = decomp
            state.has_fp16_weights = has_fp16_weights
            if not has_fp16_weights:
349
350
351
352
353
354
355
356
                if not transpose[0] and not transpose[1]:
                    B2 = B2.t().contiguous()
                (
                    state.CB,
                    CBt,
                    state.SCB,
                    SCBt,
                    coo_tensorB,
justheuristic's avatar
justheuristic committed
357
                ) = bnb.functional.double_quant(B2.to(torch.float16))
Tim Dettmers's avatar
Tim Dettmers committed
358
359
360
361
                B2 = state.CB

            if not transpose[0] and transpose[1]:
                out_torch = funcs[0](A, B.t())
Tim Dettmers's avatar
Tim Dettmers committed
362
                out_bnb = funcs[1](A, B2, state=state, bias=bias2)
Tim Dettmers's avatar
Tim Dettmers committed
363
364
            elif not transpose[0] and not transpose[1]:
                out_torch = funcs[0](A, B)
Tim Dettmers's avatar
Tim Dettmers committed
365
366
367
368
                out_bnb = funcs[1](A, B2.t(), state=state, bias=bias2)

            if has_bias:
                out_torch += bias
Tim Dettmers's avatar
Tim Dettmers committed
369

justheuristic's avatar
justheuristic committed
370
            assert out_bnb.dtype == A.dtype, f"bnb matmullt received {A.dtype} but returned {out_bnb.dtype}"
justheuristic's avatar
justheuristic committed
371

Tim Dettmers's avatar
Tim Dettmers committed
372
            n = out_bnb.numel()
373
374
            err = torch.abs(out_bnb - out_torch).mean().item()
            # print(f'abs error {err:.4f}')
justheuristic's avatar
justheuristic committed
375

Tim Dettmers's avatar
Tim Dettmers committed
376
            idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
justheuristic's avatar
justheuristic committed
377
            assert (idx == 0).sum().item() <= n * (0.0175 if dtype == torch.float16 else 0.021)
Tim Dettmers's avatar
Tim Dettmers committed
378
            idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
Tim Dettmers's avatar
Tim Dettmers committed
379
            assert (idx == 0).sum().item() <= n * 0.001
Tim Dettmers's avatar
Tim Dettmers committed
380
381
382
383
384

            if has_fp16_weights:
                if any(req_grad):
                    out_bnb.data.copy_(out_torch)
                    torch.cuda.synchronize()
385
386
387
                    loss_bnb = torch.nn.functional.mse_loss(
                        out_bnb, target
                    ).mean()
Tim Dettmers's avatar
Tim Dettmers committed
388
389
390
391
392
                    loss_bnb.backward()
                    gradA1 = A.grad
                    gradB1 = B.grad
                    A.grad = None
                    B.grad = None
Tim Dettmers's avatar
Tim Dettmers committed
393
394
395
                    if has_bias:
                        gradBias1 = bias.grad
                        bias.grad = None
Tim Dettmers's avatar
Tim Dettmers committed
396

397
398
399
                    loss_torch = torch.nn.functional.mse_loss(
                        out_torch, target
                    ).mean()
Tim Dettmers's avatar
Tim Dettmers committed
400
401
402
403
404
                    loss_torch.backward()
                    gradA2 = A.grad
                    gradB2 = B.grad
                    A.grad = None
                    B.grad = None
Tim Dettmers's avatar
Tim Dettmers committed
405
406
407
                    if has_bias:
                        gradBias2 = bias.grad
                        bias.grad = None
Tim Dettmers's avatar
Tim Dettmers committed
408
409

                if req_grad[0]:
410
411
412
                    torch.testing.assert_allclose(
                        gradA1, gradA2, atol=0.015, rtol=0.1
                    )
Tim Dettmers's avatar
Tim Dettmers committed
413
414
                if req_grad[1]:
                    n = gradB1.numel()
415
416
417
418
419
420
                    if dim2 > 0:
                        assert torch.abs(gradB1).sum() > 0.0
                        assert torch.abs(gradB2).sum() > 0.0
                    else:
                        assert torch.abs(gradB1).sum() == 0.0
                        assert torch.abs(gradB2).sum() == 0.0
Tim Dettmers's avatar
Tim Dettmers committed
421
422
                    idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)

Tim Dettmers's avatar
Tim Dettmers committed
423
                    assert (idx == 0).sum().item() <= n * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
424
                    idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
Tim Dettmers's avatar
Tim Dettmers committed
425
                    assert (idx == 0).sum().item() <= n * 0.02
426
427
428
                    torch.testing.assert_allclose(
                        gradB1, gradB2, atol=0.18, rtol=0.3
                    )
Tim Dettmers's avatar
Tim Dettmers committed
429
430

                if req_grad[2]:
justheuristic's avatar
justheuristic committed
431
                    torch.testing.assert_allclose(gradBias1, gradBias2)