"vscode:/vscode.git/clone" did not exist on "011eedfbf6a86d22da443141139519493e33eec5"
test_linear8bitlt.py 10.7 KB
Newer Older
1
from contextlib import nullcontext
2
import copy
Aarni Koskela's avatar
Aarni Koskela committed
3
import os
4
import pickle
5
import platform
6
from tempfile import TemporaryDirectory
7

8
9
10
import pytest
import torch

11
import bitsandbytes as bnb
12
from bitsandbytes.cextension import HIP_ENVIRONMENT
13
from bitsandbytes.nn.modules import Linear8bitLt
14
15
from tests.helpers import (
    TRUE_FALSE,
16
    get_available_devices,
17
18
19
20
    id_formatter,
    torch_load_from_buffer,
    torch_save_to_buffer,
)
21

22

23
24
# contributed by Alex Borzunov, see:
# https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py
25
26
@pytest.mark.parametrize("device", get_available_devices())
def test_linear_no_igemmlt(device):
27
28
29
30
31
32
33
34
35
    linear = torch.nn.Linear(1024, 3072)
    x = torch.randn(3, 1024, dtype=torch.half)
    linear_custom = Linear8bitLt(
        linear.in_features,
        linear.out_features,
        linear.bias is not None,
        has_fp16_weights=False,
        threshold=6.0,
    )
36
37

    # TODO: Remove, this is no longer implemented
38
39
40
    linear_custom.state.force_no_igemmlt = True

    linear_custom.weight = bnb.nn.Int8Params(
Ruff's avatar
Ruff committed
41
42
43
        linear.weight.data.clone(),
        requires_grad=False,
        has_fp16_weights=False,
44
45
    ).to(linear.weight.dtype)
    linear_custom.bias = linear.bias
46
47
    linear_custom = linear_custom.to(device)
    linear = linear.half().to(device)
48

49
50
    x_ref = x.clone().to(device).requires_grad_(True)
    x_ours = x.clone().to(device).requires_grad_(True)
51
52
53
54
55
56
    fx_ref = linear(x_ref).float()
    grad_proj = torch.randn_like(fx_ref)
    (fx_ref * grad_proj).mean().backward()

    fx_ours = linear_custom(x_ours).float()
    (fx_ours * grad_proj).mean().backward()
57

58
    assert linear_custom.state.CB is not None
59
60
61
62
63
64
    assert not linear_custom.state.has_fp16_weights

    idx = torch.isclose(fx_ref, fx_ours, atol=0.02, rtol=1e-5)
    assert (idx == 0).sum().item() < fx_ref.numel() * 2.5e-4
    torch.testing.assert_close(fx_ref, fx_ours, atol=0.03, rtol=1e-5)
    torch.testing.assert_close(x_ref.grad, x_ours.grad, atol=0.01, rtol=1e-5)
65
66


67
@pytest.mark.parametrize("device", get_available_devices())
Aarni Koskela's avatar
Aarni Koskela committed
68
@pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights"))
69
@pytest.mark.parametrize("threshold", [0.0, 6.0], ids=id_formatter("threshold"))
Aarni Koskela's avatar
Aarni Koskela committed
70
71
@pytest.mark.parametrize("serialize_before_forward", TRUE_FALSE, ids=id_formatter("serialize_before_forward"))
@pytest.mark.parametrize("deserialize_before_cuda", TRUE_FALSE, ids=id_formatter("deserialize_before_cuda"))
72
73
@pytest.mark.parametrize("save_before_forward", TRUE_FALSE, ids=id_formatter("save_before_forward"))
@pytest.mark.parametrize("load_before_cuda", TRUE_FALSE, ids=id_formatter("load_before_cuda"))
Ruff's avatar
Ruff committed
74
def test_linear_serialization(
75
    device,
Ruff's avatar
Ruff committed
76
    has_fp16_weights,
77
    threshold,
Ruff's avatar
Ruff committed
78
79
80
81
82
    serialize_before_forward,
    deserialize_before_cuda,
    save_before_forward,
    load_before_cuda,
):
83
84
    if device != "cuda" and has_fp16_weights:
        pytest.skip("has_fp16_weights is only supported on CUDA and is deprecated")
85

86
    linear = torch.nn.Linear(32, 96)
87
88
89
    # TODO: Fallback for bad shapes
    x = torch.randn(4, 32, dtype=torch.half)
    # x = torch.randn(3, 32, dtype=torch.half)
90
91
92
93
94
95

    linear_custom = Linear8bitLt(
        linear.in_features,
        linear.out_features,
        linear.bias is not None,
        has_fp16_weights=has_fp16_weights,
96
        threshold=threshold,
97
    )
98

99
    linear_custom.weight = bnb.nn.Int8Params(
Ruff's avatar
Ruff committed
100
101
102
        linear.weight.data.clone(),
        requires_grad=has_fp16_weights,
        has_fp16_weights=has_fp16_weights,
103
    )
104
    linear_custom.bias = linear.bias
105
    linear_custom = linear_custom.to(device)
106

107
108
109
    if serialize_before_forward:
        state_dict_8bit = linear_custom.state_dict()

110
111
112
    if save_before_forward:
        bytes_8bit = torch_save_to_buffer(linear_custom)

113
    x_first = x.clone().to(device).requires_grad_(True)
114
115
116
117
    fx_first = linear_custom(x_first).float()
    grad_proj = torch.randn_like(fx_first)
    (fx_first * grad_proj).mean().backward()

118
119
120
    if not serialize_before_forward:
        state_dict_8bit = linear_custom.state_dict()

121
122
123
    if not save_before_forward:
        bytes_8bit = torch_save_to_buffer(linear_custom)

124
125
126
127
128
129
130
131
132
133
    with TemporaryDirectory() as tmpdir:
        state_path_8bit = os.path.join(tmpdir, "state_8bit.pth")
        state_path = os.path.join(tmpdir, "state.pth")

        torch.save(linear.state_dict(), state_path)
        torch.save(state_dict_8bit, state_path_8bit)

        if not has_fp16_weights:
            assert os.path.getsize(state_path_8bit) < 0.5 * os.path.getsize(state_path)

134
        new_state_dict = torch.load(state_path_8bit, weights_only=False)
135
136
137
138
139
140

    new_linear_custom = Linear8bitLt(
        linear.in_features,
        linear.out_features,
        linear.bias is not None,
        has_fp16_weights=has_fp16_weights,
141
        threshold=threshold,
142
    )
143
144
145
146
147

    if deserialize_before_cuda:
        with nullcontext() if has_fp16_weights else pytest.raises(RuntimeError):
            new_linear_custom.load_state_dict(new_state_dict, strict=True)

148
149
150
    if load_before_cuda:
        new_linear_custom2 = torch_load_from_buffer(bytes_8bit)

151
    new_linear_custom = new_linear_custom.to(device)
152
153
154

    if not deserialize_before_cuda:
        new_linear_custom.load_state_dict(new_state_dict, strict=True)
155

156
157
158
    if not load_before_cuda:
        new_linear_custom2 = torch_load_from_buffer(bytes_8bit)

159
    x_second = x.clone().to(device).requires_grad_(True)
160
161
162
    fx_second = new_linear_custom(x_second).float()
    (fx_second * grad_proj).mean().backward()

163
    x_third = x.clone().to(device).requires_grad_(True)
164
165
166
    fx_third = new_linear_custom2(x_third).float()
    (fx_third * grad_proj).mean().backward()

167
168
169
170
    # if 8-bit weights were loaded before .cuda, state is incorrect anyway and RuntimeError was raised
    if has_fp16_weights or not deserialize_before_cuda:
        assert torch.allclose(fx_first, fx_second, atol=1e-5)
        assert torch.allclose(x_first.grad, x_second.grad, atol=1e-5)
171
    assert torch.allclose(fx_first, fx_third, atol=1e-5)
172
    assert torch.allclose(x_first.grad, x_third.grad, atol=1e-5)
173
174
175


@pytest.fixture
176
def linear8bit(requires_cuda):
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
    linear = torch.nn.Linear(32, 96)
    linear_custom = Linear8bitLt(
        linear.in_features,
        linear.out_features,
        linear.bias is not None,
        has_fp16_weights=False,
        threshold=6.0,
    )
    linear_custom.weight = bnb.nn.Int8Params(
        linear.weight.data.clone(),
        requires_grad=False,
        has_fp16_weights=False,
    )
    linear_custom.bias = linear.bias
    linear_custom = linear_custom.cuda()
    return linear_custom


def test_linear8bit_copy_param(linear8bit):
    shallow_copy = copy.copy(linear8bit)
    assert linear8bit.weight is shallow_copy.weight
    assert linear8bit.bias is shallow_copy.bias
    assert linear8bit.weight.data.data_ptr() == shallow_copy.weight.data.data_ptr()


def test_linear8bit_deepcopy_param(linear8bit):
    deep_copy = copy.deepcopy(linear8bit)
    assert linear8bit.weight is not deep_copy.weight
    assert linear8bit.bias is not deep_copy.bias
    assert linear8bit.weight.data.data_ptr() != deep_copy.weight.data.data_ptr()
    assert torch.allclose(linear8bit.weight.data, deep_copy.weight.data)
    assert linear8bit.state == deep_copy.state

    # check for a bug where SCB and CB were not copied
    assert deep_copy.weight.SCB is not None
    assert (linear8bit.weight.SCB == deep_copy.weight.SCB).all()
    assert deep_copy.weight.CB is not None
    assert (linear8bit.weight.CB == deep_copy.weight.CB).all()


def test_linear8bit_serialization(linear8bit):
    serialized = pickle.dumps(linear8bit)
    deserialized = pickle.loads(serialized)
    assert linear8bit.weight.data.data_ptr() != deserialized.weight.data.data_ptr()
    assert torch.allclose(linear8bit.weight.data, deserialized.weight.data)
    assert linear8bit.bias.data.data_ptr() != deserialized.bias.data.data_ptr()
    assert torch.allclose(linear8bit.bias.data, deserialized.bias.data)
    assert linear8bit.state == deserialized.state

    # check for a bug where SCB and CB were not copied
    assert (linear8bit.weight.SCB == deserialized.weight.SCB).all()
    assert (linear8bit.weight.CB == deserialized.weight.CB).all()
229
230
231
232
233
234
235
236


@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("threshold", [0.0, 6.0], ids=id_formatter("threshold"))
@pytest.mark.parametrize("bias", TRUE_FALSE, ids=id_formatter("bias"))
@pytest.mark.parametrize("fullgraph", TRUE_FALSE, ids=id_formatter("fullgraph"))
@pytest.mark.parametrize("mode", ["default", "reduce-overhead"], ids=id_formatter("mode"))
@pytest.mark.skipif(torch.__version__ < (2, 4), reason="Not supported in torch < 2.4")
237
@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet")
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
def test_linear8bitlt_torch_compile(device, threshold, bias, fullgraph, mode):
    if device == "cuda" and platform.system() == "Windows":
        pytest.skip("Triton is not officially supported on Windows")

    dim = 256
    batch_size = 16

    torch.compiler.reset()

    # Create a small network with Linear8bitLt layers
    net = torch.nn.Sequential(
        *[bnb.nn.Linear8bitLt(dim, dim, bias=bias, has_fp16_weights=False, threshold=threshold) for _ in range(4)]
    ).to(device)

    dynamic_output_shapes = fullgraph and threshold > 0
    with torch._dynamo.config.patch("capture_dynamic_output_shape_ops", dynamic_output_shapes):
        # Create input tensor
        x = torch.randn(batch_size, dim, dtype=torch.float16, device=device)

        # Get reference output before compilation
        with torch.no_grad():
            ref_output = net(x)

        # Compile the model
262
263
        compile_backend = "hpu_backend" if device == "hpu" else "inductor"
        compiled_net = torch.compile(net, fullgraph=fullgraph, mode=mode, backend=compile_backend)
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279

        # Get output from compiled model
        with torch.no_grad():
            compiled_output = compiled_net(x)

        # Check outputs match
        assert compiled_output.shape == ref_output.shape
        assert compiled_output.device == ref_output.device
        assert compiled_output.dtype == ref_output.dtype
        torch.testing.assert_close(compiled_output, ref_output)

        # Test with gradients. Currently only works with threshold=0.
        # Has a strange regression on Linux aarch64 CPU in torch==2.6.0.
        is_broken_platform = (
            device == "cpu"
            and platform.system() == "Linux"
280
281
            and platform.machine() == "aarch64"
            and (2, 6) <= torch.__version__ < (2, 7)
282
283
284
285
286
287
288
289
290
291
292
293
294
295
        )

        if threshold == 0 and not is_broken_platform:
            x.requires_grad_(True)
            y1 = net(x).sum()
            y1.backward()
            grad_ref = x.grad.clone()

            x.grad = None
            y2 = compiled_net(x).sum()
            y2.backward()
            grad_compiled = x.grad.clone()

            torch.testing.assert_close(grad_compiled, grad_ref)