test_autograd.py 13.3 KB
Newer Older
1
from itertools import product
Tim Dettmers's avatar
Tim Dettmers committed
2

3
import pytest
Tim Dettmers's avatar
Tim Dettmers committed
4
5
import torch

6
import bitsandbytes as bnb
Tim Dettmers's avatar
Tim Dettmers committed
7
8
9

n = 1
k = 25
10
11
12
13
dim1 = torch.randint(16, 64, size=(n,)).tolist()
dim2 = torch.randint(32, 96, size=(n,)).tolist()
dim3 = torch.randint(32, 96, size=(n,)).tolist()
dim4 = torch.randint(32, 96, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
14
funcs = [(torch.bmm, bnb.bmm_cublas), (torch.matmul, bnb.matmul_cublas)]
15
str_funcs = ["bmm", "matmul"]
Tim Dettmers's avatar
Tim Dettmers committed
16
req_grad = [(False, False), (True, False), (True, True), (False, True)]
17
req_grad_str = ["FF", "TF", "TT", "FT"]
Tim Dettmers's avatar
Tim Dettmers committed
18
transpose = [(False, False), (False, True), (True, True), (True, False)]
19
str_transpose = ["FF", "FT", "TT", "TF"]
Tim Dettmers's avatar
Tim Dettmers committed
20
dtype = [torch.float32, torch.float16]
21
22
23
values = list(
    product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose)
)
24
str_values = list(
25
26
27
    product(
        dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose
    )
28
29
30
31
32
33
34
35
36
37
)
names = [
    "dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}".format(
        *vals
    )
    for vals in str_values
]


@pytest.mark.parametrize(
38
39
40
    "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose",
    values,
    ids=names,
41
)
Tim Dettmers's avatar
Tim Dettmers committed
42
def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
43
44
    if dim2 > 0:
        dim2 = dim2 - (dim2 % 16)
Tim Dettmers's avatar
Tim Dettmers committed
45
46
47
48
49
50
51
52
    dim3 = dim3 - (dim3 % 16)
    dim4 = dim4 - (dim4 % 16)
    for i in range(k):

        # normal multiply
        if funcs[0] in [torch.mm, torch.matmul]:
            dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
            dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
53
54
55
56
57
            A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0])
            B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1])
            target = torch.randn(
                size=(dim2, dim4), device="cuda", requires_grad=req_grad[1]
            )
Tim Dettmers's avatar
Tim Dettmers committed
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
            torch.nn.init.xavier_uniform_(B)

            if not transpose[0] and not transpose[1]:
                out_torch = funcs[0](A, B)
                out_bnb = funcs[1](A, B)
            elif not transpose[0] and transpose[1]:
                out_torch = funcs[0](A, B.t())
                out_bnb = funcs[1](A, B.t())
            elif transpose[0] and not transpose[1]:
                out_torch = funcs[0](A.t(), B)
                out_bnb = funcs[1](A.t(), B)
            elif transpose[0] and transpose[1]:
                out_torch = funcs[0](A.t(), B.t())
                out_bnb = funcs[1](A.t(), B.t())

            n = out_bnb.numel()
            idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
75
            assert (idx == 0).sum().item() < n * 0.0175
Tim Dettmers's avatar
Tim Dettmers committed
76
            idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
77
            assert (idx == 0).sum().item() < n * 0.001
Tim Dettmers's avatar
Tim Dettmers committed
78
79
80
81
82
83
84
85
86
87
88

            if any(req_grad):
                out_bnb.data.copy_(out_torch)
                torch.cuda.synchronize()
                loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
                loss_bnb.backward()
                gradA1 = A.grad
                gradB1 = B.grad
                A.grad = None
                B.grad = None

89
90
91
                loss_torch = torch.nn.functional.mse_loss(
                    out_torch, target
                ).mean()
Tim Dettmers's avatar
Tim Dettmers committed
92
93
94
95
96
97
98
                loss_torch.backward()
                gradA2 = A.grad
                gradB2 = B.grad
                A.grad = None
                B.grad = None

            if req_grad[0]:
99
100
101
                torch.testing.assert_allclose(
                    gradA1, gradA2, atol=0.015, rtol=0.1
                )
Tim Dettmers's avatar
Tim Dettmers committed
102
103
104
            if req_grad[1]:
                n = gradB1.numel()
                idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
105
                assert (idx == 0).sum().item() < n * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
106
                idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
107
                assert (idx == 0).sum().item() < n * 0.02
108
109
110
                torch.testing.assert_allclose(
                    gradB1, gradB2, atol=0.18, rtol=0.3
                )
Tim Dettmers's avatar
Tim Dettmers committed
111
112
113

        # batched matrix multiply
        if funcs[0] in [torch.bmm, torch.matmul]:
114
            A = torch.randn(
115
116
117
                size=(dim1, dim2, dim3),
                device="cuda",
                requires_grad=req_grad[0],
118
119
            )
            B = torch.randn(
120
121
122
                size=(dim1, dim3, dim4),
                device="cuda",
                requires_grad=req_grad[1],
123
124
            )
            target = torch.randn(
125
126
127
                size=(dim1, dim2, dim4),
                device="cuda",
                requires_grad=req_grad[1],
128
            )
Tim Dettmers's avatar
Tim Dettmers committed
129
130
131
132
133
134
135
            torch.nn.init.xavier_uniform_(B)

            out_torch = funcs[0](A, B)
            out_bnb = funcs[1](A, B)

            n = out_bnb.numel()
            idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
136
            assert (idx == 0).sum().item() < n * 0.01
137
138
139
            torch.testing.assert_allclose(
                out_bnb, out_torch, atol=0.027, rtol=0.2
            )
Tim Dettmers's avatar
Tim Dettmers committed
140
141
142
143
144
145
146
147
148
149
150

            if any(req_grad):
                out_bnb.data.copy_(out_torch)
                torch.cuda.synchronize()
                loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
                loss_bnb.backward()
                gradA1 = A.grad
                gradB1 = B.grad
                A.grad = None
                B.grad = None

151
152
153
                loss_torch = torch.nn.functional.mse_loss(
                    out_torch, target
                ).mean()
Tim Dettmers's avatar
Tim Dettmers committed
154
155
156
157
158
159
160
                loss_torch.backward()
                gradA2 = A.grad
                gradB2 = B.grad
                A.grad = None
                B.grad = None

            if req_grad[0]:
161
162
163
                torch.testing.assert_allclose(
                    gradA1, gradA2, atol=0.015, rtol=0.1
                )
Tim Dettmers's avatar
Tim Dettmers committed
164
165
166
            if req_grad[1]:
                n = gradB1.numel()
                idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
167
                assert (idx == 0).sum().item() < n * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
168
                idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
169
                assert (idx == 0).sum().item() < n * 0.02
Tim Dettmers's avatar
Tim Dettmers committed
170
171
172

        if funcs[0] in [torch.matmul]:
            dim1 = dim1 - (dim1 % 16)
173
            A = torch.randn(
174
175
176
                size=(dim1, dim2, dim3),
                device="cuda",
                requires_grad=req_grad[0],
177
            )
Tim Dettmers's avatar
Tim Dettmers committed
178
            dimB = (dim4, dim3) if transpose[1] else (dim3, dim4)
179
180
            B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1])
            target = torch.randn(
181
182
183
                size=(dim1, dim2, dim4),
                device="cuda",
                requires_grad=req_grad[1],
184
            )
Tim Dettmers's avatar
Tim Dettmers committed
185
186
187
188
189
190
191
192
193
194
195
            torch.nn.init.xavier_uniform_(B)

            if transpose[1]:
                out_torch = funcs[0](A, B.t())
                out_bnb = funcs[1](A, B.t())
            else:
                out_torch = funcs[0](A, B)
                out_bnb = funcs[1](A, B)

            n = out_bnb.numel()
            idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
