test_optim.py 22 KB
Newer Older
Tim Dettmers's avatar
Tim Dettmers committed
1
import os
Aarni Koskela's avatar
Aarni Koskela committed
2
from os.path import join
Tim Dettmers's avatar
Tim Dettmers committed
3
import shutil
4
import sys
5
import time
Tim Dettmers's avatar
Tim Dettmers committed
6
import uuid
7

8
from lion_pytorch import Lion
Aarni Koskela's avatar
Aarni Koskela committed
9
import pytest
Tim Dettmers's avatar
Tim Dettmers committed
10
import torch
11

Tim Dettmers's avatar
Tim Dettmers committed
12
13
import bitsandbytes as bnb
import bitsandbytes.functional as F
Egor Krivov's avatar
Egor Krivov committed
14
15
from bitsandbytes.utils import sync_gpu
from tests.helpers import describe_dtype, get_available_devices, id_formatter
Tim Dettmers's avatar
Tim Dettmers committed
16

17
# import apex
Tim Dettmers's avatar
Tim Dettmers committed
18
19

k = 20
Tim Dettmers's avatar
Tim Dettmers committed
20

Ruff's avatar
Ruff committed
21

22
def assert_most_approx_close(a, b, rtol=1e-3, atol=1e-3, max_error_count=0):
Tim Dettmers's avatar
Tim Dettmers committed
23
    idx = torch.isclose(a, b, rtol=rtol, atol=atol)
24
25
    error_count = (idx == 0).sum().item()
    if error_count > max_error_count:
26
        print(f"Too many values not close: assert {error_count} < {max_error_count}")
Tim Dettmers's avatar
Tim Dettmers committed
27
        torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
28

29

Tim Dettmers's avatar
Tim Dettmers committed
30
def get_temp_dir():
Aarni Koskela's avatar
Aarni Koskela committed
31
    path = f"/tmp/autoswap/{uuid.uuid4()}"
Tim Dettmers's avatar
Tim Dettmers committed
32
33
34
    os.makedirs(path, exist_ok=True)
    return path

35

Tim Dettmers's avatar
Tim Dettmers committed
36
37
38
def rm_path(path):
    shutil.rmtree(path)

Ruff's avatar
Ruff committed
39

Tim Dettmers's avatar
Tim Dettmers committed
40
str2optimizers = {}
41
42

## TODO: maybe remove these three.
43
str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam)
44
str2optimizers["lion_pytorch"] = (None, Lion, bnb.optim.Lion)
45
46
47
48
49
str2optimizers["momentum_pytorch"] = (
    None,
    lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
    bnb.optim.Adam,
)
50

51
str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam)
52
str2optimizers["adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True))
Tim Dettmers's avatar
Tim Dettmers committed
53
str2optimizers["paged_adam"] = (torch.optim.Adam, bnb.optim.PagedAdam)
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
str2optimizers["paged_adamw"] = (torch.optim.AdamW, bnb.optim.PagedAdamW)
str2optimizers["paged_adam8bit_blockwise"] = (
    torch.optim.Adam,
    lambda pxx: bnb.optim.PagedAdam8bit(pxx, block_wise=True),
)
str2optimizers["paged_adamw8bit_blockwise"] = (
    torch.optim.AdamW,
    lambda pxx: bnb.optim.PagedAdamW8bit(pxx, block_wise=True),
)

str2optimizers["ademamix"] = (bnb.optim.ademamix._ReferenceAdEMAMix, bnb.optim.AdEMAMix)
str2optimizers["ademamix8bit_blockwise"] = (
    bnb.optim.ademamix._ReferenceAdEMAMix,
    lambda pxx: bnb.optim.AdEMAMix8bit(pxx),
)
str2optimizers["paged_ademamix"] = (bnb.optim.ademamix._ReferenceAdEMAMix, bnb.optim.PagedAdEMAMix)
str2optimizers["paged_ademamix8bit_blockwise"] = (
    bnb.optim.ademamix._ReferenceAdEMAMix,
    lambda pxx: bnb.optim.PagedAdEMAMix8bit(pxx),
)
str2optimizers["ademamix_scheduled"] = (
    lambda pxx: bnb.optim.ademamix._ReferenceAdEMAMix(pxx, t_alpha=k, t_beta3=k),
    lambda pxx: bnb.optim.AdEMAMix(pxx, t_alpha=k, t_beta3=k),
)
78
79
80
81
str2optimizers["paged_ademamix_scheduled"] = (
    lambda pxx: bnb.optim.ademamix._ReferenceAdEMAMix(pxx, t_alpha=k, t_beta3=k),
    lambda pxx: bnb.optim.PagedAdEMAMix(pxx, t_alpha=k, t_beta3=k),
)
82
83
84
85
str2optimizers["ademamix8bit_blockwise_scheduled"] = (
    lambda pxx: bnb.optim.ademamix._ReferenceAdEMAMix(pxx, t_alpha=100, t_beta3=100),
    lambda pxx: bnb.optim.AdEMAMix8bit(pxx, t_alpha=100, t_beta3=100),
)
86
87
88
89
str2optimizers["paged_ademamix8bit_blockwise_scheduled"] = (
    lambda pxx: bnb.optim.ademamix._ReferenceAdEMAMix(pxx, t_alpha=100, t_beta3=100),
    lambda pxx: bnb.optim.PagedAdEMAMix8bit(pxx, t_alpha=100, t_beta3=100),
)
90

91
str2optimizers["lion"] = (Lion, bnb.optim.Lion)
Tim Dettmers's avatar
Tim Dettmers committed
92
str2optimizers["paged_lion"] = (Lion, bnb.optim.PagedLion)
93
str2optimizers["lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=True))
94
95
str2optimizers["paged_lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.PagedLion8bit(pxx, block_wise=True))

96
97
98
99
100
101
102
103
str2optimizers["momentum"] = (
    lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
    lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False),
)
str2optimizers["momentum8bit_blockwise"] = (
    lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
    lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True),
)
104
105
106
107
108

str2optimizers["rmsprop"] = (
    lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
    lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False),
)
109
110
111
112
str2optimizers["rmsprop8bit_blockwise"] = (
    lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
    lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=True),
)
Tim Dettmers's avatar
Tim Dettmers committed
113
114

str2statenames = {}
115
str2statenames["adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
Tim Dettmers's avatar
Tim Dettmers committed
116
117
str2statenames["paged_adamw"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["paged_adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
118
str2statenames["lion"] = [("exp_avg", "state1")]
Tim Dettmers's avatar
Tim Dettmers committed
119
str2statenames["paged_lion"] = [("exp_avg", "state1")]
120
121
122
str2statenames["momentum"] = [("momentum_buffer", "state1")]
str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["rmsprop"] = [("square_avg", "state1")]
123

Ruff's avatar
Ruff committed
124
125
126
127
128
129
130
131
132
133
134
135
str2statenames["adam8bit_blockwise"] = [
    ("exp_avg", "state1", "qmap1", "absmax1"),
    ("exp_avg_sq", "state2", "qmap2", "absmax2"),
]
str2statenames["paged_adam8bit_blockwise"] = [
    ("exp_avg", "state1", "qmap1", "absmax1"),
    ("exp_avg_sq", "state2", "qmap2", "absmax2"),
]
str2statenames["paged_adamw8bit_blockwise"] = [
    ("exp_avg", "state1", "qmap1", "absmax1"),
    ("exp_avg_sq", "state2", "qmap2", "absmax2"),
]
136

Tim Dettmers's avatar
Tim Dettmers committed
137
138
str2statenames["momentum8bit_blockwise"] = [("momentum_buffer", "state1", "qmap1", "absmax1")]
str2statenames["rmsprop8bit_blockwise"] = [("square_avg", "state1", "qmap1", "absmax1")]
139
str2statenames["lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")]
Tim Dettmers's avatar
Tim Dettmers committed
140
str2statenames["paged_lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")]
Tim Dettmers's avatar
Tim Dettmers committed
141

142
str2statenames["ademamix"] = str2statenames["ademamix_scheduled"] = [("m1_m2", "state1"), ("nu", "state2")]
143
str2statenames["paged_ademamix"] = str2statenames["paged_ademamix_scheduled"] = [("m1_m2", "state1"), ("nu", "state2")]
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
str2statenames["ademamix8bit_blockwise"] = str2statenames["ademamix8bit_blockwise_scheduled"] = [
    ("m1_m2", "state1", "qmap1", "absmax1"),
    ("nu", "state2", "qmap2", "absmax2"),
]
str2statenames["paged_ademamix8bit_blockwise"] = [
    ("m1_m2", "state1", "qmap1", "absmax1"),
    ("nu", "state2", "qmap2", "absmax2"),
]

optimizer_names_32bit = [
    "adam",
    "paged_adamw",
    "paged_adam",
    "momentum",
    "rmsprop",
    "lion",
    "paged_lion",
    "ademamix",
    "ademamix_scheduled",
    "paged_ademamix",
164
    "paged_ademamix_scheduled",
165
]
Aarni Koskela's avatar
Aarni Koskela committed
166
167
168
169
170
171


@pytest.mark.parametrize("optim_name", optimizer_names_32bit, ids=id_formatter("opt"))
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [32, 1024, 4097, 1], ids=id_formatter("dim2"))
Egor Krivov's avatar
Egor Krivov committed
172
@pytest.mark.parametrize("device", get_available_devices(no_cpu=True), ids=id_formatter("device"))
Matthew Douglas's avatar
Matthew Douglas committed
173
@pytest.mark.skipif(not get_available_devices(no_cpu=True), reason="No device")
Egor Krivov's avatar
Egor Krivov committed
174
def test_optimizer32bit(dim1, dim2, gtype, optim_name, device):
175
176
177
    if optim_name.startswith("paged_") and sys.platform == "win32":
        pytest.skip("Paged optimizers can have issues on Windows.")

Ruff's avatar
Ruff committed
178
    if gtype == torch.bfloat16 and optim_name in ["momentum", "rmsprop"]:
Aarni Koskela's avatar
Aarni Koskela committed
179
        pytest.skip()
180
181
    if dim1 == 1 and dim2 == 1:
        return
Egor Krivov's avatar
Egor Krivov committed
182
    p1 = torch.randn(dim1, dim2, device=device, dtype=gtype) * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
183
184
185
186
187
188
189
    p2 = p1.clone()
    p1 = p1.float()

    torch_optimizer = str2optimizers[optim_name][0]([p1])
    bnb_optimizer = str2optimizers[optim_name][1]([p2])

    if gtype == torch.float32:
Tim Dettmers's avatar
Tim Dettmers committed
190
        atol, rtol = 1e-6, 1e-5
191
192
    elif gtype == torch.bfloat16:
        atol, rtol = 1e-3, 1e-2
Tim Dettmers's avatar
Tim Dettmers committed
193
194
195
    else:
        atol, rtol = 1e-4, 1e-3

Tim Dettmers's avatar
Tim Dettmers committed
196
    for i in range(k):
Egor Krivov's avatar
Egor Krivov committed
197
        g = torch.randn(dim1, dim2, device=device, dtype=gtype) * 0.01
Tim Dettmers's avatar
Tim Dettmers committed
198
199
200
201
202
203
204
        p1.grad = g.clone().float()
        p2.grad = g.clone()

        bnb_optimizer.step()
        torch_optimizer.step()

        for name1, name2 in str2statenames[optim_name]:
Tim Dettmers's avatar
Tim Dettmers committed
205
            torch.testing.assert_close(
206
                torch_optimizer.state[p1][name1],
Egor Krivov's avatar
Egor Krivov committed
207
                bnb_optimizer.state[p2][name2].to(device),
208
209
210
                atol=atol,
                rtol=rtol,
            )
Tim Dettmers's avatar
Tim Dettmers committed
211

212
        # since Lion can have pretty noisy updates where things lie at the boundary
Egor Krivov's avatar
Egor Krivov committed
213
214
        # allow up to 15 errors for Lion
        assert_most_approx_close(p1, p2.float(), atol=atol, rtol=rtol, max_error_count=15)
Tim Dettmers's avatar
Tim Dettmers committed
215

216
        if i % (k // 5) == 0 and i > 0:
Tim Dettmers's avatar
Tim Dettmers committed
217
            path = get_temp_dir()
218
            torch.save(bnb_optimizer.state_dict(), join(path, "opt.pt"))
Tim Dettmers's avatar
Tim Dettmers committed
219
220
221
            del bnb_optimizer
            bnb_optimizer = None
            bnb_optimizer = str2optimizers[optim_name][1]([p2])
222
            bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt")))
Tim Dettmers's avatar
Tim Dettmers committed
223
            rm_path(path)
224
225
            # since Lion can have pretty noisy updates where things lie at the boundary
            # allow up to 10 errors for Lion
226
            assert_most_approx_close(p1, p2.float(), atol=atol, rtol=rtol, max_error_count=10)
Tim Dettmers's avatar
Tim Dettmers committed
227
            for name1, name2 in str2statenames[optim_name]:
228
229
                # since Lion can have pretty noisy updates where things lie at the boundary
                # allow up to 10 errors for Lion
Ruff's avatar
Ruff committed
230
231
232
233
234
235
236
                assert_most_approx_close(
                    torch_optimizer.state[p1][name1],
                    bnb_optimizer.state[p2][name2],
                    atol=atol,
                    rtol=rtol,
                    max_error_count=10,
                )
Tim Dettmers's avatar
Tim Dettmers committed
237

238
        if gtype != torch.float32:
Tim Dettmers's avatar
Tim Dettmers committed
239
            # the adam buffers should also be close because they are 32-bit
240
            # but the parameters can diverge because they are 16-bit
Tim Dettmers's avatar
Tim Dettmers committed
241
242
            # the difference grow larger and larger with each update
            # --> copy the state to keep weights close
243
            p1.data = p1.data.to(p2.dtype).float()
Tim Dettmers's avatar
Tim Dettmers committed
244
            p2.copy_(p1.data)
Tim Dettmers's avatar
Tim Dettmers committed
245
            torch.testing.assert_close(p1.to(p2.dtype), p2)
246
247
248
        if optim_name in ["lars", "lamb"]:
            assert bnb_optimizer.state[p2]["unorm_vec"] > 0.0

Tim Dettmers's avatar
Tim Dettmers committed
249

Aarni Koskela's avatar
Aarni Koskela committed
250
251
252
@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2"))
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=describe_dtype)
Egor Krivov's avatar
Egor Krivov committed
253
@pytest.mark.parametrize("device", get_available_devices(no_cpu=True))
Matthew Douglas's avatar
Matthew Douglas committed
254
@pytest.mark.skipif(not get_available_devices(no_cpu=True), reason="No device")
Egor Krivov's avatar
Egor Krivov committed
255
def test_global_config(dim1, dim2, gtype, device):
256
257
258
259
260
    if dim1 == 1 and dim2 == 1:
        return
    p1 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
    p2 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
    p3 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
261
262
263
264
265
266
267
    mask = torch.rand_like(p2) < 0.1
    beta1 = 0.9
    beta2 = 0.999
    lr = 0.001
    eps = 1e-8

    bnb.optim.GlobalOptimManager.get_instance().initialize()
Ruff's avatar
Ruff committed
268
    bnb.optim.GlobalOptimManager.get_instance().override_config(p3, "optim_bits", 8)
Tim Dettmers's avatar
Tim Dettmers committed
269

Ruff's avatar
Ruff committed
270
    bnb.optim.GlobalOptimManager.get_instance().register_parameters([p1, p2, p3])
Egor Krivov's avatar
Egor Krivov committed
271
272
273
    p1 = p1.to(device)
    p2 = p2.to(device)
    p3 = p3.to(device)
Tim Dettmers's avatar
Tim Dettmers committed
274
275
276
277
278
279
280
281
282

    adam2 = bnb.optim.Adam([p1, p2, p3], lr, (beta1, beta2), eps)

    if gtype == torch.float32:
        atol, rtol = 1e-6, 1e-5
    else:
        atol, rtol = 1e-4, 1e-3

    for i in range(50):
Egor Krivov's avatar
Egor Krivov committed
283
284
285
        g1 = torch.randn(dim1, dim2, device=device, dtype=gtype) * 0.1 + 0.001
        g2 = torch.randn(dim1, dim2, device=device, dtype=gtype) * 0.1 + 0.001
        g3 = torch.randn(dim1, dim2, device=device, dtype=gtype) * 0.1 + 0.001
Tim Dettmers's avatar
Tim Dettmers committed
286
287
288
289
290
291
        p1.grad = g1
        p2.grad = g2
        p3.grad = g3

        adam2.step()

292
293
        assert adam2.state[p3]["state1"].dtype == torch.uint8
        assert adam2.state[p3]["state2"].dtype == torch.uint8
Tim Dettmers's avatar
Tim Dettmers committed
294
295


Aarni Koskela's avatar
Aarni Koskela committed
296
optimizer_names_8bit = [
297
    "adam8bit_blockwise",
298
    "lion8bit_blockwise",
299
300
    "momentum8bit_blockwise",
    "rmsprop8bit_blockwise",
301
302
    "ademamix8bit_blockwise",
    "ademamix8bit_blockwise_scheduled",
303
304
305
]


Aarni Koskela's avatar
Aarni Koskela committed
306
307
308
309
@pytest.mark.parametrize("optim_name", optimizer_names_8bit, ids=id_formatter("opt"))
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
@pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
Egor Krivov's avatar
Egor Krivov committed
310
@pytest.mark.parametrize("device", get_available_devices(no_cpu=True))
Matthew Douglas's avatar
Matthew Douglas committed
311
@pytest.mark.skipif(not get_available_devices(no_cpu=True), reason="No device")
Egor Krivov's avatar
Egor Krivov committed
312
def test_optimizer8bit(dim1, dim2, gtype, optim_name, device):
313
314
    torch.set_printoptions(precision=6)

315
316
    if dim1 == 1 and dim2 == 1:
        return
Matthew Douglas's avatar
Matthew Douglas committed
317

Egor Krivov's avatar
Egor Krivov committed
318
    p1 = torch.randn(dim1, dim2, device=device, dtype=gtype) * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
319
320
    p2 = p1.clone()
    p1 = p1.float()
321
    blocksize = 256
Tim Dettmers's avatar
Tim Dettmers committed
322
323
324
325
326
327
328

    torch_optimizer = str2optimizers[optim_name][0]([p1])
    bnb_optimizer = str2optimizers[optim_name][1]([p2])

    if gtype == torch.float32:
        atol, rtol = 3e-3, 1e-3
        patol, prtol = 1e-5, 1e-3
Tim Dettmers's avatar
Tim Dettmers committed
329
330
331
    elif gtype == torch.bfloat16:
        atol, rtol = 3e-3, 1e-3
        patol, prtol = 1e-4, 1e-2
Tim Dettmers's avatar
Tim Dettmers committed
332
333
334
335
336
337
338
    else:
        atol, rtol = 3e-3, 1e-3
        patol, prtol = 1e-5, 1e-3

    errors = []
    relerrors = []

339
    for i in range(50):
Egor Krivov's avatar
Egor Krivov committed
340
        g = torch.randn(dim1, dim2, device=device, dtype=gtype) * 0.01
Tim Dettmers's avatar
Tim Dettmers committed
341
342
343
344
        p1.grad = g.clone().float()
        p2.grad = g.clone()

        torch_optimizer.step()
Egor Krivov's avatar
Egor Krivov committed
345
        bnb_optimizer.step()
Tim Dettmers's avatar
Tim Dettmers committed
346

347
        # since Lion can have pretty noisy updates where things lie at the boundary
Egor Krivov's avatar
Egor Krivov committed
348
        # assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=0)
Tim Dettmers's avatar
Tim Dettmers committed
349
350
351

        dequant_states = []
        for name1, name2, qmap, max_val in str2statenames[optim_name]:
Matthew Douglas's avatar
Matthew Douglas committed
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
            ## For AdEMAMix, we need to dequantize [p2][name2][0] and [p2][name2][1]
            ## separately and then stack them. The qmap is shared, but absmax is also stacked.
            if optim_name == "ademamix8bit_blockwise" and name1 == "m1_m2":
                m1 = F.dequantize_blockwise(
                    code=bnb_optimizer.state[p2][qmap],
                    absmax=bnb_optimizer.state[p2][max_val][0],
                    A=bnb_optimizer.state[p2][name2][0],
                    blocksize=blocksize,
                )
                m2 = F.dequantize_blockwise(
                    code=bnb_optimizer.state[p2][qmap],
                    absmax=bnb_optimizer.state[p2][max_val][1],
                    A=bnb_optimizer.state[p2][name2][1],
                    blocksize=blocksize,
                )
367

Matthew Douglas's avatar
Matthew Douglas committed
368
                s1 = torch.stack((m1, m2))
Tim Dettmers's avatar
Tim Dettmers committed
369
            else:
Matthew Douglas's avatar
Matthew Douglas committed
370
                s1 = F.dequantize_blockwise(
371
372
373
                    code=bnb_optimizer.state[p2][qmap],
                    absmax=bnb_optimizer.state[p2][max_val],
                    A=bnb_optimizer.state[p2][name2],
Matthew Douglas's avatar
Matthew Douglas committed
374
                    blocksize=blocksize,
375
                )
Matthew Douglas's avatar
Matthew Douglas committed
376

Ruff's avatar
Ruff committed
377
            num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0
Egor Krivov's avatar
Egor Krivov committed
378
            assert num_not_close.sum().item() < 20
Tim Dettmers's avatar
Tim Dettmers committed
379
380
            dequant_states.append(s1.clone())

381
        err = torch.abs(p1 - p2)
Ruff's avatar
Ruff committed
382
        relerr = err / (torch.abs(p1) + 1e-9)
Tim Dettmers's avatar
Tim Dettmers committed
383
        if g.dtype == torch.bfloat16:
384
385
            assert err.mean() <= 0.00017
            assert relerr.mean() <= 0.0016
Tim Dettmers's avatar
Tim Dettmers committed
386
        else:
387
388
            assert err.mean() < 0.00006
            assert relerr.mean() < 0.0006
Tim Dettmers's avatar
Tim Dettmers committed
389
390
391
392
393

        errors.append(err.mean().item())
        relerrors.append(relerr.mean().item())

        if i % 10 == 0 and i > 0:
Ruff's avatar
Ruff committed
394
            for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states):
Tim Dettmers's avatar
Tim Dettmers committed
395
396
397
398
399
                s1cpy = s.clone()
                raws1cpy = bnb_optimizer.state[p2][name2].clone()
                qmap1 = bnb_optimizer.state[p2][qmap].clone()

                path = get_temp_dir()
400
                torch.save(bnb_optimizer.state_dict(), join(path, "opt.pt"))
Tim Dettmers's avatar
Tim Dettmers committed
401
402
403
                del bnb_optimizer
                bnb_optimizer = None
                bnb_optimizer = str2optimizers[optim_name][1]([p2])
404
                bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt")))
Tim Dettmers's avatar
Tim Dettmers committed
405
                rm_path(path)
Tim Dettmers's avatar
Tim Dettmers committed
406
407
                torch.testing.assert_close(raws1cpy, bnb_optimizer.state[p2][name2])
                torch.testing.assert_close(qmap1, bnb_optimizer.state[p2][qmap])
Tim Dettmers's avatar
Tim Dettmers committed
408

Matthew Douglas's avatar
Matthew Douglas committed
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
                ## For AdEMAMix, we need to dequantize [p2][name2][0] and [p2][name2][1]
                ## separately and then stack them. The qmap is shared, but absmax is also stacked.
                if optim_name == "ademamix8bit_blockwise" and name1 == "m1_m2":
                    s1 = torch.stack(
                        (
                            F.dequantize_blockwise(
                                code=bnb_optimizer.state[p2][qmap],
                                absmax=bnb_optimizer.state[p2][max_val][0],
                                A=bnb_optimizer.state[p2][name2][0],
                                blocksize=blocksize,
                            ),
                            F.dequantize_blockwise(
                                code=bnb_optimizer.state[p2][qmap],
                                absmax=bnb_optimizer.state[p2][max_val][1],
                                A=bnb_optimizer.state[p2][name2][1],
                                blocksize=blocksize,
                            ),
426
                        )
Matthew Douglas's avatar
Matthew Douglas committed
427
                    )
Tim Dettmers's avatar
Tim Dettmers committed
428
                else:
Matthew Douglas's avatar
Matthew Douglas committed
429
                    s1 = F.dequantize_blockwise(
430
431
432
                        code=bnb_optimizer.state[p2][qmap],
                        absmax=bnb_optimizer.state[p2][max_val],
                        A=bnb_optimizer.state[p2][name2],
Matthew Douglas's avatar
Matthew Douglas committed
433
                        blocksize=blocksize,
434
                    )
Matthew Douglas's avatar
Matthew Douglas committed
435

Tim Dettmers's avatar
Tim Dettmers committed
436
                torch.testing.assert_close(s1cpy, s1)
Tim Dettmers's avatar
Tim Dettmers committed
437

Ruff's avatar
Ruff committed
438
                num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0
Tim Dettmers's avatar
Tim Dettmers committed
439
                assert num_not_close.sum().item() < 20
440
441
442

            # Lion can have pretty noisy updates where things lie at the boundary
            assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=0)
Tim Dettmers's avatar
Tim Dettmers committed
443
444
445
446
447

        # the parameters diverge quickly. Here we keep them close
        # together so we can test against the Adam error
        p1.data = p1.data.to(gtype).float()
        p2.copy_(p1.data)
Tim Dettmers's avatar
Tim Dettmers committed
448
        torch.testing.assert_close(p1.to(gtype), p2)
Tim Dettmers's avatar
Tim Dettmers committed
449
        for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states):
Tim Dettmers's avatar
Tim Dettmers committed
450
451
452
            torch_optimizer.state[p1][name1].copy_(s.data)


Aarni Koskela's avatar
Aarni Koskela committed
453
454
455
456
@pytest.mark.parametrize("optim_bits", [32, 8], ids=id_formatter("optim_bits"))
@pytest.mark.parametrize("gtype", [torch.float32], ids=describe_dtype)
@pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
457
458
@pytest.mark.deprecated
def test_adam_percentile_clipping(requires_cuda, dim1, dim2, gtype, optim_bits):
459
460
461
    if dim1 == 1 and dim2 == 1:
        return
    p1 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
462
463
464
465
466
467
468
    beta1 = 0.9
    beta2 = 0.999
    lr = 0.001
    eps = 1e-8
    p1 = p1.cuda()
    p2 = p1.clone()
    adam1 = bnb.optim.Adam([p1], lr, (beta1, beta2), eps, optim_bits=optim_bits)
469
    adam2 = bnb.optim.Adam(
470
471
472
473
474
475
        [p2],
        lr,
        (beta1, beta2),
        eps,
        optim_bits=optim_bits,
        percentile_clipping=5,
476
    )
Tim Dettmers's avatar
Tim Dettmers committed
477
478
479
480
481
482

    gnorm_vec = torch.zeros(100).cuda()
    step = 0

    for i in range(50):
        step += 1
Ruff's avatar
Ruff committed
483
        g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + (0.01 * i)
Tim Dettmers's avatar
Tim Dettmers committed
484
485
486
        g2 = g1.clone()
        p2.grad = g2

Ruff's avatar
Ruff committed
487
        current_gnorm, clip_val, gnorm_scale = F.percentile_clipping(g1, gnorm_vec, step, 5)
488
        g1 = (g1.float() * gnorm_scale).to(gtype)
Tim Dettmers's avatar
Tim Dettmers committed
489
490
491
492
493
494
495
        p1.grad = g1

        adam1.step()
        adam2.step()

        # gnorm_scale is not deterministic (warp reductions), as such there can be slight differences in state
        if optim_bits == 32:
Tim Dettmers's avatar
Tim Dettmers committed
496
497
            torch.testing.assert_close(p1, p2)
            torch.testing.assert_close(
498
499
500
501
502
                adam1.state[p1]["state1"],
                adam2.state[p2]["state1"],
                atol=5e-5,
                rtol=1e-4,
            )
Tim Dettmers's avatar
Tim Dettmers committed
503
            torch.testing.assert_close(
504
505
506
507
508
                adam1.state[p1]["state2"],
                adam2.state[p2]["state2"],
                atol=5e-5,
                rtol=1e-4,
            )
Tim Dettmers's avatar
Tim Dettmers committed
509
        elif optim_bits == 8:
Tim Dettmers's avatar
Tim Dettmers committed
510
511
            torch.testing.assert_close(p1, p2, atol=1e-4, rtol=1e-3)
            torch.testing.assert_close(
512
513
514
515
                adam1.state[p1]["state1"],
                adam2.state[p2]["state1"],
                atol=2,
                rtol=1e-3,
516
            )
Tim Dettmers's avatar
Tim Dettmers committed
517
            torch.testing.assert_close(
518
519
520
521
                adam1.state[p1]["state2"],
                adam2.state[p2]["state2"],
                atol=2,
                rtol=1e-3,
522
523
524
            )
            adam1.state[p1]["state1"].copy_(adam2.state[p2]["state1"])
            adam1.state[p1]["state2"].copy_(adam2.state[p2]["state2"])
Tim Dettmers's avatar
Tim Dettmers committed
525
526
        if i % 10 == 0 and i > 0:
            path = get_temp_dir()
527
            torch.save(adam2.state_dict(), join(path, "opt.pt"))
Tim Dettmers's avatar
Tim Dettmers committed
528
529
            del adam2
            adam2 = None
530
531
532
533
534
535
536
537
538
            adam2 = bnb.optim.Adam(
                [p2],
                lr,
                (beta1, beta2),
                eps,
                optim_bits=optim_bits,
                percentile_clipping=5,
            )
            adam2.load_state_dict(torch.load(join(path, "opt.pt")))
Tim Dettmers's avatar
Tim Dettmers committed
539
540


Aarni Koskela's avatar
Aarni Koskela committed
541
542
543
optimizer_names_benchmark = [
    "adam8bit_blockwise",
    "paged_adam8bit_blockwise",
544
545
546
547
548
    "ademamix8bit_blockwise",
    "paged_ademamix8bit_blockwise",
    "ademamix8bit_blockwise_scheduled",
    "paged_ademamix8bit_blockwise_scheduled",
    "lion8bit_blockwise",
Aarni Koskela's avatar
Aarni Koskela committed
549
    "paged_lion8bit_blockwise",
550
    "paged_ademamix8bit_blockwise",
551
]
552
553


Aarni Koskela's avatar
Aarni Koskela committed
554
555
@pytest.mark.parametrize("dim1", [4096], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [4096], ids=id_formatter("dim2"))
556
@pytest.mark.parametrize("gtype", [torch.float32, torch.bfloat16, torch.float16], ids=describe_dtype)
Aarni Koskela's avatar
Aarni Koskela committed
557
558
@pytest.mark.parametrize("optim_name", optimizer_names_benchmark, ids=id_formatter("opt"))
@pytest.mark.benchmark
Egor Krivov's avatar
Egor Krivov committed
559
def test_benchmark_blockwise(dim1, dim2, gtype, optim_name, device):
560
561
    if dim1 == 1 and dim2 == 1:
        return
Egor Krivov's avatar
Egor Krivov committed
562
    p1 = torch.randn(dim1, dim2, device=device, dtype=gtype) * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
563
564
565

    bnb_optimizer = str2optimizers[optim_name][1]([p1])

Egor Krivov's avatar
Egor Krivov committed
566
    g = torch.randn(dim1, dim2, device=device, dtype=gtype) * 0.01
Tim Dettmers's avatar
Tim Dettmers committed
567
    p1.grad = g
568
569
570
    total_steps = 500
    for i in range(total_steps):
        if i == total_steps // 5:
Tim Dettmers's avatar
Tim Dettmers committed
571
            # 100 iterations for burn-in
Egor Krivov's avatar
Egor Krivov committed
572
            sync_gpu(p1)
Tim Dettmers's avatar
Tim Dettmers committed
573
574
575
576
            t0 = time.time()

        bnb_optimizer.step()

Egor Krivov's avatar
Egor Krivov committed
577
    sync_gpu(p1)
578
579
    s = time.time() - t0
    print("")
580
581
    params = (total_steps - total_steps // 5) * dim1 * dim2
    print(optim_name, gtype, s, params, s / params)
582
    # assert s < 3.9