test_optim.py 17 KB
Newer Older
1
import ctypes
Tim Dettmers's avatar
Tim Dettmers committed
2
3
import os
import shutil
4
import time
Tim Dettmers's avatar
Tim Dettmers committed
5
import uuid
6
7
8
from itertools import product
from os.path import join

Tim Dettmers's avatar
Tim Dettmers committed
9
10
import pytest
import torch
11

Tim Dettmers's avatar
Tim Dettmers committed
12
13
14
import bitsandbytes as bnb
import bitsandbytes.functional as F

15
# import apex
Tim Dettmers's avatar
Tim Dettmers committed
16
17

k = 20
Tim Dettmers's avatar
Tim Dettmers committed
18

19

Tim Dettmers's avatar
Tim Dettmers committed
20
def get_temp_dir():
21
    path = f"/tmp/autoswap/{str(uuid.uuid4())}"
Tim Dettmers's avatar
Tim Dettmers committed
22
23
24
    os.makedirs(path, exist_ok=True)
    return path

25

Tim Dettmers's avatar
Tim Dettmers committed
26
27
28
def rm_path(path):
    shutil.rmtree(path)

Tim Dettmers's avatar
Tim Dettmers committed
29
30
str2bf16support = {}
str2bf16support['adam8bit_blockwise'] = True
31

Tim Dettmers's avatar
Tim Dettmers committed
32
str2optimizers = {}
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam)
# str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam)
# str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam)
str2optimizers["momentum_pytorch"] = (
    None,
    lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
    bnb.optim.Adam,
)
str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam)
# str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam)
str2optimizers["momentum"] = (
    lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
    lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False),
)
str2optimizers["lars"] = (
    lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9),
    lambda pxx: bnb.optim.LARS(pxx, 0.01, 0.9),
)
str2optimizers["rmsprop"] = (
    lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
    lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False),
)
str2optimizers["adam8bit"] = (
    torch.optim.Adam,
    lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False),
)
str2optimizers["momentum8bit"] = (
    lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
    lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False),
)
str2optimizers["rmsprop8bit"] = (
    lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
    lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False),
)
str2optimizers["lars8bit"] = (
    lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9),
    lambda pxx: bnb.optim.LARS8bit(pxx, 0.01, 0.9),
)

str2optimizers["adam8bit_blockwise"] = (
    torch.optim.Adam,
    lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True),
)
str2optimizers["momentum8bit_blockwise"] = (
    lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
    lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True),
)
str2optimizers["rmsprop8bit_blockwise"] = (
    lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
    lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=True),
)
Tim Dettmers's avatar
Tim Dettmers committed
84
85

str2statenames = {}
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
str2statenames["adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["momentum"] = [("momentum_buffer", "state1")]
str2statenames["lars"] = [("momentum_buffer", "state1")]
str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["rmsprop"] = [("square_avg", "state1")]
str2statenames["adam8bit"] = [
    ("exp_avg", "state1", "qmap1", "max1"),
    ("exp_avg_sq", "state2", "qmap2", "max2"),
]
str2statenames["lamb8bit"] = [
    ("exp_avg", "state1", "qmap1", "max1"),
    ("exp_avg_sq", "state2", "qmap2", "max2"),
]
str2statenames["adam8bit_blockwise"] = [
    ("exp_avg", "state1", "qmap1", "absmax1"),
    ("exp_avg_sq", "state2", "qmap2", "absmax2"),
]
103
104
105
str2statenames["momentum8bit"] = [
    ("momentum_buffer", "state1", "qmap1", "max1")
]
106
107
108
109
110
str2statenames["momentum8bit_blockwise"] = [
    ("momentum_buffer", "state1", "qmap1", "absmax1")
]
str2statenames["lars8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")]
str2statenames["rmsprop8bit"] = [("square_avg", "state1", "qmap1", "max1")]
111
112
113
str2statenames["rmsprop8bit_blockwise"] = [
    ("square_avg", "state1", "qmap1", "absmax1")
]
Tim Dettmers's avatar
Tim Dettmers committed
114
115
116
117

dim1 = [1024]
dim2 = [32, 1024, 4097, 1]
gtype = [torch.float32, torch.float16]
118
optimizer_names = ["adam", "momentum", "rmsprop", "lars"]
119
values = list(product(dim1, dim2, gtype, optimizer_names))
120
names = [
121
    "dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values
122
]
123
124


Tim Dettmers's avatar
Tim Dettmers committed
125
126
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
def test_optimizer32bit(dim1, dim2, gtype, optim_name):
127
128
129
    if dim1 == 1 and dim2 == 1:
        return
    p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
130
131
132
133
134
135
136
    p2 = p1.clone()
    p1 = p1.float()

    torch_optimizer = str2optimizers[optim_name][0]([p1])
    bnb_optimizer = str2optimizers[optim_name][1]([p2])

    if gtype == torch.float32:
Tim Dettmers's avatar
Tim Dettmers committed
137
        atol, rtol = 1e-6, 1e-5
Tim Dettmers's avatar
Tim Dettmers committed
138
139
140
    else:
        atol, rtol = 1e-4, 1e-3

Tim Dettmers's avatar
Tim Dettmers committed
141
    for i in range(k):
142
        g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01
Tim Dettmers's avatar
Tim Dettmers committed
143
144
145
146
147
148
149
        p1.grad = g.clone().float()
        p2.grad = g.clone()

        bnb_optimizer.step()
        torch_optimizer.step()

        for name1, name2 in str2statenames[optim_name]:
150
151
152
153
154
155
            torch.testing.assert_allclose(
                torch_optimizer.state[p1][name1],
                bnb_optimizer.state[p2][name2],
                atol=atol,
                rtol=rtol,
            )
Tim Dettmers's avatar
Tim Dettmers committed
156
157
158

        torch.testing.assert_allclose(p1, p2.float(), atol=atol, rtol=rtol)

159
        if i % (k // 5) == 0 and i > 0:
Tim Dettmers's avatar
Tim Dettmers committed
160
            path = get_temp_dir()
161
            torch.save(bnb_optimizer.state_dict(), join(path, "opt.pt"))
Tim Dettmers's avatar
Tim Dettmers committed
162
163
164
            del bnb_optimizer
            bnb_optimizer = None
            bnb_optimizer = str2optimizers[optim_name][1]([p2])
165
            bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt")))
Tim Dettmers's avatar
Tim Dettmers committed
166
167
168
            rm_path(path)
            torch.testing.assert_allclose(p1, p2.float(), atol=atol, rtol=rtol)
            for name1, name2 in str2statenames[optim_name]:
169
170
171
172
173
174
                torch.testing.assert_allclose(
                    torch_optimizer.state[p1][name1],
                    bnb_optimizer.state[p2][name2],
                    atol=atol,
                    rtol=rtol,
                )
Tim Dettmers's avatar
Tim Dettmers committed
175
176
177
178
179
180
181
182
183

        if gtype == torch.float16:
            # the adam buffers should also be close because they are 32-bit
            # but the paramters can diverge because they are 16-bit
            # the difference grow larger and larger with each update
            # --> copy the state to keep weights close
            p1.data = p1.data.half().float()
            p2.copy_(p1.data)
            torch.testing.assert_allclose(p1.half(), p2)
184
185
186
        if optim_name in ["lars", "lamb"]:
            assert bnb_optimizer.state[p2]["unorm_vec"] > 0.0

Tim Dettmers's avatar
Tim Dettmers committed
187
188
189
190

dim1 = [1024]
dim2 = [32, 1024, 4097]
gtype = [torch.float32, torch.float16]
191
values = list(product(dim1, dim2, gtype))
192
names = ["dim1_{}_dim2_{}_gtype_{}".format(*vals) for vals in values]
193
194


Tim Dettmers's avatar
Tim Dettmers committed
195
196
@pytest.mark.parametrize("dim1, dim2, gtype", values, ids=names)
def test_global_config(dim1, dim2, gtype):
197
198
199
200
201
    if dim1 == 1 and dim2 == 1:
        return
    p1 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
    p2 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
    p3 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
202
203
204
205
206
207
208
    mask = torch.rand_like(p2) < 0.1
    beta1 = 0.9
    beta2 = 0.999
    lr = 0.001
    eps = 1e-8

    bnb.optim.GlobalOptimManager.get_instance().initialize()
209
210
211
    bnb.optim.GlobalOptimManager.get_instance().override_config(
        p3, "optim_bits", 8
    )
Tim Dettmers's avatar
Tim Dettmers committed
212

213
214
215
    bnb.optim.GlobalOptimManager.get_instance().register_parameters(
        [p1, p2, p3]
    )
Tim Dettmers's avatar
Tim Dettmers committed
216
217
218
219
220
221
222
223
224
225
226
227
    p1 = p1.cuda()
    p2 = p2.cuda()
    p3 = p3.cuda()

    adam2 = bnb.optim.Adam([p1, p2, p3], lr, (beta1, beta2), eps)

    if gtype == torch.float32:
        atol, rtol = 1e-6, 1e-5
    else:
        atol, rtol = 1e-4, 1e-3

    for i in range(50):
228
229
230
        g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001
        g2 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001
        g3 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001
Tim Dettmers's avatar
Tim Dettmers committed
231
232
233
234
235
236
        p1.grad = g1
        p2.grad = g2
        p3.grad = g3

        adam2.step()

237
238
        assert adam2.state[p3]["state1"].dtype == torch.uint8
        assert adam2.state[p3]["state2"].dtype == torch.uint8
Tim Dettmers's avatar
Tim Dettmers committed
239
240
241
242


dim1 = [1024]
dim2 = [32, 1024, 4097]
Tim Dettmers's avatar
Tim Dettmers committed
243
gtype = [torch.float32, torch.float16, torch.bfloat16]
244
245
246
247
248
249
250
251
252
253
optimizer_names = [
    "adam8bit",
    "momentum8bit",
    "rmsprop8bit",
    "adam8bit_blockwise",
    "lars8bit",
    "momentum8bit_blockwise",
    "rmsprop8bit_blockwise",
]
values = list(product(dim1, dim2, gtype, optimizer_names))
254
names = [
255
    "dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values
256
]
257
258


Tim Dettmers's avatar
Tim Dettmers committed
259
260
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
def test_optimizer8bit(dim1, dim2, gtype, optim_name):
Tim Dettmers's avatar
Tim Dettmers committed
261
    if gtype == torch.bfloat16 and optim_name not in str2bf16support: return
262
263
264
    if dim1 == 1 and dim2 == 1:
        return
    p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
265
266
267
268
269
270
271
272
273
274
    p2 = p1.clone()
    p1 = p1.float()
    blocksize = 2048

    torch_optimizer = str2optimizers[optim_name][0]([p1])
    bnb_optimizer = str2optimizers[optim_name][1]([p2])

    if gtype == torch.float32:
        atol, rtol = 3e-3, 1e-3
        patol, prtol = 1e-5, 1e-3
Tim Dettmers's avatar
Tim Dettmers committed
275
276
277
    elif gtype == torch.bfloat16:
        atol, rtol = 3e-3, 1e-3
        patol, prtol = 1e-4, 1e-2
Tim Dettmers's avatar
Tim Dettmers committed
278
279
280
281
282
283
284
    else:
        atol, rtol = 3e-3, 1e-3
        patol, prtol = 1e-5, 1e-3

    errors = []
    relerrors = []

285
    for i in range(100):
286
        g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01
Tim Dettmers's avatar
Tim Dettmers committed
287
288
289
290
291
292
293
294
295
296
        p1.grad = g.clone().float()
        p2.grad = g.clone()

        bnb_optimizer.step()
        torch_optimizer.step()

        torch.testing.assert_allclose(p1, p2.float(), atol=patol, rtol=prtol)

        dequant_states = []
        for name1, name2, qmap, max_val in str2statenames[optim_name]:
297
298
299
300
301
302
303
304
            # print(bnb_optimizer.state[p2][max_val], name1)
            if "blockwise" in optim_name:
                s1 = F.dequantize_blockwise(
                    code=bnb_optimizer.state[p2][qmap],
                    absmax=bnb_optimizer.state[p2][max_val],
                    A=bnb_optimizer.state[p2][name2],
                    blocksize=blocksize,
                )
Tim Dettmers's avatar
Tim Dettmers committed
305
            else:
306
307
308
309
310
311
312
313
314
315
316
                s1 = F.dequantize(
                    code=bnb_optimizer.state[p2][qmap],
                    absmax=bnb_optimizer.state[p2][max_val],
                    A=bnb_optimizer.state[p2][name2],
                )
            num_not_close = (
                torch.isclose(
                    torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol
                )
                == 0
            )
317
            #assert num_not_close.sum().item() < 20
Tim Dettmers's avatar
Tim Dettmers committed
318
319
            dequant_states.append(s1.clone())

320
321
        err = torch.abs(p1 - p2)
        relerr = err / torch.abs(p1)
Tim Dettmers's avatar
Tim Dettmers committed
322
323
324
325
326
327
        if g.dtype == torch.bfloat16:
            assert err.mean() < 0.00015
            assert relerr.mean() < 0.0015
        else:
            assert err.mean() < 0.0001
            assert relerr.mean() < 0.001
Tim Dettmers's avatar
Tim Dettmers committed
328
329
330
331
332

        errors.append(err.mean().item())
        relerrors.append(relerr.mean().item())

        if i % 10 == 0 and i > 0:
333
334
335
            for (name1, name2, qmap, max_val), s in zip(
                str2statenames[optim_name], dequant_states
            ):
Tim Dettmers's avatar
Tim Dettmers committed
336
337
338
339
340
                s1cpy = s.clone()
                raws1cpy = bnb_optimizer.state[p2][name2].clone()
                qmap1 = bnb_optimizer.state[p2][qmap].clone()

                path = get_temp_dir()
341
                torch.save(bnb_optimizer.state_dict(), join(path, "opt.pt"))
Tim Dettmers's avatar
Tim Dettmers committed
342
343
344
                del bnb_optimizer
                bnb_optimizer = None
                bnb_optimizer = str2optimizers[optim_name][1]([p2])
345
                bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt")))
Tim Dettmers's avatar
Tim Dettmers committed
346
                rm_path(path)
Tim Dettmers's avatar
Tim Dettmers committed
347
348
                torch.testing.assert_allclose(raws1cpy, bnb_optimizer.state[p2][name2])
                torch.testing.assert_allclose(qmap1, bnb_optimizer.state[p2][qmap])
Tim Dettmers's avatar
Tim Dettmers committed
349

350
351
352
353
354
355
356
                if "blockwise" in optim_name:
                    s1 = F.dequantize_blockwise(
                        code=bnb_optimizer.state[p2][qmap],
                        absmax=bnb_optimizer.state[p2][max_val],
                        A=bnb_optimizer.state[p2][name2],
                        blocksize=blocksize,
                    )
Tim Dettmers's avatar
Tim Dettmers committed
357
                else:
358
359
360
361
362
                    s1 = F.dequantize(
                        code=bnb_optimizer.state[p2][qmap],
                        absmax=bnb_optimizer.state[p2][max_val],
                        A=bnb_optimizer.state[p2][name2],
                    )
Tim Dettmers's avatar
Tim Dettmers committed
363
364
                torch.testing.assert_allclose(s1cpy, s1)

Tim Dettmers's avatar
Tim Dettmers committed
365
                num_not_close = (torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0)
Tim Dettmers's avatar
Tim Dettmers committed
366
                assert num_not_close.sum().item() < 20
Tim Dettmers's avatar
Tim Dettmers committed
367
            torch.testing.assert_allclose(p1, p2.float(), atol=patol, rtol=prtol)
Tim Dettmers's avatar
Tim Dettmers committed
368
369
370
371
372
373

        # the parameters diverge quickly. Here we keep them close
        # together so we can test against the Adam error
        p1.data = p1.data.to(gtype).float()
        p2.copy_(p1.data)
        torch.testing.assert_allclose(p1.to(gtype), p2)
Tim Dettmers's avatar
Tim Dettmers committed
374
        for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states):
Tim Dettmers's avatar
Tim Dettmers committed
375
376
            torch_optimizer.state[p1][name1].copy_(s.data)

377
378
    # print(sum(errors)/len(errors))
    # print(sum(relerrors)/len(relerrors))
Tim Dettmers's avatar
Tim Dettmers committed
379
380
381
382
383
384


dim1 = [1024]
dim2 = [32, 1024, 4097]
gtype = [torch.float32]
optim_bits = [32, 8]
385
values = list(product(dim1, dim2, gtype, optim_bits))
386
names = [
387
    "dim1_{}_dim2_{}_gtype_{}_optim_bits_{}".format(*vals)
388
389
    for vals in values
]
390
391


Tim Dettmers's avatar
Tim Dettmers committed
392
393
@pytest.mark.parametrize("dim1, dim2, gtype, optim_bits", values, ids=names)
def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
394
395
396
    if dim1 == 1 and dim2 == 1:
        return
    p1 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
397
398
399
400
401
402
403
    beta1 = 0.9
    beta2 = 0.999
    lr = 0.001
    eps = 1e-8
    p1 = p1.cuda()
    p2 = p1.clone()
    adam1 = bnb.optim.Adam([p1], lr, (beta1, beta2), eps, optim_bits=optim_bits)
404
    adam2 = bnb.optim.Adam(
405
406
407
408
409
410
        [p2],
        lr,
        (beta1, beta2),
        eps,
        optim_bits=optim_bits,
        percentile_clipping=5,
411
    )
Tim Dettmers's avatar
Tim Dettmers committed
412
413
414
415
416
417

    gnorm_vec = torch.zeros(100).cuda()
    step = 0

    for i in range(50):
        step += 1
418
419
420
        g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + (
            0.01 * i
        )
Tim Dettmers's avatar
Tim Dettmers committed
421
422
423
        g2 = g1.clone()
        p2.grad = g2

424
425
426
427
        current_gnorm, clip_val, gnorm_scale = F.percentile_clipping(
            g1, gnorm_vec, step, 5
        )
        g1 = (g1.float() * gnorm_scale).to(gtype)
Tim Dettmers's avatar
Tim Dettmers committed
428
429
430
431
432
433
434
435
        p1.grad = g1

        adam1.step()
        adam2.step()

        # gnorm_scale is not deterministic (warp reductions), as such there can be slight differences in state
        if optim_bits == 32:
            torch.testing.assert_allclose(p1, p2)
436
437
438
439
440
441
442
443
444
445
446
447
            torch.testing.assert_allclose(
                adam1.state[p1]["state1"],
                adam2.state[p2]["state1"],
                atol=5e-5,
                rtol=1e-4,
            )
            torch.testing.assert_allclose(
                adam1.state[p1]["state2"],
                adam2.state[p2]["state2"],
                atol=5e-5,
                rtol=1e-4,
            )
Tim Dettmers's avatar
Tim Dettmers committed
448
449
        elif optim_bits == 8:
            torch.testing.assert_allclose(p1, p2, atol=1e-4, rtol=1e-3)
450
            torch.testing.assert_allclose(
451
452
453
454
                adam1.state[p1]["state1"],
                adam2.state[p2]["state1"],
                atol=2,
                rtol=1e-3,
455
456
            )
            torch.testing.assert_allclose(
457
458
459
460
                adam1.state[p1]["state2"],
                adam2.state[p2]["state2"],
                atol=2,
                rtol=1e-3,
461
462
463
            )
            adam1.state[p1]["state1"].copy_(adam2.state[p2]["state1"])
            adam1.state[p1]["state2"].copy_(adam2.state[p2]["state2"])
Tim Dettmers's avatar
Tim Dettmers committed
464
465
        if i % 10 == 0 and i > 0:
            path = get_temp_dir()
466
            torch.save(adam2.state_dict(), join(path, "opt.pt"))
Tim Dettmers's avatar
Tim Dettmers committed
467
468
            del adam2
            adam2 = None
469
470
471
472
473
474
475
476
477
            adam2 = bnb.optim.Adam(
                [p2],
                lr,
                (beta1, beta2),
                eps,
                optim_bits=optim_bits,
                percentile_clipping=5,
            )
            adam2.load_state_dict(torch.load(join(path, "opt.pt")))
Tim Dettmers's avatar
Tim Dettmers committed
478
479
480
481
482


dim1 = [4096]
dim2 = [4096]
gtype = [torch.float32, torch.float16]
483
484
485
486
487
488
489
# optimizer_names = ['adam8bit_blockwise', 'adam8bit', 'lamb8bit']
# optimizer_names = ['adam8bit_blockwise', 'adam_apex', 'adam8bit', 'adam', 'adam_pytorch']
# optimizer_names = ['momentum_apex', 'momentum8bit', 'momentum_pytorch']
# optimizer_names = ['lamb_apex', 'lamb8bit']
# optimizer_names = ['lars_apex', 'lars8bit']
optimizer_names = ["adam8bit_blockwise"]
values = list(product(dim1, dim2, gtype, optimizer_names))
490
names = [
491
    "dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values
492
]
493
494


Tim Dettmers's avatar
Tim Dettmers committed
495
496
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
497
498
499
    if dim1 == 1 and dim2 == 1:
        return
    p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
500
501
502

    bnb_optimizer = str2optimizers[optim_name][1]([p1])

503
    g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01
Tim Dettmers's avatar
Tim Dettmers committed
504
    p1.grad = g
Tim Dettmers's avatar
Tim Dettmers committed
505
    for i in range(k):
506
        if i == k // 5:
Tim Dettmers's avatar
Tim Dettmers committed
507
508
509
510
511
512
513
            # 100 iterations for burn-in
            torch.cuda.synchronize()
            t0 = time.time()

        bnb_optimizer.step()

    torch.cuda.synchronize()
514
515
516
517
518
    s = time.time() - t0
    print("")
    params = (k - k // 5) * dim1 * dim2
    print(optim_name, gtype, s / params)
    # assert s < 3.9