196
            assert (idx == 0).sum().item() < n * 0.0175
Tim Dettmers's avatar
Tim Dettmers committed
197
            idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
198
            assert (idx == 0).sum().item() < n * 0.001
Tim Dettmers's avatar
Tim Dettmers committed
199
200
201
202
203
204
205
206
207
208
209

            if any(req_grad):
                out_bnb.data.copy_(out_torch)
                torch.cuda.synchronize()
                loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
                loss_bnb.backward()
                gradA1 = A.grad
                gradB1 = B.grad
                A.grad = None
                B.grad = None

210
211
212
                loss_torch = torch.nn.functional.mse_loss(
                    out_torch, target
                ).mean()
Tim Dettmers's avatar
Tim Dettmers committed
213
214
215
216
217
218
219
                loss_torch.backward()
                gradA2 = A.grad
                gradB2 = B.grad
                A.grad = None
                B.grad = None

            if req_grad[0]:
220
221
222
                torch.testing.assert_allclose(
                    gradA1, gradA2, atol=0.015, rtol=0.1
                )
Tim Dettmers's avatar
Tim Dettmers committed
223
224
225
            if req_grad[1]:
                n = gradB1.numel()
                idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
226
                assert (idx == 0).sum().item() < n * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
227
                idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
228
                assert (idx == 0).sum().item() < n * 0.02
Tim Dettmers's avatar
Tim Dettmers committed
229
230
231
232


n = 1
k = 3
233
234
235
236
dim1 = torch.randint(16, 64, size=(n,)).tolist()
dim2 = torch.randint(32, 96, size=(n,)).tolist()
dim3 = torch.randint(32, 96, size=(n,)).tolist()
dim4 = torch.randint(32, 96, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
237

238
dim2.append(0)
Tim Dettmers's avatar
Tim Dettmers committed
239
240
241

decomp = [0.0, 6.0]
funcs = [(torch.matmul, bnb.matmul)]
242
str_funcs = ["matmul"]
Tim Dettmers's avatar
Tim Dettmers committed
243
req_grad = [(False, False), (True, False), (True, True), (False, True)]
244
req_grad_str = ["FF", "TF", "TT", "FT"]
Tim Dettmers's avatar
Tim Dettmers committed
245
transpose = [(False, True), (False, False)]
246
str_transpose = ["NT", "NN"]
Tim Dettmers's avatar
Tim Dettmers committed
247
248
dtype = [torch.float16]
has_fp16_weights = [True, False]
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
values = list(
    product(
        dim1,
        dim2,
        dim3,
        dim4,
        funcs,
        dtype,
        req_grad,
        transpose,
        decomp,
        has_fp16_weights,
    )
)
str_values = list(
    product(
        dim1,
        dim2,
        dim3,
        dim4,
        str_funcs,
        dtype,
        req_grad_str,
        str_transpose,
        decomp,
        has_fp16_weights,
    )
)
names = [
    "dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}_decomp_{8}_has_fp16_weights_{9}".format(
        *vals
    )
    for vals in str_values
]


@pytest.mark.parametrize(
    "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights",
    values,
    ids=names,
)
def test_matmullt(
291
292
293
294
295
296
297
298
299
300
    dim1,
    dim2,
    dim3,
    dim4,
    funcs,
    dtype,
    req_grad,
    transpose,
    decomp,
    has_fp16_weights,
301
):
Tim Dettmers's avatar
Tim Dettmers committed
302
303
    dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
    dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
304
    outlier_dim = torch.randint(0, dimA[1], size=(dimA[1] // 8,), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
305
306
307
308
309

    for i in range(k):

        # normal multiply
        if funcs[0] in [torch.mm, torch.matmul]:
310
311
312
            A = torch.randn(
                size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype
            )
Tim Dettmers's avatar
Tim Dettmers committed
313
314
315
            if decomp == 6.0:
                with torch.no_grad():
                    A[:, outlier_dim] = 6.0
316
317
318
319
            B = torch.randn(
                size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype
            )
            target = torch.randn(
320
321
322
323
                size=(dim2, dim4),
                device="cuda",
                requires_grad=req_grad[1],
                dtype=dtype,
324
            )
Tim Dettmers's avatar
Tim Dettmers committed
325
326
327
328
329
330
331
            torch.nn.init.xavier_uniform_(B)
            B2 = B.clone()

            state = bnb.MatmulLtState()
            state.threshold = decomp
            state.has_fp16_weights = has_fp16_weights
            if not has_fp16_weights:
332
333
334
335
336
337
338
339
340
                if not transpose[0] and not transpose[1]:
                    B2 = B2.t().contiguous()
                (
                    state.CB,
                    CBt,
                    state.SCB,
                    SCBt,
                    coo_tensorB,
                ) = bnb.functional.double_quant(B2)
Tim Dettmers's avatar
Tim Dettmers committed
341
342
343
344
345
346
347
348
349
350
                B2 = state.CB

            if not transpose[0] and transpose[1]:
                out_torch = funcs[0](A, B.t())
                out_bnb = funcs[1](A, B2, state=state)
            elif not transpose[0] and not transpose[1]:
                out_torch = funcs[0](A, B)
                out_bnb = funcs[1](A, B2.t(), state=state)

            n = out_bnb.numel()
351
352
            err = torch.abs(out_bnb - out_torch).mean().item()
            # print(f'abs error {err:.4f}')
Tim Dettmers's avatar
Tim Dettmers committed
353
            idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
Tim Dettmers's avatar
Tim Dettmers committed
354
            assert (idx == 0).sum().item() <= n * 0.0175
Tim Dettmers's avatar
Tim Dettmers committed
355
            idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
Tim Dettmers's avatar
Tim Dettmers committed
356
            assert (idx == 0).sum().item() <= n * 0.001
Tim Dettmers's avatar
Tim Dettmers committed
357
358
359
360
361

            if has_fp16_weights:
                if any(req_grad):
                    out_bnb.data.copy_(out_torch)
                    torch.cuda.synchronize()
362
363
364
                    loss_bnb = torch.nn.functional.mse_loss(
                        out_bnb, target
                    ).mean()
Tim Dettmers's avatar
Tim Dettmers committed
365
366
367
368
369
370
                    loss_bnb.backward()
                    gradA1 = A.grad
                    gradB1 = B.grad
                    A.grad = None
                    B.grad = None

371
372
373
                    loss_torch = torch.nn.functional.mse_loss(
                        out_torch, target
                    ).mean()
Tim Dettmers's avatar
Tim Dettmers committed
374
375
376
377
378
379
380
                    loss_torch.backward()
                    gradA2 = A.grad
                    gradB2 = B.grad
                    A.grad = None
                    B.grad = None

                if req_grad[0]:
381
382
383
                    torch.testing.assert_allclose(
                        gradA1, gradA2, atol=0.015, rtol=0.1
                    )
Tim Dettmers's avatar
Tim Dettmers committed
384
385
                if req_grad[1]:
                    n = gradB1.numel()
386
387
388
389
390
391
                    if dim2 > 0:
                        assert torch.abs(gradB1).sum() > 0.0
                        assert torch.abs(gradB2).sum() > 0.0
                    else:
                        assert torch.abs(gradB1).sum() == 0.0
                        assert torch.abs(gradB2).sum() == 0.0
Tim Dettmers's avatar
Tim Dettmers committed
392
393
                    idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)

Tim Dettmers's avatar
Tim Dettmers committed
394
                    assert (idx == 0).sum().item() <= n * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
395
                    idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
Tim Dettmers's avatar
Tim Dettmers committed
396
                    assert (idx == 0).sum().item() <= n * 0.02
397
398
399
                    torch.testing.assert_allclose(
                        gradB1, gradB2, atol=0.18, rtol=0.3
                    )