test_optim.py 18.5 KB
Newer Older
1
import ctypes
Tim Dettmers's avatar
Tim Dettmers committed
2
3
import os
import shutil
4
import time
Tim Dettmers's avatar
Tim Dettmers committed
5
import uuid
6
7
8
from itertools import product
from os.path import join

Tim Dettmers's avatar
Tim Dettmers committed
9
import pytest
10
11
from lion_pytorch import Lion

Tim Dettmers's avatar
Tim Dettmers committed
12
import torch
13

Tim Dettmers's avatar
Tim Dettmers committed
14
15
16
import bitsandbytes as bnb
import bitsandbytes.functional as F

17
# import apex
Tim Dettmers's avatar
Tim Dettmers committed
18
19

k = 20
Tim Dettmers's avatar
Tim Dettmers committed
20

21
22
23
24
def assert_most_approx_close(a, b, rtol=1e-3, atol=1e-3, max_error_count=0):
    idx = torch.isclose(a, b, rtol, atol)
    error_count = (idx == 0).sum().item()
    if error_count > max_error_count:
25
        print(f"Too many values not close: assert {error_count} < {max_error_count}")
26
27
        torch.testing.assert_allclose(a, b, rtol, atol)

28

Tim Dettmers's avatar
Tim Dettmers committed
29
def get_temp_dir():
30
    path = f"/tmp/autoswap/{str(uuid.uuid4())}"
Tim Dettmers's avatar
Tim Dettmers committed
31
32
33
    os.makedirs(path, exist_ok=True)
    return path

34

Tim Dettmers's avatar
Tim Dettmers committed
35
36
37
def rm_path(path):
    shutil.rmtree(path)

38

Tim Dettmers's avatar
Tim Dettmers committed
39
str2optimizers = {}
40
41
42
str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam)
# str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam)
# str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam)
43
str2optimizers["lion_pytorch"] = (None, Lion, bnb.optim.Lion)
44
45
46
47
48
49
50
str2optimizers["momentum_pytorch"] = (
    None,
    lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
    bnb.optim.Adam,
)
str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam)
# str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam)
51
str2optimizers["lion"] = (Lion, bnb.optim.Lion)
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
str2optimizers["momentum"] = (
    lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
    lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False),
)
str2optimizers["lars"] = (
    lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9),
    lambda pxx: bnb.optim.LARS(pxx, 0.01, 0.9),
)
str2optimizers["rmsprop"] = (
    lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
    lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False),
)
str2optimizers["adam8bit"] = (
    torch.optim.Adam,
    lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False),
)
68
69
70
71
str2optimizers["lion8bit"] = (
    Lion,
    lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=False),
)
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
str2optimizers["momentum8bit"] = (
    lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
    lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False),
)
str2optimizers["rmsprop8bit"] = (
    lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
    lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False),
)
str2optimizers["lars8bit"] = (
    lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9),
    lambda pxx: bnb.optim.LARS8bit(pxx, 0.01, 0.9),
)

str2optimizers["adam8bit_blockwise"] = (
    torch.optim.Adam,
    lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True),
)
89
90
91
92
str2optimizers["lion8bit_blockwise"] = (
    Lion,
    lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=True),
)
93
94
95
96
97
98
99
100
str2optimizers["momentum8bit_blockwise"] = (
    lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
    lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True),
)
str2optimizers["rmsprop8bit_blockwise"] = (
    lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
    lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=True),
)
Tim Dettmers's avatar
Tim Dettmers committed
101
102

str2statenames = {}
103
str2statenames["adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
104
str2statenames["lion"] = [("exp_avg", "state1")]
105
106
107
108
109
110
111
112
str2statenames["momentum"] = [("momentum_buffer", "state1")]
str2statenames["lars"] = [("momentum_buffer", "state1")]
str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["rmsprop"] = [("square_avg", "state1")]
str2statenames["adam8bit"] = [
    ("exp_avg", "state1", "qmap1", "max1"),
    ("exp_avg_sq", "state2", "qmap2", "max2"),
]
113
114
115
str2statenames["lion8bit"] = [
    ("exp_avg", "state1", "qmap1", "max1")
]
116
117
118
119
120
121
122
123
str2statenames["lamb8bit"] = [
    ("exp_avg", "state1", "qmap1", "max1"),
    ("exp_avg_sq", "state2", "qmap2", "max2"),
]
str2statenames["adam8bit_blockwise"] = [
    ("exp_avg", "state1", "qmap1", "absmax1"),
    ("exp_avg_sq", "state2", "qmap2", "absmax2"),
]
124
125
126
str2statenames["lion8bit_blockwise"] = [
    ("exp_avg", "state1", "qmap1", "absmax1")
]
127
128
129
str2statenames["momentum8bit"] = [
    ("momentum_buffer", "state1", "qmap1", "max1")
]
130
131
132
133
134
str2statenames["momentum8bit_blockwise"] = [
    ("momentum_buffer", "state1", "qmap1", "absmax1")
]
str2statenames["lars8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")]
str2statenames["rmsprop8bit"] = [("square_avg", "state1", "qmap1", "max1")]
135
136
137
str2statenames["rmsprop8bit_blockwise"] = [
    ("square_avg", "state1", "qmap1", "absmax1")
]
Tim Dettmers's avatar
Tim Dettmers committed
138
139
140
141

dim1 = [1024]
dim2 = [32, 1024, 4097, 1]
gtype = [torch.float32, torch.float16]
142
optimizer_names = ["adam", "momentum", "rmsprop", "lars", "lion"]
143
values = list(product(dim1, dim2, gtype, optimizer_names))
144
names = [
145
    "dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values
146
]
147
148


Tim Dettmers's avatar
Tim Dettmers committed
149
150
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
def test_optimizer32bit(dim1, dim2, gtype, optim_name):
151
152
153
    if dim1 == 1 and dim2 == 1:
        return
    p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
154
155
156
157
158
159
160
    p2 = p1.clone()
    p1 = p1.float()

    torch_optimizer = str2optimizers[optim_name][0]([p1])
    bnb_optimizer = str2optimizers[optim_name][1]([p2])

    if gtype == torch.float32:
Tim Dettmers's avatar
Tim Dettmers committed
161
        atol, rtol = 1e-6, 1e-5
Tim Dettmers's avatar
Tim Dettmers committed
162
163
164
    else:
        atol, rtol = 1e-4, 1e-3

Tim Dettmers's avatar
Tim Dettmers committed
165
    for i in range(k):
166
        g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01
Tim Dettmers's avatar
Tim Dettmers committed
167
168
169
170
171
172
        p1.grad = g.clone().float()
        p2.grad = g.clone()

        bnb_optimizer.step()
        torch_optimizer.step()

173

Tim Dettmers's avatar
Tim Dettmers committed
174
        for name1, name2 in str2statenames[optim_name]:
175
176
177
178
179
180
            torch.testing.assert_allclose(
                torch_optimizer.state[p1][name1],
                bnb_optimizer.state[p2][name2],
                atol=atol,
                rtol=rtol,
            )
Tim Dettmers's avatar
Tim Dettmers committed
181

182
183
184
        # since Lion can have pretty noisy updates where things lie at the boundary
        # allow up to 10 errors for Lion
        assert_most_approx_close(p1, p2.float(), atol, rtol, max_error_count=10)
Tim Dettmers's avatar
Tim Dettmers committed
185

186
        if i % (k // 5) == 0 and i > 0:
Tim Dettmers's avatar
Tim Dettmers committed
187
            path = get_temp_dir()
188
            torch.save(bnb_optimizer.state_dict(), join(path, "opt.pt"))
Tim Dettmers's avatar
Tim Dettmers committed
189
190
191
            del bnb_optimizer
            bnb_optimizer = None
            bnb_optimizer = str2optimizers[optim_name][1]([p2])
192
            bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt")))
Tim Dettmers's avatar
Tim Dettmers committed
193
            rm_path(path)
194
195
196
            # since Lion can have pretty noisy updates where things lie at the boundary
            # allow up to 10 errors for Lion
            assert_most_approx_close(p1, p2.float(), atol, rtol, max_error_count=10)
Tim Dettmers's avatar
Tim Dettmers committed
197
            for name1, name2 in str2statenames[optim_name]:
198
199
200
201
202
                # since Lion can have pretty noisy updates where things lie at the boundary
                # allow up to 10 errors for Lion
                assert_most_approx_close(torch_optimizer.state[p1][name1], bnb_optimizer.state[p2][name2],
                                         atol=atol, rtol=rtol,
                                         max_error_count=10)
Tim Dettmers's avatar
Tim Dettmers committed
203
204
205
206
207
208
209
210
211

        if gtype == torch.float16:
            # the adam buffers should also be close because they are 32-bit
            # but the paramters can diverge because they are 16-bit
            # the difference grow larger and larger with each update
            # --> copy the state to keep weights close
            p1.data = p1.data.half().float()
            p2.copy_(p1.data)
            torch.testing.assert_allclose(p1.half(), p2)
212
213
214
        if optim_name in ["lars", "lamb"]:
            assert bnb_optimizer.state[p2]["unorm_vec"] > 0.0

Tim Dettmers's avatar
Tim Dettmers committed
215
216
217
218

dim1 = [1024]
dim2 = [32, 1024, 4097]
gtype = [torch.float32, torch.float16]
219
values = list(product(dim1, dim2, gtype))
220
names = ["dim1_{}_dim2_{}_gtype_{}".format(*vals) for vals in values]
221
222


Tim Dettmers's avatar
Tim Dettmers committed
223
224
@pytest.mark.parametrize("dim1, dim2, gtype", values, ids=names)
def test_global_config(dim1, dim2, gtype):
225
226
227
228
229
    if dim1 == 1 and dim2 == 1:
        return
    p1 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
    p2 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
    p3 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
230
231
232
233
234
235
236
    mask = torch.rand_like(p2) < 0.1
    beta1 = 0.9
    beta2 = 0.999
    lr = 0.001
    eps = 1e-8

    bnb.optim.GlobalOptimManager.get_instance().initialize()
237
238
239
    bnb.optim.GlobalOptimManager.get_instance().override_config(
        p3, "optim_bits", 8
    )
Tim Dettmers's avatar
Tim Dettmers committed
240

241
242
243
    bnb.optim.GlobalOptimManager.get_instance().register_parameters(
        [p1, p2, p3]
    )
Tim Dettmers's avatar
Tim Dettmers committed
244
245
246
247
248
249
250
251
252
253
254
255
    p1 = p1.cuda()
    p2 = p2.cuda()
    p3 = p3.cuda()

    adam2 = bnb.optim.Adam([p1, p2, p3], lr, (beta1, beta2), eps)

    if gtype == torch.float32:
        atol, rtol = 1e-6, 1e-5
    else:
        atol, rtol = 1e-4, 1e-3

    for i in range(50):
256
257
258
        g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001
        g2 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001
        g3 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001
Tim Dettmers's avatar
Tim Dettmers committed
259
260
261
262
263
264
        p1.grad = g1
        p2.grad = g2
        p3.grad = g3

        adam2.step()

265
266
        assert adam2.state[p3]["state1"].dtype == torch.uint8
        assert adam2.state[p3]["state2"].dtype == torch.uint8
Tim Dettmers's avatar
Tim Dettmers committed
267
268
269
270
271


dim1 = [1024]
dim2 = [32, 1024, 4097]
gtype = [torch.float32, torch.float16]
272
273
optimizer_names = [
    "adam8bit",
274
    "lion8bit",
275
276
277
    "momentum8bit",
    "rmsprop8bit",
    "adam8bit_blockwise",
278
    "lion8bit_blockwise",
279
280
281
282
283
    "lars8bit",
    "momentum8bit_blockwise",
    "rmsprop8bit_blockwise",
]
values = list(product(dim1, dim2, gtype, optimizer_names))
284
names = [
285
    "dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values
286
]
287
288


Tim Dettmers's avatar
Tim Dettmers committed
289
290
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
def test_optimizer8bit(dim1, dim2, gtype, optim_name):
291
292
293
    if dim1 == 1 and dim2 == 1:
        return
    p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
    p2 = p1.clone()
    p1 = p1.float()
    blocksize = 2048

    torch_optimizer = str2optimizers[optim_name][0]([p1])
    bnb_optimizer = str2optimizers[optim_name][1]([p2])

    if gtype == torch.float32:
        atol, rtol = 3e-3, 1e-3
        patol, prtol = 1e-5, 1e-3

    else:
        atol, rtol = 3e-3, 1e-3
        patol, prtol = 1e-5, 1e-3

    errors = []
    relerrors = []

    for i in range(50):
313
        g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01
Tim Dettmers's avatar
Tim Dettmers committed
314
315
316
317
318
319
        p1.grad = g.clone().float()
        p2.grad = g.clone()

        bnb_optimizer.step()
        torch_optimizer.step()

320
321
322
        # since Lion can have pretty noisy updates where things lie at the boundary
        # allow up to 5 errors for Lion
        assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=5)
Tim Dettmers's avatar
Tim Dettmers committed
323
324
325

        dequant_states = []
        for name1, name2, qmap, max_val in str2statenames[optim_name]:
326
327
328
329
330
331
332
333
            # print(bnb_optimizer.state[p2][max_val], name1)
            if "blockwise" in optim_name:
                s1 = F.dequantize_blockwise(
                    code=bnb_optimizer.state[p2][qmap],
                    absmax=bnb_optimizer.state[p2][max_val],
                    A=bnb_optimizer.state[p2][name2],
                    blocksize=blocksize,
                )
Tim Dettmers's avatar
Tim Dettmers committed
334
            else:
335
336
337
338
339
340
341
342
343
344
345
                s1 = F.dequantize(
                    code=bnb_optimizer.state[p2][qmap],
                    absmax=bnb_optimizer.state[p2][max_val],
                    A=bnb_optimizer.state[p2][name2],
                )
            num_not_close = (
                torch.isclose(
                    torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol
                )
                == 0
            )
Tim Dettmers's avatar
Tim Dettmers committed
346
347
348
            assert num_not_close.sum().item() < 20
            dequant_states.append(s1.clone())

349
        err = torch.abs(p1 - p2)
350
        relerr = err / (torch.abs(p1)+1e-9)
Tim Dettmers's avatar
Tim Dettmers committed
351
352
353
354
355
356
357
        assert err.mean() < 0.0001
        assert relerr.mean() < 0.001

        errors.append(err.mean().item())
        relerrors.append(relerr.mean().item())

        if i % 10 == 0 and i > 0:
358
359
360
            for (name1, name2, qmap, max_val), s in zip(
                str2statenames[optim_name], dequant_states
            ):
Tim Dettmers's avatar
Tim Dettmers committed
361
362
363
364
365
                s1cpy = s.clone()
                raws1cpy = bnb_optimizer.state[p2][name2].clone()
                qmap1 = bnb_optimizer.state[p2][qmap].clone()

                path = get_temp_dir()
366
                torch.save(bnb_optimizer.state_dict(), join(path, "opt.pt"))
Tim Dettmers's avatar
Tim Dettmers committed
367
368
369
                del bnb_optimizer
                bnb_optimizer = None
                bnb_optimizer = str2optimizers[optim_name][1]([p2])
370
                bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt")))
Tim Dettmers's avatar
Tim Dettmers committed
371
                rm_path(path)
372
373
374
375
376
377
                torch.testing.assert_allclose(
                    raws1cpy, bnb_optimizer.state[p2][name2]
                )
                torch.testing.assert_allclose(
                    qmap1, bnb_optimizer.state[p2][qmap]
                )
Tim Dettmers's avatar
Tim Dettmers committed
378

379
380
381
382
383
384
385
                if "blockwise" in optim_name:
                    s1 = F.dequantize_blockwise(
                        code=bnb_optimizer.state[p2][qmap],
                        absmax=bnb_optimizer.state[p2][max_val],
                        A=bnb_optimizer.state[p2][name2],
                        blocksize=blocksize,
                    )
Tim Dettmers's avatar
Tim Dettmers committed
386
                else:
387
388
389
390
391
                    s1 = F.dequantize(
                        code=bnb_optimizer.state[p2][qmap],
                        absmax=bnb_optimizer.state[p2][max_val],
                        A=bnb_optimizer.state[p2][name2],
                    )
Tim Dettmers's avatar
Tim Dettmers committed
392
393
                torch.testing.assert_allclose(s1cpy, s1)

394
395
                num_not_close = (
                    torch.isclose(
396
397
398
399
                        torch_optimizer.state[p1][name1],
                        s1,
                        atol=atol,
                        rtol=rtol,
400
401
402
                    )
                    == 0
                )
Tim Dettmers's avatar
Tim Dettmers committed
403
                assert num_not_close.sum().item() < 20
404
405
406
            # since Lion can have pretty noisy updates where things lie at the boundary
            # allow up to 5 errors for Lion
            assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=5)
Tim Dettmers's avatar
Tim Dettmers committed
407
408
409
410
411
412

        # the parameters diverge quickly. Here we keep them close
        # together so we can test against the Adam error
        p1.data = p1.data.to(gtype).float()
        p2.copy_(p1.data)
        torch.testing.assert_allclose(p1.to(gtype), p2)
413
414
415
        for (name1, name2, qmap, max_val), s in zip(
            str2statenames[optim_name], dequant_states
        ):
Tim Dettmers's avatar
Tim Dettmers committed
416
417
            torch_optimizer.state[p1][name1].copy_(s.data)

418
419
    # print(sum(errors)/len(errors))
    # print(sum(relerrors)/len(relerrors))
Tim Dettmers's avatar
Tim Dettmers committed
420
421
422
423
424
425


dim1 = [1024]
dim2 = [32, 1024, 4097]
gtype = [torch.float32]
optim_bits = [32, 8]
426
values = list(product(dim1, dim2, gtype, optim_bits))
427
names = [
428
    "dim1_{}_dim2_{}_gtype_{}_optim_bits_{}".format(*vals)
429
430
    for vals in values
]
431
432


Tim Dettmers's avatar
Tim Dettmers committed
433
434
@pytest.mark.parametrize("dim1, dim2, gtype, optim_bits", values, ids=names)
def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
435
436
437
    if dim1 == 1 and dim2 == 1:
        return
    p1 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
438
439
440
441
442
443
444
    beta1 = 0.9
    beta2 = 0.999
    lr = 0.001
    eps = 1e-8
    p1 = p1.cuda()
    p2 = p1.clone()
    adam1 = bnb.optim.Adam([p1], lr, (beta1, beta2), eps, optim_bits=optim_bits)
445
    adam2 = bnb.optim.Adam(
446
447
448
449
450
451
        [p2],
        lr,
        (beta1, beta2),
        eps,
        optim_bits=optim_bits,
        percentile_clipping=5,
452
    )
Tim Dettmers's avatar
Tim Dettmers committed
453
454
455
456
457
458

    gnorm_vec = torch.zeros(100).cuda()
    step = 0

    for i in range(50):
        step += 1
459
460
461
        g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + (
            0.01 * i
        )
Tim Dettmers's avatar
Tim Dettmers committed
462
463
464
        g2 = g1.clone()
        p2.grad = g2

465
466
467
468
        current_gnorm, clip_val, gnorm_scale = F.percentile_clipping(
            g1, gnorm_vec, step, 5
        )
        g1 = (g1.float() * gnorm_scale).to(gtype)
Tim Dettmers's avatar
Tim Dettmers committed
469
470
471
472
473
474
475
476
        p1.grad = g1

        adam1.step()
        adam2.step()

        # gnorm_scale is not deterministic (warp reductions), as such there can be slight differences in state
        if optim_bits == 32:
            torch.testing.assert_allclose(p1, p2)
477
478
479
480
481
482
483
484
485
486
487
488
            torch.testing.assert_allclose(
                adam1.state[p1]["state1"],
                adam2.state[p2]["state1"],
                atol=5e-5,
                rtol=1e-4,
            )
            torch.testing.assert_allclose(
                adam1.state[p1]["state2"],
                adam2.state[p2]["state2"],
                atol=5e-5,
                rtol=1e-4,
            )
Tim Dettmers's avatar
Tim Dettmers committed
489
490
        elif optim_bits == 8:
            torch.testing.assert_allclose(p1, p2, atol=1e-4, rtol=1e-3)
491
            torch.testing.assert_allclose(
492
493
494
495
                adam1.state[p1]["state1"],
                adam2.state[p2]["state1"],
                atol=2,
                rtol=1e-3,
496
497
            )
            torch.testing.assert_allclose(
498
499
500
501
                adam1.state[p1]["state2"],
                adam2.state[p2]["state2"],
                atol=2,
                rtol=1e-3,
502
503
504
            )
            adam1.state[p1]["state1"].copy_(adam2.state[p2]["state1"])
            adam1.state[p1]["state2"].copy_(adam2.state[p2]["state2"])
Tim Dettmers's avatar
Tim Dettmers committed
505
506
        if i % 10 == 0 and i > 0:
            path = get_temp_dir()
507
            torch.save(adam2.state_dict(), join(path, "opt.pt"))
Tim Dettmers's avatar
Tim Dettmers committed
508
509
            del adam2
            adam2 = None
510
511
512
513
514
515
516
517
518
            adam2 = bnb.optim.Adam(
                [p2],
                lr,
                (beta1, beta2),
                eps,
                optim_bits=optim_bits,
                percentile_clipping=5,
            )
            adam2.load_state_dict(torch.load(join(path, "opt.pt")))
Tim Dettmers's avatar
Tim Dettmers committed
519
520
521
522
523


dim1 = [4096]
dim2 = [4096]
gtype = [torch.float32, torch.float16]
524
525
526
527
528
529
530
# optimizer_names = ['adam8bit_blockwise', 'adam8bit', 'lamb8bit']
# optimizer_names = ['adam8bit_blockwise', 'adam_apex', 'adam8bit', 'adam', 'adam_pytorch']
# optimizer_names = ['momentum_apex', 'momentum8bit', 'momentum_pytorch']
# optimizer_names = ['lamb_apex', 'lamb8bit']
# optimizer_names = ['lars_apex', 'lars8bit']
optimizer_names = ["adam8bit_blockwise"]
values = list(product(dim1, dim2, gtype, optimizer_names))
531
names = [
532
    "dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values
533
]
534
535


Tim Dettmers's avatar
Tim Dettmers committed
536
537
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
538
539
540
    if dim1 == 1 and dim2 == 1:
        return
    p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
541
542
543

    bnb_optimizer = str2optimizers[optim_name][1]([p1])

544
    g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01
Tim Dettmers's avatar
Tim Dettmers committed
545
    p1.grad = g
Tim Dettmers's avatar
Tim Dettmers committed
546
    for i in range(k):
547
        if i == k // 5:
Tim Dettmers's avatar
Tim Dettmers committed
548
549
550
551
552
553
554
            # 100 iterations for burn-in
            torch.cuda.synchronize()
            t0 = time.time()

        bnb_optimizer.step()

    torch.cuda.synchronize()
555
556
557
558
559
    s = time.time() - t0
    print("")
    params = (k - k // 5) * dim1 * dim2
    print(optim_name, gtype, s / params)
    # assert s < 3.9