test_linear4bit.py 7.06 KB
Newer Older
1
import copy
2
from io import BytesIO
Aarni Koskela's avatar
Aarni Koskela committed
3
import os
4
import pickle
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
5
6
7
8
9
10
from tempfile import TemporaryDirectory

import pytest
import torch

import bitsandbytes as bnb
Aarni Koskela's avatar
Aarni Koskela committed
11
from tests.helpers import TRUE_FALSE
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
12

13
storage = {
14
15
16
17
    "uint8": torch.uint8,
    "float16": torch.float16,
    "bfloat16": torch.bfloat16,
    "float32": torch.float32,
18
}
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
19

20
21
22
23
24
25
26
27
28
29
30
def torch_save_to_buffer(obj):
    buffer = BytesIO()
    torch.save(obj, buffer)
    buffer.seek(0)
    return buffer

def torch_load_from_buffer(buffer):
    buffer.seek(0)
    obj = torch.load(buffer)
    buffer.seek(0)
    return obj
31
32

@pytest.mark.parametrize("quant_storage", ["uint8", "float16", "bfloat16", "float32"])
Aarni Koskela's avatar
Aarni Koskela committed
33
34
35
@pytest.mark.parametrize("bias", TRUE_FALSE)
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE)
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
36
37
@pytest.mark.parametrize("save_before_forward", TRUE_FALSE)
def test_linear_serialization(quant_type, compress_statistics, bias, quant_storage, save_before_forward):
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
38
39
40
41
42
    original_dtype = torch.float16
    compute_dtype = None
    device = "cuda"
    layer_shape = (300, 400)

43
44
45
    linear = torch.nn.Linear(
        *layer_shape, dtype=original_dtype, device="cpu"
    )  # original layer
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
46
47
48
49
50
51
52
53
54

    # Quantizing original layer
    linear_q = bnb.nn.Linear4bit(
        linear.in_features,
        linear.out_features,
        bias=bias,
        compute_dtype=compute_dtype,
        compress_statistics=compress_statistics,
        quant_type=quant_type,
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
55
        device="meta",
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
56
    )
57
58
59
    new_weight = bnb.nn.Params4bit(
        data=linear.weight, quant_type=quant_type, requires_grad=False
    )
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
60
    linear_q.weight = new_weight
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
61
    if bias:
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
62
63
        linear_q.bias = torch.nn.Parameter(linear.bias)
    linear_q = linear_q.to(device)
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
64

65
    # saving to state_dict:
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
66
    sd = linear_q.state_dict()
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
67

68
69
70
71
    # restoring from state_dict:
    bias_data2 = sd.pop("bias", None)
    weight_data2 = sd.pop("weight")
    weight2 = bnb.nn.Params4bit.from_prequantized(quantized_stats=sd, data=weight_data2)
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
72

73
    # creating new layer with same params:
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
74
75
76
77
78
79
80
    linear_q2 = bnb.nn.Linear4bit(
        linear.in_features,
        linear.out_features,
        bias=bias,
        compute_dtype=compute_dtype,
        compress_statistics=compress_statistics,
        quant_type=quant_type,
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
81
        device="meta",
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
82
    )
83
    # loading weights from state_dict:
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
84
    linear_q2.weight = weight2
85
86
    if bias:
        linear_q2.bias = torch.nn.Parameter(bias_data2)
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
87
    linear_q2 = linear_q2.to(device)
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
88

89
    # MATCHING
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
90
91
    a, b = linear_q.weight, linear_q2.weight

92
93
94
95
96
97
98
99
100
101
102
    # Quantizing original layer with specified quant_storage type
    linear_qs = bnb.nn.Linear4bit(
        linear.in_features,
        linear.out_features,
        bias=bias,
        compute_dtype=compute_dtype,
        compress_statistics=compress_statistics,
        quant_type=quant_type,
        quant_storage=storage[quant_storage],
        device="meta",
    )
103
104
105
106
107
108
    linear_qs.weight = bnb.nn.Params4bit(
        data=linear.weight,
        requires_grad=False,
        quant_type=quant_type,
        quant_storage=storage[quant_storage],
    )
109
110
111
112
    if bias:
        linear_qs.bias = torch.nn.Parameter(linear.bias)
    linear_qs = linear_qs.to(device)

Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
113
114
115
    assert a.device == b.device
    assert a.dtype == b.dtype
    assert torch.equal(a, b)
116

Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
117
118
    q0 = a.quant_state
    q1 = b.quant_state
119
    for attr in ("code", "dtype", "blocksize", "absmax"):
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
120
121
122
123
124
125
126
        c, d = getattr(q0, attr), getattr(q1, attr)
        if isinstance(c, torch.Tensor):
            assert torch.equal(c, d)
        else:
            assert c == d, f"{c} != {d}"

    if q0.state2 is not None:
127
        for attr in ("code", "dtype", "blocksize", "absmax"):
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
128
129
130
131
132
133
134
135
136
137
138
139
            c, d = getattr(q0.state2, attr), getattr(q1.state2, attr)
            if isinstance(c, torch.Tensor):
                assert torch.equal(c, d)
            else:
                assert c == d, f"{c} != {d}"

    if bias:
        a, b = linear_q.bias, linear_q2.bias
        assert a.device == b.device
        assert a.dtype == b.dtype
        assert torch.equal(a, b)

140
141
142
    if save_before_forward:
        bytes_4bit = torch_save_to_buffer(linear_q)

Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
143
    # Forward test
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
144
    x = torch.rand(42, layer_shape[0], device=device)
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
145
146
    a = linear_q(x)
    b = linear_q2(x)
147
    c = linear_qs(x)
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
148
149
    assert a.device == b.device
    assert a.dtype == b.dtype
150
151
    assert a.device == c.device
    assert a.dtype == c.dtype
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
152
    assert torch.equal(a, b)
153
154
    assert torch.equal(a, c)

155
156
157
158
    if not save_before_forward:
        bytes_4bit = torch_save_to_buffer(linear_q)
    linear_q3 = torch_load_from_buffer(bytes_4bit)

159
    # Test moving to CPU and back to GPU
160
    linear_q2.to("cpu")
161
162
163
164
165
    linear_q2.to(device)
    d = linear_qs(x)
    assert c.dtype == d.dtype
    assert c.device == d.device
    assert torch.equal(c, d)
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
166

167
168
169
170
171
    d = linear_q3(x)
    assert c.dtype == d.dtype
    assert c.device == d.device
    assert torch.equal(c, d)

Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
172
173
174
175
176
177
178
    # Saved size ratio test. Target set for layer_shape == (300, 400) w/ bias
    with TemporaryDirectory() as tmpdir:
        state_path_4bit = os.path.join(tmpdir, "state_4bit.pth")
        state_path = os.path.join(tmpdir, "state.pth")
        torch.save(linear.state_dict(), state_path)
        torch.save(linear_q.state_dict(), state_path_4bit)

179
180
181
        size_orig, size_4 = (
            os.path.getsize(state_path),
            os.path.getsize(state_path_4bit),
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
182
183
        )
        size_ratio = size_4 / size_orig
184
185
186
        target_compression = (
            0.143 if original_dtype == torch.float32 else 0.29
        )  # these numbers get lower as weight shape increases
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
187
188
        ratio_error_msg = f"quantized_size {size_4:,} is larger on disk than {target_compression:.2%} of original size {size_orig:,}"
        assert size_ratio < target_compression, ratio_error_msg
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222


def test_copy_param():
    tensor = torch.tensor([1.0, 2.0, 3.0, 4.0])
    param = bnb.nn.Params4bit(data=tensor, requires_grad=False).cuda(0)

    shallow_copy_param = copy.copy(param)
    assert param.quant_state is shallow_copy_param.quant_state
    assert param.data.data_ptr() == shallow_copy_param.data.data_ptr()


def test_deepcopy_param():
    tensor = torch.tensor([1.0, 2.0, 3.0, 4.0])
    param = bnb.nn.Params4bit(data=tensor, requires_grad=False).cuda(0)
    copy_param = copy.deepcopy(param)
    assert param.quant_state is not copy_param.quant_state
    assert param.data.data_ptr() != copy_param.data.data_ptr()


def test_params4bit_real_serialization():
    original_tensor = torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float32)
    original_param = bnb.nn.Params4bit(data=original_tensor, quant_type="fp4")

    original_param.cuda(0)  # move to CUDA to trigger quantization

    serialized_param = pickle.dumps(original_param)
    deserialized_param = pickle.loads(serialized_param)

    assert torch.equal(original_param.data, deserialized_param.data)
    assert original_param.requires_grad == deserialized_param.requires_grad == False
    assert original_param.quant_type == deserialized_param.quant_type
    assert original_param.blocksize == deserialized_param.blocksize
    assert original_param.compress_statistics == deserialized_param.compress_statistics
    assert original_param.quant_state == deserialized_param.quant_state