test_linear4bit.py 12.6 KB
Newer Older
1
import copy
Aarni Koskela's avatar
Aarni Koskela committed
2
import os
3
import pickle
4
import platform
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
5
6
7
8
9
10
from tempfile import TemporaryDirectory

import pytest
import torch

import bitsandbytes as bnb
11
12
13
14
15
from tests.helpers import (
    TRUE_FALSE,
    describe_dtype,
    get_available_devices,
    id_formatter,
16
    is_supported_on_hpu,
17
18
19
    torch_load_from_buffer,
    torch_save_to_buffer,
)
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
20

21
storage = {
22
23
24
25
    "uint8": torch.uint8,
    "float16": torch.float16,
    "bfloat16": torch.bfloat16,
    "float32": torch.float32,
26
}
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
27

28

29
@pytest.mark.parametrize("device", get_available_devices())
30
@pytest.mark.parametrize("quant_storage", ["uint8", "float16", "bfloat16", "float32"])
31
@pytest.mark.parametrize("original_dtype", [torch.float16, torch.bfloat16])
32
33
@pytest.mark.parametrize("bias", TRUE_FALSE, ids=id_formatter("bias"))
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
Aarni Koskela's avatar
Aarni Koskela committed
34
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
35
@pytest.mark.parametrize("save_before_forward", TRUE_FALSE, ids=id_formatter("save_before_forward"))
36
37
38
39
40
41
def test_linear_serialization(
    device, quant_type, original_dtype, compress_statistics, bias, quant_storage, save_before_forward
):
    if device == "hpu" and not is_supported_on_hpu(quant_type, original_dtype, storage[quant_storage]):
        pytest.skip("This configuration is not supported on HPU.")

Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
42
43
44
    compute_dtype = None
    layer_shape = (300, 400)

Ruff's avatar
Ruff committed
45
    linear = torch.nn.Linear(*layer_shape, dtype=original_dtype, device="cpu")  # original layer
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
46
47
48
49
50
51
52
53
54

    # Quantizing original layer
    linear_q = bnb.nn.Linear4bit(
        linear.in_features,
        linear.out_features,
        bias=bias,
        compute_dtype=compute_dtype,
        compress_statistics=compress_statistics,
        quant_type=quant_type,
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
55
        device="meta",
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
56
    )
Ruff's avatar
Ruff committed
57
    new_weight = bnb.nn.Params4bit(data=linear.weight, quant_type=quant_type, requires_grad=False)
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
58
    linear_q.weight = new_weight
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
59
    if bias:
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
60
61
        linear_q.bias = torch.nn.Parameter(linear.bias)
    linear_q = linear_q.to(device)
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
62

63
    # saving to state_dict:
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
64
    sd = linear_q.state_dict()
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
65

66
67
68
    # restoring from state_dict:
    bias_data2 = sd.pop("bias", None)
    weight_data2 = sd.pop("weight")
69
    weight2 = bnb.nn.Params4bit.from_prequantized(quantized_stats=sd, data=weight_data2, device=device)
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
70

71
    # creating new layer with same params:
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
72
73
74
75
76
77
78
    linear_q2 = bnb.nn.Linear4bit(
        linear.in_features,
        linear.out_features,
        bias=bias,
        compute_dtype=compute_dtype,
        compress_statistics=compress_statistics,
        quant_type=quant_type,
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
79
        device="meta",
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
80
    )
81
    # loading weights from state_dict:
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
82
    linear_q2.weight = weight2
83
84
    if bias:
        linear_q2.bias = torch.nn.Parameter(bias_data2)
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
85
    linear_q2 = linear_q2.to(device)
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
86

87
    # MATCHING
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
88
89
    a, b = linear_q.weight, linear_q2.weight

90
91
92
93
94
95
96
97
98
99
100
    # Quantizing original layer with specified quant_storage type
    linear_qs = bnb.nn.Linear4bit(
        linear.in_features,
        linear.out_features,
        bias=bias,
        compute_dtype=compute_dtype,
        compress_statistics=compress_statistics,
        quant_type=quant_type,
        quant_storage=storage[quant_storage],
        device="meta",
    )
101
102
103
104
105
106
    linear_qs.weight = bnb.nn.Params4bit(
        data=linear.weight,
        requires_grad=False,
        quant_type=quant_type,
        quant_storage=storage[quant_storage],
    )
107
108
109
110
    if bias:
        linear_qs.bias = torch.nn.Parameter(linear.bias)
    linear_qs = linear_qs.to(device)

Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
111
112
113
    assert a.device == b.device
    assert a.dtype == b.dtype
    assert torch.equal(a, b)
114

Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
115
116
    q0 = a.quant_state
    q1 = b.quant_state
117
    for attr in ("code", "dtype", "blocksize", "absmax"):
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
118
119
120
121
122
123
124
        c, d = getattr(q0, attr), getattr(q1, attr)
        if isinstance(c, torch.Tensor):
            assert torch.equal(c, d)
        else:
            assert c == d, f"{c} != {d}"

    if q0.state2 is not None:
125
        for attr in ("code", "dtype", "blocksize", "absmax"):
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
126
127
128
129
130
131
132
133
134
135
136
137
            c, d = getattr(q0.state2, attr), getattr(q1.state2, attr)
            if isinstance(c, torch.Tensor):
                assert torch.equal(c, d)
            else:
                assert c == d, f"{c} != {d}"

    if bias:
        a, b = linear_q.bias, linear_q2.bias
        assert a.device == b.device
        assert a.dtype == b.dtype
        assert torch.equal(a, b)

138
139
140
    if save_before_forward:
        bytes_4bit = torch_save_to_buffer(linear_q)

Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
141
    # Forward test
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
142
    x = torch.rand(42, layer_shape[0], device=device)
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
143
144
    a = linear_q(x)
    b = linear_q2(x)
145
    c = linear_qs(x)
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
146
147
    assert a.device == b.device
    assert a.dtype == b.dtype
148
149
    assert a.device == c.device
    assert a.dtype == c.dtype
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
150
    assert torch.equal(a, b)
151
152
    assert torch.equal(a, c)

153
154
155
156
    if not save_before_forward:
        bytes_4bit = torch_save_to_buffer(linear_q)
    linear_q3 = torch_load_from_buffer(bytes_4bit)

157
    # Test moving to CPU and back to GPU
158
159
160
    if device != "cpu":
        linear_q2.to("cpu")
        linear_q2.to(device)
161
162
163
164
    d = linear_qs(x)
    assert c.dtype == d.dtype
    assert c.device == d.device
    assert torch.equal(c, d)
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
165

166
167
168
169
170
    d = linear_q3(x)
    assert c.dtype == d.dtype
    assert c.device == d.device
    assert torch.equal(c, d)

Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
171
172
173
174
175
176
177
    # Saved size ratio test. Target set for layer_shape == (300, 400) w/ bias
    with TemporaryDirectory() as tmpdir:
        state_path_4bit = os.path.join(tmpdir, "state_4bit.pth")
        state_path = os.path.join(tmpdir, "state.pth")
        torch.save(linear.state_dict(), state_path)
        torch.save(linear_q.state_dict(), state_path_4bit)

178
179
180
        size_orig, size_4 = (
            os.path.getsize(state_path),
            os.path.getsize(state_path_4bit),
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
181
182
        )
        size_ratio = size_4 / size_orig
183
184
185
        target_compression = (
            0.143 if original_dtype == torch.float32 else 0.29
        )  # these numbers get lower as weight shape increases
Ruff's avatar
Ruff committed
186
187
188
        ratio_error_msg = (
            f"quantized_size {size_4:,} is larger on disk than {target_compression:.2%} of original size {size_orig:,}"
        )
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
189
        assert size_ratio < target_compression, ratio_error_msg
190
191


192
193
194
195
196
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
@pytest.mark.parametrize("blocksize", [64, 128])
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
def test_copy_param(device, quant_type, blocksize, compress_statistics):
197
198
199
    if device == "hpu" and not is_supported_on_hpu(quant_type):
        pytest.skip("This configuration is not supported on HPU.")

200
    tensor = torch.randn(300, 400)
201
202
203
204
205
206
207
    param = bnb.nn.Params4bit(
        data=tensor,
        quant_type=quant_type,
        blocksize=blocksize,
        compress_statistics=compress_statistics,
        requires_grad=False,
    ).to(device)
208
209
210
211
212
213

    shallow_copy_param = copy.copy(param)
    assert param.quant_state is shallow_copy_param.quant_state
    assert param.data.data_ptr() == shallow_copy_param.data.data_ptr()


214
215
216
217
218
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
@pytest.mark.parametrize("blocksize", [64, 128])
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
def test_deepcopy_param(device, quant_type, blocksize, compress_statistics):
219
220
221
    if device == "hpu" and not is_supported_on_hpu(quant_type):
        pytest.skip("This configuration is not supported on HPU.")

222
    tensor = torch.randn(300, 400)
223
224
225
226
227
228
229
    param = bnb.nn.Params4bit(
        data=tensor,
        quant_type=quant_type,
        blocksize=blocksize,
        compress_statistics=compress_statistics,
        requires_grad=False,
    ).to(device)
230
    dict_keys_before = set(param.__dict__.keys())
231
    copy_param = copy.deepcopy(param)
232
233
234
    dict_keys_after = set(param.__dict__.keys())
    dict_keys_copy = set(copy_param.__dict__.keys())

235
236
237
    assert param.quant_state is not copy_param.quant_state
    assert param.data.data_ptr() != copy_param.data.data_ptr()

