test_linear4bit.py 7.51 KB
Newer Older
1
import copy
Aarni Koskela's avatar
Aarni Koskela committed
2
import os
3
import pickle
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
4
5
6
7
8
9
from tempfile import TemporaryDirectory

import pytest
import torch

import bitsandbytes as bnb
10
from tests.helpers import TRUE_FALSE, torch_load_from_buffer, torch_save_to_buffer
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
11

12
storage = {
13
14
15
16
    "uint8": torch.uint8,
    "float16": torch.float16,
    "bfloat16": torch.bfloat16,
    "float32": torch.float32,
17
}
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
18

19
20

@pytest.mark.parametrize("quant_storage", ["uint8", "float16", "bfloat16", "float32"])
Aarni Koskela's avatar
Aarni Koskela committed
21
22
23
@pytest.mark.parametrize("bias", TRUE_FALSE)
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE)
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
24
25
@pytest.mark.parametrize("save_before_forward", TRUE_FALSE)
def test_linear_serialization(quant_type, compress_statistics, bias, quant_storage, save_before_forward):
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
26
27
28
29
30
    original_dtype = torch.float16
    compute_dtype = None
    device = "cuda"
    layer_shape = (300, 400)

Ruff's avatar
Ruff committed
31
    linear = torch.nn.Linear(*layer_shape, dtype=original_dtype, device="cpu")  # original layer
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
32
33
34
35
36
37
38
39
40

    # Quantizing original layer
    linear_q = bnb.nn.Linear4bit(
        linear.in_features,
        linear.out_features,
        bias=bias,
        compute_dtype=compute_dtype,
        compress_statistics=compress_statistics,
        quant_type=quant_type,
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
41
        device="meta",
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
42
    )
Ruff's avatar
Ruff committed
43
    new_weight = bnb.nn.Params4bit(data=linear.weight, quant_type=quant_type, requires_grad=False)
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
44
    linear_q.weight = new_weight
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
45
    if bias:
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
46
47
        linear_q.bias = torch.nn.Parameter(linear.bias)
    linear_q = linear_q.to(device)
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
48

49
    # saving to state_dict:
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
50
    sd = linear_q.state_dict()
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
51

52
53
54
55
    # restoring from state_dict:
    bias_data2 = sd.pop("bias", None)
    weight_data2 = sd.pop("weight")
    weight2 = bnb.nn.Params4bit.from_prequantized(quantized_stats=sd, data=weight_data2)
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
56

57
    # creating new layer with same params:
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
58
59
60
61
62
63
64
    linear_q2 = bnb.nn.Linear4bit(
        linear.in_features,
        linear.out_features,
        bias=bias,
        compute_dtype=compute_dtype,
        compress_statistics=compress_statistics,
        quant_type=quant_type,
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
65
        device="meta",
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
66
    )
67
    # loading weights from state_dict:
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
68
    linear_q2.weight = weight2
69
70
    if bias:
        linear_q2.bias = torch.nn.Parameter(bias_data2)
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
71
    linear_q2 = linear_q2.to(device)
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
72

73
    # MATCHING
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
74
75
    a, b = linear_q.weight, linear_q2.weight

76
77
78
79
80
81
82
83
84
85
86
    # Quantizing original layer with specified quant_storage type
    linear_qs = bnb.nn.Linear4bit(
        linear.in_features,
        linear.out_features,
        bias=bias,
        compute_dtype=compute_dtype,
        compress_statistics=compress_statistics,
        quant_type=quant_type,
        quant_storage=storage[quant_storage],
        device="meta",
    )
87
88
89
90
91
92
    linear_qs.weight = bnb.nn.Params4bit(
        data=linear.weight,
        requires_grad=False,
        quant_type=quant_type,
        quant_storage=storage[quant_storage],
    )
93
94
95
96
    if bias:
        linear_qs.bias = torch.nn.Parameter(linear.bias)
    linear_qs = linear_qs.to(device)

Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
97
98
99
    assert a.device == b.device
    assert a.dtype == b.dtype
    assert torch.equal(a, b)
100

Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
101
102
    q0 = a.quant_state
    q1 = b.quant_state
103
    for attr in ("code", "dtype", "blocksize", "absmax"):
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
104
105
106
107
108
109
110
        c, d = getattr(q0, attr), getattr(q1, attr)
        if isinstance(c, torch.Tensor):
            assert torch.equal(c, d)
        else:
            assert c == d, f"{c} != {d}"

    if q0.state2 is not None:
111
        for attr in ("code", "dtype", "blocksize", "absmax"):
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
112
113
114
115
116
117
118
119
120
121
122
123
            c, d = getattr(q0.state2, attr), getattr(q1.state2, attr)
            if isinstance(c, torch.Tensor):
                assert torch.equal(c, d)
            else:
                assert c == d, f"{c} != {d}"

    if bias:
        a, b = linear_q.bias, linear_q2.bias
        assert a.device == b.device
        assert a.dtype == b.dtype
        assert torch.equal(a, b)

124
125
126
    if save_before_forward:
        bytes_4bit = torch_save_to_buffer(linear_q)

Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
127
    # Forward test
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
128
    x = torch.rand(42, layer_shape[0], device=device)
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
129
130
    a = linear_q(x)
    b = linear_q2(x)
131
    c = linear_qs(x)
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
132
133
    assert a.device == b.device
    assert a.dtype == b.dtype
134
135
    assert a.device == c.device
    assert a.dtype == c.dtype
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
136
    assert torch.equal(a, b)
137
138
    assert torch.equal(a, c)

139
140
141
142
    if not save_before_forward:
        bytes_4bit = torch_save_to_buffer(linear_q)
    linear_q3 = torch_load_from_buffer(bytes_4bit)

143
    # Test moving to CPU and back to GPU
144
    linear_q2.to("cpu")
145
146
147
148
149
    linear_q2.to(device)
    d = linear_qs(x)
    assert c.dtype == d.dtype
    assert c.device == d.device
    assert torch.equal(c, d)
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
150

151
152
153
154
155
    d = linear_q3(x)
    assert c.dtype == d.dtype
    assert c.device == d.device
    assert torch.equal(c, d)

Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
156
157
158
159
160
161
162
    # Saved size ratio test. Target set for layer_shape == (300, 400) w/ bias
    with TemporaryDirectory() as tmpdir:
        state_path_4bit = os.path.join(tmpdir, "state_4bit.pth")
        state_path = os.path.join(tmpdir, "state.pth")
        torch.save(linear.state_dict(), state_path)
        torch.save(linear_q.state_dict(), state_path_4bit)

163
164
165
        size_orig, size_4 = (
            os.path.getsize(state_path),
            os.path.getsize(state_path_4bit),
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
166
167
        )
        size_ratio = size_4 / size_orig
168
169
170
        target_compression = (
            0.143 if original_dtype == torch.float32 else 0.29
        )  # these numbers get lower as weight shape increases
Ruff's avatar
Ruff committed
171
172
173
        ratio_error_msg = (
            f"quantized_size {size_4:,} is larger on disk than {target_compression:.2%} of original size {size_orig:,}"
        )
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
174
        assert size_ratio < target_compression, ratio_error_msg
175
176
177
178
179
180
181
182
183
184
185
186
187
188


def test_copy_param():
    tensor = torch.tensor([1.0, 2.0, 3.0, 4.0])
    param = bnb.nn.Params4bit(data=tensor, requires_grad=False).cuda(0)

    shallow_copy_param = copy.copy(param)
    assert param.quant_state is shallow_copy_param.quant_state
    assert param.data.data_ptr() == shallow_copy_param.data.data_ptr()


def test_deepcopy_param():
    tensor = torch.tensor([1.0, 2.0, 3.0, 4.0])
    param = bnb.nn.Params4bit(data=tensor, requires_grad=False).cuda(0)
189
    dict_keys_before = set(param.__dict__.keys())
190
    copy_param = copy.deepcopy(param)
191
192
193
    dict_keys_after = set(param.__dict__.keys())
    dict_keys_copy = set(copy_param.__dict__.keys())

194
195
196
    assert param.quant_state is not copy_param.quant_state
    assert param.data.data_ptr() != copy_param.data.data_ptr()

197
198
199
200
    # there was a bug where deepcopy would modify the original object
    assert dict_keys_before == dict_keys_after
    assert dict_keys_before == dict_keys_copy

201
202
203
204

def test_params4bit_real_serialization():
    original_tensor = torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float32)
    original_param = bnb.nn.Params4bit(data=original_tensor, quant_type="fp4")
205
    dict_keys_before = set(original_param.__dict__.keys())
206
207
208
209
210

    original_param.cuda(0)  # move to CUDA to trigger quantization

    serialized_param = pickle.dumps(original_param)
    deserialized_param = pickle.loads(serialized_param)
211
212
    dict_keys_after = set(original_param.__dict__.keys())
    dict_keys_deserialized = set(deserialized_param.__dict__.keys())
213
214
215
216
217
218
219

    assert torch.equal(original_param.data, deserialized_param.data)
    assert original_param.requires_grad == deserialized_param.requires_grad == False
    assert original_param.quant_type == deserialized_param.quant_type
    assert original_param.blocksize == deserialized_param.blocksize
    assert original_param.compress_statistics == deserialized_param.compress_statistics
    assert original_param.quant_state == deserialized_param.quant_state
220
221
222
223

    # there was a bug where deepcopy would modify the original object
    assert dict_keys_before == dict_keys_after
    assert dict_keys_before == dict_keys_deserialized