Commit abf4a1e3 authored by Egor Krivov's avatar Egor Krivov
Browse files

enabled tests

parent 35ce337b
......@@ -11,7 +11,8 @@ import torch
import bitsandbytes as bnb
import bitsandbytes.functional as F
from tests.helpers import describe_dtype, id_formatter
from bitsandbytes.utils import sync_gpu
from tests.helpers import describe_dtype, get_available_devices, id_formatter
# import apex
......@@ -168,7 +169,8 @@ optimizer_names_32bit = [
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [32, 1024, 4097, 1], ids=id_formatter("dim2"))
def test_optimizer32bit(requires_cuda, dim1, dim2, gtype, optim_name):
@pytest.mark.parametrize("device", get_available_devices(), ids=id_formatter("device"))
def test_optimizer32bit(dim1, dim2, gtype, optim_name, device):
if optim_name.startswith("paged_") and sys.platform == "win32":
pytest.skip("Paged optimizers can have issues on Windows.")
......@@ -176,7 +178,7 @@ def test_optimizer32bit(requires_cuda, dim1, dim2, gtype, optim_name):
pytest.skip()
if dim1 == 1 and dim2 == 1:
return
p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
p1 = torch.randn(dim1, dim2, device=device, dtype=gtype) * 0.1
p2 = p1.clone()
p1 = p1.float()
......@@ -191,7 +193,7 @@ def test_optimizer32bit(requires_cuda, dim1, dim2, gtype, optim_name):
atol, rtol = 1e-4, 1e-3
for i in range(k):
g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01
g = torch.randn(dim1, dim2, device=device, dtype=gtype) * 0.01
p1.grad = g.clone().float()
p2.grad = g.clone()
......@@ -201,7 +203,7 @@ def test_optimizer32bit(requires_cuda, dim1, dim2, gtype, optim_name):
for name1, name2 in str2statenames[optim_name]:
torch.testing.assert_close(
torch_optimizer.state[p1][name1],
bnb_optimizer.state[p2][name2].cuda(),
bnb_optimizer.state[p2][name2].to(device),
atol=atol,
rtol=rtol,
)
......@@ -247,7 +249,8 @@ def test_optimizer32bit(requires_cuda, dim1, dim2, gtype, optim_name):
@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2"))
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=describe_dtype)
def test_global_config(requires_cuda, dim1, dim2, gtype):
@pytest.mark.parametrize("device", get_available_devices())
def test_global_config(dim1, dim2, gtype, device):
if dim1 == 1 and dim2 == 1:
return
p1 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
......@@ -263,9 +266,9 @@ def test_global_config(requires_cuda, dim1, dim2, gtype):
bnb.optim.GlobalOptimManager.get_instance().override_config(p3, "optim_bits", 8)
bnb.optim.GlobalOptimManager.get_instance().register_parameters([p1, p2, p3])
p1 = p1.cuda()
p2 = p2.cuda()
p3 = p3.cuda()
p1 = p1.to(device)
p2 = p2.to(device)
p3 = p3.to(device)
adam2 = bnb.optim.Adam([p1, p2, p3], lr, (beta1, beta2), eps)
......@@ -275,9 +278,9 @@ def test_global_config(requires_cuda, dim1, dim2, gtype):
atol, rtol = 1e-4, 1e-3
for i in range(50):
g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001
g2 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001
g3 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001
g1 = torch.randn(dim1, dim2, device=device, dtype=gtype) * 0.1 + 0.001
g2 = torch.randn(dim1, dim2, device=device, dtype=gtype) * 0.1 + 0.001
g3 = torch.randn(dim1, dim2, device=device, dtype=gtype) * 0.1 + 0.001
p1.grad = g1
p2.grad = g2
p3.grad = g3
......@@ -302,13 +305,14 @@ optimizer_names_8bit = [
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
@pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
def test_optimizer8bit(requires_cuda, dim1, dim2, gtype, optim_name):
@pytest.mark.parametrize("device", get_available_devices())
def test_optimizer8bit(dim1, dim2, gtype, optim_name, device):
torch.set_printoptions(precision=6)
if dim1 == 1 and dim2 == 1:
return
p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
p1 = torch.randn(dim1, dim2, device=device, dtype=gtype) * 0.1
p2 = p1.clone()
p1 = p1.float()
blocksize = 256
......@@ -330,12 +334,12 @@ def test_optimizer8bit(requires_cuda, dim1, dim2, gtype, optim_name):
relerrors = []
for i in range(50):
g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01
g = torch.randn(dim1, dim2, device=device, dtype=gtype) * 0.01
p1.grad = g.clone().float()
p2.grad = g.clone()
bnb_optimizer.step()
torch_optimizer.step()
bnb_optimizer.step()
# since Lion can have pretty noisy updates where things lie at the boundary
assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=0)
......@@ -368,7 +372,7 @@ def test_optimizer8bit(requires_cuda, dim1, dim2, gtype, optim_name):
)
num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0
# assert num_not_close.sum().item() < 20
assert num_not_close.sum().item() < 20
dequant_states.append(s1.clone())
err = torch.abs(p1 - p2)
......@@ -549,25 +553,25 @@ optimizer_names_benchmark = [
@pytest.mark.parametrize("gtype", [torch.float32, torch.bfloat16, torch.float16], ids=describe_dtype)
@pytest.mark.parametrize("optim_name", optimizer_names_benchmark, ids=id_formatter("opt"))
@pytest.mark.benchmark
def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
def test_benchmark_blockwise(dim1, dim2, gtype, optim_name, device):
if dim1 == 1 and dim2 == 1:
return
p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
p1 = torch.randn(dim1, dim2, device=device, dtype=gtype) * 0.1
bnb_optimizer = str2optimizers[optim_name][1]([p1])
g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01
g = torch.randn(dim1, dim2, device=device, dtype=gtype) * 0.01
p1.grad = g
total_steps = 500
for i in range(total_steps):
if i == total_steps // 5:
# 100 iterations for burn-in
torch.cuda.synchronize()
sync_gpu(p1)
t0 = time.time()
bnb_optimizer.step()
torch.cuda.synchronize()
sync_gpu(p1)
s = time.time() - t0
print("")
params = (total_steps - total_steps // 5) * dim1 * dim2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment