test_optim.py 20.3 KB
Newer Older
Tim Dettmers's avatar
Tim Dettmers committed
1
import os
Aarni Koskela's avatar
Aarni Koskela committed
2
from os.path import join
Tim Dettmers's avatar
Tim Dettmers committed
3
import shutil
4
import time
Tim Dettmers's avatar
Tim Dettmers committed
5
import uuid
6

7
from lion_pytorch import Lion
Aarni Koskela's avatar
Aarni Koskela committed
8
import pytest
Tim Dettmers's avatar
Tim Dettmers committed
9
import torch
10

Tim Dettmers's avatar
Tim Dettmers committed
11
12
import bitsandbytes as bnb
import bitsandbytes.functional as F
Aarni Koskela's avatar
Aarni Koskela committed
13
from tests.helpers import describe_dtype, id_formatter
Tim Dettmers's avatar
Tim Dettmers committed
14

15
# import apex
Tim Dettmers's avatar
Tim Dettmers committed
16
17

k = 20
Tim Dettmers's avatar
Tim Dettmers committed
18

Ruff's avatar
Ruff committed
19

20
def assert_most_approx_close(a, b, rtol=1e-3, atol=1e-3, max_error_count=0):
Tim Dettmers's avatar
Tim Dettmers committed
21
    idx = torch.isclose(a, b, rtol=rtol, atol=atol)
22
23
    error_count = (idx == 0).sum().item()
    if error_count > max_error_count:
24
        print(f"Too many values not close: assert {error_count} < {max_error_count}")
Tim Dettmers's avatar
Tim Dettmers committed
25
        torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
26

27

Tim Dettmers's avatar
Tim Dettmers committed
28
def get_temp_dir():
Aarni Koskela's avatar
Aarni Koskela committed
29
    path = f"/tmp/autoswap/{uuid.uuid4()}"
Tim Dettmers's avatar
Tim Dettmers committed
30
31
32
    os.makedirs(path, exist_ok=True)
    return path

33

Tim Dettmers's avatar
Tim Dettmers committed
34
35
36
def rm_path(path):
    shutil.rmtree(path)

Ruff's avatar
Ruff committed
37

Tim Dettmers's avatar
Tim Dettmers committed
38
str2optimizers = {}
39
str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam)
40
str2optimizers["lion_pytorch"] = (None, Lion, bnb.optim.Lion)
41
42
43
44
45
46
str2optimizers["momentum_pytorch"] = (
    None,
    lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
    bnb.optim.Adam,
)
str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam)
Tim Dettmers's avatar
Tim Dettmers committed
47
48
str2optimizers["paged_adamw"] = (torch.optim.AdamW, bnb.optim.PagedAdamW)
str2optimizers["paged_adam"] = (torch.optim.Adam, bnb.optim.PagedAdam)
49
str2optimizers["lion"] = (Lion, bnb.optim.Lion)
Tim Dettmers's avatar
Tim Dettmers committed
50
str2optimizers["paged_lion"] = (Lion, bnb.optim.PagedLion)
51
52
53
54
55
56
57
58
str2optimizers["momentum"] = (
    lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
    lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False),
)
str2optimizers["rmsprop"] = (
    lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
    lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False),
)
Tim Dettmers's avatar
Tim Dettmers committed
59
str2optimizers["adam8bit"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False))
60
str2optimizers["lion8bit"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=False))
61
62
63
64
65
66
67
68
69
str2optimizers["momentum8bit"] = (
    lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
    lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False),
)
str2optimizers["rmsprop8bit"] = (
    lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
    lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False),
)

Tim Dettmers's avatar
Tim Dettmers committed
70
str2optimizers["adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True))
Ruff's avatar
Ruff committed
71
72
73
74
75
76
77
78
str2optimizers["paged_adamw8bit_blockwise"] = (
    torch.optim.AdamW,
    lambda pxx: bnb.optim.PagedAdamW8bit(pxx, block_wise=True),
)
str2optimizers["paged_adam8bit_blockwise"] = (
    torch.optim.Adam,
    lambda pxx: bnb.optim.PagedAdam8bit(pxx, block_wise=True),
)
79
str2optimizers["lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=True))
Tim Dettmers's avatar
Tim Dettmers committed
80
str2optimizers["paged_lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.PagedLion8bit(pxx, block_wise=True))
81
82
83
84
85
86
87
88
str2optimizers["momentum8bit_blockwise"] = (
    lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
    lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True),
)
str2optimizers["rmsprop8bit_blockwise"] = (
    lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
    lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=True),
)
Tim Dettmers's avatar
Tim Dettmers committed
89
90

str2statenames = {}
91
str2statenames["adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
Tim Dettmers's avatar
Tim Dettmers committed
92
93
str2statenames["paged_adamw"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["paged_adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
94
str2statenames["lion"] = [("exp_avg", "state1")]
Tim Dettmers's avatar
Tim Dettmers committed
95
str2statenames["paged_lion"] = [("exp_avg", "state1")]
96
97
98
str2statenames["momentum"] = [("momentum_buffer", "state1")]
str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["rmsprop"] = [("square_avg", "state1")]
Tim Dettmers's avatar
Tim Dettmers committed
99
100
str2statenames["adam8bit"] = [("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2")]
str2statenames["lamb8bit"] = [("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2")]
Ruff's avatar
Ruff committed
101
102
103
104
105
106
107
108
109
110
111
112
str2statenames["adam8bit_blockwise"] = [
    ("exp_avg", "state1", "qmap1", "absmax1"),
    ("exp_avg_sq", "state2", "qmap2", "absmax2"),
]
str2statenames["paged_adam8bit_blockwise"] = [
    ("exp_avg", "state1", "qmap1", "absmax1"),
    ("exp_avg_sq", "state2", "qmap2", "absmax2"),
]
str2statenames["paged_adamw8bit_blockwise"] = [
    ("exp_avg", "state1", "qmap1", "absmax1"),
    ("exp_avg_sq", "state2", "qmap2", "absmax2"),
]
Tim Dettmers's avatar
Tim Dettmers committed
113
str2statenames["momentum8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")]
114
str2statenames["lion8bit"] = [("exp_avg", "state1", "qmap1", "max1")]
Tim Dettmers's avatar
Tim Dettmers committed
115
str2statenames["momentum8bit_blockwise"] = [("momentum_buffer", "state1", "qmap1", "absmax1")]
116
str2statenames["rmsprop8bit"] = [("square_avg", "state1", "qmap1", "max1")]
Tim Dettmers's avatar
Tim Dettmers committed
117
str2statenames["rmsprop8bit_blockwise"] = [("square_avg", "state1", "qmap1", "absmax1")]
118
str2statenames["lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")]
Tim Dettmers's avatar
Tim Dettmers committed
119
str2statenames["paged_lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")]
Tim Dettmers's avatar
Tim Dettmers committed
120

Ruff's avatar
Ruff committed
121
optimizer_names_32bit = ["adam", "momentum", "rmsprop", "paged_adamw", "paged_adam", "lion", "paged_lion"]
Aarni Koskela's avatar
Aarni Koskela committed
122
123
124
125
126
127


@pytest.mark.parametrize("optim_name", optimizer_names_32bit, ids=id_formatter("opt"))
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [32, 1024, 4097, 1], ids=id_formatter("dim2"))
Tim Dettmers's avatar
Tim Dettmers committed
128
def test_optimizer32bit(dim1, dim2, gtype, optim_name):
Ruff's avatar
Ruff committed
129
    if gtype == torch.bfloat16 and optim_name in ["momentum", "rmsprop"]:
Aarni Koskela's avatar
Aarni Koskela committed
130
        pytest.skip()
131
132
133
    if dim1 == 1 and dim2 == 1:
        return
    p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
134
135
136
137
138
139
140
    p2 = p1.clone()
    p1 = p1.float()

    torch_optimizer = str2optimizers[optim_name][0]([p1])
    bnb_optimizer = str2optimizers[optim_name][1]([p2])

    if gtype == torch.float32:
Tim Dettmers's avatar
Tim Dettmers committed
141
        atol, rtol = 1e-6, 1e-5
142
143
    elif gtype == torch.bfloat16:
        atol, rtol = 1e-3, 1e-2
Tim Dettmers's avatar
Tim Dettmers committed
144
145
146
    else:
        atol, rtol = 1e-4, 1e-3

Tim Dettmers's avatar
Tim Dettmers committed
147
    for i in range(k):
148
        g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01
Tim Dettmers's avatar
Tim Dettmers committed
149
150
151
152
153
154
155
        p1.grad = g.clone().float()
        p2.grad = g.clone()

        bnb_optimizer.step()
        torch_optimizer.step()

        for name1, name2 in str2statenames[optim_name]:
Tim Dettmers's avatar
Tim Dettmers committed
156
            torch.testing.assert_close(
157
                torch_optimizer.state[p1][name1],
Tim Dettmers's avatar
Tim Dettmers committed
158
                bnb_optimizer.state[p2][name2].cuda(),
159
160
161
                atol=atol,
                rtol=rtol,
            )
Tim Dettmers's avatar
Tim Dettmers committed
162

163
164
        # since Lion can have pretty noisy updates where things lie at the boundary
        # allow up to 10 errors for Lion
165
        assert_most_approx_close(p1, p2.float(), atol=atol, rtol=rtol, max_error_count=10)
Tim Dettmers's avatar
Tim Dettmers committed
166

167
        if i % (k // 5) == 0 and i > 0:
Tim Dettmers's avatar
Tim Dettmers committed
168
            path = get_temp_dir()
169
            torch.save(bnb_optimizer.state_dict(), join(path, "opt.pt"))
Tim Dettmers's avatar
Tim Dettmers committed
170
171
172
            del bnb_optimizer
            bnb_optimizer = None
            bnb_optimizer = str2optimizers[optim_name][1]([p2])
173
            bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt")))
Tim Dettmers's avatar
Tim Dettmers committed
174
            rm_path(path)
175
176
            # since Lion can have pretty noisy updates where things lie at the boundary
            # allow up to 10 errors for Lion
177
            assert_most_approx_close(p1, p2.float(), atol=atol, rtol=rtol, max_error_count=10)
Tim Dettmers's avatar
Tim Dettmers committed
178
            for name1, name2 in str2statenames[optim_name]:
179
180
                # since Lion can have pretty noisy updates where things lie at the boundary
                # allow up to 10 errors for Lion
Ruff's avatar
Ruff committed
181
182
183
184
185
186
187
                assert_most_approx_close(
                    torch_optimizer.state[p1][name1],
                    bnb_optimizer.state[p2][name2],
                    atol=atol,
                    rtol=rtol,
                    max_error_count=10,
                )
Tim Dettmers's avatar
Tim Dettmers committed
188

189
        if gtype != torch.float32:
Tim Dettmers's avatar
Tim Dettmers committed
190
            # the adam buffers should also be close because they are 32-bit
191
            # but the parameters can diverge because they are 16-bit
Tim Dettmers's avatar
Tim Dettmers committed
192
193
            # the difference grow larger and larger with each update
            # --> copy the state to keep weights close
194
            p1.data = p1.data.to(p2.dtype).float()
Tim Dettmers's avatar
Tim Dettmers committed
195
            p2.copy_(p1.data)
Tim Dettmers's avatar
Tim Dettmers committed
196
            torch.testing.assert_close(p1.to(p2.dtype), p2)
197
198
199
        if optim_name in ["lars", "lamb"]:
            assert bnb_optimizer.state[p2]["unorm_vec"] > 0.0

Tim Dettmers's avatar
Tim Dettmers committed
200

Aarni Koskela's avatar
Aarni Koskela committed
201
202
203
@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2"))
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=describe_dtype)
Tim Dettmers's avatar
Tim Dettmers committed
204
def test_global_config(dim1, dim2, gtype):
205
206
207
208
209
    if dim1 == 1 and dim2 == 1:
        return
    p1 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
    p2 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
    p3 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
210
211
212
213
214
215
216
    mask = torch.rand_like(p2) < 0.1
    beta1 = 0.9
    beta2 = 0.999
    lr = 0.001
    eps = 1e-8

    bnb.optim.GlobalOptimManager.get_instance().initialize()
Ruff's avatar
Ruff committed
217
    bnb.optim.GlobalOptimManager.get_instance().override_config(p3, "optim_bits", 8)
Tim Dettmers's avatar
Tim Dettmers committed
218

Ruff's avatar
Ruff committed
219
    bnb.optim.GlobalOptimManager.get_instance().register_parameters([p1, p2, p3])
Tim Dettmers's avatar
Tim Dettmers committed
220
221
222
223
224
225
226
227
228
229
230
231
    p1 = p1.cuda()
    p2 = p2.cuda()
    p3 = p3.cuda()

    adam2 = bnb.optim.Adam([p1, p2, p3], lr, (beta1, beta2), eps)

    if gtype == torch.float32:
        atol, rtol = 1e-6, 1e-5
    else:
        atol, rtol = 1e-4, 1e-3

    for i in range(50):
232
233
234
        g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001
        g2 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001
        g3 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001
Tim Dettmers's avatar
Tim Dettmers committed
235
236
237
238
239
240
        p1.grad = g1
        p2.grad = g2
        p3.grad = g3

        adam2.step()

241
242
        assert adam2.state[p3]["state1"].dtype == torch.uint8
        assert adam2.state[p3]["state2"].dtype == torch.uint8
Tim Dettmers's avatar
Tim Dettmers committed
243
244


Aarni Koskela's avatar
Aarni Koskela committed
245
optimizer_names_8bit = [
246
    "adam8bit",
247
    "lion8bit",
248
249
250
    "momentum8bit",
    "rmsprop8bit",
    "adam8bit_blockwise",
251
    "lion8bit_blockwise",
252
253
254
255
256
    "momentum8bit_blockwise",
    "rmsprop8bit_blockwise",
]


Aarni Koskela's avatar
Aarni Koskela committed
257
258
259
260
@pytest.mark.parametrize("optim_name", optimizer_names_8bit, ids=id_formatter("opt"))
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
@pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
Tim Dettmers's avatar
Tim Dettmers committed
261
def test_optimizer8bit(dim1, dim2, gtype, optim_name):
Ruff's avatar
Ruff committed
262
263
    if gtype == torch.bfloat16 and optim_name not in ["adam8bit_blockwise", "lion8bit_blockwise"]:
        pytest.skip()
264
265
266
    if dim1 == 1 and dim2 == 1:
        return
    p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
267
268
269
270
271
272
273
274
275
276
    p2 = p1.clone()
    p1 = p1.float()
    blocksize = 2048

    torch_optimizer = str2optimizers[optim_name][0]([p1])
    bnb_optimizer = str2optimizers[optim_name][1]([p2])

    if gtype == torch.float32:
        atol, rtol = 3e-3, 1e-3
        patol, prtol = 1e-5, 1e-3
Tim Dettmers's avatar
Tim Dettmers committed
277
278
279
    elif gtype == torch.bfloat16:
        atol, rtol = 3e-3, 1e-3
        patol, prtol = 1e-4, 1e-2
Tim Dettmers's avatar
Tim Dettmers committed
280
281
282
283
284
285
286
    else:
        atol, rtol = 3e-3, 1e-3
        patol, prtol = 1e-5, 1e-3

    errors = []
    relerrors = []

287
    for i in range(100):
288
        g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01
Tim Dettmers's avatar
Tim Dettmers committed
289
290
291
292
293
294
        p1.grad = g.clone().float()
        p2.grad = g.clone()

        bnb_optimizer.step()
        torch_optimizer.step()

295
296
297
        # since Lion can have pretty noisy updates where things lie at the boundary
        # allow up to 5 errors for Lion
        assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=5)
Tim Dettmers's avatar
Tim Dettmers committed
298
299
300

        dequant_states = []
        for name1, name2, qmap, max_val in str2statenames[optim_name]:
301
302
303
304
305
306
307
308
            # print(bnb_optimizer.state[p2][max_val], name1)
            if "blockwise" in optim_name:
                s1 = F.dequantize_blockwise(
                    code=bnb_optimizer.state[p2][qmap],
                    absmax=bnb_optimizer.state[p2][max_val],
                    A=bnb_optimizer.state[p2][name2],
                    blocksize=blocksize,
                )
Tim Dettmers's avatar
Tim Dettmers committed
309
            else:
310
311
312
313
314
                s1 = F.dequantize(
                    code=bnb_optimizer.state[p2][qmap],
                    absmax=bnb_optimizer.state[p2][max_val],
                    A=bnb_optimizer.state[p2][name2],
                )
Ruff's avatar
Ruff committed
315
316
            num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0
            # assert num_not_close.sum().item() < 20
Tim Dettmers's avatar
Tim Dettmers committed
317
318
            dequant_states.append(s1.clone())

319
        err = torch.abs(p1 - p2)
Ruff's avatar
Ruff committed
320
        relerr = err / (torch.abs(p1) + 1e-9)
Tim Dettmers's avatar
Tim Dettmers committed
321
322
        if g.dtype == torch.bfloat16:
            assert err.mean() < 0.00015
323
            assert relerr.mean() < 0.0016
Tim Dettmers's avatar
Tim Dettmers committed
324
        else:
325
326
            assert err.mean() < 0.00012
            assert relerr.mean() < 0.0012
Tim Dettmers's avatar
Tim Dettmers committed
327
328
329
330
331

        errors.append(err.mean().item())
        relerrors.append(relerr.mean().item())

        if i % 10 == 0 and i > 0:
Ruff's avatar
Ruff committed
332
            for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states):
Tim Dettmers's avatar
Tim Dettmers committed
333
334
335
336
337
                s1cpy = s.clone()
                raws1cpy = bnb_optimizer.state[p2][name2].clone()
                qmap1 = bnb_optimizer.state[p2][qmap].clone()

                path = get_temp_dir()
338
                torch.save(bnb_optimizer.state_dict(), join(path, "opt.pt"))
Tim Dettmers's avatar
Tim Dettmers committed
339
340
341
                del bnb_optimizer
                bnb_optimizer = None
                bnb_optimizer = str2optimizers[optim_name][1]([p2])
342
                bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt")))
Tim Dettmers's avatar
Tim Dettmers committed
343
                rm_path(path)
Tim Dettmers's avatar
Tim Dettmers committed
344
345
                torch.testing.assert_close(raws1cpy, bnb_optimizer.state[p2][name2])
                torch.testing.assert_close(qmap1, bnb_optimizer.state[p2][qmap])
Tim Dettmers's avatar
Tim Dettmers committed
346

347
348
349
350
351
352
353
                if "blockwise" in optim_name:
                    s1 = F.dequantize_blockwise(
                        code=bnb_optimizer.state[p2][qmap],
                        absmax=bnb_optimizer.state[p2][max_val],
                        A=bnb_optimizer.state[p2][name2],
                        blocksize=blocksize,
                    )
Tim Dettmers's avatar
Tim Dettmers committed
354
                else:
355
356
357
358
359
                    s1 = F.dequantize(
                        code=bnb_optimizer.state[p2][qmap],
                        absmax=bnb_optimizer.state[p2][max_val],
                        A=bnb_optimizer.state[p2][name2],
                    )
Tim Dettmers's avatar
Tim Dettmers committed
360
                torch.testing.assert_close(s1cpy, s1)
Tim Dettmers's avatar
Tim Dettmers committed
361

Ruff's avatar
Ruff committed
362
                num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0
Tim Dettmers's avatar
Tim Dettmers committed
363
                assert num_not_close.sum().item() < 20
364
365
366
            # since Lion can have pretty noisy updates where things lie at the boundary
            # allow up to 5 errors for Lion
            assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=5)
Tim Dettmers's avatar
Tim Dettmers committed
367
368
369
370
371

        # the parameters diverge quickly. Here we keep them close
        # together so we can test against the Adam error
        p1.data = p1.data.to(gtype).float()
        p2.copy_(p1.data)
Tim Dettmers's avatar
Tim Dettmers committed
372
        torch.testing.assert_close(p1.to(gtype), p2)
Tim Dettmers's avatar
Tim Dettmers committed
373
        for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states):
Tim Dettmers's avatar
Tim Dettmers committed
374
375
            torch_optimizer.state[p1][name1].copy_(s.data)

376
377
    # print(sum(errors)/len(errors))
    # print(sum(relerrors)/len(relerrors))
Tim Dettmers's avatar
Tim Dettmers committed
378
379


Aarni Koskela's avatar
Aarni Koskela committed
380
381
382
383
@pytest.mark.parametrize("optim_bits", [32, 8], ids=id_formatter("optim_bits"))
@pytest.mark.parametrize("gtype", [torch.float32], ids=describe_dtype)
@pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
Tim Dettmers's avatar
Tim Dettmers committed
384
def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
385
386
387
    if dim1 == 1 and dim2 == 1:
        return
    p1 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
388
389
390
391
392
393
394
    beta1 = 0.9
    beta2 = 0.999
    lr = 0.001
    eps = 1e-8
    p1 = p1.cuda()
    p2 = p1.clone()
    adam1 = bnb.optim.Adam([p1], lr, (beta1, beta2), eps, optim_bits=optim_bits)
395
    adam2 = bnb.optim.Adam(
396
397
398
399
400
401
        [p2],
        lr,
        (beta1, beta2),
        eps,
        optim_bits=optim_bits,
        percentile_clipping=5,
402
    )
Tim Dettmers's avatar
Tim Dettmers committed
403
404
405
406
407
408

    gnorm_vec = torch.zeros(100).cuda()
    step = 0

    for i in range(50):
        step += 1
Ruff's avatar
Ruff committed
409
        g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + (0.01 * i)
Tim Dettmers's avatar
Tim Dettmers committed
410
411
412
        g2 = g1.clone()
        p2.grad = g2

Ruff's avatar
Ruff committed
413
        current_gnorm, clip_val, gnorm_scale = F.percentile_clipping(g1, gnorm_vec, step, 5)
414
        g1 = (g1.float() * gnorm_scale).to(gtype)
Tim Dettmers's avatar
Tim Dettmers committed
415
416
417
418
419
420
421
        p1.grad = g1

        adam1.step()
        adam2.step()

        # gnorm_scale is not deterministic (warp reductions), as such there can be slight differences in state
        if optim_bits == 32:
Tim Dettmers's avatar
Tim Dettmers committed
422
423
            torch.testing.assert_close(p1, p2)
            torch.testing.assert_close(
424
425
426
427
428
                adam1.state[p1]["state1"],
                adam2.state[p2]["state1"],
                atol=5e-5,
                rtol=1e-4,
            )
Tim Dettmers's avatar
Tim Dettmers committed
429
            torch.testing.assert_close(
430
431
432
433
434
                adam1.state[p1]["state2"],
                adam2.state[p2]["state2"],
                atol=5e-5,
                rtol=1e-4,
            )
Tim Dettmers's avatar
Tim Dettmers committed
435
        elif optim_bits == 8:
Tim Dettmers's avatar
Tim Dettmers committed
436
437
            torch.testing.assert_close(p1, p2, atol=1e-4, rtol=1e-3)
            torch.testing.assert_close(
438
439
440
441
                adam1.state[p1]["state1"],
                adam2.state[p2]["state1"],
                atol=2,
                rtol=1e-3,
442
            )
Tim Dettmers's avatar
Tim Dettmers committed
443
            torch.testing.assert_close(
444
445
446
447
                adam1.state[p1]["state2"],
                adam2.state[p2]["state2"],
                atol=2,
                rtol=1e-3,
448
449
450
            )
            adam1.state[p1]["state1"].copy_(adam2.state[p2]["state1"])
            adam1.state[p1]["state2"].copy_(adam2.state[p2]["state2"])
Tim Dettmers's avatar
Tim Dettmers committed
451
452
        if i % 10 == 0 and i > 0:
            path = get_temp_dir()
453
            torch.save(adam2.state_dict(), join(path, "opt.pt"))
Tim Dettmers's avatar
Tim Dettmers committed
454
455
            del adam2
            adam2 = None
456
457
458
459
460
461
462
463
464
            adam2 = bnb.optim.Adam(
                [p2],
                lr,
                (beta1, beta2),
                eps,
                optim_bits=optim_bits,
                percentile_clipping=5,
            )
            adam2.load_state_dict(torch.load(join(path, "opt.pt")))
Tim Dettmers's avatar
Tim Dettmers committed
465
466


Aarni Koskela's avatar
Aarni Koskela committed
467
468
469
470
471
optimizer_names_benchmark = [
    "adam8bit_blockwise",
    "paged_adam8bit_blockwise",
    "paged_adamw8bit_blockwise",
    "paged_lion8bit_blockwise",
472
]
473
474


Aarni Koskela's avatar
Aarni Koskela committed
475
476
477
478
479
@pytest.mark.parametrize("dim1", [4096], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [4096], ids=id_formatter("dim2"))
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=describe_dtype)
@pytest.mark.parametrize("optim_name", optimizer_names_benchmark, ids=id_formatter("opt"))
@pytest.mark.benchmark
Tim Dettmers's avatar
Tim Dettmers committed
480
def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
481
482
483
    if dim1 == 1 and dim2 == 1:
        return
    p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
484
485
486

    bnb_optimizer = str2optimizers[optim_name][1]([p1])

487
    g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01
Tim Dettmers's avatar
Tim Dettmers committed
488
    p1.grad = g
Tim Dettmers's avatar
Tim Dettmers committed
489
    for i in range(k):
490
        if i == k // 5:
Tim Dettmers's avatar
Tim Dettmers committed
491
492
493
494
495
496
497
            # 100 iterations for burn-in
            torch.cuda.synchronize()
            t0 = time.time()

        bnb_optimizer.step()

    torch.cuda.synchronize()
498
499
500
501
502
    s = time.time() - t0
    print("")
    params = (k - k // 5) * dim1 * dim2
    print(optim_name, gtype, s / params)
    # assert s < 3.9
503

Aarni Koskela's avatar
Aarni Koskela committed
504
505
506

@pytest.mark.parametrize("dim1", [2 * 1024], ids=id_formatter("dim1"))
@pytest.mark.parametrize("gtype", [torch.float16], ids=describe_dtype)
Ruff's avatar
Ruff committed
507
508
@pytest.mark.parametrize("optim_name", ["paged_adamw"], ids=id_formatter("optim_name"))
@pytest.mark.parametrize("mode", ["bnb"], ids=id_formatter("mode"))
Aarni Koskela's avatar
Aarni Koskela committed
509
@pytest.mark.benchmark
510
511
512
513
514
515
def test_stream_optimizer_bench(dim1, gtype, optim_name, mode):
    layers1 = torch.nn.Sequential(*torch.nn.ModuleList([torch.nn.Linear(dim1, dim1) for i in range(10)]))
    layers1 = layers1.to(gtype)
    layers1 = layers1.cuda()

    large_tensor = None
Ruff's avatar
Ruff committed
516
    if mode == "torch":
517
518
519
520
        optim = str2optimizers[optim_name][0](layers1.parameters())
    else:
        optim = str2optimizers[optim_name][1](layers1.parameters())
        # 12 GB
Ruff's avatar
Ruff committed
521
        large_tensor = torch.empty((int(4.5e9),), device="cuda")
522
523
524
525
526

    torch.cuda.synchronize()
    time.sleep(5)

    num_batches = 5
Ruff's avatar
Ruff committed
527
528
    batches = torch.randn(num_batches, 128, dim1, device="cuda").to(gtype)
    lbls = torch.randint(0, 10, size=(num_batches, 128)).cuda()
529
530
531
532

    for i in range(num_batches):
        print(i)
        b = batches[i]
Ruff's avatar
Ruff committed
533
        if i == 2:
534
535
536
537
538
539
540
541
542
543
            torch.cuda.synchronize()
            t0 = time.time()

        out1 = layers1(b)

        loss1 = torch.nn.functional.cross_entropy(out1, lbls[i]).mean()
        loss1.backward()
        optim.step()
    torch.cuda.synchronize()
    print(mode, time.time() - t0)