"docs_zh_CN/conf.py" did not exist on "da39212f94a3d433cf99e55cfecd77ebf359e0ca"
test_autograd.py 20.9 KB
Newer Older
Aarni Koskela's avatar
Aarni Koskela committed
1
from typing import Tuple
Tim Dettmers's avatar
Tim Dettmers committed
2

3
import pytest
Tim Dettmers's avatar
Tim Dettmers committed
4
5
import torch

6
import bitsandbytes as bnb
Aarni Koskela's avatar
Aarni Koskela committed
7
8
9
10
11
12
13
from tests.helpers import (
    BOOLEAN_TRIPLES,
    BOOLEAN_TUPLES,
    TRUE_FALSE,
    describe_dtype,
    get_test_dims,
    id_formatter,
14
)
Aarni Koskela's avatar
Aarni Koskela committed
15
16
17
18
19
20
21
22

TRANSPOSE_VALS = [(False, True), (False, False)]


@pytest.mark.parametrize("dim1", get_test_dims(16, 64, n=1), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", get_test_dims(32, 96, n=1), ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3"))
@pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4"))
Ruff's avatar
Ruff committed
23
24
25
26
27
@pytest.mark.parametrize(
    "funcs",
    [(torch.bmm, bnb.bmm_cublas), (torch.matmul, bnb.matmul_cublas)],
    ids=["func=bmm", "func=matmul"],
)
Aarni Koskela's avatar
Aarni Koskela committed
28
29
30
31
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=describe_dtype)
@pytest.mark.parametrize("req_grad", BOOLEAN_TUPLES, ids=id_formatter("req_grad"))
@pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose"))
def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool], transpose: Tuple[bool, bool]):
32
33
    if dim2 > 0:
        dim2 = dim2 - (dim2 % 16)
Tim Dettmers's avatar
Tim Dettmers committed
34
35
    dim3 = dim3 - (dim3 % 16)
    dim4 = dim4 - (dim4 % 16)
Aarni Koskela's avatar
Aarni Koskela committed
36
    for i in range(25):
Tim Dettmers's avatar
Tim Dettmers committed
37
38
39
40
        # normal multiply
        if funcs[0] in [torch.mm, torch.matmul]:
            dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
            dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
41
42
            A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0])
            B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1])
Ruff's avatar
Ruff committed
43
            target = torch.randn(size=(dim2, dim4), device="cuda", requires_grad=req_grad[1])
Tim Dettmers's avatar
Tim Dettmers committed
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
            torch.nn.init.xavier_uniform_(B)

            if not transpose[0] and not transpose[1]:
                out_torch = funcs[0](A, B)
                out_bnb = funcs[1](A, B)
            elif not transpose[0] and transpose[1]:
                out_torch = funcs[0](A, B.t())
                out_bnb = funcs[1](A, B.t())
            elif transpose[0] and not transpose[1]:
                out_torch = funcs[0](A.t(), B)
                out_bnb = funcs[1](A.t(), B)
            elif transpose[0] and transpose[1]:
                out_torch = funcs[0](A.t(), B.t())
                out_bnb = funcs[1](A.t(), B.t())

            n = out_bnb.numel()
            idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
61
            assert (idx == 0).sum().item() < n * 0.0175
Tim Dettmers's avatar
Tim Dettmers committed
62
            idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
63
            assert (idx == 0).sum().item() < n * 0.001
Tim Dettmers's avatar
Tim Dettmers committed
64
65
66
67
68
69
70
71
72
73
74

            if any(req_grad):
                out_bnb.data.copy_(out_torch)
                torch.cuda.synchronize()
                loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
                loss_bnb.backward()
                gradA1 = A.grad
                gradB1 = B.grad
                A.grad = None
                B.grad = None

Ruff's avatar
Ruff committed
75
                loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
Tim Dettmers's avatar
Tim Dettmers committed
76
77
78
79
80
81
82
                loss_torch.backward()
                gradA2 = A.grad
                gradB2 = B.grad
                A.grad = None
                B.grad = None

            if req_grad[0]:
Ruff's avatar
Ruff committed
83
                torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1)
Tim Dettmers's avatar
Tim Dettmers committed
84
85
86
            if req_grad[1]:
                n = gradB1.numel()
                idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
