Commit a24aae30 authored by Jeongseok Kang's avatar Jeongseok Kang
Browse files

Merge branch 'main' into fix/libcuda-to-torch

parents 2b4cc256 4395d68c
...@@ -139,17 +139,6 @@ if [ ! -f "./bitsandbytes/libbitsandbytes_cuda121.so" ]; then ...@@ -139,17 +139,6 @@ if [ ! -f "./bitsandbytes/libbitsandbytes_cuda121.so" ]; then
fi fi
make clean
export CUDA_HOME=$BASE_PATH/cuda-10.2
make cuda10x_nomatmul CUDA_VERSION=102
if [ ! -f "./bitsandbytes/libbitsandbytes_cuda102_nocublaslt.so" ]; then
# Control will enter here if $DIRECTORY doesn't exist.
echo "Compilation unsuccessul!" 1>&2
exit 64
fi
make clean make clean
export CUDA_HOME=$BASE_PATH/cuda-11.0 export CUDA_HOME=$BASE_PATH/cuda-11.0
make cuda110_nomatmul CUDA_VERSION=110 make cuda110_nomatmul CUDA_VERSION=110
......
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
MAX_NEW_TOKENS = 128
model_name = 'decapoda-research/llama-7b-hf'
text = 'Hamburg is in which country?\n'
tokenizer = AutoTokenizer.from_pretrained(model_name)
input_ids = tokenizer(text, return_tensors="pt").input_ids
free_in_GB = int(torch.cuda.mem_get_info()[0]/1024**3)
max_memory = f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB'
n_gpus = torch.cuda.device_count()
max_memory = {i: max_memory for i in range(n_gpus)}
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map='auto',
load_in_8bit=True,
max_memory=max_memory
)
generated_ids = model.generate(input_ids, max_length=MAX_NEW_TOKENS)
print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
...@@ -18,10 +18,10 @@ def read(fname): ...@@ -18,10 +18,10 @@ def read(fname):
setup( setup(
name=f"bitsandbytes", name=f"bitsandbytes",
version=f"0.38.0", version=f"0.39.1",
author="Tim Dettmers", author="Tim Dettmers",
author_email="dettmers@cs.washington.edu", author_email="dettmers@cs.washington.edu",
description="8-bit optimizers and matrix multiplication routines.", description="k-bit optimizers and matrix multiplication routines.",
license="MIT", license="MIT",
keywords="gpu optimizers optimization 8-bit quantization compression", keywords="gpu optimizers optimization 8-bit quantization compression",
url="https://github.com/TimDettmers/bitsandbytes", url="https://github.com/TimDettmers/bitsandbytes",
......
...@@ -97,7 +97,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): ...@@ -97,7 +97,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
B.grad = None B.grad = None
if req_grad[0]: if req_grad[0]:
torch.testing.assert_allclose( torch.testing.assert_close(
gradA1, gradA2, atol=0.015, rtol=0.1 gradA1, gradA2, atol=0.015, rtol=0.1
) )
if req_grad[1]: if req_grad[1]:
...@@ -106,7 +106,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): ...@@ -106,7 +106,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
assert (idx == 0).sum().item() < n * 0.1 assert (idx == 0).sum().item() < n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx == 0).sum().item() < n * 0.02 assert (idx == 0).sum().item() < n * 0.02
torch.testing.assert_allclose( torch.testing.assert_close(
gradB1, gradB2, atol=0.18, rtol=0.3 gradB1, gradB2, atol=0.18, rtol=0.3
) )
...@@ -135,7 +135,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): ...@@ -135,7 +135,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
n = out_bnb.numel() n = out_bnb.numel()
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1) idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
assert (idx == 0).sum().item() < n * 0.01 assert (idx == 0).sum().item() < n * 0.01
torch.testing.assert_allclose( torch.testing.assert_close(
out_bnb, out_torch, atol=0.027, rtol=0.2 out_bnb, out_torch, atol=0.027, rtol=0.2
) )
...@@ -159,7 +159,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): ...@@ -159,7 +159,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
B.grad = None B.grad = None
if req_grad[0]: if req_grad[0]:
torch.testing.assert_allclose( torch.testing.assert_close(
gradA1, gradA2, atol=0.015, rtol=0.1 gradA1, gradA2, atol=0.015, rtol=0.1
) )
if req_grad[1]: if req_grad[1]:
...@@ -218,7 +218,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): ...@@ -218,7 +218,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
B.grad = None B.grad = None
if req_grad[0]: if req_grad[0]:
torch.testing.assert_allclose( torch.testing.assert_close(
gradA1, gradA2, atol=0.015, rtol=0.1 gradA1, gradA2, atol=0.015, rtol=0.1
) )
if req_grad[1]: if req_grad[1]:
...@@ -239,8 +239,8 @@ dim4 = torch.randint(32, 96, size=(n,)).tolist() ...@@ -239,8 +239,8 @@ dim4 = torch.randint(32, 96, size=(n,)).tolist()
dim2.append(0) dim2.append(0)
decomp = [0.0, 6.0] decomp = [0.0, 6.0]
funcs = [(torch.matmul, bnb.matmul)] funcs = [(torch.matmul, bnb.matmul), (torch.matmul, bnb.research.switchback_bnb)]
str_funcs = ["matmul"] str_funcs = ["matmullt", 'switchback_bnb']
req_grad = [(False, False), (True, False), (True, True), (False, True)] req_grad = [(False, False), (True, False), (True, True), (False, True)]
req_grad = list(product([True, False], repeat=3)) req_grad = list(product([True, False], repeat=3))
req_grad_str = [] req_grad_str = []
...@@ -407,7 +407,7 @@ def test_matmullt( ...@@ -407,7 +407,7 @@ def test_matmullt(
bias.grad = None bias.grad = None
if req_grad[0]: if req_grad[0]:
torch.testing.assert_allclose( torch.testing.assert_close(
gradA1, gradA2, atol=0.015, rtol=0.1 gradA1, gradA2, atol=0.015, rtol=0.1
) )
if req_grad[1]: if req_grad[1]:
...@@ -423,9 +423,204 @@ def test_matmullt( ...@@ -423,9 +423,204 @@ def test_matmullt(
assert (idx == 0).sum().item() <= n * 0.1 assert (idx == 0).sum().item() <= n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx == 0).sum().item() <= n * 0.02 assert (idx == 0).sum().item() <= n * 0.02
torch.testing.assert_allclose( torch.testing.assert_close(
gradB1, gradB2, atol=0.18, rtol=0.3 gradB1, gradB2, atol=0.18, rtol=0.3
) )
if req_grad[2]: if req_grad[2]:
torch.testing.assert_allclose(gradBias1, gradBias2) torch.testing.assert_close(gradBias1, gradBias2)
n = 1
k = 3
dim1 = torch.randint(16, 64, size=(n,)).tolist()
dim2 = torch.randint(32, 96, size=(n,)).tolist()
dim3 = torch.randint(32, 96, size=(n,)).tolist()
dim4 = torch.randint(32, 96, size=(n,)).tolist()
dim2.append(0)
funcs = [(torch.matmul, bnb.matmul_4bit)]
str_funcs = ["matmul"]
req_grad = list(product([True, False], repeat=3))
req_grad_str = []
for c in req_grad:
strval = ''
for v in c:
if v == True: strval += 'T'
else: strval += 'F'
req_grad_str.append(strval)
transpose = [(False, True), (False, False)]
str_transpose = ["NT", "NN"]
dtype = [torch.float16, torch.float32]
compress_statistics = [False, True]
has_fp16_weights = [True, False]
has_bias = [True, False]
quant_type = ['fp4', 'nf4']
values = list(product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics, quant_type))
str_values = list(product(dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose, has_bias, compress_statistics, quant_type))
names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}_has_bias_{}_compress_statistics_{}_quant_type_{}".format(*vals) for vals in str_values]
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
@pytest.mark.parametrize( "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics, quant_type", values, ids=names)
def test_matmul_4bit( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics, quant_type):
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
if has_bias == False:
req_grad = list(req_grad)
req_grad[2] = False
for i in range(k):
# normal multiply
if funcs[0] in [torch.mm, torch.matmul]:
A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype)
B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype)
target = torch.randn(size=(dim2, dim4), device="cuda", requires_grad=req_grad[1], dtype=dtype)
bias = None
bias2 = None
if has_bias:
bias = torch.randn(dim4, device='cuda', dtype=dtype, requires_grad=req_grad[2])
bias2 = bias.clone()
torch.nn.init.xavier_uniform_(B)
B2, quant_state = bnb.functional.quantize_4bit(B, compress_statistics=compress_statistics, quant_type=quant_type)
if not transpose[0] and transpose[1]:
out_torch = funcs[0](A, B.t())
out_bnb = funcs[1](A, B2.t(), quant_state, bias=bias2)
elif not transpose[0] and not transpose[1]:
out_torch = funcs[0](A, B)
out_bnb = funcs[1](A, B2, quant_state, bias=bias2)
if has_bias:
out_torch += bias
assert out_bnb.dtype == A.dtype, f"bnb matmullt received {A.dtype} but returned {out_bnb.dtype}"
n = out_bnb.numel()
err = torch.abs(out_bnb - out_torch).float().mean().item()
if n > 0:
assert err < 0.115
#assert err < 0.20
if any(req_grad):
out_bnb.data.copy_(out_torch)
torch.cuda.synchronize()
loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
loss_bnb.backward()
gradA1 = A.grad
gradB1 = B.grad
A.grad = None
B.grad = None
if has_bias:
gradBias1 = bias.grad
bias.grad = None
loss_torch = torch.nn.functional.mse_loss( out_torch, target ).mean()
loss_torch.backward()
gradA2 = A.grad
gradB2 = B.grad
A.grad = None
B.grad = None
if has_bias:
gradBias2 = bias.grad
bias.grad = None
if req_grad[0]:
torch.testing.assert_close( gradA1, gradA2, atol=0.015, rtol=0.1)
if req_grad[2]:
torch.testing.assert_close(gradBias1, gradBias2)
funcs = [(torch.matmul, bnb.research.matmul_fp8_mixed), (torch.matmul, bnb.research.matmul_fp8_global)]
str_funcs = ["matmul_fp8_mixed", 'matmul_fp8_global']
req_grad = list(product([True, False], repeat=3))
req_grad_str = []
for c in req_grad:
strval = ''
for v in c:
if v == True: strval += 'T'
else: strval += 'F'
req_grad_str.append(strval)
transpose = [(False, True), (False, False)]
str_transpose = ["NT", "NN"]
dtype = [torch.float16, torch.float32]
has_fp16_weights = [True, False]
values = list(product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose))
str_values = list(product(dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose))
names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}".format(*vals) for vals in str_values]
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
@pytest.mark.parametrize( "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose", values, ids=names)
def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
req_grad = list(req_grad)
req_grad[2] = False
for i in range(k):
# normal multiply
if funcs[0] in [torch.mm, torch.matmul]:
A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype)
B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype)
target = torch.randn(size=(dim2, dim4), device="cuda", requires_grad=req_grad[1], dtype=dtype)
torch.nn.init.xavier_uniform_(B)
fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(A.device)
bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(A.device)
if not transpose[0] and transpose[1]:
out_torch = funcs[0](A, B.t())
out_bnb = funcs[1](A, B.t(), fw_code, bw_code)
elif not transpose[0] and not transpose[1]:
out_torch = funcs[0](A, B)
out_bnb = funcs[1](A, B, fw_code, bw_code)
assert out_bnb.dtype == A.dtype, f"bnb matmullt received {A.dtype} but returned {out_bnb.dtype}"
n = out_bnb.numel()
err = torch.abs(out_bnb - out_torch).float().mean().item()
if n > 0:
assert err < 0.115
#assert err < 0.20
if any(req_grad):
out_bnb.data.copy_(out_torch)
torch.cuda.synchronize()
loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
loss_bnb.backward()
gradA1 = A.grad
gradB1 = B.grad
A.grad = None
B.grad = None
loss_torch = torch.nn.functional.mse_loss( out_torch, target ).mean()
loss_torch.backward()
gradA2 = A.grad
gradB2 = B.grad
A.grad = None
B.grad = None
if req_grad[0]:
torch.testing.assert_close( gradA1, gradA2, atol=0.015, rtol=0.1)
if req_grad[1]:
n = gradB1.numel()
if dim2 > 0:
assert torch.abs(gradB1).sum() > 0.0
assert torch.abs(gradB2).sum() > 0.0
else:
assert torch.abs(gradB1).sum() == 0.0
assert torch.abs(gradB2).sum() == 0.0
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
assert (idx == 0).sum().item() <= n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx == 0).sum().item() <= n * 0.02
grad_err = (gradB1-gradB2).abs().mean()
assert grad_err.item() < 0.003
torch.testing.assert_close(
gradB1, gradB2, atol=0.18, rtol=0.3
)
...@@ -18,12 +18,15 @@ torch.set_printoptions( ...@@ -18,12 +18,15 @@ torch.set_printoptions(
k = 20 k = 20
def assert_all_approx_close(a, b, rtol=1e-3, atol=1e-3, count=0): def assert_all_approx_close(a, b, rtol=1e-3, atol=1e-3, count=0, throw=True):
idx = torch.isclose(a, b, rtol, atol) idx = torch.isclose(a, b, rtol, atol)
sumval = (idx == 0).sum().item() sumval = (idx == 0).sum().item()
if sumval > count: if sumval > count:
print(f"Too many values not close: assert {sumval} < {count}") if throw:
torch.testing.assert_allclose(a, b, rtol, atol) print(f"Too many values not close: assert {sumval} < {count}")
torch.testing.assert_close(a, b, rtol, atol)
return sumval
class FFN(torch.nn.Module): class FFN(torch.nn.Module):
...@@ -97,7 +100,7 @@ def test_estimate_quantiles(dtype): ...@@ -97,7 +100,7 @@ def test_estimate_quantiles(dtype):
code = F.estimate_quantiles(A) code = F.estimate_quantiles(A)
percs = torch.linspace(1 / 512, 511 / 512, 256, device=A.device) percs = torch.linspace(1 / 512, 511 / 512, 256, device=A.device)
torch.testing.assert_allclose(percs, code, atol=1e-3, rtol=1e-2) torch.testing.assert_close(percs, code, atol=1e-3, rtol=1e-2)
A = torch.randn(1024, 1024, device="cuda") A = torch.randn(1024, 1024, device="cuda")
A = A.to(dtype) A = A.to(dtype)
...@@ -122,7 +125,7 @@ def test_quantile_quantization(): ...@@ -122,7 +125,7 @@ def test_quantile_quantization():
C = F.quantize_no_absmax(A1, code) C = F.quantize_no_absmax(A1, code)
A2 = F.dequantize_no_absmax(C, code) A2 = F.dequantize_no_absmax(C, code)
diff = torch.abs(A1 - A2).mean().item() diff = torch.abs(A1 - A2).mean().item()
torch.testing.assert_allclose(A1, A2, atol=5e-3, rtol=0) torch.testing.assert_close(A1, A2, atol=5e-3, rtol=0)
assert diff < 0.001 assert diff < 0.001
...@@ -146,63 +149,49 @@ def test_dynamic_quantization(): ...@@ -146,63 +149,49 @@ def test_dynamic_quantization():
C, S = F.quantize(A1) C, S = F.quantize(A1)
A2 = F.dequantize(C, S) A2 = F.dequantize(C, S)
diff = torch.abs(A1 - A2).mean().item() diff = torch.abs(A1 - A2).mean().item()
torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0) torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0)
assert diff < 0.004 assert diff < 0.004
def test_dynamic_blockwise_quantization():
@pytest.mark.parametrize("nested", [False, True], ids=["False", "True"])
@pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64])
def test_dynamic_blockwise_quantization(nested, blocksize):
#print('') #print('')
for blocksize in [4096, 2048, 1024, 512]:
diffs = []
reldiffs = []
for i in range(100):
A1 = torch.randn(1024, 1024, device="cuda")
C, S = F.quantize_blockwise(A1, blocksize=blocksize)
A2 = F.dequantize_blockwise(C, S, blocksize=blocksize)
diff = torch.abs(A1 - A2)
reldiff = diff / torch.abs(A1 + 1e-8)
diffs.append(diff.mean().item())
reldiffs.append(reldiff.mean().item())
abserr = sum(diffs)/len(diffs)
relerr = sum(reldiffs)/len(reldiffs)
assert abserr < 0.011
assert relerr < 0.018
#print('randn', blocksize, sum(diffs)/len(diffs))
#print('randn', blocksize, sum(reldiffs)/len(reldiffs))
diffs = []
for i in range(100):
A1 = torch.rand(1024, 1024, device="cuda")
C, S = F.quantize_blockwise(A1, blocksize=blocksize)
A2 = F.dequantize_blockwise(C, S, blocksize=blocksize)
diff = torch.abs(A1 - A2)
reldiff = diff / torch.abs(A1 + 1e-8)
diffs.append(diff.mean().item())
reldiffs.append(reldiff.mean().item())
#torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0)
abserr = sum(diffs)/len(diffs)
relerr = sum(reldiffs)/len(reldiffs)
assert abserr < 0.0035
assert relerr < 0.015
#print('rand', blocksize, sum(diffs)/len(diffs))
#print('rand', blocksize, sum(reldiffs)/len(reldiffs))
def test_dynamic_blockwise_stochastic_quantization():
diffs = [] diffs = []
reldiffs = [] reldiffs = []
rand = torch.rand(1024).cuda()
for i in range(100): for i in range(100):
A1 = torch.randn(1024, 1024, device="cuda") A1 = torch.randn(1024, 1024, device="cuda")
C1, S1 = F.quantize_blockwise(A1, rand=rand) C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested)
C2, S2 = F.quantize_blockwise(A1) A2 = F.dequantize_blockwise(C, S)
# a maximunm distance of quantized values of 1 diff = torch.abs(A1 - A2)
torch.testing.assert_allclose(C1, C2, atol=1, rtol=0) reldiff = diff / torch.abs(A1 + 1e-8)
fraction_smaller = (C1 < C2).float().sum() / C1.numel() diffs.append(diff.mean().item())
fraction_larger = (C1 > C2).float().sum() / C1.numel() reldiffs.append(reldiff.mean().item())
torch.testing.assert_allclose( abserr = sum(diffs)/len(diffs)
fraction_larger, fraction_smaller, atol=0.01, rtol=0 relerr = sum(reldiffs)/len(reldiffs)
) assert abserr < 0.011
assert relerr < 0.018
#print('nested=', nested, 'randn', blocksize, sum(diffs)/len(diffs))
#print('nested=', nested, 'randn', blocksize, sum(reldiffs)/len(reldiffs))
diffs = []
for i in range(100):
A1 = torch.rand(1024, 1024, device="cuda")
C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested)
A2 = F.dequantize_blockwise(C, S)
diff = torch.abs(A1 - A2)
reldiff = diff / torch.abs(A1 + 1e-8)
diffs.append(diff.mean().item())
reldiffs.append(reldiff.mean().item())
#torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0)
abserr = sum(diffs)/len(diffs)
relerr = sum(reldiffs)/len(reldiffs)
assert abserr < 0.0035
assert relerr < 0.015
#print('nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs))
#print('nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs))
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -231,9 +220,9 @@ def test_percentile_clipping(gtype): ...@@ -231,9 +220,9 @@ def test_percentile_clipping(gtype):
vals, idx = torch.sort(gnorm_vec1) vals, idx = torch.sort(gnorm_vec1)
clip1 = vals[percentile] clip1 = vals[percentile]
torch.testing.assert_allclose(gnorm_vec1, torch.sqrt(gnorm_vec2)) torch.testing.assert_close(gnorm_vec1, torch.sqrt(gnorm_vec2))
torch.testing.assert_allclose(clip1, clip2) torch.testing.assert_close(clip1, clip2)
torch.testing.assert_allclose(gnorm1, gnorm2) torch.testing.assert_close(gnorm1, gnorm2)
def quant(x): def quant(x):
...@@ -315,7 +304,7 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched): ...@@ -315,7 +304,7 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched):
dim2 = dim2 - (dim2 % 32) dim2 = dim2 - (dim2 % 32)
errors = [] errors = []
relerrors = [] relerrors = []
print("") #print("")
for i in range(5): for i in range(5):
if batched: if batched:
A = torch.normal(0, 0.5, size=(32, dim1, dim2 // 32), device="cuda") A = torch.normal(0, 0.5, size=(32, dim1, dim2 // 32), device="cuda")
...@@ -327,7 +316,7 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched): ...@@ -327,7 +316,7 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched):
B = torch.normal(0, 0.5, size=(dim2, dim1), device="cuda") B = torch.normal(0, 0.5, size=(dim2, dim1), device="cuda")
maxA, Ac = quant_methods[0](A, 1) maxA, Ac = quant_methods[0](A, 1)
maxB, Bc = quant_methods[1](B, 0) maxB, Bc = quant_methods[1](B, 0)
torch.testing.assert_allclose( torch.testing.assert_close(
quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05 quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05
) )
if batched: if batched:
...@@ -344,8 +333,8 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched): ...@@ -344,8 +333,8 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched):
relerr = err / torch.abs(out2) relerr = err / torch.abs(out2)
errors.append(err.mean().item()) errors.append(err.mean().item())
relerrors.append(relerr.mean().item()) relerrors.append(relerr.mean().item())
print(mean(errors)) #print(mean(errors))
print(mean(relerrors)) #print(mean(relerrors))
def test_stable_embedding(): def test_stable_embedding():
...@@ -398,7 +387,7 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim): ...@@ -398,7 +387,7 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
out2 = torch.matmul(A.t().float(), B.t().float()) out2 = torch.matmul(A.t().float(), B.t().float())
out = F.igemm(A.t(), B.t()) out = F.igemm(A.t(), B.t())
torch.testing.assert_allclose(out.float(), out2) torch.testing.assert_close(out.float(), out2)
for i in range(k): for i in range(k):
shapeA = (batch_dim, seq_dim, hidden_dim) shapeA = (batch_dim, seq_dim, hidden_dim)
...@@ -416,7 +405,7 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim): ...@@ -416,7 +405,7 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
out2 = torch.matmul(A.float(), B.t().float()) out2 = torch.matmul(A.float(), B.t().float())
out = F.igemm(A, B.t()) out = F.igemm(A, B.t())
torch.testing.assert_allclose(out.float(), out2) torch.testing.assert_close(out.float(), out2)
n = 3 n = 3
...@@ -447,7 +436,7 @@ def test_dim3_igemm(seq_dim, hidden_dim, batch_dim): ...@@ -447,7 +436,7 @@ def test_dim3_igemm(seq_dim, hidden_dim, batch_dim):
) )
out = F.igemm(A, B, out=iout) out = F.igemm(A, B, out=iout)
torch.testing.assert_allclose(out.float(), out2) torch.testing.assert_close(out.float(), out2)
n = 2 n = 2
...@@ -572,7 +561,7 @@ def test_ibmm(dim1, dim2, dim3, dim4, transpose): ...@@ -572,7 +561,7 @@ def test_ibmm(dim1, dim2, dim3, dim4, transpose):
A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float() A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float()
) )
out = F.igemm(A.permute([0, 2, 1]), B.permute([0, 2, 1])) out = F.igemm(A.permute([0, 2, 1]), B.permute([0, 2, 1]))
torch.testing.assert_allclose(out.float(), out2.float()) torch.testing.assert_close(out.float(), out2.float())
n = 1 n = 1
...@@ -630,9 +619,9 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans ...@@ -630,9 +619,9 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
out, S = F.nvidia_transform(A, to_order=orderOut) out, S = F.nvidia_transform(A, to_order=orderOut)
if orderOut == "row": if orderOut == "row":
torch.testing.assert_allclose(A.flatten(), out.flatten()) torch.testing.assert_close(A.flatten(), out.flatten())
elif orderOut == "col": elif orderOut == "col":
torch.testing.assert_allclose(A.t().flatten(), out.flatten()) torch.testing.assert_close(A.t().flatten(), out.flatten())
elif orderOut == "col32": elif orderOut == "col32":
if dims == 2: if dims == 2:
n = A.shape[0] * (A.shape[1] + (32 - (A.shape[1] % 32))) n = A.shape[0] * (A.shape[1] + (32 - (A.shape[1] % 32)))
...@@ -665,14 +654,14 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans ...@@ -665,14 +654,14 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
assert A.flatten()[i + j] == A[row, col] assert A.flatten()[i + j] == A[row, col]
# assert A.flatten()[i+j] == out.flatten()[row2+col2] # assert A.flatten()[i+j] == out.flatten()[row2+col2]
# torch.testing.assert_allclose(A.flatten()[i+j], A[row, col]) # torch.testing.assert_close(A.flatten()[i+j], A[row, col])
# torch.testing.assert_allclose(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset]) # torch.testing.assert_close(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset])
if orderOut == "col32": if orderOut == "col32":
out2, S = F.nvidia_transform( out2, S = F.nvidia_transform(
out, from_order=orderOut, to_order="row", state=S out, from_order=orderOut, to_order="row", state=S
) )
torch.testing.assert_allclose(A, out2) torch.testing.assert_close(A, out2)
n = 1 n = 1
...@@ -716,7 +705,7 @@ def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb): ...@@ -716,7 +705,7 @@ def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb):
B2, SB = F.transform(B, "col_turing") B2, SB = F.transform(B, "col_turing")
C2, SC = F.igemmlt(A2, B2, SA, SB) C2, SC = F.igemmlt(A2, B2, SA, SB)
C3, S = F.nvidia_transform(C2, "row", state=SC) C3, S = F.nvidia_transform(C2, "row", state=SC)
torch.testing.assert_allclose(C1, C3.float()) torch.testing.assert_close(C1, C3.float())
# transpose # transpose
B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to( B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to(
...@@ -727,7 +716,7 @@ def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb): ...@@ -727,7 +716,7 @@ def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb):
B2t, SBt = F.transform(B, "col_turing", transpose=True) B2t, SBt = F.transform(B, "col_turing", transpose=True)
C2, SC = F.igemmlt(A2, B2t, SA, SBt) C2, SC = F.igemmlt(A2, B2t, SA, SBt)
C3, S = F.nvidia_transform(C2, "row", state=SC) C3, S = F.nvidia_transform(C2, "row", state=SC)
torch.testing.assert_allclose(C1, C3.float()) torch.testing.assert_close(C1, C3.float())
dim1 = [32] dim1 = [32]
...@@ -773,7 +762,7 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims): ...@@ -773,7 +762,7 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims):
# print(C1.flatten()[:10]) # print(C1.flatten()[:10])
# print(C2.flatten()[:10]) # print(C2.flatten()[:10])
# torch.testing.assert_allclose(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05) # torch.testing.assert_close(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05)
# transpose # transpose
# B = torch.randint(-128, 127, size=(dim3, dim4), device='cuda').to(torch.int8) # B = torch.randint(-128, 127, size=(dim3, dim4), device='cuda').to(torch.int8)
...@@ -782,7 +771,7 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims): ...@@ -782,7 +771,7 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims):
# B2t, SBt = F.transform2(B, 'col_turing', transpose=True) # B2t, SBt = F.transform2(B, 'col_turing', transpose=True)
# C2, SC = F.igemmlt(A2, B2t, SA, SBt) # C2, SC = F.igemmlt(A2, B2t, SA, SBt)
# C3, S = F.transform(C2, 'row', state=SC) # C3, S = F.transform(C2, 'row', state=SC)
# torch.testing.assert_allclose(C1, C3.float()) # torch.testing.assert_close(C1, C3.float())
batch_size = 2 batch_size = 2
...@@ -1001,7 +990,7 @@ def test_dequant_mm(dim1, dim4, dims, formatB, has_bias): ...@@ -1001,7 +990,7 @@ def test_dequant_mm(dim1, dim4, dims, formatB, has_bias):
#assert (count / n < p), f"error in more than {p} of elements: {count}/{n}={count/n}" #assert (count / n < p), f"error in more than {p} of elements: {count}/{n}={count/n}"
C5 = F.mm_dequant(C2, SC, maxA.flatten(), maxB.flatten(), bias=bias) C5 = F.mm_dequant(C2, SC, maxA.flatten(), maxB.flatten(), bias=bias)
#torch.testing.assert_allclose(C5, C4, atol=0.015, rtol=0.1) #torch.testing.assert_close(C5, C4, atol=0.015, rtol=0.1)
n = C5.numel() n = C5.numel()
assert_all_approx_close(C1, C4, atol=0.015, rtol=0.1, count=int(0.01*n)) assert_all_approx_close(C1, C4, atol=0.015, rtol=0.1, count=int(0.01*n))
...@@ -1051,16 +1040,16 @@ def test_colrow_absmax(dim1, dim2, dims): ...@@ -1051,16 +1040,16 @@ def test_colrow_absmax(dim1, dim2, dims):
) )
nnz_block_ptr1[1:] = nnz_rows1_counts.cumsum(0) nnz_block_ptr1[1:] = nnz_rows1_counts.cumsum(0)
torch.testing.assert_allclose(col_stats1_trunc, col_stats2) torch.testing.assert_close(col_stats1_trunc, col_stats2)
torch.testing.assert_allclose(row_stats1_trunc, row_stats2) torch.testing.assert_close(row_stats1_trunc, row_stats2)
torch.testing.assert_allclose(nnz_block_ptr1, nnz_block_ptr2) torch.testing.assert_close(nnz_block_ptr1.int(), nnz_block_ptr2)
row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax( row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(
A, threshold=0.0 A, threshold=0.0
) )
torch.testing.assert_allclose(col_stats1, col_stats2) torch.testing.assert_close(col_stats1, col_stats2)
torch.testing.assert_allclose(row_stats1, row_stats2) torch.testing.assert_close(row_stats1, row_stats2)
assert nnz_block_ptr2 is None assert nnz_block_ptr2 is None
...@@ -1084,8 +1073,8 @@ def test_double_quant(dim1, dim2): ...@@ -1084,8 +1073,8 @@ def test_double_quant(dim1, dim2):
CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A) CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
# max difference is 1 due to rounding differences # max difference is 1 due to rounding differences
torch.testing.assert_allclose(CA, out_row1, atol=1, rtol=0) torch.testing.assert_close(CA, out_row1, atol=1, rtol=0)
torch.testing.assert_allclose(CAt, out_col1, atol=1, rtol=0) torch.testing.assert_close(CAt, out_col1, atol=1, rtol=0)
n = CAt.numel() n = CAt.numel()
num_not_close_rows = ( num_not_close_rows = (
...@@ -1108,8 +1097,8 @@ def test_double_quant(dim1, dim2): ...@@ -1108,8 +1097,8 @@ def test_double_quant(dim1, dim2):
) )
assert False assert False
torch.testing.assert_allclose(Srow.flatten(), statsA) torch.testing.assert_close(Srow.flatten().float(), statsA)
torch.testing.assert_allclose(Scol.flatten(), statsAt) torch.testing.assert_close(Scol.flatten().float(), statsAt)
n = 4 n = 4
...@@ -1134,10 +1123,10 @@ def test_integrated_igemmlt(dim1, dim4, inner): ...@@ -1134,10 +1123,10 @@ def test_integrated_igemmlt(dim1, dim4, inner):
A1, maxA = F.vectorwise_quant(A, dim=1) A1, maxA = F.vectorwise_quant(A, dim=1)
B1, maxB = F.vectorwise_quant(B, dim=1) B1, maxB = F.vectorwise_quant(B, dim=1)
torch.testing.assert_allclose(maxA.flatten(), stats1a) torch.testing.assert_close(maxA.flatten().float(), stats1a)
torch.testing.assert_allclose(maxB.flatten(), stats2a) torch.testing.assert_close(maxB.flatten().float(), stats2a)
torch.testing.assert_allclose(C1a, A1, rtol=0, atol=1) torch.testing.assert_close(C1a, A1, rtol=0, atol=1)
torch.testing.assert_allclose(C2a, B1, rtol=0, atol=1) torch.testing.assert_close(C2a, B1, rtol=0, atol=1)
A2, SA = F.nvidia_transform(C1a, "col32") A2, SA = F.nvidia_transform(C1a, "col32")
B2, SB = F.nvidia_transform(C2a, "col_turing") B2, SB = F.nvidia_transform(C2a, "col_turing")
...@@ -1339,7 +1328,7 @@ def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose): ...@@ -1339,7 +1328,7 @@ def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
# print(out1) # print(out1)
# print(out2) # print(out2)
torch.testing.assert_allclose(out1, out2) torch.testing.assert_close(out1, out2)
n = 2 n = 2
...@@ -1401,11 +1390,11 @@ def test_coo_double_quant(dim1, dim2): ...@@ -1401,11 +1390,11 @@ def test_coo_double_quant(dim1, dim2):
A2[ A2[
coo_tensor.rowidx.long(), coo_tensor.colidx.long() coo_tensor.rowidx.long(), coo_tensor.colidx.long()
] = coo_tensor.values ] = coo_tensor.values
torch.testing.assert_allclose(A1, A2) torch.testing.assert_close(A1, A2)
A1 = A * (idx == 0) A1 = A * (idx == 0)
A2 = (CA.float() * statsA.unsqueeze(1) / 127).half() A2 = (CA.float() * statsA.unsqueeze(1) / 127).half()
torch.testing.assert_allclose( torch.testing.assert_close(
A * (idx == 0), A2, rtol=0.05, atol=1.5e-2 A * (idx == 0), A2, rtol=0.05, atol=1.5e-2
) )
...@@ -1613,7 +1602,7 @@ def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func): ...@@ -1613,7 +1602,7 @@ def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func):
idx_col = torch.randint(0, A2.shape[-1], size=(15,)) idx_col = torch.randint(0, A2.shape[-1], size=(15,))
# torch.testing.assert_allclose(out1, out2.half(), rtol=0.05, atol=0.001) # torch.testing.assert_close(out1, out2.half(), rtol=0.05, atol=0.001)
# Bt = torch.randn(dim2*4, dim2, device='cuda').half() # Bt = torch.randn(dim2*4, dim2, device='cuda').half()
# torch.cuda.synchronize() # torch.cuda.synchronize()
...@@ -1644,9 +1633,9 @@ def test_coo2csr(): ...@@ -1644,9 +1633,9 @@ def test_coo2csr():
counts = csrA.rowptr[1:] - csrA.rowptr[:-1] counts = csrA.rowptr[1:] - csrA.rowptr[:-1]
assert counts.numel() == A.shape[0] assert counts.numel() == A.shape[0]
torch.testing.assert_allclose(counts, (A2 != 0).sum(1)) torch.testing.assert_close(counts.long(), (A2 != 0).sum(1))
idx = A2 != 0 idx = A2 != 0
torch.testing.assert_allclose(A2[idx], csrA.values) torch.testing.assert_close(A2[idx], csrA.values)
def test_coo2csc(): def test_coo2csc():
...@@ -1664,10 +1653,10 @@ def test_coo2csc(): ...@@ -1664,10 +1653,10 @@ def test_coo2csc():
counts = cscA.colptr[1:] - cscA.colptr[:-1] counts = cscA.colptr[1:] - cscA.colptr[:-1]
assert counts.numel() == A.shape[1] assert counts.numel() == A.shape[1]
torch.testing.assert_allclose(counts, (A2 != 0).sum(0)) torch.testing.assert_close(counts.long(), (A2 != 0).sum(0))
# torch uses row-major -> use transpose to transfer to col-major # torch uses row-major -> use transpose to transfer to col-major
idx = A2.t() != 0 idx = A2.t() != 0
torch.testing.assert_allclose(A2.t()[idx], cscA.values) torch.testing.assert_close(A2.t()[idx], cscA.values)
n = 2 n = 2
...@@ -1717,7 +1706,7 @@ def test_spmm_coo_dequant(dim1, dim2, dtype): ...@@ -1717,7 +1706,7 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
max_count, max_idx = torch.sort(counts, descending=True) max_count, max_idx = torch.sort(counts, descending=True)
print(torch.median(max_count.float())) print(torch.median(max_count.float()))
torch.testing.assert_allclose(out2, out3, rtol=0.05, atol=0.001) torch.testing.assert_close(out2, out3, rtol=0.05, atol=0.001)
p = 200 / (2048 * 12288 * 4) p = 200 / (2048 * 12288 * 4)
n = out1.numel() n = out1.numel()
...@@ -1787,38 +1776,43 @@ def test_spmm_coo_dequant(dim1, dim2, dtype): ...@@ -1787,38 +1776,43 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
batch_size = 1 batch_size = 1
seqdim = 1 seqdim = 1
values = [] values = []
values.append((batch_size, seqdim, 768, 4 * 768)) #values.append((batch_size, seqdim, 768, 4 * 768))
# values.append((batch_size, seqdim, 1024, 4*1024)) #values.append((batch_size, seqdim, 1024, 4*1024))
# values.append((batch_size, seqdim, 1536, 4*1536)) #values.append((batch_size, seqdim, 1536, 4*1536))
# values.append((batch_size, seqdim, 2048, 4*2048)) #values.append((batch_size, seqdim, 2048, 4*2048))
# values.append((batch_size, seqdim, 2560, 4*2560)) #values.append((batch_size, seqdim, 2560, 4*2560))
# values.append((batch_size, seqdim, 4096, 4*4096)) values.append((batch_size, seqdim, 4096, 4*4096))
# values.append((batch_size, seqdim, 5140, 4*5140)) values.append((batch_size, seqdim, 5120, 4*5120))
values.append((batch_size, seqdim, 6656, 4*6656))
values.append((batch_size, seqdim, 8192, 4*8192))
#values.append((batch_size, seqdim, 5140, 4*5140))
#values.append((batch_size, seqdim, 12288, 4*12288)) #values.append((batch_size, seqdim, 12288, 4*12288))
names = [ names = ["batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values]
"batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values
]
@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names) @pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)
def test_bench_matmul(batch, seq, model, hidden): def test_bench_matmul(batch, seq, model, hidden):
iters = 128 iters = 80
formatB = F.get_special_format_str() formatB = F.get_special_format_str()
A = torch.randn(batch, seq, model, device="cuda").half() A = torch.randn(batch, seq, model, device="cuda").half()
B = torch.empty(hidden, model, dtype=torch.float16, device="cuda") B = torch.empty(hidden, model, dtype=torch.float16, device="cuda")
torch.nn.init.xavier_uniform_(B) torch.nn.init.xavier_uniform_(B)
linear8bit = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half() B_fp4, state = F.quantize_fp4(B)
B_fp4_c, state_c = F.quantize_fp4(B, compress_statistics=True)
B_nf4, state_nf4= F.quantize_nf4(B)
linear8bit = bnb.nn.Linear8bitLt(model, hidden, False, False).cuda().half()
linear8bit.eval() linear8bit.eval()
outliers = torch.randint(0, model, size=(5,)).cuda() outliers = torch.randint(0, model, size=(5,)).cuda()
A[:, :, outliers] = 8.0 A[:, :, outliers] = 8.0
linearMixedBit = ( linearMixedBit = (bnb.nn.Linear8bitLt(model, hidden, False, False, threshold=6.0).cuda().half())
bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half() #linearMixedBit.eval()
)
linearMixedBit.eval() linear8bit_train = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half()
linear8bit_train_thresh = bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half()
# warmup # warmup
for i in range(iters): for i in range(iters):
...@@ -1831,61 +1825,80 @@ def test_bench_matmul(batch, seq, model, hidden): ...@@ -1831,61 +1825,80 @@ def test_bench_matmul(batch, seq, model, hidden):
for i in range(iters): for i in range(iters):
torch.matmul(A, B.t()) torch.matmul(A, B.t())
torch.cuda.synchronize() torch.cuda.synchronize()
print( print( f"pytorch fp16: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
f"pytorch fp16: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
)
torch.cuda.synchronize()
t0 = time.time()
for i in range(iters):
bnb.matmul(A, B)
torch.cuda.synchronize()
print(f"CB -> CxB conversion (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
torch.cuda.synchronize() torch.cuda.synchronize()
t0 = time.time() t0 = time.time()
for i in range(iters): for i in range(iters):
bnb.matmul(A, B, threshold=6.0) bnb.matmul_4bit(A, B_fp4.t(), quant_state=state)
torch.cuda.synchronize() torch.cuda.synchronize()
print(f"CB -> CxB conversion + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0)
C32A, SA = F.transform(CA, "col32")
CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B)
CxB, SB = F.transform(CB, to_order=formatB)
torch.cuda.synchronize() torch.cuda.synchronize()
t0 = time.time() t0 = time.time()
for i in range(iters): for i in range(iters):
out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) bnb.matmul_4bit(A, B_fp4.t(), quant_state=state_c)
torch.cuda.synchronize() torch.cuda.synchronize()
print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") print( f"bnb fp4 + compressed stats: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
BA, statsB = F.vectorwise_quant(B, dim=1)
CxB, SB = F.nvidia_transform(CB, to_order=formatB)
torch.cuda.synchronize() torch.cuda.synchronize()
t0 = time.time() t0 = time.time()
for i in range(iters): for i in range(iters):
A2 = A.view(-1, A.shape[-1]).contiguous() bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4)
CA, statsA = F.vectorwise_quant(A2, dim=1)
C32A, SA = F.nvidia_transform(CA, "col32")
out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
F.vectorwise_mm_dequant(Cout, statsA, statsB.t())
torch.cuda.synchronize() torch.cuda.synchronize()
print( f"bnb nf4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
#torch.cuda.synchronize()
#t0 = time.time()
#for i in range(iters):
# bnb.matmul(A, B)
#torch.cuda.synchronize()
#print(f"CB -> CxB conversion (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#torch.cuda.synchronize()
#t0 = time.time()
#for i in range(iters):
# bnb.matmul(A, B, threshold=6.0)
#torch.cuda.synchronize()
#print(f"CB -> CxB conversion + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0)
#C32A, SA = F.transform(CA, "col32")
#CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B)
#CxB, SB = F.transform(CB, to_order=formatB)
#torch.cuda.synchronize()
#t0 = time.time()
#for i in range(iters):
# out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
#torch.cuda.synchronize()
#print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#BA, statsB = F.vectorwise_quant(B, dim=1)
#CxB, SB = F.nvidia_transform(CB, to_order=formatB)
#torch.cuda.synchronize()
#t0 = time.time()
#for i in range(iters):
# A2 = A.view(-1, A.shape[-1]).contiguous()
# CA, statsA = F.vectorwise_quant(A2, dim=1)
# C32A, SA = F.nvidia_transform(CA, "col32")
# out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
# Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
# F.vectorwise_mm_dequant(Cout, statsA, statsB.t())
#torch.cuda.synchronize()
#print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") #print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear") #BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear")
CxB, SB = F.nvidia_transform(CB, to_order=formatB) #CxB, SB = F.nvidia_transform(CB, to_order=formatB)
torch.cuda.synchronize() #torch.cuda.synchronize()
t0 = time.time() #t0 = time.time()
for i in range(iters): #for i in range(iters):
A2 = A.view(-1, A.shape[-1]).contiguous() # A2 = A.view(-1, A.shape[-1]).contiguous()
CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type="linear") # CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type="linear")
C32A, SA = F.nvidia_transform(CA, "col32") # C32A, SA = F.nvidia_transform(CA, "col32")
out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) # out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32) # Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
out = Cout * statsB * statsA * (1.0 / (127 * 127)) # out = Cout * statsB * statsA * (1.0 / (127 * 127))
torch.cuda.synchronize() #torch.cuda.synchronize()
#print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") #print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
linear8bit(A) linear8bit(A)
...@@ -1894,9 +1907,7 @@ def test_bench_matmul(batch, seq, model, hidden): ...@@ -1894,9 +1907,7 @@ def test_bench_matmul(batch, seq, model, hidden):
for i in range(iters): for i in range(iters):
linear8bit(A) linear8bit(A)
torch.cuda.synchronize() torch.cuda.synchronize()
print( print( f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
f"bnb linear8bitlt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
)
linearMixedBit(A) linearMixedBit(A)
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -1904,9 +1915,23 @@ def test_bench_matmul(batch, seq, model, hidden): ...@@ -1904,9 +1915,23 @@ def test_bench_matmul(batch, seq, model, hidden):
for i in range(iters): for i in range(iters):
linearMixedBit(A) linearMixedBit(A)
torch.cuda.synchronize() torch.cuda.synchronize()
print( print( f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
f"bnb linear8bitlt with threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
) #linear8bit_train(A)
#torch.cuda.synchronize()
#t0 = time.time()
#for i in range(iters):
# linear8bit_train(A)
#torch.cuda.synchronize()
#print( f"bnb linear8bitlt (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#linear8bit_train_thresh(A)
#torch.cuda.synchronize()
#t0 = time.time()
#for i in range(iters):
# linear8bit_train(A)
#torch.cuda.synchronize()
#print( f"bnb linear8bitlt with threshold (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
def test_zeropoint(): def test_zeropoint():
def quant_zp(x): def quant_zp(x):
...@@ -2009,7 +2034,7 @@ def test_extract_outliers(): ...@@ -2009,7 +2034,7 @@ def test_extract_outliers():
assert outliers2.shape[0] == shapeA[0] assert outliers2.shape[0] == shapeA[0]
assert outliers2.shape[1] == idx.numel() assert outliers2.shape[1] == idx.numel()
torch.testing.assert_allclose(outliers1, outliers2) torch.testing.assert_close(outliers1, outliers2)
CA, SA = F.transform(A, "col_ampere") CA, SA = F.transform(A, "col_ampere")
...@@ -2018,7 +2043,7 @@ def test_extract_outliers(): ...@@ -2018,7 +2043,7 @@ def test_extract_outliers():
assert outliers2.shape[0] == shapeA[0] assert outliers2.shape[0] == shapeA[0]
assert outliers2.shape[1] == idx.numel() assert outliers2.shape[1] == idx.numel()
torch.testing.assert_allclose(outliers1, outliers2) torch.testing.assert_close(outliers1, outliers2)
...@@ -2050,7 +2075,6 @@ def test_fp8_quant(): ...@@ -2050,7 +2075,6 @@ def test_fp8_quant():
p_bits = 7-e_bits p_bits = 7-e_bits
code = F.create_fp8_map(True, e_bits, p_bits).cuda() code = F.create_fp8_map(True, e_bits, p_bits).cuda()
print(e_bits, p_bits)
abserr = [] abserr = []
relerr = [] relerr = []
for i in range(100): for i in range(100):
...@@ -2149,7 +2173,7 @@ def test_few_bit_quant(): ...@@ -2149,7 +2173,7 @@ def test_few_bit_quant():
#assert err2.mean() <= err1 #assert err2.mean() <= err1
else: else:
torch.testing.assert_allclose(q1, q2) torch.testing.assert_close(q1, q2)
#print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs)) #print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs))
#assert False #assert False
...@@ -2181,7 +2205,9 @@ def test_kbit_quantile_estimation(): ...@@ -2181,7 +2205,9 @@ def test_kbit_quantile_estimation():
def test_bench_dequantization(): def test_bench_dequantization():
a = torch.rand(1024, 1024, device='cuda').half() a = torch.rand(1024, 1024, device='cuda').half()
qa, SA = F.quantize_blockwise(a) code =F.create_fp8_map(True, 3, 0, 4).cuda()
qa, SA = F.quantize_blockwise(a, code=code)
print(qa.max())
max_theoretical_mu = 1024*1024*2/1024**3/672*1000*1000 max_theoretical_mu = 1024*1024*2/1024**3/672*1000*1000
#print(max_theoretical_mu) #print(max_theoretical_mu)
...@@ -2189,7 +2215,302 @@ def test_bench_dequantization(): ...@@ -2189,7 +2215,302 @@ def test_bench_dequantization():
torch.cuda.synchronize() torch.cuda.synchronize()
t0 = time.time() t0 = time.time()
for i in range(100): for i in range(100):
F.dequantize_blockwise(qa, SA, blocksize=2048) qa, SA = F.quantize_blockwise(a)
torch.cuda.synchronize() torch.cuda.synchronize()
#print((time.time()-t0)/1e6) #print((time.time()-t0)/1e6)
def test_fp4_quant():
vals = list(product([0, 1], repeat=4))
code = {}
for bits in vals:
result = 0
bias = 3
sign, e1, e2, p1 = bits
idx = sign*8 + e1*4 + e2*2 + p1*1
sign = -1.0 if sign else 1.0
exp = e1*2 + e2*1
if exp == 0:
# sub-normal
if p1 == 0: result = 0
else: result = sign*0.0625
else:
# normal
exp = 2**(-exp + bias + 1)
frac = 1.5 if p1 else 1.0
result = sign*exp*frac
code[idx] = result
A1 = torch.randn(1024, 1024, device='cuda').half()
qa, SA = F.quantize_fp4(A1, blocksize=64)
A2 = F.dequantize_fp4(qa, SA)
err = (A1 - A2).abs().float()
relerr = (err/A1.abs().float()).mean()
idx = err > 1.0
err = err.mean()
assert err.item() < 0.1
assert relerr.item() < 0.28
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
@pytest.mark.parametrize("quant_type", ['fp4', 'nf4'])
def test_4bit_compressed_stats(quant_type):
for blocksize in [128, 64]:
errs1 = []
errs2 = []
for i in range(10):
A1 = torch.randn(1024, 1024, device='cuda').half()
q2, SA2 = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type)
q3, SA3= F.quantize_4bit(A1, blocksize=blocksize, compress_statistics=True, quant_type=quant_type)
A2 = F.dequantize_4bit(q2, SA2, quant_type=quant_type)
A3 = F.dequantize_4bit(q3, SA3, quant_type=quant_type)
err = (A1 - A2).abs().float()
relerr = (err/(A1.abs().float()+1e-15)).mean()
err = err.mean()
errs1.append(err.item())
assert err.item() < 0.11
assert relerr.item() < 0.28
err = (A1 - A3).abs().float()
relerr = (err/(A1.abs().float()+1e-15)).mean()
err = err.mean()
errs2.append(err.item())
assert err.item() < 0.11
assert relerr.item() < 0.28
#print(sum(errs1)/len(errs1), blocksize, quant_type)
#print(sum(errs2)/len(errs2), blocksize, quant_type)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
@pytest.mark.parametrize("quant_type", ['fp4', 'nf4'])
def test_bench_4bit_dequant(quant_type):
blocksize = 256
a = torch.rand(1024*12*4, 1024*12, device='cuda').half()
qa, SA = F.quantize_4bit(a, blocksize=blocksize, quant_type=quant_type)
input_size = a.numel()/2
output_size = a.numel()*2
num_bytes = input_size+output_size
GB = num_bytes/1e9
max_theoretical_s = GB/768
#print(max_theoretical_s*1e6)
b = torch.randn(128, 1024*12, device='cuda').half()
iters = 5
torch.cuda.synchronize()
t0 = time.time()
for i in range(iters):
F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type)
#b.copy_(a)
torch.cuda.synchronize()
#print((time.time()-t0)/iters*1e6)
#torch.cuda.synchronize()
#t0 = time.time()
#for i in range(iters):
# torch.matmul(b, a.t())
#torch.cuda.synchronize()
#print((time.time()-t0)/iters*1e6)
def test_normal_map_tree():
code = F.create_normal_map()
values =code[:8].tolist() + code[-8:].tolist()
num_pivots = 1
print(values)
while num_pivots <16:
idx = list(range(16//num_pivots//2, 16, 16//num_pivots))
print(idx)
num_pivots *= 2
pivots = []
for i in idx:
pivots.append((values[i-1]+values[i])/2)
print(pivots)
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16'])
def test_cutlass3_gemm(dtype):
debug = True
#for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]:
#for dim in [4096, 5120, 6656, 8192]:
for dim in [4096]:
#for dim in [128+1]:
errs = []
relerrs = []
max_err = 0
max_relerr = 0
for i in range(100):
A = torch.randn(1, dim, dtype=dtype, device='cuda')
B = torch.randn(4*dim, dim+0, dtype=dtype, device='cuda')/math.sqrt(dim)
#B = torch.randn(1, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
#print('')
#print(A)
#print(B.t())
#A[:, :-1] = 0
#B[:, :-1] = 0
C1 = torch.matmul(A, B.t())
C2 = F.cutlass3_gemm(A, B.t())
# tensor cores are non-deterministic
# so we need to analyze errors around the mean
# to test our implementation
err = torch.abs(C1-C2)
mag = torch.abs(C1)+1e-8
relerr = err/mag
max_err = max(err.max(), max_err)
max_relerr = max(relerr.max(), max_relerr)
err = err.mean().item()
relerr = relerr.mean().item()
errs.append(err)
relerrs.append(relerr)
#if not debug and err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5:
# print('')
# print(i, err, relerr)
# print(A.flatten()[-6:])
# print(B.flatten()[-6:])
# out = A.flatten()[-6:]*B.flatten()[-6:]
# print(out)
# print(out[:-1].sum())
# print('='*80)
# print(C1.flatten()[-6:])
# print(C2.flatten()[-6:])
# #assert False, 'ERROR'
c = int(C1.numel()*0.0014*(dim/256))+1
c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=not debug)
#print(c/math.sqrt(dim))
print('')
print(dim, sum(errs)/len(errs)/math.sqrt(dim))
print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim))
print(dim, (max_err.item(), max_relerr.item()))
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16'])
def test_gemm_4bit(dtype):
#for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]:
#for dim in [4096, 5120, 6656, 8192]:
#for dim in [32]:
for dim in [4096]:
errs = []
relerrs = []
max_err = 0
max_relerr = 0
for i in range(1):
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
#A = torch.rand(1, 4096, dtype=dtype, device='cuda')
#B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda')
A = torch.randn(1, dim+0, dtype=dtype, device='cuda')
B = torch.randn(4*dim, dim+0, dtype=dtype, device='cuda')/math.sqrt(dim)
#print('')
#print(A)
#print(B.t())
#A[:, :-1] = 0
#B[:, :-1] = 0
qB, state = F.quantize_nf4(B)
F.dequantize_nf4(qB, state)
C3 = torch.matmul(A, B.t())
C2 = F.cutlass3_gemm(A, qB.t(), state=state)
C1 = bnb.matmul_4bit(A, qB.t(), state)
C2 = F.cutlass3_gemm(A, qB.t(), state=state)
print(C1.shape, C2.shape)
# tensor cores are non-deterministic
# so we need to analyze errors around the mean
# to test our implementation
err = torch.abs(C1-C2)
mag = torch.abs(C1)+1e-8
relerr = err/mag
max_err = max(err.max(), max_err)
max_relerr = max(relerr.max(), max_relerr)
err = err.mean().item()
relerr = relerr.mean().item()
errs.append(err)
relerrs.append(relerr)
if err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5:
print('')
print(i, err, relerr)
print(A.flatten()[-6:])
print(B.flatten()[-6:])
out = A.flatten()[-6:]*B.flatten()[-6:]
print(out)
print(out[:-1].sum())
print('='*80)
print(C1.flatten()[-6:])
print(C2.flatten()[-6:])
#assert False, 'ERROR'
c = int(C1.numel()*0.0014*(dim/256))+1
c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=False)
#print(c/math.sqrt(dim))
print('')
print(dim, sum(errs)/len(errs)/math.sqrt(dim))
print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim))
print(dim, (max_err.item(), max_relerr.item()))
@pytest.mark.skip("Row scale has some bugs for ampere")
def test_managed():
n = 32*10
A = F.get_paged(n, n, dtype=torch.float32)
B = F.get_paged(n, n, dtype=torch.uint8)
B2 = F.get_paged(n, n, dtype=torch.float32)
assert A.is_paged
assert B.is_paged
assert A.page_deviceid==0
assert B.page_deviceid==0
F.fill(A, 17.0)
F.fill(B, 17)
F.fill(B2, 2)
assert (A==17).sum().item() == n*n
assert (B==17).sum().item() == n*n
C = A*B.float()
assert (C==289).sum().item() == n*n
F._mul(A, B2)
F._mul(A, B2)
F._mul(A, B2)
assert (A==17*(2**3)).sum().item() == n*n
# F.prefetch_tensor(A)
# F.prefetch_tensor(B)
# F.fill(B2, 17.0)
# F._mul(A, B2)
# F.prefetch_tensor(A, to_cpu=True)
# F.prefetch_tensor(B, to_cpu=True)
# F.prefetch_tensor(B2, to_cpu=True)
# torch.cuda.synchronize()
# assert (A==17).sum().item() == n*n
# torch.testing.assert_close(A, torch.ones(A.shape)*289)
...@@ -44,7 +44,7 @@ def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10): ...@@ -44,7 +44,7 @@ def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10):
sumval = (idx == 0).sum().item() sumval = (idx == 0).sum().item()
if sumval > count: if sumval > count:
print(f"Too many values not close: assert {sumval} < {count}") print(f"Too many values not close: assert {sumval} < {count}")
torch.testing.assert_allclose(a, b, rtol, atol) torch.testing.assert_close(a, b, rtol, atol)
class LinearFunction(torch.autograd.Function): class LinearFunction(torch.autograd.Function):
...@@ -330,18 +330,15 @@ def test_linear8bitlt_inference(threshold): ...@@ -330,18 +330,15 @@ def test_linear8bitlt_inference(threshold):
def test_linear8bitlt_accumulated_gradient(): def test_linear8bitlt_accumulated_gradient():
l1 = torch.nn.Sequential( l1 = torch.nn.Sequential(*[bnb.nn.Linear8bitLt(32, 32).cuda().half() for i in range(2)])
*[bnb.nn.Linear8bitLt(32, 32).cuda().half() for i in range(2)] l2 = torch.nn.Sequential(*[torch.nn.Linear(32, 32).cuda().half() for i in range(2)])
) l1[0].weight.data.copy_(l2[0].weight.data)
l2 = torch.nn.Sequential( l1[1].weight.data.copy_(l2[1].weight.data)
*[torch.nn.Linear(32, 32).cuda().half() for i in range(2)] l1[0].bias.data.copy_(l2[0].bias.data)
) l1[1].bias.data.copy_(l2[1].bias.data)
l2[0].weight = torch.nn.Parameter(l1[0].weight.clone())
l2[0].bias = torch.nn.Parameter(l1[0].bias.clone()) opt1 = bnb.optim.Adam32bit(l1.parameters(), lr=0.001)
l2[1].weight = torch.nn.Parameter(l1[1].weight.clone()) opt2 = bnb.optim.Adam32bit(l2.parameters(), lr=0.001)
l2[1].bias = torch.nn.Parameter(l1[1].bias.clone())
opt1 = bnb.optim.Adam8bit(l1.parameters(), lr=0.001)
opt2 = bnb.optim.Adam8bit(l2.parameters(), lr=0.001)
acc_steps = 10 acc_steps = 10
...@@ -371,26 +368,17 @@ def test_linear8bitlt_accumulated_gradient(): ...@@ -371,26 +368,17 @@ def test_linear8bitlt_accumulated_gradient():
# we do this copy because otherwise we have small divergences over time that add up # we do this copy because otherwise we have small divergences over time that add up
l1[0].weight.data.copy_(l2[0].weight.data) l1[0].weight.data.copy_(l2[0].weight.data)
l1[1].weight.data.copy_(l2[1].weight.data) l1[1].weight.data.copy_(l2[1].weight.data)
l1[0].bias.data.copy_(l2[0].bias.data)
l1[1].bias.data.copy_(l2[1].bias.data)
else: else:
torch.testing.assert_allclose(l1[0].weight.grad, l2[0].weight.grad) torch.testing.assert_close(l1[0].weight.grad, l2[0].weight.grad, atol=1e-3, rtol=1e-3)
torch.testing.assert_allclose(l1[1].weight.grad, l2[1].weight.grad) torch.testing.assert_close(l1[1].weight.grad, l2[1].weight.grad, atol=1e-3, rtol=1e-3)
threshold = [0.0, 2.0]
values = threshold
names = [f"threshold_{vals}" for vals in values]
@pytest.mark.parametrize("threshold", values, ids=names) @pytest.mark.parametrize("threshold", [0.0, 2.0])
@pytest.mark.parametrize("memory_efficient_backward", [False]) @pytest.mark.parametrize("memory_efficient_backward", [False])
def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
l1 = ( l1 = (bnb.nn.Linear8bitLt( 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward).cuda().half())
bnb.nn.Linear8bitLt(
32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward
)
.cuda()
.half()
)
assert l1.weight.dtype == torch.int8 assert l1.weight.dtype == torch.int8
l1.eval() l1.eval()
...@@ -446,13 +434,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): ...@@ -446,13 +434,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc1.weight.dtype == torch.int8
assert mlp.fc2.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8
mlp = ( mlp = ( MLP8bit( 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward).half().to("cuda"))
MLP8bit(
32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward
)
.half()
.to("cuda")
)
for i in range(100): for i in range(100):
b1 = torch.randn(16, 8, 32, device="cuda").half() b1 = torch.randn(16, 8, 32, device="cuda").half()
...@@ -499,15 +481,16 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): ...@@ -499,15 +481,16 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
grad_ref = grad_proj.flatten(2) @ w2.half() @ w1.half() grad_ref = grad_proj.flatten(2) @ w2.half() @ w1.half()
scale = grad_ref.abs().mean() scale = grad_ref.abs().mean()
torch.testing.assert_allclose(b1.grad, grad_ref, rtol=0, atol=0.05 * scale) torch.testing.assert_close(b1.grad, grad_ref, rtol=0, atol=0.05 * scale)
idx = torch.isclose(b1.grad, grad_ref, atol=0.01 * scale, rtol=0.1) idx = torch.isclose(b1.grad, grad_ref, atol=0.01 * scale, rtol=0.1)
assert (idx == 0).sum().item() <= b1.numel() * 0.005 assert (idx == 0).sum().item() <= b1.numel() * 0.005
def test_linear8bitlt_fp32_bias(): @pytest.mark.parametrize("module", [lambda nin, nout, bias=True: bnb.nn.Linear8bitLt(nin, nout, bias=bias, has_fp16_weights=False), bnb.nn.LinearFP4], ids=['Int8Lt', 'FP4'])
def test_linear_kbit_fp32_bias(module):
# casts model to fp16 -> int8 automatically # casts model to fp16 -> int8 automatically
l1 = bnb.nn.Linear8bitLt(32, 64, has_fp16_weights=False).cuda() l1 = module(32, 64).cuda()
assert l1.weight.dtype == torch.int8 assert l1.weight.dtype in [torch.int8, torch.uint8]
assert l1.bias.dtype == torch.float32 assert l1.bias.dtype == torch.float32
for i in range(100): for i in range(100):
...@@ -517,11 +500,116 @@ def test_linear8bitlt_fp32_bias(): ...@@ -517,11 +500,116 @@ def test_linear8bitlt_fp32_bias():
assert l1.bias.dtype == torch.float16 assert l1.bias.dtype == torch.float16
# casts model to fp16 -> int8 automatically # casts model to fp16 -> int8 automatically
l1 = bnb.nn.Linear8bitLt(32, 64, has_fp16_weights=False, bias=False).cuda() l1 = module(32, 64, bias=False).cuda()
assert l1.weight.dtype == torch.int8 assert l1.weight.dtype in [torch.int8, torch.uint8]
assert l1.bias is None assert l1.bias is None
for i in range(100): for i in range(100):
b1 = torch.randn(16, 8, 32, device="cuda").half() b1 = torch.randn(16, 8, 32, device="cuda").half()
o1 = l1(b1) o1 = l1(b1)
assert l1.bias is None assert l1.bias is None
modules = []
modules.append(bnb.nn.Linear8bitLt)
modules.append(bnb.nn.Linear4bit)
modules.append(bnb.nn.LinearFP4)
modules.append(bnb.nn.LinearNF4)
modules.append(lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compress_statistics=True))
modules.append(lambda d1, d2: bnb.nn.LinearNF4(d1, d2, compress_statistics=True))
names = ['Int8Lt', '4bit', 'FP4', 'NF4', 'FP4+C', 'NF4+C']
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
@pytest.mark.parametrize("module", modules, ids=names)
def test_kbit_backprop(module):
b = 17
dim1 = 37
dim2 = 83
ref = nn.Sequential(*[torch.nn.Linear(dim1, dim2), torch.nn.Linear(dim2, 10)])
ref[1].weight.requires_grad = False
torch.nn.init.kaiming_normal_(ref[0].weight)
torch.nn.init.kaiming_normal_(ref[1].weight)
kbit = nn.Sequential(*[torch.nn.Linear(dim1, dim2), module(dim2, 10)])
kbit[0].weight.detach().copy_(ref[0].weight)
kbit[1].weight.detach().copy_(ref[1].weight)
kbit[0].bias.detach().copy_(ref[0].bias)
kbit[1].bias.detach().copy_(ref[1].bias)
ref = ref.half().cuda()
kbit = kbit.half().cuda()
errs1 = []
errs2 = []
relerrs1 = []
relerrs2 = []
for i in range(100):
batch = torch.randn(b, dim1).half().cuda()
out1 = ref(batch)
out2 = kbit(batch)
out1.mean().backward()
out2.mean().backward()
grad1 = ref[0].weight.grad
grad2 = kbit[0].weight.grad
bgrad1 = ref[0].bias.grad
bgrad2 = kbit[0].bias.grad
err1 = (out1-out2).abs().float()
err2 = (grad1-grad2).abs().float()
relerr1 = (err1/(out1.abs().float()+1e-9))
relerr2 = (err2/(grad1.abs().float()+1e-9))
errs1.append(err1.mean().item())
errs2.append(err2.mean().item())
relerrs1.append(relerr1.mean().item())
relerrs2.append(relerr2.mean().item())
if isinstance(module, bnb.nn.Linear8bitLt):
torch.testing.assert_close(grad1, grad2, atol=0.008, rtol=0.05)
torch.testing.assert_close(bgrad1, bgrad2, atol=0.008, rtol=0.05)
else:
torch.testing.assert_close(grad1, grad2, atol=0.015, rtol=0.05)
torch.testing.assert_close(bgrad1, bgrad2, atol=0.02, rtol=0.05)
ref.zero_grad()
kbit.zero_grad()
assert kbit[0].weight.grad is None or kbit[0].weight.grad.sum().item() == 0
assert kbit[0].weight.grad is None or kbit[0].bias.grad.sum().item() == 0
print('out', sum(errs1)/len(errs1))
print('grad', sum(errs2)/len(errs2))
print('rel out', sum(relerrs1)/len(relerrs1))
print('rel grad', sum(relerrs2)/len(relerrs2))
def test_fp8linear():
b = 10
h = 1024
inp = torch.randn(b, h).cuda()
fp32 = torch.nn.Linear(h, h*2).cuda()
fp8 = bnb.research.nn.LinearFP8Mixed(h, h*2).cuda()
fp32b = torch.nn.Linear(h*2, h).cuda()
fp8b = bnb.research.nn.LinearFP8Mixed(h*2, h).cuda()
fp8.weight.data.copy_(fp32.weight.data)
fp8.bias.data.copy_(fp32.bias.data)
fp8b.weight.data.copy_(fp32b.weight.data)
fp8b.bias.data.copy_(fp32b.bias.data)
a = fp32b(torch.nn.functional.gelu(fp32(inp)))
b = fp8b(torch.nn.functional.gelu(fp8(inp)))
err = (a-b).abs().mean()
a.mean().backward()
b.mean().backward()
graderr = (fp8.weight.grad-fp32.weight.grad).abs().mean()
bgraderr = (fp8.bias.grad-fp32.bias.grad).abs().mean()
assert err < 0.05
assert graderr < 0.00002
assert bgraderr < 0.00002
...@@ -19,11 +19,11 @@ import bitsandbytes.functional as F ...@@ -19,11 +19,11 @@ import bitsandbytes.functional as F
k = 20 k = 20
def assert_most_approx_close(a, b, rtol=1e-3, atol=1e-3, max_error_count=0): def assert_most_approx_close(a, b, rtol=1e-3, atol=1e-3, max_error_count=0):
idx = torch.isclose(a, b, rtol, atol) idx = torch.isclose(a, b, rtol=rtol, atol=atol)
error_count = (idx == 0).sum().item() error_count = (idx == 0).sum().item()
if error_count > max_error_count: if error_count > max_error_count:
print(f"Too many values not close: assert {error_count} < {max_error_count}") print(f"Too many values not close: assert {error_count} < {max_error_count}")
torch.testing.assert_allclose(a, b, rtol, atol) torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
def get_temp_dir(): def get_temp_dir():
...@@ -35,11 +35,8 @@ def get_temp_dir(): ...@@ -35,11 +35,8 @@ def get_temp_dir():
def rm_path(path): def rm_path(path):
shutil.rmtree(path) shutil.rmtree(path)
str2optimizers = {} str2optimizers = {}
str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam) str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam)
# str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam)
# str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam)
str2optimizers["lion_pytorch"] = (None, Lion, bnb.optim.Lion) str2optimizers["lion_pytorch"] = (None, Lion, bnb.optim.Lion)
str2optimizers["momentum_pytorch"] = ( str2optimizers["momentum_pytorch"] = (
None, None,
...@@ -47,28 +44,20 @@ str2optimizers["momentum_pytorch"] = ( ...@@ -47,28 +44,20 @@ str2optimizers["momentum_pytorch"] = (
bnb.optim.Adam, bnb.optim.Adam,
) )
str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam) str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam)
# str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam) str2optimizers["paged_adamw"] = (torch.optim.AdamW, bnb.optim.PagedAdamW)
str2optimizers["paged_adam"] = (torch.optim.Adam, bnb.optim.PagedAdam)
str2optimizers["lion"] = (Lion, bnb.optim.Lion) str2optimizers["lion"] = (Lion, bnb.optim.Lion)
str2optimizers["paged_lion"] = (Lion, bnb.optim.PagedLion)
str2optimizers["momentum"] = ( str2optimizers["momentum"] = (
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False), lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False),
) )
str2optimizers["lars"] = (
lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.LARS(pxx, 0.01, 0.9),
)
str2optimizers["rmsprop"] = ( str2optimizers["rmsprop"] = (
lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False), lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False),
) )
str2optimizers["adam8bit"] = ( str2optimizers["adam8bit"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False))
torch.optim.Adam, str2optimizers["lion8bit"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=False))
lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False),
)
str2optimizers["lion8bit"] = (
Lion,
lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=False),
)
str2optimizers["momentum8bit"] = ( str2optimizers["momentum8bit"] = (
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False),
...@@ -77,19 +66,12 @@ str2optimizers["rmsprop8bit"] = ( ...@@ -77,19 +66,12 @@ str2optimizers["rmsprop8bit"] = (
lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False), lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False),
) )
str2optimizers["lars8bit"] = (
lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.LARS8bit(pxx, 0.01, 0.9),
)
str2optimizers["adam8bit_blockwise"] = ( str2optimizers["adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True))
torch.optim.Adam, str2optimizers["paged_adamw8bit_blockwise"] = (torch.optim.AdamW, lambda pxx: bnb.optim.PagedAdamW8bit(pxx, block_wise=True))
lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True), str2optimizers["paged_adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.PagedAdam8bit(pxx, block_wise=True))
) str2optimizers["lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=True))
str2optimizers["lion8bit_blockwise"] = ( str2optimizers["paged_lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.PagedLion8bit(pxx, block_wise=True))
Lion,
lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=True),
)
str2optimizers["momentum8bit_blockwise"] = ( str2optimizers["momentum8bit_blockwise"] = (
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True),
...@@ -101,53 +83,35 @@ str2optimizers["rmsprop8bit_blockwise"] = ( ...@@ -101,53 +83,35 @@ str2optimizers["rmsprop8bit_blockwise"] = (
str2statenames = {} str2statenames = {}
str2statenames["adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] str2statenames["adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["paged_adamw"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["paged_adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["lion"] = [("exp_avg", "state1")] str2statenames["lion"] = [("exp_avg", "state1")]
str2statenames["paged_lion"] = [("exp_avg", "state1")]
str2statenames["momentum"] = [("momentum_buffer", "state1")] str2statenames["momentum"] = [("momentum_buffer", "state1")]
str2statenames["lars"] = [("momentum_buffer", "state1")]
str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["rmsprop"] = [("square_avg", "state1")] str2statenames["rmsprop"] = [("square_avg", "state1")]
str2statenames["adam8bit"] = [ str2statenames["adam8bit"] = [("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2")]
("exp_avg", "state1", "qmap1", "max1"), str2statenames["lamb8bit"] = [("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2")]
("exp_avg_sq", "state2", "qmap2", "max2"), str2statenames["adam8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")]
] str2statenames["paged_adam8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")]
str2statenames["lion8bit"] = [ str2statenames["paged_adamw8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")]
("exp_avg", "state1", "qmap1", "max1") str2statenames["momentum8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")]
] str2statenames["lion8bit"] = [("exp_avg", "state1", "qmap1", "max1")]
str2statenames["lamb8bit"] = [ str2statenames["momentum8bit_blockwise"] = [("momentum_buffer", "state1", "qmap1", "absmax1")]
("exp_avg", "state1", "qmap1", "max1"),
("exp_avg_sq", "state2", "qmap2", "max2"),
]
str2statenames["adam8bit_blockwise"] = [
("exp_avg", "state1", "qmap1", "absmax1"),
("exp_avg_sq", "state2", "qmap2", "absmax2"),
]
str2statenames["lion8bit_blockwise"] = [
("exp_avg", "state1", "qmap1", "absmax1")
]
str2statenames["momentum8bit"] = [
("momentum_buffer", "state1", "qmap1", "max1")
]
str2statenames["momentum8bit_blockwise"] = [
("momentum_buffer", "state1", "qmap1", "absmax1")
]
str2statenames["lars8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")]
str2statenames["rmsprop8bit"] = [("square_avg", "state1", "qmap1", "max1")] str2statenames["rmsprop8bit"] = [("square_avg", "state1", "qmap1", "max1")]
str2statenames["rmsprop8bit_blockwise"] = [ str2statenames["rmsprop8bit_blockwise"] = [("square_avg", "state1", "qmap1", "absmax1")]
("square_avg", "state1", "qmap1", "absmax1") str2statenames["lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")]
] str2statenames["paged_lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")]
dim1 = [1024] dim1 = [1024]
dim2 = [32, 1024, 4097, 1] dim2 = [32, 1024, 4097, 1]
gtype = [torch.float32, torch.float16] gtype = [torch.float32, torch.float16, torch.bfloat16]
optimizer_names = ["adam", "momentum", "rmsprop", "lars", "lion"] optimizer_names = ["adam", "momentum", "rmsprop", 'paged_adamw', 'paged_adam', 'lion', 'paged_lion']
values = list(product(dim1, dim2, gtype, optimizer_names)) values = list(product(dim1, dim2, gtype, optimizer_names))
names = [ names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values]
"dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values
]
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) @pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
def test_optimizer32bit(dim1, dim2, gtype, optim_name): def test_optimizer32bit(dim1, dim2, gtype, optim_name):
if gtype == torch.bfloat16 and optim_name in ['momentum', 'rmsprop']: pytest.skip()
if dim1 == 1 and dim2 == 1: if dim1 == 1 and dim2 == 1:
return return
p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
...@@ -159,6 +123,8 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): ...@@ -159,6 +123,8 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
if gtype == torch.float32: if gtype == torch.float32:
atol, rtol = 1e-6, 1e-5 atol, rtol = 1e-6, 1e-5
elif gtype == torch.bfloat16:
atol, rtol = 1e-3, 1e-2
else: else:
atol, rtol = 1e-4, 1e-3 atol, rtol = 1e-4, 1e-3
...@@ -172,9 +138,9 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): ...@@ -172,9 +138,9 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
for name1, name2 in str2statenames[optim_name]: for name1, name2 in str2statenames[optim_name]:
torch.testing.assert_allclose( torch.testing.assert_close(
torch_optimizer.state[p1][name1], torch_optimizer.state[p1][name1],
bnb_optimizer.state[p2][name2], bnb_optimizer.state[p2][name2].cuda(),
atol=atol, atol=atol,
rtol=rtol, rtol=rtol,
) )
...@@ -201,14 +167,14 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): ...@@ -201,14 +167,14 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
atol=atol, rtol=rtol, atol=atol, rtol=rtol,
max_error_count=10) max_error_count=10)
if gtype == torch.float16: if gtype != torch.float32:
# the adam buffers should also be close because they are 32-bit # the adam buffers should also be close because they are 32-bit
# but the paramters can diverge because they are 16-bit # but the paramters can diverge because they are 16-bit
# the difference grow larger and larger with each update # the difference grow larger and larger with each update
# --> copy the state to keep weights close # --> copy the state to keep weights close
p1.data = p1.data.half().float() p1.data = p1.data.to(p2.dtype).float()
p2.copy_(p1.data) p2.copy_(p1.data)
torch.testing.assert_allclose(p1.half(), p2) torch.testing.assert_close(p1.to(p2.dtype), p2)
if optim_name in ["lars", "lamb"]: if optim_name in ["lars", "lamb"]:
assert bnb_optimizer.state[p2]["unorm_vec"] > 0.0 assert bnb_optimizer.state[p2]["unorm_vec"] > 0.0
...@@ -268,7 +234,7 @@ def test_global_config(dim1, dim2, gtype): ...@@ -268,7 +234,7 @@ def test_global_config(dim1, dim2, gtype):
dim1 = [1024] dim1 = [1024]
dim2 = [32, 1024, 4097] dim2 = [32, 1024, 4097]
gtype = [torch.float32, torch.float16] gtype = [torch.float32, torch.float16, torch.bfloat16]
optimizer_names = [ optimizer_names = [
"adam8bit", "adam8bit",
"lion8bit", "lion8bit",
...@@ -276,7 +242,6 @@ optimizer_names = [ ...@@ -276,7 +242,6 @@ optimizer_names = [
"rmsprop8bit", "rmsprop8bit",
"adam8bit_blockwise", "adam8bit_blockwise",
"lion8bit_blockwise", "lion8bit_blockwise",
"lars8bit",
"momentum8bit_blockwise", "momentum8bit_blockwise",
"rmsprop8bit_blockwise", "rmsprop8bit_blockwise",
] ]
...@@ -288,6 +253,7 @@ names = [ ...@@ -288,6 +253,7 @@ names = [
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) @pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
def test_optimizer8bit(dim1, dim2, gtype, optim_name): def test_optimizer8bit(dim1, dim2, gtype, optim_name):
if gtype == torch.bfloat16 and optim_name not in ['adam8bit_blockwise', 'lion8bit_blockwise']: pytest.skip()
if dim1 == 1 and dim2 == 1: if dim1 == 1 and dim2 == 1:
return return
p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
...@@ -301,7 +267,9 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): ...@@ -301,7 +267,9 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
if gtype == torch.float32: if gtype == torch.float32:
atol, rtol = 3e-3, 1e-3 atol, rtol = 3e-3, 1e-3
patol, prtol = 1e-5, 1e-3 patol, prtol = 1e-5, 1e-3
elif gtype == torch.bfloat16:
atol, rtol = 3e-3, 1e-3
patol, prtol = 1e-4, 1e-2
else: else:
atol, rtol = 3e-3, 1e-3 atol, rtol = 3e-3, 1e-3
patol, prtol = 1e-5, 1e-3 patol, prtol = 1e-5, 1e-3
...@@ -309,7 +277,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): ...@@ -309,7 +277,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
errors = [] errors = []
relerrors = [] relerrors = []
for i in range(50): for i in range(100):
g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01 g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01
p1.grad = g.clone().float() p1.grad = g.clone().float()
p2.grad = g.clone() p2.grad = g.clone()
...@@ -343,13 +311,17 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): ...@@ -343,13 +311,17 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
) )
== 0 == 0
) )
assert num_not_close.sum().item() < 20 #assert num_not_close.sum().item() < 20
dequant_states.append(s1.clone()) dequant_states.append(s1.clone())
err = torch.abs(p1 - p2) err = torch.abs(p1 - p2)
relerr = err / (torch.abs(p1)+1e-9) relerr = err / (torch.abs(p1)+1e-9)
assert err.mean() < 0.0001 if g.dtype == torch.bfloat16:
assert relerr.mean() < 0.001 assert err.mean() < 0.00015
assert relerr.mean() < 0.0016
else:
assert err.mean() < 0.00012
assert relerr.mean() < 0.0012
errors.append(err.mean().item()) errors.append(err.mean().item())
relerrors.append(relerr.mean().item()) relerrors.append(relerr.mean().item())
...@@ -369,12 +341,8 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): ...@@ -369,12 +341,8 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
bnb_optimizer = str2optimizers[optim_name][1]([p2]) bnb_optimizer = str2optimizers[optim_name][1]([p2])
bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt"))) bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt")))
rm_path(path) rm_path(path)
torch.testing.assert_allclose( torch.testing.assert_close(raws1cpy, bnb_optimizer.state[p2][name2])
raws1cpy, bnb_optimizer.state[p2][name2] torch.testing.assert_close(qmap1, bnb_optimizer.state[p2][qmap])
)
torch.testing.assert_allclose(
qmap1, bnb_optimizer.state[p2][qmap]
)
if "blockwise" in optim_name: if "blockwise" in optim_name:
s1 = F.dequantize_blockwise( s1 = F.dequantize_blockwise(
...@@ -389,17 +357,9 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): ...@@ -389,17 +357,9 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
absmax=bnb_optimizer.state[p2][max_val], absmax=bnb_optimizer.state[p2][max_val],
A=bnb_optimizer.state[p2][name2], A=bnb_optimizer.state[p2][name2],
) )
torch.testing.assert_allclose(s1cpy, s1) torch.testing.assert_close(s1cpy, s1)
num_not_close = ( num_not_close = (torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0)
torch.isclose(
torch_optimizer.state[p1][name1],
s1,
atol=atol,
rtol=rtol,
)
== 0
)
assert num_not_close.sum().item() < 20 assert num_not_close.sum().item() < 20
# since Lion can have pretty noisy updates where things lie at the boundary # since Lion can have pretty noisy updates where things lie at the boundary
# allow up to 5 errors for Lion # allow up to 5 errors for Lion
...@@ -409,10 +369,8 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): ...@@ -409,10 +369,8 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
# together so we can test against the Adam error # together so we can test against the Adam error
p1.data = p1.data.to(gtype).float() p1.data = p1.data.to(gtype).float()
p2.copy_(p1.data) p2.copy_(p1.data)
torch.testing.assert_allclose(p1.to(gtype), p2) torch.testing.assert_close(p1.to(gtype), p2)
for (name1, name2, qmap, max_val), s in zip( for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states):
str2statenames[optim_name], dequant_states
):
torch_optimizer.state[p1][name1].copy_(s.data) torch_optimizer.state[p1][name1].copy_(s.data)
# print(sum(errors)/len(errors)) # print(sum(errors)/len(errors))
...@@ -473,28 +431,28 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits): ...@@ -473,28 +431,28 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
# gnorm_scale is not deterministic (warp reductions), as such there can be slight differences in state # gnorm_scale is not deterministic (warp reductions), as such there can be slight differences in state
if optim_bits == 32: if optim_bits == 32:
torch.testing.assert_allclose(p1, p2) torch.testing.assert_close(p1, p2)
torch.testing.assert_allclose( torch.testing.assert_close(
adam1.state[p1]["state1"], adam1.state[p1]["state1"],
adam2.state[p2]["state1"], adam2.state[p2]["state1"],
atol=5e-5, atol=5e-5,
rtol=1e-4, rtol=1e-4,
) )
torch.testing.assert_allclose( torch.testing.assert_close(
adam1.state[p1]["state2"], adam1.state[p1]["state2"],
adam2.state[p2]["state2"], adam2.state[p2]["state2"],
atol=5e-5, atol=5e-5,
rtol=1e-4, rtol=1e-4,
) )
elif optim_bits == 8: elif optim_bits == 8:
torch.testing.assert_allclose(p1, p2, atol=1e-4, rtol=1e-3) torch.testing.assert_close(p1, p2, atol=1e-4, rtol=1e-3)
torch.testing.assert_allclose( torch.testing.assert_close(
adam1.state[p1]["state1"], adam1.state[p1]["state1"],
adam2.state[p2]["state1"], adam2.state[p2]["state1"],
atol=2, atol=2,
rtol=1e-3, rtol=1e-3,
) )
torch.testing.assert_allclose( torch.testing.assert_close(
adam1.state[p1]["state2"], adam1.state[p1]["state2"],
adam2.state[p2]["state2"], adam2.state[p2]["state2"],
atol=2, atol=2,
...@@ -526,7 +484,7 @@ gtype = [torch.float32, torch.float16] ...@@ -526,7 +484,7 @@ gtype = [torch.float32, torch.float16]
# optimizer_names = ['momentum_apex', 'momentum8bit', 'momentum_pytorch'] # optimizer_names = ['momentum_apex', 'momentum8bit', 'momentum_pytorch']
# optimizer_names = ['lamb_apex', 'lamb8bit'] # optimizer_names = ['lamb_apex', 'lamb8bit']
# optimizer_names = ['lars_apex', 'lars8bit'] # optimizer_names = ['lars_apex', 'lars8bit']
optimizer_names = ["adam8bit_blockwise"] optimizer_names = ["adam8bit_blockwise", 'paged_adam8bit_blockwise', 'paged_adamw8bit_blockwise', 'paged_lion8bit_blockwise']
values = list(product(dim1, dim2, gtype, optimizer_names)) values = list(product(dim1, dim2, gtype, optimizer_names))
names = [ names = [
"dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values "dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values
...@@ -557,3 +515,47 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name): ...@@ -557,3 +515,47 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
params = (k - k // 5) * dim1 * dim2 params = (k - k // 5) * dim1 * dim2
print(optim_name, gtype, s / params) print(optim_name, gtype, s / params)
# assert s < 3.9 # assert s < 3.9
dim1 = [2*1024]
gtype = [torch.float16]
#mode = ['torch', 'bnb']
mode = ['bnb']
optimizer_names = ['paged_adamw']
#optimizer_names = ['paged_adamw8bit_blockwise']
values = list(product(dim1,gtype, optimizer_names, mode))
names = ['dim1_{0}_gtype_{1}_optim_{2}_mode_{3}'.format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, gtype, optim_name, mode", values, ids=names)
def test_stream_optimizer_bench(dim1, gtype, optim_name, mode):
layers1 = torch.nn.Sequential(*torch.nn.ModuleList([torch.nn.Linear(dim1, dim1) for i in range(10)]))
layers1 = layers1.to(gtype)
layers1 = layers1.cuda()
large_tensor = None
if mode == 'torch':
optim = str2optimizers[optim_name][0](layers1.parameters())
else:
optim = str2optimizers[optim_name][1](layers1.parameters())
# 12 GB
large_tensor = torch.empty((int(4.5e9),), device='cuda')
torch.cuda.synchronize()
time.sleep(5)
num_batches = 5
batches = torch.randn(num_batches, 128, dim1, device='cuda').to(gtype)
lbls = torch.randint(0, 10, size=(num_batches,128)).cuda()
for i in range(num_batches):
print(i)
b = batches[i]
if i ==2:
torch.cuda.synchronize()
t0 = time.time()
out1 = layers1(b)
loss1 = torch.nn.functional.cross_entropy(out1, lbls[i]).mean()
loss1.backward()
optim.step()
torch.cuda.synchronize()
print(mode, time.time() - t0)
import pytest
import torch
from bitsandbytes.triton.triton_utils import is_triton_available
from bitsandbytes.nn.triton_based_modules import SwitchBackLinear
from bitsandbytes.nn import Linear8bitLt
@pytest.mark.skipif(not is_triton_available() or not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] >= 8,
reason="This test requires triton and a GPU with compute capability 8.0 or higher.")
@pytest.mark.parametrize("vector_wise_quantization", [False, True])
def test_switchback(vector_wise_quantization):
for dim in [83]:
for batch in [13]:
standard = torch.nn.Linear(dim, 4 * dim).cuda().half()
switchback = SwitchBackLinear(dim, 4 * dim, vector_wise_quantization=vector_wise_quantization).cuda().half()
baseline = Linear8bitLt(dim, 4 * dim).cuda().half()
switchback.weight.data.copy_(standard.weight)
switchback.bias.data.copy_(standard.bias)
baseline.weight.data.copy_(standard.weight)
baseline.bias.data.copy_(standard.bias)
x1 = torch.randn(batch, dim).cuda().half().requires_grad_(True)
x2 = x1.clone().detach().requires_grad_(True)
x3 = x1.clone().detach().requires_grad_(True)
out_standard = standard(x1)
(2**10 * out_standard.abs().mean()).backward()
print(x2.dtype)
out_sb = switchback(x2)
(2**10 * out_sb.abs().mean()).backward()
out_baseline = baseline(x3)
(2**10 * out_baseline.abs().mean()).backward()
err_sb = (out_standard - out_sb).abs().mean()
err_baseline = (out_standard - out_baseline).abs().mean()
print('OUT', err_sb, err_baseline)
assert err_sb < 2 * err_baseline
err_sb = (standard.bias.grad - switchback.bias.grad).abs().mean()
err_baseline = (standard.bias.grad - baseline.bias.grad).abs().mean()
print('GW2', err_sb, err_baseline)
assert err_sb < 2 * err_baseline
err_sb = (standard.weight.grad - switchback.weight.grad).abs().mean()
err_baseline = (standard.weight.grad - baseline.weight.grad).abs().mean()
print('GW1', err_sb, err_baseline)
assert err_sb < 2 * err_baseline
err_sb = (x1.grad - x2.grad).abs().mean()
err_baseline = (x1.grad - x3.grad).abs().mean()
print('GX1', err_sb, err_baseline)
assert err_sb < 2 * err_baseline
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment