Commit 675baa79 authored by Tim Dettmers's avatar Tim Dettmers
Browse files

Merge remote-tracking branch 'origin/main' into merge

parents f64cfe65 9e7cdc9e
...@@ -12,10 +12,12 @@ URL116=https://developer.download.nvidia.com/compute/cuda/11.6.2/local_installer ...@@ -12,10 +12,12 @@ URL116=https://developer.download.nvidia.com/compute/cuda/11.6.2/local_installer
URL117=https://developer.download.nvidia.com/compute/cuda/11.7.0/local_installers/cuda_11.7.0_515.43.04_linux.run URL117=https://developer.download.nvidia.com/compute/cuda/11.7.0/local_installers/cuda_11.7.0_515.43.04_linux.run
URL118=https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run URL118=https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run
URL120=https://developer.download.nvidia.com/compute/cuda/12.0.0/local_installers/cuda_12.0.0_525.60.13_linux.run URL120=https://developer.download.nvidia.com/compute/cuda/12.0.0/local_installers/cuda_12.0.0_525.60.13_linux.run
URL121=https://developer.download.nvidia.com/compute/cuda/12.1.0/local_installers/cuda_12.1.0_530.30.02_linux.run
CUDA_VERSION=$1 CUDA_VERSION=$1
BASE_PATH=$2 BASE_PATH=$2
EXPORT_BASHRC=$3
if [[ -n "$CUDA_VERSION" ]]; then if [[ -n "$CUDA_VERSION" ]]; then
if [[ "$CUDA_VERSION" -eq "92" ]]; then if [[ "$CUDA_VERSION" -eq "92" ]]; then
...@@ -60,11 +62,14 @@ if [[ -n "$CUDA_VERSION" ]]; then ...@@ -60,11 +62,14 @@ if [[ -n "$CUDA_VERSION" ]]; then
elif [[ "$CUDA_VERSION" -eq "120" ]]; then elif [[ "$CUDA_VERSION" -eq "120" ]]; then
URL=$URL120 URL=$URL120
FOLDER=cuda-12.0 FOLDER=cuda-12.0
elif [[ "$CUDA_VERSION" -eq "121" ]]; then
URL=$URL121
FOLDER=cuda-12.1
else else
echo "argument error: No cuda version passed as input. Choose among: {111, 115}" echo "argument error: No cuda version passed as input. Choose among versions 92 to 121"
fi fi
else else
echo "argument error: No cuda version passed as input. Choose among: {111, 115}" echo "argument error: No cuda version passed as input. Choose among versions 92 to 112"
fi fi
FILE=$(basename $URL) FILE=$(basename $URL)
...@@ -72,11 +77,13 @@ FILE=$(basename $URL) ...@@ -72,11 +77,13 @@ FILE=$(basename $URL)
if [[ -n "$CUDA_VERSION" ]]; then if [[ -n "$CUDA_VERSION" ]]; then
echo $URL echo $URL
echo $FILE echo $FILE
wget $URL #wget $URL
bash $FILE --no-drm --no-man-page --override --toolkitpath=$BASE_PATH/$FOLDER/ --toolkit --silent bash $FILE --no-drm --no-man-page --override --toolkitpath=$BASE_PATH/$FOLDER/ --toolkit --silent
echo "export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$BASE_PATH/$FOLDER/lib64/" >> ~/.bashrc if [ "$EXPORT_BASHRC" -eq "1" ]; then
echo "export PATH=$PATH:$BASE_PATH/$FOLDER/bin/" >> ~/.bashrc echo "export LD_LIBRARY_PATH=\$LD_LIBRARY_PATH:$BASE_PATH/$FOLDER/lib64" >> ~/.bashrc
source ~/.bashrc echo "export PATH=\$PATH:$BASE_PATH/$FOLDER/bin" >> ~/.bashrc
source ~/.bashrc
fi
else else
echo "" echo ""
fi fi
...@@ -10,8 +10,8 @@ if [[ ! -z "${LD_LIBRARY_PATH}" ]]; then ...@@ -10,8 +10,8 @@ if [[ ! -z "${LD_LIBRARY_PATH}" ]]; then
fi fi
module unload cuda module unload cuda && echo "no module function available. Probably not on a slurm cluster."
module unload gcc module unload gcc && echo "no module function available. Probably not on a slurm cluster."
rm -rf dist build rm -rf dist build
make cleaneggs make cleaneggs
...@@ -128,6 +128,16 @@ if [ ! -f "./bitsandbytes/libbitsandbytes_cuda120.so" ]; then ...@@ -128,6 +128,16 @@ if [ ! -f "./bitsandbytes/libbitsandbytes_cuda120.so" ]; then
exit 64 exit 64
fi fi
make clean
export CUDA_HOME=$BASE_PATH/cuda-12.1
make cuda12x CUDA_VERSION=121
if [ ! -f "./bitsandbytes/libbitsandbytes_cuda121.so" ]; then
# Control will enter here if $DIRECTORY doesn't exist.
echo "Compilation unsuccessul!" 1>&2
exit 64
fi
make clean make clean
export CUDA_HOME=$BASE_PATH/cuda-10.2 export CUDA_HOME=$BASE_PATH/cuda-10.2
...@@ -241,5 +251,15 @@ if [ ! -f "./bitsandbytes/libbitsandbytes_cuda120_nocublaslt.so" ]; then ...@@ -241,5 +251,15 @@ if [ ! -f "./bitsandbytes/libbitsandbytes_cuda120_nocublaslt.so" ]; then
exit 64 exit 64
fi fi
make clean
export CUDA_HOME=$BASE_PATH/cuda-12.1
make cuda12x_nomatmul CUDA_VERSION=121
if [ ! -f "./bitsandbytes/libbitsandbytes_cuda121_nocublaslt.so" ]; then
# Control will enter here if $DIRECTORY doesn't exist.
echo "Compilation unsuccessul!" 1>&2
exit 64
fi
python -m build python -m build
python -m twine upload dist/* --verbose python -m twine upload dist/* --verbose
# No kernel image available # No kernel image available
This problem arises with the cuda version loaded by bitsandbytes is not supported by your GPU, or if you pytorch CUDA version mismatches. So solve this problem you need to debug ``$LD_LIBRARY_PATH``, ``$CUDA_HOME``, ``$PATH``. You can print these via ``echo $PATH``. You should look for multiple paths to different CUDA versions. This can include versions in your anaconda path, for example ``$HOME/anaconda3/lib``. You can check those versions via ``ls -l $HOME/anaconda3/lib/*cuda*`` or equivalent paths. Look at the CUDA versions of files in these paths. Does it match with ``nvidia-smi``? This problem arises with the cuda version loaded by bitsandbytes is not supported by your GPU, or if you pytorch CUDA version mismatches. To solve this problem you need to debug ``$LD_LIBRARY_PATH``, ``$CUDA_HOME``, ``$PATH``. You can print these via ``echo $PATH``. You should look for multiple paths to different CUDA versions. This can include versions in your anaconda path, for example ``$HOME/anaconda3/lib``. You can check those versions via ``ls -l $HOME/anaconda3/lib/*cuda*`` or equivalent paths. Look at the CUDA versions of files in these paths. Does it match with ``nvidia-smi``?
If you are feeling lucky, you can also try to compile the library from source. This can be still problematic if your PATH variables have multiple cuda versions. As such, it is recommended to figure out path conflicts before you proceed with compilation. If you are feeling lucky, you can also try to compile the library from source. This can be still problematic if your PATH variables have multiple cuda versions. As such, it is recommended to figure out path conflicts before you proceed with compilation.
......
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
MAX_NEW_TOKENS = 128
model_name = 'decapoda-research/llama-7b-hf'
text = 'Hamburg is in which country?\n'
tokenizer = AutoTokenizer.from_pretrained(model_name)
input_ids = tokenizer(text, return_tensors="pt").input_ids
free_in_GB = int(torch.cuda.mem_get_info()[0]/1024**3)
max_memory = f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB'
n_gpus = torch.cuda.device_count()
max_memory = {i: max_memory for i in range(n_gpus)}
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map='auto',
load_in_8bit=True,
max_memory=max_memory
)
generated_ids = model.generate(input_ids, max_length=MAX_NEW_TOKENS)
print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
...@@ -18,7 +18,7 @@ def read(fname): ...@@ -18,7 +18,7 @@ def read(fname):
setup( setup(
name=f"bitsandbytes", name=f"bitsandbytes",
version=f"0.37.0", version=f"0.38.1",
author="Tim Dettmers", author="Tim Dettmers",
author_email="dettmers@cs.washington.edu", author_email="dettmers@cs.washington.edu",
description="8-bit optimizers and matrix multiplication routines.", description="8-bit optimizers and matrix multiplication routines.",
......
...@@ -97,7 +97,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): ...@@ -97,7 +97,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
B.grad = None B.grad = None
if req_grad[0]: if req_grad[0]:
torch.testing.assert_allclose( torch.testing.assert_close(
gradA1, gradA2, atol=0.015, rtol=0.1 gradA1, gradA2, atol=0.015, rtol=0.1
) )
if req_grad[1]: if req_grad[1]:
...@@ -106,7 +106,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): ...@@ -106,7 +106,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
assert (idx == 0).sum().item() < n * 0.1 assert (idx == 0).sum().item() < n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx == 0).sum().item() < n * 0.02 assert (idx == 0).sum().item() < n * 0.02
torch.testing.assert_allclose( torch.testing.assert_close(
gradB1, gradB2, atol=0.18, rtol=0.3 gradB1, gradB2, atol=0.18, rtol=0.3
) )
...@@ -135,7 +135,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): ...@@ -135,7 +135,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
n = out_bnb.numel() n = out_bnb.numel()
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1) idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
assert (idx == 0).sum().item() < n * 0.01 assert (idx == 0).sum().item() < n * 0.01
torch.testing.assert_allclose( torch.testing.assert_close(
out_bnb, out_torch, atol=0.027, rtol=0.2 out_bnb, out_torch, atol=0.027, rtol=0.2
) )
...@@ -159,7 +159,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): ...@@ -159,7 +159,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
B.grad = None B.grad = None
if req_grad[0]: if req_grad[0]:
torch.testing.assert_allclose( torch.testing.assert_close(
gradA1, gradA2, atol=0.015, rtol=0.1 gradA1, gradA2, atol=0.015, rtol=0.1
) )
if req_grad[1]: if req_grad[1]:
...@@ -218,7 +218,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): ...@@ -218,7 +218,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
B.grad = None B.grad = None
if req_grad[0]: if req_grad[0]:
torch.testing.assert_allclose( torch.testing.assert_close(
gradA1, gradA2, atol=0.015, rtol=0.1 gradA1, gradA2, atol=0.015, rtol=0.1
) )
if req_grad[1]: if req_grad[1]:
...@@ -239,8 +239,8 @@ dim4 = torch.randint(32, 96, size=(n,)).tolist() ...@@ -239,8 +239,8 @@ dim4 = torch.randint(32, 96, size=(n,)).tolist()
dim2.append(0) dim2.append(0)
decomp = [0.0, 6.0] decomp = [0.0, 6.0]
funcs = [(torch.matmul, bnb.matmul)] funcs = [(torch.matmul, bnb.matmul), (torch.matmul, bnb.research.switchback_bnb)]
str_funcs = ["matmul"] str_funcs = ["matmullt", 'switchback_bnb']
req_grad = [(False, False), (True, False), (True, True), (False, True)] req_grad = [(False, False), (True, False), (True, True), (False, True)]
req_grad = list(product([True, False], repeat=3)) req_grad = list(product([True, False], repeat=3))
req_grad_str = [] req_grad_str = []
...@@ -407,7 +407,7 @@ def test_matmullt( ...@@ -407,7 +407,7 @@ def test_matmullt(
bias.grad = None bias.grad = None
if req_grad[0]: if req_grad[0]:
torch.testing.assert_allclose( torch.testing.assert_close(
gradA1, gradA2, atol=0.015, rtol=0.1 gradA1, gradA2, atol=0.015, rtol=0.1
) )
if req_grad[1]: if req_grad[1]:
...@@ -423,12 +423,12 @@ def test_matmullt( ...@@ -423,12 +423,12 @@ def test_matmullt(
assert (idx == 0).sum().item() <= n * 0.1 assert (idx == 0).sum().item() <= n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx == 0).sum().item() <= n * 0.02 assert (idx == 0).sum().item() <= n * 0.02
torch.testing.assert_allclose( torch.testing.assert_close(
gradB1, gradB2, atol=0.18, rtol=0.3 gradB1, gradB2, atol=0.18, rtol=0.3
) )
if req_grad[2]: if req_grad[2]:
torch.testing.assert_allclose(gradBias1, gradBias2) torch.testing.assert_close(gradBias1, gradBias2)
n = 1 n = 1
...@@ -502,6 +502,7 @@ def test_matmul_4bit( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, ...@@ -502,6 +502,7 @@ def test_matmul_4bit( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
if n > 0: if n > 0:
assert err < 0.115 assert err < 0.115
#assert err < 0.20
if any(req_grad): if any(req_grad):
out_bnb.data.copy_(out_torch) out_bnb.data.copy_(out_torch)
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -526,7 +527,100 @@ def test_matmul_4bit( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, ...@@ -526,7 +527,100 @@ def test_matmul_4bit( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
bias.grad = None bias.grad = None
if req_grad[0]: if req_grad[0]:
torch.testing.assert_allclose( gradA1, gradA2, atol=0.015, rtol=0.1) torch.testing.assert_close( gradA1, gradA2, atol=0.015, rtol=0.1)
if req_grad[2]: if req_grad[2]:
torch.testing.assert_allclose(gradBias1, gradBias2) torch.testing.assert_close(gradBias1, gradBias2)
funcs = [(torch.matmul, bnb.research.matmul_fp8_mixed), (torch.matmul, bnb.research.matmul_fp8_global)]
str_funcs = ["matmul_fp8_mixed", 'matmul_fp8_global']
req_grad = list(product([True, False], repeat=3))
req_grad_str = []
for c in req_grad:
strval = ''
for v in c:
if v == True: strval += 'T'
else: strval += 'F'
req_grad_str.append(strval)
transpose = [(False, True), (False, False)]
str_transpose = ["NT", "NN"]
dtype = [torch.float16, torch.float32]
has_fp16_weights = [True, False]
values = list(product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose))
str_values = list(product(dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose))
names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}".format(*vals) for vals in str_values]
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
@pytest.mark.parametrize( "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose", values, ids=names)
def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
req_grad = list(req_grad)
req_grad[2] = False
for i in range(k):
# normal multiply
if funcs[0] in [torch.mm, torch.matmul]:
A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype)
B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype)
target = torch.randn(size=(dim2, dim4), device="cuda", requires_grad=req_grad[1], dtype=dtype)
torch.nn.init.xavier_uniform_(B)
fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(A.device)
bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(A.device)
if not transpose[0] and transpose[1]:
out_torch = funcs[0](A, B.t())
out_bnb = funcs[1](A, B.t(), fw_code, bw_code)
elif not transpose[0] and not transpose[1]:
out_torch = funcs[0](A, B)
out_bnb = funcs[1](A, B, fw_code, bw_code)
assert out_bnb.dtype == A.dtype, f"bnb matmullt received {A.dtype} but returned {out_bnb.dtype}"
n = out_bnb.numel()
err = torch.abs(out_bnb - out_torch).float().mean().item()
if n > 0:
assert err < 0.115
#assert err < 0.20
if any(req_grad):
out_bnb.data.copy_(out_torch)
torch.cuda.synchronize()
loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
loss_bnb.backward()
gradA1 = A.grad
gradB1 = B.grad
A.grad = None
B.grad = None
loss_torch = torch.nn.functional.mse_loss( out_torch, target ).mean()
loss_torch.backward()
gradA2 = A.grad
gradB2 = B.grad
A.grad = None
B.grad = None
if req_grad[0]:
torch.testing.assert_close( gradA1, gradA2, atol=0.015, rtol=0.1)
if req_grad[1]:
n = gradB1.numel()
if dim2 > 0:
assert torch.abs(gradB1).sum() > 0.0
assert torch.abs(gradB2).sum() > 0.0
else:
assert torch.abs(gradB1).sum() == 0.0
assert torch.abs(gradB2).sum() == 0.0
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
assert (idx == 0).sum().item() <= n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx == 0).sum().item() <= n * 0.02
grad_err = (gradB1-gradB2).abs().mean()
assert grad_err.item() < 0.003
torch.testing.assert_close(
gradB1, gradB2, atol=0.18, rtol=0.3
)
...@@ -5,95 +5,20 @@ import pytest ...@@ -5,95 +5,20 @@ import pytest
import bitsandbytes as bnb import bitsandbytes as bnb
from bitsandbytes.cuda_setup.main import ( from bitsandbytes.cuda_setup.main import (
CUDA_RUNTIME_LIB,
determine_cuda_runtime_lib_path, determine_cuda_runtime_lib_path,
evaluate_cuda_setup, evaluate_cuda_setup,
extract_candidate_paths, extract_candidate_paths,
) )
"""
'LD_LIBRARY_PATH': ':/mnt/D/titus/local/cuda-11.1/lib64/'
'CONDA_EXE': '/mnt/D/titus/miniconda/bin/conda'
'LESSCLOSE': '/usr/bin/lesspipe %s %s'
'OLDPWD': '/mnt/D/titus/src'
'CONDA_PREFIX': '/mnt/D/titus/miniconda/envs/8-bit'
'SSH_AUTH_SOCK': '/mnt/D/titus/.ssh/ssh-agent.tim-uw.sock'
'CONDA_PREFIX_1': '/mnt/D/titus/miniconda'
'PWD': '/mnt/D/titus/src/8-bit'
'HOME': '/mnt/D/titus'
'CONDA_PYTHON_EXE': '/mnt/D/titus/miniconda/bin/python'
'CUDA_HOME': '/mnt/D/titus/local/cuda-11.1/'
'TMUX': '/tmp/tmux-1007/default,59286,1'
'XDG_DATA_DIRS': '/usr/local/share:/usr/share:/var/lib/snapd/desktop'
'SSH_TTY': '/dev/pts/0'
'MAIL': '/var/mail/titus'
'SHELL': '/bin/bash'
'DBUS_SESSION_BUS_ADDRESS': 'unix:path=/run/user/1007/bus'
'XDG_RUNTIME_DIR': '/run/user/1007'
'PATH': '/mnt/D/titus/miniconda/envs/8-bit/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/games:/usr/local/games:/snap/bin:/mnt/D/titus/local/cuda-11.1/bin'
'LESSOPEN': '| /usr/bin/lesspipe %s'
'_': '/mnt/D/titus/miniconda/envs/8-bit/bin/python'
# any that include 'CONDA' that are not 'CONDA_PREFIX'
# we search for def test_cuda_full_system():
'CUDA_HOME': '/mnt/D/titus/local/cuda-11.1/'
"""
class InputAndExpectedOutput(NamedTuple):
input: str
output: str
HAPPY_PATH__LD_LIB_TEST_PATHS: List[InputAndExpectedOutput] = [
(
f"some/other/dir:dir/with/{CUDA_RUNTIME_LIB}",
f"dir/with/{CUDA_RUNTIME_LIB}",
),
(
f":some/other/dir:dir/with/{CUDA_RUNTIME_LIB}",
f"dir/with/{CUDA_RUNTIME_LIB}",
),
(
f"some/other/dir:dir/with/{CUDA_RUNTIME_LIB}:",
f"dir/with/{CUDA_RUNTIME_LIB}",
),
(
f"some/other/dir::dir/with/{CUDA_RUNTIME_LIB}",
f"dir/with/{CUDA_RUNTIME_LIB}",
),
(
f"dir/with/{CUDA_RUNTIME_LIB}:some/other/dir",
f"dir/with/{CUDA_RUNTIME_LIB}",
),
(
f"dir/with/{CUDA_RUNTIME_LIB}:other/dir/libcuda.so",
f"dir/with/{CUDA_RUNTIME_LIB}",
),
]
@pytest.fixture(params=HAPPY_PATH__LD_LIB_TEST_PATHS)
def happy_path_path_string(tmpdir, request):
for path in extract_candidate_paths(request.param):
test_dir.mkdir()
if CUDA_RUNTIME_LIB in path:
(test_input / CUDA_RUNTIME_LIB).touch()
UNHAPPY_PATH__LD_LIB_TEST_PATHS = [
f"a/b/c/{CUDA_RUNTIME_LIB}:d/e/f/{CUDA_RUNTIME_LIB}",
f"a/b/c/{CUDA_RUNTIME_LIB}:d/e/f/{CUDA_RUNTIME_LIB}:g/h/j/{CUDA_RUNTIME_LIB}",
]
def test_full_system():
## this only tests the cuda version and not compute capability ## this only tests the cuda version and not compute capability
# if CONDA_PREFIX exists, it has priority before all other env variables # if CONDA_PREFIX exists, it has priority before all other env variables
# but it does not contain the library directly, so we need to look at the a sub-folder # but it does not contain the library directly, so we need to look at the a sub-folder
version = "" version = ""
if "CONDA_PREFIX" in os.environ: if "CONDA_PREFIX" in os.environ:
ls_output, err = bnb.utils.execute_and_return(f'ls -l {os.environ["CONDA_PREFIX"]}/lib/libcudart.so') ls_output, err = bnb.utils.execute_and_return(f'ls -l {os.environ["CONDA_PREFIX"]}/lib/libcudart.so.11.0')
major, minor, revision = (ls_output.split(" ")[-1].replace("libcudart.so.", "").split(".")) major, minor, revision = (ls_output.split(" ")[-1].replace("libcudart.so.", "").split("."))
version = float(f"{major}.{minor}") version = float(f"{major}.{minor}")
......
...@@ -24,7 +24,7 @@ def assert_all_approx_close(a, b, rtol=1e-3, atol=1e-3, count=0, throw=True): ...@@ -24,7 +24,7 @@ def assert_all_approx_close(a, b, rtol=1e-3, atol=1e-3, count=0, throw=True):
if sumval > count: if sumval > count:
if throw: if throw:
print(f"Too many values not close: assert {sumval} < {count}") print(f"Too many values not close: assert {sumval} < {count}")
torch.testing.assert_allclose(a, b, rtol, atol) torch.testing.assert_close(a, b, rtol, atol)
return sumval return sumval
...@@ -100,7 +100,7 @@ def test_estimate_quantiles(dtype): ...@@ -100,7 +100,7 @@ def test_estimate_quantiles(dtype):
code = F.estimate_quantiles(A) code = F.estimate_quantiles(A)
percs = torch.linspace(1 / 512, 511 / 512, 256, device=A.device) percs = torch.linspace(1 / 512, 511 / 512, 256, device=A.device)
torch.testing.assert_allclose(percs, code, atol=1e-3, rtol=1e-2) torch.testing.assert_close(percs, code, atol=1e-3, rtol=1e-2)
A = torch.randn(1024, 1024, device="cuda") A = torch.randn(1024, 1024, device="cuda")
A = A.to(dtype) A = A.to(dtype)
...@@ -125,7 +125,7 @@ def test_quantile_quantization(): ...@@ -125,7 +125,7 @@ def test_quantile_quantization():
C = F.quantize_no_absmax(A1, code) C = F.quantize_no_absmax(A1, code)
A2 = F.dequantize_no_absmax(C, code) A2 = F.dequantize_no_absmax(C, code)
diff = torch.abs(A1 - A2).mean().item() diff = torch.abs(A1 - A2).mean().item()
torch.testing.assert_allclose(A1, A2, atol=5e-3, rtol=0) torch.testing.assert_close(A1, A2, atol=5e-3, rtol=0)
assert diff < 0.001 assert diff < 0.001
...@@ -149,7 +149,7 @@ def test_dynamic_quantization(): ...@@ -149,7 +149,7 @@ def test_dynamic_quantization():
C, S = F.quantize(A1) C, S = F.quantize(A1)
A2 = F.dequantize(C, S) A2 = F.dequantize(C, S)
diff = torch.abs(A1 - A2).mean().item() diff = torch.abs(A1 - A2).mean().item()
torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0) torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0)
assert diff < 0.004 assert diff < 0.004
...@@ -184,7 +184,7 @@ def test_dynamic_blockwise_quantization(nested, blocksize): ...@@ -184,7 +184,7 @@ def test_dynamic_blockwise_quantization(nested, blocksize):
reldiff = diff / torch.abs(A1 + 1e-8) reldiff = diff / torch.abs(A1 + 1e-8)
diffs.append(diff.mean().item()) diffs.append(diff.mean().item())
reldiffs.append(reldiff.mean().item()) reldiffs.append(reldiff.mean().item())
#torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0) #torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0)
abserr = sum(diffs)/len(diffs) abserr = sum(diffs)/len(diffs)
relerr = sum(reldiffs)/len(reldiffs) relerr = sum(reldiffs)/len(reldiffs)
assert abserr < 0.0035 assert abserr < 0.0035
...@@ -193,22 +193,6 @@ def test_dynamic_blockwise_quantization(nested, blocksize): ...@@ -193,22 +193,6 @@ def test_dynamic_blockwise_quantization(nested, blocksize):
#print('nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs)) #print('nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs))
def test_dynamic_blockwise_stochastic_quantization():
diffs = []
reldiffs = []
rand = torch.rand(1024).cuda()
for i in range(100):
A1 = torch.randn(1024, 1024, device="cuda")
C1, S1 = F.quantize_blockwise(A1, rand=rand)
C2, S2 = F.quantize_blockwise(A1)
# a maximunm distance of quantized values of 1
torch.testing.assert_allclose(C1, C2, atol=1, rtol=0)
fraction_smaller = (C1 < C2).float().sum() / C1.numel()
fraction_larger = (C1 > C2).float().sum() / C1.numel()
torch.testing.assert_allclose(
fraction_larger, fraction_smaller, atol=0.01, rtol=0
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"gtype", [torch.float32, torch.float16], ids=["float", "half"] "gtype", [torch.float32, torch.float16], ids=["float", "half"]
...@@ -236,9 +220,9 @@ def test_percentile_clipping(gtype): ...@@ -236,9 +220,9 @@ def test_percentile_clipping(gtype):
vals, idx = torch.sort(gnorm_vec1) vals, idx = torch.sort(gnorm_vec1)
clip1 = vals[percentile] clip1 = vals[percentile]
torch.testing.assert_allclose(gnorm_vec1, torch.sqrt(gnorm_vec2)) torch.testing.assert_close(gnorm_vec1, torch.sqrt(gnorm_vec2))
torch.testing.assert_allclose(clip1, clip2) torch.testing.assert_close(clip1, clip2)
torch.testing.assert_allclose(gnorm1, gnorm2) torch.testing.assert_close(gnorm1, gnorm2)
def quant(x): def quant(x):
...@@ -332,7 +316,7 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched): ...@@ -332,7 +316,7 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched):
B = torch.normal(0, 0.5, size=(dim2, dim1), device="cuda") B = torch.normal(0, 0.5, size=(dim2, dim1), device="cuda")
maxA, Ac = quant_methods[0](A, 1) maxA, Ac = quant_methods[0](A, 1)
maxB, Bc = quant_methods[1](B, 0) maxB, Bc = quant_methods[1](B, 0)
torch.testing.assert_allclose( torch.testing.assert_close(
quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05 quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05
) )
if batched: if batched:
...@@ -403,7 +387,7 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim): ...@@ -403,7 +387,7 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
out2 = torch.matmul(A.t().float(), B.t().float()) out2 = torch.matmul(A.t().float(), B.t().float())
out = F.igemm(A.t(), B.t()) out = F.igemm(A.t(), B.t())
torch.testing.assert_allclose(out.float(), out2) torch.testing.assert_close(out.float(), out2)
for i in range(k): for i in range(k):
shapeA = (batch_dim, seq_dim, hidden_dim) shapeA = (batch_dim, seq_dim, hidden_dim)
...@@ -421,7 +405,7 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim): ...@@ -421,7 +405,7 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
out2 = torch.matmul(A.float(), B.t().float()) out2 = torch.matmul(A.float(), B.t().float())
out = F.igemm(A, B.t()) out = F.igemm(A, B.t())
torch.testing.assert_allclose(out.float(), out2) torch.testing.assert_close(out.float(), out2)
n = 3 n = 3
...@@ -452,7 +436,7 @@ def test_dim3_igemm(seq_dim, hidden_dim, batch_dim): ...@@ -452,7 +436,7 @@ def test_dim3_igemm(seq_dim, hidden_dim, batch_dim):
) )
out = F.igemm(A, B, out=iout) out = F.igemm(A, B, out=iout)
torch.testing.assert_allclose(out.float(), out2) torch.testing.assert_close(out.float(), out2)
n = 2 n = 2
...@@ -577,7 +561,7 @@ def test_ibmm(dim1, dim2, dim3, dim4, transpose): ...@@ -577,7 +561,7 @@ def test_ibmm(dim1, dim2, dim3, dim4, transpose):
A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float() A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float()
) )
out = F.igemm(A.permute([0, 2, 1]), B.permute([0, 2, 1])) out = F.igemm(A.permute([0, 2, 1]), B.permute([0, 2, 1]))
torch.testing.assert_allclose(out.float(), out2.float()) torch.testing.assert_close(out.float(), out2.float())
n = 1 n = 1
...@@ -635,9 +619,9 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans ...@@ -635,9 +619,9 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
out, S = F.nvidia_transform(A, to_order=orderOut) out, S = F.nvidia_transform(A, to_order=orderOut)
if orderOut == "row": if orderOut == "row":
torch.testing.assert_allclose(A.flatten(), out.flatten()) torch.testing.assert_close(A.flatten(), out.flatten())
elif orderOut == "col": elif orderOut == "col":
torch.testing.assert_allclose(A.t().flatten(), out.flatten()) torch.testing.assert_close(A.t().flatten(), out.flatten())
elif orderOut == "col32": elif orderOut == "col32":
if dims == 2: if dims == 2:
n = A.shape[0] * (A.shape[1] + (32 - (A.shape[1] % 32))) n = A.shape[0] * (A.shape[1] + (32 - (A.shape[1] % 32)))
...@@ -670,14 +654,14 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans ...@@ -670,14 +654,14 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
assert A.flatten()[i + j] == A[row, col] assert A.flatten()[i + j] == A[row, col]
# assert A.flatten()[i+j] == out.flatten()[row2+col2] # assert A.flatten()[i+j] == out.flatten()[row2+col2]
# torch.testing.assert_allclose(A.flatten()[i+j], A[row, col]) # torch.testing.assert_close(A.flatten()[i+j], A[row, col])
# torch.testing.assert_allclose(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset]) # torch.testing.assert_close(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset])
if orderOut == "col32": if orderOut == "col32":
out2, S = F.nvidia_transform( out2, S = F.nvidia_transform(
out, from_order=orderOut, to_order="row", state=S out, from_order=orderOut, to_order="row", state=S
) )
torch.testing.assert_allclose(A, out2) torch.testing.assert_close(A, out2)
n = 1 n = 1
...@@ -721,7 +705,7 @@ def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb): ...@@ -721,7 +705,7 @@ def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb):
B2, SB = F.transform(B, "col_turing") B2, SB = F.transform(B, "col_turing")
C2, SC = F.igemmlt(A2, B2, SA, SB) C2, SC = F.igemmlt(A2, B2, SA, SB)
C3, S = F.nvidia_transform(C2, "row", state=SC) C3, S = F.nvidia_transform(C2, "row", state=SC)
torch.testing.assert_allclose(C1, C3.float()) torch.testing.assert_close(C1, C3.float())
# transpose # transpose
B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to( B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to(
...@@ -732,7 +716,7 @@ def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb): ...@@ -732,7 +716,7 @@ def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb):
B2t, SBt = F.transform(B, "col_turing", transpose=True) B2t, SBt = F.transform(B, "col_turing", transpose=True)
C2, SC = F.igemmlt(A2, B2t, SA, SBt) C2, SC = F.igemmlt(A2, B2t, SA, SBt)
C3, S = F.nvidia_transform(C2, "row", state=SC) C3, S = F.nvidia_transform(C2, "row", state=SC)
torch.testing.assert_allclose(C1, C3.float()) torch.testing.assert_close(C1, C3.float())
dim1 = [32] dim1 = [32]
...@@ -778,7 +762,7 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims): ...@@ -778,7 +762,7 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims):
# print(C1.flatten()[:10]) # print(C1.flatten()[:10])
# print(C2.flatten()[:10]) # print(C2.flatten()[:10])
# torch.testing.assert_allclose(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05) # torch.testing.assert_close(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05)
# transpose # transpose
# B = torch.randint(-128, 127, size=(dim3, dim4), device='cuda').to(torch.int8) # B = torch.randint(-128, 127, size=(dim3, dim4), device='cuda').to(torch.int8)
...@@ -787,7 +771,7 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims): ...@@ -787,7 +771,7 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims):
# B2t, SBt = F.transform2(B, 'col_turing', transpose=True) # B2t, SBt = F.transform2(B, 'col_turing', transpose=True)
# C2, SC = F.igemmlt(A2, B2t, SA, SBt) # C2, SC = F.igemmlt(A2, B2t, SA, SBt)
# C3, S = F.transform(C2, 'row', state=SC) # C3, S = F.transform(C2, 'row', state=SC)
# torch.testing.assert_allclose(C1, C3.float()) # torch.testing.assert_close(C1, C3.float())
batch_size = 2 batch_size = 2
...@@ -1006,7 +990,7 @@ def test_dequant_mm(dim1, dim4, dims, formatB, has_bias): ...@@ -1006,7 +990,7 @@ def test_dequant_mm(dim1, dim4, dims, formatB, has_bias):
#assert (count / n < p), f"error in more than {p} of elements: {count}/{n}={count/n}" #assert (count / n < p), f"error in more than {p} of elements: {count}/{n}={count/n}"
C5 = F.mm_dequant(C2, SC, maxA.flatten(), maxB.flatten(), bias=bias) C5 = F.mm_dequant(C2, SC, maxA.flatten(), maxB.flatten(), bias=bias)
#torch.testing.assert_allclose(C5, C4, atol=0.015, rtol=0.1) #torch.testing.assert_close(C5, C4, atol=0.015, rtol=0.1)
n = C5.numel() n = C5.numel()
assert_all_approx_close(C1, C4, atol=0.015, rtol=0.1, count=int(0.01*n)) assert_all_approx_close(C1, C4, atol=0.015, rtol=0.1, count=int(0.01*n))
...@@ -1056,16 +1040,16 @@ def test_colrow_absmax(dim1, dim2, dims): ...@@ -1056,16 +1040,16 @@ def test_colrow_absmax(dim1, dim2, dims):
) )
nnz_block_ptr1[1:] = nnz_rows1_counts.cumsum(0) nnz_block_ptr1[1:] = nnz_rows1_counts.cumsum(0)
torch.testing.assert_allclose(col_stats1_trunc, col_stats2) torch.testing.assert_close(col_stats1_trunc, col_stats2)
torch.testing.assert_allclose(row_stats1_trunc, row_stats2) torch.testing.assert_close(row_stats1_trunc, row_stats2)
torch.testing.assert_allclose(nnz_block_ptr1, nnz_block_ptr2) torch.testing.assert_close(nnz_block_ptr1.int(), nnz_block_ptr2)
row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax( row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(
A, threshold=0.0 A, threshold=0.0
) )
torch.testing.assert_allclose(col_stats1, col_stats2) torch.testing.assert_close(col_stats1, col_stats2)
torch.testing.assert_allclose(row_stats1, row_stats2) torch.testing.assert_close(row_stats1, row_stats2)
assert nnz_block_ptr2 is None assert nnz_block_ptr2 is None
...@@ -1089,8 +1073,8 @@ def test_double_quant(dim1, dim2): ...@@ -1089,8 +1073,8 @@ def test_double_quant(dim1, dim2):
CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A) CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
# max difference is 1 due to rounding differences # max difference is 1 due to rounding differences
torch.testing.assert_allclose(CA, out_row1, atol=1, rtol=0) torch.testing.assert_close(CA, out_row1, atol=1, rtol=0)
torch.testing.assert_allclose(CAt, out_col1, atol=1, rtol=0) torch.testing.assert_close(CAt, out_col1, atol=1, rtol=0)
n = CAt.numel() n = CAt.numel()
num_not_close_rows = ( num_not_close_rows = (
...@@ -1113,8 +1097,8 @@ def test_double_quant(dim1, dim2): ...@@ -1113,8 +1097,8 @@ def test_double_quant(dim1, dim2):
) )
assert False assert False
torch.testing.assert_allclose(Srow.flatten(), statsA) torch.testing.assert_close(Srow.flatten().float(), statsA)
torch.testing.assert_allclose(Scol.flatten(), statsAt) torch.testing.assert_close(Scol.flatten().float(), statsAt)
n = 4 n = 4
...@@ -1139,10 +1123,10 @@ def test_integrated_igemmlt(dim1, dim4, inner): ...@@ -1139,10 +1123,10 @@ def test_integrated_igemmlt(dim1, dim4, inner):
A1, maxA = F.vectorwise_quant(A, dim=1) A1, maxA = F.vectorwise_quant(A, dim=1)
B1, maxB = F.vectorwise_quant(B, dim=1) B1, maxB = F.vectorwise_quant(B, dim=1)
torch.testing.assert_allclose(maxA.flatten(), stats1a) torch.testing.assert_close(maxA.flatten().float(), stats1a)
torch.testing.assert_allclose(maxB.flatten(), stats2a) torch.testing.assert_close(maxB.flatten().float(), stats2a)
torch.testing.assert_allclose(C1a, A1, rtol=0, atol=1) torch.testing.assert_close(C1a, A1, rtol=0, atol=1)
torch.testing.assert_allclose(C2a, B1, rtol=0, atol=1) torch.testing.assert_close(C2a, B1, rtol=0, atol=1)
A2, SA = F.nvidia_transform(C1a, "col32") A2, SA = F.nvidia_transform(C1a, "col32")
B2, SB = F.nvidia_transform(C2a, "col_turing") B2, SB = F.nvidia_transform(C2a, "col_turing")
...@@ -1344,7 +1328,7 @@ def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose): ...@@ -1344,7 +1328,7 @@ def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
# print(out1) # print(out1)
# print(out2) # print(out2)
torch.testing.assert_allclose(out1, out2) torch.testing.assert_close(out1, out2)
n = 2 n = 2
...@@ -1406,11 +1390,11 @@ def test_coo_double_quant(dim1, dim2): ...@@ -1406,11 +1390,11 @@ def test_coo_double_quant(dim1, dim2):
A2[ A2[
coo_tensor.rowidx.long(), coo_tensor.colidx.long() coo_tensor.rowidx.long(), coo_tensor.colidx.long()
] = coo_tensor.values ] = coo_tensor.values
torch.testing.assert_allclose(A1, A2) torch.testing.assert_close(A1, A2)
A1 = A * (idx == 0) A1 = A * (idx == 0)
A2 = (CA.float() * statsA.unsqueeze(1) / 127).half() A2 = (CA.float() * statsA.unsqueeze(1) / 127).half()
torch.testing.assert_allclose( torch.testing.assert_close(
A * (idx == 0), A2, rtol=0.05, atol=1.5e-2 A * (idx == 0), A2, rtol=0.05, atol=1.5e-2
) )
...@@ -1618,7 +1602,7 @@ def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func): ...@@ -1618,7 +1602,7 @@ def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func):
idx_col = torch.randint(0, A2.shape[-1], size=(15,)) idx_col = torch.randint(0, A2.shape[-1], size=(15,))
# torch.testing.assert_allclose(out1, out2.half(), rtol=0.05, atol=0.001) # torch.testing.assert_close(out1, out2.half(), rtol=0.05, atol=0.001)
# Bt = torch.randn(dim2*4, dim2, device='cuda').half() # Bt = torch.randn(dim2*4, dim2, device='cuda').half()
# torch.cuda.synchronize() # torch.cuda.synchronize()
...@@ -1649,9 +1633,9 @@ def test_coo2csr(): ...@@ -1649,9 +1633,9 @@ def test_coo2csr():
counts = csrA.rowptr[1:] - csrA.rowptr[:-1] counts = csrA.rowptr[1:] - csrA.rowptr[:-1]
assert counts.numel() == A.shape[0] assert counts.numel() == A.shape[0]
torch.testing.assert_allclose(counts, (A2 != 0).sum(1)) torch.testing.assert_close(counts.long(), (A2 != 0).sum(1))
idx = A2 != 0 idx = A2 != 0
torch.testing.assert_allclose(A2[idx], csrA.values) torch.testing.assert_close(A2[idx], csrA.values)
def test_coo2csc(): def test_coo2csc():
...@@ -1669,10 +1653,10 @@ def test_coo2csc(): ...@@ -1669,10 +1653,10 @@ def test_coo2csc():
counts = cscA.colptr[1:] - cscA.colptr[:-1] counts = cscA.colptr[1:] - cscA.colptr[:-1]
assert counts.numel() == A.shape[1] assert counts.numel() == A.shape[1]
torch.testing.assert_allclose(counts, (A2 != 0).sum(0)) torch.testing.assert_close(counts.long(), (A2 != 0).sum(0))
# torch uses row-major -> use transpose to transfer to col-major # torch uses row-major -> use transpose to transfer to col-major
idx = A2.t() != 0 idx = A2.t() != 0
torch.testing.assert_allclose(A2.t()[idx], cscA.values) torch.testing.assert_close(A2.t()[idx], cscA.values)
n = 2 n = 2
...@@ -1722,7 +1706,7 @@ def test_spmm_coo_dequant(dim1, dim2, dtype): ...@@ -1722,7 +1706,7 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
max_count, max_idx = torch.sort(counts, descending=True) max_count, max_idx = torch.sort(counts, descending=True)
print(torch.median(max_count.float())) print(torch.median(max_count.float()))
torch.testing.assert_allclose(out2, out3, rtol=0.05, atol=0.001) torch.testing.assert_close(out2, out3, rtol=0.05, atol=0.001)
p = 200 / (2048 * 12288 * 4) p = 200 / (2048 * 12288 * 4)
n = out1.numel() n = out1.numel()
...@@ -1793,13 +1777,13 @@ batch_size = 2 ...@@ -1793,13 +1777,13 @@ batch_size = 2
seqdim = 2048 seqdim = 2048
values = [] values = []
values.append((batch_size, seqdim, 768, 4 * 768)) values.append((batch_size, seqdim, 768, 4 * 768))
values.append((batch_size, seqdim, 1024, 4*1024)) #values.append((batch_size, seqdim, 1024, 4*1024))
values.append((batch_size, seqdim, 1536, 4*1536)) #values.append((batch_size, seqdim, 1536, 4*1536))
values.append((batch_size, seqdim, 2048, 4*2048)) #values.append((batch_size, seqdim, 2048, 4*2048))
values.append((batch_size, seqdim, 2560, 4*2560)) #values.append((batch_size, seqdim, 2560, 4*2560))
values.append((batch_size, seqdim, 4096, 4*4096)) #values.append((batch_size, seqdim, 4096, 4*4096))
values.append((batch_size, seqdim, 5140, 4*5140)) #values.append((batch_size, seqdim, 5140, 4*5140))
values.append((batch_size, seqdim, 12288, 4*12288)) #values.append((batch_size, seqdim, 12288, 4*12288))
names = ["batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values] names = ["batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names) @pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)
def test_bench_matmul(batch, seq, model, hidden): def test_bench_matmul(batch, seq, model, hidden):
...@@ -2047,7 +2031,7 @@ def test_extract_outliers(): ...@@ -2047,7 +2031,7 @@ def test_extract_outliers():
assert outliers2.shape[0] == shapeA[0] assert outliers2.shape[0] == shapeA[0]
assert outliers2.shape[1] == idx.numel() assert outliers2.shape[1] == idx.numel()
torch.testing.assert_allclose(outliers1, outliers2) torch.testing.assert_close(outliers1, outliers2)
CA, SA = F.transform(A, "col_ampere") CA, SA = F.transform(A, "col_ampere")
...@@ -2056,7 +2040,7 @@ def test_extract_outliers(): ...@@ -2056,7 +2040,7 @@ def test_extract_outliers():
assert outliers2.shape[0] == shapeA[0] assert outliers2.shape[0] == shapeA[0]
assert outliers2.shape[1] == idx.numel() assert outliers2.shape[1] == idx.numel()
torch.testing.assert_allclose(outliers1, outliers2) torch.testing.assert_close(outliers1, outliers2)
...@@ -2186,7 +2170,7 @@ def test_few_bit_quant(): ...@@ -2186,7 +2170,7 @@ def test_few_bit_quant():
#assert err2.mean() <= err1 #assert err2.mean() <= err1
else: else:
torch.testing.assert_allclose(q1, q2) torch.testing.assert_close(q1, q2)
#print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs)) #print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs))
#assert False #assert False
...@@ -2218,7 +2202,9 @@ def test_kbit_quantile_estimation(): ...@@ -2218,7 +2202,9 @@ def test_kbit_quantile_estimation():
def test_bench_dequantization(): def test_bench_dequantization():
a = torch.rand(1024, 1024, device='cuda').half() a = torch.rand(1024, 1024, device='cuda').half()
qa, SA = F.quantize_blockwise(a) code =F.create_fp8_map(True, 3, 0, 4).cuda()
qa, SA = F.quantize_blockwise(a, code=code)
print(qa.max())
max_theoretical_mu = 1024*1024*2/1024**3/672*1000*1000 max_theoretical_mu = 1024*1024*2/1024**3/672*1000*1000
#print(max_theoretical_mu) #print(max_theoretical_mu)
...@@ -2489,6 +2475,7 @@ def test_gemm_4bit(dtype): ...@@ -2489,6 +2475,7 @@ def test_gemm_4bit(dtype):
print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim)) print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim))
print(dim, (max_err.item(), max_relerr.item())) print(dim, (max_err.item(), max_relerr.item()))
@pytest.mark.skip("Row scale has some bugs for ampere")
def test_managed(): def test_managed():
n = 32*10 n = 32*10
A = F.get_paged(n, n, dtype=torch.float32) A = F.get_paged(n, n, dtype=torch.float32)
...@@ -2523,4 +2510,4 @@ def test_managed(): ...@@ -2523,4 +2510,4 @@ def test_managed():
# assert (A==17).sum().item() == n*n # assert (A==17).sum().item() == n*n
# torch.testing.assert_allclose(A, torch.ones(A.shape)*289) # torch.testing.assert_close(A, torch.ones(A.shape)*289)
import bitsandbytes as bnb import os
from contextlib import nullcontext
from itertools import product
from tempfile import TemporaryDirectory
import pytest import pytest
import torch import torch
from bitsandbytes import functional as F
import bitsandbytes as bnb
from bitsandbytes import functional as F
from bitsandbytes.autograd import get_inverse_transform_indices, undo_layout from bitsandbytes.autograd import get_inverse_transform_indices, undo_layout
from bitsandbytes.nn.modules import Linear8bitLt from bitsandbytes.nn.modules import Linear8bitLt
# contributed by Alex Borzunov, see: # contributed by Alex Borzunov, see:
# https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py # https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py
...@@ -26,6 +32,7 @@ def test_layout_exact_match(): ...@@ -26,6 +32,7 @@ def test_layout_exact_match():
assert restored_x.is_contiguous() assert restored_x.is_contiguous()
assert torch.all(torch.eq(restored_x, x)) assert torch.all(torch.eq(restored_x, x))
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU") @pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
def test_linear_no_igemmlt(): def test_linear_no_igemmlt():
linear = torch.nn.Linear(1024, 3072) linear = torch.nn.Linear(1024, 3072)
...@@ -43,7 +50,7 @@ def test_linear_no_igemmlt(): ...@@ -43,7 +50,7 @@ def test_linear_no_igemmlt():
linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False
).to(linear.weight.dtype) ).to(linear.weight.dtype)
linear_custom.bias = linear.bias linear_custom.bias = linear.bias
linear = linear_custom.cuda() linear_custom = linear_custom.cuda()
linear = linear.half().cuda() linear = linear.half().cuda()
x_ref = x.clone().cuda().requires_grad_(True) x_ref = x.clone().cuda().requires_grad_(True)
...@@ -59,3 +66,78 @@ def test_linear_no_igemmlt(): ...@@ -59,3 +66,78 @@ def test_linear_no_igemmlt():
assert not linear_custom.state.has_fp16_weights assert not linear_custom.state.has_fp16_weights
assert linear_custom.state.CB is not None assert linear_custom.state.CB is not None
assert linear_custom.state.CxB is None assert linear_custom.state.CxB is None
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
@pytest.mark.parametrize("has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt",
list(product([False, True], [False, True], [False, True], [False, True])))
def test_linear_serialization(has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt):
linear = torch.nn.Linear(32, 96)
x = torch.randn(3, 32, dtype=torch.half)
linear_custom = Linear8bitLt(
linear.in_features,
linear.out_features,
linear.bias is not None,
has_fp16_weights=has_fp16_weights,
threshold=6.0,
)
if force_no_igemmlt:
linear_custom.state.force_no_igemmlt = True
linear_custom.weight = bnb.nn.Int8Params(
linear.weight.data.clone(), requires_grad=has_fp16_weights, has_fp16_weights=has_fp16_weights
)
linear_custom.bias = linear.bias
linear_custom = linear_custom.cuda()
if serialize_before_forward:
state_dict_8bit = linear_custom.state_dict()
x_first = x.clone().cuda().requires_grad_(True)
fx_first = linear_custom(x_first).float()
grad_proj = torch.randn_like(fx_first)
(fx_first * grad_proj).mean().backward()
if not serialize_before_forward:
state_dict_8bit = linear_custom.state_dict()
with TemporaryDirectory() as tmpdir:
state_path_8bit = os.path.join(tmpdir, "state_8bit.pth")
state_path = os.path.join(tmpdir, "state.pth")
torch.save(linear.state_dict(), state_path)
torch.save(state_dict_8bit, state_path_8bit)
if not has_fp16_weights:
assert os.path.getsize(state_path_8bit) < 0.5 * os.path.getsize(state_path)
new_state_dict = torch.load(state_path_8bit)
new_linear_custom = Linear8bitLt(
linear.in_features,
linear.out_features,
linear.bias is not None,
has_fp16_weights=has_fp16_weights,
threshold=6.0,
)
if force_no_igemmlt:
new_linear_custom.state.force_no_igemmlt = True
if deserialize_before_cuda:
with nullcontext() if has_fp16_weights else pytest.raises(RuntimeError):
new_linear_custom.load_state_dict(new_state_dict, strict=True)
new_linear_custom = new_linear_custom.cuda()
if not deserialize_before_cuda:
new_linear_custom.load_state_dict(new_state_dict, strict=True)
x_second = x.clone().cuda().requires_grad_(True)
fx_second = new_linear_custom(x_second).float()
(fx_second * grad_proj).mean().backward()
# if 8-bit weights were loaded before .cuda, state is incorrect anyway and RuntimeError was raised
if has_fp16_weights or not deserialize_before_cuda:
assert torch.allclose(fx_first, fx_second, atol=1e-5)
assert torch.allclose(x_first.grad, x_second.grad, atol=1e-5)
...@@ -44,7 +44,7 @@ def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10): ...@@ -44,7 +44,7 @@ def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10):
sumval = (idx == 0).sum().item() sumval = (idx == 0).sum().item()
if sumval > count: if sumval > count:
print(f"Too many values not close: assert {sumval} < {count}") print(f"Too many values not close: assert {sumval} < {count}")
torch.testing.assert_allclose(a, b, rtol, atol) torch.testing.assert_close(a, b, rtol, atol)
class LinearFunction(torch.autograd.Function): class LinearFunction(torch.autograd.Function):
...@@ -353,6 +353,7 @@ def test_linear8bitlt_accumulated_gradient(): ...@@ -353,6 +353,7 @@ def test_linear8bitlt_accumulated_gradient():
assert l1[0].state.CxB is not None assert l1[0].state.CxB is not None
assert l1[1].state.CxB is not None assert l1[1].state.CxB is not None
print(i)
if i > 0 and i % acc_steps == 0: if i > 0 and i % acc_steps == 0:
opt1.step() opt1.step()
opt1.zero_grad(True) opt1.zero_grad(True)
...@@ -368,8 +369,8 @@ def test_linear8bitlt_accumulated_gradient(): ...@@ -368,8 +369,8 @@ def test_linear8bitlt_accumulated_gradient():
l1[0].weight.data.copy_(l2[0].weight.data) l1[0].weight.data.copy_(l2[0].weight.data)
l1[1].weight.data.copy_(l2[1].weight.data) l1[1].weight.data.copy_(l2[1].weight.data)
else: else:
torch.testing.assert_allclose(l1[0].weight.grad, l2[0].weight.grad) torch.testing.assert_close(l1[0].weight.grad, l2[0].weight.grad)
torch.testing.assert_allclose(l1[1].weight.grad, l2[1].weight.grad) torch.testing.assert_close(l1[1].weight.grad, l2[1].weight.grad)
@pytest.mark.parametrize("threshold", [0.0, 2.0]) @pytest.mark.parametrize("threshold", [0.0, 2.0])
...@@ -478,7 +479,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): ...@@ -478,7 +479,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
grad_ref = grad_proj.flatten(2) @ w2.half() @ w1.half() grad_ref = grad_proj.flatten(2) @ w2.half() @ w1.half()
scale = grad_ref.abs().mean() scale = grad_ref.abs().mean()
torch.testing.assert_allclose(b1.grad, grad_ref, rtol=0, atol=0.05 * scale) torch.testing.assert_close(b1.grad, grad_ref, rtol=0, atol=0.05 * scale)
idx = torch.isclose(b1.grad, grad_ref, atol=0.01 * scale, rtol=0.1) idx = torch.isclose(b1.grad, grad_ref, atol=0.01 * scale, rtol=0.1)
assert (idx == 0).sum().item() <= b1.numel() * 0.005 assert (idx == 0).sum().item() <= b1.numel() * 0.005
...@@ -559,11 +560,11 @@ def test_kbit_backprop(module): ...@@ -559,11 +560,11 @@ def test_kbit_backprop(module):
relerrs2.append(relerr2.mean().item()) relerrs2.append(relerr2.mean().item())
if isinstance(module, bnb.nn.Linear8bitLt): if isinstance(module, bnb.nn.Linear8bitLt):
torch.testing.assert_allclose(grad1, grad2, atol=0.008, rtol=0.05) torch.testing.assert_close(grad1, grad2, atol=0.008, rtol=0.05)
torch.testing.assert_allclose(bgrad1, bgrad2, atol=0.008, rtol=0.05) torch.testing.assert_close(bgrad1, bgrad2, atol=0.008, rtol=0.05)
else: else:
torch.testing.assert_allclose(grad1, grad2, atol=0.015, rtol=0.05) torch.testing.assert_close(grad1, grad2, atol=0.015, rtol=0.05)
torch.testing.assert_allclose(bgrad1, bgrad2, atol=0.02, rtol=0.05) torch.testing.assert_close(bgrad1, bgrad2, atol=0.02, rtol=0.05)
ref.zero_grad() ref.zero_grad()
kbit.zero_grad() kbit.zero_grad()
...@@ -574,4 +575,39 @@ def test_kbit_backprop(module): ...@@ -574,4 +575,39 @@ def test_kbit_backprop(module):
print('rel out', sum(relerrs1)/len(relerrs1)) print('rel out', sum(relerrs1)/len(relerrs1))
print('rel grad', sum(relerrs2)/len(relerrs2)) print('rel grad', sum(relerrs2)/len(relerrs2))
def test_fp8linear():
b = 10
h = 1024
inp = torch.randn(b, h).cuda()
fp32 = torch.nn.Linear(h, h*2).cuda()
fp8 = bnb.research.nn.LinearFP8Mixed(h, h*2).cuda()
fp32b = torch.nn.Linear(h*2, h).cuda()
fp8b = bnb.research.nn.LinearFP8Mixed(h*2, h).cuda()
fp8.weight.data.copy_(fp32.weight.data)
fp8.bias.data.copy_(fp32.bias.data)
fp8b.weight.data.copy_(fp32b.weight.data)
fp8b.bias.data.copy_(fp32b.bias.data)
a = fp32b(torch.nn.functional.gelu(fp32(inp)))
b = fp8b(torch.nn.functional.gelu(fp8(inp)))
err = (a-b).abs().mean()
a.mean().backward()
b.mean().backward()
graderr = (fp8.weight.grad-fp32.weight.grad).abs().mean()
bgraderr = (fp8.bias.grad-fp32.bias.grad).abs().mean()
assert err < 0.05
assert graderr < 0.00002
assert bgraderr < 0.00002
...@@ -7,6 +7,8 @@ from itertools import product ...@@ -7,6 +7,8 @@ from itertools import product
from os.path import join from os.path import join
import pytest import pytest
from lion_pytorch import Lion
import torch import torch
import bitsandbytes as bnb import bitsandbytes as bnb
...@@ -16,6 +18,13 @@ import bitsandbytes.functional as F ...@@ -16,6 +18,13 @@ import bitsandbytes.functional as F
k = 20 k = 20
def assert_most_approx_close(a, b, rtol=1e-3, atol=1e-3, max_error_count=0):
idx = torch.isclose(a, b, rtol, atol)
error_count = (idx == 0).sum().item()
if error_count > max_error_count:
print(f"Too many values not close: assert {error_count} < {max_error_count}")
torch.testing.assert_close(a, b, rtol, atol)
def get_temp_dir(): def get_temp_dir():
path = f"/tmp/autoswap/{str(uuid.uuid4())}" path = f"/tmp/autoswap/{str(uuid.uuid4())}"
...@@ -33,6 +42,7 @@ str2optimizers = {} ...@@ -33,6 +42,7 @@ str2optimizers = {}
str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam) str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam)
# str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam) # str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam)
# str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam) # str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam)
str2optimizers["lion_pytorch"] = (None, Lion, bnb.optim.Lion)
str2optimizers["momentum_pytorch"] = ( str2optimizers["momentum_pytorch"] = (
None, None,
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
...@@ -42,6 +52,7 @@ str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam) ...@@ -42,6 +52,7 @@ str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam)
str2optimizers["paged_adamw"] = (torch.optim.AdamW, bnb.optim.PagedAdamW) str2optimizers["paged_adamw"] = (torch.optim.AdamW, bnb.optim.PagedAdamW)
str2optimizers["paged_adam"] = (torch.optim.Adam, bnb.optim.PagedAdam) str2optimizers["paged_adam"] = (torch.optim.Adam, bnb.optim.PagedAdam)
# str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam) # str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam)
str2optimizers["lion"] = (Lion, bnb.optim.Lion)
str2optimizers["momentum"] = ( str2optimizers["momentum"] = (
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False), lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False),
...@@ -51,6 +62,7 @@ str2optimizers["rmsprop"] = ( ...@@ -51,6 +62,7 @@ str2optimizers["rmsprop"] = (
lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False), lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False),
) )
str2optimizers["adam8bit"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False)) str2optimizers["adam8bit"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False))
str2optimizers["lion8bit"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=False))
str2optimizers["momentum8bit"] = ( str2optimizers["momentum8bit"] = (
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False),
...@@ -63,6 +75,7 @@ str2optimizers["rmsprop8bit"] = ( ...@@ -63,6 +75,7 @@ str2optimizers["rmsprop8bit"] = (
str2optimizers["adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True)) str2optimizers["adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True))
str2optimizers["paged_adamw8bit_blockwise"] = (torch.optim.AdamW, lambda pxx: bnb.optim.PagedAdamW8bit(pxx, block_wise=True)) str2optimizers["paged_adamw8bit_blockwise"] = (torch.optim.AdamW, lambda pxx: bnb.optim.PagedAdamW8bit(pxx, block_wise=True))
str2optimizers["paged_adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.PagedAdam8bit(pxx, block_wise=True)) str2optimizers["paged_adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.PagedAdam8bit(pxx, block_wise=True))
str2optimizers["lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=True))
str2optimizers["momentum8bit_blockwise"] = ( str2optimizers["momentum8bit_blockwise"] = (
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True),
...@@ -76,6 +89,7 @@ str2statenames = {} ...@@ -76,6 +89,7 @@ str2statenames = {}
str2statenames["adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] str2statenames["adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["paged_adamw"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] str2statenames["paged_adamw"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["paged_adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] str2statenames["paged_adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["lion"] = [("exp_avg", "state1")]
str2statenames["momentum"] = [("momentum_buffer", "state1")] str2statenames["momentum"] = [("momentum_buffer", "state1")]
str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["rmsprop"] = [("square_avg", "state1")] str2statenames["rmsprop"] = [("square_avg", "state1")]
...@@ -85,14 +99,16 @@ str2statenames["adam8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1" ...@@ -85,14 +99,16 @@ str2statenames["adam8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"
str2statenames["paged_adam8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")] str2statenames["paged_adam8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")]
str2statenames["paged_adamw8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")] str2statenames["paged_adamw8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")]
str2statenames["momentum8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")] str2statenames["momentum8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")]
str2statenames["lion8bit"] = [("exp_avg", "state1", "qmap1", "max1")]
str2statenames["momentum8bit_blockwise"] = [("momentum_buffer", "state1", "qmap1", "absmax1")] str2statenames["momentum8bit_blockwise"] = [("momentum_buffer", "state1", "qmap1", "absmax1")]
str2statenames["rmsprop8bit"] = [("square_avg", "state1", "qmap1", "max1")] str2statenames["rmsprop8bit"] = [("square_avg", "state1", "qmap1", "max1")]
str2statenames["rmsprop8bit_blockwise"] = [("square_avg", "state1", "qmap1", "absmax1")] str2statenames["rmsprop8bit_blockwise"] = [("square_avg", "state1", "qmap1", "absmax1")]
str2statenames["lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")]
dim1 = [1024] dim1 = [1024]
dim2 = [32, 1024, 4097, 1] dim2 = [32, 1024, 4097, 1]
gtype = [torch.float32, torch.float16] gtype = [torch.float32, torch.float16]
optimizer_names = ["adam", "momentum", "rmsprop", 'paged_adamw', 'paged_adam'] optimizer_names = ["adam", "momentum", "rmsprop", 'paged_adamw', 'paged_adam', 'lion']
values = list(product(dim1, dim2, gtype, optimizer_names)) values = list(product(dim1, dim2, gtype, optimizer_names))
names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values] names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) @pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
...@@ -121,6 +137,7 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): ...@@ -121,6 +137,7 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
bnb_optimizer.step() bnb_optimizer.step()
torch_optimizer.step() torch_optimizer.step()
for name1, name2 in str2statenames[optim_name]: for name1, name2 in str2statenames[optim_name]:
torch.testing.assert_close( torch.testing.assert_close(
torch_optimizer.state[p1][name1], torch_optimizer.state[p1][name1],
...@@ -129,7 +146,9 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): ...@@ -129,7 +146,9 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
rtol=rtol, rtol=rtol,
) )
torch.testing.assert_close(p1, p2.float(), atol=atol, rtol=rtol) # since Lion can have pretty noisy updates where things lie at the boundary
# allow up to 10 errors for Lion
assert_most_approx_close(p1, p2.float(), atol, rtol, max_error_count=10)
if i % (k // 5) == 0 and i > 0: if i % (k // 5) == 0 and i > 0:
path = get_temp_dir() path = get_temp_dir()
...@@ -139,14 +158,15 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): ...@@ -139,14 +158,15 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
bnb_optimizer = str2optimizers[optim_name][1]([p2]) bnb_optimizer = str2optimizers[optim_name][1]([p2])
bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt"))) bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt")))
rm_path(path) rm_path(path)
torch.testing.assert_close(p1, p2.float(), atol=atol, rtol=rtol) # since Lion can have pretty noisy updates where things lie at the boundary
# allow up to 10 errors for Lion
assert_most_approx_close(p1, p2.float(), atol, rtol, max_error_count=10)
for name1, name2 in str2statenames[optim_name]: for name1, name2 in str2statenames[optim_name]:
torch.testing.assert_close( # since Lion can have pretty noisy updates where things lie at the boundary
torch_optimizer.state[p1][name1], # allow up to 10 errors for Lion
bnb_optimizer.state[p2][name2], assert_most_approx_close(torch_optimizer.state[p1][name1], bnb_optimizer.state[p2][name2],
atol=atol, atol=atol, rtol=rtol,
rtol=rtol, max_error_count=10)
)
if gtype != torch.float32: if gtype != torch.float32:
# the adam buffers should also be close because they are 32-bit # the adam buffers should also be close because they are 32-bit
...@@ -218,9 +238,11 @@ dim2 = [32, 1024, 4097] ...@@ -218,9 +238,11 @@ dim2 = [32, 1024, 4097]
gtype = [torch.float32, torch.float16, torch.bfloat16] gtype = [torch.float32, torch.float16, torch.bfloat16]
optimizer_names = [ optimizer_names = [
"adam8bit", "adam8bit",
"lion8bit",
"momentum8bit", "momentum8bit",
"rmsprop8bit", "rmsprop8bit",
"adam8bit_blockwise", "adam8bit_blockwise",
"lion8bit_blockwise",
"momentum8bit_blockwise", "momentum8bit_blockwise",
"rmsprop8bit_blockwise", "rmsprop8bit_blockwise",
] ]
...@@ -264,7 +286,9 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): ...@@ -264,7 +286,9 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
bnb_optimizer.step() bnb_optimizer.step()
torch_optimizer.step() torch_optimizer.step()
torch.testing.assert_close(p1, p2.float(), atol=patol, rtol=prtol) # since Lion can have pretty noisy updates where things lie at the boundary
# allow up to 5 errors for Lion
assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=5)
dequant_states = [] dequant_states = []
for name1, name2, qmap, max_val in str2statenames[optim_name]: for name1, name2, qmap, max_val in str2statenames[optim_name]:
...@@ -292,7 +316,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): ...@@ -292,7 +316,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
dequant_states.append(s1.clone()) dequant_states.append(s1.clone())
err = torch.abs(p1 - p2) err = torch.abs(p1 - p2)
relerr = err / torch.abs(p1) relerr = err / (torch.abs(p1)+1e-9)
if g.dtype == torch.bfloat16: if g.dtype == torch.bfloat16:
assert err.mean() < 0.00015 assert err.mean() < 0.00015
assert relerr.mean() < 0.0016 assert relerr.mean() < 0.0016
...@@ -338,7 +362,9 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): ...@@ -338,7 +362,9 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
num_not_close = (torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0) num_not_close = (torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0)
assert num_not_close.sum().item() < 20 assert num_not_close.sum().item() < 20
torch.testing.assert_close(p1, p2.float(), atol=patol, rtol=prtol) # since Lion can have pretty noisy updates where things lie at the boundary
# allow up to 5 errors for Lion
assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=5)
# the parameters diverge quickly. Here we keep them close # the parameters diverge quickly. Here we keep them close
# together so we can test against the Adam error # together so we can test against the Adam error
...@@ -491,7 +517,7 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name): ...@@ -491,7 +517,7 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
print(optim_name, gtype, s / params) print(optim_name, gtype, s / params)
# assert s < 3.9 # assert s < 3.9
dim1 = [10*1024] dim1 = [2*1024]
gtype = [torch.float16] gtype = [torch.float16]
#mode = ['torch', 'bnb'] #mode = ['torch', 'bnb']
mode = ['bnb'] mode = ['bnb']
......
import pytest
import torch
from bitsandbytes.triton.triton_utils import is_triton_available
from bitsandbytes.nn.triton_based_modules import SwitchBackLinear
from bitsandbytes.nn import Linear8bitLt
@pytest.mark.skipif(not is_triton_available() or not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] >= 8,
reason="This test requires triton and a GPU with compute capability 8.0 or higher.")
@pytest.mark.parametrize("vector_wise_quantization", [False, True])
def test_switchback(vector_wise_quantization):
for dim in [83]:
for batch in [13]:
standard = torch.nn.Linear(dim, 4 * dim).cuda().half()
switchback = SwitchBackLinear(dim, 4 * dim, vector_wise_quantization=vector_wise_quantization).cuda().half()
baseline = Linear8bitLt(dim, 4 * dim).cuda().half()
switchback.weight.data.copy_(standard.weight)
switchback.bias.data.copy_(standard.bias)
baseline.weight.data.copy_(standard.weight)
baseline.bias.data.copy_(standard.bias)
x1 = torch.randn(batch, dim).cuda().half().requires_grad_(True)
x2 = x1.clone().detach().requires_grad_(True)
x3 = x1.clone().detach().requires_grad_(True)
out_standard = standard(x1)
(2**10 * out_standard.abs().mean()).backward()
print(x2.dtype)
out_sb = switchback(x2)
(2**10 * out_sb.abs().mean()).backward()
out_baseline = baseline(x3)
(2**10 * out_baseline.abs().mean()).backward()
err_sb = (out_standard - out_sb).abs().mean()
err_baseline = (out_standard - out_baseline).abs().mean()
print('OUT', err_sb, err_baseline)
assert err_sb < 2 * err_baseline
err_sb = (standard.bias.grad - switchback.bias.grad).abs().mean()
err_baseline = (standard.bias.grad - baseline.bias.grad).abs().mean()
print('GW2', err_sb, err_baseline)
assert err_sb < 2 * err_baseline
err_sb = (standard.weight.grad - switchback.weight.grad).abs().mean()
err_baseline = (standard.weight.grad - baseline.weight.grad).abs().mean()
print('GW1', err_sb, err_baseline)
assert err_sb < 2 * err_baseline
err_sb = (x1.grad - x2.grad).abs().mean()
err_baseline = (x1.grad - x3.grad).abs().mean()
print('GX1', err_sb, err_baseline)
assert err_sb < 2 * err_baseline
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment