test_optim.py 24.2 KB
Newer Older
Tim Dettmers's avatar
Tim Dettmers committed
1
import os
Aarni Koskela's avatar
Aarni Koskela committed
2
from os.path import join
Tim Dettmers's avatar
Tim Dettmers committed
3
import shutil
4
import time
Tim Dettmers's avatar
Tim Dettmers committed
5
import uuid
6

7
from lion_pytorch import Lion
Aarni Koskela's avatar
Aarni Koskela committed
8
import pytest
Tim Dettmers's avatar
Tim Dettmers committed
9
import torch
10

Tim Dettmers's avatar
Tim Dettmers committed
11
12
import bitsandbytes as bnb
import bitsandbytes.functional as F
Aarni Koskela's avatar
Aarni Koskela committed
13
from tests.helpers import describe_dtype, id_formatter
Tim Dettmers's avatar
Tim Dettmers committed
14

15
# import apex
Tim Dettmers's avatar
Tim Dettmers committed
16
17

k = 20
Tim Dettmers's avatar
Tim Dettmers committed
18

Ruff's avatar
Ruff committed
19

20
def assert_most_approx_close(a, b, rtol=1e-3, atol=1e-3, max_error_count=0):
Tim Dettmers's avatar
Tim Dettmers committed
21
    idx = torch.isclose(a, b, rtol=rtol, atol=atol)
22
23
    error_count = (idx == 0).sum().item()
    if error_count > max_error_count:
24
        print(f"Too many values not close: assert {error_count} < {max_error_count}")
Tim Dettmers's avatar
Tim Dettmers committed
25
        torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
26

27

Tim Dettmers's avatar
Tim Dettmers committed
28
def get_temp_dir():
Aarni Koskela's avatar
Aarni Koskela committed
29
    path = f"/tmp/autoswap/{uuid.uuid4()}"
Tim Dettmers's avatar
Tim Dettmers committed
30
31
32
    os.makedirs(path, exist_ok=True)
    return path

33

Tim Dettmers's avatar
Tim Dettmers committed
34
35
36
def rm_path(path):
    shutil.rmtree(path)

Ruff's avatar
Ruff committed
37

Tim Dettmers's avatar
Tim Dettmers committed
38
str2optimizers = {}
39
40

## TODO: maybe remove these three.
41
str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam)
42
str2optimizers["lion_pytorch"] = (None, Lion, bnb.optim.Lion)
43
44
45
46
47
str2optimizers["momentum_pytorch"] = (
    None,
    lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
    bnb.optim.Adam,
)
48

49
str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam)
50
51
str2optimizers["adam8bit"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False))
str2optimizers["adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True))
Tim Dettmers's avatar
Tim Dettmers committed
52
str2optimizers["paged_adam"] = (torch.optim.Adam, bnb.optim.PagedAdam)
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
str2optimizers["paged_adamw"] = (torch.optim.AdamW, bnb.optim.PagedAdamW)
str2optimizers["paged_adam8bit_blockwise"] = (
    torch.optim.Adam,
    lambda pxx: bnb.optim.PagedAdam8bit(pxx, block_wise=True),
)
str2optimizers["paged_adamw8bit_blockwise"] = (
    torch.optim.AdamW,
    lambda pxx: bnb.optim.PagedAdamW8bit(pxx, block_wise=True),
)

str2optimizers["ademamix"] = (bnb.optim.ademamix._ReferenceAdEMAMix, bnb.optim.AdEMAMix)
str2optimizers["ademamix8bit_blockwise"] = (
    bnb.optim.ademamix._ReferenceAdEMAMix,
    lambda pxx: bnb.optim.AdEMAMix8bit(pxx),
)
str2optimizers["paged_ademamix"] = (bnb.optim.ademamix._ReferenceAdEMAMix, bnb.optim.PagedAdEMAMix)
str2optimizers["paged_ademamix8bit_blockwise"] = (
    bnb.optim.ademamix._ReferenceAdEMAMix,
    lambda pxx: bnb.optim.PagedAdEMAMix8bit(pxx),
)
str2optimizers["ademamix_scheduled"] = (
    lambda pxx: bnb.optim.ademamix._ReferenceAdEMAMix(pxx, t_alpha=k, t_beta3=k),
    lambda pxx: bnb.optim.AdEMAMix(pxx, t_alpha=k, t_beta3=k),
)
str2optimizers["ademamix8bit_blockwise_scheduled"] = (
    lambda pxx: bnb.optim.ademamix._ReferenceAdEMAMix(pxx, t_alpha=100, t_beta3=100),
    lambda pxx: bnb.optim.AdEMAMix8bit(pxx, t_alpha=100, t_beta3=100),
)

82
str2optimizers["lion"] = (Lion, bnb.optim.Lion)
83
84
str2optimizers["lion8bit"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=False))
str2optimizers["lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=True))
Tim Dettmers's avatar
Tim Dettmers committed
85
str2optimizers["paged_lion"] = (Lion, bnb.optim.PagedLion)
86
87
str2optimizers["paged_lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.PagedLion8bit(pxx, block_wise=True))

88
89
90
91
92
93
94
95
96
97
98
99
str2optimizers["momentum"] = (
    lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
    lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False),
)
str2optimizers["momentum8bit"] = (
    lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
    lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False),
)
str2optimizers["momentum8bit_blockwise"] = (
    lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
    lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True),
)
100
101
102
103
104
105
106
107
108

str2optimizers["rmsprop"] = (
    lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
    lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False),
)
str2optimizers["rmsprop8bit"] = (
    lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
    lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False),
)
109
110
111
112
str2optimizers["rmsprop8bit_blockwise"] = (
    lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
    lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=True),
)
Tim Dettmers's avatar
Tim Dettmers committed
113
114

str2statenames = {}
115
str2statenames["adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
Tim Dettmers's avatar
Tim Dettmers committed
116
117
str2statenames["paged_adamw"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["paged_adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
118
str2statenames["lion"] = [("exp_avg", "state1")]
Tim Dettmers's avatar
Tim Dettmers committed
119
str2statenames["paged_lion"] = [("exp_avg", "state1")]
120
121
122
str2statenames["momentum"] = [("momentum_buffer", "state1")]
str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["rmsprop"] = [("square_avg", "state1")]
Tim Dettmers's avatar
Tim Dettmers committed
123
124
str2statenames["adam8bit"] = [("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2")]
str2statenames["lamb8bit"] = [("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2")]
Ruff's avatar
Ruff committed
125
126
127
128
129
130
131
132
133
134
135
136
str2statenames["adam8bit_blockwise"] = [
    ("exp_avg", "state1", "qmap1", "absmax1"),
    ("exp_avg_sq", "state2", "qmap2", "absmax2"),
]
str2statenames["paged_adam8bit_blockwise"] = [
    ("exp_avg", "state1", "qmap1", "absmax1"),
    ("exp_avg_sq", "state2", "qmap2", "absmax2"),
]
str2statenames["paged_adamw8bit_blockwise"] = [
    ("exp_avg", "state1", "qmap1", "absmax1"),
    ("exp_avg_sq", "state2", "qmap2", "absmax2"),
]
Tim Dettmers's avatar
Tim Dettmers committed
137
str2statenames["momentum8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")]
138
str2statenames["lion8bit"] = [("exp_avg", "state1", "qmap1", "max1")]
Tim Dettmers's avatar
Tim Dettmers committed
139
str2statenames["momentum8bit_blockwise"] = [("momentum_buffer", "state1", "qmap1", "absmax1")]
140
str2statenames["rmsprop8bit"] = [("square_avg", "state1", "qmap1", "max1")]
Tim Dettmers's avatar
Tim Dettmers committed
141
str2statenames["rmsprop8bit_blockwise"] = [("square_avg", "state1", "qmap1", "absmax1")]
142
str2statenames["lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")]
Tim Dettmers's avatar
Tim Dettmers committed
143
str2statenames["paged_lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")]
Tim Dettmers's avatar
Tim Dettmers committed
144

145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
str2statenames["ademamix"] = str2statenames["ademamix_scheduled"] = [("m1_m2", "state1"), ("nu", "state2")]
str2statenames["paged_ademamix"] = [("m1_m2", "state1"), ("nu", "state2")]
str2statenames["ademamix8bit_blockwise"] = str2statenames["ademamix8bit_blockwise_scheduled"] = [
    ("m1_m2", "state1", "qmap1", "absmax1"),
    ("nu", "state2", "qmap2", "absmax2"),
]
str2statenames["paged_ademamix8bit_blockwise"] = [
    ("m1_m2", "state1", "qmap1", "absmax1"),
    ("nu", "state2", "qmap2", "absmax2"),
]

optimizer_names_32bit = [
    "adam",
    "paged_adamw",
    "paged_adam",
    "momentum",
    "rmsprop",
    "lion",
    "paged_lion",
    "ademamix",
    "ademamix_scheduled",
    "paged_ademamix",
]
Aarni Koskela's avatar
Aarni Koskela committed
168
169
170
171
172
173


@pytest.mark.parametrize("optim_name", optimizer_names_32bit, ids=id_formatter("opt"))
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [32, 1024, 4097, 1], ids=id_formatter("dim2"))
Tim Dettmers's avatar
Tim Dettmers committed
174
def test_optimizer32bit(dim1, dim2, gtype, optim_name):
Ruff's avatar
Ruff committed
175
    if gtype == torch.bfloat16 and optim_name in ["momentum", "rmsprop"]:
Aarni Koskela's avatar
Aarni Koskela committed
176
        pytest.skip()
177
178
179
    if dim1 == 1 and dim2 == 1:
        return
    p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
180
181
182
183
184
185
186
    p2 = p1.clone()
    p1 = p1.float()

    torch_optimizer = str2optimizers[optim_name][0]([p1])
    bnb_optimizer = str2optimizers[optim_name][1]([p2])

    if gtype == torch.float32:
Tim Dettmers's avatar
Tim Dettmers committed
187
        atol, rtol = 1e-6, 1e-5
188
189
    elif gtype == torch.bfloat16:
        atol, rtol = 1e-3, 1e-2
Tim Dettmers's avatar
Tim Dettmers committed
190
191
192
    else:
        atol, rtol = 1e-4, 1e-3

Tim Dettmers's avatar
Tim Dettmers committed
193
    for i in range(k):
194
        g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01
Tim Dettmers's avatar
Tim Dettmers committed
195
196
197
198
199
200
201
        p1.grad = g.clone().float()
        p2.grad = g.clone()

        bnb_optimizer.step()
        torch_optimizer.step()

        for name1, name2 in str2statenames[optim_name]:
Tim Dettmers's avatar
Tim Dettmers committed
202
            torch.testing.assert_close(
203
                torch_optimizer.state[p1][name1],
Tim Dettmers's avatar
Tim Dettmers committed
204
                bnb_optimizer.state[p2][name2].cuda(),
205
206
207
                atol=atol,
                rtol=rtol,
            )
Tim Dettmers's avatar
Tim Dettmers committed
208

209
210
        # since Lion can have pretty noisy updates where things lie at the boundary
        # allow up to 10 errors for Lion
211
        assert_most_approx_close(p1, p2.float(), atol=atol, rtol=rtol, max_error_count=10)
Tim Dettmers's avatar
Tim Dettmers committed
212

213
        if i % (k // 5) == 0 and i > 0:
Tim Dettmers's avatar
Tim Dettmers committed
214
            path = get_temp_dir()
215
            torch.save(bnb_optimizer.state_dict(), join(path, "opt.pt"))
Tim Dettmers's avatar
Tim Dettmers committed
216
217
218
            del bnb_optimizer
            bnb_optimizer = None
            bnb_optimizer = str2optimizers[optim_name][1]([p2])
219
            bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt")))
Tim Dettmers's avatar
Tim Dettmers committed
220
            rm_path(path)
221
222
            # since Lion can have pretty noisy updates where things lie at the boundary
            # allow up to 10 errors for Lion
223
            assert_most_approx_close(p1, p2.float(), atol=atol, rtol=rtol, max_error_count=10)
Tim Dettmers's avatar
Tim Dettmers committed
224
            for name1, name2 in str2statenames[optim_name]:
225
226
                # since Lion can have pretty noisy updates where things lie at the boundary
                # allow up to 10 errors for Lion
Ruff's avatar
Ruff committed
227
228
229
230
231
232
233
                assert_most_approx_close(
                    torch_optimizer.state[p1][name1],
                    bnb_optimizer.state[p2][name2],
                    atol=atol,
                    rtol=rtol,
                    max_error_count=10,
                )
Tim Dettmers's avatar
Tim Dettmers committed
234

235
        if gtype != torch.float32:
Tim Dettmers's avatar
Tim Dettmers committed
236
            # the adam buffers should also be close because they are 32-bit
237
            # but the parameters can diverge because they are 16-bit
Tim Dettmers's avatar
Tim Dettmers committed
238
239
            # the difference grow larger and larger with each update
            # --> copy the state to keep weights close
240
            p1.data = p1.data.to(p2.dtype).float()
Tim Dettmers's avatar
Tim Dettmers committed
241
            p2.copy_(p1.data)
Tim Dettmers's avatar
Tim Dettmers committed
242
            torch.testing.assert_close(p1.to(p2.dtype), p2)
243
244
245
        if optim_name in ["lars", "lamb"]:
            assert bnb_optimizer.state[p2]["unorm_vec"] > 0.0

Tim Dettmers's avatar
Tim Dettmers committed
246

Aarni Koskela's avatar
Aarni Koskela committed
247
248
249
@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2"))
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=describe_dtype)
Tim Dettmers's avatar
Tim Dettmers committed
250
def test_global_config(dim1, dim2, gtype):
251
252
253
254
255
    if dim1 == 1 and dim2 == 1:
        return
    p1 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
    p2 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
    p3 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
256
257
258
259
260
261
262
    mask = torch.rand_like(p2) < 0.1
    beta1 = 0.9
    beta2 = 0.999
    lr = 0.001
    eps = 1e-8

    bnb.optim.GlobalOptimManager.get_instance().initialize()
Ruff's avatar
Ruff committed
263
    bnb.optim.GlobalOptimManager.get_instance().override_config(p3, "optim_bits", 8)
Tim Dettmers's avatar
Tim Dettmers committed
264

Ruff's avatar
Ruff committed
265
    bnb.optim.GlobalOptimManager.get_instance().register_parameters([p1, p2, p3])
Tim Dettmers's avatar
Tim Dettmers committed
266
267
268
269
270
271
272
273
274
275
276
277
    p1 = p1.cuda()
    p2 = p2.cuda()
    p3 = p3.cuda()

    adam2 = bnb.optim.Adam([p1, p2, p3], lr, (beta1, beta2), eps)

    if gtype == torch.float32:
        atol, rtol = 1e-6, 1e-5
    else:
        atol, rtol = 1e-4, 1e-3

    for i in range(50):
278
279
280
        g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001
        g2 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001
        g3 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001
Tim Dettmers's avatar
Tim Dettmers committed
281
282
283
284
285
286
        p1.grad = g1
        p2.grad = g2
        p3.grad = g3

        adam2.step()

287
288
        assert adam2.state[p3]["state1"].dtype == torch.uint8
        assert adam2.state[p3]["state2"].dtype == torch.uint8
Tim Dettmers's avatar
Tim Dettmers committed
289
290


Aarni Koskela's avatar
Aarni Koskela committed
291
optimizer_names_8bit = [
292
    "adam8bit",
293
    "lion8bit",
294
295
296
    "momentum8bit",
    "rmsprop8bit",
    "adam8bit_blockwise",
297
    "lion8bit_blockwise",
298
299
    "momentum8bit_blockwise",
    "rmsprop8bit_blockwise",
300
301
    "ademamix8bit_blockwise",
    "ademamix8bit_blockwise_scheduled",
302
303
304
]


Aarni Koskela's avatar
Aarni Koskela committed
305
306
307
308
@pytest.mark.parametrize("optim_name", optimizer_names_8bit, ids=id_formatter("opt"))
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
@pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
Tim Dettmers's avatar
Tim Dettmers committed
309
def test_optimizer8bit(dim1, dim2, gtype, optim_name):
310
311
312
313
314
315
316
    torch.set_printoptions(precision=6)

    if gtype == torch.bfloat16 and optim_name not in [
        "adam8bit_blockwise",
        "lion8bit_blockwise",
        "ademamix8bit_blockwise",
    ]:
Ruff's avatar
Ruff committed
317
        pytest.skip()
318
319
320
    if dim1 == 1 and dim2 == 1:
        return
    p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
321
322
323
324
325
326
327
328
329
330
    p2 = p1.clone()
    p1 = p1.float()
    blocksize = 2048

    torch_optimizer = str2optimizers[optim_name][0]([p1])
    bnb_optimizer = str2optimizers[optim_name][1]([p2])

    if gtype == torch.float32:
        atol, rtol = 3e-3, 1e-3
        patol, prtol = 1e-5, 1e-3
Tim Dettmers's avatar
Tim Dettmers committed
331
332
333
    elif gtype == torch.bfloat16:
        atol, rtol = 3e-3, 1e-3
        patol, prtol = 1e-4, 1e-2
Tim Dettmers's avatar
Tim Dettmers committed
334
335
336
337
338
339
340
    else:
        atol, rtol = 3e-3, 1e-3
        patol, prtol = 1e-5, 1e-3

    errors = []
    relerrors = []

341
    for i in range(50):
342
        g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01
Tim Dettmers's avatar
Tim Dettmers committed
343
344
345
346
347
348
        p1.grad = g.clone().float()
        p2.grad = g.clone()

        bnb_optimizer.step()
        torch_optimizer.step()

349
        # since Lion can have pretty noisy updates where things lie at the boundary
350
351
        # and AdEMAMix can diverge as well, allow up to 0.05% errors.
        assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=int(p1.numel() * 5e-4))
Tim Dettmers's avatar
Tim Dettmers committed
352
353
354

        dequant_states = []
        for name1, name2, qmap, max_val in str2statenames[optim_name]:
355
356
            # print(bnb_optimizer.state[p2][max_val], name1)
            if "blockwise" in optim_name:
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
                ## For AdEMAMix, we need to dequantize [p2][name2][0] and [p2][name2][1]
                ## separately and then stack them. The qmap is shared, but absmax is also stacked.
                if optim_name == "ademamix8bit_blockwise" and name1 == "m1_m2":
                    m1 = F.dequantize_blockwise(
                        code=bnb_optimizer.state[p2][qmap],
                        absmax=bnb_optimizer.state[p2][max_val][0],
                        A=bnb_optimizer.state[p2][name2][0],
                        blocksize=blocksize,
                    )
                    m2 = F.dequantize_blockwise(
                        code=bnb_optimizer.state[p2][qmap],
                        absmax=bnb_optimizer.state[p2][max_val][1],
                        A=bnb_optimizer.state[p2][name2][1],
                        blocksize=blocksize,
                    )

                    s1 = torch.stack((m1, m2))

                else:
                    s1 = F.dequantize_blockwise(
                        code=bnb_optimizer.state[p2][qmap],
                        absmax=bnb_optimizer.state[p2][max_val],
                        A=bnb_optimizer.state[p2][name2],
                        blocksize=blocksize,
                    )
Tim Dettmers's avatar
Tim Dettmers committed
382
            else:
383
384
385
386
387
                s1 = F.dequantize(
                    code=bnb_optimizer.state[p2][qmap],
                    absmax=bnb_optimizer.state[p2][max_val],
                    A=bnb_optimizer.state[p2][name2],
                )
Ruff's avatar
Ruff committed
388
389
            num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0
            # assert num_not_close.sum().item() < 20
Tim Dettmers's avatar
Tim Dettmers committed
390
391
            dequant_states.append(s1.clone())

392
        err = torch.abs(p1 - p2)
Ruff's avatar
Ruff committed
393
        relerr = err / (torch.abs(p1) + 1e-9)
Tim Dettmers's avatar
Tim Dettmers committed
394
395
        if g.dtype == torch.bfloat16:
            assert err.mean() < 0.00015
396
            assert relerr.mean() < 0.0020  # 0.0016
Tim Dettmers's avatar
Tim Dettmers committed
397
        else:
398
399
            assert err.mean() < 0.00016  # 0.00012
            assert relerr.mean() < 0.0016  # 0.0012
Tim Dettmers's avatar
Tim Dettmers committed
400
401
402
403
404

        errors.append(err.mean().item())
        relerrors.append(relerr.mean().item())

        if i % 10 == 0 and i > 0:
Ruff's avatar
Ruff committed
405
            for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states):
Tim Dettmers's avatar
Tim Dettmers committed
406
407
408
409
410
                s1cpy = s.clone()
                raws1cpy = bnb_optimizer.state[p2][name2].clone()
                qmap1 = bnb_optimizer.state[p2][qmap].clone()

                path = get_temp_dir()
411
                torch.save(bnb_optimizer.state_dict(), join(path, "opt.pt"))
Tim Dettmers's avatar
Tim Dettmers committed
412
413
414
                del bnb_optimizer
                bnb_optimizer = None
                bnb_optimizer = str2optimizers[optim_name][1]([p2])
415
                bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt")))
Tim Dettmers's avatar
Tim Dettmers committed
416
                rm_path(path)
Tim Dettmers's avatar
Tim Dettmers committed
417
418
                torch.testing.assert_close(raws1cpy, bnb_optimizer.state[p2][name2])
                torch.testing.assert_close(qmap1, bnb_optimizer.state[p2][qmap])
Tim Dettmers's avatar
Tim Dettmers committed
419

420
                if "blockwise" in optim_name:
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
                    ## For AdEMAMix, we need to dequantize [p2][name2][0] and [p2][name2][1]
                    ## separately and then stack them. The qmap is shared, but absmax is also stacked.
                    if optim_name == "ademamix8bit_blockwise" and name1 == "m1_m2":
                        s1 = torch.stack(
                            (
                                F.dequantize_blockwise(
                                    code=bnb_optimizer.state[p2][qmap],
                                    absmax=bnb_optimizer.state[p2][max_val][0],
                                    A=bnb_optimizer.state[p2][name2][0],
                                    blocksize=blocksize,
                                ),
                                F.dequantize_blockwise(
                                    code=bnb_optimizer.state[p2][qmap],
                                    absmax=bnb_optimizer.state[p2][max_val][1],
                                    A=bnb_optimizer.state[p2][name2][1],
                                    blocksize=blocksize,
                                ),
                            )
                        )
                    else:
                        s1 = F.dequantize_blockwise(
                            code=bnb_optimizer.state[p2][qmap],
                            absmax=bnb_optimizer.state[p2][max_val],
                            A=bnb_optimizer.state[p2][name2],
                            blocksize=blocksize,
                        )
Tim Dettmers's avatar
Tim Dettmers committed
447
                else:
448
449
450
451
452
                    s1 = F.dequantize(
                        code=bnb_optimizer.state[p2][qmap],
                        absmax=bnb_optimizer.state[p2][max_val],
                        A=bnb_optimizer.state[p2][name2],
                    )
Tim Dettmers's avatar
Tim Dettmers committed
453
                torch.testing.assert_close(s1cpy, s1)
Tim Dettmers's avatar
Tim Dettmers committed
454

Ruff's avatar
Ruff committed
455
                num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0
Tim Dettmers's avatar
Tim Dettmers committed
456
                assert num_not_close.sum().item() < 20
457
            # since Lion can have pretty noisy updates where things lie at the boundary
458
459
            # and AdEMAMix can also be noisy, allow up to 0.05%.
            assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=int(p1.numel() * 5e-04))
Tim Dettmers's avatar
Tim Dettmers committed
460
461
462
463
464

        # the parameters diverge quickly. Here we keep them close
        # together so we can test against the Adam error
        p1.data = p1.data.to(gtype).float()
        p2.copy_(p1.data)
Tim Dettmers's avatar
Tim Dettmers committed
465
        torch.testing.assert_close(p1.to(gtype), p2)
Tim Dettmers's avatar
Tim Dettmers committed
466
        for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states):
Tim Dettmers's avatar
Tim Dettmers committed
467
468
            torch_optimizer.state[p1][name1].copy_(s.data)

469
470
    # print(sum(errors)/len(errors))
    # print(sum(relerrors)/len(relerrors))
Tim Dettmers's avatar
Tim Dettmers committed
471
472


Aarni Koskela's avatar
Aarni Koskela committed
473
474
475
476
@pytest.mark.parametrize("optim_bits", [32, 8], ids=id_formatter("optim_bits"))
@pytest.mark.parametrize("gtype", [torch.float32], ids=describe_dtype)
@pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
Tim Dettmers's avatar
Tim Dettmers committed
477
def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
478
479
480
    if dim1 == 1 and dim2 == 1:
        return
    p1 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
481
482
483
484
485
486
487
    beta1 = 0.9
    beta2 = 0.999
    lr = 0.001
    eps = 1e-8
    p1 = p1.cuda()
    p2 = p1.clone()
    adam1 = bnb.optim.Adam([p1], lr, (beta1, beta2), eps, optim_bits=optim_bits)
488
    adam2 = bnb.optim.Adam(
489
490
491
492
493
494
        [p2],
        lr,
        (beta1, beta2),
        eps,
        optim_bits=optim_bits,
        percentile_clipping=5,
495
    )
Tim Dettmers's avatar
Tim Dettmers committed
496
497
498
499
500
501

    gnorm_vec = torch.zeros(100).cuda()
    step = 0

    for i in range(50):
        step += 1
Ruff's avatar
Ruff committed
502
        g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + (0.01 * i)
Tim Dettmers's avatar
Tim Dettmers committed
503
504
505
        g2 = g1.clone()
        p2.grad = g2

Ruff's avatar
Ruff committed
506
        current_gnorm, clip_val, gnorm_scale = F.percentile_clipping(g1, gnorm_vec, step, 5)
507
        g1 = (g1.float() * gnorm_scale).to(gtype)
Tim Dettmers's avatar
Tim Dettmers committed
508
509
510
511
512
513
514
        p1.grad = g1

        adam1.step()
        adam2.step()

        # gnorm_scale is not deterministic (warp reductions), as such there can be slight differences in state
        if optim_bits == 32:
Tim Dettmers's avatar
Tim Dettmers committed
515
516
            torch.testing.assert_close(p1, p2)
            torch.testing.assert_close(
517
518
519
520
521
                adam1.state[p1]["state1"],
                adam2.state[p2]["state1"],
                atol=5e-5,
                rtol=1e-4,
            )
Tim Dettmers's avatar
Tim Dettmers committed
522
            torch.testing.assert_close(
523
524
525
526
527
                adam1.state[p1]["state2"],
                adam2.state[p2]["state2"],
                atol=5e-5,
                rtol=1e-4,
            )
Tim Dettmers's avatar
Tim Dettmers committed
528
        elif optim_bits == 8:
Tim Dettmers's avatar
Tim Dettmers committed
529
530
            torch.testing.assert_close(p1, p2, atol=1e-4, rtol=1e-3)
            torch.testing.assert_close(
531
532
533
534
                adam1.state[p1]["state1"],
                adam2.state[p2]["state1"],
                atol=2,
                rtol=1e-3,
535
            )
Tim Dettmers's avatar
Tim Dettmers committed
536
            torch.testing.assert_close(
537
538
539
540
                adam1.state[p1]["state2"],
                adam2.state[p2]["state2"],
                atol=2,
                rtol=1e-3,
541
542
543
            )
            adam1.state[p1]["state1"].copy_(adam2.state[p2]["state1"])
            adam1.state[p1]["state2"].copy_(adam2.state[p2]["state2"])
Tim Dettmers's avatar
Tim Dettmers committed
544
545
        if i % 10 == 0 and i > 0:
            path = get_temp_dir()
546
            torch.save(adam2.state_dict(), join(path, "opt.pt"))
Tim Dettmers's avatar
Tim Dettmers committed
547
548
            del adam2
            adam2 = None
549
550
551
552
553
554
555
556
557
            adam2 = bnb.optim.Adam(
                [p2],
                lr,
                (beta1, beta2),
                eps,
                optim_bits=optim_bits,
                percentile_clipping=5,
            )
            adam2.load_state_dict(torch.load(join(path, "opt.pt")))
Tim Dettmers's avatar
Tim Dettmers committed
558
559


Aarni Koskela's avatar
Aarni Koskela committed
560
561
562
563
564
optimizer_names_benchmark = [
    "adam8bit_blockwise",
    "paged_adam8bit_blockwise",
    "paged_adamw8bit_blockwise",
    "paged_lion8bit_blockwise",
565
    "paged_ademamix8bit_blockwise",
566
]
567
568


Aarni Koskela's avatar
Aarni Koskela committed
569
570
571
572
573
@pytest.mark.parametrize("dim1", [4096], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [4096], ids=id_formatter("dim2"))
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=describe_dtype)
@pytest.mark.parametrize("optim_name", optimizer_names_benchmark, ids=id_formatter("opt"))
@pytest.mark.benchmark
Tim Dettmers's avatar
Tim Dettmers committed
574
def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
575
576
577
    if dim1 == 1 and dim2 == 1:
        return
    p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
578
579
580

    bnb_optimizer = str2optimizers[optim_name][1]([p1])

581
    g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01
Tim Dettmers's avatar
Tim Dettmers committed
582
    p1.grad = g
Tim Dettmers's avatar
Tim Dettmers committed
583
    for i in range(k):
584
        if i == k // 5:
Tim Dettmers's avatar
Tim Dettmers committed
585
586
587
588
589
590
591
            # 100 iterations for burn-in
            torch.cuda.synchronize()
            t0 = time.time()

        bnb_optimizer.step()

    torch.cuda.synchronize()
592
593
594
595
596
    s = time.time() - t0
    print("")
    params = (k - k // 5) * dim1 * dim2
    print(optim_name, gtype, s / params)
    # assert s < 3.9
597

Aarni Koskela's avatar
Aarni Koskela committed
598
599
600

@pytest.mark.parametrize("dim1", [2 * 1024], ids=id_formatter("dim1"))
@pytest.mark.parametrize("gtype", [torch.float16], ids=describe_dtype)
Ruff's avatar
Ruff committed
601
602
@pytest.mark.parametrize("optim_name", ["paged_adamw"], ids=id_formatter("optim_name"))
@pytest.mark.parametrize("mode", ["bnb"], ids=id_formatter("mode"))
Aarni Koskela's avatar
Aarni Koskela committed
603
@pytest.mark.benchmark
604
605
606
607
608
609
def test_stream_optimizer_bench(dim1, gtype, optim_name, mode):
    layers1 = torch.nn.Sequential(*torch.nn.ModuleList([torch.nn.Linear(dim1, dim1) for i in range(10)]))
    layers1 = layers1.to(gtype)
    layers1 = layers1.cuda()

    large_tensor = None
Ruff's avatar
Ruff committed
610
    if mode == "torch":
611
612
613
614
        optim = str2optimizers[optim_name][0](layers1.parameters())
    else:
        optim = str2optimizers[optim_name][1](layers1.parameters())
        # 12 GB
Ruff's avatar
Ruff committed
615
        large_tensor = torch.empty((int(4.5e9),), device="cuda")
616
617
618
619
620

    torch.cuda.synchronize()
    time.sleep(5)

    num_batches = 5
Ruff's avatar
Ruff committed
621
622
    batches = torch.randn(num_batches, 128, dim1, device="cuda").to(gtype)
    lbls = torch.randint(0, 10, size=(num_batches, 128)).cuda()
623
624
625
626

    for i in range(num_batches):
        print(i)
        b = batches[i]
Ruff's avatar
Ruff committed
627
        if i == 2:
628
629
630
631
632
633
634
635
636
637
            torch.cuda.synchronize()
            t0 = time.time()

        out1 = layers1(b)

        loss1 = torch.nn.functional.cross_entropy(out1, lbls[i]).mean()
        loss1.backward()
        optim.step()
    torch.cuda.synchronize()
    print(mode, time.time() - t0)