test_optim.py 20.6 KB
Newer Older
Aarni Koskela's avatar
Aarni Koskela committed
1
from itertools import product
Tim Dettmers's avatar
Tim Dettmers committed
2
import os
Aarni Koskela's avatar
Aarni Koskela committed
3
from os.path import join
Tim Dettmers's avatar
Tim Dettmers committed
4
import shutil
5
import time
Tim Dettmers's avatar
Tim Dettmers committed
6
import uuid
7

8
from lion_pytorch import Lion
Aarni Koskela's avatar
Aarni Koskela committed
9
import pytest
Tim Dettmers's avatar
Tim Dettmers committed
10
import torch
11

Tim Dettmers's avatar
Tim Dettmers committed
12
13
14
import bitsandbytes as bnb
import bitsandbytes.functional as F

15
# import apex
Tim Dettmers's avatar
Tim Dettmers committed
16
17

k = 20
Tim Dettmers's avatar
Tim Dettmers committed
18

19
def assert_most_approx_close(a, b, rtol=1e-3, atol=1e-3, max_error_count=0):
Tim Dettmers's avatar
Tim Dettmers committed
20
    idx = torch.isclose(a, b, rtol=rtol, atol=atol)
21
22
    error_count = (idx == 0).sum().item()
    if error_count > max_error_count:
23
        print(f"Too many values not close: assert {error_count} < {max_error_count}")
Tim Dettmers's avatar
Tim Dettmers committed
24
        torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
25

26

Tim Dettmers's avatar
Tim Dettmers committed
27
def get_temp_dir():
Aarni Koskela's avatar
Aarni Koskela committed
28
    path = f"/tmp/autoswap/{uuid.uuid4()}"
Tim Dettmers's avatar
Tim Dettmers committed
29
30
31
    os.makedirs(path, exist_ok=True)
    return path

32

Tim Dettmers's avatar
Tim Dettmers committed
33
34
35
36
def rm_path(path):
    shutil.rmtree(path)

str2optimizers = {}
37
str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam)
38
str2optimizers["lion_pytorch"] = (None, Lion, bnb.optim.Lion)
39
40
41
42
43
44
str2optimizers["momentum_pytorch"] = (
    None,
    lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
    bnb.optim.Adam,
)
str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam)
Tim Dettmers's avatar
Tim Dettmers committed
45
46
str2optimizers["paged_adamw"] = (torch.optim.AdamW, bnb.optim.PagedAdamW)
str2optimizers["paged_adam"] = (torch.optim.Adam, bnb.optim.PagedAdam)
47
str2optimizers["lion"] = (Lion, bnb.optim.Lion)
Tim Dettmers's avatar
Tim Dettmers committed
48
str2optimizers["paged_lion"] = (Lion, bnb.optim.PagedLion)
49
50
51
52
53
54
55
56
str2optimizers["momentum"] = (
    lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
    lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False),
)
str2optimizers["rmsprop"] = (
    lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
    lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False),
)
Tim Dettmers's avatar
Tim Dettmers committed
57
str2optimizers["adam8bit"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False))
58
str2optimizers["lion8bit"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=False))
59
60
61
62
63
64
65
66
67
str2optimizers["momentum8bit"] = (
    lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
    lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False),
)
str2optimizers["rmsprop8bit"] = (
    lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
    lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False),
)

Tim Dettmers's avatar
Tim Dettmers committed
68
69
70
str2optimizers["adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True))
str2optimizers["paged_adamw8bit_blockwise"] = (torch.optim.AdamW, lambda pxx: bnb.optim.PagedAdamW8bit(pxx, block_wise=True))
str2optimizers["paged_adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.PagedAdam8bit(pxx, block_wise=True))
71
str2optimizers["lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=True))
Tim Dettmers's avatar
Tim Dettmers committed
72
str2optimizers["paged_lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.PagedLion8bit(pxx, block_wise=True))
73
74
75
76
77
78
79
80
str2optimizers["momentum8bit_blockwise"] = (
    lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
    lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True),
)
str2optimizers["rmsprop8bit_blockwise"] = (
    lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
    lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=True),
)
Tim Dettmers's avatar
Tim Dettmers committed
81
82

str2statenames = {}
83
str2statenames["adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
Tim Dettmers's avatar
Tim Dettmers committed
84
85
str2statenames["paged_adamw"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["paged_adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
86
str2statenames["lion"] = [("exp_avg", "state1")]
Tim Dettmers's avatar
Tim Dettmers committed
87
str2statenames["paged_lion"] = [("exp_avg", "state1")]
88
89
90
str2statenames["momentum"] = [("momentum_buffer", "state1")]
str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["rmsprop"] = [("square_avg", "state1")]
Tim Dettmers's avatar
Tim Dettmers committed
91
92
93
94
95
96
str2statenames["adam8bit"] = [("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2")]
str2statenames["lamb8bit"] = [("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2")]
str2statenames["adam8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")]
str2statenames["paged_adam8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")]
str2statenames["paged_adamw8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")]
str2statenames["momentum8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")]
97
str2statenames["lion8bit"] = [("exp_avg", "state1", "qmap1", "max1")]
Tim Dettmers's avatar
Tim Dettmers committed
98
str2statenames["momentum8bit_blockwise"] = [("momentum_buffer", "state1", "qmap1", "absmax1")]
99
str2statenames["rmsprop8bit"] = [("square_avg", "state1", "qmap1", "max1")]
Tim Dettmers's avatar
Tim Dettmers committed
100
str2statenames["rmsprop8bit_blockwise"] = [("square_avg", "state1", "qmap1", "absmax1")]
101
str2statenames["lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")]
Tim Dettmers's avatar
Tim Dettmers committed
102
str2statenames["paged_lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")]
Tim Dettmers's avatar
Tim Dettmers committed
103
104
105

dim1 = [1024]
dim2 = [32, 1024, 4097, 1]
Tim Dettmers's avatar
Tim Dettmers committed
106
107
gtype = [torch.float32, torch.float16, torch.bfloat16]
optimizer_names = ["adam", "momentum", "rmsprop", 'paged_adamw', 'paged_adam', 'lion', 'paged_lion']
108
values = list(product(dim1, dim2, gtype, optimizer_names))
109
names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values]
Tim Dettmers's avatar
Tim Dettmers committed
110
111
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
def test_optimizer32bit(dim1, dim2, gtype, optim_name):
Tim Dettmers's avatar
Tim Dettmers committed
112
    if gtype == torch.bfloat16 and optim_name in ['momentum', 'rmsprop']: pytest.skip()
113
114
115
    if dim1 == 1 and dim2 == 1:
        return
    p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
116
117
118
119
120
121
122
    p2 = p1.clone()
    p1 = p1.float()

    torch_optimizer = str2optimizers[optim_name][0]([p1])
    bnb_optimizer = str2optimizers[optim_name][1]([p2])

    if gtype == torch.float32:
Tim Dettmers's avatar
Tim Dettmers committed
123
        atol, rtol = 1e-6, 1e-5
124
125
    elif gtype == torch.bfloat16:
        atol, rtol = 1e-3, 1e-2
Tim Dettmers's avatar
Tim Dettmers committed
126
127
128
    else:
        atol, rtol = 1e-4, 1e-3

Tim Dettmers's avatar
Tim Dettmers committed
129
    for i in range(k):
130
        g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01
Tim Dettmers's avatar
Tim Dettmers committed
131
132
133
134
135
136
        p1.grad = g.clone().float()
        p2.grad = g.clone()

        bnb_optimizer.step()
        torch_optimizer.step()

137

Tim Dettmers's avatar
Tim Dettmers committed
138
        for name1, name2 in str2statenames[optim_name]:
Tim Dettmers's avatar
Tim Dettmers committed
139
            torch.testing.assert_close(
140
                torch_optimizer.state[p1][name1],
Tim Dettmers's avatar
Tim Dettmers committed
141
                bnb_optimizer.state[p2][name2].cuda(),
142
143
144
                atol=atol,
                rtol=rtol,
            )
Tim Dettmers's avatar
Tim Dettmers committed
145

146
147
148
        # since Lion can have pretty noisy updates where things lie at the boundary
        # allow up to 10 errors for Lion
        assert_most_approx_close(p1, p2.float(), atol, rtol, max_error_count=10)
Tim Dettmers's avatar
Tim Dettmers committed
149

150
        if i % (k // 5) == 0 and i > 0:
Tim Dettmers's avatar
Tim Dettmers committed
151
            path = get_temp_dir()
152
            torch.save(bnb_optimizer.state_dict(), join(path, "opt.pt"))
Tim Dettmers's avatar
Tim Dettmers committed
153
154
155
            del bnb_optimizer
            bnb_optimizer = None
            bnb_optimizer = str2optimizers[optim_name][1]([p2])
156
            bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt")))
Tim Dettmers's avatar
Tim Dettmers committed
157
            rm_path(path)
158
159
160
            # since Lion can have pretty noisy updates where things lie at the boundary
            # allow up to 10 errors for Lion
            assert_most_approx_close(p1, p2.float(), atol, rtol, max_error_count=10)
Tim Dettmers's avatar
Tim Dettmers committed
161
            for name1, name2 in str2statenames[optim_name]:
162
163
164
165
166
                # since Lion can have pretty noisy updates where things lie at the boundary
                # allow up to 10 errors for Lion
                assert_most_approx_close(torch_optimizer.state[p1][name1], bnb_optimizer.state[p2][name2],
                                         atol=atol, rtol=rtol,
                                         max_error_count=10)
Tim Dettmers's avatar
Tim Dettmers committed
167

168
        if gtype != torch.float32:
Tim Dettmers's avatar
Tim Dettmers committed
169
            # the adam buffers should also be close because they are 32-bit
170
            # but the parameters can diverge because they are 16-bit
Tim Dettmers's avatar
Tim Dettmers committed
171
172
            # the difference grow larger and larger with each update
            # --> copy the state to keep weights close
173
            p1.data = p1.data.to(p2.dtype).float()
Tim Dettmers's avatar
Tim Dettmers committed
174
            p2.copy_(p1.data)
Tim Dettmers's avatar
Tim Dettmers committed
175
            torch.testing.assert_close(p1.to(p2.dtype), p2)
176
177
178
        if optim_name in ["lars", "lamb"]:
            assert bnb_optimizer.state[p2]["unorm_vec"] > 0.0

Tim Dettmers's avatar
Tim Dettmers committed
179
180
181
182

dim1 = [1024]
dim2 = [32, 1024, 4097]
gtype = [torch.float32, torch.float16]
183
values = list(product(dim1, dim2, gtype))
184
names = ["dim1_{}_dim2_{}_gtype_{}".format(*vals) for vals in values]
185
186


Tim Dettmers's avatar
Tim Dettmers committed
187
188
@pytest.mark.parametrize("dim1, dim2, gtype", values, ids=names)
def test_global_config(dim1, dim2, gtype):
189
190
191
192
193
    if dim1 == 1 and dim2 == 1:
        return
    p1 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
    p2 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
    p3 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
194
195
196
197
198
199
200
    mask = torch.rand_like(p2) < 0.1
    beta1 = 0.9
    beta2 = 0.999
    lr = 0.001
    eps = 1e-8

    bnb.optim.GlobalOptimManager.get_instance().initialize()
201
202
203
    bnb.optim.GlobalOptimManager.get_instance().override_config(
        p3, "optim_bits", 8
    )
Tim Dettmers's avatar
Tim Dettmers committed
204

205
206
207
    bnb.optim.GlobalOptimManager.get_instance().register_parameters(
        [p1, p2, p3]
    )
Tim Dettmers's avatar
Tim Dettmers committed
208
209
210
211
212
213
214
215
216
217
218
219
    p1 = p1.cuda()
    p2 = p2.cuda()
    p3 = p3.cuda()

    adam2 = bnb.optim.Adam([p1, p2, p3], lr, (beta1, beta2), eps)

    if gtype == torch.float32:
        atol, rtol = 1e-6, 1e-5
    else:
        atol, rtol = 1e-4, 1e-3

    for i in range(50):
220
221
222
        g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001
        g2 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001
        g3 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001
Tim Dettmers's avatar
Tim Dettmers committed
223
224
225
226
227
228
        p1.grad = g1
        p2.grad = g2
        p3.grad = g3

        adam2.step()

229
230
        assert adam2.state[p3]["state1"].dtype == torch.uint8
        assert adam2.state[p3]["state2"].dtype == torch.uint8
Tim Dettmers's avatar
Tim Dettmers committed
231
232
233
234


dim1 = [1024]
dim2 = [32, 1024, 4097]
Tim Dettmers's avatar
Tim Dettmers committed
235
gtype = [torch.float32, torch.float16, torch.bfloat16]
236
237
optimizer_names = [
    "adam8bit",
238
    "lion8bit",
239
240
241
    "momentum8bit",
    "rmsprop8bit",
    "adam8bit_blockwise",
242
    "lion8bit_blockwise",
243
244
245
246
    "momentum8bit_blockwise",
    "rmsprop8bit_blockwise",
]
values = list(product(dim1, dim2, gtype, optimizer_names))
247
names = [
248
    "dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values
249
]
250
251


Tim Dettmers's avatar
Tim Dettmers committed
252
253
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
def test_optimizer8bit(dim1, dim2, gtype, optim_name):
Tim Dettmers's avatar
Tim Dettmers committed
254
    if gtype == torch.bfloat16 and optim_name not in ['adam8bit_blockwise', 'lion8bit_blockwise']: pytest.skip()
255
256
257
    if dim1 == 1 and dim2 == 1:
        return
    p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
258
259
260
261
262
263
264
265
266
267
    p2 = p1.clone()
    p1 = p1.float()
    blocksize = 2048

    torch_optimizer = str2optimizers[optim_name][0]([p1])
    bnb_optimizer = str2optimizers[optim_name][1]([p2])

    if gtype == torch.float32:
        atol, rtol = 3e-3, 1e-3
        patol, prtol = 1e-5, 1e-3
Tim Dettmers's avatar
Tim Dettmers committed
268
269
270
    elif gtype == torch.bfloat16:
        atol, rtol = 3e-3, 1e-3
        patol, prtol = 1e-4, 1e-2
Tim Dettmers's avatar
Tim Dettmers committed
271
272
273
274
275
276
277
    else:
        atol, rtol = 3e-3, 1e-3
        patol, prtol = 1e-5, 1e-3

    errors = []
    relerrors = []

278
    for i in range(100):
279
        g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01
Tim Dettmers's avatar
Tim Dettmers committed
280
281
282
283
284
285
        p1.grad = g.clone().float()
        p2.grad = g.clone()

        bnb_optimizer.step()
        torch_optimizer.step()

286
287
288
        # since Lion can have pretty noisy updates where things lie at the boundary
        # allow up to 5 errors for Lion
        assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=5)
Tim Dettmers's avatar
Tim Dettmers committed
289
290
291

        dequant_states = []
        for name1, name2, qmap, max_val in str2statenames[optim_name]:
292
293
294
295
296
297
298
299
            # print(bnb_optimizer.state[p2][max_val], name1)
            if "blockwise" in optim_name:
                s1 = F.dequantize_blockwise(
                    code=bnb_optimizer.state[p2][qmap],
                    absmax=bnb_optimizer.state[p2][max_val],
                    A=bnb_optimizer.state[p2][name2],
                    blocksize=blocksize,
                )
Tim Dettmers's avatar
Tim Dettmers committed
300
            else:
301
302
303
304
305
306
307
308
309
310
311
                s1 = F.dequantize(
                    code=bnb_optimizer.state[p2][qmap],
                    absmax=bnb_optimizer.state[p2][max_val],
                    A=bnb_optimizer.state[p2][name2],
                )
            num_not_close = (
                torch.isclose(
                    torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol
                )
                == 0
            )
312
            #assert num_not_close.sum().item() < 20
Tim Dettmers's avatar
Tim Dettmers committed
313
314
            dequant_states.append(s1.clone())

315
        err = torch.abs(p1 - p2)
316
        relerr = err / (torch.abs(p1)+1e-9)
Tim Dettmers's avatar
Tim Dettmers committed
317
318
        if g.dtype == torch.bfloat16:
            assert err.mean() < 0.00015
319
            assert relerr.mean() < 0.0016
Tim Dettmers's avatar
Tim Dettmers committed
320
        else:
321
322
            assert err.mean() < 0.00012
            assert relerr.mean() < 0.0012
Tim Dettmers's avatar
Tim Dettmers committed
323
324
325
326
327

        errors.append(err.mean().item())
        relerrors.append(relerr.mean().item())

        if i % 10 == 0 and i > 0:
328
329
330
            for (name1, name2, qmap, max_val), s in zip(
                str2statenames[optim_name], dequant_states
            ):
Tim Dettmers's avatar
Tim Dettmers committed
331
332
333
334
335
                s1cpy = s.clone()
                raws1cpy = bnb_optimizer.state[p2][name2].clone()
                qmap1 = bnb_optimizer.state[p2][qmap].clone()

                path = get_temp_dir()
336
                torch.save(bnb_optimizer.state_dict(), join(path, "opt.pt"))
Tim Dettmers's avatar
Tim Dettmers committed
337
338
339
                del bnb_optimizer
                bnb_optimizer = None
                bnb_optimizer = str2optimizers[optim_name][1]([p2])
340
                bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt")))
Tim Dettmers's avatar
Tim Dettmers committed
341
                rm_path(path)
Tim Dettmers's avatar
Tim Dettmers committed
342
343
                torch.testing.assert_close(raws1cpy, bnb_optimizer.state[p2][name2])
                torch.testing.assert_close(qmap1, bnb_optimizer.state[p2][qmap])
Tim Dettmers's avatar
Tim Dettmers committed
344

345
346
347
348
349
350
351
                if "blockwise" in optim_name:
                    s1 = F.dequantize_blockwise(
                        code=bnb_optimizer.state[p2][qmap],
                        absmax=bnb_optimizer.state[p2][max_val],
                        A=bnb_optimizer.state[p2][name2],
                        blocksize=blocksize,
                    )
Tim Dettmers's avatar
Tim Dettmers committed
352
                else:
353
354
355
356
357
                    s1 = F.dequantize(
                        code=bnb_optimizer.state[p2][qmap],
                        absmax=bnb_optimizer.state[p2][max_val],
                        A=bnb_optimizer.state[p2][name2],
                    )
Tim Dettmers's avatar
Tim Dettmers committed
358
                torch.testing.assert_close(s1cpy, s1)
Tim Dettmers's avatar
Tim Dettmers committed
359

Tim Dettmers's avatar
Tim Dettmers committed
360
                num_not_close = (torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0)
Tim Dettmers's avatar
Tim Dettmers committed
361
                assert num_not_close.sum().item() < 20
362
363
364
            # since Lion can have pretty noisy updates where things lie at the boundary
            # allow up to 5 errors for Lion
            assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=5)
Tim Dettmers's avatar
Tim Dettmers committed
365
366
367
368
369

        # the parameters diverge quickly. Here we keep them close
        # together so we can test against the Adam error
        p1.data = p1.data.to(gtype).float()
        p2.copy_(p1.data)
Tim Dettmers's avatar
Tim Dettmers committed
370
        torch.testing.assert_close(p1.to(gtype), p2)
Tim Dettmers's avatar
Tim Dettmers committed
371
        for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states):
Tim Dettmers's avatar
Tim Dettmers committed
372
373
            torch_optimizer.state[p1][name1].copy_(s.data)

374
375
    # print(sum(errors)/len(errors))
    # print(sum(relerrors)/len(relerrors))
Tim Dettmers's avatar
Tim Dettmers committed
376
377
378
379
380
381


dim1 = [1024]
dim2 = [32, 1024, 4097]
gtype = [torch.float32]
optim_bits = [32, 8]
382
values = list(product(dim1, dim2, gtype, optim_bits))
383
names = [
384
    "dim1_{}_dim2_{}_gtype_{}_optim_bits_{}".format(*vals)
385
386
    for vals in values
]
387
388


Tim Dettmers's avatar
Tim Dettmers committed
389
390
@pytest.mark.parametrize("dim1, dim2, gtype, optim_bits", values, ids=names)
def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
391
392
393
    if dim1 == 1 and dim2 == 1:
        return
    p1 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
394
395
396
397
398
399
400
    beta1 = 0.9
    beta2 = 0.999
    lr = 0.001
    eps = 1e-8
    p1 = p1.cuda()
    p2 = p1.clone()
    adam1 = bnb.optim.Adam([p1], lr, (beta1, beta2), eps, optim_bits=optim_bits)
401
    adam2 = bnb.optim.Adam(
402
403
404
405
406
407
        [p2],
        lr,
        (beta1, beta2),
        eps,
        optim_bits=optim_bits,
        percentile_clipping=5,
408
    )
Tim Dettmers's avatar
Tim Dettmers committed
409
410
411
412
413
414

    gnorm_vec = torch.zeros(100).cuda()
    step = 0

    for i in range(50):
        step += 1
415
416
417
        g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + (
            0.01 * i
        )
Tim Dettmers's avatar
Tim Dettmers committed
418
419
420
        g2 = g1.clone()
        p2.grad = g2

421
422
423
424
        current_gnorm, clip_val, gnorm_scale = F.percentile_clipping(
            g1, gnorm_vec, step, 5
        )
        g1 = (g1.float() * gnorm_scale).to(gtype)
Tim Dettmers's avatar
Tim Dettmers committed
425
426
427
428
429
430
431
        p1.grad = g1

        adam1.step()
        adam2.step()

        # gnorm_scale is not deterministic (warp reductions), as such there can be slight differences in state
        if optim_bits == 32:
Tim Dettmers's avatar
Tim Dettmers committed
432
433
            torch.testing.assert_close(p1, p2)
            torch.testing.assert_close(
434
435
436
437
438
                adam1.state[p1]["state1"],
                adam2.state[p2]["state1"],
                atol=5e-5,
                rtol=1e-4,
            )
Tim Dettmers's avatar
Tim Dettmers committed
439
            torch.testing.assert_close(
440
441
442
443
444
                adam1.state[p1]["state2"],
                adam2.state[p2]["state2"],
                atol=5e-5,
                rtol=1e-4,
            )
Tim Dettmers's avatar
Tim Dettmers committed
445
        elif optim_bits == 8:
Tim Dettmers's avatar
Tim Dettmers committed
446
447
            torch.testing.assert_close(p1, p2, atol=1e-4, rtol=1e-3)
            torch.testing.assert_close(
448
449
450
451
                adam1.state[p1]["state1"],
                adam2.state[p2]["state1"],
                atol=2,
                rtol=1e-3,
452
            )
Tim Dettmers's avatar
Tim Dettmers committed
453
            torch.testing.assert_close(
454
455
456
457
                adam1.state[p1]["state2"],
                adam2.state[p2]["state2"],
                atol=2,
                rtol=1e-3,
458
459
460
            )
            adam1.state[p1]["state1"].copy_(adam2.state[p2]["state1"])
            adam1.state[p1]["state2"].copy_(adam2.state[p2]["state2"])
Tim Dettmers's avatar
Tim Dettmers committed
461
462
        if i % 10 == 0 and i > 0:
            path = get_temp_dir()
463
            torch.save(adam2.state_dict(), join(path, "opt.pt"))
Tim Dettmers's avatar
Tim Dettmers committed
464
465
            del adam2
            adam2 = None
466
467
468
469
470
471
472
473
474
            adam2 = bnb.optim.Adam(
                [p2],
                lr,
                (beta1, beta2),
                eps,
                optim_bits=optim_bits,
                percentile_clipping=5,
            )
            adam2.load_state_dict(torch.load(join(path, "opt.pt")))
Tim Dettmers's avatar
Tim Dettmers committed
475
476
477
478
479


dim1 = [4096]
dim2 = [4096]
gtype = [torch.float32, torch.float16]
480
481
482
483
484
# optimizer_names = ['adam8bit_blockwise', 'adam8bit', 'lamb8bit']
# optimizer_names = ['adam8bit_blockwise', 'adam_apex', 'adam8bit', 'adam', 'adam_pytorch']
# optimizer_names = ['momentum_apex', 'momentum8bit', 'momentum_pytorch']
# optimizer_names = ['lamb_apex', 'lamb8bit']
# optimizer_names = ['lars_apex', 'lars8bit']
Tim Dettmers's avatar
Tim Dettmers committed
485
optimizer_names = ["adam8bit_blockwise", 'paged_adam8bit_blockwise', 'paged_adamw8bit_blockwise', 'paged_lion8bit_blockwise']
486
values = list(product(dim1, dim2, gtype, optimizer_names))
487
names = [
488
    "dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values
489
]
490
491


Tim Dettmers's avatar
Tim Dettmers committed
492
493
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
494
495
496
    if dim1 == 1 and dim2 == 1:
        return
    p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
497
498
499

    bnb_optimizer = str2optimizers[optim_name][1]([p1])

500
    g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01
Tim Dettmers's avatar
Tim Dettmers committed
501
    p1.grad = g
Tim Dettmers's avatar
Tim Dettmers committed
502
    for i in range(k):
503
        if i == k // 5:
Tim Dettmers's avatar
Tim Dettmers committed
504
505
506
507
508
509
510
            # 100 iterations for burn-in
            torch.cuda.synchronize()
            t0 = time.time()

        bnb_optimizer.step()

    torch.cuda.synchronize()
511
512
513
514
515
    s = time.time() - t0
    print("")
    params = (k - k // 5) * dim1 * dim2
    print(optim_name, gtype, s / params)
    # assert s < 3.9
516

517
dim1 = [2*1024]
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
gtype = [torch.float16]
#mode = ['torch', 'bnb']
mode = ['bnb']
optimizer_names = ['paged_adamw']
#optimizer_names = ['paged_adamw8bit_blockwise']
values = list(product(dim1,gtype, optimizer_names, mode))
names = ['dim1_{0}_gtype_{1}_optim_{2}_mode_{3}'.format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, gtype, optim_name, mode", values, ids=names)
def test_stream_optimizer_bench(dim1, gtype, optim_name, mode):
    layers1 = torch.nn.Sequential(*torch.nn.ModuleList([torch.nn.Linear(dim1, dim1) for i in range(10)]))
    layers1 = layers1.to(gtype)
    layers1 = layers1.cuda()

    large_tensor = None
    if mode == 'torch':
        optim = str2optimizers[optim_name][0](layers1.parameters())
    else:
        optim = str2optimizers[optim_name][1](layers1.parameters())
        # 12 GB
        large_tensor = torch.empty((int(4.5e9),), device='cuda')

    torch.cuda.synchronize()
    time.sleep(5)

    num_batches = 5
    batches = torch.randn(num_batches, 128, dim1, device='cuda').to(gtype)
    lbls = torch.randint(0, 10, size=(num_batches,128)).cuda()

    for i in range(num_batches):
        print(i)
        b = batches[i]
        if i ==2:
            torch.cuda.synchronize()
            t0 = time.time()

        out1 = layers1(b)

        loss1 = torch.nn.functional.cross_entropy(out1, lbls[i]).mean()
        loss1.backward()
        optim.step()
    torch.cuda.synchronize()
    print(mode, time.time() - t0)