87
                assert (idx == 0).sum().item() < n * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
88
                idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
89
                assert (idx == 0).sum().item() < n * 0.02
Ruff's avatar
Ruff committed
90
                torch.testing.assert_close(gradB1, gradB2, atol=0.18, rtol=0.3)
Tim Dettmers's avatar
Tim Dettmers committed
91
92
93

        # batched matrix multiply
        if funcs[0] in [torch.bmm, torch.matmul]:
94
            A = torch.randn(
95
96
97
                size=(dim1, dim2, dim3),
                device="cuda",
                requires_grad=req_grad[0],
98
99
            )
            B = torch.randn(
100
101
102
                size=(dim1, dim3, dim4),
                device="cuda",
                requires_grad=req_grad[1],
103
104
            )
            target = torch.randn(
105
106
107
                size=(dim1, dim2, dim4),
                device="cuda",
                requires_grad=req_grad[1],
108
            )
Tim Dettmers's avatar
Tim Dettmers committed
109
110
111
112
113
114
115
            torch.nn.init.xavier_uniform_(B)

            out_torch = funcs[0](A, B)
            out_bnb = funcs[1](A, B)

            n = out_bnb.numel()
            idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
116
            assert (idx == 0).sum().item() < n * 0.01
Ruff's avatar
Ruff committed
117
            torch.testing.assert_close(out_bnb, out_torch, atol=0.027, rtol=0.2)
Tim Dettmers's avatar
Tim Dettmers committed
118
119
120
121
122
123
124
125
126
127
128

            if any(req_grad):
                out_bnb.data.copy_(out_torch)
                torch.cuda.synchronize()
                loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
                loss_bnb.backward()
                gradA1 = A.grad
                gradB1 = B.grad
                A.grad = None
                B.grad = None

Ruff's avatar
Ruff committed
129
                loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
Tim Dettmers's avatar
Tim Dettmers committed
130
131
132
133
134
135
136
                loss_torch.backward()
                gradA2 = A.grad
                gradB2 = B.grad
                A.grad = None
                B.grad = None

            if req_grad[0]:
Ruff's avatar
Ruff committed
137
                torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1)
Tim Dettmers's avatar
Tim Dettmers committed
138
139
140
            if req_grad[1]:
                n = gradB1.numel()
                idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
141
                assert (idx == 0).sum().item() < n * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
142
                idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
143
                assert (idx == 0).sum().item() < n * 0.02
Tim Dettmers's avatar
Tim Dettmers committed
144
145
146

        if funcs[0] in [torch.matmul]:
            dim1 = dim1 - (dim1 % 16)
147
            A = torch.randn(
148
149
150
                size=(dim1, dim2, dim3),
                device="cuda",
                requires_grad=req_grad[0],
151
            )
Tim Dettmers's avatar
Tim Dettmers committed
152
            dimB = (dim4, dim3) if transpose[1] else (dim3, dim4)
153
154
            B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1])
            target = torch.randn(
155
156
157
                size=(dim1, dim2, dim4),
                device="cuda",
                requires_grad=req_grad[1],
158
            )
Tim Dettmers's avatar
Tim Dettmers committed
159
160
161
162
163
164
165
166
167
168
169
            torch.nn.init.xavier_uniform_(B)

            if transpose[1]:
                out_torch = funcs[0](A, B.t())
                out_bnb = funcs[1](A, B.t())
            else:
                out_torch = funcs[0](A, B)
                out_bnb = funcs[1](A, B)

            n = out_bnb.numel()
            idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
170
            assert (idx == 0).sum().item() < n * 0.0175
Tim Dettmers's avatar
Tim Dettmers committed
171
            idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
172
            assert (idx == 0).sum().item() < n * 0.001
Tim Dettmers's avatar
Tim Dettmers committed
173
174
175
176
177
178
179
180
181
182
183

            if any(req_grad):
                out_bnb.data.copy_(out_torch)
                torch.cuda.synchronize()
                loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
                loss_bnb.backward()
                gradA1 = A.grad
                gradB1 = B.grad
                A.grad = None
                B.grad = None

Ruff's avatar
Ruff committed
184
                loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
Tim Dettmers's avatar
Tim Dettmers committed
185
186
187
188
189
190
191
                loss_torch.backward()
                gradA2 = A.grad
                gradB2 = B.grad
                A.grad = None
                B.grad = None

            if req_grad[0]:
Ruff's avatar
Ruff committed
192
                torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1)
Tim Dettmers's avatar
Tim Dettmers committed
193
194
195
            if req_grad[1]:
                n = gradB1.numel()
                idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
196
                assert (idx == 0).sum().item() < n * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
197
                idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
198
                assert (idx == 0).sum().item() < n * 0.02
Tim Dettmers's avatar
Tim Dettmers committed
199
200


Aarni Koskela's avatar
Aarni Koskela committed
201
202
203
204
205
@pytest.mark.parametrize("dim1", get_test_dims(16, 64, n=1), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [*get_test_dims(32, 96, n=1), 0], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3"))
@pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4"))
@pytest.mark.parametrize("decomp", [0.0, 6.0], ids=id_formatter("decomp"))
Ruff's avatar
Ruff committed
206
207
208
209
210
@pytest.mark.parametrize(
    "funcs",
    [(torch.matmul, bnb.matmul), (torch.matmul, bnb.research.switchback_bnb)],
    ids=["func=matmul", "func=switchback_bnb"],
)
Aarni Koskela's avatar
Aarni Koskela committed
211
212
213
214
215
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
@pytest.mark.parametrize("req_grad", BOOLEAN_TRIPLES, ids=id_formatter("req_grad"))
@pytest.mark.parametrize("transpose", TRANSPOSE_VALS, ids=id_formatter("transpose"))
@pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights"))
@pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias"))
Ruff's avatar
Ruff committed
216
def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights, has_bias):
Tim Dettmers's avatar
Tim Dettmers committed
217
218
    dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
    dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
219
    outlier_dim = torch.randint(0, dimA[1], size=(dimA[1] // 8,), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
220
221
222
    if has_bias == False:
        req_grad = list(req_grad)
        req_grad[2] = False
Tim Dettmers's avatar
Tim Dettmers committed
223

Aarni Koskela's avatar
Aarni Koskela committed
224
    for i in range(3):
Tim Dettmers's avatar
Tim Dettmers committed
225
226
        # normal multiply
        if funcs[0] in [torch.mm, torch.matmul]:
Ruff's avatar
Ruff committed
227
            A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype)
Tim Dettmers's avatar
Tim Dettmers committed
228
229
230
            if decomp == 6.0:
                with torch.no_grad():
                    A[:, outlier_dim] = 6.0
Ruff's avatar
Ruff committed
231
            B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype)
232
            target = torch.randn(
233
234
235
236
                size=(dim2, dim4),
                device="cuda",
                requires_grad=req_grad[1],
                dtype=dtype,
237
            )
Tim Dettmers's avatar
Tim Dettmers committed
238
239
            bias = None
            bias2 = None
240
            if has_bias:
Ruff's avatar
Ruff committed
241
                bias = torch.randn(dim4, device="cuda", dtype=dtype, requires_grad=req_grad[2])
Tim Dettmers's avatar
Tim Dettmers committed
242
                bias2 = bias.clone()
Tim Dettmers's avatar
Tim Dettmers committed
243
244
245
246
247
248
249
            torch.nn.init.xavier_uniform_(B)
            B2 = B.clone()

            state = bnb.MatmulLtState()
            state.threshold = decomp
            state.has_fp16_weights = has_fp16_weights
            if not has_fp16_weights:
250
251
252
253
254
255
256
257
                if not transpose[0] and not transpose[1]:
                    B2 = B2.t().contiguous()
                (
                    state.CB,
                    CBt,
                    state.SCB,
                    SCBt,
                    coo_tensorB,
justheuristic's avatar
justheuristic committed
258
                ) = bnb.functional.double_quant(B2.to(torch.float16))
Tim Dettmers's avatar
Tim Dettmers committed
259
260
261
262
                B2 = state.CB

            if not transpose[0] and transpose[1]:
                out_torch = funcs[0](A, B.t())
Tim Dettmers's avatar
Tim Dettmers committed
263
                out_bnb = funcs[1](A, B2, state=state, bias=bias2)
Tim Dettmers's avatar
Tim Dettmers committed
264
265
            elif not transpose[0] and not transpose[1]:
                out_torch = funcs[0](A, B)
Tim Dettmers's avatar
Tim Dettmers committed
266
267
268
269
                out_bnb = funcs[1](A, B2.t(), state=state, bias=bias2)

            if has_bias:
                out_torch += bias
Tim Dettmers's avatar
Tim Dettmers committed
270

justheuristic's avatar
justheuristic committed
271
            assert out_bnb.dtype == A.dtype, f"bnb matmullt received {A.dtype} but returned {out_bnb.dtype}"
justheuristic's avatar
justheuristic committed
272

Tim Dettmers's avatar
Tim Dettmers committed
273
            n = out_bnb.numel()
274
275
            err = torch.abs(out_bnb - out_torch).mean().item()
            # print(f'abs error {err:.4f}')
justheuristic's avatar
justheuristic committed
276

Tim Dettmers's avatar
Tim Dettmers committed
277
            idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
justheuristic's avatar
justheuristic committed
278
            assert (idx == 0).sum().item() <= n * (0.0175 if dtype == torch.float16 else 0.021)
Tim Dettmers's avatar
Tim Dettmers committed
279
            idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
Tim Dettmers's avatar
Tim Dettmers committed
280
            assert (idx == 0).sum().item() <= n * 0.001
Tim Dettmers's avatar
Tim Dettmers committed
281
282
283
284
285

            if has_fp16_weights:
                if any(req_grad):
                    out_bnb.data.copy_(out_torch)
                    torch.cuda.synchronize()
Ruff's avatar
Ruff committed
286
                    loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
Tim Dettmers's avatar
Tim Dettmers committed
287
288
289
290
291
                    loss_bnb.backward()
                    gradA1 = A.grad
                    gradB1 = B.grad
                    A.grad = None
                    B.grad = None
Tim Dettmers's avatar
Tim Dettmers committed
292
293
294
                    if has_bias:
                        gradBias1 = bias.grad
                        bias.grad = None
Tim Dettmers's avatar
Tim Dettmers committed
295

Ruff's avatar
Ruff committed
296
                    loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
Tim Dettmers's avatar
Tim Dettmers committed
297
298
299
300
301
                    loss_torch.backward()
                    gradA2 = A.grad
                    gradB2 = B.grad
                    A.grad = None
                    B.grad = None
Tim Dettmers's avatar
Tim Dettmers committed
302
303
304
                    if has_bias:
                        gradBias2 = bias.grad
                        bias.grad = None
Tim Dettmers's avatar
Tim Dettmers committed
305
306

                if req_grad[0]:
Ruff's avatar
Ruff committed
307
                    torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1)
Tim Dettmers's avatar
Tim Dettmers committed
308
309
                if req_grad[1]:
                    n = gradB1.numel()
310
311
312
313
314
315
                    if dim2 > 0:
                        assert torch.abs(gradB1).sum() > 0.0
                        assert torch.abs(gradB2).sum() > 0.0
                    else:
                        assert torch.abs(gradB1).sum() == 0.0
                        assert torch.abs(gradB2).sum() == 0.0
Tim Dettmers's avatar
Tim Dettmers committed
316
317
                    idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)

Tim Dettmers's avatar
Tim Dettmers committed
318
                    assert (idx == 0).sum().item() <= n * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
319
                    idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
Tim Dettmers's avatar
Tim Dettmers committed
320
                    assert (idx == 0).sum().item() <= n * 0.02
Ruff's avatar
Ruff committed
321
                    torch.testing.assert_close(gradB1, gradB2, atol=0.18, rtol=0.3)
Tim Dettmers's avatar
Tim Dettmers committed
322
323

                if req_grad[2]:
324
                    torch.testing.assert_close(gradBias1, gradBias2)
Tim Dettmers's avatar
Tim Dettmers committed
325
326


