"...git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "0d99ae1fe84f8d191abe5ed1c2f4fdc5a9f9a773"
Commit 7140c014 authored by Tim Dettmers's avatar Tim Dettmers
Browse files

Merge branch 'main' into fp8_merge

parents dd562c24 32f8c892
...@@ -12,10 +12,12 @@ URL116=https://developer.download.nvidia.com/compute/cuda/11.6.2/local_installer ...@@ -12,10 +12,12 @@ URL116=https://developer.download.nvidia.com/compute/cuda/11.6.2/local_installer
URL117=https://developer.download.nvidia.com/compute/cuda/11.7.0/local_installers/cuda_11.7.0_515.43.04_linux.run URL117=https://developer.download.nvidia.com/compute/cuda/11.7.0/local_installers/cuda_11.7.0_515.43.04_linux.run
URL118=https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run URL118=https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run
URL120=https://developer.download.nvidia.com/compute/cuda/12.0.0/local_installers/cuda_12.0.0_525.60.13_linux.run URL120=https://developer.download.nvidia.com/compute/cuda/12.0.0/local_installers/cuda_12.0.0_525.60.13_linux.run
URL121=https://developer.download.nvidia.com/compute/cuda/12.1.0/local_installers/cuda_12.1.0_530.30.02_linux.run
CUDA_VERSION=$1 CUDA_VERSION=$1
BASE_PATH=$2 BASE_PATH=$2
EXPORT_BASHRC=$3
if [[ -n "$CUDA_VERSION" ]]; then if [[ -n "$CUDA_VERSION" ]]; then
if [[ "$CUDA_VERSION" -eq "92" ]]; then if [[ "$CUDA_VERSION" -eq "92" ]]; then
...@@ -60,11 +62,14 @@ if [[ -n "$CUDA_VERSION" ]]; then ...@@ -60,11 +62,14 @@ if [[ -n "$CUDA_VERSION" ]]; then
elif [[ "$CUDA_VERSION" -eq "120" ]]; then elif [[ "$CUDA_VERSION" -eq "120" ]]; then
URL=$URL120 URL=$URL120
FOLDER=cuda-12.0 FOLDER=cuda-12.0
elif [[ "$CUDA_VERSION" -eq "121" ]]; then
URL=$URL121
FOLDER=cuda-12.1
else else
echo "argument error: No cuda version passed as input. Choose among: {111, 115}" echo "argument error: No cuda version passed as input. Choose among versions 92 to 121"
fi fi
else else
echo "argument error: No cuda version passed as input. Choose among: {111, 115}" echo "argument error: No cuda version passed as input. Choose among versions 92 to 112"
fi fi
FILE=$(basename $URL) FILE=$(basename $URL)
...@@ -72,11 +77,13 @@ FILE=$(basename $URL) ...@@ -72,11 +77,13 @@ FILE=$(basename $URL)
if [[ -n "$CUDA_VERSION" ]]; then if [[ -n "$CUDA_VERSION" ]]; then
echo $URL echo $URL
echo $FILE echo $FILE
wget $URL #wget $URL
bash $FILE --no-drm --no-man-page --override --toolkitpath=$BASE_PATH/$FOLDER/ --toolkit --silent bash $FILE --no-drm --no-man-page --override --toolkitpath=$BASE_PATH/$FOLDER/ --toolkit --silent
echo "export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$BASE_PATH/$FOLDER/lib64/" >> ~/.bashrc if [ "$EXPORT_BASHRC" -eq "1" ]; then
echo "export PATH=$PATH:$BASE_PATH/$FOLDER/bin/" >> ~/.bashrc echo "export LD_LIBRARY_PATH=\$LD_LIBRARY_PATH:$BASE_PATH/$FOLDER/lib64" >> ~/.bashrc
echo "export PATH=\$PATH:$BASE_PATH/$FOLDER/bin" >> ~/.bashrc
source ~/.bashrc source ~/.bashrc
fi
else else
echo "" echo ""
fi fi
...@@ -10,8 +10,8 @@ if [[ ! -z "${LD_LIBRARY_PATH}" ]]; then ...@@ -10,8 +10,8 @@ if [[ ! -z "${LD_LIBRARY_PATH}" ]]; then
fi fi
module unload cuda module unload cuda && echo "no module function available. Probably not on a slurm cluster."
module unload gcc module unload gcc && echo "no module function available. Probably not on a slurm cluster."
rm -rf dist build rm -rf dist build
make cleaneggs make cleaneggs
...@@ -128,6 +128,16 @@ if [ ! -f "./bitsandbytes/libbitsandbytes_cuda120.so" ]; then ...@@ -128,6 +128,16 @@ if [ ! -f "./bitsandbytes/libbitsandbytes_cuda120.so" ]; then
exit 64 exit 64
fi fi
make clean
export CUDA_HOME=$BASE_PATH/cuda-12.1
make cuda12x CUDA_VERSION=121
if [ ! -f "./bitsandbytes/libbitsandbytes_cuda121.so" ]; then
# Control will enter here if $DIRECTORY doesn't exist.
echo "Compilation unsuccessul!" 1>&2
exit 64
fi
make clean make clean
export CUDA_HOME=$BASE_PATH/cuda-10.2 export CUDA_HOME=$BASE_PATH/cuda-10.2
...@@ -241,5 +251,15 @@ if [ ! -f "./bitsandbytes/libbitsandbytes_cuda120_nocublaslt.so" ]; then ...@@ -241,5 +251,15 @@ if [ ! -f "./bitsandbytes/libbitsandbytes_cuda120_nocublaslt.so" ]; then
exit 64 exit 64
fi fi
make clean
export CUDA_HOME=$BASE_PATH/cuda-12.1
make cuda12x_nomatmul CUDA_VERSION=121
if [ ! -f "./bitsandbytes/libbitsandbytes_cuda121_nocublaslt.so" ]; then
# Control will enter here if $DIRECTORY doesn't exist.
echo "Compilation unsuccessul!" 1>&2
exit 64
fi
python -m build python -m build
python -m twine upload dist/* --verbose python -m twine upload dist/* --verbose
# No kernel image available # No kernel image available
This problem arises with the cuda version loaded by bitsandbytes is not supported by your GPU, or if you pytorch CUDA version mismatches. So solve this problem you need to debug ``$LD_LIBRARY_PATH``, ``$CUDA_HOME``, ``$PATH``. You can print these via ``echo $PATH``. You should look for multiple paths to different CUDA versions. This can include versions in your anaconda path, for example ``$HOME/anaconda3/lib``. You can check those versions via ``ls -l $HOME/anaconda3/lib/*cuda*`` or equivalent paths. Look at the CUDA versions of files in these paths. Does it match with ``nvidia-smi``? This problem arises with the cuda version loaded by bitsandbytes is not supported by your GPU, or if you pytorch CUDA version mismatches. To solve this problem you need to debug ``$LD_LIBRARY_PATH``, ``$CUDA_HOME``, ``$PATH``. You can print these via ``echo $PATH``. You should look for multiple paths to different CUDA versions. This can include versions in your anaconda path, for example ``$HOME/anaconda3/lib``. You can check those versions via ``ls -l $HOME/anaconda3/lib/*cuda*`` or equivalent paths. Look at the CUDA versions of files in these paths. Does it match with ``nvidia-smi``?
If you are feeling lucky, you can also try to compile the library from source. This can be still problematic if your PATH variables have multiple cuda versions. As such, it is recommended to figure out path conflicts before you proceed with compilation. If you are feeling lucky, you can also try to compile the library from source. This can be still problematic if your PATH variables have multiple cuda versions. As such, it is recommended to figure out path conflicts before you proceed with compilation.
......
...@@ -18,7 +18,7 @@ def read(fname): ...@@ -18,7 +18,7 @@ def read(fname):
setup( setup(
name=f"bitsandbytes", name=f"bitsandbytes",
version=f"0.36.0-2", version=f"0.38.0.post2",
author="Tim Dettmers", author="Tim Dettmers",
author_email="dettmers@cs.washington.edu", author_email="dettmers@cs.washington.edu",
description="8-bit optimizers and matrix multiplication routines.", description="8-bit optimizers and matrix multiplication routines.",
......
...@@ -5,95 +5,20 @@ import pytest ...@@ -5,95 +5,20 @@ import pytest
import bitsandbytes as bnb import bitsandbytes as bnb
from bitsandbytes.cuda_setup.main import ( from bitsandbytes.cuda_setup.main import (
CUDA_RUNTIME_LIB,
determine_cuda_runtime_lib_path, determine_cuda_runtime_lib_path,
evaluate_cuda_setup, evaluate_cuda_setup,
extract_candidate_paths, extract_candidate_paths,
) )
"""
'LD_LIBRARY_PATH': ':/mnt/D/titus/local/cuda-11.1/lib64/'
'CONDA_EXE': '/mnt/D/titus/miniconda/bin/conda'
'LESSCLOSE': '/usr/bin/lesspipe %s %s'
'OLDPWD': '/mnt/D/titus/src'
'CONDA_PREFIX': '/mnt/D/titus/miniconda/envs/8-bit'
'SSH_AUTH_SOCK': '/mnt/D/titus/.ssh/ssh-agent.tim-uw.sock'
'CONDA_PREFIX_1': '/mnt/D/titus/miniconda'
'PWD': '/mnt/D/titus/src/8-bit'
'HOME': '/mnt/D/titus'
'CONDA_PYTHON_EXE': '/mnt/D/titus/miniconda/bin/python'
'CUDA_HOME': '/mnt/D/titus/local/cuda-11.1/'
'TMUX': '/tmp/tmux-1007/default,59286,1'
'XDG_DATA_DIRS': '/usr/local/share:/usr/share:/var/lib/snapd/desktop'
'SSH_TTY': '/dev/pts/0'
'MAIL': '/var/mail/titus'
'SHELL': '/bin/bash'
'DBUS_SESSION_BUS_ADDRESS': 'unix:path=/run/user/1007/bus'
'XDG_RUNTIME_DIR': '/run/user/1007'
'PATH': '/mnt/D/titus/miniconda/envs/8-bit/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/games:/usr/local/games:/snap/bin:/mnt/D/titus/local/cuda-11.1/bin'
'LESSOPEN': '| /usr/bin/lesspipe %s'
'_': '/mnt/D/titus/miniconda/envs/8-bit/bin/python'
# any that include 'CONDA' that are not 'CONDA_PREFIX'
# we search for def test_cuda_full_system():
'CUDA_HOME': '/mnt/D/titus/local/cuda-11.1/'
"""
class InputAndExpectedOutput(NamedTuple):
input: str
output: str
HAPPY_PATH__LD_LIB_TEST_PATHS: List[InputAndExpectedOutput] = [
(
f"some/other/dir:dir/with/{CUDA_RUNTIME_LIB}",
f"dir/with/{CUDA_RUNTIME_LIB}",
),
(
f":some/other/dir:dir/with/{CUDA_RUNTIME_LIB}",
f"dir/with/{CUDA_RUNTIME_LIB}",
),
(
f"some/other/dir:dir/with/{CUDA_RUNTIME_LIB}:",
f"dir/with/{CUDA_RUNTIME_LIB}",
),
(
f"some/other/dir::dir/with/{CUDA_RUNTIME_LIB}",
f"dir/with/{CUDA_RUNTIME_LIB}",
),
(
f"dir/with/{CUDA_RUNTIME_LIB}:some/other/dir",
f"dir/with/{CUDA_RUNTIME_LIB}",
),
(
f"dir/with/{CUDA_RUNTIME_LIB}:other/dir/libcuda.so",
f"dir/with/{CUDA_RUNTIME_LIB}",
),
]
@pytest.fixture(params=HAPPY_PATH__LD_LIB_TEST_PATHS)
def happy_path_path_string(tmpdir, request):
for path in extract_candidate_paths(request.param):
test_dir.mkdir()
if CUDA_RUNTIME_LIB in path:
(test_input / CUDA_RUNTIME_LIB).touch()
UNHAPPY_PATH__LD_LIB_TEST_PATHS = [
f"a/b/c/{CUDA_RUNTIME_LIB}:d/e/f/{CUDA_RUNTIME_LIB}",
f"a/b/c/{CUDA_RUNTIME_LIB}:d/e/f/{CUDA_RUNTIME_LIB}:g/h/j/{CUDA_RUNTIME_LIB}",
]
def test_full_system():
## this only tests the cuda version and not compute capability ## this only tests the cuda version and not compute capability
# if CONDA_PREFIX exists, it has priority before all other env variables # if CONDA_PREFIX exists, it has priority before all other env variables
# but it does not contain the library directly, so we need to look at the a sub-folder # but it does not contain the library directly, so we need to look at the a sub-folder
version = "" version = ""
if "CONDA_PREFIX" in os.environ: if "CONDA_PREFIX" in os.environ:
ls_output, err = bnb.utils.execute_and_return(f'ls -l {os.environ["CONDA_PREFIX"]}/lib/libcudart.so') ls_output, err = bnb.utils.execute_and_return(f'ls -l {os.environ["CONDA_PREFIX"]}/lib/libcudart.so.11.0')
major, minor, revision = (ls_output.split(" ")[-1].replace("libcudart.so.", "").split(".")) major, minor, revision = (ls_output.split(" ")[-1].replace("libcudart.so.", "").split("."))
version = float(f"{major}.{minor}") version = float(f"{major}.{minor}")
......
import os
from contextlib import nullcontext
from itertools import product
from tempfile import TemporaryDirectory
import pytest
import torch
import bitsandbytes as bnb
from bitsandbytes import functional as F
from bitsandbytes.autograd import get_inverse_transform_indices, undo_layout
from bitsandbytes.nn.modules import Linear8bitLt
# contributed by Alex Borzunov, see:
# https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py
@pytest.mark.skipif(
not torch.cuda.is_available() or torch.cuda.get_device_capability() < (7, 5),
reason="this test requires a turing-generation or newer GPU, see bitsandbytes docs",
)
def test_layout_exact_match():
x = (torch.randn(14336 * 3, 14336) * 10).to(torch.int8).cuda()
for tile_size, order in ((8, 32), "col_turing"), ((32, 32), "col_ampere"):
transform = lambda x: F.transform(x.cuda(), from_order="row", to_order=order)[0].to(x.device)
tile_indices = get_inverse_transform_indices(transform, tile_size)
cxb = transform(x)
torch.cuda.synchronize()
restored_x = undo_layout(cxb, tile_indices)
torch.cuda.synchronize()
assert restored_x.is_contiguous()
assert torch.all(torch.eq(restored_x, x))
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
def test_linear_no_igemmlt():
linear = torch.nn.Linear(1024, 3072)
x = torch.randn(3, 1024, dtype=torch.half)
linear_custom = Linear8bitLt(
linear.in_features,
linear.out_features,
linear.bias is not None,
has_fp16_weights=False,
threshold=6.0,
)
linear_custom.state.force_no_igemmlt = True
linear_custom.weight = bnb.nn.Int8Params(
linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False
).to(linear.weight.dtype)
linear_custom.bias = linear.bias
linear_custom = linear_custom.cuda()
linear = linear.half().cuda()
x_ref = x.clone().cuda().requires_grad_(True)
x_ours = x.clone().cuda().requires_grad_(True)
fx_ref = linear(x_ref).float()
grad_proj = torch.randn_like(fx_ref)
(fx_ref * grad_proj).mean().backward()
fx_ours = linear_custom(x_ours).float()
(fx_ours * grad_proj).mean().backward()
assert torch.allclose(fx_ref, fx_ours, atol=0.02)
assert torch.allclose(x_ref.grad, x_ours.grad, atol=0.01)
assert not linear_custom.state.has_fp16_weights
assert linear_custom.state.CB is not None
assert linear_custom.state.CxB is None
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
@pytest.mark.parametrize("has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt",
list(product([False, True], [False, True], [False, True], [False, True])))
def test_linear_serialization(has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt):
linear = torch.nn.Linear(32, 96)
x = torch.randn(3, 32, dtype=torch.half)
linear_custom = Linear8bitLt(
linear.in_features,
linear.out_features,
linear.bias is not None,
has_fp16_weights=has_fp16_weights,
threshold=6.0,
)
if force_no_igemmlt:
linear_custom.state.force_no_igemmlt = True
linear_custom.weight = bnb.nn.Int8Params(
linear.weight.data.clone(), requires_grad=has_fp16_weights, has_fp16_weights=has_fp16_weights
)
linear_custom.bias = linear.bias
linear_custom = linear_custom.cuda()
if serialize_before_forward:
state_dict_8bit = linear_custom.state_dict()
x_first = x.clone().cuda().requires_grad_(True)
fx_first = linear_custom(x_first).float()
grad_proj = torch.randn_like(fx_first)
(fx_first * grad_proj).mean().backward()
if not serialize_before_forward:
state_dict_8bit = linear_custom.state_dict()
with TemporaryDirectory() as tmpdir:
state_path_8bit = os.path.join(tmpdir, "state_8bit.pth")
state_path = os.path.join(tmpdir, "state.pth")
torch.save(linear.state_dict(), state_path)
torch.save(state_dict_8bit, state_path_8bit)
if not has_fp16_weights:
assert os.path.getsize(state_path_8bit) < 0.5 * os.path.getsize(state_path)
new_state_dict = torch.load(state_path_8bit)
new_linear_custom = Linear8bitLt(
linear.in_features,
linear.out_features,
linear.bias is not None,
has_fp16_weights=has_fp16_weights,
threshold=6.0,
)
if force_no_igemmlt:
new_linear_custom.state.force_no_igemmlt = True
if deserialize_before_cuda:
with nullcontext() if has_fp16_weights else pytest.raises(RuntimeError):
new_linear_custom.load_state_dict(new_state_dict, strict=True)
new_linear_custom = new_linear_custom.cuda()
if not deserialize_before_cuda:
new_linear_custom.load_state_dict(new_state_dict, strict=True)
x_second = x.clone().cuda().requires_grad_(True)
fx_second = new_linear_custom(x_second).float()
(fx_second * grad_proj).mean().backward()
# if 8-bit weights were loaded before .cuda, state is incorrect anyway and RuntimeError was raised
if has_fp16_weights or not deserialize_before_cuda:
assert torch.allclose(fx_first, fx_second, atol=1e-5)
assert torch.allclose(x_first.grad, x_second.grad, atol=1e-5)
...@@ -382,7 +382,7 @@ names = [f"threshold_{vals}" for vals in values] ...@@ -382,7 +382,7 @@ names = [f"threshold_{vals}" for vals in values]
@pytest.mark.parametrize("threshold", values, ids=names) @pytest.mark.parametrize("threshold", values, ids=names)
@pytest.mark.parametrize("memory_efficient_backward", [True, False]) @pytest.mark.parametrize("memory_efficient_backward", [False])
def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
l1 = ( l1 = (
bnb.nn.Linear8bitLt( bnb.nn.Linear8bitLt(
......
...@@ -7,6 +7,8 @@ from itertools import product ...@@ -7,6 +7,8 @@ from itertools import product
from os.path import join from os.path import join
import pytest import pytest
from lion_pytorch import Lion
import torch import torch
import bitsandbytes as bnb import bitsandbytes as bnb
...@@ -16,6 +18,13 @@ import bitsandbytes.functional as F ...@@ -16,6 +18,13 @@ import bitsandbytes.functional as F
k = 20 k = 20
def assert_most_approx_close(a, b, rtol=1e-3, atol=1e-3, max_error_count=0):
idx = torch.isclose(a, b, rtol, atol)
error_count = (idx == 0).sum().item()
if error_count > max_error_count:
print(f"Too many values not close: assert {error_count} < {max_error_count}")
torch.testing.assert_allclose(a, b, rtol, atol)
def get_temp_dir(): def get_temp_dir():
path = f"/tmp/autoswap/{str(uuid.uuid4())}" path = f"/tmp/autoswap/{str(uuid.uuid4())}"
...@@ -31,6 +40,7 @@ str2optimizers = {} ...@@ -31,6 +40,7 @@ str2optimizers = {}
str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam) str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam)
# str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam) # str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam)
# str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam) # str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam)
str2optimizers["lion_pytorch"] = (None, Lion, bnb.optim.Lion)
str2optimizers["momentum_pytorch"] = ( str2optimizers["momentum_pytorch"] = (
None, None,
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
...@@ -38,6 +48,7 @@ str2optimizers["momentum_pytorch"] = ( ...@@ -38,6 +48,7 @@ str2optimizers["momentum_pytorch"] = (
) )
str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam) str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam)
# str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam) # str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam)
str2optimizers["lion"] = (Lion, bnb.optim.Lion)
str2optimizers["momentum"] = ( str2optimizers["momentum"] = (
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False), lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False),
...@@ -54,6 +65,10 @@ str2optimizers["adam8bit"] = ( ...@@ -54,6 +65,10 @@ str2optimizers["adam8bit"] = (
torch.optim.Adam, torch.optim.Adam,
lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False), lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False),
) )
str2optimizers["lion8bit"] = (
Lion,
lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=False),
)
str2optimizers["momentum8bit"] = ( str2optimizers["momentum8bit"] = (
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False),
...@@ -71,6 +86,10 @@ str2optimizers["adam8bit_blockwise"] = ( ...@@ -71,6 +86,10 @@ str2optimizers["adam8bit_blockwise"] = (
torch.optim.Adam, torch.optim.Adam,
lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True), lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True),
) )
str2optimizers["lion8bit_blockwise"] = (
Lion,
lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=True),
)
str2optimizers["momentum8bit_blockwise"] = ( str2optimizers["momentum8bit_blockwise"] = (
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True),
...@@ -82,6 +101,7 @@ str2optimizers["rmsprop8bit_blockwise"] = ( ...@@ -82,6 +101,7 @@ str2optimizers["rmsprop8bit_blockwise"] = (
str2statenames = {} str2statenames = {}
str2statenames["adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] str2statenames["adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["lion"] = [("exp_avg", "state1")]
str2statenames["momentum"] = [("momentum_buffer", "state1")] str2statenames["momentum"] = [("momentum_buffer", "state1")]
str2statenames["lars"] = [("momentum_buffer", "state1")] str2statenames["lars"] = [("momentum_buffer", "state1")]
str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
...@@ -90,6 +110,9 @@ str2statenames["adam8bit"] = [ ...@@ -90,6 +110,9 @@ str2statenames["adam8bit"] = [
("exp_avg", "state1", "qmap1", "max1"), ("exp_avg", "state1", "qmap1", "max1"),
("exp_avg_sq", "state2", "qmap2", "max2"), ("exp_avg_sq", "state2", "qmap2", "max2"),
] ]
str2statenames["lion8bit"] = [
("exp_avg", "state1", "qmap1", "max1")
]
str2statenames["lamb8bit"] = [ str2statenames["lamb8bit"] = [
("exp_avg", "state1", "qmap1", "max1"), ("exp_avg", "state1", "qmap1", "max1"),
("exp_avg_sq", "state2", "qmap2", "max2"), ("exp_avg_sq", "state2", "qmap2", "max2"),
...@@ -98,6 +121,9 @@ str2statenames["adam8bit_blockwise"] = [ ...@@ -98,6 +121,9 @@ str2statenames["adam8bit_blockwise"] = [
("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg", "state1", "qmap1", "absmax1"),
("exp_avg_sq", "state2", "qmap2", "absmax2"), ("exp_avg_sq", "state2", "qmap2", "absmax2"),
] ]
str2statenames["lion8bit_blockwise"] = [
("exp_avg", "state1", "qmap1", "absmax1")
]
str2statenames["momentum8bit"] = [ str2statenames["momentum8bit"] = [
("momentum_buffer", "state1", "qmap1", "max1") ("momentum_buffer", "state1", "qmap1", "max1")
] ]
...@@ -113,7 +139,7 @@ str2statenames["rmsprop8bit_blockwise"] = [ ...@@ -113,7 +139,7 @@ str2statenames["rmsprop8bit_blockwise"] = [
dim1 = [1024] dim1 = [1024]
dim2 = [32, 1024, 4097, 1] dim2 = [32, 1024, 4097, 1]
gtype = [torch.float32, torch.float16] gtype = [torch.float32, torch.float16]
optimizer_names = ["adam", "momentum", "rmsprop", "lars"] optimizer_names = ["adam", "momentum", "rmsprop", "lars", "lion"]
values = list(product(dim1, dim2, gtype, optimizer_names)) values = list(product(dim1, dim2, gtype, optimizer_names))
names = [ names = [
"dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values "dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values
...@@ -144,6 +170,7 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): ...@@ -144,6 +170,7 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
bnb_optimizer.step() bnb_optimizer.step()
torch_optimizer.step() torch_optimizer.step()
for name1, name2 in str2statenames[optim_name]: for name1, name2 in str2statenames[optim_name]:
torch.testing.assert_allclose( torch.testing.assert_allclose(
torch_optimizer.state[p1][name1], torch_optimizer.state[p1][name1],
...@@ -152,7 +179,9 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): ...@@ -152,7 +179,9 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
rtol=rtol, rtol=rtol,
) )
torch.testing.assert_allclose(p1, p2.float(), atol=atol, rtol=rtol) # since Lion can have pretty noisy updates where things lie at the boundary
# allow up to 10 errors for Lion
assert_most_approx_close(p1, p2.float(), atol, rtol, max_error_count=10)
if i % (k // 5) == 0 and i > 0: if i % (k // 5) == 0 and i > 0:
path = get_temp_dir() path = get_temp_dir()
...@@ -162,14 +191,15 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): ...@@ -162,14 +191,15 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
bnb_optimizer = str2optimizers[optim_name][1]([p2]) bnb_optimizer = str2optimizers[optim_name][1]([p2])
bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt"))) bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt")))
rm_path(path) rm_path(path)
torch.testing.assert_allclose(p1, p2.float(), atol=atol, rtol=rtol) # since Lion can have pretty noisy updates where things lie at the boundary
# allow up to 10 errors for Lion
assert_most_approx_close(p1, p2.float(), atol, rtol, max_error_count=10)
for name1, name2 in str2statenames[optim_name]: for name1, name2 in str2statenames[optim_name]:
torch.testing.assert_allclose( # since Lion can have pretty noisy updates where things lie at the boundary
torch_optimizer.state[p1][name1], # allow up to 10 errors for Lion
bnb_optimizer.state[p2][name2], assert_most_approx_close(torch_optimizer.state[p1][name1], bnb_optimizer.state[p2][name2],
atol=atol, atol=atol, rtol=rtol,
rtol=rtol, max_error_count=10)
)
if gtype == torch.float16: if gtype == torch.float16:
# the adam buffers should also be close because they are 32-bit # the adam buffers should also be close because they are 32-bit
...@@ -241,9 +271,11 @@ dim2 = [32, 1024, 4097] ...@@ -241,9 +271,11 @@ dim2 = [32, 1024, 4097]
gtype = [torch.float32, torch.float16] gtype = [torch.float32, torch.float16]
optimizer_names = [ optimizer_names = [
"adam8bit", "adam8bit",
"lion8bit",
"momentum8bit", "momentum8bit",
"rmsprop8bit", "rmsprop8bit",
"adam8bit_blockwise", "adam8bit_blockwise",
"lion8bit_blockwise",
"lars8bit", "lars8bit",
"momentum8bit_blockwise", "momentum8bit_blockwise",
"rmsprop8bit_blockwise", "rmsprop8bit_blockwise",
...@@ -285,7 +317,9 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): ...@@ -285,7 +317,9 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
bnb_optimizer.step() bnb_optimizer.step()
torch_optimizer.step() torch_optimizer.step()
torch.testing.assert_allclose(p1, p2.float(), atol=patol, rtol=prtol) # since Lion can have pretty noisy updates where things lie at the boundary
# allow up to 5 errors for Lion
assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=5)
dequant_states = [] dequant_states = []
for name1, name2, qmap, max_val in str2statenames[optim_name]: for name1, name2, qmap, max_val in str2statenames[optim_name]:
...@@ -313,7 +347,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): ...@@ -313,7 +347,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
dequant_states.append(s1.clone()) dequant_states.append(s1.clone())
err = torch.abs(p1 - p2) err = torch.abs(p1 - p2)
relerr = err / torch.abs(p1) relerr = err / (torch.abs(p1)+1e-9)
assert err.mean() < 0.0001 assert err.mean() < 0.0001
assert relerr.mean() < 0.001 assert relerr.mean() < 0.001
...@@ -367,9 +401,9 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): ...@@ -367,9 +401,9 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
== 0 == 0
) )
assert num_not_close.sum().item() < 20 assert num_not_close.sum().item() < 20
torch.testing.assert_allclose( # since Lion can have pretty noisy updates where things lie at the boundary
p1, p2.float(), atol=patol, rtol=prtol # allow up to 5 errors for Lion
) assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=5)
# the parameters diverge quickly. Here we keep them close # the parameters diverge quickly. Here we keep them close
# together so we can test against the Adam error # together so we can test against the Adam error
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment