"vscode:/vscode.git/clone" did not exist on "9a59ebcec35c176af8ae3ecaf36f15e7ff486ec6"
Unverified Commit 2336a45c authored by Aarni Koskela's avatar Aarni Koskela Committed by GitHub
Browse files

Test improvements (#1001)

* test_nvidia_transform: fix variable reference

`out_order` is the global parametrization list, not the test fixture argument

* Make `parametrize` use more idiomatic

* Use a more deterministic helper for `dim*` determination

* Convert NO_CUBLASLT errors into skips too

* Mark slow and benchmark tests as such (allows `-k "not benchmark"`)
parent 1a0dc5c3
......@@ -8,3 +8,6 @@ addopts = -rP
log_cli = True
log_cli_level = INFO
log_file = logs/pytest.log
markers =
benchmark: mark test as benchmark
slow: mark test as slow
......@@ -5,6 +5,10 @@ import torch
def pytest_runtest_call(item):
try:
item.runtest()
except NotImplementedError as nie:
if "NO_CUBLASLT" in str(nie):
pytest.skip("CUBLASLT not available")
raise
except AssertionError as ae:
if str(ae) == "Torch not compiled with CUDA enabled":
pytest.skip("Torch not compiled with CUDA enabled")
......
from itertools import product
import random
from typing import Any
import torch
test_dims_rng = random.Random(42)
def get_test_dims(min: int, max: int, *, n: int) -> list[int]:
return [test_dims_rng.randint(min, max) for _ in range(n)]
def format_with_label(label: str, value: Any) -> str:
if isinstance(value, bool):
formatted = "T" if value else "F"
elif isinstance(value, (list, tuple)) and all(isinstance(v, bool) for v in value):
formatted = "".join("T" if b else "F" for b in value)
else:
formatted = str(value)
return f"{label}={formatted}"
def id_formatter(label: str):
"""
Return a function that formats the value given to it with the given label.
"""
return lambda value: format_with_label(label, value)
DTYPE_NAMES = {
torch.bfloat16: "bf16",
torch.bool: "bool",
torch.float16: "fp16",
torch.float32: "fp32",
torch.float64: "fp64",
torch.int32: "int32",
torch.int64: "int64",
torch.int8: "int8",
}
def describe_dtype(dtype: torch.dtype) -> str:
return DTYPE_NAMES.get(dtype) or str(dtype).rpartition(".")[2]
TRUE_FALSE = (True, False)
BOOLEAN_TRIPLES = list(
product(TRUE_FALSE, repeat=3)
) # all combinations of (bool, bool, bool)
BOOLEAN_TUPLES = list(product(TRUE_FALSE, repeat=2)) # all combinations of (bool, bool)
from itertools import product
from typing import Tuple
import pytest
import torch
import bitsandbytes as bnb
n = 1
k = 25
dim1 = torch.randint(16, 64, size=(n,)).tolist()
dim2 = torch.randint(32, 96, size=(n,)).tolist()
dim3 = torch.randint(32, 96, size=(n,)).tolist()
dim4 = torch.randint(32, 96, size=(n,)).tolist()
funcs = [(torch.bmm, bnb.bmm_cublas), (torch.matmul, bnb.matmul_cublas)]
str_funcs = ["bmm", "matmul"]
req_grad = [(False, False), (True, False), (True, True), (False, True)]
req_grad_str = ["FF", "TF", "TT", "FT"]
transpose = [(False, False), (False, True), (True, True), (True, False)]
str_transpose = ["FF", "FT", "TT", "TF"]
dtype = [torch.float32, torch.float16]
values = list(
product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose)
)
str_values = list(
product(
dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose
)
from tests.helpers import (
BOOLEAN_TRIPLES,
BOOLEAN_TUPLES,
TRUE_FALSE,
describe_dtype,
get_test_dims,
id_formatter,
)
names = [
"dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}".format(
*vals
)
for vals in str_values
]
TRANSPOSE_VALS = [(False, True), (False, False)]
@pytest.mark.parametrize(
"dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose",
values,
ids=names,
)
def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
@pytest.mark.parametrize("dim1", get_test_dims(16, 64, n=1), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", get_test_dims(32, 96, n=1), ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3"))
@pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4"))
@pytest.mark.parametrize("funcs", [(torch.bmm, bnb.bmm_cublas), (torch.matmul, bnb.matmul_cublas)], ids=["func=bmm", "func=matmul"])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=describe_dtype)
@pytest.mark.parametrize("req_grad", BOOLEAN_TUPLES, ids=id_formatter("req_grad"))
@pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose"))
def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool], transpose: Tuple[bool, bool]):
if dim2 > 0:
dim2 = dim2 - (dim2 % 16)
dim3 = dim3 - (dim3 % 16)
dim4 = dim4 - (dim4 % 16)
for i in range(k):
for i in range(25):
# normal multiply
if funcs[0] in [torch.mm, torch.matmul]:
......@@ -228,71 +213,17 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
assert (idx == 0).sum().item() < n * 0.02
n = 1
k = 3
dim1 = torch.randint(16, 64, size=(n,)).tolist()
dim2 = torch.randint(32, 96, size=(n,)).tolist()
dim3 = torch.randint(32, 96, size=(n,)).tolist()
dim4 = torch.randint(32, 96, size=(n,)).tolist()
dim2.append(0)
decomp = [0.0, 6.0]
funcs = [(torch.matmul, bnb.matmul), (torch.matmul, bnb.research.switchback_bnb)]
str_funcs = ["matmullt", 'switchback_bnb']
req_grad = [(False, False), (True, False), (True, True), (False, True)]
req_grad = list(product([True, False], repeat=3))
req_grad_str = []
for c in req_grad:
strval = ''
for v in c:
if v == True: strval += 'T'
else: strval += 'F'
req_grad_str.append(strval)
transpose = [(False, True), (False, False)]
str_transpose = ["NT", "NN"]
dtype = [torch.float16, torch.bfloat16, torch.float32]
has_fp16_weights = [True, False]
has_bias = [True, False]
values = list(
product(
dim1,
dim2,
dim3,
dim4,
funcs,
dtype,
req_grad,
transpose,
decomp,
has_fp16_weights,
has_bias
)
)
str_values = list(
product(
dim1,
dim2,
dim3,
dim4,
str_funcs,
dtype,
req_grad_str,
str_transpose,
decomp,
has_fp16_weights,
has_bias
)
)
names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}_decomp_{}_has_fp16_weights_{}_has_bias_{}".format(*vals) for vals in str_values]
@pytest.mark.parametrize(
"dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights, has_bias",
values,
ids=names,
)
@pytest.mark.parametrize("dim1", get_test_dims(16, 64, n=1), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [*get_test_dims(32, 96, n=1), 0], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3"))
@pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4"))
@pytest.mark.parametrize("decomp", [0.0, 6.0], ids=id_formatter("decomp"))
@pytest.mark.parametrize("funcs", [(torch.matmul, bnb.matmul), (torch.matmul, bnb.research.switchback_bnb)], ids=["func=matmul", "func=switchback_bnb"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
@pytest.mark.parametrize("req_grad", BOOLEAN_TRIPLES, ids=id_formatter("req_grad"))
@pytest.mark.parametrize("transpose", TRANSPOSE_VALS, ids=id_formatter("transpose"))
@pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights"))
@pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias"))
def test_matmullt(
dim1,
dim2,
......@@ -313,7 +244,7 @@ def test_matmullt(
req_grad = list(req_grad)
req_grad[2] = False
for i in range(k):
for i in range(3):
# normal multiply
if funcs[0] in [torch.mm, torch.matmul]:
......@@ -429,45 +360,25 @@ def test_matmullt(
torch.testing.assert_close(gradBias1, gradBias2)
n = 1
k = 3
dim1 = torch.randint(16, 64, size=(n,)).tolist()
dim2 = torch.randint(32, 96, size=(n,)).tolist()
dim3 = torch.randint(32, 96, size=(n,)).tolist()
dim4 = torch.randint(32, 96, size=(n,)).tolist()
dim2.append(0)
funcs = [(torch.matmul, bnb.matmul_4bit)]
str_funcs = ["matmul"]
req_grad = list(product([True, False], repeat=3))
req_grad_str = []
for c in req_grad:
strval = ''
for v in c:
if v == True: strval += 'T'
else: strval += 'F'
req_grad_str.append(strval)
transpose = [(False, True), (False, False)]
str_transpose = ["NT", "NN"]
dtype = [torch.float16, torch.float32]
compress_statistics = [False, True]
has_fp16_weights = [True, False]
has_bias = [True, False]
quant_type = ['fp4', 'nf4']
values = list(product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics, quant_type))
str_values = list(product(dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose, has_bias, compress_statistics, quant_type))
names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}_has_bias_{}_compress_statistics_{}_quant_type_{}".format(*vals) for vals in str_values]
@pytest.mark.parametrize( "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics, quant_type", values, ids=names)
def test_matmul_4bit( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics, quant_type):
@pytest.mark.parametrize("dim1", get_test_dims(16, 64, n=1), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [*get_test_dims(32, 96, n=1), 0], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3"))
@pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4"))
@pytest.mark.parametrize("funcs", [(torch.matmul, bnb.matmul_4bit)], ids=["func=matmul"])
@pytest.mark.parametrize("req_grad", BOOLEAN_TRIPLES, ids=id_formatter("req_grad"))
@pytest.mark.parametrize("transpose", TRANSPOSE_VALS, ids=id_formatter("transpose"))
@pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias"))
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=describe_dtype)
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
@pytest.mark.parametrize("quant_type", ['fp4', 'nf4'], ids=id_formatter("quant_type"))
def test_matmul_4bit(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics, quant_type):
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
if has_bias == False:
req_grad = list(req_grad)
req_grad[2] = False
for i in range(k):
for i in range(3):
# normal multiply
if funcs[0] in [torch.mm, torch.matmul]:
A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype)
......@@ -530,32 +441,21 @@ def test_matmul_4bit( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
torch.testing.assert_close(gradBias1, gradBias2)
funcs = [(torch.matmul, bnb.research.matmul_fp8_mixed), (torch.matmul, bnb.research.matmul_fp8_global)]
str_funcs = ["matmul_fp8_mixed", 'matmul_fp8_global']
req_grad = list(product([True, False], repeat=3))
req_grad_str = []
for c in req_grad:
strval = ''
for v in c:
if v == True: strval += 'T'
else: strval += 'F'
req_grad_str.append(strval)
transpose = [(False, True), (False, False)]
str_transpose = ["NT", "NN"]
dtype = [torch.float16, torch.float32]
has_fp16_weights = [True, False]
values = list(product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose))
str_values = list(product(dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose))
names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}".format(*vals) for vals in str_values]
@pytest.mark.parametrize( "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose", values, ids=names)
@pytest.mark.parametrize("dim1", get_test_dims(16, 64, n=1), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [*get_test_dims(32, 96, n=1), 0], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3"))
@pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4"))
@pytest.mark.parametrize("req_grad", BOOLEAN_TRIPLES, ids=id_formatter("req_grad"))
@pytest.mark.parametrize("transpose", TRANSPOSE_VALS, ids=id_formatter("transpose"))
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=describe_dtype)
@pytest.mark.parametrize("funcs", [(torch.matmul, bnb.research.matmul_fp8_mixed), (torch.matmul, bnb.research.matmul_fp8_global)], ids=["matmul_fp8_mixed", 'matmul_fp8_global'])
def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
req_grad = list(req_grad)
req_grad[2] = False
for i in range(k):
for i in range(3):
# normal multiply
if funcs[0] in [torch.mm, torch.matmul]:
A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype)
......
This diff is collapsed.
......@@ -9,6 +9,8 @@ from transformers import (
BitsAndBytesConfig,
)
from tests.helpers import TRUE_FALSE, describe_dtype, id_formatter
def get_4bit_config():
return BitsAndBytesConfig(
......@@ -59,23 +61,19 @@ def generate(model, tokenizer, text, generation_config, prompt_func=get_prompt_f
models = ['huggyllama/llama-7b', 'bigscience/bloom-1b7']
dtypes = ['nf4', 'fp4']
load_in_4bit = [True, False]
values = list(product(models, dtypes))
strfunc = lambda lst: [str(x) for x in lst]
ids = ['_'.join(strfunc(x)) for x in values]
@pytest.fixture(scope='session', params=values, ids=ids)
@pytest.fixture(scope='session', params=product(models, dtypes))
def model_and_tokenizer(request):
model, tokenizer = get_model_and_tokenizer(request.param)
yield request.param, model, tokenizer
del model
@pytest.mark.parametrize("DQ", [True, False], ids=['DQ_True', 'DQ_False'])
@pytest.mark.parametrize("inference_kernel", [True, False], ids=['inference_kernel_True', 'inference_kernel_False'])
#@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=['fp16', 'bf16', 'fp32'])
def test_pi(requires_cuda, model_and_tokenizer, inference_kernel, DQ):
print('')
dtype = torch.float16
@pytest.mark.parametrize("DQ", TRUE_FALSE, ids=id_formatter("dq"))
@pytest.mark.parametrize("inference_kernel", TRUE_FALSE, ids=id_formatter("inference_kernel"))
@pytest.mark.parametrize("dtype", [torch.float16], ids=describe_dtype)
@pytest.mark.slow
def test_pi(requires_cuda, model_and_tokenizer, inference_kernel, DQ, dtype):
fixture_config, model, tokenizer = model_and_tokenizer
generation_config = transformers.GenerationConfig(
......
from itertools import product
import os
from tempfile import TemporaryDirectory
......@@ -6,6 +5,7 @@ import pytest
import torch
import bitsandbytes as bnb
from tests.helpers import TRUE_FALSE
storage = {
'uint8': torch.uint8,
......@@ -14,10 +14,10 @@ storage = {
'float32': torch.float32
}
@pytest.mark.parametrize(
"quant_type, compress_statistics, bias, quant_storage",
list(product(["nf4", "fp4"], [False, True], [False, True], ['uint8', 'float16', 'bfloat16', 'float32'])),
)
@pytest.mark.parametrize("quant_storage", ['uint8', 'float16', 'bfloat16', 'float32'])
@pytest.mark.parametrize("bias", TRUE_FALSE)
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE)
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
def test_linear_serialization(quant_type, compress_statistics, bias, quant_storage):
original_dtype = torch.float16
compute_dtype = None
......
from contextlib import nullcontext
from itertools import product
import os
from tempfile import TemporaryDirectory
......@@ -10,6 +9,7 @@ import bitsandbytes as bnb
from bitsandbytes import functional as F
from bitsandbytes.autograd import get_inverse_transform_indices, undo_layout
from bitsandbytes.nn.modules import Linear8bitLt
from tests.helpers import TRUE_FALSE, id_formatter
# contributed by Alex Borzunov, see:
# https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py
......@@ -66,8 +66,10 @@ def test_linear_no_igemmlt():
assert linear_custom.state.CxB is None
@pytest.mark.parametrize("has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt",
list(product([False, True], [False, True], [False, True], [False, True])))
@pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights"))
@pytest.mark.parametrize("serialize_before_forward", TRUE_FALSE, ids=id_formatter("serialize_before_forward"))
@pytest.mark.parametrize("deserialize_before_cuda", TRUE_FALSE, ids=id_formatter("deserialize_before_cuda"))
@pytest.mark.parametrize("force_no_igemmlt", TRUE_FALSE, ids=id_formatter("force_no_igemmlt"))
def test_linear_serialization(has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt):
linear = torch.nn.Linear(32, 96)
x = torch.randn(3, 32, dtype=torch.half)
......
......@@ -6,6 +6,7 @@ import torch
from torch import nn
import bitsandbytes as bnb
from tests.helpers import id_formatter
class MockArgs:
......@@ -311,12 +312,7 @@ class Linear8bit(nn.Module):
return LinearFunction.apply(x, self.weight, self.bias, self.args)
threshold = [0.0, 3.0]
values = threshold
names = [f"threshold_{vals}" for vals in values]
@pytest.mark.parametrize("threshold", values, ids=names)
@pytest.mark.parametrize("threshold", [0.0, 3.0], ids=id_formatter("threshold"))
def test_linear8bitlt_inference(threshold):
l1 = bnb.nn.Linear8bitLt(32, 64, threshold=threshold).cuda().half()
assert l1.weight.device.type == "cuda"
......@@ -510,18 +506,21 @@ def test_linear_kbit_fp32_bias(module):
o1 = l1(b1)
assert l1.bias is None
modules = []
modules.append(bnb.nn.Linear8bitLt)
modules.append(bnb.nn.Linear4bit)
modules.append(bnb.nn.LinearFP4)
modules.append(bnb.nn.LinearNF4)
modules.append(lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compress_statistics=True))
modules.append(lambda d1, d2: bnb.nn.LinearNF4(d1, d2, compress_statistics=True))
modules.append(lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.float32))
modules.append(lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.float16))
modules.append(lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.bfloat16))
names = ['Int8Lt', '4bit', 'FP4', 'NF4', 'FP4+C', 'NF4+C', 'NF4+fp32', 'NF4+fp16', 'NF4+bf16']
@pytest.mark.parametrize("module", modules, ids=names)
module_dict = {
"Int8Lt": bnb.nn.Linear8bitLt,
"4bit": bnb.nn.Linear4bit,
"FP4": bnb.nn.LinearFP4,
"NF4": bnb.nn.LinearNF4,
"FP4+C": lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compress_statistics=True),
"NF4+C": lambda d1, d2: bnb.nn.LinearNF4(d1, d2, compress_statistics=True),
"NF4+fp32": lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.float32),
"NF4+fp16": lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.float16),
"NF4+bf16": lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.bfloat16),
}
@pytest.mark.parametrize("module", module_dict.values(), ids=module_dict.keys())
def test_kbit_backprop(module):
b = 17
dim1 = 37
......
from itertools import product
import os
from os.path import join
import shutil
......@@ -11,6 +10,7 @@ import torch
import bitsandbytes as bnb
import bitsandbytes.functional as F
from tests.helpers import describe_dtype, id_formatter
# import apex
......@@ -101,15 +101,16 @@ str2statenames["rmsprop8bit_blockwise"] = [("square_avg", "state1", "qmap1", "ab
str2statenames["lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")]
str2statenames["paged_lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")]
dim1 = [1024]
dim2 = [32, 1024, 4097, 1]
gtype = [torch.float32, torch.float16, torch.bfloat16]
optimizer_names = ["adam", "momentum", "rmsprop", 'paged_adamw', 'paged_adam', 'lion', 'paged_lion']
values = list(product(dim1, dim2, gtype, optimizer_names))
names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
optimizer_names_32bit = ["adam", "momentum", "rmsprop", 'paged_adamw', 'paged_adam', 'lion', 'paged_lion']
@pytest.mark.parametrize("optim_name", optimizer_names_32bit, ids=id_formatter("opt"))
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [32, 1024, 4097, 1], ids=id_formatter("dim2"))
def test_optimizer32bit(dim1, dim2, gtype, optim_name):
if gtype == torch.bfloat16 and optim_name in ['momentum', 'rmsprop']: pytest.skip()
if gtype == torch.bfloat16 and optim_name in ['momentum', 'rmsprop']:
pytest.skip()
if dim1 == 1 and dim2 == 1:
return
p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
......@@ -134,7 +135,6 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
bnb_optimizer.step()
torch_optimizer.step()
for name1, name2 in str2statenames[optim_name]:
torch.testing.assert_close(
torch_optimizer.state[p1][name1],
......@@ -177,14 +177,9 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
assert bnb_optimizer.state[p2]["unorm_vec"] > 0.0
dim1 = [1024]
dim2 = [32, 1024, 4097]
gtype = [torch.float32, torch.float16]
values = list(product(dim1, dim2, gtype))
names = ["dim1_{}_dim2_{}_gtype_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, gtype", values, ids=names)
@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2"))
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=describe_dtype)
def test_global_config(dim1, dim2, gtype):
if dim1 == 1 and dim2 == 1:
return
......@@ -230,10 +225,7 @@ def test_global_config(dim1, dim2, gtype):
assert adam2.state[p3]["state2"].dtype == torch.uint8
dim1 = [1024]
dim2 = [32, 1024, 4097]
gtype = [torch.float32, torch.float16, torch.bfloat16]
optimizer_names = [
optimizer_names_8bit = [
"adam8bit",
"lion8bit",
"momentum8bit",
......@@ -243,13 +235,12 @@ optimizer_names = [
"momentum8bit_blockwise",
"rmsprop8bit_blockwise",
]
values = list(product(dim1, dim2, gtype, optimizer_names))
names = [
"dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values
]
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
@pytest.mark.parametrize("optim_name", optimizer_names_8bit, ids=id_formatter("opt"))
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
@pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
def test_optimizer8bit(dim1, dim2, gtype, optim_name):
if gtype == torch.bfloat16 and optim_name not in ['adam8bit_blockwise', 'lion8bit_blockwise']: pytest.skip()
if dim1 == 1 and dim2 == 1:
......@@ -375,18 +366,10 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
# print(sum(relerrors)/len(relerrors))
dim1 = [1024]
dim2 = [32, 1024, 4097]
gtype = [torch.float32]
optim_bits = [32, 8]
values = list(product(dim1, dim2, gtype, optim_bits))
names = [
"dim1_{}_dim2_{}_gtype_{}_optim_bits_{}".format(*vals)
for vals in values
]
@pytest.mark.parametrize("dim1, dim2, gtype, optim_bits", values, ids=names)
@pytest.mark.parametrize("optim_bits", [32, 8], ids=id_formatter("optim_bits"))
@pytest.mark.parametrize("gtype", [torch.float32], ids=describe_dtype)
@pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
if dim1 == 1 and dim2 == 1:
return
......@@ -474,22 +457,19 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
adam2.load_state_dict(torch.load(join(path, "opt.pt")))
dim1 = [4096]
dim2 = [4096]
gtype = [torch.float32, torch.float16]
# optimizer_names = ['adam8bit_blockwise', 'adam8bit', 'lamb8bit']
# optimizer_names = ['adam8bit_blockwise', 'adam_apex', 'adam8bit', 'adam', 'adam_pytorch']
# optimizer_names = ['momentum_apex', 'momentum8bit', 'momentum_pytorch']
# optimizer_names = ['lamb_apex', 'lamb8bit']
# optimizer_names = ['lars_apex', 'lars8bit']
optimizer_names = ["adam8bit_blockwise", 'paged_adam8bit_blockwise', 'paged_adamw8bit_blockwise', 'paged_lion8bit_blockwise']
values = list(product(dim1, dim2, gtype, optimizer_names))
names = [
"dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values
optimizer_names_benchmark = [
"adam8bit_blockwise",
"paged_adam8bit_blockwise",
"paged_adamw8bit_blockwise",
"paged_lion8bit_blockwise",
]
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
@pytest.mark.parametrize("dim1", [4096], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [4096], ids=id_formatter("dim2"))
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=describe_dtype)
@pytest.mark.parametrize("optim_name", optimizer_names_benchmark, ids=id_formatter("opt"))
@pytest.mark.benchmark
def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
if dim1 == 1 and dim2 == 1:
return
......@@ -514,15 +494,12 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
print(optim_name, gtype, s / params)
# assert s < 3.9
dim1 = [2*1024]
gtype = [torch.float16]
#mode = ['torch', 'bnb']
mode = ['bnb']
optimizer_names = ['paged_adamw']
#optimizer_names = ['paged_adamw8bit_blockwise']
values = list(product(dim1,gtype, optimizer_names, mode))
names = ['dim1_{0}_gtype_{1}_optim_{2}_mode_{3}'.format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, gtype, optim_name, mode", values, ids=names)
@pytest.mark.parametrize("dim1", [2 * 1024], ids=id_formatter("dim1"))
@pytest.mark.parametrize("gtype", [torch.float16], ids=describe_dtype)
@pytest.mark.parametrize("optim_name", ['paged_adamw'], ids=id_formatter("optim_name"))
@pytest.mark.parametrize("mode", ['bnb'], ids=id_formatter("mode"))
@pytest.mark.benchmark
def test_stream_optimizer_bench(dim1, gtype, optim_name, mode):
layers1 = torch.nn.Sequential(*torch.nn.ModuleList([torch.nn.Linear(dim1, dim1) for i in range(10)]))
layers1 = layers1.to(gtype)
......
......@@ -4,11 +4,12 @@ import torch
from bitsandbytes.nn import Linear8bitLt
from bitsandbytes.nn.triton_based_modules import SwitchBackLinear
from bitsandbytes.triton.triton_utils import is_triton_available
from tests.helpers import TRUE_FALSE
@pytest.mark.skipif(not is_triton_available() or not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] >= 8,
reason="This test requires triton and a GPU with compute capability 8.0 or higher.")
@pytest.mark.parametrize("vector_wise_quantization", [False, True])
@pytest.mark.parametrize("vector_wise_quantization", TRUE_FALSE)
def test_switchback(vector_wise_quantization):
for dim in [83]:
for batch in [13]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment