test_linear4bit.py 12.8 KB
Newer Older
1
import copy
Aarni Koskela's avatar
Aarni Koskela committed
2
import os
3
import pickle
4
import platform
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
5
6
7
8
9
10
from tempfile import TemporaryDirectory

import pytest
import torch

import bitsandbytes as bnb
11
from bitsandbytes.cextension import HIP_ENVIRONMENT
12
13
14
15
16
from tests.helpers import (
    TRUE_FALSE,
    describe_dtype,
    get_available_devices,
    id_formatter,
17
    is_supported_on_hpu,
18
19
20
    torch_load_from_buffer,
    torch_save_to_buffer,
)
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
21

22
storage = {
23
24
25
26
    "uint8": torch.uint8,
    "float16": torch.float16,
    "bfloat16": torch.bfloat16,
    "float32": torch.float32,
27
}
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
28

29

30
@pytest.mark.parametrize("device", get_available_devices())
31
@pytest.mark.parametrize("quant_storage", ["uint8", "float16", "bfloat16", "float32"])
32
@pytest.mark.parametrize("original_dtype", [torch.float16, torch.bfloat16])
33
34
@pytest.mark.parametrize("bias", TRUE_FALSE, ids=id_formatter("bias"))
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
Aarni Koskela's avatar
Aarni Koskela committed
35
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
36
@pytest.mark.parametrize("save_before_forward", TRUE_FALSE, ids=id_formatter("save_before_forward"))
37
38
39
40
41
42
def test_linear_serialization(
    device, quant_type, original_dtype, compress_statistics, bias, quant_storage, save_before_forward
):
    if device == "hpu" and not is_supported_on_hpu(quant_type, original_dtype, storage[quant_storage]):
        pytest.skip("This configuration is not supported on HPU.")

Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
43
44
45
    compute_dtype = None
    layer_shape = (300, 400)

Ruff's avatar
Ruff committed
46
    linear = torch.nn.Linear(*layer_shape, dtype=original_dtype, device="cpu")  # original layer
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
47
48
49
50
51
52
53
54
55

    # Quantizing original layer
    linear_q = bnb.nn.Linear4bit(
        linear.in_features,
        linear.out_features,
        bias=bias,
        compute_dtype=compute_dtype,
        compress_statistics=compress_statistics,
        quant_type=quant_type,
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
56
        device="meta",
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
57
    )
Ruff's avatar
Ruff committed
58
    new_weight = bnb.nn.Params4bit(data=linear.weight, quant_type=quant_type, requires_grad=False)
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
59
    linear_q.weight = new_weight
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
60
    if bias:
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
61
62
        linear_q.bias = torch.nn.Parameter(linear.bias)
    linear_q = linear_q.to(device)
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
63

64
    # saving to state_dict:
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
65
    sd = linear_q.state_dict()
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
66

67
68
69
    # restoring from state_dict:
    bias_data2 = sd.pop("bias", None)
    weight_data2 = sd.pop("weight")
70
    weight2 = bnb.nn.Params4bit.from_prequantized(quantized_stats=sd, data=weight_data2, device=device)
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
71

72
    # creating new layer with same params:
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
73
74
75
76
77
78
79
    linear_q2 = bnb.nn.Linear4bit(
        linear.in_features,
        linear.out_features,
        bias=bias,
        compute_dtype=compute_dtype,
        compress_statistics=compress_statistics,
        quant_type=quant_type,
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
80
        device="meta",
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
81
    )
82
    # loading weights from state_dict:
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
83
    linear_q2.weight = weight2
84
85
    if bias:
        linear_q2.bias = torch.nn.Parameter(bias_data2)
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
86
    linear_q2 = linear_q2.to(device)
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
87

88
    # MATCHING
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
89
90
    a, b = linear_q.weight, linear_q2.weight

91
92
93
94
95
96
97
98
99
100
101
    # Quantizing original layer with specified quant_storage type
    linear_qs = bnb.nn.Linear4bit(
        linear.in_features,
        linear.out_features,
        bias=bias,
        compute_dtype=compute_dtype,
        compress_statistics=compress_statistics,
        quant_type=quant_type,
        quant_storage=storage[quant_storage],
        device="meta",
    )
102
103
104
105
106
107
    linear_qs.weight = bnb.nn.Params4bit(
        data=linear.weight,
        requires_grad=False,
        quant_type=quant_type,
        quant_storage=storage[quant_storage],
    )
108
109
110
111
    if bias:
        linear_qs.bias = torch.nn.Parameter(linear.bias)
    linear_qs = linear_qs.to(device)

Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
112
113
114
    assert a.device == b.device
    assert a.dtype == b.dtype
    assert torch.equal(a, b)
115

Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
116
117
    q0 = a.quant_state
    q1 = b.quant_state
118
    for attr in ("code", "dtype", "blocksize", "absmax"):
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
119
120
121
122
123
124
125
        c, d = getattr(q0, attr), getattr(q1, attr)
        if isinstance(c, torch.Tensor):
            assert torch.equal(c, d)
        else:
            assert c == d, f"{c} != {d}"

    if q0.state2 is not None:
126
        for attr in ("code", "dtype", "blocksize", "absmax"):
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
127
128
129
130
131
132
133
134
135
136
137
138
            c, d = getattr(q0.state2, attr), getattr(q1.state2, attr)
            if isinstance(c, torch.Tensor):
                assert torch.equal(c, d)
            else:
                assert c == d, f"{c} != {d}"

    if bias:
        a, b = linear_q.bias, linear_q2.bias
        assert a.device == b.device
        assert a.dtype == b.dtype
        assert torch.equal(a, b)

139
140
141
    if save_before_forward:
        bytes_4bit = torch_save_to_buffer(linear_q)

Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
142
    # Forward test
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
143
    x = torch.rand(42, layer_shape[0], device=device)
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
144
145
    a = linear_q(x)
    b = linear_q2(x)
146
    c = linear_qs(x)
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
147
148
    assert a.device == b.device
    assert a.dtype == b.dtype
149
150
    assert a.device == c.device
    assert a.dtype == c.dtype
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
151
    assert torch.equal(a, b)
152
153
    assert torch.equal(a, c)

154
155
156
157
    if not save_before_forward:
        bytes_4bit = torch_save_to_buffer(linear_q)
    linear_q3 = torch_load_from_buffer(bytes_4bit)

158
    # Test moving to CPU and back to GPU
159
160
161
    if device != "cpu":
        linear_q2.to("cpu")
        linear_q2.to(device)
162
163
164
165
    d = linear_qs(x)
    assert c.dtype == d.dtype
    assert c.device == d.device
    assert torch.equal(c, d)
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
166

167
168
169
170
171
    d = linear_q3(x)
    assert c.dtype == d.dtype
    assert c.device == d.device
    assert torch.equal(c, d)

Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
172
173
174
175
176
177
178
    # Saved size ratio test. Target set for layer_shape == (300, 400) w/ bias
    with TemporaryDirectory() as tmpdir:
        state_path_4bit = os.path.join(tmpdir, "state_4bit.pth")
        state_path = os.path.join(tmpdir, "state.pth")
        torch.save(linear.state_dict(), state_path)
        torch.save(linear_q.state_dict(), state_path_4bit)

179
180
181
        size_orig, size_4 = (
            os.path.getsize(state_path),
            os.path.getsize(state_path_4bit),
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
182
183
        )
        size_ratio = size_4 / size_orig
184
185
186
        target_compression = (
            0.143 if original_dtype == torch.float32 else 0.29
        )  # these numbers get lower as weight shape increases
Ruff's avatar
Ruff committed
187
188
189
        ratio_error_msg = (
            f"quantized_size {size_4:,} is larger on disk than {target_compression:.2%} of original size {size_orig:,}"
        )
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
190
        assert size_ratio < target_compression, ratio_error_msg
191
192


193
194
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
195
@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128])
196
197
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
def test_copy_param(device, quant_type, blocksize, compress_statistics):
198
199
200
    if device == "hpu" and not is_supported_on_hpu(quant_type):
        pytest.skip("This configuration is not supported on HPU.")

201
    tensor = torch.randn(300, 400)
202
203
204
205
206
207
208
    param = bnb.nn.Params4bit(
        data=tensor,
        quant_type=quant_type,
        blocksize=blocksize,
        compress_statistics=compress_statistics,
        requires_grad=False,
    ).to(device)
209
210
211
212
213
214

    shallow_copy_param = copy.copy(param)
    assert param.quant_state is shallow_copy_param.quant_state
    assert param.data.data_ptr() == shallow_copy_param.data.data_ptr()


215
216
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
217
@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128])
218
219
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
def test_deepcopy_param(device, quant_type, blocksize, compress_statistics):
220
221
222
    if device == "hpu" and not is_supported_on_hpu(quant_type):
        pytest.skip("This configuration is not supported on HPU.")

223
    tensor = torch.randn(300, 400)
224
225
226
227
228
229
230
    param = bnb.nn.Params4bit(
        data=tensor,
        quant_type=quant_type,
        blocksize=blocksize,
        compress_statistics=compress_statistics,
        requires_grad=False,
    ).to(device)
231
    dict_keys_before = set(param.__dict__.keys())
232
    copy_param = copy.deepcopy(param)
233
234
235
    dict_keys_after = set(param.__dict__.keys())
    dict_keys_copy = set(copy_param.__dict__.keys())

236
237
238
    assert param.quant_state is not copy_param.quant_state
    assert param.data.data_ptr() != copy_param.data.data_ptr()

239
240
241
242
    # there was a bug where deepcopy would modify the original object
    assert dict_keys_before == dict_keys_after
    assert dict_keys_before == dict_keys_copy

243

244
245
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
246
@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128])
247
248
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
def test_params4bit_real_serialization(device, quant_type, blocksize, compress_statistics):
249
250
251
    if device == "hpu" and not is_supported_on_hpu(quant_type):
        pytest.skip("This configuration is not supported on HPU.")

252
    original_tensor = torch.randn(300, 400)
253
254
255
256
257
258
    original_param = bnb.nn.Params4bit(
        data=original_tensor,
        quant_type=quant_type,
        blocksize=blocksize,
        compress_statistics=compress_statistics,
    )
259
    dict_keys_before = set(original_param.__dict__.keys())
260

261
    original_param.to(device)  # change device to trigger quantization
262
263
264

    serialized_param = pickle.dumps(original_param)
    deserialized_param = pickle.loads(serialized_param)
265
266
    dict_keys_after = set(original_param.__dict__.keys())
    dict_keys_deserialized = set(deserialized_param.__dict__.keys())
267
268
269
270
271
272
273

    assert torch.equal(original_param.data, deserialized_param.data)
    assert original_param.requires_grad == deserialized_param.requires_grad == False
    assert original_param.quant_type == deserialized_param.quant_type
    assert original_param.blocksize == deserialized_param.blocksize
    assert original_param.compress_statistics == deserialized_param.compress_statistics
    assert original_param.quant_state == deserialized_param.quant_state
274
275
276
277

    # there was a bug where deepcopy would modify the original object
    assert dict_keys_before == dict_keys_after
    assert dict_keys_before == dict_keys_deserialized
278
279
280
281
282
283
284
285
286
287
288


@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
@pytest.mark.parametrize("compute_dtype", [torch.bfloat16, torch.float32], ids=describe_dtype)
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
@pytest.mark.parametrize("bias", TRUE_FALSE, ids=id_formatter("bias"))
@pytest.mark.parametrize("fullgraph", TRUE_FALSE, ids=id_formatter("fullgraph"))
@pytest.mark.parametrize("mode", ["default", "reduce-overhead"], ids=id_formatter("mode"))
@pytest.mark.skipif(torch.__version__ < (2, 4), reason="Not supported in torch < 2.4")
def test_linear4bit_torch_compile(device, quant_type, compute_dtype, compress_statistics, bias, fullgraph, mode):
289
290
291
    if device == "hpu" and not is_supported_on_hpu(quant_type):
        pytest.skip("This configuration is not supported on HPU.")

292
    if fullgraph and torch.__version__ < (2, 8, 0, "dev"):
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
        pytest.skip("fullgraph mode requires torch 2.8 or higher")

    if device == "cuda" and platform.system() == "Windows":
        pytest.skip("Triton is not officially supported on Windows")

    # Has a strange regression on Linux aarch64 CPU in torch==2.6.0 when fullgraph=False.
    if (
        not fullgraph
        and device == "cpu"
        and platform.machine() == "aarch64"
        and platform.system() == "Linux"
        and ((2, 7) > torch.__version__ >= (2, 6))
    ):
        pytest.xfail("Regression in torch==2.6.0 on Linux aarch64 CPU")

    dim = 256
    batch_size = 16

    torch.compiler.reset()

    # Create a small network with Linear4bit layers
    net = torch.nn.Sequential(
        *[
            bnb.nn.Linear4bit(
                dim,
                dim,
                bias=bias,
                compute_dtype=compute_dtype,
                compress_statistics=compress_statistics,
                quant_type=quant_type,
            )
            for _ in range(4)
        ]
    ).to(device)

    # Create input tensor
    x = torch.randn(batch_size, dim, dtype=compute_dtype, device=device)

    # Get reference output before compilation
    with torch.no_grad():
        ref_output = net(x)

    # Compile the model
336
337
    compile_backend = "hpu_backend" if device == "hpu" else "inductor"
    compiled_net = torch.compile(net, fullgraph=fullgraph, mode=mode, backend=compile_backend)
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360

    # Get output from compiled model
    with torch.no_grad():
        compiled_output = compiled_net(x)

    # Check outputs match
    assert compiled_output.shape == ref_output.shape
    assert compiled_output.device == ref_output.device
    assert compiled_output.dtype == ref_output.dtype
    torch.testing.assert_close(compiled_output, ref_output)

    # Test with gradients
    x.requires_grad_(True)
    y1 = net(x).sum()
    y1.backward()
    grad_ref = x.grad.clone()

    x.grad = None
    y2 = compiled_net(x).sum()
    y2.backward()
    grad_compiled = x.grad.clone()

    torch.testing.assert_close(grad_compiled, grad_ref)