test_linear4bit.py 14.4 KB
Newer Older
1
import copy
Aarni Koskela's avatar
Aarni Koskela committed
2
import os
3
import pickle
4
import platform
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
5
6
7
8
9
10
from tempfile import TemporaryDirectory

import pytest
import torch

import bitsandbytes as bnb
11
from bitsandbytes.cextension import HIP_ENVIRONMENT
12
13
14
15
16
from tests.helpers import (
    TRUE_FALSE,
    describe_dtype,
    get_available_devices,
    id_formatter,
17
    is_supported_on_hpu,
18
19
20
    torch_load_from_buffer,
    torch_save_to_buffer,
)
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
21

22
storage = {
23
24
25
26
    "uint8": torch.uint8,
    "float16": torch.float16,
    "bfloat16": torch.bfloat16,
    "float32": torch.float32,
27
}
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
28

29

30
@pytest.mark.parametrize("device", get_available_devices())
31
@pytest.mark.parametrize("quant_storage", ["uint8", "float16", "bfloat16", "float32"])
32
@pytest.mark.parametrize("original_dtype", [torch.float16, torch.bfloat16])
33
34
@pytest.mark.parametrize("bias", TRUE_FALSE, ids=id_formatter("bias"))
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
Aarni Koskela's avatar
Aarni Koskela committed
35
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
36
@pytest.mark.parametrize("save_before_forward", TRUE_FALSE, ids=id_formatter("save_before_forward"))
37
38
39
40
41
42
def test_linear_serialization(
    device, quant_type, original_dtype, compress_statistics, bias, quant_storage, save_before_forward
):
    if device == "hpu" and not is_supported_on_hpu(quant_type, original_dtype, storage[quant_storage]):
        pytest.skip("This configuration is not supported on HPU.")

Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
43
44
45
    compute_dtype = None
    layer_shape = (300, 400)

Ruff's avatar
Ruff committed
46
    linear = torch.nn.Linear(*layer_shape, dtype=original_dtype, device="cpu")  # original layer
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
47
48
49
50
51
52
53
54
55

    # Quantizing original layer
    linear_q = bnb.nn.Linear4bit(
        linear.in_features,
        linear.out_features,
        bias=bias,
        compute_dtype=compute_dtype,
        compress_statistics=compress_statistics,
        quant_type=quant_type,
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
56
        device="meta",
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
57
    )
Ruff's avatar
Ruff committed
58
    new_weight = bnb.nn.Params4bit(data=linear.weight, quant_type=quant_type, requires_grad=False)
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
59
    linear_q.weight = new_weight
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
60
    if bias:
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
61
62
        linear_q.bias = torch.nn.Parameter(linear.bias)
    linear_q = linear_q.to(device)
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
63

64
    # saving to state_dict:
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
65
    sd = linear_q.state_dict()
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
66

67
68
69
    # restoring from state_dict:
    bias_data2 = sd.pop("bias", None)
    weight_data2 = sd.pop("weight")
70
    weight2 = bnb.nn.Params4bit.from_prequantized(quantized_stats=sd, data=weight_data2, device=device)
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
71

72
    # creating new layer with same params:
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
73
74
75
76
77
78
79
    linear_q2 = bnb.nn.Linear4bit(
        linear.in_features,
        linear.out_features,
        bias=bias,
        compute_dtype=compute_dtype,
        compress_statistics=compress_statistics,
        quant_type=quant_type,
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
80
        device="meta",
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
81
    )
82
    # loading weights from state_dict:
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
83
    linear_q2.weight = weight2
84
85
    if bias:
        linear_q2.bias = torch.nn.Parameter(bias_data2)
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
86
    linear_q2 = linear_q2.to(device)
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
87

88
    # MATCHING
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
89
90
    a, b = linear_q.weight, linear_q2.weight

91
92
93
94
95
96
97
98
99
100
101
    # Quantizing original layer with specified quant_storage type
    linear_qs = bnb.nn.Linear4bit(
        linear.in_features,
        linear.out_features,
        bias=bias,
        compute_dtype=compute_dtype,
        compress_statistics=compress_statistics,
        quant_type=quant_type,
        quant_storage=storage[quant_storage],
        device="meta",
    )
102
103
104
105
106
107
    linear_qs.weight = bnb.nn.Params4bit(
        data=linear.weight,
        requires_grad=False,
        quant_type=quant_type,
        quant_storage=storage[quant_storage],
    )
108
109
110
111
    if bias:
        linear_qs.bias = torch.nn.Parameter(linear.bias)
    linear_qs = linear_qs.to(device)

Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
112
113
114
    assert a.device == b.device
    assert a.dtype == b.dtype
    assert torch.equal(a, b)
115

Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
116
117
    q0 = a.quant_state
    q1 = b.quant_state
118
    for attr in ("code", "dtype", "blocksize", "absmax"):
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
119
120
121
122
123
124
125
        c, d = getattr(q0, attr), getattr(q1, attr)
        if isinstance(c, torch.Tensor):
            assert torch.equal(c, d)
        else:
            assert c == d, f"{c} != {d}"

    if q0.state2 is not None:
126
        for attr in ("code", "dtype", "blocksize", "absmax"):
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
127
128
129
130
131
132
133
134
135
136
137
138
            c, d = getattr(q0.state2, attr), getattr(q1.state2, attr)
            if isinstance(c, torch.Tensor):
                assert torch.equal(c, d)
            else:
                assert c == d, f"{c} != {d}"

    if bias:
        a, b = linear_q.bias, linear_q2.bias
        assert a.device == b.device
        assert a.dtype == b.dtype
        assert torch.equal(a, b)

139
140
141
    if save_before_forward:
        bytes_4bit = torch_save_to_buffer(linear_q)

Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
142
    # Forward test
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
143
    x = torch.rand(42, layer_shape[0], device=device)
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
144
145
    a = linear_q(x)
    b = linear_q2(x)
146
    c = linear_qs(x)
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
147
148
    assert a.device == b.device
    assert a.dtype == b.dtype
149
150
    assert a.device == c.device
    assert a.dtype == c.dtype
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
151
    assert torch.equal(a, b)
152
153
    assert torch.equal(a, c)

154
155
156
157
    if not save_before_forward:
        bytes_4bit = torch_save_to_buffer(linear_q)
    linear_q3 = torch_load_from_buffer(bytes_4bit)

158
    # Test moving to CPU and back to GPU
159
160
161
    if device != "cpu":
        linear_q2.to("cpu")
        linear_q2.to(device)
162
163
164
165
    d = linear_qs(x)
    assert c.dtype == d.dtype
    assert c.device == d.device
    assert torch.equal(c, d)
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
166

167
168
169
170
171
    d = linear_q3(x)
    assert c.dtype == d.dtype
    assert c.device == d.device
    assert torch.equal(c, d)

Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
172
173
174
175
176
177
178
    # Saved size ratio test. Target set for layer_shape == (300, 400) w/ bias
    with TemporaryDirectory() as tmpdir:
        state_path_4bit = os.path.join(tmpdir, "state_4bit.pth")
        state_path = os.path.join(tmpdir, "state.pth")
        torch.save(linear.state_dict(), state_path)
        torch.save(linear_q.state_dict(), state_path_4bit)

179
180
181
        size_orig, size_4 = (
            os.path.getsize(state_path),
            os.path.getsize(state_path_4bit),
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
182
183
        )
        size_ratio = size_4 / size_orig
184
185
186
        target_compression = (
            0.143 if original_dtype == torch.float32 else 0.29
        )  # these numbers get lower as weight shape increases
Ruff's avatar
Ruff committed
187
188
189
        ratio_error_msg = (
            f"quantized_size {size_4:,} is larger on disk than {target_compression:.2%} of original size {size_orig:,}"
        )
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
190
        assert size_ratio < target_compression, ratio_error_msg
191
192


193
194
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
195
@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128])
196
197
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
def test_copy_param(device, quant_type, blocksize, compress_statistics):
198
199
200
    if device == "hpu" and not is_supported_on_hpu(quant_type):
        pytest.skip("This configuration is not supported on HPU.")

201
    tensor = torch.randn(300, 400)
202
203
204
205
206
207
208
    param = bnb.nn.Params4bit(
        data=tensor,
        quant_type=quant_type,
        blocksize=blocksize,
        compress_statistics=compress_statistics,
        requires_grad=False,
    ).to(device)
209
210
211
212
213
214

    shallow_copy_param = copy.copy(param)
    assert param.quant_state is shallow_copy_param.quant_state
    assert param.data.data_ptr() == shallow_copy_param.data.data_ptr()


ved1beta's avatar
ved1beta committed
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
def test_params4bit_torch_chunk_split(device, quant_type):
    """Test that torch.chunk and torch.split preserve Params4bit subclass for FSDP2 compatibility."""
    if device == "hpu" and not is_supported_on_hpu(quant_type, torch.float16, torch.uint8):
        pytest.skip("This configuration is not supported on HPU.")

    if device == "cpu":
        pytest.skip("CPU quantization causes segfault, skipping CPU test")

    original_tensor = torch.randn(8, 4, dtype=torch.float16, device="cpu")

    params4bit = bnb.nn.Params4bit(data=original_tensor, quant_type=quant_type, requires_grad=False)

    if device != "cpu":
        params4bit = params4bit.to(device)

    chunks = torch.chunk(params4bit, 2, dim=0)

    assert isinstance(chunks, tuple), "torch.chunk should return tuple"
    for chunk in chunks:
        assert isinstance(chunk, bnb.nn.Params4bit), "Chunk should preserve Params4bit subclass"
        assert hasattr(chunk, "quant_type"), "Should preserve metadata"
        assert chunk.quant_type == params4bit.quant_type, "Should preserve quant_type value"

    splits = torch.split(params4bit, 2, dim=0)

    assert isinstance(splits, tuple), "torch.split should return tuple"
    assert len(splits) > 0, "Should have at least one split"
    for split in splits:
        assert isinstance(split, bnb.nn.Params4bit), "Split should preserve Params4bit subclass"
        assert hasattr(split, "quant_type"), "Should preserve metadata"
        assert split.quant_type == params4bit.quant_type, "Should preserve quant_type value"


250
251
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
252
@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128])
253
254
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
def test_deepcopy_param(device, quant_type, blocksize, compress_statistics):
255
256
257
    if device == "hpu" and not is_supported_on_hpu(quant_type):
        pytest.skip("This configuration is not supported on HPU.")

258
    tensor = torch.randn(300, 400)
259
260
261
262
263
264
265
    param = bnb.nn.Params4bit(
        data=tensor,
        quant_type=quant_type,
        blocksize=blocksize,
        compress_statistics=compress_statistics,
        requires_grad=False,
    ).to(device)
266
    dict_keys_before = set(param.__dict__.keys())
267
    copy_param = copy.deepcopy(param)
268
269
270
    dict_keys_after = set(param.__dict__.keys())
    dict_keys_copy = set(copy_param.__dict__.keys())

271
272
273
    assert param.quant_state is not copy_param.quant_state
    assert param.data.data_ptr() != copy_param.data.data_ptr()

274
275
276
277
    # there was a bug where deepcopy would modify the original object
    assert dict_keys_before == dict_keys_after
    assert dict_keys_before == dict_keys_copy

278

279
280
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
281
@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128])
282
283
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
def test_params4bit_real_serialization(device, quant_type, blocksize, compress_statistics):
284
285
286
    if device == "hpu" and not is_supported_on_hpu(quant_type):
        pytest.skip("This configuration is not supported on HPU.")

287
    original_tensor = torch.randn(300, 400)
288
289
290
291
292
293
    original_param = bnb.nn.Params4bit(
        data=original_tensor,
        quant_type=quant_type,
        blocksize=blocksize,
        compress_statistics=compress_statistics,
    )
294
    dict_keys_before = set(original_param.__dict__.keys())
295

296
    original_param.to(device)  # change device to trigger quantization
297
298
299

    serialized_param = pickle.dumps(original_param)
    deserialized_param = pickle.loads(serialized_param)
300
301
    dict_keys_after = set(original_param.__dict__.keys())
    dict_keys_deserialized = set(deserialized_param.__dict__.keys())
302
303
304
305
306
307
308

    assert torch.equal(original_param.data, deserialized_param.data)
    assert original_param.requires_grad == deserialized_param.requires_grad == False
    assert original_param.quant_type == deserialized_param.quant_type
    assert original_param.blocksize == deserialized_param.blocksize
    assert original_param.compress_statistics == deserialized_param.compress_statistics
    assert original_param.quant_state == deserialized_param.quant_state
309
310
311
312

    # there was a bug where deepcopy would modify the original object
    assert dict_keys_before == dict_keys_after
    assert dict_keys_before == dict_keys_deserialized
313
314
315
316
317
318
319
320
321
322
323


@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
@pytest.mark.parametrize("compute_dtype", [torch.bfloat16, torch.float32], ids=describe_dtype)
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
@pytest.mark.parametrize("bias", TRUE_FALSE, ids=id_formatter("bias"))
@pytest.mark.parametrize("fullgraph", TRUE_FALSE, ids=id_formatter("fullgraph"))
@pytest.mark.parametrize("mode", ["default", "reduce-overhead"], ids=id_formatter("mode"))
@pytest.mark.skipif(torch.__version__ < (2, 4), reason="Not supported in torch < 2.4")
def test_linear4bit_torch_compile(device, quant_type, compute_dtype, compress_statistics, bias, fullgraph, mode):
324
325
326
    if device == "hpu" and not is_supported_on_hpu(quant_type):
        pytest.skip("This configuration is not supported on HPU.")

327
    if fullgraph and torch.__version__ < (2, 8, 0, "dev"):
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
        pytest.skip("fullgraph mode requires torch 2.8 or higher")

    if device == "cuda" and platform.system() == "Windows":
        pytest.skip("Triton is not officially supported on Windows")

    # Has a strange regression on Linux aarch64 CPU in torch==2.6.0 when fullgraph=False.
    if (
        not fullgraph
        and device == "cpu"
        and platform.machine() == "aarch64"
        and platform.system() == "Linux"
        and ((2, 7) > torch.__version__ >= (2, 6))
    ):
        pytest.xfail("Regression in torch==2.6.0 on Linux aarch64 CPU")

    dim = 256
    batch_size = 16

    torch.compiler.reset()

    # Create a small network with Linear4bit layers
    net = torch.nn.Sequential(
        *[
            bnb.nn.Linear4bit(
                dim,
                dim,
                bias=bias,
                compute_dtype=compute_dtype,
                compress_statistics=compress_statistics,
                quant_type=quant_type,
            )
            for _ in range(4)
        ]
    ).to(device)

    # Create input tensor
    x = torch.randn(batch_size, dim, dtype=compute_dtype, device=device)

    # Get reference output before compilation
    with torch.no_grad():
        ref_output = net(x)

    # Compile the model
371
372
    compile_backend = "hpu_backend" if device == "hpu" else "inductor"
    compiled_net = torch.compile(net, fullgraph=fullgraph, mode=mode, backend=compile_backend)
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395

    # Get output from compiled model
    with torch.no_grad():
        compiled_output = compiled_net(x)

    # Check outputs match
    assert compiled_output.shape == ref_output.shape
    assert compiled_output.device == ref_output.device
    assert compiled_output.dtype == ref_output.dtype
    torch.testing.assert_close(compiled_output, ref_output)

    # Test with gradients
    x.requires_grad_(True)
    y1 = net(x).sum()
    y1.backward()
    grad_ref = x.grad.clone()

    x.grad = None
    y2 = compiled_net(x).sum()
    y2.backward()
    grad_compiled = x.grad.clone()

    torch.testing.assert_close(grad_compiled, grad_ref)