test_optim.py 17.4 KB
Newer Older
1
import ctypes
Tim Dettmers's avatar
Tim Dettmers committed
2
3
import os
import shutil
4
import time
Tim Dettmers's avatar
Tim Dettmers committed
5
import uuid
6
7
8
from itertools import product
from os.path import join

Tim Dettmers's avatar
Tim Dettmers committed
9
10
import pytest
import torch
11

Tim Dettmers's avatar
Tim Dettmers committed
12
13
14
import bitsandbytes as bnb
import bitsandbytes.functional as F

15
# import apex
Tim Dettmers's avatar
Tim Dettmers committed
16
17

k = 20
Tim Dettmers's avatar
Tim Dettmers committed
18

19

Tim Dettmers's avatar
Tim Dettmers committed
20
def get_temp_dir():
21
    path = f"/tmp/autoswap/{str(uuid.uuid4())}"
Tim Dettmers's avatar
Tim Dettmers committed
22
23
24
    os.makedirs(path, exist_ok=True)
    return path

25

Tim Dettmers's avatar
Tim Dettmers committed
26
27
28
def rm_path(path):
    shutil.rmtree(path)

Tim Dettmers's avatar
Tim Dettmers committed
29
30
str2bf16support = {}
str2bf16support['adam8bit_blockwise'] = True
31

Tim Dettmers's avatar
Tim Dettmers committed
32
str2optimizers = {}
33
34
35
36
37
38
39
40
41
str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam)
# str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam)
# str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam)
str2optimizers["momentum_pytorch"] = (
    None,
    lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
    bnb.optim.Adam,
)
str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam)
Tim Dettmers's avatar
Tim Dettmers committed
42
43
str2optimizers["paged_adamw"] = (torch.optim.AdamW, bnb.optim.PagedAdamW)
str2optimizers["paged_adam"] = (torch.optim.Adam, bnb.optim.PagedAdam)
44
45
46
47
48
49
50
51
52
# str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam)
str2optimizers["momentum"] = (
    lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
    lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False),
)
str2optimizers["rmsprop"] = (
    lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
    lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False),
)
Tim Dettmers's avatar
Tim Dettmers committed
53
str2optimizers["adam8bit"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False))
54
55
56
57
58
59
60
61
62
str2optimizers["momentum8bit"] = (
    lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
    lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False),
)
str2optimizers["rmsprop8bit"] = (
    lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
    lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False),
)

Tim Dettmers's avatar
Tim Dettmers committed
63
64
65
str2optimizers["adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True))
str2optimizers["paged_adamw8bit_blockwise"] = (torch.optim.AdamW, lambda pxx: bnb.optim.PagedAdamW8bit(pxx, block_wise=True))
str2optimizers["paged_adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.PagedAdam8bit(pxx, block_wise=True))
66
67
68
69
70
71
72
73
str2optimizers["momentum8bit_blockwise"] = (
    lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
    lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True),
)
str2optimizers["rmsprop8bit_blockwise"] = (
    lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
    lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=True),
)
Tim Dettmers's avatar
Tim Dettmers committed
74
75

str2statenames = {}
76
str2statenames["adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
Tim Dettmers's avatar
Tim Dettmers committed
77
78
str2statenames["paged_adamw"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["paged_adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
79
80
81
str2statenames["momentum"] = [("momentum_buffer", "state1")]
str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["rmsprop"] = [("square_avg", "state1")]
Tim Dettmers's avatar
Tim Dettmers committed
82
83
84
85
86
87
88
str2statenames["adam8bit"] = [("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2")]
str2statenames["lamb8bit"] = [("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2")]
str2statenames["adam8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")]
str2statenames["paged_adam8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")]
str2statenames["paged_adamw8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")]
str2statenames["momentum8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")]
str2statenames["momentum8bit_blockwise"] = [("momentum_buffer", "state1", "qmap1", "absmax1")]
89
str2statenames["rmsprop8bit"] = [("square_avg", "state1", "qmap1", "max1")]
Tim Dettmers's avatar
Tim Dettmers committed
90
str2statenames["rmsprop8bit_blockwise"] = [("square_avg", "state1", "qmap1", "absmax1")]
Tim Dettmers's avatar
Tim Dettmers committed
91
92
93

dim1 = [1024]
dim2 = [32, 1024, 4097, 1]
Tim Dettmers's avatar
Tim Dettmers committed
94
95
gtype = [torch.float32, torch.float16]
optimizer_names = ["adam", "momentum", "rmsprop", 'paged_adamw', 'paged_adam']
96
values = list(product(dim1, dim2, gtype, optimizer_names))
97
names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values]
Tim Dettmers's avatar
Tim Dettmers committed
98
99
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
def test_optimizer32bit(dim1, dim2, gtype, optim_name):
100
101
102
    if dim1 == 1 and dim2 == 1:
        return
    p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
103
104
105
106
107
108
109
    p2 = p1.clone()
    p1 = p1.float()

    torch_optimizer = str2optimizers[optim_name][0]([p1])
    bnb_optimizer = str2optimizers[optim_name][1]([p2])

    if gtype == torch.float32:
Tim Dettmers's avatar
Tim Dettmers committed
110
        atol, rtol = 1e-6, 1e-5
111
112
    elif gtype == torch.bfloat16:
        atol, rtol = 1e-3, 1e-2
Tim Dettmers's avatar
Tim Dettmers committed
113
114
115
    else:
        atol, rtol = 1e-4, 1e-3

Tim Dettmers's avatar
Tim Dettmers committed
116
    for i in range(k):
117
        g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01
Tim Dettmers's avatar
Tim Dettmers committed
118
119
120
121
122
123
124
        p1.grad = g.clone().float()
        p2.grad = g.clone()

        bnb_optimizer.step()
        torch_optimizer.step()

        for name1, name2 in str2statenames[optim_name]:
Tim Dettmers's avatar
Tim Dettmers committed
125
            torch.testing.assert_close(
126
                torch_optimizer.state[p1][name1],
Tim Dettmers's avatar
Tim Dettmers committed
127
                bnb_optimizer.state[p2][name2].cuda(),
128
129
130
                atol=atol,
                rtol=rtol,
            )
Tim Dettmers's avatar
Tim Dettmers committed
131

Tim Dettmers's avatar
Tim Dettmers committed
132
        torch.testing.assert_close(p1, p2.float(), atol=atol, rtol=rtol)
Tim Dettmers's avatar
Tim Dettmers committed
133

134
        if i % (k // 5) == 0 and i > 0:
Tim Dettmers's avatar
Tim Dettmers committed
135
            path = get_temp_dir()
136
            torch.save(bnb_optimizer.state_dict(), join(path, "opt.pt"))
Tim Dettmers's avatar
Tim Dettmers committed
137
138
139
            del bnb_optimizer
            bnb_optimizer = None
            bnb_optimizer = str2optimizers[optim_name][1]([p2])
140
            bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt")))
Tim Dettmers's avatar
Tim Dettmers committed
141
            rm_path(path)
Tim Dettmers's avatar
Tim Dettmers committed
142
            torch.testing.assert_close(p1, p2.float(), atol=atol, rtol=rtol)
Tim Dettmers's avatar
Tim Dettmers committed
143
            for name1, name2 in str2statenames[optim_name]:
Tim Dettmers's avatar
Tim Dettmers committed
144
                torch.testing.assert_close(
145
146
147
148
149
                    torch_optimizer.state[p1][name1],
                    bnb_optimizer.state[p2][name2],
                    atol=atol,
                    rtol=rtol,
                )
Tim Dettmers's avatar
Tim Dettmers committed
150

151
        if gtype != torch.float32:
Tim Dettmers's avatar
Tim Dettmers committed
152
153
154
155
            # the adam buffers should also be close because they are 32-bit
            # but the paramters can diverge because they are 16-bit
            # the difference grow larger and larger with each update
            # --> copy the state to keep weights close
156
            p1.data = p1.data.to(p2.dtype).float()
Tim Dettmers's avatar
Tim Dettmers committed
157
            p2.copy_(p1.data)
Tim Dettmers's avatar
Tim Dettmers committed
158
            torch.testing.assert_close(p1.to(p2.dtype), p2)
159
160
161
        if optim_name in ["lars", "lamb"]:
            assert bnb_optimizer.state[p2]["unorm_vec"] > 0.0

Tim Dettmers's avatar
Tim Dettmers committed
162
163
164
165

dim1 = [1024]
dim2 = [32, 1024, 4097]
gtype = [torch.float32, torch.float16]
166
values = list(product(dim1, dim2, gtype))
167
names = ["dim1_{}_dim2_{}_gtype_{}".format(*vals) for vals in values]
168
169


Tim Dettmers's avatar
Tim Dettmers committed
170
171
@pytest.mark.parametrize("dim1, dim2, gtype", values, ids=names)
def test_global_config(dim1, dim2, gtype):
172
173
174
175
176
    if dim1 == 1 and dim2 == 1:
        return
    p1 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
    p2 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
    p3 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
177
178
179
180
181
182
183
    mask = torch.rand_like(p2) < 0.1
    beta1 = 0.9
    beta2 = 0.999
    lr = 0.001
    eps = 1e-8

    bnb.optim.GlobalOptimManager.get_instance().initialize()
184
185
186
    bnb.optim.GlobalOptimManager.get_instance().override_config(
        p3, "optim_bits", 8
    )
Tim Dettmers's avatar
Tim Dettmers committed
187

188
189
190
    bnb.optim.GlobalOptimManager.get_instance().register_parameters(
        [p1, p2, p3]
    )
Tim Dettmers's avatar
Tim Dettmers committed
191
192
193
194
195
196
197
198
199
200
201
202
    p1 = p1.cuda()
    p2 = p2.cuda()
    p3 = p3.cuda()

    adam2 = bnb.optim.Adam([p1, p2, p3], lr, (beta1, beta2), eps)

    if gtype == torch.float32:
        atol, rtol = 1e-6, 1e-5
    else:
        atol, rtol = 1e-4, 1e-3

    for i in range(50):
203
204
205
        g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001
        g2 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001
        g3 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001
Tim Dettmers's avatar
Tim Dettmers committed
206
207
208
209
210
211
        p1.grad = g1
        p2.grad = g2
        p3.grad = g3

        adam2.step()

212
213
        assert adam2.state[p3]["state1"].dtype == torch.uint8
        assert adam2.state[p3]["state2"].dtype == torch.uint8
Tim Dettmers's avatar
Tim Dettmers committed
214
215
216
217


dim1 = [1024]
dim2 = [32, 1024, 4097]
Tim Dettmers's avatar
Tim Dettmers committed
218
gtype = [torch.float32, torch.float16, torch.bfloat16]
219
220
221
222
223
224
225
226
227
optimizer_names = [
    "adam8bit",
    "momentum8bit",
    "rmsprop8bit",
    "adam8bit_blockwise",
    "momentum8bit_blockwise",
    "rmsprop8bit_blockwise",
]
values = list(product(dim1, dim2, gtype, optimizer_names))
228
names = [
229
    "dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values
230
]
231
232


Tim Dettmers's avatar
Tim Dettmers committed
233
234
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
def test_optimizer8bit(dim1, dim2, gtype, optim_name):
Tim Dettmers's avatar
Tim Dettmers committed
235
    if gtype == torch.bfloat16 and optim_name not in str2bf16support: return
236
237
238
    if dim1 == 1 and dim2 == 1:
        return
    p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
239
240
241
242
243
244
245
246
247
248
    p2 = p1.clone()
    p1 = p1.float()
    blocksize = 2048

    torch_optimizer = str2optimizers[optim_name][0]([p1])
    bnb_optimizer = str2optimizers[optim_name][1]([p2])

    if gtype == torch.float32:
        atol, rtol = 3e-3, 1e-3
        patol, prtol = 1e-5, 1e-3
Tim Dettmers's avatar
Tim Dettmers committed
249
250
251
    elif gtype == torch.bfloat16:
        atol, rtol = 3e-3, 1e-3
        patol, prtol = 1e-4, 1e-2
Tim Dettmers's avatar
Tim Dettmers committed
252
253
254
255
256
257
258
    else:
        atol, rtol = 3e-3, 1e-3
        patol, prtol = 1e-5, 1e-3

    errors = []
    relerrors = []

259
    for i in range(100):
260
        g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01
Tim Dettmers's avatar
Tim Dettmers committed
261
262
263
264
265
266
        p1.grad = g.clone().float()
        p2.grad = g.clone()

        bnb_optimizer.step()
        torch_optimizer.step()

Tim Dettmers's avatar
Tim Dettmers committed
267
        torch.testing.assert_close(p1, p2.float(), atol=patol, rtol=prtol)
Tim Dettmers's avatar
Tim Dettmers committed
268
269
270

        dequant_states = []
        for name1, name2, qmap, max_val in str2statenames[optim_name]:
271
272
273
274
275
276
277
278
            # print(bnb_optimizer.state[p2][max_val], name1)
            if "blockwise" in optim_name:
                s1 = F.dequantize_blockwise(
                    code=bnb_optimizer.state[p2][qmap],
                    absmax=bnb_optimizer.state[p2][max_val],
                    A=bnb_optimizer.state[p2][name2],
                    blocksize=blocksize,
                )
Tim Dettmers's avatar
Tim Dettmers committed
279
            else:
280
281
282
283
284
285
286
287
288
289
290
                s1 = F.dequantize(
                    code=bnb_optimizer.state[p2][qmap],
                    absmax=bnb_optimizer.state[p2][max_val],
                    A=bnb_optimizer.state[p2][name2],
                )
            num_not_close = (
                torch.isclose(
                    torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol
                )
                == 0
            )
291
            #assert num_not_close.sum().item() < 20
Tim Dettmers's avatar
Tim Dettmers committed
292
293
            dequant_states.append(s1.clone())

294
295
        err = torch.abs(p1 - p2)
        relerr = err / torch.abs(p1)
Tim Dettmers's avatar
Tim Dettmers committed
296
297
        if g.dtype == torch.bfloat16:
            assert err.mean() < 0.00015
298
            assert relerr.mean() < 0.0016
Tim Dettmers's avatar
Tim Dettmers committed
299
        else:
300
301
            assert err.mean() < 0.00012
            assert relerr.mean() < 0.0012
Tim Dettmers's avatar
Tim Dettmers committed
302
303
304
305
306

        errors.append(err.mean().item())
        relerrors.append(relerr.mean().item())

        if i % 10 == 0 and i > 0:
307
308
309
            for (name1, name2, qmap, max_val), s in zip(
                str2statenames[optim_name], dequant_states
            ):
Tim Dettmers's avatar
Tim Dettmers committed
310
311
312
313
314
                s1cpy = s.clone()
                raws1cpy = bnb_optimizer.state[p2][name2].clone()
                qmap1 = bnb_optimizer.state[p2][qmap].clone()

                path = get_temp_dir()
315
                torch.save(bnb_optimizer.state_dict(), join(path, "opt.pt"))
Tim Dettmers's avatar
Tim Dettmers committed
316
317
318
                del bnb_optimizer
                bnb_optimizer = None
                bnb_optimizer = str2optimizers[optim_name][1]([p2])
319
                bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt")))
Tim Dettmers's avatar
Tim Dettmers committed
320
                rm_path(path)
Tim Dettmers's avatar
Tim Dettmers committed
321
322
                torch.testing.assert_close(raws1cpy, bnb_optimizer.state[p2][name2])
                torch.testing.assert_close(qmap1, bnb_optimizer.state[p2][qmap])
Tim Dettmers's avatar
Tim Dettmers committed
323

324
325
326
327
328
329
330
                if "blockwise" in optim_name:
                    s1 = F.dequantize_blockwise(
                        code=bnb_optimizer.state[p2][qmap],
                        absmax=bnb_optimizer.state[p2][max_val],
                        A=bnb_optimizer.state[p2][name2],
                        blocksize=blocksize,
                    )
Tim Dettmers's avatar
Tim Dettmers committed
331
                else:
332
333
334
335
336
                    s1 = F.dequantize(
                        code=bnb_optimizer.state[p2][qmap],
                        absmax=bnb_optimizer.state[p2][max_val],
                        A=bnb_optimizer.state[p2][name2],
                    )
Tim Dettmers's avatar
Tim Dettmers committed
337
                torch.testing.assert_close(s1cpy, s1)
Tim Dettmers's avatar
Tim Dettmers committed
338

Tim Dettmers's avatar
Tim Dettmers committed
339
                num_not_close = (torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0)
Tim Dettmers's avatar
Tim Dettmers committed
340
                assert num_not_close.sum().item() < 20
Tim Dettmers's avatar
Tim Dettmers committed
341
            torch.testing.assert_close(p1, p2.float(), atol=patol, rtol=prtol)
Tim Dettmers's avatar
Tim Dettmers committed
342
343
344
345
346

        # the parameters diverge quickly. Here we keep them close
        # together so we can test against the Adam error
        p1.data = p1.data.to(gtype).float()
        p2.copy_(p1.data)
Tim Dettmers's avatar
Tim Dettmers committed
347
        torch.testing.assert_close(p1.to(gtype), p2)
Tim Dettmers's avatar
Tim Dettmers committed
348
        for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states):
Tim Dettmers's avatar
Tim Dettmers committed
349
350
            torch_optimizer.state[p1][name1].copy_(s.data)

351
352
    # print(sum(errors)/len(errors))
    # print(sum(relerrors)/len(relerrors))
Tim Dettmers's avatar
Tim Dettmers committed
353
354
355
356
357
358


dim1 = [1024]
dim2 = [32, 1024, 4097]
gtype = [torch.float32]
optim_bits = [32, 8]
359
values = list(product(dim1, dim2, gtype, optim_bits))
360
names = [
361
    "dim1_{}_dim2_{}_gtype_{}_optim_bits_{}".format(*vals)
362
363
    for vals in values
]
364
365


Tim Dettmers's avatar
Tim Dettmers committed
366
367
@pytest.mark.parametrize("dim1, dim2, gtype, optim_bits", values, ids=names)
def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
368
369
370
    if dim1 == 1 and dim2 == 1:
        return
    p1 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
371
372
373
374
375
376
377
    beta1 = 0.9
    beta2 = 0.999
    lr = 0.001
    eps = 1e-8
    p1 = p1.cuda()
    p2 = p1.clone()
    adam1 = bnb.optim.Adam([p1], lr, (beta1, beta2), eps, optim_bits=optim_bits)
378
    adam2 = bnb.optim.Adam(
379
380
381
382
383
384
        [p2],
        lr,
        (beta1, beta2),
        eps,
        optim_bits=optim_bits,
        percentile_clipping=5,
385
    )
Tim Dettmers's avatar
Tim Dettmers committed
386
387
388
389
390
391

    gnorm_vec = torch.zeros(100).cuda()
    step = 0

    for i in range(50):
        step += 1
392
393
394
        g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + (
            0.01 * i
        )
Tim Dettmers's avatar
Tim Dettmers committed
395
396
397
        g2 = g1.clone()
        p2.grad = g2

398
399
400
401
        current_gnorm, clip_val, gnorm_scale = F.percentile_clipping(
            g1, gnorm_vec, step, 5
        )
        g1 = (g1.float() * gnorm_scale).to(gtype)
Tim Dettmers's avatar
Tim Dettmers committed
402
403
404
405
406
407
408
        p1.grad = g1

        adam1.step()
        adam2.step()

        # gnorm_scale is not deterministic (warp reductions), as such there can be slight differences in state
        if optim_bits == 32:
Tim Dettmers's avatar
Tim Dettmers committed
409
410
            torch.testing.assert_close(p1, p2)
            torch.testing.assert_close(
411
412
413
414
415
                adam1.state[p1]["state1"],
                adam2.state[p2]["state1"],
                atol=5e-5,
                rtol=1e-4,
            )
Tim Dettmers's avatar
Tim Dettmers committed
416
            torch.testing.assert_close(
417
418
419
420
421
                adam1.state[p1]["state2"],
                adam2.state[p2]["state2"],
                atol=5e-5,
                rtol=1e-4,
            )
Tim Dettmers's avatar
Tim Dettmers committed
422
        elif optim_bits == 8:
Tim Dettmers's avatar
Tim Dettmers committed
423
424
            torch.testing.assert_close(p1, p2, atol=1e-4, rtol=1e-3)
            torch.testing.assert_close(
425
426
427
428
                adam1.state[p1]["state1"],
                adam2.state[p2]["state1"],
                atol=2,
                rtol=1e-3,
429
            )
Tim Dettmers's avatar
Tim Dettmers committed
430
            torch.testing.assert_close(
431
432
433
434
                adam1.state[p1]["state2"],
                adam2.state[p2]["state2"],
                atol=2,
                rtol=1e-3,
435
436
437
            )
            adam1.state[p1]["state1"].copy_(adam2.state[p2]["state1"])
            adam1.state[p1]["state2"].copy_(adam2.state[p2]["state2"])
Tim Dettmers's avatar
Tim Dettmers committed
438
439
        if i % 10 == 0 and i > 0:
            path = get_temp_dir()
440
            torch.save(adam2.state_dict(), join(path, "opt.pt"))
Tim Dettmers's avatar
Tim Dettmers committed
441
442
            del adam2
            adam2 = None
443
444
445
446
447
448
449
450
451
            adam2 = bnb.optim.Adam(
                [p2],
                lr,
                (beta1, beta2),
                eps,
                optim_bits=optim_bits,
                percentile_clipping=5,
            )
            adam2.load_state_dict(torch.load(join(path, "opt.pt")))
Tim Dettmers's avatar
Tim Dettmers committed
452
453
454
455
456


dim1 = [4096]
dim2 = [4096]
gtype = [torch.float32, torch.float16]
457
458
459
460
461
# optimizer_names = ['adam8bit_blockwise', 'adam8bit', 'lamb8bit']
# optimizer_names = ['adam8bit_blockwise', 'adam_apex', 'adam8bit', 'adam', 'adam_pytorch']
# optimizer_names = ['momentum_apex', 'momentum8bit', 'momentum_pytorch']
# optimizer_names = ['lamb_apex', 'lamb8bit']
# optimizer_names = ['lars_apex', 'lars8bit']
Tim Dettmers's avatar
Tim Dettmers committed
462
optimizer_names = ["adam8bit_blockwise", 'paged_adam8bit_blockwise', 'paged_adamw8bit_blockwise']
463
values = list(product(dim1, dim2, gtype, optimizer_names))
464
names = [
465
    "dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values
466
]
467
468


Tim Dettmers's avatar
Tim Dettmers committed
469
470
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
471
472
473
    if dim1 == 1 and dim2 == 1:
        return
    p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
474
475
476

    bnb_optimizer = str2optimizers[optim_name][1]([p1])

477
    g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01
Tim Dettmers's avatar
Tim Dettmers committed
478
    p1.grad = g
Tim Dettmers's avatar
Tim Dettmers committed
479
    for i in range(k):
480
        if i == k // 5:
Tim Dettmers's avatar
Tim Dettmers committed
481
482
483
484
485
486
487
            # 100 iterations for burn-in
            torch.cuda.synchronize()
            t0 = time.time()

        bnb_optimizer.step()

    torch.cuda.synchronize()
488
489
490
491
492
    s = time.time() - t0
    print("")
    params = (k - k // 5) * dim1 * dim2
    print(optim_name, gtype, s / params)
    # assert s < 3.9