"vscode:/vscode.git/clone" did not exist on "d9a3018fb7b312f471cbe856a7bbe36076574d37"
test_linear4bit.py 9.84 KB
Newer Older
1
import copy
Aarni Koskela's avatar
Aarni Koskela committed
2
import os
3
import pickle
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
4
5
6
7
8
9
from tempfile import TemporaryDirectory

import pytest
import torch

import bitsandbytes as bnb
10
from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter, torch_load_from_buffer, torch_save_to_buffer
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
11

12
storage = {
13
14
15
16
    "uint8": torch.uint8,
    "float16": torch.float16,
    "bfloat16": torch.bfloat16,
    "float32": torch.float32,
17
}
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
18

19

20
@pytest.mark.parametrize("device", get_available_devices())
21
@pytest.mark.parametrize("quant_storage", ["uint8", "float16", "bfloat16", "float32"])
22
23
@pytest.mark.parametrize("bias", TRUE_FALSE, ids=id_formatter("bias"))
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
Aarni Koskela's avatar
Aarni Koskela committed
24
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
25
26
27
@pytest.mark.parametrize("save_before_forward", TRUE_FALSE, ids=id_formatter("save_before_forward"))
def test_linear_serialization(device, quant_type, compress_statistics, bias, quant_storage, save_before_forward):
    if device == "cpu":
28
29
30
31
        if quant_type == "fp4":
            pytest.xfail("FP4 is not supported for CPU")
        if quant_storage != "uint8":
            pytest.xfail("Only uint8 storage is supported for CPU")
32

Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
33
34
35
36
    original_dtype = torch.float16
    compute_dtype = None
    layer_shape = (300, 400)

Ruff's avatar
Ruff committed
37
    linear = torch.nn.Linear(*layer_shape, dtype=original_dtype, device="cpu")  # original layer
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
38
39
40
41
42
43
44
45
46

    # Quantizing original layer
    linear_q = bnb.nn.Linear4bit(
        linear.in_features,
        linear.out_features,
        bias=bias,
        compute_dtype=compute_dtype,
        compress_statistics=compress_statistics,
        quant_type=quant_type,
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
47
        device="meta",
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
48
    )
Ruff's avatar
Ruff committed
49
    new_weight = bnb.nn.Params4bit(data=linear.weight, quant_type=quant_type, requires_grad=False)
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
50
    linear_q.weight = new_weight
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
51
    if bias:
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
52
53
        linear_q.bias = torch.nn.Parameter(linear.bias)
    linear_q = linear_q.to(device)
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
54

55
    # saving to state_dict:
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
56
    sd = linear_q.state_dict()
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
57

58
59
60
    # restoring from state_dict:
    bias_data2 = sd.pop("bias", None)
    weight_data2 = sd.pop("weight")
61
    weight2 = bnb.nn.Params4bit.from_prequantized(quantized_stats=sd, data=weight_data2, device=device)
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
62

63
    # creating new layer with same params:
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
64
65
66
67
68
69
70
    linear_q2 = bnb.nn.Linear4bit(
        linear.in_features,
        linear.out_features,
        bias=bias,
        compute_dtype=compute_dtype,
        compress_statistics=compress_statistics,
        quant_type=quant_type,
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
71
        device="meta",
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
72
    )
73
    # loading weights from state_dict:
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
74
    linear_q2.weight = weight2
75
76
    if bias:
        linear_q2.bias = torch.nn.Parameter(bias_data2)
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
77
    linear_q2 = linear_q2.to(device)
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
78

79
    # MATCHING
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
80
81
    a, b = linear_q.weight, linear_q2.weight

82
83
84
85
86
87
88
89
90
91
92
    # Quantizing original layer with specified quant_storage type
    linear_qs = bnb.nn.Linear4bit(
        linear.in_features,
        linear.out_features,
        bias=bias,
        compute_dtype=compute_dtype,
        compress_statistics=compress_statistics,
        quant_type=quant_type,
        quant_storage=storage[quant_storage],
        device="meta",
    )
93
94
95
96
97
98
    linear_qs.weight = bnb.nn.Params4bit(
        data=linear.weight,
        requires_grad=False,
        quant_type=quant_type,
        quant_storage=storage[quant_storage],
    )
99
100
101
102
    if bias:
        linear_qs.bias = torch.nn.Parameter(linear.bias)
    linear_qs = linear_qs.to(device)

Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
103
104
105
    assert a.device == b.device
    assert a.dtype == b.dtype
    assert torch.equal(a, b)
106

Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
107
108
    q0 = a.quant_state
    q1 = b.quant_state
109
    for attr in ("code", "dtype", "blocksize", "absmax"):
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
110
111
112
113
114
115
116
        c, d = getattr(q0, attr), getattr(q1, attr)
        if isinstance(c, torch.Tensor):
            assert torch.equal(c, d)
        else:
            assert c == d, f"{c} != {d}"

    if q0.state2 is not None:
117
        for attr in ("code", "dtype", "blocksize", "absmax"):
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
118
119
120
121
122
123
124
125
126
127
128
129
            c, d = getattr(q0.state2, attr), getattr(q1.state2, attr)
            if isinstance(c, torch.Tensor):
                assert torch.equal(c, d)
            else:
                assert c == d, f"{c} != {d}"

    if bias:
        a, b = linear_q.bias, linear_q2.bias
        assert a.device == b.device
        assert a.dtype == b.dtype
        assert torch.equal(a, b)

130
131
132
    if save_before_forward:
        bytes_4bit = torch_save_to_buffer(linear_q)

Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
133
    # Forward test
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
134
    x = torch.rand(42, layer_shape[0], device=device)
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
135
136
    a = linear_q(x)
    b = linear_q2(x)
137
    c = linear_qs(x)
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
138
139
    assert a.device == b.device
    assert a.dtype == b.dtype
140
141
    assert a.device == c.device
    assert a.dtype == c.dtype
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
142
    assert torch.equal(a, b)
143
144
    assert torch.equal(a, c)

145
146
147
148
    if not save_before_forward:
        bytes_4bit = torch_save_to_buffer(linear_q)
    linear_q3 = torch_load_from_buffer(bytes_4bit)

149
    # Test moving to CPU and back to GPU
150
151
152
    if device != "cpu":
        linear_q2.to("cpu")
        linear_q2.to(device)
153
154
155
156
    d = linear_qs(x)
    assert c.dtype == d.dtype
    assert c.device == d.device
    assert torch.equal(c, d)
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
157

158
159
160
161
162
    d = linear_q3(x)
    assert c.dtype == d.dtype
    assert c.device == d.device
    assert torch.equal(c, d)

Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
163
164
165
166
167
168
169
    # Saved size ratio test. Target set for layer_shape == (300, 400) w/ bias
    with TemporaryDirectory() as tmpdir:
        state_path_4bit = os.path.join(tmpdir, "state_4bit.pth")
        state_path = os.path.join(tmpdir, "state.pth")
        torch.save(linear.state_dict(), state_path)
        torch.save(linear_q.state_dict(), state_path_4bit)

170
171
172
        size_orig, size_4 = (
            os.path.getsize(state_path),
            os.path.getsize(state_path_4bit),
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
173
174
        )
        size_ratio = size_4 / size_orig
175
176
177
        target_compression = (
            0.143 if original_dtype == torch.float32 else 0.29
        )  # these numbers get lower as weight shape increases
Ruff's avatar
Ruff committed
178
179
180
        ratio_error_msg = (
            f"quantized_size {size_4:,} is larger on disk than {target_compression:.2%} of original size {size_orig:,}"
        )
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
181
        assert size_ratio < target_compression, ratio_error_msg
182
183


184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
@pytest.mark.parametrize("blocksize", [64, 128])
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
def test_copy_param(device, quant_type, blocksize, compress_statistics):
    if device == "cpu":
        if compress_statistics:
            pytest.skip("Currently segfaults on CPU")
        if quant_type == "fp4":
            pytest.xfail("FP4 not supported on CPU")

    tensor = torch.linspace(1, blocksize, blocksize)
    param = bnb.nn.Params4bit(
        data=tensor,
        quant_type=quant_type,
        blocksize=blocksize,
        compress_statistics=compress_statistics,
        requires_grad=False,
    ).to(device)
203
204
205
206
207
208

    shallow_copy_param = copy.copy(param)
    assert param.quant_state is shallow_copy_param.quant_state
    assert param.data.data_ptr() == shallow_copy_param.data.data_ptr()


209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
@pytest.mark.parametrize("blocksize", [64, 128])
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
def test_deepcopy_param(device, quant_type, blocksize, compress_statistics):
    if device == "cpu":
        if compress_statistics:
            pytest.skip("Currently segfaults on CPU")
        if quant_type == "fp4":
            pytest.xfail("FP4 not supported on CPU")

    tensor = torch.linspace(1, blocksize, blocksize)
    param = bnb.nn.Params4bit(
        data=tensor,
        quant_type=quant_type,
        blocksize=blocksize,
        compress_statistics=compress_statistics,
        requires_grad=False,
    ).to(device)
228
    dict_keys_before = set(param.__dict__.keys())
229
    copy_param = copy.deepcopy(param)
230
231
232
    dict_keys_after = set(param.__dict__.keys())
    dict_keys_copy = set(copy_param.__dict__.keys())

233
234
235
    assert param.quant_state is not copy_param.quant_state
    assert param.data.data_ptr() != copy_param.data.data_ptr()

236
237
238
239
    # there was a bug where deepcopy would modify the original object
    assert dict_keys_before == dict_keys_after
    assert dict_keys_before == dict_keys_copy

240

241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
@pytest.mark.parametrize("blocksize", [64, 128])
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
def test_params4bit_real_serialization(device, quant_type, blocksize, compress_statistics):
    if device == "cpu":
        if compress_statistics:
            pytest.skip("Currently segfaults on CPU")
        if quant_type == "fp4":
            pytest.xfail("FP4 not supported on CPU")

    original_tensor = torch.linspace(1, blocksize, blocksize, dtype=torch.float32)
    original_param = bnb.nn.Params4bit(
        data=original_tensor,
        quant_type=quant_type,
        blocksize=blocksize,
        compress_statistics=compress_statistics,
    )
259
    dict_keys_before = set(original_param.__dict__.keys())
260

261
    original_param.to(device)  # change device to trigger quantization
262
263
264

    serialized_param = pickle.dumps(original_param)
    deserialized_param = pickle.loads(serialized_param)
265
266
    dict_keys_after = set(original_param.__dict__.keys())
    dict_keys_deserialized = set(deserialized_param.__dict__.keys())
267
268
269
270
271
272
273

    assert torch.equal(original_param.data, deserialized_param.data)
    assert original_param.requires_grad == deserialized_param.requires_grad == False
    assert original_param.quant_type == deserialized_param.quant_type
    assert original_param.blocksize == deserialized_param.blocksize
    assert original_param.compress_statistics == deserialized_param.compress_statistics
    assert original_param.quant_state == deserialized_param.quant_state
274
275
276
277

    # there was a bug where deepcopy would modify the original object
    assert dict_keys_before == dict_keys_after
    assert dict_keys_before == dict_keys_deserialized