"vscode:/vscode.git/clone" did not exist on "eb08835fa246ea81263eb25fbe2caa54ef11271c"
Commit a24aae30 authored by Jeongseok Kang's avatar Jeongseok Kang
Browse files

Merge branch 'main' into fix/libcuda-to-torch

parents 2b4cc256 4395d68c
...@@ -139,17 +139,6 @@ if [ ! -f "./bitsandbytes/libbitsandbytes_cuda121.so" ]; then ...@@ -139,17 +139,6 @@ if [ ! -f "./bitsandbytes/libbitsandbytes_cuda121.so" ]; then
fi fi
make clean
export CUDA_HOME=$BASE_PATH/cuda-10.2
make cuda10x_nomatmul CUDA_VERSION=102
if [ ! -f "./bitsandbytes/libbitsandbytes_cuda102_nocublaslt.so" ]; then
# Control will enter here if $DIRECTORY doesn't exist.
echo "Compilation unsuccessul!" 1>&2
exit 64
fi
make clean make clean
export CUDA_HOME=$BASE_PATH/cuda-11.0 export CUDA_HOME=$BASE_PATH/cuda-11.0
make cuda110_nomatmul CUDA_VERSION=110 make cuda110_nomatmul CUDA_VERSION=110
......
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
MAX_NEW_TOKENS = 128
model_name = 'decapoda-research/llama-7b-hf'
text = 'Hamburg is in which country?\n'
tokenizer = AutoTokenizer.from_pretrained(model_name)
input_ids = tokenizer(text, return_tensors="pt").input_ids
free_in_GB = int(torch.cuda.mem_get_info()[0]/1024**3)
max_memory = f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB'
n_gpus = torch.cuda.device_count()
max_memory = {i: max_memory for i in range(n_gpus)}
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map='auto',
load_in_8bit=True,
max_memory=max_memory
)
generated_ids = model.generate(input_ids, max_length=MAX_NEW_TOKENS)
print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
...@@ -18,10 +18,10 @@ def read(fname): ...@@ -18,10 +18,10 @@ def read(fname):
setup( setup(
name=f"bitsandbytes", name=f"bitsandbytes",
version=f"0.38.0", version=f"0.39.1",
author="Tim Dettmers", author="Tim Dettmers",
author_email="dettmers@cs.washington.edu", author_email="dettmers@cs.washington.edu",
description="8-bit optimizers and matrix multiplication routines.", description="k-bit optimizers and matrix multiplication routines.",
license="MIT", license="MIT",
keywords="gpu optimizers optimization 8-bit quantization compression", keywords="gpu optimizers optimization 8-bit quantization compression",
url="https://github.com/TimDettmers/bitsandbytes", url="https://github.com/TimDettmers/bitsandbytes",
......
...@@ -97,7 +97,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): ...@@ -97,7 +97,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
B.grad = None B.grad = None
if req_grad[0]: if req_grad[0]:
torch.testing.assert_allclose( torch.testing.assert_close(
gradA1, gradA2, atol=0.015, rtol=0.1 gradA1, gradA2, atol=0.015, rtol=0.1
) )
if req_grad[1]: if req_grad[1]:
...@@ -106,7 +106,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): ...@@ -106,7 +106,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
assert (idx == 0).sum().item() < n * 0.1 assert (idx == 0).sum().item() < n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx == 0).sum().item() < n * 0.02 assert (idx == 0).sum().item() < n * 0.02
torch.testing.assert_allclose( torch.testing.assert_close(
gradB1, gradB2, atol=0.18, rtol=0.3 gradB1, gradB2, atol=0.18, rtol=0.3
) )
...@@ -135,7 +135,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): ...@@ -135,7 +135,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
n = out_bnb.numel() n = out_bnb.numel()
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1) idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
assert (idx == 0).sum().item() < n * 0.01 assert (idx == 0).sum().item() < n * 0.01
torch.testing.assert_allclose( torch.testing.assert_close(
out_bnb, out_torch, atol=0.027, rtol=0.2 out_bnb, out_torch, atol=0.027, rtol=0.2
) )
...@@ -159,7 +159,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): ...@@ -159,7 +159,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
B.grad = None B.grad = None
if req_grad[0]: if req_grad[0]:
torch.testing.assert_allclose( torch.testing.assert_close(
gradA1, gradA2, atol=0.015, rtol=0.1 gradA1, gradA2, atol=0.015, rtol=0.1
) )
if req_grad[1]: if req_grad[1]:
...@@ -218,7 +218,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): ...@@ -218,7 +218,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
B.grad = None B.grad = None
if req_grad[0]: if req_grad[0]:
torch.testing.assert_allclose( torch.testing.assert_close(
gradA1, gradA2, atol=0.015, rtol=0.1 gradA1, gradA2, atol=0.015, rtol=0.1
) )
if req_grad[1]: if req_grad[1]:
...@@ -239,8 +239,8 @@ dim4 = torch.randint(32, 96, size=(n,)).tolist() ...@@ -239,8 +239,8 @@ dim4 = torch.randint(32, 96, size=(n,)).tolist()
dim2.append(0) dim2.append(0)
decomp = [0.0, 6.0] decomp = [0.0, 6.0]
funcs = [(torch.matmul, bnb.matmul)] funcs = [(torch.matmul, bnb.matmul), (torch.matmul, bnb.research.switchback_bnb)]
str_funcs = ["matmul"] str_funcs = ["matmullt", 'switchback_bnb']
req_grad = [(False, False), (True, False), (True, True), (False, True)] req_grad = [(False, False), (True, False), (True, True), (False, True)]
req_grad = list(product([True, False], repeat=3)) req_grad = list(product([True, False], repeat=3))
req_grad_str = [] req_grad_str = []
...@@ -407,7 +407,7 @@ def test_matmullt( ...@@ -407,7 +407,7 @@ def test_matmullt(
bias.grad = None bias.grad = None
if req_grad[0]: if req_grad[0]:
torch.testing.assert_allclose( torch.testing.assert_close(
gradA1, gradA2, atol=0.015, rtol=0.1 gradA1, gradA2, atol=0.015, rtol=0.1
) )
if req_grad[1]: if req_grad[1]:
...@@ -423,9 +423,204 @@ def test_matmullt( ...@@ -423,9 +423,204 @@ def test_matmullt(
assert (idx == 0).sum().item() <= n * 0.1 assert (idx == 0).sum().item() <= n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx == 0).sum().item() <= n * 0.02 assert (idx == 0).sum().item() <= n * 0.02
torch.testing.assert_allclose( torch.testing.assert_close(
gradB1, gradB2, atol=0.18, rtol=0.3 gradB1, gradB2, atol=0.18, rtol=0.3
) )
if req_grad[2]: if req_grad[2]:
torch.testing.assert_allclose(gradBias1, gradBias2) torch.testing.assert_close(gradBias1, gradBias2)
n = 1
k = 3
dim1 = torch.randint(16, 64, size=(n,)).tolist()
dim2 = torch.randint(32, 96, size=(n,)).tolist()
dim3 = torch.randint(32, 96, size=(n,)).tolist()
dim4 = torch.randint(32, 96, size=(n,)).tolist()
dim2.append(0)
funcs = [(torch.matmul, bnb.matmul_4bit)]
str_funcs = ["matmul"]
req_grad = list(product([True, False], repeat=3))
req_grad_str = []
for c in req_grad:
strval = ''
for v in c:
if v == True: strval += 'T'
else: strval += 'F'
req_grad_str.append(strval)
transpose = [(False, True), (False, False)]
str_transpose = ["NT", "NN"]
dtype = [torch.float16, torch.float32]
compress_statistics = [False, True]
has_fp16_weights = [True, False]
has_bias = [True, False]
quant_type = ['fp4', 'nf4']
values = list(product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics, quant_type))
str_values = list(product(dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose, has_bias, compress_statistics, quant_type))
names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}_has_bias_{}_compress_statistics_{}_quant_type_{}".format(*vals) for vals in str_values]
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
@pytest.mark.parametrize( "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics, quant_type", values, ids=names)
def test_matmul_4bit( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics, quant_type):
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
if has_bias == False:
req_grad = list(req_grad)
req_grad[2] = False
for i in range(k):
# normal multiply
if funcs[0] in [torch.mm, torch.matmul]:
A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype)
B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype)
target = torch.randn(size=(dim2, dim4), device="cuda", requires_grad=req_grad[1], dtype=dtype)
bias = None
bias2 = None
if has_bias:
bias = torch.randn(dim4, device='cuda', dtype=dtype, requires_grad=req_grad[2])
bias2 = bias.clone()
torch.nn.init.xavier_uniform_(B)
B2, quant_state = bnb.functional.quantize_4bit(B, compress_statistics=compress_statistics, quant_type=quant_type)
if not transpose[0] and transpose[1]:
out_torch = funcs[0](A, B.t())
out_bnb = funcs[1](A, B2.t(), quant_state, bias=bias2)
elif not transpose[0] and not transpose[1]:
out_torch = funcs[0](A, B)
out_bnb = funcs[1](A, B2, quant_state, bias=bias2)
if has_bias:
out_torch += bias
assert out_bnb.dtype == A.dtype, f"bnb matmullt received {A.dtype} but returned {out_bnb.dtype}"
n = out_bnb.numel()
err = torch.abs(out_bnb - out_torch).float().mean().item()
if n > 0:
assert err < 0.115
#assert err < 0.20
if any(req_grad):
out_bnb.data.copy_(out_torch)
torch.cuda.synchronize()
loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
loss_bnb.backward()
gradA1 = A.grad
gradB1 = B.grad
A.grad = None
B.grad = None
if has_bias:
gradBias1 = bias.grad
bias.grad = None
loss_torch = torch.nn.functional.mse_loss( out_torch, target ).mean()
loss_torch.backward()
gradA2 = A.grad
gradB2 = B.grad
A.grad = None
B.grad = None
if has_bias:
gradBias2 = bias.grad
bias.grad = None
if req_grad[0]:
torch.testing.assert_close( gradA1, gradA2, atol=0.015, rtol=0.1)
if req_grad[2]:
torch.testing.assert_close(gradBias1, gradBias2)
funcs = [(torch.matmul, bnb.research.matmul_fp8_mixed), (torch.matmul, bnb.research.matmul_fp8_global)]
str_funcs = ["matmul_fp8_mixed", 'matmul_fp8_global']
req_grad = list(product([True, False], repeat=3))
req_grad_str = []
for c in req_grad:
strval = ''
for v in c:
if v == True: strval += 'T'
else: strval += 'F'
req_grad_str.append(strval)
transpose = [(False, True), (False, False)]
str_transpose = ["NT", "NN"]
dtype = [torch.float16, torch.float32]
has_fp16_weights = [True, False]
values = list(product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose))
str_values = list(product(dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose))
names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}".format(*vals) for vals in str_values]
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
@pytest.mark.parametrize( "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose", values, ids=names)
def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
req_grad = list(req_grad)
req_grad[2] = False
for i in range(k):
# normal multiply
if funcs[0] in [torch.mm, torch.matmul]:
A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype)
B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype)
target = torch.randn(size=(dim2, dim4), device="cuda", requires_grad=req_grad[1], dtype=dtype)
torch.nn.init.xavier_uniform_(B)
fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(A.device)
bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(A.device)
if not transpose[0] and transpose[1]:
out_torch = funcs[0](A, B.t())
out_bnb = funcs[1](A, B.t(), fw_code, bw_code)
elif not transpose[0] and not transpose[1]:
out_torch = funcs[0](A, B)
out_bnb = funcs[1](A, B, fw_code, bw_code)
assert out_bnb.dtype == A.dtype, f"bnb matmullt received {A.dtype} but returned {out_bnb.dtype}"
n = out_bnb.numel()
err = torch.abs(out_bnb - out_torch).float().mean().item()
if n > 0:
assert err < 0.115
#assert err < 0.20
if any(req_grad):
out_bnb.data.copy_(out_torch)
torch.cuda.synchronize()
loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
loss_bnb.backward()
gradA1 = A.grad
gradB1 = B.grad
A.grad = None
B.grad = None
loss_torch = torch.nn.functional.mse_loss( out_torch, target ).mean()
loss_torch.backward()
gradA2 = A.grad
gradB2 = B.grad
A.grad = None
B.grad = None
if req_grad[0]:
torch.testing.assert_close( gradA1, gradA2, atol=0.015, rtol=0.1)
if req_grad[1]:
n = gradB1.numel()
if dim2 > 0:
assert torch.abs(gradB1).sum() > 0.0
assert torch.abs(gradB2).sum() > 0.0
else:
assert torch.abs(gradB1).sum() == 0.0
assert torch.abs(gradB2).sum() == 0.0
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
assert (idx == 0).sum().item() <= n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx == 0).sum().item() <= n * 0.02
grad_err = (gradB1-gradB2).abs().mean()
assert grad_err.item() < 0.003
torch.testing.assert_close(
gradB1, gradB2, atol=0.18, rtol=0.3
)
This diff is collapsed.
...@@ -44,7 +44,7 @@ def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10): ...@@ -44,7 +44,7 @@ def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10):
sumval = (idx == 0).sum().item() sumval = (idx == 0).sum().item()
if sumval > count: if sumval > count:
print(f"Too many values not close: assert {sumval} < {count}") print(f"Too many values not close: assert {sumval} < {count}")
torch.testing.assert_allclose(a, b, rtol, atol) torch.testing.assert_close(a, b, rtol, atol)
class LinearFunction(torch.autograd.Function): class LinearFunction(torch.autograd.Function):
...@@ -330,18 +330,15 @@ def test_linear8bitlt_inference(threshold): ...@@ -330,18 +330,15 @@ def test_linear8bitlt_inference(threshold):
def test_linear8bitlt_accumulated_gradient(): def test_linear8bitlt_accumulated_gradient():
l1 = torch.nn.Sequential( l1 = torch.nn.Sequential(*[bnb.nn.Linear8bitLt(32, 32).cuda().half() for i in range(2)])
*[bnb.nn.Linear8bitLt(32, 32).cuda().half() for i in range(2)] l2 = torch.nn.Sequential(*[torch.nn.Linear(32, 32).cuda().half() for i in range(2)])
) l1[0].weight.data.copy_(l2[0].weight.data)
l2 = torch.nn.Sequential( l1[1].weight.data.copy_(l2[1].weight.data)
*[torch.nn.Linear(32, 32).cuda().half() for i in range(2)] l1[0].bias.data.copy_(l2[0].bias.data)
) l1[1].bias.data.copy_(l2[1].bias.data)
l2[0].weight = torch.nn.Parameter(l1[0].weight.clone())
l2[0].bias = torch.nn.Parameter(l1[0].bias.clone()) opt1 = bnb.optim.Adam32bit(l1.parameters(), lr=0.001)
l2[1].weight = torch.nn.Parameter(l1[1].weight.clone()) opt2 = bnb.optim.Adam32bit(l2.parameters(), lr=0.001)
l2[1].bias = torch.nn.Parameter(l1[1].bias.clone())
opt1 = bnb.optim.Adam8bit(l1.parameters(), lr=0.001)
opt2 = bnb.optim.Adam8bit(l2.parameters(), lr=0.001)
acc_steps = 10 acc_steps = 10
...@@ -371,26 +368,17 @@ def test_linear8bitlt_accumulated_gradient(): ...@@ -371,26 +368,17 @@ def test_linear8bitlt_accumulated_gradient():
# we do this copy because otherwise we have small divergences over time that add up # we do this copy because otherwise we have small divergences over time that add up
l1[0].weight.data.copy_(l2[0].weight.data) l1[0].weight.data.copy_(l2[0].weight.data)
l1[1].weight.data.copy_(l2[1].weight.data) l1[1].weight.data.copy_(l2[1].weight.data)
l1[0].bias.data.copy_(l2[0].bias.data)
l1[1].bias.data.copy_(l2[1].bias.data)
else: else:
torch.testing.assert_allclose(l1[0].weight.grad, l2[0].weight.grad) torch.testing.assert_close(l1[0].weight.grad, l2[0].weight.grad, atol=1e-3, rtol=1e-3)
torch.testing.assert_allclose(l1[1].weight.grad, l2[1].weight.grad) torch.testing.assert_close(l1[1].weight.grad, l2[1].weight.grad, atol=1e-3, rtol=1e-3)
threshold = [0.0, 2.0]
values = threshold
names = [f"threshold_{vals}" for vals in values]
@pytest.mark.parametrize("threshold", values, ids=names) @pytest.mark.parametrize("threshold", [0.0, 2.0])
@pytest.mark.parametrize("memory_efficient_backward", [False]) @pytest.mark.parametrize("memory_efficient_backward", [False])
def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
l1 = ( l1 = (bnb.nn.Linear8bitLt( 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward).cuda().half())
bnb.nn.Linear8bitLt(
32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward
)
.cuda()
.half()
)
assert l1.weight.dtype == torch.int8 assert l1.weight.dtype == torch.int8
l1.eval() l1.eval()
...@@ -446,13 +434,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): ...@@ -446,13 +434,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc1.weight.dtype == torch.int8
assert mlp.fc2.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8
mlp = ( mlp = ( MLP8bit( 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward).half().to("cuda"))
MLP8bit(
32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward
)
.half()
.to("cuda")
)
for i in range(100): for i in range(100):
b1 = torch.randn(16, 8, 32, device="cuda").half() b1 = torch.randn(16, 8, 32, device="cuda").half()
...@@ -499,15 +481,16 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): ...@@ -499,15 +481,16 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
grad_ref = grad_proj.flatten(2) @ w2.half() @ w1.half() grad_ref = grad_proj.flatten(2) @ w2.half() @ w1.half()
scale = grad_ref.abs().mean() scale = grad_ref.abs().mean()
torch.testing.assert_allclose(b1.grad, grad_ref, rtol=0, atol=0.05 * scale) torch.testing.assert_close(b1.grad, grad_ref, rtol=0, atol=0.05 * scale)
idx = torch.isclose(b1.grad, grad_ref, atol=0.01 * scale, rtol=0.1) idx = torch.isclose(b1.grad, grad_ref, atol=0.01 * scale, rtol=0.1)
assert (idx == 0).sum().item() <= b1.numel() * 0.005 assert (idx == 0).sum().item() <= b1.numel() * 0.005
def test_linear8bitlt_fp32_bias(): @pytest.mark.parametrize("module", [lambda nin, nout, bias=True: bnb.nn.Linear8bitLt(nin, nout, bias=bias, has_fp16_weights=False), bnb.nn.LinearFP4], ids=['Int8Lt', 'FP4'])
def test_linear_kbit_fp32_bias(module):
# casts model to fp16 -> int8 automatically # casts model to fp16 -> int8 automatically
l1 = bnb.nn.Linear8bitLt(32, 64, has_fp16_weights=False).cuda() l1 = module(32, 64).cuda()
assert l1.weight.dtype == torch.int8 assert l1.weight.dtype in [torch.int8, torch.uint8]
assert l1.bias.dtype == torch.float32 assert l1.bias.dtype == torch.float32
for i in range(100): for i in range(100):
...@@ -517,11 +500,116 @@ def test_linear8bitlt_fp32_bias(): ...@@ -517,11 +500,116 @@ def test_linear8bitlt_fp32_bias():
assert l1.bias.dtype == torch.float16 assert l1.bias.dtype == torch.float16
# casts model to fp16 -> int8 automatically # casts model to fp16 -> int8 automatically
l1 = bnb.nn.Linear8bitLt(32, 64, has_fp16_weights=False, bias=False).cuda() l1 = module(32, 64, bias=False).cuda()
assert l1.weight.dtype == torch.int8 assert l1.weight.dtype in [torch.int8, torch.uint8]
assert l1.bias is None assert l1.bias is None
for i in range(100): for i in range(100):
b1 = torch.randn(16, 8, 32, device="cuda").half() b1 = torch.randn(16, 8, 32, device="cuda").half()
o1 = l1(b1) o1 = l1(b1)
assert l1.bias is None assert l1.bias is None
modules = []
modules.append(bnb.nn.Linear8bitLt)
modules.append(bnb.nn.Linear4bit)
modules.append(bnb.nn.LinearFP4)
modules.append(bnb.nn.LinearNF4)
modules.append(lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compress_statistics=True))
modules.append(lambda d1, d2: bnb.nn.LinearNF4(d1, d2, compress_statistics=True))
names = ['Int8Lt', '4bit', 'FP4', 'NF4', 'FP4+C', 'NF4+C']
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
@pytest.mark.parametrize("module", modules, ids=names)
def test_kbit_backprop(module):
b = 17
dim1 = 37
dim2 = 83
ref = nn.Sequential(*[torch.nn.Linear(dim1, dim2), torch.nn.Linear(dim2, 10)])
ref[1].weight.requires_grad = False
torch.nn.init.kaiming_normal_(ref[0].weight)
torch.nn.init.kaiming_normal_(ref[1].weight)
kbit = nn.Sequential(*[torch.nn.Linear(dim1, dim2), module(dim2, 10)])
kbit[0].weight.detach().copy_(ref[0].weight)
kbit[1].weight.detach().copy_(ref[1].weight)
kbit[0].bias.detach().copy_(ref[0].bias)
kbit[1].bias.detach().copy_(ref[1].bias)
ref = ref.half().cuda()
kbit = kbit.half().cuda()
errs1 = []
errs2 = []
relerrs1 = []
relerrs2 = []
for i in range(100):
batch = torch.randn(b, dim1).half().cuda()
out1 = ref(batch)
out2 = kbit(batch)
out1.mean().backward()
out2.mean().backward()
grad1 = ref[0].weight.grad
grad2 = kbit[0].weight.grad
bgrad1 = ref[0].bias.grad
bgrad2 = kbit[0].bias.grad
err1 = (out1-out2).abs().float()
err2 = (grad1-grad2).abs().float()
relerr1 = (err1/(out1.abs().float()+1e-9))
relerr2 = (err2/(grad1.abs().float()+1e-9))
errs1.append(err1.mean().item())
errs2.append(err2.mean().item())
relerrs1.append(relerr1.mean().item())
relerrs2.append(relerr2.mean().item())
if isinstance(module, bnb.nn.Linear8bitLt):
torch.testing.assert_close(grad1, grad2, atol=0.008, rtol=0.05)
torch.testing.assert_close(bgrad1, bgrad2, atol=0.008, rtol=0.05)
else:
torch.testing.assert_close(grad1, grad2, atol=0.015, rtol=0.05)
torch.testing.assert_close(bgrad1, bgrad2, atol=0.02, rtol=0.05)
ref.zero_grad()
kbit.zero_grad()
assert kbit[0].weight.grad is None or kbit[0].weight.grad.sum().item() == 0
assert kbit[0].weight.grad is None or kbit[0].bias.grad.sum().item() == 0
print('out', sum(errs1)/len(errs1))
print('grad', sum(errs2)/len(errs2))
print('rel out', sum(relerrs1)/len(relerrs1))
print('rel grad', sum(relerrs2)/len(relerrs2))
def test_fp8linear():
b = 10
h = 1024
inp = torch.randn(b, h).cuda()
fp32 = torch.nn.Linear(h, h*2).cuda()
fp8 = bnb.research.nn.LinearFP8Mixed(h, h*2).cuda()
fp32b = torch.nn.Linear(h*2, h).cuda()
fp8b = bnb.research.nn.LinearFP8Mixed(h*2, h).cuda()
fp8.weight.data.copy_(fp32.weight.data)
fp8.bias.data.copy_(fp32.bias.data)
fp8b.weight.data.copy_(fp32b.weight.data)
fp8b.bias.data.copy_(fp32b.bias.data)
a = fp32b(torch.nn.functional.gelu(fp32(inp)))
b = fp8b(torch.nn.functional.gelu(fp8(inp)))
err = (a-b).abs().mean()
a.mean().backward()
b.mean().backward()
graderr = (fp8.weight.grad-fp32.weight.grad).abs().mean()
bgraderr = (fp8.bias.grad-fp32.bias.grad).abs().mean()
assert err < 0.05
assert graderr < 0.00002
assert bgraderr < 0.00002
This diff is collapsed.
import pytest
import torch
from bitsandbytes.triton.triton_utils import is_triton_available
from bitsandbytes.nn.triton_based_modules import SwitchBackLinear
from bitsandbytes.nn import Linear8bitLt
@pytest.mark.skipif(not is_triton_available() or not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] >= 8,
reason="This test requires triton and a GPU with compute capability 8.0 or higher.")
@pytest.mark.parametrize("vector_wise_quantization", [False, True])
def test_switchback(vector_wise_quantization):
for dim in [83]:
for batch in [13]:
standard = torch.nn.Linear(dim, 4 * dim).cuda().half()
switchback = SwitchBackLinear(dim, 4 * dim, vector_wise_quantization=vector_wise_quantization).cuda().half()
baseline = Linear8bitLt(dim, 4 * dim).cuda().half()
switchback.weight.data.copy_(standard.weight)
switchback.bias.data.copy_(standard.bias)
baseline.weight.data.copy_(standard.weight)
baseline.bias.data.copy_(standard.bias)
x1 = torch.randn(batch, dim).cuda().half().requires_grad_(True)
x2 = x1.clone().detach().requires_grad_(True)
x3 = x1.clone().detach().requires_grad_(True)
out_standard = standard(x1)
(2**10 * out_standard.abs().mean()).backward()
print(x2.dtype)
out_sb = switchback(x2)
(2**10 * out_sb.abs().mean()).backward()
out_baseline = baseline(x3)
(2**10 * out_baseline.abs().mean()).backward()
err_sb = (out_standard - out_sb).abs().mean()
err_baseline = (out_standard - out_baseline).abs().mean()
print('OUT', err_sb, err_baseline)
assert err_sb < 2 * err_baseline
err_sb = (standard.bias.grad - switchback.bias.grad).abs().mean()
err_baseline = (standard.bias.grad - baseline.bias.grad).abs().mean()
print('GW2', err_sb, err_baseline)
assert err_sb < 2 * err_baseline
err_sb = (standard.weight.grad - switchback.weight.grad).abs().mean()
err_baseline = (standard.weight.grad - baseline.weight.grad).abs().mean()
print('GW1', err_sb, err_baseline)
assert err_sb < 2 * err_baseline
err_sb = (x1.grad - x2.grad).abs().mean()
err_baseline = (x1.grad - x3.grad).abs().mean()
print('GX1', err_sb, err_baseline)
assert err_sb < 2 * err_baseline
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment