test_linear8bitlt.py 6.24 KB
Newer Older
1
from contextlib import nullcontext
Aarni Koskela's avatar
Aarni Koskela committed
2
import os
3
from tempfile import TemporaryDirectory
4

5
6
7
import pytest
import torch

8
9
import bitsandbytes as bnb
from bitsandbytes import functional as F
10
11
from bitsandbytes.autograd import get_inverse_transform_indices, undo_layout
from bitsandbytes.nn.modules import Linear8bitLt
12
13
14
15
16
17
from tests.helpers import (
    TRUE_FALSE,
    id_formatter,
    torch_load_from_buffer,
    torch_save_to_buffer,
)
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38

# contributed by Alex Borzunov, see:
# https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py

@pytest.mark.skipif(
    not torch.cuda.is_available() or torch.cuda.get_device_capability() < (7, 5),
    reason="this test requires a turing-generation or newer GPU, see bitsandbytes docs",
)
def test_layout_exact_match():
    x = (torch.randn(14336 * 3, 14336) * 10).to(torch.int8).cuda()
    for tile_size, order in ((8, 32), "col_turing"), ((32, 32), "col_ampere"):
        transform = lambda x: F.transform(x.cuda(), from_order="row", to_order=order)[0].to(x.device)
        tile_indices = get_inverse_transform_indices(transform, tile_size)
        cxb = transform(x)

        torch.cuda.synchronize()
        restored_x = undo_layout(cxb, tile_indices)
        torch.cuda.synchronize()
        assert restored_x.is_contiguous()
        assert torch.all(torch.eq(restored_x, x))

39

40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def test_linear_no_igemmlt():
    linear = torch.nn.Linear(1024, 3072)
    x = torch.randn(3, 1024, dtype=torch.half)
    linear_custom = Linear8bitLt(
        linear.in_features,
        linear.out_features,
        linear.bias is not None,
        has_fp16_weights=False,
        threshold=6.0,
    )
    linear_custom.state.force_no_igemmlt = True

    linear_custom.weight = bnb.nn.Int8Params(
        linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False
    ).to(linear.weight.dtype)
    linear_custom.bias = linear.bias
56
    linear_custom = linear_custom.cuda()
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
    linear = linear.half().cuda()

    x_ref = x.clone().cuda().requires_grad_(True)
    x_ours = x.clone().cuda().requires_grad_(True)
    fx_ref = linear(x_ref).float()
    grad_proj = torch.randn_like(fx_ref)
    (fx_ref * grad_proj).mean().backward()

    fx_ours = linear_custom(x_ours).float()
    (fx_ours * grad_proj).mean().backward()
    assert torch.allclose(fx_ref, fx_ours, atol=0.02)
    assert torch.allclose(x_ref.grad, x_ours.grad, atol=0.01)
    assert not linear_custom.state.has_fp16_weights
    assert linear_custom.state.CB is not None
    assert linear_custom.state.CxB is None
72
73


Aarni Koskela's avatar
Aarni Koskela committed
74
75
76
77
@pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights"))
@pytest.mark.parametrize("serialize_before_forward", TRUE_FALSE, ids=id_formatter("serialize_before_forward"))
@pytest.mark.parametrize("deserialize_before_cuda", TRUE_FALSE, ids=id_formatter("deserialize_before_cuda"))
@pytest.mark.parametrize("force_no_igemmlt", TRUE_FALSE, ids=id_formatter("force_no_igemmlt"))
78
79
80
@pytest.mark.parametrize("save_before_forward", TRUE_FALSE, ids=id_formatter("save_before_forward"))
@pytest.mark.parametrize("load_before_cuda", TRUE_FALSE, ids=id_formatter("load_before_cuda"))
def test_linear_serialization(has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt, save_before_forward, load_before_cuda):
81
82
    linear = torch.nn.Linear(32, 96)
    x = torch.randn(3, 32, dtype=torch.half)
83
84
85
86
87
88
89
90

    linear_custom = Linear8bitLt(
        linear.in_features,
        linear.out_features,
        linear.bias is not None,
        has_fp16_weights=has_fp16_weights,
        threshold=6.0,
    )
91
92
93
    if force_no_igemmlt:
        linear_custom.state.force_no_igemmlt = True

94
    linear_custom.weight = bnb.nn.Int8Params(
95
96
        linear.weight.data.clone(), requires_grad=has_fp16_weights, has_fp16_weights=has_fp16_weights
    )
97
98
99
    linear_custom.bias = linear.bias
    linear_custom = linear_custom.cuda()

100
101
102
    if serialize_before_forward:
        state_dict_8bit = linear_custom.state_dict()

103
104
105
    if save_before_forward:
        bytes_8bit = torch_save_to_buffer(linear_custom)

106
107
108
109
110
    x_first = x.clone().cuda().requires_grad_(True)
    fx_first = linear_custom(x_first).float()
    grad_proj = torch.randn_like(fx_first)
    (fx_first * grad_proj).mean().backward()

111
112
113
    if not serialize_before_forward:
        state_dict_8bit = linear_custom.state_dict()

114
115
116
    if not save_before_forward:
        bytes_8bit = torch_save_to_buffer(linear_custom)

117
118
119
120
121
122
123
124
125
126
127
    with TemporaryDirectory() as tmpdir:
        state_path_8bit = os.path.join(tmpdir, "state_8bit.pth")
        state_path = os.path.join(tmpdir, "state.pth")

        torch.save(linear.state_dict(), state_path)
        torch.save(state_dict_8bit, state_path_8bit)

        if not has_fp16_weights:
            assert os.path.getsize(state_path_8bit) < 0.5 * os.path.getsize(state_path)

        new_state_dict = torch.load(state_path_8bit)
128
129
130
131
132
133
134
135

    new_linear_custom = Linear8bitLt(
        linear.in_features,
        linear.out_features,
        linear.bias is not None,
        has_fp16_weights=has_fp16_weights,
        threshold=6.0,
    )
136
137
    if force_no_igemmlt:
        new_linear_custom.state.force_no_igemmlt = True
138
139
140
141
142

    if deserialize_before_cuda:
        with nullcontext() if has_fp16_weights else pytest.raises(RuntimeError):
            new_linear_custom.load_state_dict(new_state_dict, strict=True)

143
144
145
    if load_before_cuda:
        new_linear_custom2 = torch_load_from_buffer(bytes_8bit)

146
    new_linear_custom = new_linear_custom.cuda()
147
148
149

    if not deserialize_before_cuda:
        new_linear_custom.load_state_dict(new_state_dict, strict=True)
150

151
152
153
    if not load_before_cuda:
        new_linear_custom2 = torch_load_from_buffer(bytes_8bit)

154
155
156
157
    x_second = x.clone().cuda().requires_grad_(True)
    fx_second = new_linear_custom(x_second).float()
    (fx_second * grad_proj).mean().backward()

158
159
160
161
    x_third = x.clone().cuda().requires_grad_(True)
    fx_third = new_linear_custom2(x_third).float()
    (fx_third * grad_proj).mean().backward()

162
163
164
165
    # if 8-bit weights were loaded before .cuda, state is incorrect anyway and RuntimeError was raised
    if has_fp16_weights or not deserialize_before_cuda:
        assert torch.allclose(fx_first, fx_second, atol=1e-5)
        assert torch.allclose(x_first.grad, x_second.grad, atol=1e-5)
166
    assert torch.allclose(fx_first, fx_third, atol=1e-5)
167
    assert torch.allclose(x_first.grad, x_third.grad, atol=1e-5)