Aarni Koskela's avatar
Aarni Koskela committed
327
328
329
330
331
332
333
334
335
336
@pytest.mark.parametrize("dim1", get_test_dims(16, 64, n=1), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [*get_test_dims(32, 96, n=1), 0], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3"))
@pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4"))
@pytest.mark.parametrize("funcs", [(torch.matmul, bnb.matmul_4bit)], ids=["func=matmul"])
@pytest.mark.parametrize("req_grad", BOOLEAN_TRIPLES, ids=id_formatter("req_grad"))
@pytest.mark.parametrize("transpose", TRANSPOSE_VALS, ids=id_formatter("transpose"))
@pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias"))
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=describe_dtype)
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
Ruff's avatar
Ruff committed
337
338
339
340
341
342
343
344
345
346
347
348
349
350
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"], ids=id_formatter("quant_type"))
def test_matmul_4bit(
    dim1,
    dim2,
    dim3,
    dim4,
    funcs,
    dtype,
    req_grad,
    transpose,
    has_bias,
    compress_statistics,
    quant_type,
):
Tim Dettmers's avatar
Tim Dettmers committed
351
352
353
354
355
356
    dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
    dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
    if has_bias == False:
        req_grad = list(req_grad)
        req_grad[2] = False

Aarni Koskela's avatar
Aarni Koskela committed
357
    for i in range(3):
Tim Dettmers's avatar
Tim Dettmers committed
358
359
360
361
362
363
364
365
        # normal multiply
        if funcs[0] in [torch.mm, torch.matmul]:
            A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype)
            B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype)
            target = torch.randn(size=(dim2, dim4), device="cuda", requires_grad=req_grad[1], dtype=dtype)
            bias = None
            bias2 = None
            if has_bias:
Ruff's avatar
Ruff committed
366
                bias = torch.randn(dim4, device="cuda", dtype=dtype, requires_grad=req_grad[2])
Tim Dettmers's avatar
Tim Dettmers committed
367
368
369
                bias2 = bias.clone()
            torch.nn.init.xavier_uniform_(B)

Ruff's avatar
Ruff committed
370
371
372
373
374
            B2, quant_state = bnb.functional.quantize_4bit(
                B,
                compress_statistics=compress_statistics,
                quant_type=quant_type,
            )
Tim Dettmers's avatar
Tim Dettmers committed
375
376
377

            if not transpose[0] and transpose[1]:
                out_torch = funcs[0](A, B.t())
Tim Dettmers's avatar
Tim Dettmers committed
378
                out_bnb = funcs[1](A, B2.t(), quant_state, bias=bias2)
Tim Dettmers's avatar
Tim Dettmers committed
379
380
            elif not transpose[0] and not transpose[1]:
                out_torch = funcs[0](A, B)
Tim Dettmers's avatar
Tim Dettmers committed
381
                out_bnb = funcs[1](A, B2, quant_state, bias=bias2)
Tim Dettmers's avatar
Tim Dettmers committed
382
383
384
385
386
387
388
389
390

            if has_bias:
                out_torch += bias

            assert out_bnb.dtype == A.dtype, f"bnb matmullt received {A.dtype} but returned {out_bnb.dtype}"

            n = out_bnb.numel()
            err = torch.abs(out_bnb - out_torch).float().mean().item()
            if n > 0:
391
                assert err < 0.115
Tim Dettmers's avatar
Tim Dettmers committed
392

Ruff's avatar
Ruff committed
393
                # assert err < 0.20
Tim Dettmers's avatar
Tim Dettmers committed
394
395
396
397
398
399
400
401
402
403
404
405
406
            if any(req_grad):
                out_bnb.data.copy_(out_torch)
                torch.cuda.synchronize()
                loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
                loss_bnb.backward()
                gradA1 = A.grad
                gradB1 = B.grad
                A.grad = None
                B.grad = None
                if has_bias:
                    gradBias1 = bias.grad
                    bias.grad = None

Ruff's avatar
Ruff committed
407
                loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
Tim Dettmers's avatar
Tim Dettmers committed
408
409
410
411
412
413
414
415
416
417
                loss_torch.backward()
                gradA2 = A.grad
                gradB2 = B.grad
                A.grad = None
                B.grad = None
                if has_bias:
                    gradBias2 = bias.grad
                    bias.grad = None

                if req_grad[0]:
Ruff's avatar
Ruff committed
418
                    torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1)
Tim Dettmers's avatar
Tim Dettmers committed
419
420

                if req_grad[2]:
421
422
423
                    torch.testing.assert_close(gradBias1, gradBias2)


Aarni Koskela's avatar
Aarni Koskela committed
424
425
426
427
428
429
430
@pytest.mark.parametrize("dim1", get_test_dims(16, 64, n=1), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [*get_test_dims(32, 96, n=1), 0], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3"))
@pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4"))
@pytest.mark.parametrize("req_grad", BOOLEAN_TRIPLES, ids=id_formatter("req_grad"))
@pytest.mark.parametrize("transpose", TRANSPOSE_VALS, ids=id_formatter("transpose"))
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=describe_dtype)
Ruff's avatar
Ruff committed
431
432
433
434
435
436
@pytest.mark.parametrize(
    "funcs",
    [(torch.matmul, bnb.research.matmul_fp8_mixed), (torch.matmul, bnb.research.matmul_fp8_global)],
    ids=["matmul_fp8_mixed", "matmul_fp8_global"],
)
def test_matmul_fp8(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
Tim Dettmers's avatar
Tim Dettmers committed
437
438
    dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
    dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
439
440
    req_grad = list(req_grad)
    req_grad[2] = False
Tim Dettmers's avatar
Tim Dettmers committed
441

Aarni Koskela's avatar
Aarni Koskela committed
442
    for i in range(3):
Tim Dettmers's avatar
Tim Dettmers committed
443
444
445
446
447
        # normal multiply
        if funcs[0] in [torch.mm, torch.matmul]:
            A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype)
            B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype)
            target = torch.randn(size=(dim2, dim4), device="cuda", requires_grad=req_grad[1], dtype=dtype)
448

Tim Dettmers's avatar
Tim Dettmers committed
449
450
            torch.nn.init.xavier_uniform_(B)

451
452
            fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(A.device)
            bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(A.device)
Tim Dettmers's avatar
Tim Dettmers committed
453
454
455

            if not transpose[0] and transpose[1]:
                out_torch = funcs[0](A, B.t())
456
                out_bnb = funcs[1](A, B.t(), fw_code, bw_code)
Tim Dettmers's avatar
Tim Dettmers committed
457
458
            elif not transpose[0] and not transpose[1]:
                out_torch = funcs[0](A, B)
459
                out_bnb = funcs[1](A, B, fw_code, bw_code)
Tim Dettmers's avatar
Tim Dettmers committed
460
461
462
463
464
465

            assert out_bnb.dtype == A.dtype, f"bnb matmullt received {A.dtype} but returned {out_bnb.dtype}"

            n = out_bnb.numel()
            err = torch.abs(out_bnb - out_torch).float().mean().item()
            if n > 0:
466
                assert err < 0.115
Ruff's avatar
Ruff committed
467
                # assert err < 0.20
Tim Dettmers's avatar
Tim Dettmers committed
468
469
470
471
472
473
474
475
476
477
            if any(req_grad):
                out_bnb.data.copy_(out_torch)
                torch.cuda.synchronize()
                loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
                loss_bnb.backward()
                gradA1 = A.grad
                gradB1 = B.grad
                A.grad = None
                B.grad = None

Ruff's avatar
Ruff committed
478
                loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
Tim Dettmers's avatar
Tim Dettmers committed
479
480
481
482
483
484
485
                loss_torch.backward()
                gradA2 = A.grad
                gradB2 = B.grad
                A.grad = None
                B.grad = None

                if req_grad[0]:
Ruff's avatar
Ruff committed
486
                    torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1)
Tim Dettmers's avatar
Tim Dettmers committed
487

488
489
490
491
492
493
494
495
496
497
498
499
500
                if req_grad[1]:
                    n = gradB1.numel()
                    if dim2 > 0:
                        assert torch.abs(gradB1).sum() > 0.0
                        assert torch.abs(gradB2).sum() > 0.0
                    else:
                        assert torch.abs(gradB1).sum() == 0.0
                        assert torch.abs(gradB2).sum() == 0.0
                    idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)

                    assert (idx == 0).sum().item() <= n * 0.1
                    idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
                    assert (idx == 0).sum().item() <= n * 0.02
Ruff's avatar
Ruff committed
501
                    grad_err = (gradB1 - gradB2).abs().mean()
502
                    assert grad_err.item() < 0.003
Ruff's avatar
Ruff committed
503
                    torch.testing.assert_close(gradB1, gradB2, atol=0.18, rtol=0.3)