test_linear4bit.py 4.83 KB
Newer Older
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
1
2
3
4
5
6
7
8
9
10
import os
from contextlib import nullcontext
from itertools import product
from tempfile import TemporaryDirectory

import pytest
import torch

import bitsandbytes as bnb

11
12
13
14
15
16
storage = {
    'uint8': torch.uint8,
    'float16': torch.float16,
    'bfloat16': torch.bfloat16,
    'float32': torch.float32
}
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
17
18

@pytest.mark.parametrize(
19
20
    "quant_type, compress_statistics, bias, quant_storage",
    list(product(["nf4", "fp4"], [False, True], [False, True], ['uint8', 'float16', 'bfloat16', 'float32'])),
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
21
)
22
def test_linear_serialization(quant_type, compress_statistics, bias, quant_storage):
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
23
24
25
26
27
    original_dtype = torch.float16
    compute_dtype = None
    device = "cuda"
    layer_shape = (300, 400)

Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
28
    linear = torch.nn.Linear(*layer_shape, dtype=original_dtype, device="cpu")  # original layer
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
29
30
31
32
33
34
35
36
37

    # Quantizing original layer
    linear_q = bnb.nn.Linear4bit(
        linear.in_features,
        linear.out_features,
        bias=bias,
        compute_dtype=compute_dtype,
        compress_statistics=compress_statistics,
        quant_type=quant_type,
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
38
        device="meta",
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
39
    )
40
    new_weight = bnb.nn.Params4bit(data=linear.weight, quant_type=quant_type, requires_grad=False)
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
41
    linear_q.weight = new_weight
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
42
    if bias:
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
43
44
        linear_q.bias = torch.nn.Parameter(linear.bias)
    linear_q = linear_q.to(device)
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
45

46
    # saving to state_dict:
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
47
    sd = linear_q.state_dict()
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
48

49
50
51
52
    # restoring from state_dict:
    bias_data2 = sd.pop("bias", None)
    weight_data2 = sd.pop("weight")
    weight2 = bnb.nn.Params4bit.from_prequantized(quantized_stats=sd, data=weight_data2)
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
53

54
    # creating new layer with same params:
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
55
56
57
58
59
60
61
    linear_q2 = bnb.nn.Linear4bit(
        linear.in_features,
        linear.out_features,
        bias=bias,
        compute_dtype=compute_dtype,
        compress_statistics=compress_statistics,
        quant_type=quant_type,
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
62
        device="meta",
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
63
    )
64
    # loading weights from state_dict:
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
65
    linear_q2.weight = weight2
66
67
    if bias:
        linear_q2.bias = torch.nn.Parameter(bias_data2)
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
68
    linear_q2 = linear_q2.to(device)
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
69

70
    # MATCHING
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
71
72
    a, b = linear_q.weight, linear_q2.weight

73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
    # Quantizing original layer with specified quant_storage type
    linear_qs = bnb.nn.Linear4bit(
        linear.in_features,
        linear.out_features,
        bias=bias,
        compute_dtype=compute_dtype,
        compress_statistics=compress_statistics,
        quant_type=quant_type,
        quant_storage=storage[quant_storage],
        device="meta",
    )
    linear_qs.weight = bnb.nn.Params4bit(data=linear.weight, requires_grad=False, quant_type=quant_type, quant_storage=storage[quant_storage])
    if bias:
        linear_qs.bias = torch.nn.Parameter(linear.bias)
    linear_qs = linear_qs.to(device)

Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
89
90
91
    assert a.device == b.device
    assert a.dtype == b.dtype
    assert torch.equal(a, b)
92

Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
    q0 = a.quant_state
    q1 = b.quant_state
    for attr in ('code', 'dtype', 'blocksize', 'absmax'):
        c, d = getattr(q0, attr), getattr(q1, attr)
        if isinstance(c, torch.Tensor):
            assert torch.equal(c, d)
        else:
            assert c == d, f"{c} != {d}"

    if q0.state2 is not None:
        for attr in ('code', 'dtype', 'blocksize', 'absmax'):
            c, d = getattr(q0.state2, attr), getattr(q1.state2, attr)
            if isinstance(c, torch.Tensor):
                assert torch.equal(c, d)
            else:
                assert c == d, f"{c} != {d}"

    if bias:
        a, b = linear_q.bias, linear_q2.bias
        assert a.device == b.device
        assert a.dtype == b.dtype
        assert torch.equal(a, b)

    # Forward test
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
117
    x = torch.rand(42, layer_shape[0], device=device)
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
118
119
    a = linear_q(x)
    b = linear_q2(x)
120
    c = linear_qs(x)
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
121
122
    assert a.device == b.device
    assert a.dtype == b.dtype
123
124
    assert a.device == c.device
    assert a.dtype == c.dtype
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
125
    assert torch.equal(a, b)
126
127
128
129
130
131
132
133
134
    assert torch.equal(a, c)

    # Test moving to CPU and back to GPU
    linear_q2.to('cpu')
    linear_q2.to(device)
    d = linear_qs(x)
    assert c.dtype == d.dtype
    assert c.device == d.device
    assert torch.equal(c, d)
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
135
136
137
138
139
140
141
142
143
144
145
146

    # Saved size ratio test. Target set for layer_shape == (300, 400) w/ bias
    with TemporaryDirectory() as tmpdir:
        state_path_4bit = os.path.join(tmpdir, "state_4bit.pth")
        state_path = os.path.join(tmpdir, "state.pth")
        torch.save(linear.state_dict(), state_path)
        torch.save(linear_q.state_dict(), state_path_4bit)

        size_orig, size_4 = os.path.getsize(state_path), os.path.getsize(
            state_path_4bit
        )
        size_ratio = size_4 / size_orig
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
147
        target_compression = 0.143 if original_dtype == torch.float32 else 0.29  # these numbers get lower as weight shape increases
Ruslan Svirschevski's avatar
Ruslan Svirschevski committed
148
149
        ratio_error_msg = f"quantized_size {size_4:,} is larger on disk than {target_compression:.2%} of original size {size_orig:,}"
        assert size_ratio < target_compression, ratio_error_msg