test_autograd.py 10.6 KB
Newer Older
1
import pytest
Tim Dettmers's avatar
Tim Dettmers committed
2
3
import torch

4
import bitsandbytes as bnb
Aarni Koskela's avatar
Aarni Koskela committed
5
6
7
8
from tests.helpers import (
    BOOLEAN_TRIPLES,
    TRUE_FALSE,
    describe_dtype,
9
    get_available_devices,
Aarni Koskela's avatar
Aarni Koskela committed
10
    id_formatter,
11
    is_supported_on_hpu,
12
)
Aarni Koskela's avatar
Aarni Koskela committed
13
14
15
16

TRANSPOSE_VALS = [(False, True), (False, False)]


17
@pytest.mark.parametrize("device", get_available_devices())
18
19
20
21
@pytest.mark.parametrize("dim1", [40], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [64, 0], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", [32], ids=id_formatter("dim3"))
@pytest.mark.parametrize("dim4", [48], ids=id_formatter("dim4"))
Aarni Koskela's avatar
Aarni Koskela committed
22
@pytest.mark.parametrize("decomp", [0.0, 6.0], ids=id_formatter("decomp"))
Ruff's avatar
Ruff committed
23
24
25
26
27
@pytest.mark.parametrize(
    "funcs",
    [(torch.matmul, bnb.matmul), (torch.matmul, bnb.research.switchback_bnb)],
    ids=["func=matmul", "func=switchback_bnb"],
)
Aarni Koskela's avatar
Aarni Koskela committed
28
29
30
31
32
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
@pytest.mark.parametrize("req_grad", BOOLEAN_TRIPLES, ids=id_formatter("req_grad"))
@pytest.mark.parametrize("transpose", TRANSPOSE_VALS, ids=id_formatter("transpose"))
@pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights"))
@pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias"))
33
34
35
def test_matmullt(
    device, dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights, has_bias
):
36
37
38
39
40
41
42
43
44
    if device != "cuda":
        if funcs[1] == bnb.research.switchback_bnb:
            # TODO: Deprecate/remove?
            pytest.skip("switchback_bnb only works on CUDA.")

        if req_grad[1]:
            # This will be deprecated for CUDA in the future. We don't expect
            # this to work on any other device.
            pytest.skip("Deprecated feature with CUDA support only.")
45

Tim Dettmers's avatar
Tim Dettmers committed
46
47
    dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
    dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
48
    outlier_dim = torch.randint(0, dimA[1], size=(dimA[1] // 8,), device=device)
Tim Dettmers's avatar
Tim Dettmers committed
49
50
51
    if has_bias == False:
        req_grad = list(req_grad)
        req_grad[2] = False
Tim Dettmers's avatar
Tim Dettmers committed
52

53
54
55
56
    if device == "cpu" and dtype != torch.float32 and has_fp16_weights and any(req_grad):
        if torch.__version__ < (2, 6):
            pytest.xfail("mse_loss bf16/fp16 on CPU is not supported in torch < 2.6")

Aarni Koskela's avatar
Aarni Koskela committed
57
    for i in range(3):
Tim Dettmers's avatar
Tim Dettmers committed
58
59
        # normal multiply
        if funcs[0] in [torch.mm, torch.matmul]:
60
            A = torch.randn(size=dimA, device=device, requires_grad=req_grad[0], dtype=dtype)
Tim Dettmers's avatar
Tim Dettmers committed
61
62
63
            if decomp == 6.0:
                with torch.no_grad():
                    A[:, outlier_dim] = 6.0
64
            B = torch.randn(size=dimB, device=device, requires_grad=req_grad[1], dtype=dtype)
65
            target = torch.randn(
66
                size=(dim2, dim4),
67
                device=device,
68
69
                requires_grad=req_grad[1],
                dtype=dtype,
70
            )
Tim Dettmers's avatar
Tim Dettmers committed
71
72
            bias = None
            bias2 = None
73
            if has_bias:
74
                bias = torch.randn(dim4, device=device, dtype=dtype, requires_grad=req_grad[2])
Tim Dettmers's avatar
Tim Dettmers committed
75
                bias2 = bias.clone()
Tim Dettmers's avatar
Tim Dettmers committed
76
77
78
79
80
81
82
            torch.nn.init.xavier_uniform_(B)
            B2 = B.clone()

            state = bnb.MatmulLtState()
            state.threshold = decomp
            state.has_fp16_weights = has_fp16_weights
            if not has_fp16_weights:
83
84
                if not transpose[0] and not transpose[1]:
                    B2 = B2.t().contiguous()
85
86

                state.CB, state.SCB, _ = bnb.functional.int8_vectorwise_quant(B2.to(torch.float16))
Tim Dettmers's avatar
Tim Dettmers committed
87
88
89
90
                B2 = state.CB

            if not transpose[0] and transpose[1]:
                out_torch = funcs[0](A, B.t())
Tim Dettmers's avatar
Tim Dettmers committed
91
                out_bnb = funcs[1](A, B2, state=state, bias=bias2)
Tim Dettmers's avatar
Tim Dettmers committed
92
93
            elif not transpose[0] and not transpose[1]:
                out_torch = funcs[0](A, B)
Tim Dettmers's avatar
Tim Dettmers committed
94
95
96
97
                out_bnb = funcs[1](A, B2.t(), state=state, bias=bias2)

            if has_bias:
                out_torch += bias
Tim Dettmers's avatar
Tim Dettmers committed
98

justheuristic's avatar
justheuristic committed
99
            assert out_bnb.dtype == A.dtype, f"bnb matmullt received {A.dtype} but returned {out_bnb.dtype}"
justheuristic's avatar
justheuristic committed
100

Tim Dettmers's avatar
Tim Dettmers committed
101
            n = out_bnb.numel()
102
103
            err = torch.abs(out_bnb - out_torch).mean().item()
            # print(f'abs error {err:.4f}')
justheuristic's avatar
justheuristic committed
104

Tim Dettmers's avatar
Tim Dettmers committed
105
            idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
justheuristic's avatar
justheuristic committed
106
            assert (idx == 0).sum().item() <= n * (0.0175 if dtype == torch.float16 else 0.021)
Tim Dettmers's avatar
Tim Dettmers committed
107
            idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
Tim Dettmers's avatar
Tim Dettmers committed
108
            assert (idx == 0).sum().item() <= n * 0.001
Tim Dettmers's avatar
Tim Dettmers committed
109
110
111
112

            if has_fp16_weights:
                if any(req_grad):
                    out_bnb.data.copy_(out_torch)
113
114
                    if device == "cuda":
                        torch.cuda.synchronize()
Ruff's avatar
Ruff committed
115
                    loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
Tim Dettmers's avatar
Tim Dettmers committed
116
117
118
119
120
                    loss_bnb.backward()
                    gradA1 = A.grad
                    gradB1 = B.grad
                    A.grad = None
                    B.grad = None
Tim Dettmers's avatar
Tim Dettmers committed
121
122
123
                    if has_bias:
                        gradBias1 = bias.grad
                        bias.grad = None
Tim Dettmers's avatar
Tim Dettmers committed
124

Ruff's avatar
Ruff committed
125
                    loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
Tim Dettmers's avatar
Tim Dettmers committed
126
127
128
129
130
                    loss_torch.backward()
                    gradA2 = A.grad
                    gradB2 = B.grad
                    A.grad = None
                    B.grad = None
Tim Dettmers's avatar
Tim Dettmers committed
131
132
133
                    if has_bias:
                        gradBias2 = bias.grad
                        bias.grad = None
Tim Dettmers's avatar
Tim Dettmers committed
134
135

                if req_grad[0]:
Ruff's avatar
Ruff committed
136
                    torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1)
Tim Dettmers's avatar
Tim Dettmers committed
137
138
                if req_grad[1]:
                    n = gradB1.numel()
139
140
141
142
143
144
                    if dim2 > 0:
                        assert torch.abs(gradB1).sum() > 0.0
                        assert torch.abs(gradB2).sum() > 0.0
                    else:
                        assert torch.abs(gradB1).sum() == 0.0
                        assert torch.abs(gradB2).sum() == 0.0
145

Tim Dettmers's avatar
Tim Dettmers committed
146
                    idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
147
                    assert (idx == 0).sum().item() <= n * 0.10
Tim Dettmers's avatar
Tim Dettmers committed
148
149

                    idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
Tim Dettmers's avatar
Tim Dettmers committed
150
                    assert (idx == 0).sum().item() <= n * 0.02
151

Ruff's avatar
Ruff committed
152
                    torch.testing.assert_close(gradB1, gradB2, atol=0.18, rtol=0.3)
Tim Dettmers's avatar
Tim Dettmers committed
153
154

                if req_grad[2]:
155
                    torch.testing.assert_close(gradBias1, gradBias2)
Tim Dettmers's avatar
Tim Dettmers committed
156
157


158
@pytest.mark.parametrize("device", get_available_devices())
Matthew Douglas's avatar
Matthew Douglas committed
159
160
161
162
@pytest.mark.parametrize("dim1", [48], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [64, 0], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", [64], ids=id_formatter("dim3"))
@pytest.mark.parametrize("dim4", [96], ids=id_formatter("dim4"))
Aarni Koskela's avatar
Aarni Koskela committed
163
164
165
166
167
168
@pytest.mark.parametrize("funcs", [(torch.matmul, bnb.matmul_4bit)], ids=["func=matmul"])
@pytest.mark.parametrize("req_grad", BOOLEAN_TRIPLES, ids=id_formatter("req_grad"))
@pytest.mark.parametrize("transpose", TRANSPOSE_VALS, ids=id_formatter("transpose"))
@pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias"))
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=describe_dtype)
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
Ruff's avatar
Ruff committed
169
170
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"], ids=id_formatter("quant_type"))
def test_matmul_4bit(
171
    device,
Ruff's avatar
Ruff committed
172
173
174
175
176
177
178
179
180
181
182
183
    dim1,
    dim2,
    dim3,
    dim4,
    funcs,
    dtype,
    req_grad,
    transpose,
    has_bias,
    compress_statistics,
    quant_type,
):
Tim Dettmers's avatar
Tim Dettmers committed
184
185
186
187
188
189
    dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
    dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
    if has_bias == False:
        req_grad = list(req_grad)
        req_grad[2] = False

190
191
192
    if device == "cpu" and dtype != torch.float32 and any(req_grad) and torch.__version__ < (2, 6):
        pytest.xfail("mse_loss fp16 on CPU is not supported in torch < 2.6")

193
194
195
    if device == "hpu" and not is_supported_on_hpu(quant_type, dtype):
        pytest.skip("This configuration is not supported on HPU.")

Aarni Koskela's avatar
Aarni Koskela committed
196
    for i in range(3):
Tim Dettmers's avatar
Tim Dettmers committed
197
198
        # normal multiply
        if funcs[0] in [torch.mm, torch.matmul]:
199
200
201
            A = torch.randn(size=dimA, device=device, requires_grad=req_grad[0], dtype=dtype)
            B = torch.randn(size=dimB, device=device, requires_grad=req_grad[1], dtype=dtype)
            target = torch.randn(size=(dim2, dim4), device=device, requires_grad=req_grad[1], dtype=dtype)
Tim Dettmers's avatar
Tim Dettmers committed
202
203
204
            bias = None
            bias2 = None
            if has_bias:
205
                bias = torch.randn(dim4, device=device, dtype=dtype, requires_grad=req_grad[2])
Tim Dettmers's avatar
Tim Dettmers committed
206
207
208
                bias2 = bias.clone()
            torch.nn.init.xavier_uniform_(B)

Ruff's avatar
Ruff committed
209
210
211
212
213
            B2, quant_state = bnb.functional.quantize_4bit(
                B,
                compress_statistics=compress_statistics,
                quant_type=quant_type,
            )
Tim Dettmers's avatar
Tim Dettmers committed
214
215
216

            if not transpose[0] and transpose[1]:
                out_torch = funcs[0](A, B.t())
Tim Dettmers's avatar
Tim Dettmers committed
217
                out_bnb = funcs[1](A, B2.t(), quant_state, bias=bias2)
Tim Dettmers's avatar
Tim Dettmers committed
218
219
            elif not transpose[0] and not transpose[1]:
                out_torch = funcs[0](A, B)
Tim Dettmers's avatar
Tim Dettmers committed
220
                out_bnb = funcs[1](A, B2, quant_state, bias=bias2)
Tim Dettmers's avatar
Tim Dettmers committed
221
222
223
224
225
226
227
228
229

            if has_bias:
                out_torch += bias

            assert out_bnb.dtype == A.dtype, f"bnb matmullt received {A.dtype} but returned {out_bnb.dtype}"

            n = out_bnb.numel()
            err = torch.abs(out_bnb - out_torch).float().mean().item()
            if n > 0:
230
                assert err < 0.115
Tim Dettmers's avatar
Tim Dettmers committed
231

Ruff's avatar
Ruff committed
232
                # assert err < 0.20
Tim Dettmers's avatar
Tim Dettmers committed
233
234
            if any(req_grad):
                out_bnb.data.copy_(out_torch)
235
236
                if device == "cuda":
                    torch.cuda.synchronize()
237
238
239
                elif device == "hpu":
                    torch.hpu.synchronize()

Tim Dettmers's avatar
Tim Dettmers committed
240
241
242
243
244
245
246
247
248
249
                loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
                loss_bnb.backward()
                gradA1 = A.grad
                gradB1 = B.grad
                A.grad = None
                B.grad = None
                if has_bias:
                    gradBias1 = bias.grad
                    bias.grad = None

Ruff's avatar
Ruff committed
250
                loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
Tim Dettmers's avatar
Tim Dettmers committed
251
252
253
254
255
256
257
258
259
260
                loss_torch.backward()
                gradA2 = A.grad
                gradB2 = B.grad
                A.grad = None
                B.grad = None
                if has_bias:
                    gradBias2 = bias.grad
                    bias.grad = None

                if req_grad[0]:
Ruff's avatar
Ruff committed
261
                    torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1)
Tim Dettmers's avatar
Tim Dettmers committed
262
263

                if req_grad[2]:
264
                    torch.testing.assert_close(gradBias1, gradBias2)