"vscode:/vscode.git/clone" did not exist on "675baa79d2434b8bb779859ad08a73b9c0c82fbe"
test_optim.py 20.6 KB
Newer Older
1
import ctypes
Tim Dettmers's avatar
Tim Dettmers committed
2
3
import os
import shutil
4
import time
Tim Dettmers's avatar
Tim Dettmers committed
5
import uuid
6
7
8
from itertools import product
from os.path import join

Tim Dettmers's avatar
Tim Dettmers committed
9
import pytest
10
11
from lion_pytorch import Lion

Tim Dettmers's avatar
Tim Dettmers committed
12
import torch
13

Tim Dettmers's avatar
Tim Dettmers committed
14
15
16
import bitsandbytes as bnb
import bitsandbytes.functional as F

17
# import apex
Tim Dettmers's avatar
Tim Dettmers committed
18
19

k = 20
Tim Dettmers's avatar
Tim Dettmers committed
20

21
def assert_most_approx_close(a, b, rtol=1e-3, atol=1e-3, max_error_count=0):
Tim Dettmers's avatar
Tim Dettmers committed
22
    idx = torch.isclose(a, b, rtol=rtol, atol=atol)
23
24
    error_count = (idx == 0).sum().item()
    if error_count > max_error_count:
25
        print(f"Too many values not close: assert {error_count} < {max_error_count}")
Tim Dettmers's avatar
Tim Dettmers committed
26
        torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
27

28

Tim Dettmers's avatar
Tim Dettmers committed
29
def get_temp_dir():
30
    path = f"/tmp/autoswap/{str(uuid.uuid4())}"
Tim Dettmers's avatar
Tim Dettmers committed
31
32
33
    os.makedirs(path, exist_ok=True)
    return path

34

Tim Dettmers's avatar
Tim Dettmers committed
35
36
37
38
def rm_path(path):
    shutil.rmtree(path)

str2optimizers = {}
39
str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam)
40
str2optimizers["lion_pytorch"] = (None, Lion, bnb.optim.Lion)
41
42
43
44
45
46
str2optimizers["momentum_pytorch"] = (
    None,
    lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
    bnb.optim.Adam,
)
str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam)
Tim Dettmers's avatar
Tim Dettmers committed
47
48
str2optimizers["paged_adamw"] = (torch.optim.AdamW, bnb.optim.PagedAdamW)
str2optimizers["paged_adam"] = (torch.optim.Adam, bnb.optim.PagedAdam)
49
str2optimizers["lion"] = (Lion, bnb.optim.Lion)
Tim Dettmers's avatar
Tim Dettmers committed
50
str2optimizers["paged_lion"] = (Lion, bnb.optim.PagedLion)
51
52
53
54
55
56
57
58
str2optimizers["momentum"] = (
    lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
    lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False),
)
str2optimizers["rmsprop"] = (
    lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
    lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False),
)
Tim Dettmers's avatar
Tim Dettmers committed
59
str2optimizers["adam8bit"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False))
60
str2optimizers["lion8bit"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=False))
61
62
63
64
65
66
67
68
69
str2optimizers["momentum8bit"] = (
    lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
    lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False),
)
str2optimizers["rmsprop8bit"] = (
    lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
    lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False),
)

Tim Dettmers's avatar
Tim Dettmers committed
70
71
72
str2optimizers["adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True))
str2optimizers["paged_adamw8bit_blockwise"] = (torch.optim.AdamW, lambda pxx: bnb.optim.PagedAdamW8bit(pxx, block_wise=True))
str2optimizers["paged_adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.PagedAdam8bit(pxx, block_wise=True))
73
str2optimizers["lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=True))
Tim Dettmers's avatar
Tim Dettmers committed
74
str2optimizers["paged_lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.PagedLion8bit(pxx, block_wise=True))
75
76
77
78
79
80
81
82
str2optimizers["momentum8bit_blockwise"] = (
    lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
    lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True),
)
str2optimizers["rmsprop8bit_blockwise"] = (
    lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
    lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=True),
)
Tim Dettmers's avatar
Tim Dettmers committed
83
84

str2statenames = {}
85
str2statenames["adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
Tim Dettmers's avatar
Tim Dettmers committed
86
87
str2statenames["paged_adamw"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["paged_adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
88
str2statenames["lion"] = [("exp_avg", "state1")]
Tim Dettmers's avatar
Tim Dettmers committed
89
str2statenames["paged_lion"] = [("exp_avg", "state1")]
90
91
92
str2statenames["momentum"] = [("momentum_buffer", "state1")]
str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["rmsprop"] = [("square_avg", "state1")]
Tim Dettmers's avatar
Tim Dettmers committed
93
94
95
96
97
98
str2statenames["adam8bit"] = [("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2")]
str2statenames["lamb8bit"] = [("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2")]
str2statenames["adam8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")]
str2statenames["paged_adam8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")]
str2statenames["paged_adamw8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")]
str2statenames["momentum8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")]
99
str2statenames["lion8bit"] = [("exp_avg", "state1", "qmap1", "max1")]
Tim Dettmers's avatar
Tim Dettmers committed
100
str2statenames["momentum8bit_blockwise"] = [("momentum_buffer", "state1", "qmap1", "absmax1")]
101
str2statenames["rmsprop8bit"] = [("square_avg", "state1", "qmap1", "max1")]
Tim Dettmers's avatar
Tim Dettmers committed
102
str2statenames["rmsprop8bit_blockwise"] = [("square_avg", "state1", "qmap1", "absmax1")]
103
str2statenames["lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")]
Tim Dettmers's avatar
Tim Dettmers committed
104
str2statenames["paged_lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")]
Tim Dettmers's avatar
Tim Dettmers committed
105
106
107

dim1 = [1024]
dim2 = [32, 1024, 4097, 1]
Tim Dettmers's avatar
Tim Dettmers committed
108
109
gtype = [torch.float32, torch.float16, torch.bfloat16]
optimizer_names = ["adam", "momentum", "rmsprop", 'paged_adamw', 'paged_adam', 'lion', 'paged_lion']
110
values = list(product(dim1, dim2, gtype, optimizer_names))
111
names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values]
Tim Dettmers's avatar
Tim Dettmers committed
112
113
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
def test_optimizer32bit(dim1, dim2, gtype, optim_name):
Tim Dettmers's avatar
Tim Dettmers committed
114
    if gtype == torch.bfloat16 and optim_name in ['momentum', 'rmsprop']: pytest.skip()
115
116
117
    if dim1 == 1 and dim2 == 1:
        return
    p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
118
119
120
121
122
123
124
    p2 = p1.clone()
    p1 = p1.float()

    torch_optimizer = str2optimizers[optim_name][0]([p1])
    bnb_optimizer = str2optimizers[optim_name][1]([p2])

    if gtype == torch.float32:
Tim Dettmers's avatar
Tim Dettmers committed
125
        atol, rtol = 1e-6, 1e-5
126
127
    elif gtype == torch.bfloat16:
        atol, rtol = 1e-3, 1e-2
Tim Dettmers's avatar
Tim Dettmers committed
128
129
130
    else:
        atol, rtol = 1e-4, 1e-3

Tim Dettmers's avatar
Tim Dettmers committed
131
    for i in range(k):
132
        g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01
Tim Dettmers's avatar
Tim Dettmers committed
133
134
135
136
137
138
        p1.grad = g.clone().float()
        p2.grad = g.clone()

        bnb_optimizer.step()
        torch_optimizer.step()

139

Tim Dettmers's avatar
Tim Dettmers committed
140
        for name1, name2 in str2statenames[optim_name]:
Tim Dettmers's avatar
Tim Dettmers committed
141
            torch.testing.assert_close(
142
                torch_optimizer.state[p1][name1],
Tim Dettmers's avatar
Tim Dettmers committed
143
                bnb_optimizer.state[p2][name2].cuda(),
144
145
146
                atol=atol,
                rtol=rtol,
            )
Tim Dettmers's avatar
Tim Dettmers committed
147

148
149
150
        # since Lion can have pretty noisy updates where things lie at the boundary
        # allow up to 10 errors for Lion
        assert_most_approx_close(p1, p2.float(), atol, rtol, max_error_count=10)
Tim Dettmers's avatar
Tim Dettmers committed
151

152
        if i % (k // 5) == 0 and i > 0:
Tim Dettmers's avatar
Tim Dettmers committed
153
            path = get_temp_dir()
154
            torch.save(bnb_optimizer.state_dict(), join(path, "opt.pt"))
Tim Dettmers's avatar
Tim Dettmers committed
155
156
157
            del bnb_optimizer
            bnb_optimizer = None
            bnb_optimizer = str2optimizers[optim_name][1]([p2])
158
            bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt")))
Tim Dettmers's avatar
Tim Dettmers committed
159
            rm_path(path)
160
161
162
            # since Lion can have pretty noisy updates where things lie at the boundary
            # allow up to 10 errors for Lion
            assert_most_approx_close(p1, p2.float(), atol, rtol, max_error_count=10)
Tim Dettmers's avatar
Tim Dettmers committed
163
            for name1, name2 in str2statenames[optim_name]:
164
165
166
167
168
                # since Lion can have pretty noisy updates where things lie at the boundary
                # allow up to 10 errors for Lion
                assert_most_approx_close(torch_optimizer.state[p1][name1], bnb_optimizer.state[p2][name2],
                                         atol=atol, rtol=rtol,
                                         max_error_count=10)
Tim Dettmers's avatar
Tim Dettmers committed
169

170
        if gtype != torch.float32:
Tim Dettmers's avatar
Tim Dettmers committed
171
            # the adam buffers should also be close because they are 32-bit
172
            # but the parameters can diverge because they are 16-bit
Tim Dettmers's avatar
Tim Dettmers committed
173
174
            # the difference grow larger and larger with each update
            # --> copy the state to keep weights close
175
            p1.data = p1.data.to(p2.dtype).float()
Tim Dettmers's avatar
Tim Dettmers committed
176
            p2.copy_(p1.data)
Tim Dettmers's avatar
Tim Dettmers committed
177
            torch.testing.assert_close(p1.to(p2.dtype), p2)
178
179
180
        if optim_name in ["lars", "lamb"]:
            assert bnb_optimizer.state[p2]["unorm_vec"] > 0.0

Tim Dettmers's avatar
Tim Dettmers committed
181
182
183
184

dim1 = [1024]
dim2 = [32, 1024, 4097]
gtype = [torch.float32, torch.float16]
185
values = list(product(dim1, dim2, gtype))
186
names = ["dim1_{}_dim2_{}_gtype_{}".format(*vals) for vals in values]
187
188


Tim Dettmers's avatar
Tim Dettmers committed
189
190
@pytest.mark.parametrize("dim1, dim2, gtype", values, ids=names)
def test_global_config(dim1, dim2, gtype):
191
192
193
194
195
    if dim1 == 1 and dim2 == 1:
        return
    p1 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
    p2 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
    p3 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
196
197
198
199
200
201
202
    mask = torch.rand_like(p2) < 0.1
    beta1 = 0.9
    beta2 = 0.999
    lr = 0.001
    eps = 1e-8

    bnb.optim.GlobalOptimManager.get_instance().initialize()
203
204
205
    bnb.optim.GlobalOptimManager.get_instance().override_config(
        p3, "optim_bits", 8
    )
Tim Dettmers's avatar
Tim Dettmers committed
206

207
208
209
    bnb.optim.GlobalOptimManager.get_instance().register_parameters(
        [p1, p2, p3]
    )
Tim Dettmers's avatar
Tim Dettmers committed
210
211
212
213
214
215
216
217
218
219
220
221
    p1 = p1.cuda()
    p2 = p2.cuda()
    p3 = p3.cuda()

    adam2 = bnb.optim.Adam([p1, p2, p3], lr, (beta1, beta2), eps)

    if gtype == torch.float32:
        atol, rtol = 1e-6, 1e-5
    else:
        atol, rtol = 1e-4, 1e-3

    for i in range(50):
222
223
224
        g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001
        g2 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001
        g3 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001
Tim Dettmers's avatar
Tim Dettmers committed
225
226
227
228
229
230
        p1.grad = g1
        p2.grad = g2
        p3.grad = g3

        adam2.step()

231
232
        assert adam2.state[p3]["state1"].dtype == torch.uint8
        assert adam2.state[p3]["state2"].dtype == torch.uint8
Tim Dettmers's avatar
Tim Dettmers committed
233
234
235
236


dim1 = [1024]
dim2 = [32, 1024, 4097]
Tim Dettmers's avatar
Tim Dettmers committed
237
gtype = [torch.float32, torch.float16, torch.bfloat16]
238
239
optimizer_names = [
    "adam8bit",
240
    "lion8bit",
241
242
243
    "momentum8bit",
    "rmsprop8bit",
    "adam8bit_blockwise",
244
    "lion8bit_blockwise",
245
246
247
248
    "momentum8bit_blockwise",
    "rmsprop8bit_blockwise",
]
values = list(product(dim1, dim2, gtype, optimizer_names))
249
names = [
250
    "dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values
251
]
252
253


Tim Dettmers's avatar
Tim Dettmers committed
254
255
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
def test_optimizer8bit(dim1, dim2, gtype, optim_name):
Tim Dettmers's avatar
Tim Dettmers committed
256
    if gtype == torch.bfloat16 and optim_name not in ['adam8bit_blockwise', 'lion8bit_blockwise']: pytest.skip()
257
258
259
    if dim1 == 1 and dim2 == 1:
        return
    p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
260
261
262
263
264
265
266
267
268
269
    p2 = p1.clone()
    p1 = p1.float()
    blocksize = 2048

    torch_optimizer = str2optimizers[optim_name][0]([p1])
    bnb_optimizer = str2optimizers[optim_name][1]([p2])

    if gtype == torch.float32:
        atol, rtol = 3e-3, 1e-3
        patol, prtol = 1e-5, 1e-3
Tim Dettmers's avatar
Tim Dettmers committed
270
271
272
    elif gtype == torch.bfloat16:
        atol, rtol = 3e-3, 1e-3
        patol, prtol = 1e-4, 1e-2
Tim Dettmers's avatar
Tim Dettmers committed
273
274
275
276
277
278
279
    else:
        atol, rtol = 3e-3, 1e-3
        patol, prtol = 1e-5, 1e-3

    errors = []
    relerrors = []

280
    for i in range(100):
281
        g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01
Tim Dettmers's avatar
Tim Dettmers committed
282
283
284
285
286
287
        p1.grad = g.clone().float()
        p2.grad = g.clone()

        bnb_optimizer.step()
        torch_optimizer.step()

288
289
290
        # since Lion can have pretty noisy updates where things lie at the boundary
        # allow up to 5 errors for Lion
        assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=5)
Tim Dettmers's avatar
Tim Dettmers committed
291
292
293

        dequant_states = []
        for name1, name2, qmap, max_val in str2statenames[optim_name]:
294
295
296
297
298
299
300
301
            # print(bnb_optimizer.state[p2][max_val], name1)
            if "blockwise" in optim_name:
                s1 = F.dequantize_blockwise(
                    code=bnb_optimizer.state[p2][qmap],
                    absmax=bnb_optimizer.state[p2][max_val],
                    A=bnb_optimizer.state[p2][name2],
                    blocksize=blocksize,
                )
Tim Dettmers's avatar
Tim Dettmers committed
302
            else:
303
304
305
306
307
308
309
310
311
312
313
                s1 = F.dequantize(
                    code=bnb_optimizer.state[p2][qmap],
                    absmax=bnb_optimizer.state[p2][max_val],
                    A=bnb_optimizer.state[p2][name2],
                )
            num_not_close = (
                torch.isclose(
                    torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol
                )
                == 0
            )
314
            #assert num_not_close.sum().item() < 20
Tim Dettmers's avatar
Tim Dettmers committed
315
316
            dequant_states.append(s1.clone())

317
        err = torch.abs(p1 - p2)
318
        relerr = err / (torch.abs(p1)+1e-9)
Tim Dettmers's avatar
Tim Dettmers committed
319
320
        if g.dtype == torch.bfloat16:
            assert err.mean() < 0.00015
321
            assert relerr.mean() < 0.0016
Tim Dettmers's avatar
Tim Dettmers committed
322
        else:
323
324
            assert err.mean() < 0.00012
            assert relerr.mean() < 0.0012
Tim Dettmers's avatar
Tim Dettmers committed
325
326
327
328
329

        errors.append(err.mean().item())
        relerrors.append(relerr.mean().item())

        if i % 10 == 0 and i > 0:
330
331
332
            for (name1, name2, qmap, max_val), s in zip(
                str2statenames[optim_name], dequant_states
            ):
Tim Dettmers's avatar
Tim Dettmers committed
333
334
335
336
337
                s1cpy = s.clone()
                raws1cpy = bnb_optimizer.state[p2][name2].clone()
                qmap1 = bnb_optimizer.state[p2][qmap].clone()

                path = get_temp_dir()
338
                torch.save(bnb_optimizer.state_dict(), join(path, "opt.pt"))
Tim Dettmers's avatar
Tim Dettmers committed
339
340
341
                del bnb_optimizer
                bnb_optimizer = None
                bnb_optimizer = str2optimizers[optim_name][1]([p2])
342
                bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt")))
Tim Dettmers's avatar
Tim Dettmers committed
343
                rm_path(path)
Tim Dettmers's avatar
Tim Dettmers committed
344
345
                torch.testing.assert_close(raws1cpy, bnb_optimizer.state[p2][name2])
                torch.testing.assert_close(qmap1, bnb_optimizer.state[p2][qmap])
Tim Dettmers's avatar
Tim Dettmers committed
346

347
348
349
350
351
352
353
                if "blockwise" in optim_name:
                    s1 = F.dequantize_blockwise(
                        code=bnb_optimizer.state[p2][qmap],
                        absmax=bnb_optimizer.state[p2][max_val],
                        A=bnb_optimizer.state[p2][name2],
                        blocksize=blocksize,
                    )
Tim Dettmers's avatar
Tim Dettmers committed
354
                else:
355
356
357
358
359
                    s1 = F.dequantize(
                        code=bnb_optimizer.state[p2][qmap],
                        absmax=bnb_optimizer.state[p2][max_val],
                        A=bnb_optimizer.state[p2][name2],
                    )
Tim Dettmers's avatar
Tim Dettmers committed
360
                torch.testing.assert_close(s1cpy, s1)
Tim Dettmers's avatar
Tim Dettmers committed
361

Tim Dettmers's avatar
Tim Dettmers committed
362
                num_not_close = (torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0)
Tim Dettmers's avatar
Tim Dettmers committed
363
                assert num_not_close.sum().item() < 20
364
365
366
            # since Lion can have pretty noisy updates where things lie at the boundary
            # allow up to 5 errors for Lion
            assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=5)
Tim Dettmers's avatar
Tim Dettmers committed
367
368
369
370
371

        # the parameters diverge quickly. Here we keep them close
        # together so we can test against the Adam error
        p1.data = p1.data.to(gtype).float()
        p2.copy_(p1.data)
Tim Dettmers's avatar
Tim Dettmers committed
372
        torch.testing.assert_close(p1.to(gtype), p2)
Tim Dettmers's avatar
Tim Dettmers committed
373
        for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states):
Tim Dettmers's avatar
Tim Dettmers committed
374
375
            torch_optimizer.state[p1][name1].copy_(s.data)

376
377
    # print(sum(errors)/len(errors))
    # print(sum(relerrors)/len(relerrors))
Tim Dettmers's avatar
Tim Dettmers committed
378
379
380
381
382
383


dim1 = [1024]
dim2 = [32, 1024, 4097]
gtype = [torch.float32]
optim_bits = [32, 8]
384
values = list(product(dim1, dim2, gtype, optim_bits))
385
names = [
386
    "dim1_{}_dim2_{}_gtype_{}_optim_bits_{}".format(*vals)
387
388
    for vals in values
]
389
390


Tim Dettmers's avatar
Tim Dettmers committed
391
392
@pytest.mark.parametrize("dim1, dim2, gtype, optim_bits", values, ids=names)
def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
393
394
395
    if dim1 == 1 and dim2 == 1:
        return
    p1 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
396
397
398
399
400
401
402
    beta1 = 0.9
    beta2 = 0.999
    lr = 0.001
    eps = 1e-8
    p1 = p1.cuda()
    p2 = p1.clone()
    adam1 = bnb.optim.Adam([p1], lr, (beta1, beta2), eps, optim_bits=optim_bits)
403
    adam2 = bnb.optim.Adam(
404
405
406
407
408
409
        [p2],
        lr,
        (beta1, beta2),
        eps,
        optim_bits=optim_bits,
        percentile_clipping=5,
410
    )
Tim Dettmers's avatar
Tim Dettmers committed
411
412
413
414
415
416

    gnorm_vec = torch.zeros(100).cuda()
    step = 0

    for i in range(50):
        step += 1
417
418
419
        g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + (
            0.01 * i
        )
Tim Dettmers's avatar
Tim Dettmers committed
420
421
422
        g2 = g1.clone()
        p2.grad = g2

423
424
425
426
        current_gnorm, clip_val, gnorm_scale = F.percentile_clipping(
            g1, gnorm_vec, step, 5
        )
        g1 = (g1.float() * gnorm_scale).to(gtype)
Tim Dettmers's avatar
Tim Dettmers committed
427
428
429
430
431
432
433
        p1.grad = g1

        adam1.step()
        adam2.step()

        # gnorm_scale is not deterministic (warp reductions), as such there can be slight differences in state
        if optim_bits == 32:
Tim Dettmers's avatar
Tim Dettmers committed
434
435
            torch.testing.assert_close(p1, p2)
            torch.testing.assert_close(
436
437
438
439
440
                adam1.state[p1]["state1"],
                adam2.state[p2]["state1"],
                atol=5e-5,
                rtol=1e-4,
            )
Tim Dettmers's avatar
Tim Dettmers committed
441
            torch.testing.assert_close(
442
443
444
445
446
                adam1.state[p1]["state2"],
                adam2.state[p2]["state2"],
                atol=5e-5,
                rtol=1e-4,
            )
Tim Dettmers's avatar
Tim Dettmers committed
447
        elif optim_bits == 8:
Tim Dettmers's avatar
Tim Dettmers committed
448
449
            torch.testing.assert_close(p1, p2, atol=1e-4, rtol=1e-3)
            torch.testing.assert_close(
450
451
452
453
                adam1.state[p1]["state1"],
                adam2.state[p2]["state1"],
                atol=2,
                rtol=1e-3,
454
            )
Tim Dettmers's avatar
Tim Dettmers committed
455
            torch.testing.assert_close(
456
457
458
459
                adam1.state[p1]["state2"],
                adam2.state[p2]["state2"],
                atol=2,
                rtol=1e-3,
460
461
462
            )
            adam1.state[p1]["state1"].copy_(adam2.state[p2]["state1"])
            adam1.state[p1]["state2"].copy_(adam2.state[p2]["state2"])
Tim Dettmers's avatar
Tim Dettmers committed
463
464
        if i % 10 == 0 and i > 0:
            path = get_temp_dir()
465
            torch.save(adam2.state_dict(), join(path, "opt.pt"))
Tim Dettmers's avatar
Tim Dettmers committed
466
467
            del adam2
            adam2 = None
468
469
470
471
472
473
474
475
476
            adam2 = bnb.optim.Adam(
                [p2],
                lr,
                (beta1, beta2),
                eps,
                optim_bits=optim_bits,
                percentile_clipping=5,
            )
            adam2.load_state_dict(torch.load(join(path, "opt.pt")))
Tim Dettmers's avatar
Tim Dettmers committed
477
478
479
480
481


dim1 = [4096]
dim2 = [4096]
gtype = [torch.float32, torch.float16]
482
483
484
485
486
# optimizer_names = ['adam8bit_blockwise', 'adam8bit', 'lamb8bit']
# optimizer_names = ['adam8bit_blockwise', 'adam_apex', 'adam8bit', 'adam', 'adam_pytorch']
# optimizer_names = ['momentum_apex', 'momentum8bit', 'momentum_pytorch']
# optimizer_names = ['lamb_apex', 'lamb8bit']
# optimizer_names = ['lars_apex', 'lars8bit']
Tim Dettmers's avatar
Tim Dettmers committed
487
optimizer_names = ["adam8bit_blockwise", 'paged_adam8bit_blockwise', 'paged_adamw8bit_blockwise', 'paged_lion8bit_blockwise']
488
values = list(product(dim1, dim2, gtype, optimizer_names))
489
names = [
490
    "dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values
491
]
492
493


Tim Dettmers's avatar
Tim Dettmers committed
494
495
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
496
497
498
    if dim1 == 1 and dim2 == 1:
        return
    p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
499
500
501

    bnb_optimizer = str2optimizers[optim_name][1]([p1])

502
    g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01
Tim Dettmers's avatar
Tim Dettmers committed
503
    p1.grad = g
Tim Dettmers's avatar
Tim Dettmers committed
504
    for i in range(k):
505
        if i == k // 5:
Tim Dettmers's avatar
Tim Dettmers committed
506
507
508
509
510
511
512
            # 100 iterations for burn-in
            torch.cuda.synchronize()
            t0 = time.time()

        bnb_optimizer.step()

    torch.cuda.synchronize()
513
514
515
516
517
    s = time.time() - t0
    print("")
    params = (k - k // 5) * dim1 * dim2
    print(optim_name, gtype, s / params)
    # assert s < 3.9
518

519
dim1 = [2*1024]
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
gtype = [torch.float16]
#mode = ['torch', 'bnb']
mode = ['bnb']
optimizer_names = ['paged_adamw']
#optimizer_names = ['paged_adamw8bit_blockwise']
values = list(product(dim1,gtype, optimizer_names, mode))
names = ['dim1_{0}_gtype_{1}_optim_{2}_mode_{3}'.format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, gtype, optim_name, mode", values, ids=names)
def test_stream_optimizer_bench(dim1, gtype, optim_name, mode):
    layers1 = torch.nn.Sequential(*torch.nn.ModuleList([torch.nn.Linear(dim1, dim1) for i in range(10)]))
    layers1 = layers1.to(gtype)
    layers1 = layers1.cuda()

    large_tensor = None
    if mode == 'torch':
        optim = str2optimizers[optim_name][0](layers1.parameters())
    else:
        optim = str2optimizers[optim_name][1](layers1.parameters())
        # 12 GB
        large_tensor = torch.empty((int(4.5e9),), device='cuda')

    torch.cuda.synchronize()
    time.sleep(5)

    num_batches = 5
    batches = torch.randn(num_batches, 128, dim1, device='cuda').to(gtype)
    lbls = torch.randint(0, 10, size=(num_batches,128)).cuda()

    for i in range(num_batches):
        print(i)
        b = batches[i]
        if i ==2:
            torch.cuda.synchronize()
            t0 = time.time()

        out1 = layers1(b)

        loss1 = torch.nn.functional.cross_entropy(out1, lbls[i]).mean()
        loss1.backward()
        optim.step()
    torch.cuda.synchronize()
    print(mode, time.time() - t0)