"vscode:/vscode.git/clone" did not exist on "3b6a9154dde490da393630e7790136e7f516d3c1"
Commit 74399248 authored by Tim Dettmers's avatar Tim Dettmers
Browse files

Initial commit

parents
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import pytest
import torch
import bitsandbytes as bnb
from itertools import product
from bitsandbytes import functional as F
def setup():
pass
def teardown():
pass
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['float', 'half'])
def test_estimate_quantiles(dtype):
A = torch.rand(1024, 1024, device='cuda')
A = A.to(dtype)
code = F.estimate_quantiles(A)
percs = torch.linspace(1/512, 511/512, 256, device=A.device)
torch.testing.assert_allclose(percs, code, atol=1e-3, rtol=1e-2)
A = torch.randn(1024, 1024, device='cuda')
A = A.to(dtype)
code = F.estimate_quantiles(A)
quantiles = torch.quantile(A.float(), percs)
diff = torch.abs(code-quantiles)
assert (diff > 5e-02).sum().item() == 0
def test_quantile_quantization():
for i in range(100):
A1 = torch.randn(1024, 1024, device='cuda')
code = F.estimate_quantiles(A1)
C = F.quantize_no_absmax(A1, code)
A2 = F.dequantize_no_absmax(C, code)
diff = torch.abs(A1-A2).mean().item()
assert diff < 0.0075
A1 = torch.rand(1024, 1024, device='cuda')
code = F.estimate_quantiles(A1)
C = F.quantize_no_absmax(A1, code)
A2 = F.dequantize_no_absmax(C, code)
diff = torch.abs(A1-A2).mean().item()
torch.testing.assert_allclose(A1, A2, atol=5e-3, rtol=0)
assert diff < 0.001
def test_dynamic_quantization():
diffs = []
reldiffs = []
for i in range(100):
A1 = torch.randn(1024, 1024, device='cuda')
C, S = F.quantize(A1)
A2 = F.dequantize(C, S)
diff = torch.abs(A1-A2)
reldiff = diff/torch.abs(A1+1e-8)
diffs.append(diff.mean().item())
reldiffs.append(reldiff.mean().item())
assert diff.mean().item() < 0.0135
print(sum(diffs)/len(diffs))
print(sum(reldiffs)/len(reldiffs))
for i in range(100):
A1 = torch.rand(1024, 1024, device='cuda')
C, S = F.quantize(A1)
A2 = F.dequantize(C, S)
diff = torch.abs(A1-A2).mean().item()
torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0)
assert diff < 0.004
def test_dynamic_blockwise_quantization():
diffs = []
reldiffs = []
for i in range(100):
A1 = torch.randn(1024, 1024, device='cuda')
C, S = F.quantize_blockwise(A1)
A2 = F.dequantize_blockwise(C, S)
diff = torch.abs(A1-A2)
reldiff = diff/torch.abs(A1+1e-8)
diffs.append(diff.mean().item())
reldiffs.append(reldiff.mean().item())
assert diffs[-1] < 0.011
print(sum(diffs)/len(diffs))
print(sum(reldiffs)/len(reldiffs))
diffs = []
for i in range(100):
A1 = torch.rand(1024, 1024, device='cuda')
C, S = F.quantize_blockwise(A1)
A2 = F.dequantize_blockwise(C, S)
diff = torch.abs(A1-A2).mean().item()
assert diff < 0.0033
diffs.append(diff)
torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0)
#print(sum(diffs)/len(diffs))
def test_dynamic_blockwise_stochastic_quantization():
diffs = []
reldiffs = []
rand = torch.rand(1024).cuda()
for i in range(100):
A1 = torch.randn(1024, 1024, device='cuda')
C1, S1 = F.quantize_blockwise(A1, rand=rand)
C2, S2 = F.quantize_blockwise(A1)
# a maximunm distance of quantized values of 1
torch.testing.assert_allclose(C1, C2, atol=1, rtol=0)
fraction_smaller = (C1<C2).float().sum()/C1.numel()
fraction_larger = (C1>C2).float().sum()/C1.numel()
torch.testing.assert_allclose(fraction_larger, fraction_smaller, atol=0.01, rtol=0)
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=['float', 'half'])
def test_percentile_clipping(gtype):
gnorm_vec1 = torch.zeros(100, device='cuda')
gnorm_vec2 = torch.zeros(100, device='cuda')
n = 4
step = 0
percentile=5
for i in range(1000):
step += 1
g = torch.randn(n, n, dtype=gtype, device='cuda')
gnorm1, clip2, gnorm_scale = F.percentile_clipping(g, gnorm_vec2, step, percentile=percentile)
assert gnorm_scale == 1.0 if gnorm1 < clip2 else clip2/gnorm1
gnorm2 = torch.norm(g.float())
if step == 1:
gnorm_vec1[:] = gnorm2
else:
gnorm_vec1[step % 100] = gnorm2
vals, idx = torch.sort(gnorm_vec1)
clip1 = vals[percentile]
torch.testing.assert_allclose(gnorm_vec1, torch.sqrt(gnorm_vec2))
torch.testing.assert_allclose(clip1, clip2)
torch.testing.assert_allclose(gnorm1, gnorm2)
def test_stable_embedding():
layer = bnb.nn.StableEmbedding(1024, 1024)
layer.reset_parameters()
def test_dynamic_blockwise_quantization_cpu():
#A1 = torch.randn(1024, 1024, device='cpu')
#code = F.create_dynamic_map()
#for i in range(1000):
# C, S = F.quantize_blockwise(A1, code=code)
# A2 = F.dequantize_blockwise(C, S)
for i in range(10):
# equivalence with GPU blockwise quantization
A1 = torch.randn(1024, 1024, device='cpu')
C1, S1 = F.quantize_blockwise(A1)
C2, S2 = F.quantize_blockwise(A1.cuda())
torch.testing.assert_allclose(S1[0], S2[0].cpu())
# there seems to be some issues with precision in CUDA vs CPU
# not all elements are usually close, with couple off elements in a million
idx = torch.isclose(C1, C2.cpu())
assert (idx==0).sum().item() < 15
diffs = []
reldiffs = []
for i in range(10):
A1 = torch.randn(1024, 1024, device='cpu')
C, S = F.quantize_blockwise(A1)
A2 = F.dequantize_blockwise(C, S)
diff = torch.abs(A1-A2)
reldiff = diff/torch.abs(A1+1e-8)
diffs.append(diff.mean().item())
reldiffs.append(reldiff.mean().item())
assert diffs[-1] < 0.011
#print(sum(diffs)/len(diffs))
#print(sum(reldiffs)/len(reldiffs))
diffs = []
for i in range(10):
A1 = torch.rand(1024, 1024, device='cpu')
C, S = F.quantize_blockwise(A1)
A2 = F.dequantize_blockwise(C, S)
diff = torch.abs(A1-A2).mean().item()
assert diff < 0.0033
diffs.append(diff)
torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0)
#print(sum(diffs)/len(diffs))
def test_histogram():
dim1, dim2 = 32, 32
source = torch.rand(dim1, dim2, device='cuda')
idx1 = torch.randint(0, 255, size=(dim1, dim2), device='cuda').int()
idx2 = torch.randint(0, 255, size=(dim1, dim2), device='cuda').int()
histogram1 = torch.zeros((256, 256)).cuda()
histogram2 = torch.zeros((256, 256)).cuda()
F.histogram_scatter_add_2d(histogram2, idx1, idx2, source)
for i in range(dim1):
for j in range(dim2):
histogram1[idx1[i, j].item(), idx2[i, j].item()] += source[i, j]
torch.testing.assert_allclose(histogram1, histogram2)
torch.testing.assert_allclose(histogram1.sum(), source.sum())
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import time
import shutil
import uuid
import pytest
import ctypes
import torch
import bitsandbytes as bnb
import bitsandbytes.functional as F
from os.path import join
from itertools import product
import apex
def get_temp_dir():
path = '/tmp/autoswap/{0}'.format(str(uuid.uuid4()))
os.makedirs(path, exist_ok=True)
return path
def rm_path(path):
shutil.rmtree(path)
str2optimizers = {}
str2optimizers['adam_pytorch'] = (None, torch.optim.Adam, bnb.optim.Adam)
str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam)
str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam)
str2optimizers['momentum_pytorch'] = (None, lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), bnb.optim.Adam)
str2optimizers['lamb_apex'] = (None, lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.00, use_nvlamb=True), bnb.optim.Adam)
str2optimizers['lars_apex'] = (None, lambda pxx: apex.parallel.LARC.LARC(apex.optimizers.FusedSGD(pxx, 0.01, 0.9)), bnb.optim.Adam)
str2optimizers['adam'] = (torch.optim.Adam, bnb.optim.Adam)
str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam)
str2optimizers['momentum'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False))
str2optimizers['lars'] = (lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9), lambda pxx: bnb.optim.LARS(pxx, 0.01, 0.9))
str2optimizers['lamb'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB)
str2optimizers['rmsprop'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False))
str2optimizers['adam8bit'] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False))
str2optimizers['momentum8bit'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False))
str2optimizers['rmsprop8bit'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False))
str2optimizers['lamb8bit'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB8bit)
str2optimizers['lars8bit'] = (lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9), lambda pxx: bnb.optim.LARS8bit(pxx, 0.01, 0.9))
str2optimizers['adam8bit_blockwise'] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True))
str2optimizers['momentum8bit_blockwise'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True))
str2optimizers['rmsprop8bit_blockwise'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=True))
str2statenames = {}
str2statenames['adam'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')]
str2statenames['momentum'] = [('momentum_buffer', 'state1')]
str2statenames['lars'] = [('momentum_buffer', 'state1')]
str2statenames['lamb'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')]
str2statenames['rmsprop'] = [('square_avg', 'state1')]
str2statenames['adam8bit'] = [('exp_avg', 'state1', 'qmap1', 'max1'), ('exp_avg_sq', 'state2', 'qmap2', 'max2')]
str2statenames['lamb8bit'] = [('exp_avg', 'state1', 'qmap1', 'max1'), ('exp_avg_sq', 'state2', 'qmap2', 'max2')]
str2statenames['adam8bit_blockwise'] = [('exp_avg', 'state1', 'qmap1', 'absmax1'), ('exp_avg_sq', 'state2', 'qmap2', 'absmax2')]
str2statenames['momentum8bit'] = [('momentum_buffer', 'state1', 'qmap1', 'max1')]
str2statenames['momentum8bit_blockwise'] = [('momentum_buffer', 'state1', 'qmap1', 'absmax1')]
str2statenames['lars8bit'] = [('momentum_buffer', 'state1', 'qmap1', 'max1')]
str2statenames['rmsprop8bit'] = [('square_avg', 'state1', 'qmap1', 'max1')]
str2statenames['rmsprop8bit_blockwise'] = [('square_avg', 'state1', 'qmap1', 'absmax1')]
dim1 = [1024]
dim2 = [32, 1024, 4097, 1]
gtype = [torch.float32, torch.float16]
optimizer_names = ['adam', 'momentum', 'rmsprop', 'lars', 'lamb']
values = list(product(dim1,dim2, gtype, optimizer_names))
names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'.format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
def test_optimizer32bit(dim1, dim2, gtype, optim_name):
if dim1 == 1 and dim2 == 1: return
p1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1
p2 = p1.clone()
p1 = p1.float()
torch_optimizer = str2optimizers[optim_name][0]([p1])
bnb_optimizer = str2optimizers[optim_name][1]([p2])
if gtype == torch.float32:
atol, rtol = 1e-6, 1e-5
else:
atol, rtol = 1e-4, 1e-3
for i in range(50):
g = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.01
p1.grad = g.clone().float()
p2.grad = g.clone()
bnb_optimizer.step()
torch_optimizer.step()
for name1, name2 in str2statenames[optim_name]:
torch.testing.assert_allclose(torch_optimizer.state[p1][name1], bnb_optimizer.state[p2][name2], atol=atol, rtol=rtol)
torch.testing.assert_allclose(p1, p2.float(), atol=atol, rtol=rtol)
if i % 10 == 0 and i > 0:
path = get_temp_dir()
torch.save(bnb_optimizer.state_dict(),join(path, 'opt.pt'))
del bnb_optimizer
bnb_optimizer = None
bnb_optimizer = str2optimizers[optim_name][1]([p2])
bnb_optimizer.load_state_dict(torch.load(join(path, 'opt.pt')))
rm_path(path)
torch.testing.assert_allclose(p1, p2.float(), atol=atol, rtol=rtol)
for name1, name2 in str2statenames[optim_name]:
torch.testing.assert_allclose(torch_optimizer.state[p1][name1], bnb_optimizer.state[p2][name2], atol=atol, rtol=rtol)
if gtype == torch.float16:
# the adam buffers should also be close because they are 32-bit
# but the paramters can diverge because they are 16-bit
# the difference grow larger and larger with each update
# --> copy the state to keep weights close
p1.data = p1.data.half().float()
p2.copy_(p1.data)
torch.testing.assert_allclose(p1.half(), p2)
if optim_name in ['lars', 'lamb']:
assert bnb_optimizer.state[p2]['unorm_vec'] > 0.0
dim1 = [1024]
dim2 = [32, 1024, 4097]
gtype = [torch.float32, torch.float16]
values = list(product(dim1,dim2, gtype))
names = ['dim1_{0}_dim2_{1}_gtype_{2}'.format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, gtype", values, ids=names)
def test_global_config(dim1, dim2, gtype):
if dim1 == 1 and dim2 == 1: return
p1 = torch.randn(dim1,dim2, device='cpu', dtype=gtype)*0.1
p2 = torch.randn(dim1,dim2, device='cpu', dtype=gtype)*0.1
p3 = torch.randn(dim1,dim2, device='cpu', dtype=gtype)*0.1
mask = torch.rand_like(p2) < 0.1
beta1 = 0.9
beta2 = 0.999
lr = 0.001
eps = 1e-8
bnb.optim.GlobalOptimManager.get_instance().initialize()
bnb.optim.GlobalOptimManager.get_instance().override_config(p3, 'optim_bits', 8)
bnb.optim.GlobalOptimManager.get_instance().register_parameters([p1, p2, p3])
p1 = p1.cuda()
p2 = p2.cuda()
p3 = p3.cuda()
adam2 = bnb.optim.Adam([p1, p2, p3], lr, (beta1, beta2), eps)
if gtype == torch.float32:
atol, rtol = 1e-6, 1e-5
else:
atol, rtol = 1e-4, 1e-3
for i in range(50):
g1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001
g2 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001
g3 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001
p1.grad = g1
p2.grad = g2
p3.grad = g3
adam2.step()
assert adam2.state[p3]['state1'].dtype == torch.uint8
assert adam2.state[p3]['state2'].dtype == torch.uint8
dim1 = [1024]
dim2 = [32, 1024, 4097]
gtype = [torch.float32, torch.float16]
optimizer_names = ['adam8bit', 'momentum8bit', 'rmsprop8bit', 'adam8bit_blockwise', 'lamb8bit', 'lars8bit', 'momentum8bit_blockwise', 'rmsprop8bit_blockwise']
values = list(product(dim1,dim2, gtype, optimizer_names))
names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'.format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
def test_optimizer8bit(dim1, dim2, gtype, optim_name):
if dim1 == 1 and dim2 == 1: return
p1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1
p2 = p1.clone()
p1 = p1.float()
blocksize = 2048
torch_optimizer = str2optimizers[optim_name][0]([p1])
bnb_optimizer = str2optimizers[optim_name][1]([p2])
if gtype == torch.float32:
atol, rtol = 3e-3, 1e-3
patol, prtol = 1e-5, 1e-3
else:
atol, rtol = 3e-3, 1e-3
patol, prtol = 1e-5, 1e-3
errors = []
relerrors = []
for i in range(50):
g = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.01
p1.grad = g.clone().float()
p2.grad = g.clone()
bnb_optimizer.step()
torch_optimizer.step()
torch.testing.assert_allclose(p1, p2.float(), atol=patol, rtol=prtol)
dequant_states = []
for name1, name2, qmap, max_val in str2statenames[optim_name]:
#print(bnb_optimizer.state[p2][max_val], name1)
if 'blockwise' in optim_name:
s1 = F.dequantize_blockwise(code=bnb_optimizer.state[p2][qmap], absmax=bnb_optimizer.state[p2][max_val], A=bnb_optimizer.state[p2][name2], blocksize=blocksize)
else:
s1 = F.dequantize(code=bnb_optimizer.state[p2][qmap], absmax=bnb_optimizer.state[p2][max_val], A=bnb_optimizer.state[p2][name2])
num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol)==0
assert num_not_close.sum().item() < 20
dequant_states.append(s1.clone())
err = torch.abs(p1-p2)
relerr = err/torch.abs(p1)
assert err.mean() < 0.0001
assert relerr.mean() < 0.001
errors.append(err.mean().item())
relerrors.append(relerr.mean().item())
if i % 10 == 0 and i > 0:
for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states):
s1cpy = s.clone()
raws1cpy = bnb_optimizer.state[p2][name2].clone()
qmap1 = bnb_optimizer.state[p2][qmap].clone()
path = get_temp_dir()
torch.save(bnb_optimizer.state_dict(),join(path, 'opt.pt'))
del bnb_optimizer
bnb_optimizer = None
bnb_optimizer = str2optimizers[optim_name][1]([p2])
bnb_optimizer.load_state_dict(torch.load(join(path, 'opt.pt')))
rm_path(path)
torch.testing.assert_allclose(raws1cpy, bnb_optimizer.state[p2][name2])
torch.testing.assert_allclose(qmap1, bnb_optimizer.state[p2][qmap])
if 'blockwise' in optim_name:
s1 = F.dequantize_blockwise(code=bnb_optimizer.state[p2][qmap], absmax=bnb_optimizer.state[p2][max_val], A=bnb_optimizer.state[p2][name2], blocksize=blocksize)
else:
s1 = F.dequantize(code=bnb_optimizer.state[p2][qmap], absmax=bnb_optimizer.state[p2][max_val], A=bnb_optimizer.state[p2][name2])
torch.testing.assert_allclose(s1cpy, s1)
num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol)==0
assert num_not_close.sum().item() < 20
torch.testing.assert_allclose(p1, p2.float(), atol=patol, rtol=prtol)
# the parameters diverge quickly. Here we keep them close
# together so we can test against the Adam error
p1.data = p1.data.to(gtype).float()
p2.copy_(p1.data)
torch.testing.assert_allclose(p1.to(gtype), p2)
for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states):
torch_optimizer.state[p1][name1].copy_(s.data)
#print(sum(errors)/len(errors))
#print(sum(relerrors)/len(relerrors))
dim1 = [1024]
dim2 = [32, 1024, 4097]
gtype = [torch.float32]
optim_bits = [32, 8]
values = list(product(dim1,dim2, gtype, optim_bits))
names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_bits_{3}'.format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, gtype, optim_bits", values, ids=names)
def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
if dim1 == 1 and dim2 == 1: return
p1 = torch.randn(dim1,dim2, device='cpu', dtype=gtype)*0.1
beta1 = 0.9
beta2 = 0.999
lr = 0.001
eps = 1e-8
p1 = p1.cuda()
p2 = p1.clone()
adam1 = bnb.optim.Adam([p1], lr, (beta1, beta2), eps, optim_bits=optim_bits)
adam2 = bnb.optim.Adam([p2], lr, (beta1, beta2), eps, optim_bits=optim_bits, percentile_clipping=5)
gnorm_vec = torch.zeros(100).cuda()
step = 0
for i in range(50):
step += 1
g1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + (0.01*i)
g2 = g1.clone()
p2.grad = g2
current_gnorm, clip_val, gnorm_scale = F.percentile_clipping(g1, gnorm_vec, step, 5)
g1 = (g1.float()*gnorm_scale).to(gtype)
p1.grad = g1
adam1.step()
adam2.step()
# gnorm_scale is not deterministic (warp reductions), as such there can be slight differences in state
if optim_bits == 32:
torch.testing.assert_allclose(p1, p2)
torch.testing.assert_allclose(adam1.state[p1]['state1'], adam2.state[p2]['state1'], atol=5e-5, rtol=1e-4)
torch.testing.assert_allclose(adam1.state[p1]['state2'], adam2.state[p2]['state2'], atol=5e-5, rtol=1e-4)
elif optim_bits == 8:
torch.testing.assert_allclose(p1, p2, atol=1e-4, rtol=1e-3)
torch.testing.assert_allclose(adam1.state[p1]['state1'], adam2.state[p2]['state1'], atol=2, rtol=1e-3)
torch.testing.assert_allclose(adam1.state[p1]['state2'], adam2.state[p2]['state2'], atol=2, rtol=1e-3)
adam1.state[p1]['state1'].copy_(adam2.state[p2]['state1'])
adam1.state[p1]['state2'].copy_(adam2.state[p2]['state2'])
if i % 10 == 0 and i > 0:
path = get_temp_dir()
torch.save(adam2.state_dict(),join(path, 'opt.pt'))
del adam2
adam2 = None
adam2 = bnb.optim.Adam([p2], lr, (beta1, beta2), eps, optim_bits=optim_bits, percentile_clipping=5)
adam2.load_state_dict(torch.load(join(path, 'opt.pt')))
dim1 = [4096]
dim2 = [4096]
gtype = [torch.float32, torch.float16]
#optimizer_names = ['adam8bit_blockwise', 'adam8bit', 'lamb8bit']
#optimizer_names = ['adam8bit_blockwise', 'adam_apex', 'adam8bit', 'adam', 'adam_pytorch']
#optimizer_names = ['momentum_apex', 'momentum8bit', 'momentum_pytorch']
#optimizer_names = ['lamb_apex', 'lamb8bit']
#optimizer_names = ['lars_apex', 'lars8bit']
optimizer_names = ['adam8bit_blockwise']
values = list(product(dim1,dim2, gtype, optimizer_names))
names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'.format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
if dim1 == 1 and dim2 == 1: return
p1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1
bnb_optimizer = str2optimizers[optim_name][1]([p1])
g = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.01
p1.grad = g
for i in range(5000):
if i == 500:
# 100 iterations for burn-in
torch.cuda.synchronize()
t0 = time.time()
bnb_optimizer.step()
torch.cuda.synchronize()
s = time.time()-t0
print('')
params = 4500*4096*4096
print(optim_name, gtype, s/params)
#assert s < 3.9
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment