test_autograd.py 13.2 KB
Newer Older
1
from itertools import product
Tim Dettmers's avatar
Tim Dettmers committed
2

3
import pytest
Tim Dettmers's avatar
Tim Dettmers committed
4
5
import torch

6
import bitsandbytes as bnb
Tim Dettmers's avatar
Tim Dettmers committed
7
8
9

n = 1
k = 25
10
11
12
13
dim1 = torch.randint(16, 64, size=(n,)).tolist()
dim2 = torch.randint(32, 96, size=(n,)).tolist()
dim3 = torch.randint(32, 96, size=(n,)).tolist()
dim4 = torch.randint(32, 96, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
14
funcs = [(torch.bmm, bnb.bmm_cublas), (torch.matmul, bnb.matmul_cublas)]
15
str_funcs = ["bmm", "matmul"]
Tim Dettmers's avatar
Tim Dettmers committed
16
req_grad = [(False, False), (True, False), (True, True), (False, True)]
17
req_grad_str = ["FF", "TF", "TT", "FT"]
Tim Dettmers's avatar
Tim Dettmers committed
18
transpose = [(False, False), (False, True), (True, True), (True, False)]
19
str_transpose = ["FF", "FT", "TT", "TF"]
Tim Dettmers's avatar
Tim Dettmers committed
20
dtype = [torch.float32, torch.float16]
21
22
23
values = list(
    product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose)
)
24
str_values = list(
25
26
27
    product(
        dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose
    )
28
29
30
31
32
33
34
35
36
37
)
names = [
    "dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}".format(
        *vals
    )
    for vals in str_values
]


@pytest.mark.parametrize(
38
39
40
    "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose",
    values,
    ids=names,
41
)
Tim Dettmers's avatar
Tim Dettmers committed
42
43
44
45
46
47
48
49
50
51
def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
    dim2 = dim2 - (dim2 % 16)
    dim3 = dim3 - (dim3 % 16)
    dim4 = dim4 - (dim4 % 16)
    for i in range(k):

        # normal multiply
        if funcs[0] in [torch.mm, torch.matmul]:
            dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
            dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
52
53
54
55
56
            A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0])
            B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1])
            target = torch.randn(
                size=(dim2, dim4), device="cuda", requires_grad=req_grad[1]
            )
Tim Dettmers's avatar
Tim Dettmers committed
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
            torch.nn.init.xavier_uniform_(B)

            if not transpose[0] and not transpose[1]:
                out_torch = funcs[0](A, B)
                out_bnb = funcs[1](A, B)
            elif not transpose[0] and transpose[1]:
                out_torch = funcs[0](A, B.t())
                out_bnb = funcs[1](A, B.t())
            elif transpose[0] and not transpose[1]:
                out_torch = funcs[0](A.t(), B)
                out_bnb = funcs[1](A.t(), B)
            elif transpose[0] and transpose[1]:
                out_torch = funcs[0](A.t(), B.t())
                out_bnb = funcs[1](A.t(), B.t())

            n = out_bnb.numel()
            idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
74
            assert (idx == 0).sum().item() < n * 0.0175
Tim Dettmers's avatar
Tim Dettmers committed
75
            idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
76
            assert (idx == 0).sum().item() < n * 0.001
Tim Dettmers's avatar
Tim Dettmers committed
77
78
79
80
81
82
83
84
85
86
87

            if any(req_grad):
                out_bnb.data.copy_(out_torch)
                torch.cuda.synchronize()
                loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
                loss_bnb.backward()
                gradA1 = A.grad
                gradB1 = B.grad
                A.grad = None
                B.grad = None

88
89
90
                loss_torch = torch.nn.functional.mse_loss(
                    out_torch, target
                ).mean()
Tim Dettmers's avatar
Tim Dettmers committed
91
92
93
94
95
96
97
                loss_torch.backward()
                gradA2 = A.grad
                gradB2 = B.grad
                A.grad = None
                B.grad = None

            if req_grad[0]:
98
99
100
                torch.testing.assert_allclose(
                    gradA1, gradA2, atol=0.015, rtol=0.1
                )
Tim Dettmers's avatar
Tim Dettmers committed
101
102
103
            if req_grad[1]:
                n = gradB1.numel()
                idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
104
                assert (idx == 0).sum().item() < n * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
105
                idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
106
                assert (idx == 0).sum().item() < n * 0.02
107
108
109
                torch.testing.assert_allclose(
                    gradB1, gradB2, atol=0.18, rtol=0.3
                )
Tim Dettmers's avatar
Tim Dettmers committed
110
111
112

        # batched matrix multiply
        if funcs[0] in [torch.bmm, torch.matmul]:
113
            A = torch.randn(
114
115
116
                size=(dim1, dim2, dim3),
                device="cuda",
                requires_grad=req_grad[0],
117
118
            )
            B = torch.randn(
119
120
121
                size=(dim1, dim3, dim4),
                device="cuda",
                requires_grad=req_grad[1],
122
123
            )
            target = torch.randn(
124
125
126
                size=(dim1, dim2, dim4),
                device="cuda",
                requires_grad=req_grad[1],
127
            )
Tim Dettmers's avatar
Tim Dettmers committed
128
129
130
131
132
133
134
            torch.nn.init.xavier_uniform_(B)

            out_torch = funcs[0](A, B)
            out_bnb = funcs[1](A, B)

            n = out_bnb.numel()
            idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
135
            assert (idx == 0).sum().item() < n * 0.01
136
137
138
            torch.testing.assert_allclose(
                out_bnb, out_torch, atol=0.027, rtol=0.2
            )
Tim Dettmers's avatar
Tim Dettmers committed
139
140
141
142
143
144
145
146
147
148
149

            if any(req_grad):
                out_bnb.data.copy_(out_torch)
                torch.cuda.synchronize()
                loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
                loss_bnb.backward()
                gradA1 = A.grad
                gradB1 = B.grad
                A.grad = None
                B.grad = None

150
151
152
                loss_torch = torch.nn.functional.mse_loss(
                    out_torch, target
                ).mean()
Tim Dettmers's avatar
Tim Dettmers committed
153
154
155
156
157
158
159
                loss_torch.backward()
                gradA2 = A.grad
                gradB2 = B.grad
                A.grad = None
                B.grad = None

            if req_grad[0]:
160
161
162
                torch.testing.assert_allclose(
                    gradA1, gradA2, atol=0.015, rtol=0.1
                )
Tim Dettmers's avatar
Tim Dettmers committed
163
164
165
            if req_grad[1]:
                n = gradB1.numel()
                idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
166
                assert (idx == 0).sum().item() < n * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
167
                idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
168
                assert (idx == 0).sum().item() < n * 0.02
Tim Dettmers's avatar
Tim Dettmers committed
169
170
171

        if funcs[0] in [torch.matmul]:
            dim1 = dim1 - (dim1 % 16)
172
            A = torch.randn(
173
174
175
                size=(dim1, dim2, dim3),
                device="cuda",
                requires_grad=req_grad[0],
176
            )
Tim Dettmers's avatar
Tim Dettmers committed
177
            dimB = (dim4, dim3) if transpose[1] else (dim3, dim4)
178
179
            B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1])
            target = torch.randn(
180
181
182
                size=(dim1, dim2, dim4),
                device="cuda",
                requires_grad=req_grad[1],
183
            )
Tim Dettmers's avatar
Tim Dettmers committed
184
185
186
187
188
189
190
191
192
193
194
            torch.nn.init.xavier_uniform_(B)

            if transpose[1]:
                out_torch = funcs[0](A, B.t())
                out_bnb = funcs[1](A, B.t())
            else:
                out_torch = funcs[0](A, B)
                out_bnb = funcs[1](A, B)

            n = out_bnb.numel()
            idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
195
            assert (idx == 0).sum().item() < n * 0.0175
Tim Dettmers's avatar
Tim Dettmers committed
196
            idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
197
            assert (idx == 0).sum().item() < n * 0.001
Tim Dettmers's avatar
Tim Dettmers committed
198
199
200
201
202
203
204
205
206
207
208

            if any(req_grad):
                out_bnb.data.copy_(out_torch)
                torch.cuda.synchronize()
                loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
                loss_bnb.backward()
                gradA1 = A.grad
                gradB1 = B.grad
                A.grad = None
                B.grad = None

209
210
211
                loss_torch = torch.nn.functional.mse_loss(
                    out_torch, target
                ).mean()
Tim Dettmers's avatar
Tim Dettmers committed
212
213
214
215
216
217
218
                loss_torch.backward()
                gradA2 = A.grad
                gradB2 = B.grad
                A.grad = None
                B.grad = None

            if req_grad[0]:
219
220
221
                torch.testing.assert_allclose(
                    gradA1, gradA2, atol=0.015, rtol=0.1
                )
Tim Dettmers's avatar
Tim Dettmers committed
222
223
224
            if req_grad[1]:
                n = gradB1.numel()
                idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
225
                assert (idx == 0).sum().item() < n * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
226
                idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
227
                assert (idx == 0).sum().item() < n * 0.02
Tim Dettmers's avatar
Tim Dettmers committed
228
229
230
231


n = 1
k = 3
232
233
234
235
dim1 = torch.randint(16, 64, size=(n,)).tolist()
dim2 = torch.randint(32, 96, size=(n,)).tolist()
dim3 = torch.randint(32, 96, size=(n,)).tolist()
dim4 = torch.randint(32, 96, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
236

237
238
239
240
# dim1 = (17,)
# dim2 = (7,)
# dim3 = (37,)
# dim4 = (23,)
Tim Dettmers's avatar
Tim Dettmers committed
241
242
243

decomp = [0.0, 6.0]
funcs = [(torch.matmul, bnb.matmul)]
244
str_funcs = ["matmul"]
Tim Dettmers's avatar
Tim Dettmers committed
245
req_grad = [(False, False), (True, False), (True, True), (False, True)]
246
req_grad_str = ["FF", "TF", "TT", "FT"]
Tim Dettmers's avatar
Tim Dettmers committed
247
transpose = [(False, True), (False, False)]
248
str_transpose = ["NT", "NN"]
Tim Dettmers's avatar
Tim Dettmers committed
249
250
dtype = [torch.float16]
has_fp16_weights = [True, False]
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
values = list(
    product(
        dim1,
        dim2,
        dim3,
        dim4,
        funcs,
        dtype,
        req_grad,
        transpose,
        decomp,
        has_fp16_weights,
    )
)
str_values = list(
    product(
        dim1,
        dim2,
        dim3,
        dim4,
        str_funcs,
        dtype,
        req_grad_str,
        str_transpose,
        decomp,
        has_fp16_weights,
    )
)
names = [
    "dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}_decomp_{8}_has_fp16_weights_{9}".format(
        *vals
    )
    for vals in str_values
]


@pytest.mark.parametrize(
    "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights",
    values,
    ids=names,
)
def test_matmullt(
293
294
295
296
297
298
299
300
301
302
    dim1,
    dim2,
    dim3,
    dim4,
    funcs,
    dtype,
    req_grad,
    transpose,
    decomp,
    has_fp16_weights,
303
):
Tim Dettmers's avatar
Tim Dettmers committed
304
305
    dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
    dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
306
    outlier_dim = torch.randint(0, dimA[1], size=(dimA[1] // 8,), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
307
308
309
310
311

    for i in range(k):

        # normal multiply
        if funcs[0] in [torch.mm, torch.matmul]:
312
313
314
            A = torch.randn(
                size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype
            )
Tim Dettmers's avatar
Tim Dettmers committed
315
316
317
            if decomp == 6.0:
                with torch.no_grad():
                    A[:, outlier_dim] = 6.0
318
319
320
321
            B = torch.randn(
                size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype
            )
            target = torch.randn(
322
323
324
325
                size=(dim2, dim4),
                device="cuda",
                requires_grad=req_grad[1],
                dtype=dtype,
326
            )
Tim Dettmers's avatar
Tim Dettmers committed
327
328
329
330
331
332
333
            torch.nn.init.xavier_uniform_(B)
            B2 = B.clone()

            state = bnb.MatmulLtState()
            state.threshold = decomp
            state.has_fp16_weights = has_fp16_weights
            if not has_fp16_weights:
334
335
336
337
338
339
340
341
342
                if not transpose[0] and not transpose[1]:
                    B2 = B2.t().contiguous()
                (
                    state.CB,
                    CBt,
                    state.SCB,
                    SCBt,
                    coo_tensorB,
                ) = bnb.functional.double_quant(B2)
Tim Dettmers's avatar
Tim Dettmers committed
343
344
345
346
347
348
349
350
351
352
                B2 = state.CB

            if not transpose[0] and transpose[1]:
                out_torch = funcs[0](A, B.t())
                out_bnb = funcs[1](A, B2, state=state)
            elif not transpose[0] and not transpose[1]:
                out_torch = funcs[0](A, B)
                out_bnb = funcs[1](A, B2.t(), state=state)

            n = out_bnb.numel()
353
354
            err = torch.abs(out_bnb - out_torch).mean().item()
            # print(f'abs error {err:.4f}')
Tim Dettmers's avatar
Tim Dettmers committed
355
            idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
356
            assert (idx == 0).sum().item() < n * 0.0175
Tim Dettmers's avatar
Tim Dettmers committed
357
            idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
358
            assert (idx == 0).sum().item() < n * 0.001
Tim Dettmers's avatar
Tim Dettmers committed
359
360
361
362
363

            if has_fp16_weights:
                if any(req_grad):
                    out_bnb.data.copy_(out_torch)
                    torch.cuda.synchronize()
364
365
366
                    loss_bnb = torch.nn.functional.mse_loss(
                        out_bnb, target
                    ).mean()
Tim Dettmers's avatar
Tim Dettmers committed
367
368
369
370
371
372
                    loss_bnb.backward()
                    gradA1 = A.grad
                    gradB1 = B.grad
                    A.grad = None
                    B.grad = None

373
374
375
                    loss_torch = torch.nn.functional.mse_loss(
                        out_torch, target
                    ).mean()
Tim Dettmers's avatar
Tim Dettmers committed
376
377
378
379
380
381
382
                    loss_torch.backward()
                    gradA2 = A.grad
                    gradB2 = B.grad
                    A.grad = None
                    B.grad = None

                if req_grad[0]:
383
384
385
                    torch.testing.assert_allclose(
                        gradA1, gradA2, atol=0.015, rtol=0.1
                    )
Tim Dettmers's avatar
Tim Dettmers committed
386
387
388
389
390
                if req_grad[1]:
                    n = gradB1.numel()
                    assert torch.abs(gradB1).sum() > 0.0
                    assert torch.abs(gradB2).sum() > 0.0
                    idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
391
                    assert (idx == 0).sum().item() < n * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
392
                    idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
393
                    assert (idx == 0).sum().item() < n * 0.02
394
395
396
                    torch.testing.assert_allclose(
                        gradB1, gradB2, atol=0.18, rtol=0.3
                    )