238
239
240
241
    # there was a bug where deepcopy would modify the original object
    assert dict_keys_before == dict_keys_after
    assert dict_keys_before == dict_keys_copy

242

243
244
245
246
247
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
@pytest.mark.parametrize("blocksize", [64, 128])
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
def test_params4bit_real_serialization(device, quant_type, blocksize, compress_statistics):
248
249
250
    if device == "hpu" and not is_supported_on_hpu(quant_type):
        pytest.skip("This configuration is not supported on HPU.")

251
    original_tensor = torch.randn(300, 400)
252
253
254
255
256
257
    original_param = bnb.nn.Params4bit(
        data=original_tensor,
        quant_type=quant_type,
        blocksize=blocksize,
        compress_statistics=compress_statistics,
    )
258
    dict_keys_before = set(original_param.__dict__.keys())
259

260
    original_param.to(device)  # change device to trigger quantization
261
262
263

    serialized_param = pickle.dumps(original_param)
    deserialized_param = pickle.loads(serialized_param)
264
265
    dict_keys_after = set(original_param.__dict__.keys())
    dict_keys_deserialized = set(deserialized_param.__dict__.keys())
266
267
268
269
270
271
272

    assert torch.equal(original_param.data, deserialized_param.data)
    assert original_param.requires_grad == deserialized_param.requires_grad == False
    assert original_param.quant_type == deserialized_param.quant_type
    assert original_param.blocksize == deserialized_param.blocksize
    assert original_param.compress_statistics == deserialized_param.compress_statistics
    assert original_param.quant_state == deserialized_param.quant_state
273
274
275
276

    # there was a bug where deepcopy would modify the original object
    assert dict_keys_before == dict_keys_after
    assert dict_keys_before == dict_keys_deserialized
277
278
279
280
281
282
283
284
285
286
287


@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
@pytest.mark.parametrize("compute_dtype", [torch.bfloat16, torch.float32], ids=describe_dtype)
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
@pytest.mark.parametrize("bias", TRUE_FALSE, ids=id_formatter("bias"))
@pytest.mark.parametrize("fullgraph", TRUE_FALSE, ids=id_formatter("fullgraph"))
@pytest.mark.parametrize("mode", ["default", "reduce-overhead"], ids=id_formatter("mode"))
@pytest.mark.skipif(torch.__version__ < (2, 4), reason="Not supported in torch < 2.4")
def test_linear4bit_torch_compile(device, quant_type, compute_dtype, compress_statistics, bias, fullgraph, mode):
288
289
290
    if device == "hpu" and not is_supported_on_hpu(quant_type):
        pytest.skip("This configuration is not supported on HPU.")

291
    if fullgraph and torch.__version__ < (2, 8, 0, "dev"):
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
        pytest.skip("fullgraph mode requires torch 2.8 or higher")

    if device == "cuda" and platform.system() == "Windows":
        pytest.skip("Triton is not officially supported on Windows")

    # Has a strange regression on Linux aarch64 CPU in torch==2.6.0 when fullgraph=False.
    if (
        not fullgraph
        and device == "cpu"
        and platform.machine() == "aarch64"
        and platform.system() == "Linux"
        and ((2, 7) > torch.__version__ >= (2, 6))
    ):
        pytest.xfail("Regression in torch==2.6.0 on Linux aarch64 CPU")

    dim = 256
    batch_size = 16

    torch.compiler.reset()

    # Create a small network with Linear4bit layers
    net = torch.nn.Sequential(
        *[
            bnb.nn.Linear4bit(
                dim,
                dim,
                bias=bias,
                compute_dtype=compute_dtype,
                compress_statistics=compress_statistics,
                quant_type=quant_type,
            )
            for _ in range(4)
        ]
    ).to(device)

    # Create input tensor
    x = torch.randn(batch_size, dim, dtype=compute_dtype, device=device)

    # Get reference output before compilation
    with torch.no_grad():
        ref_output = net(x)

    # Compile the model
335
336
    compile_backend = "hpu_backend" if device == "hpu" else "inductor"
    compiled_net = torch.compile(net, fullgraph=fullgraph, mode=mode, backend=compile_backend)
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359

    # Get output from compiled model
    with torch.no_grad():
        compiled_output = compiled_net(x)

    # Check outputs match
    assert compiled_output.shape == ref_output.shape
    assert compiled_output.device == ref_output.device
    assert compiled_output.dtype == ref_output.dtype
    torch.testing.assert_close(compiled_output, ref_output)

    # Test with gradients
    x.requires_grad_(True)
    y1 = net(x).sum()
    y1.backward()
    grad_ref = x.grad.clone()

    x.grad = None
    y2 = compiled_net(x).sum()
    y2.backward()
    grad_compiled = x.grad.clone()

    torch.testing.assert_close(grad_compiled, grad_ref)