Unverified Commit 2336a45c authored by Aarni Koskela's avatar Aarni Koskela Committed by GitHub
Browse files

Test improvements (#1001)

* test_nvidia_transform: fix variable reference

`out_order` is the global parametrization list, not the test fixture argument

* Make `parametrize` use more idiomatic

* Use a more deterministic helper for `dim*` determination

* Convert NO_CUBLASLT errors into skips too

* Mark slow and benchmark tests as such (allows `-k "not benchmark"`)
parent 1a0dc5c3
...@@ -7,4 +7,7 @@ addopts = -rP ...@@ -7,4 +7,7 @@ addopts = -rP
log_cli = True log_cli = True
log_cli_level = INFO log_cli_level = INFO
log_file = logs/pytest.log log_file = logs/pytest.log
\ No newline at end of file markers =
benchmark: mark test as benchmark
slow: mark test as slow
...@@ -5,6 +5,10 @@ import torch ...@@ -5,6 +5,10 @@ import torch
def pytest_runtest_call(item): def pytest_runtest_call(item):
try: try:
item.runtest() item.runtest()
except NotImplementedError as nie:
if "NO_CUBLASLT" in str(nie):
pytest.skip("CUBLASLT not available")
raise
except AssertionError as ae: except AssertionError as ae:
if str(ae) == "Torch not compiled with CUDA enabled": if str(ae) == "Torch not compiled with CUDA enabled":
pytest.skip("Torch not compiled with CUDA enabled") pytest.skip("Torch not compiled with CUDA enabled")
......
from itertools import product
import random
from typing import Any
import torch
test_dims_rng = random.Random(42)
def get_test_dims(min: int, max: int, *, n: int) -> list[int]:
return [test_dims_rng.randint(min, max) for _ in range(n)]
def format_with_label(label: str, value: Any) -> str:
if isinstance(value, bool):
formatted = "T" if value else "F"
elif isinstance(value, (list, tuple)) and all(isinstance(v, bool) for v in value):
formatted = "".join("T" if b else "F" for b in value)
else:
formatted = str(value)
return f"{label}={formatted}"
def id_formatter(label: str):
"""
Return a function that formats the value given to it with the given label.
"""
return lambda value: format_with_label(label, value)
DTYPE_NAMES = {
torch.bfloat16: "bf16",
torch.bool: "bool",
torch.float16: "fp16",
torch.float32: "fp32",
torch.float64: "fp64",
torch.int32: "int32",
torch.int64: "int64",
torch.int8: "int8",
}
def describe_dtype(dtype: torch.dtype) -> str:
return DTYPE_NAMES.get(dtype) or str(dtype).rpartition(".")[2]
TRUE_FALSE = (True, False)
BOOLEAN_TRIPLES = list(
product(TRUE_FALSE, repeat=3)
) # all combinations of (bool, bool, bool)
BOOLEAN_TUPLES = list(product(TRUE_FALSE, repeat=2)) # all combinations of (bool, bool)
from itertools import product from typing import Tuple
import pytest import pytest
import torch import torch
import bitsandbytes as bnb import bitsandbytes as bnb
from tests.helpers import (
n = 1 BOOLEAN_TRIPLES,
k = 25 BOOLEAN_TUPLES,
dim1 = torch.randint(16, 64, size=(n,)).tolist() TRUE_FALSE,
dim2 = torch.randint(32, 96, size=(n,)).tolist() describe_dtype,
dim3 = torch.randint(32, 96, size=(n,)).tolist() get_test_dims,
dim4 = torch.randint(32, 96, size=(n,)).tolist() id_formatter,
funcs = [(torch.bmm, bnb.bmm_cublas), (torch.matmul, bnb.matmul_cublas)]
str_funcs = ["bmm", "matmul"]
req_grad = [(False, False), (True, False), (True, True), (False, True)]
req_grad_str = ["FF", "TF", "TT", "FT"]
transpose = [(False, False), (False, True), (True, True), (True, False)]
str_transpose = ["FF", "FT", "TT", "TF"]
dtype = [torch.float32, torch.float16]
values = list(
product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose)
)
str_values = list(
product(
dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose
)
)
names = [
"dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}".format(
*vals
)
for vals in str_values
]
@pytest.mark.parametrize(
"dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose",
values,
ids=names,
) )
def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
TRANSPOSE_VALS = [(False, True), (False, False)]
@pytest.mark.parametrize("dim1", get_test_dims(16, 64, n=1), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", get_test_dims(32, 96, n=1), ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3"))
@pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4"))
@pytest.mark.parametrize("funcs", [(torch.bmm, bnb.bmm_cublas), (torch.matmul, bnb.matmul_cublas)], ids=["func=bmm", "func=matmul"])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=describe_dtype)
@pytest.mark.parametrize("req_grad", BOOLEAN_TUPLES, ids=id_formatter("req_grad"))
@pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose"))
def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool], transpose: Tuple[bool, bool]):
if dim2 > 0: if dim2 > 0:
dim2 = dim2 - (dim2 % 16) dim2 = dim2 - (dim2 % 16)
dim3 = dim3 - (dim3 % 16) dim3 = dim3 - (dim3 % 16)
dim4 = dim4 - (dim4 % 16) dim4 = dim4 - (dim4 % 16)
for i in range(k): for i in range(25):
# normal multiply # normal multiply
if funcs[0] in [torch.mm, torch.matmul]: if funcs[0] in [torch.mm, torch.matmul]:
...@@ -228,71 +213,17 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): ...@@ -228,71 +213,17 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
assert (idx == 0).sum().item() < n * 0.02 assert (idx == 0).sum().item() < n * 0.02
n = 1 @pytest.mark.parametrize("dim1", get_test_dims(16, 64, n=1), ids=id_formatter("dim1"))
k = 3 @pytest.mark.parametrize("dim2", [*get_test_dims(32, 96, n=1), 0], ids=id_formatter("dim2"))
dim1 = torch.randint(16, 64, size=(n,)).tolist() @pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3"))
dim2 = torch.randint(32, 96, size=(n,)).tolist() @pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4"))
dim3 = torch.randint(32, 96, size=(n,)).tolist() @pytest.mark.parametrize("decomp", [0.0, 6.0], ids=id_formatter("decomp"))
dim4 = torch.randint(32, 96, size=(n,)).tolist() @pytest.mark.parametrize("funcs", [(torch.matmul, bnb.matmul), (torch.matmul, bnb.research.switchback_bnb)], ids=["func=matmul", "func=switchback_bnb"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
dim2.append(0) @pytest.mark.parametrize("req_grad", BOOLEAN_TRIPLES, ids=id_formatter("req_grad"))
@pytest.mark.parametrize("transpose", TRANSPOSE_VALS, ids=id_formatter("transpose"))
decomp = [0.0, 6.0] @pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights"))
funcs = [(torch.matmul, bnb.matmul), (torch.matmul, bnb.research.switchback_bnb)] @pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias"))
str_funcs = ["matmullt", 'switchback_bnb']
req_grad = [(False, False), (True, False), (True, True), (False, True)]
req_grad = list(product([True, False], repeat=3))
req_grad_str = []
for c in req_grad:
strval = ''
for v in c:
if v == True: strval += 'T'
else: strval += 'F'
req_grad_str.append(strval)
transpose = [(False, True), (False, False)]
str_transpose = ["NT", "NN"]
dtype = [torch.float16, torch.bfloat16, torch.float32]
has_fp16_weights = [True, False]
has_bias = [True, False]
values = list(
product(
dim1,
dim2,
dim3,
dim4,
funcs,
dtype,
req_grad,
transpose,
decomp,
has_fp16_weights,
has_bias
)
)
str_values = list(
product(
dim1,
dim2,
dim3,
dim4,
str_funcs,
dtype,
req_grad_str,
str_transpose,
decomp,
has_fp16_weights,
has_bias
)
)
names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}_decomp_{}_has_fp16_weights_{}_has_bias_{}".format(*vals) for vals in str_values]
@pytest.mark.parametrize(
"dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights, has_bias",
values,
ids=names,
)
def test_matmullt( def test_matmullt(
dim1, dim1,
dim2, dim2,
...@@ -313,7 +244,7 @@ def test_matmullt( ...@@ -313,7 +244,7 @@ def test_matmullt(
req_grad = list(req_grad) req_grad = list(req_grad)
req_grad[2] = False req_grad[2] = False
for i in range(k): for i in range(3):
# normal multiply # normal multiply
if funcs[0] in [torch.mm, torch.matmul]: if funcs[0] in [torch.mm, torch.matmul]:
...@@ -429,45 +360,25 @@ def test_matmullt( ...@@ -429,45 +360,25 @@ def test_matmullt(
torch.testing.assert_close(gradBias1, gradBias2) torch.testing.assert_close(gradBias1, gradBias2)
n = 1 @pytest.mark.parametrize("dim1", get_test_dims(16, 64, n=1), ids=id_formatter("dim1"))
k = 3 @pytest.mark.parametrize("dim2", [*get_test_dims(32, 96, n=1), 0], ids=id_formatter("dim2"))
dim1 = torch.randint(16, 64, size=(n,)).tolist() @pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3"))
dim2 = torch.randint(32, 96, size=(n,)).tolist() @pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4"))
dim3 = torch.randint(32, 96, size=(n,)).tolist() @pytest.mark.parametrize("funcs", [(torch.matmul, bnb.matmul_4bit)], ids=["func=matmul"])
dim4 = torch.randint(32, 96, size=(n,)).tolist() @pytest.mark.parametrize("req_grad", BOOLEAN_TRIPLES, ids=id_formatter("req_grad"))
@pytest.mark.parametrize("transpose", TRANSPOSE_VALS, ids=id_formatter("transpose"))
dim2.append(0) @pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias"))
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=describe_dtype)
funcs = [(torch.matmul, bnb.matmul_4bit)] @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
str_funcs = ["matmul"] @pytest.mark.parametrize("quant_type", ['fp4', 'nf4'], ids=id_formatter("quant_type"))
req_grad = list(product([True, False], repeat=3)) def test_matmul_4bit(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics, quant_type):
req_grad_str = []
for c in req_grad:
strval = ''
for v in c:
if v == True: strval += 'T'
else: strval += 'F'
req_grad_str.append(strval)
transpose = [(False, True), (False, False)]
str_transpose = ["NT", "NN"]
dtype = [torch.float16, torch.float32]
compress_statistics = [False, True]
has_fp16_weights = [True, False]
has_bias = [True, False]
quant_type = ['fp4', 'nf4']
values = list(product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics, quant_type))
str_values = list(product(dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose, has_bias, compress_statistics, quant_type))
names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}_has_bias_{}_compress_statistics_{}_quant_type_{}".format(*vals) for vals in str_values]
@pytest.mark.parametrize( "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics, quant_type", values, ids=names)
def test_matmul_4bit( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics, quant_type):
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
if has_bias == False: if has_bias == False:
req_grad = list(req_grad) req_grad = list(req_grad)
req_grad[2] = False req_grad[2] = False
for i in range(k): for i in range(3):
# normal multiply # normal multiply
if funcs[0] in [torch.mm, torch.matmul]: if funcs[0] in [torch.mm, torch.matmul]:
A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype) A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype)
...@@ -530,32 +441,21 @@ def test_matmul_4bit( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, ...@@ -530,32 +441,21 @@ def test_matmul_4bit( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
torch.testing.assert_close(gradBias1, gradBias2) torch.testing.assert_close(gradBias1, gradBias2)
funcs = [(torch.matmul, bnb.research.matmul_fp8_mixed), (torch.matmul, bnb.research.matmul_fp8_global)] @pytest.mark.parametrize("dim1", get_test_dims(16, 64, n=1), ids=id_formatter("dim1"))
str_funcs = ["matmul_fp8_mixed", 'matmul_fp8_global'] @pytest.mark.parametrize("dim2", [*get_test_dims(32, 96, n=1), 0], ids=id_formatter("dim2"))
req_grad = list(product([True, False], repeat=3)) @pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3"))
req_grad_str = [] @pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4"))
for c in req_grad: @pytest.mark.parametrize("req_grad", BOOLEAN_TRIPLES, ids=id_formatter("req_grad"))
strval = '' @pytest.mark.parametrize("transpose", TRANSPOSE_VALS, ids=id_formatter("transpose"))
for v in c: @pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=describe_dtype)
if v == True: strval += 'T' @pytest.mark.parametrize("funcs", [(torch.matmul, bnb.research.matmul_fp8_mixed), (torch.matmul, bnb.research.matmul_fp8_global)], ids=["matmul_fp8_mixed", 'matmul_fp8_global'])
else: strval += 'F'
req_grad_str.append(strval)
transpose = [(False, True), (False, False)]
str_transpose = ["NT", "NN"]
dtype = [torch.float16, torch.float32]
has_fp16_weights = [True, False]
values = list(product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose))
str_values = list(product(dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose))
names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}".format(*vals) for vals in str_values]
@pytest.mark.parametrize( "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose", values, ids=names)
def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
req_grad = list(req_grad) req_grad = list(req_grad)
req_grad[2] = False req_grad[2] = False
for i in range(k): for i in range(3):
# normal multiply # normal multiply
if funcs[0] in [torch.mm, torch.matmul]: if funcs[0] in [torch.mm, torch.matmul]:
A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype) A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype)
......
...@@ -11,6 +11,13 @@ import torch ...@@ -11,6 +11,13 @@ import torch
import bitsandbytes as bnb import bitsandbytes as bnb
from bitsandbytes import functional as F from bitsandbytes import functional as F
from tests.helpers import (
BOOLEAN_TUPLES,
TRUE_FALSE,
describe_dtype,
get_test_dims,
id_formatter,
)
torch.set_printoptions( torch.set_printoptions(
precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000 precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000
...@@ -155,10 +162,10 @@ def test_dynamic_quantization(): ...@@ -155,10 +162,10 @@ def test_dynamic_quantization():
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=["fp32", "fp16", "bf16"]) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
@pytest.mark.parametrize("nested", [False, True], ids=["False", "True"]) @pytest.mark.parametrize("nested", TRUE_FALSE, ids=id_formatter("nested"))
@pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64]) @pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64])
@pytest.mark.parametrize("signed", [True, False], ids=['signed_True', 'signed_False']) @pytest.mark.parametrize("signed", TRUE_FALSE, ids=id_formatter("signed"))
def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed): def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed):
#print('') #print('')
diffs = [] diffs = []
...@@ -281,34 +288,22 @@ def mean(xx): ...@@ -281,34 +288,22 @@ def mean(xx):
return sum(xx) / float(len(xx)) return sum(xx) / float(len(xx))
# dim1 = torch.randint(1,1024*4, size=(4,)).tolist() methods = {
# dim2 = torch.randint(1,1024*4, size=(4,)).tolist() "linear": (
dim1 = [1024 * 2]
dim2 = [1024 * 16]
methods = [
(
lambda x, dim: quant(x), lambda x, dim: quant(x),
lambda x, dim: quant(x), lambda x, dim: quant(x),
dequant, dequant,
dequant, dequant,
mm_dequant, mm_dequant,
) ),
] "vectorwise": (quant_multi, quant_multi, dequant, dequant, mm_dequant),
methods.append((quant_multi, quant_multi, dequant, dequant, mm_dequant)) }
# methods.append((lambda x: quant_multi_chunk(x, dim=-1), lambda x: quant_multi_chunk(x, dim=0), dequant, dequant, mm_dequant))
method_names = ["linear", "vectorwise"]
batched = [False, True]
values = list(product(dim1, dim2, methods, batched))
values_names = list(product(dim1, dim2, method_names, batched))
names = [
"dim1_{}_dim2_{}_quant_{}_batched_{}".format(*vals)
for vals in values_names
]
@pytest.mark.parametrize( @pytest.mark.parametrize("dim1", [1024 * 2], ids=id_formatter("dim1"))
"dim1, dim2, quant_methods, batched", values, ids=names @pytest.mark.parametrize("dim2", [1024 * 16], ids=id_formatter("dim2"))
) @pytest.mark.parametrize("quant_methods", methods.values(), ids=methods.keys())
@pytest.mark.parametrize("batched", TRUE_FALSE, ids=id_formatter("batched"))
def test_approx_igemm(dim1, dim2, quant_methods, batched): def test_approx_igemm(dim1, dim2, quant_methods, batched):
dim1 = dim1 - (dim1 % 32) dim1 = dim1 - (dim1 % 32)
dim2 = dim2 - (dim2 % 32) dim2 = dim2 - (dim2 % 32)
...@@ -352,21 +347,10 @@ def test_stable_embedding(): ...@@ -352,21 +347,10 @@ def test_stable_embedding():
layer.reset_parameters() layer.reset_parameters()
n = 2 @pytest.mark.parametrize("hidden_dim", get_test_dims(32, 256, n=2), ids=id_formatter("hidden_dim"))
hidden_dim = torch.randint(32, 256, size=(n,)).tolist() @pytest.mark.parametrize("batch_dim", get_test_dims(16, 256, n=2), ids=id_formatter("batch_dim"))
batch_dim = torch.randint(16, 256, size=(n,)).tolist() @pytest.mark.parametrize("seq_dim", get_test_dims(16, 256, n=2), ids=id_formatter("seq_dim"))
seq_dim = torch.randint(16, 256, size=(n,)).tolist() @pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose"))
transpose = [(False, False), (False, True), (True, False), (True, True)]
values = list(product(hidden_dim, batch_dim, transpose, seq_dim))
names = [
"hidden_dim_{}_batch_dim_{},transpose_{}_seq_dim_{}".format(*vals)
for vals in values
]
@pytest.mark.parametrize(
"hidden_dim, batch_dim, transpose, seq_dim", values, ids=names
)
def test_igemm(hidden_dim, batch_dim, transpose, seq_dim): def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
hidden_dim = hidden_dim - (hidden_dim % 32) hidden_dim = hidden_dim - (hidden_dim % 32)
batch_dim = batch_dim - (batch_dim % 16) batch_dim = batch_dim - (batch_dim % 16)
...@@ -418,17 +402,9 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim): ...@@ -418,17 +402,9 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
torch.testing.assert_close(out.float(), out2) torch.testing.assert_close(out.float(), out2)
n = 3 @pytest.mark.parametrize("seq_dim", get_test_dims(32, 512, n=3), ids=id_formatter("seq_dim"))
seq_dim = torch.randint(32, 512, size=(n,)).tolist() @pytest.mark.parametrize("hidden_dim", get_test_dims(32, 1024 * 4, n=3), ids=id_formatter("hidden_dim"))
hidden_dim = torch.randint(32, 1024 * 4, size=(n,)).tolist() @pytest.mark.parametrize("batch_dim", get_test_dims(2, 16, n=3), ids=id_formatter("batch_dim"))
batch_dim = torch.randint(2, 16, size=(n,)).tolist()
values = list(product(seq_dim, hidden_dim, batch_dim))
names = [
"seq_dim{}_hidden_dim{}_batch_dim{}".format(*vals) for vals in values
]
@pytest.mark.parametrize("seq_dim, hidden_dim, batch_dim", values, ids=names)
def test_dim3_igemm(seq_dim, hidden_dim, batch_dim): def test_dim3_igemm(seq_dim, hidden_dim, batch_dim):
seq_dim = seq_dim - (seq_dim % 32) seq_dim = seq_dim - (seq_dim % 32)
hidden_dim = hidden_dim - (hidden_dim % 32) hidden_dim = hidden_dim - (hidden_dim % 32)
...@@ -449,21 +425,10 @@ def test_dim3_igemm(seq_dim, hidden_dim, batch_dim): ...@@ -449,21 +425,10 @@ def test_dim3_igemm(seq_dim, hidden_dim, batch_dim):
torch.testing.assert_close(out.float(), out2) torch.testing.assert_close(out.float(), out2)
n = 2 @pytest.mark.parametrize("seq_dim", get_test_dims(32, 512, n=2), ids=id_formatter("seq_dim"))
seq_dim = torch.randint(32, 512, size=(n,)).tolist() @pytest.mark.parametrize("hidden_dim", get_test_dims(32, 1024 * 4, n=2), ids=id_formatter("hidden_dim"))
hidden_dim = torch.randint(32, 1024 * 4, size=(n,)).tolist() @pytest.mark.parametrize("batch_dim", get_test_dims(2, 16, n=2), ids=id_formatter("batch_dim"))
batch_dim = torch.randint(2, 16, size=(n,)).tolist() @pytest.mark.parametrize("transpose", TRUE_FALSE, ids=id_formatter("transpose"))
transpose = [False, True]
values = list(product(seq_dim, hidden_dim, batch_dim, transpose))
names = [
"seq_dim={}_hidden_dim={}_batch_dim={}_transpose{}".format(*vals)
for vals in values
]
@pytest.mark.parametrize(
"seq_dim, hidden_dim, batch_dim, transpose", values, ids=names
)
def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose): def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose):
def min_max(x): def min_max(x):
maxA = torch.amax(x, dim=2, keepdim=True) maxA = torch.amax(x, dim=2, keepdim=True)
...@@ -533,20 +498,11 @@ def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose): ...@@ -533,20 +498,11 @@ def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose):
assert mean(relerrs) < 0.3 assert mean(relerrs) < 0.3
n = 2 @pytest.mark.parametrize("dim1", get_test_dims(1, 64, n=2), ids=id_formatter("dim1"))
dim1 = torch.randint(1, 64, size=(n,)).tolist() @pytest.mark.parametrize("dim2", get_test_dims(32, 128, n=2), ids=id_formatter("dim2"))
dim2 = torch.randint(32, 128, size=(n,)).tolist() @pytest.mark.parametrize("dim3", get_test_dims(32, 256, n=2), ids=id_formatter("dim3"))
dim3 = torch.randint(32, 256, size=(n,)).tolist() @pytest.mark.parametrize("dim4", get_test_dims(32, 256, n=2), ids=id_formatter("dim4"))
dim4 = torch.randint(32, 256, size=(n,)).tolist() @pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose"))
transpose = [(False, False), (True, False), (False, True), (True, True)]
values = list(product(dim1, dim2, dim3, dim4, transpose))
names = [
"dim1_{}_dim2_{}_dim3_{}_dim4_{}_transpose_{}".format(*vals)
for vals in values
]
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, transpose", values, ids=names)
def test_ibmm(dim1, dim2, dim3, dim4, transpose): def test_ibmm(dim1, dim2, dim3, dim4, transpose):
dim2 = dim2 - (dim2 % 16) dim2 = dim2 - (dim2 % 16)
dim3 = dim3 - (dim3 % 16) dim3 = dim3 - (dim3 % 16)
...@@ -574,15 +530,9 @@ def test_ibmm(dim1, dim2, dim3, dim4, transpose): ...@@ -574,15 +530,9 @@ def test_ibmm(dim1, dim2, dim3, dim4, transpose):
torch.testing.assert_close(out.float(), out2.float()) torch.testing.assert_close(out.float(), out2.float())
n = 1 @pytest.mark.parametrize("dim1", get_test_dims(1, 64, n=1), ids=id_formatter("dim1"))
dim1 = torch.randint(1, 64, size=(n,)).tolist() @pytest.mark.parametrize("dim2", get_test_dims(32, 128, n=1), ids=id_formatter("dim2"))
dim2 = torch.randint(32, 128, size=(n,)).tolist() @pytest.mark.parametrize("dim3", get_test_dims(32, 256, n=1), ids=id_formatter("dim3"))
dim3 = torch.randint(32, 256, size=(n,)).tolist()
values = list(product(dim1, dim2, dim3))
names = ["dim1_{}_dim2_{}_dim3_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, dim3", values, ids=names)
def test_vector_quant(dim1, dim2, dim3): def test_vector_quant(dim1, dim2, dim3):
dim2 = dim2 - (dim2 % 16) dim2 = dim2 - (dim2 % 16)
dim3 = dim3 - (dim3 % 16) dim3 = dim3 - (dim3 % 16)
...@@ -594,24 +544,14 @@ def test_vector_quant(dim1, dim2, dim3): ...@@ -594,24 +544,14 @@ def test_vector_quant(dim1, dim2, dim3):
assert_all_approx_close(A1, A, atol=0.01, rtol=0.1, count=int(n*0.002)) assert_all_approx_close(A1, A, atol=0.01, rtol=0.1, count=int(n*0.002))
@pytest.mark.parametrize("dim1", get_test_dims(2, 256, n=2), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", get_test_dims(2, 256, n=2), ids=id_formatter("dim2"))
n = 2 @pytest.mark.parametrize("dim3", get_test_dims(2, 256, n=2), ids=id_formatter("dim3"))
dim1 = torch.randint(2, 256, size=(n,)).tolist() @pytest.mark.parametrize("dtype", [torch.int8, torch.int32], ids=describe_dtype)
dim2 = torch.randint(2, 256, size=(n,)).tolist() @pytest.mark.parametrize("orderA", ["row"], ids=id_formatter("orderA"))
dim3 = torch.randint(2, 256, size=(n,)).tolist() @pytest.mark.parametrize("orderOut", ["col", "row", "col32"], ids=id_formatter("orderOut"))
# dim1, dim2 = (256,), (256,) @pytest.mark.parametrize("transpose", [False], ids=id_formatter("transpose"))
dtype = [torch.int8, torch.int32] @pytest.mark.parametrize("dims", [2, 3], ids=id_formatter("dims"))
a_order = ["row"]
out_order = ["col", "row", "col32"]
transpose = [False]
dims = [2, 3]
values = list(product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose))
names = ["dim1_{}_dim2_{}_dim3_{}_dims_{}_dtype_{}_orderA_{}_orderOut_{}_transpose_{}".format(*vals)for vals in values]
@pytest.mark.parametrize("dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose",values,ids=names)
def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose): def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
if dims == 3 and orderOut != "col32": if dims == 3 and orderOut != "col32":
return return
...@@ -677,28 +617,12 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans ...@@ -677,28 +617,12 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
torch.testing.assert_close(A, out2) torch.testing.assert_close(A, out2)
n = 1 @pytest.mark.parametrize("dim1", get_test_dims(1, 256, n=1), ids=id_formatter("dim1"))
dim1 = torch.randint(1, 256, size=(n,)).tolist() @pytest.mark.parametrize("dim2", get_test_dims(32, 512, n=1), ids=id_formatter("dim2"))
dim2 = torch.randint(32, 512, size=(n,)).tolist() @pytest.mark.parametrize("dim3", get_test_dims(32, 1024, n=1), ids=id_formatter("dim3"))
dim3 = torch.randint(32, 1024, size=(n,)).tolist() @pytest.mark.parametrize("dim4", get_test_dims(32, 1024, n=1), ids=id_formatter("dim4"))
dim4 = torch.randint(32, 1024, size=(n,)).tolist() @pytest.mark.parametrize("dims", (2, 3), ids=id_formatter("dims"))
@pytest.mark.parametrize("ldb", (0,), ids=id_formatter("ldb"))
# dim1 = [2]
# dim2 = [2]
# dim3 = [2]
# dim4 = [2]
dims = (2, 3)
ldb = [0]
# ldb = list(range(256, 1*1024, 256))
values = list(product(dim1, dim2, dim3, dim4, dims, ldb))
names = [
"dim1_{}_dim2_{}_dim3_{}_dim4_{}_dims_{}_ldb_{}".format(*vals)
for vals in values
]
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, dims, ldb", values, ids=names)
def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb): def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb):
for i in range(k): for i in range(k):
if dims == 2: if dims == 2:
...@@ -732,21 +656,11 @@ def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb): ...@@ -732,21 +656,11 @@ def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb):
torch.testing.assert_close(C1, C3.float()) torch.testing.assert_close(C1, C3.float())
dim1 = [32] @pytest.mark.parametrize("dim1", [32], ids=id_formatter("dim1"))
dim2 = [32] @pytest.mark.parametrize("dim2", [32], ids=id_formatter("dim2"))
dim3 = [32] @pytest.mark.parametrize("dim3", [32], ids=id_formatter("dim3"))
dim4 = [32] @pytest.mark.parametrize("dim4", [32], ids=id_formatter("dim4"))
@pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims"))
dims = (2,)
# ldb = list(range(256, 1*1024, 256))
values = list(product(dim1, dim2, dim3, dim4, dims))
names = [
"dim1_{}_dim2_{}_dim3_{}_dim4_{}_dims_{}".format(*vals)
for vals in values
]
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, dims", values, ids=names)
def test_igemmlt_half(dim1, dim2, dim3, dim4, dims): def test_igemmlt_half(dim1, dim2, dim3, dim4, dims):
formatB = F.get_special_format_str() formatB = F.get_special_format_str()
for i in range(k): for i in range(k):
...@@ -786,24 +700,15 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims): ...@@ -786,24 +700,15 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims):
# C3, S = F.transform(C2, 'row', state=SC) # C3, S = F.transform(C2, 'row', state=SC)
# torch.testing.assert_close(C1, C3.float()) # torch.testing.assert_close(C1, C3.float())
@pytest.mark.parametrize(
batch_size = 2 ("batch", "seq", "model", "hidden"),
seqdim = 512 [
# values = [(batch_size, seqdim, 4*1024, 16*1024),(batch_size, seqdim, 5120, 4*5120),(batch_size, seqdim, 12*1024, 4*12*1024)] pytest.param(2, 512, 4 * 1024, 3 * 4 * 1024, id="batch=2, seq=512, model=4k, hidden=12k"),
values = [ pytest.param(2, 512, 5120, 3 * 5120, id="batch=2, seq=512, model=5k, hidden=15k"),
(batch_size, seqdim, 4 * 1024, 3 * 4 * 1024), pytest.param(2, 512, 12 * 1024, 4 * 12 * 1024, id="batch=2, seq=512, model=12k, hidden=48k"),
(batch_size, seqdim, 5120, 3 * 5120), ],
(batch_size, seqdim, 12 * 1024, 4 * 12 * 1024), )
] @pytest.mark.benchmark
# values = list(product(batch, seq, model, hidden))
names = [
"batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values
]
@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)
def test_bench_8bit_training(batch, seq, model, hidden): def test_bench_8bit_training(batch, seq, model, hidden):
formatB = F.get_special_format_str() formatB = F.get_special_format_str()
A = torch.randn(batch, seq, model, device="cuda").half() A = torch.randn(batch, seq, model, device="cuda").half()
...@@ -953,24 +858,11 @@ def test_bench_8bit_training(batch, seq, model, hidden): ...@@ -953,24 +858,11 @@ def test_bench_8bit_training(batch, seq, model, hidden):
# print(t8) # print(t8)
n = 2 @pytest.mark.parametrize("dim1", get_test_dims(64, 256, n=2), ids=id_formatter("dim1"))
dim1 = torch.randint(64, 256, size=(n,)).tolist() @pytest.mark.parametrize("dim4", get_test_dims(64, 1024, n=2), ids=id_formatter("dim4"))
dim4 = torch.randint(64, 1024, size=(n,)).tolist() @pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims"))
@pytest.mark.parametrize("formatB", ["col_turing", "col_ampere"], ids=id_formatter("formatB"))
#dim1 = [2*1024] @pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias"))
#dim4 = [2*1024]
#dim1 = [4]
#dim4 = [4]
dims = (2,)
formatB = ["col_turing", "col_ampere"]
has_bias = [True, False]
values = list(product(dim1, dim4, dims, formatB, has_bias))
names = ["dim1_{}_dim4_{}_dims_{}_formatB_{}_has_bias_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim4, dims, formatB, has_bias", values, ids=names)
def test_dequant_mm(dim1, dim4, dims, formatB, has_bias): def test_dequant_mm(dim1, dim4, dims, formatB, has_bias):
inner = torch.randint(1, 128, size=(1,)).item() inner = torch.randint(1, 128, size=(1,)).item()
bias = None bias = None
...@@ -994,33 +886,23 @@ def test_dequant_mm(dim1, dim4, dims, formatB, has_bias): ...@@ -994,33 +886,23 @@ def test_dequant_mm(dim1, dim4, dims, formatB, has_bias):
if has_bias: C4 += bias if has_bias: C4 += bias
# TODO: is something wrong here? If so, the problem goes deeper # TODO: is something wrong here? If so, the problem goes deeper
#n = C1.numel() # n = C1.numel()
#p = 0.06 # p = 0.06
std = C1.std(0).view(1, -1) std = C1.std(0).view(1, -1)
C1 /= std C1 /= std
C4 /= std C4 /= std
#assert_all_approx_close(C1, C4, atol=0.02, rtol=0.1, count=int(n*0.06)) # assert_all_approx_close(C1, C4, atol=0.02, rtol=0.1, count=int(n*0.06))
#assert (count / n < p), f"error in more than {p} of elements: {count}/{n}={count/n}" # assert (count / n < p), f"error in more than {p} of elements: {count}/{n}={count/n}"
C5 = F.mm_dequant(C2, SC, maxA.flatten(), maxB.flatten(), bias=bias) C5 = F.mm_dequant(C2, SC, maxA.flatten(), maxB.flatten(), bias=bias)
#torch.testing.assert_close(C5, C4, atol=0.015, rtol=0.1) # torch.testing.assert_close(C5, C4, atol=0.015, rtol=0.1)
n = C5.numel() n = C5.numel()
assert_all_approx_close(C1, C4, atol=0.015, rtol=0.1, count=int(0.01*n)) assert_all_approx_close(C1, C4, atol=0.015, rtol=0.1, count=int(0.01 * n))
n = 2
dim1 = [1 * 1024]
dim2 = [1 * 1024]
# dim1 = torch.randint(1,4*1024, size=(n,)).tolist()
# dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
dims = (2,)
# ldb = list(range(256, 1*1024, 256))
values = list(product(dim1, dim2, dims))
names = ["dim1_{}_dim2_{}_dims_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1", [1 * 1024], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim1, dim2, dims", values, ids=names) @pytest.mark.parametrize("dim2", [1 * 1024], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims"))
def test_colrow_absmax(dim1, dim2, dims): def test_colrow_absmax(dim1, dim2, dims):
for i in range(k): for i in range(k):
threshold = 3.0 threshold = 3.0
...@@ -1066,17 +948,8 @@ def test_colrow_absmax(dim1, dim2, dims): ...@@ -1066,17 +948,8 @@ def test_colrow_absmax(dim1, dim2, dims):
assert nnz_block_ptr2 is None assert nnz_block_ptr2 is None
n = 2 @pytest.mark.parametrize("dim1", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim1"))
# dim1 = [8*1024] @pytest.mark.parametrize("dim2", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim2"))
# dim2 = [4*1024]
dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
values = list(product(dim1, dim2))
names = ["dim1_{}_dim2_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2", values, ids=names)
def test_double_quant(dim1, dim2): def test_double_quant(dim1, dim2):
for i in range(k): for i in range(k):
A = torch.randn(dim1, dim2, device="cuda").half() A = torch.randn(dim1, dim2, device="cuda").half()
...@@ -1114,16 +987,18 @@ def test_double_quant(dim1, dim2): ...@@ -1114,16 +987,18 @@ def test_double_quant(dim1, dim2):
torch.testing.assert_close(Scol.flatten().float(), statsAt) torch.testing.assert_close(Scol.flatten().float(), statsAt)
n = 4 @pytest.mark.parametrize(
dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist() ("dim1", "dim4", "inner"),
dim4 = torch.randint(1, 4 * 1024, size=(n,)).tolist() (
inner = torch.randint(1, 4 * 1024, size=(n,)).tolist() pytest.param(dim1, dim4, inner, id=f"{dim1=},{dim4=},{inner=}")
for (dim1, dim4, inner)
values = list(zip(dim1, dim4, inner)) in zip(
names = ["dim1_{}_dim4_{}_inner_{}".format(*vals) for vals in values] get_test_dims(1, 4 * 1024, n=4),
get_test_dims(1, 4 * 1024, n=4),
get_test_dims(1, 4 * 1024, n=4),
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names) )
)
)
def test_integrated_igemmlt(dim1, dim4, inner): def test_integrated_igemmlt(dim1, dim4, inner):
for i in range(k): for i in range(k):
A = torch.randn(dim1, inner, device="cuda").half() A = torch.randn(dim1, inner, device="cuda").half()
...@@ -1158,16 +1033,18 @@ def test_integrated_igemmlt(dim1, dim4, inner): ...@@ -1158,16 +1033,18 @@ def test_integrated_igemmlt(dim1, dim4, inner):
assert err2 <= err1 * 1.025 assert err2 <= err1 * 1.025
n = 6 @pytest.mark.parametrize(
dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist() ("dim1", "dim4", "inner"),
dim4 = torch.randint(1, 4 * 1024, size=(n,)).tolist() (
inner = torch.randint(1, 4 * 1024, size=(n,)).tolist() pytest.param(dim1, dim4, inner, id=f"{dim1=},{dim4=},{inner=}")
for (dim1, dim4, inner)
values = list(zip(dim1, dim4, inner)) in zip(
names = ["dim1_{}_dim4_{}_inner_{}".format(*vals) for vals in values] get_test_dims(1, 4 * 1024, n=6),
get_test_dims(1, 4 * 1024, n=6),
get_test_dims(1, 4 * 1024, n=6),
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names) )
)
)
@pytest.mark.skip("Row scale has some bugs for ampere") @pytest.mark.skip("Row scale has some bugs for ampere")
def test_igemmlt_row_scale(dim1, dim4, inner): def test_igemmlt_row_scale(dim1, dim4, inner):
formatB = F.get_special_format_str() formatB = F.get_special_format_str()
...@@ -1234,17 +1111,17 @@ def test_igemmlt_row_scale(dim1, dim4, inner): ...@@ -1234,17 +1111,17 @@ def test_igemmlt_row_scale(dim1, dim4, inner):
print(sum(err3) / len(err3)) print(sum(err3) / len(err3))
dim1 = [1024, 2048] @pytest.mark.parametrize(
inner = [12288 * 4, 4096 * 4] ("dim1", "dim4", "inner"),
dim4 = [12288, 4096] [
pytest.param(1024, 12288 * 4, 12288, id="1024, 12288*4, 12288"),
values = list(zip(dim1, dim4, inner)) pytest.param(2048, 4096 * 4, 4096, id="2048, 4096*4, 4096"),
names = ["dim1_{}_dim4_{}_inner_{}".format(*vals) for vals in values] ],
)
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
@pytest.mark.skip("Row scale has some bugs for ampere") @pytest.mark.skip("Row scale has some bugs for ampere")
@pytest.mark.benchmark
def test_row_scale_bench(dim1, dim4, inner): def test_row_scale_bench(dim1, dim4, inner):
formatB = F.get_special_format_str()
err1, err2, err3 = [], [], [] err1, err2, err3 = [], [], []
relerr1, relerr2 = [], [] relerr1, relerr2 = [], []
scale = 1 scale = 1
...@@ -1289,34 +1166,14 @@ def test_row_scale_bench(dim1, dim4, inner): ...@@ -1289,34 +1166,14 @@ def test_row_scale_bench(dim1, dim4, inner):
print("vector-wise", time.time() - t0) print("vector-wise", time.time() - t0)
n = 2 @pytest.mark.parametrize("dim1", get_test_dims(2, 1024, n=2), ids=id_formatter("dim1"))
dim1 = torch.randint(2, 1024, size=(n,)).tolist() @pytest.mark.parametrize("dim2", get_test_dims(2, 1024, n=2), ids=id_formatter("dim2"))
dim2 = torch.randint(2, 1024, size=(n,)).tolist() @pytest.mark.parametrize("dim3", [0], ids=id_formatter("dim3"))
# dim1 = [8*1024] @pytest.mark.parametrize("dims", [2], ids=id_formatter("dims"))
# dim2 = [4*1024] @pytest.mark.parametrize("dtype", [torch.int8], ids=describe_dtype)
@pytest.mark.parametrize("orderA", ["row"], ids=id_formatter("orderA"))
dim3 = [0] @pytest.mark.parametrize("orderOut", ["col32", "col_turing", "col_ampere"], ids=id_formatter("orderOut"))
dtype = [torch.int8] @pytest.mark.parametrize("transpose", TRUE_FALSE, ids=id_formatter("transpose"))
a_order = ["row"]
out_order = ["col32", "col_turing", "col_ampere"]
transpose = [False, True]
dims = [2]
values = list(
product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose)
)
names = [
"dim1_{}_dim2_{}_dim3_{}_dims_{}_dtype_{}_orderA_{}_orderOut_{}_{}".format(
*vals
)
for vals in values
]
@pytest.mark.parametrize(
"dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose",
values,
ids=names,
)
def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose): def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
for i in range(k): for i in range(k):
if dims == 2: if dims == 2:
...@@ -1344,23 +1201,6 @@ def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose): ...@@ -1344,23 +1201,6 @@ def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
torch.testing.assert_close(out1, out2) torch.testing.assert_close(out1, out2)
n = 2
# dim1 = torch.randint(2,1024, size=(n,)).tolist()
# dim2 = torch.randint(2,1024, size=(n,)).tolist()
dim1 = [1]
dim2 = [33]
dtype = [torch.int8]
# a_order = ['col_turing', 'col_ampere']
a_order = ["col_turing"]
out_order = ["row"]
values = list(product(dim1, dim2, dtype, a_order, out_order))
names = [
"dim1_{}_dim2_{}_dtype_{}_orderA_{}_orderOut_{}".format(*vals)
for vals in values
]
def test_overflow(): def test_overflow():
formatB = F.get_special_format_str() formatB = F.get_special_format_str()
print(formatB) print(formatB)
...@@ -1375,17 +1215,8 @@ def test_overflow(): ...@@ -1375,17 +1215,8 @@ def test_overflow():
c2 = torch.matmul(a.float(), b.float().t()) c2 = torch.matmul(a.float(), b.float().t())
n = 2 @pytest.mark.parametrize("dim1", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim1"))
dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist() @pytest.mark.parametrize("dim2", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim2"))
dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
# dim1 = [4]
# dim2 = [5]
values = list(product(dim1, dim2))
names = ["dim1_{}_dim2_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2", values, ids=names)
def test_coo_double_quant(dim1, dim2): def test_coo_double_quant(dim1, dim2):
threshold = 3.00 threshold = 3.00
for i in range(k): for i in range(k):
...@@ -1412,17 +1243,9 @@ def test_coo_double_quant(dim1, dim2): ...@@ -1412,17 +1243,9 @@ def test_coo_double_quant(dim1, dim2):
) )
n = 2 @pytest.mark.parametrize("dim1", get_test_dims(1, 1 * 1024, n=2), ids=id_formatter("dim1"))
dim1 = torch.randint(1, 1 * 1024, size=(n,)).tolist() @pytest.mark.parametrize("dim2", get_test_dims(1, 1 * 1024, n=2), ids=id_formatter("dim2"))
dim2 = torch.randint(1, 1 * 1024, size=(n,)).tolist() @pytest.mark.parametrize("transposed_B", TRUE_FALSE, ids=id_formatter("transposed_B"))
# dim1 = [7]
# dim2 = [11]
transposed_B = [False, True]
values = list(product(dim1, dim2, transposed_B))
names = ["dim1_{}_dim2_{}_transposed_B_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, transposed_B", values, ids=names)
def test_spmm_coo(dim1, dim2, transposed_B): def test_spmm_coo(dim1, dim2, transposed_B):
threshold = 1.5 threshold = 1.5
dim3 = torch.randint(32, 128, size=(1,)).item() dim3 = torch.randint(32, 128, size=(1,)).item()
...@@ -1453,6 +1276,7 @@ def test_spmm_coo(dim1, dim2, transposed_B): ...@@ -1453,6 +1276,7 @@ def test_spmm_coo(dim1, dim2, transposed_B):
assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=30) assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=30)
@pytest.mark.benchmark
def test_spmm_bench(): def test_spmm_bench():
batch = 2 batch = 2
model = 1024 * 1 model = 1024 * 1
...@@ -1496,14 +1320,8 @@ def test_spmm_bench(): ...@@ -1496,14 +1320,8 @@ def test_spmm_bench():
print(tsp / t8) print(tsp / t8)
n = 2 @pytest.mark.parametrize("dim1", get_test_dims(256, 1024, n=2), ids=id_formatter("dim1"))
dim1 = torch.randint(256, 1 * 1024, size=(n,)).tolist() @pytest.mark.parametrize("dim2", get_test_dims(256, 1024, n=2), ids=id_formatter("dim2"))
dim2 = torch.randint(256, 1 * 1024, size=(n,)).tolist()
values = list(product(dim1, dim2))
names = ["dim1_{}_dim2_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2", values, ids=names)
def test_integrated_sparse_decomp(dim1, dim2): def test_integrated_sparse_decomp(dim1, dim2):
threshold = 3.0 threshold = 3.0
formatB = "col_turing" formatB = "col_turing"
...@@ -1553,23 +1371,10 @@ def test_matmuls(): ...@@ -1553,23 +1371,10 @@ def test_matmuls():
print(err1, err2) print(err1, err2)
n = 2 @pytest.mark.parametrize("dim1", [1 * 2048], ids=id_formatter("dim1"))
# dim1 = torch.randint(1,1*1024, size=(n,)).tolist() @pytest.mark.parametrize("dim2", [12288], ids=id_formatter("dim2"))
# dim2 = torch.randint(1,4*1024, size=(n,)).tolist() @pytest.mark.parametrize("dtype", [torch.float16], ids=describe_dtype)
dim1 = [1 * 2048] @pytest.mark.parametrize("out_func", ["zeros", "ones"], ids=id_formatter("out_func"))
dim2 = [12288]
# dim1 = [32]
# dim2 = [32]
# dtype = [torch.float16, torch.int8]
dtype = [torch.float16]
out_function = ["zeros", "ones"]
values = list(product(dim1, dim2, dtype, out_function))
names = [
"dim1_{}_dim2_{}_dtype_{}_out_func_{}".format(*vals) for vals in values
]
@pytest.mark.parametrize("dim1, dim2, dtype, out_func", values, ids=names)
def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func): def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func):
out_func = getattr(torch, out_func) out_func = getattr(torch, out_func)
...@@ -1672,20 +1477,9 @@ def test_coo2csc(): ...@@ -1672,20 +1477,9 @@ def test_coo2csc():
torch.testing.assert_close(A2.t()[idx], cscA.values) torch.testing.assert_close(A2.t()[idx], cscA.values)
n = 2 @pytest.mark.parametrize("dim1", [1 * 2048])
# dim1 = torch.randint(1,1*1024, size=(n,)).tolist() @pytest.mark.parametrize("dim2", [2048])
# dim2 = torch.randint(1,4*1024, size=(n,)).tolist() @pytest.mark.parametrize("dtype", [torch.int8])
dim1 = [1 * 2048]
# dim2 = [12288]
dim2 = [2048]
# dim1 = [2]
# dim2 = [2]
dtype = [torch.int8]
values = list(product(dim1, dim2, dtype))
names = ["dim1_{}_dim2_{}_dtype_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, dtype", values, ids=names)
def test_spmm_coo_dequant(dim1, dim2, dtype): def test_spmm_coo_dequant(dim1, dim2, dtype):
threshold = 6.0 threshold = 6.0
# threshold = 2.8 # threshold = 2.8
...@@ -1786,22 +1580,11 @@ def test_spmm_coo_dequant(dim1, dim2, dtype): ...@@ -1786,22 +1580,11 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
print("partial matmul", time.time() - t0) print("partial matmul", time.time() - t0)
batch_size = 1 @pytest.mark.parametrize(
seqdim = 1 ("batch", "seq", "model", "hidden"),
values = [] [pytest.param(1, 1, 6656, 4*6656, id="batch=1, seq=1, model=6656, hidden=26k")],
#values.append((batch_size, seqdim, 768, 4 * 768)) )
#values.append((batch_size, seqdim, 1024, 4*1024)) @pytest.mark.benchmark
#values.append((batch_size, seqdim, 1536, 4*1536))
#values.append((batch_size, seqdim, 2048, 4*2048))
#values.append((batch_size, seqdim, 2560, 4*2560))
#values.append((batch_size, seqdim, 4096, 4*4096))
#values.append((batch_size, seqdim, 5120, 4*5120))
values.append((batch_size, seqdim, 6656, 4*6656))
#values.append((batch_size, seqdim, 8192, 4*8192))
#values.append((batch_size, seqdim, 5140, 4*5140))
#values.append((batch_size, seqdim, 12288, 4*12288))
names = ["batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)
def test_bench_matmul(batch, seq, model, hidden): def test_bench_matmul(batch, seq, model, hidden):
iters = 1000 iters = 1000
formatB = F.get_special_format_str() formatB = F.get_special_format_str()
...@@ -2226,6 +2009,7 @@ def test_kbit_quantile_estimation(): ...@@ -2226,6 +2009,7 @@ def test_kbit_quantile_estimation():
assert err < 0.035 assert err < 0.035
@pytest.mark.benchmark
def test_bench_dequantization(): def test_bench_dequantization():
a = torch.rand(1024, 1024, device='cuda').half() a = torch.rand(1024, 1024, device='cuda').half()
code =F.create_fp8_map(True, 3, 0, 4).cuda() code =F.create_fp8_map(True, 3, 0, 4).cuda()
...@@ -2244,7 +2028,7 @@ def test_bench_dequantization(): ...@@ -2244,7 +2028,7 @@ def test_bench_dequantization():
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=["fp32", "fp16", "bf16"]) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
def test_fp4_quant(dtype): def test_fp4_quant(dtype):
vals = list(product([0, 1], repeat=4)) vals = list(product([0, 1], repeat=4))
...@@ -2321,6 +2105,7 @@ def test_4bit_compressed_stats(quant_type): ...@@ -2321,6 +2105,7 @@ def test_4bit_compressed_stats(quant_type):
#@pytest.mark.parametrize("quant_type", ['fp4', 'nf4']) #@pytest.mark.parametrize("quant_type", ['fp4', 'nf4'])
@pytest.mark.parametrize("quant_type", ['nf4']) @pytest.mark.parametrize("quant_type", ['nf4'])
@pytest.mark.benchmark
def test_bench_4bit_dequant(quant_type): def test_bench_4bit_dequant(quant_type):
blocksize = 256 blocksize = 256
a = torch.rand(1024*12*4, 1024*12, device='cuda').half() a = torch.rand(1024*12*4, 1024*12, device='cuda').half()
...@@ -2367,11 +2152,11 @@ def test_normal_map_tree(): ...@@ -2367,11 +2152,11 @@ def test_normal_map_tree():
#print(pivots) #print(pivots)
@pytest.mark.parametrize("double_quant", [True, False], ids=['DQ_True', 'DQ_False']) @pytest.mark.parametrize("double_quant", TRUE_FALSE, ids=lambda double_quant: f"DQ_{double_quant}")
@pytest.mark.parametrize("storage_type", ['nf4', 'fp4'], ids=['nf4', 'fp4']) @pytest.mark.parametrize("storage_type", ['nf4', 'fp4'])
@pytest.mark.parametrize("kind", ['fc1', 'fc2', 'attn', 'attn_packed'], ids=['fc1', 'fc2', 'attn', 'attn_packed']) @pytest.mark.parametrize("kind", ['fc1', 'fc2', 'attn', 'attn_packed'])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=['fp16', 'bf16', 'fp32']) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
@pytest.mark.parametrize("quant_storage", [torch.uint8, torch.float16, torch.bfloat16, torch.float32], ids=['uint8', 'fp16', 'bf16', 'fp32']) @pytest.mark.parametrize("quant_storage", [torch.uint8, torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind): def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind):
for dim in [128, 256, 512, 1024]: for dim in [128, 256, 512, 1024]:
#for dim in [4*1024]: #for dim in [4*1024]:
...@@ -2537,12 +2322,12 @@ def test_managed(): ...@@ -2537,12 +2322,12 @@ def test_managed():
@pytest.mark.parametrize("storage_type", ['nf4', 'fp4'], ids=['nf4', 'fp4']) @pytest.mark.parametrize("storage_type", ['nf4', 'fp4'], ids=['nf4', 'fp4'])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=['fp16', 'bf16', 'fp32']) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
@pytest.mark.parametrize("double_quant", [False], ids=['DQ_True']) @pytest.mark.parametrize("double_quant", [False], ids=['DQ_True'])
def test_gemv_eye_4bit(storage_type, dtype, double_quant): def test_gemv_eye_4bit(storage_type, dtype, double_quant):
dims = 10 dims = 10
torch.random.manual_seed(np.random.randint(0, 412424242)) torch.random.manual_seed(np.random.randint(0, 412424242))
dims = torch.randint(0, 8192, size=(dims,)).tolist() dims = get_test_dims(0, 8192, n=dims)
dims = [dim + (64-(dim % 64)) for dim in dims] dims = [dim + (64-(dim % 64)) for dim in dims]
#for dim in [576, 5120, 3520, 5184, 1280, 4992, 5312, 2048]: #for dim in [576, 5120, 3520, 5184, 1280, 4992, 5312, 2048]:
for dim in dims: for dim in dims:
......
...@@ -9,6 +9,8 @@ from transformers import ( ...@@ -9,6 +9,8 @@ from transformers import (
BitsAndBytesConfig, BitsAndBytesConfig,
) )
from tests.helpers import TRUE_FALSE, describe_dtype, id_formatter
def get_4bit_config(): def get_4bit_config():
return BitsAndBytesConfig( return BitsAndBytesConfig(
...@@ -59,23 +61,19 @@ def generate(model, tokenizer, text, generation_config, prompt_func=get_prompt_f ...@@ -59,23 +61,19 @@ def generate(model, tokenizer, text, generation_config, prompt_func=get_prompt_f
models = ['huggyllama/llama-7b', 'bigscience/bloom-1b7'] models = ['huggyllama/llama-7b', 'bigscience/bloom-1b7']
dtypes = ['nf4', 'fp4'] dtypes = ['nf4', 'fp4']
load_in_4bit = [True, False]
values = list(product(models, dtypes)) @pytest.fixture(scope='session', params=product(models, dtypes))
strfunc = lambda lst: [str(x) for x in lst]
ids = ['_'.join(strfunc(x)) for x in values]
@pytest.fixture(scope='session', params=values, ids=ids)
def model_and_tokenizer(request): def model_and_tokenizer(request):
model, tokenizer = get_model_and_tokenizer(request.param) model, tokenizer = get_model_and_tokenizer(request.param)
yield request.param, model, tokenizer yield request.param, model, tokenizer
del model del model
@pytest.mark.parametrize("DQ", [True, False], ids=['DQ_True', 'DQ_False'])
@pytest.mark.parametrize("inference_kernel", [True, False], ids=['inference_kernel_True', 'inference_kernel_False'])
#@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=['fp16', 'bf16', 'fp32'])
def test_pi(requires_cuda, model_and_tokenizer, inference_kernel, DQ):
print('')
dtype = torch.float16
@pytest.mark.parametrize("DQ", TRUE_FALSE, ids=id_formatter("dq"))
@pytest.mark.parametrize("inference_kernel", TRUE_FALSE, ids=id_formatter("inference_kernel"))
@pytest.mark.parametrize("dtype", [torch.float16], ids=describe_dtype)
@pytest.mark.slow
def test_pi(requires_cuda, model_and_tokenizer, inference_kernel, DQ, dtype):
fixture_config, model, tokenizer = model_and_tokenizer fixture_config, model, tokenizer = model_and_tokenizer
generation_config = transformers.GenerationConfig( generation_config = transformers.GenerationConfig(
......
from itertools import product
import os import os
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
...@@ -6,6 +5,7 @@ import pytest ...@@ -6,6 +5,7 @@ import pytest
import torch import torch
import bitsandbytes as bnb import bitsandbytes as bnb
from tests.helpers import TRUE_FALSE
storage = { storage = {
'uint8': torch.uint8, 'uint8': torch.uint8,
...@@ -14,10 +14,10 @@ storage = { ...@@ -14,10 +14,10 @@ storage = {
'float32': torch.float32 'float32': torch.float32
} }
@pytest.mark.parametrize( @pytest.mark.parametrize("quant_storage", ['uint8', 'float16', 'bfloat16', 'float32'])
"quant_type, compress_statistics, bias, quant_storage", @pytest.mark.parametrize("bias", TRUE_FALSE)
list(product(["nf4", "fp4"], [False, True], [False, True], ['uint8', 'float16', 'bfloat16', 'float32'])), @pytest.mark.parametrize("compress_statistics", TRUE_FALSE)
) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
def test_linear_serialization(quant_type, compress_statistics, bias, quant_storage): def test_linear_serialization(quant_type, compress_statistics, bias, quant_storage):
original_dtype = torch.float16 original_dtype = torch.float16
compute_dtype = None compute_dtype = None
......
from contextlib import nullcontext from contextlib import nullcontext
from itertools import product
import os import os
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
...@@ -10,6 +9,7 @@ import bitsandbytes as bnb ...@@ -10,6 +9,7 @@ import bitsandbytes as bnb
from bitsandbytes import functional as F from bitsandbytes import functional as F
from bitsandbytes.autograd import get_inverse_transform_indices, undo_layout from bitsandbytes.autograd import get_inverse_transform_indices, undo_layout
from bitsandbytes.nn.modules import Linear8bitLt from bitsandbytes.nn.modules import Linear8bitLt
from tests.helpers import TRUE_FALSE, id_formatter
# contributed by Alex Borzunov, see: # contributed by Alex Borzunov, see:
# https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py # https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py
...@@ -66,8 +66,10 @@ def test_linear_no_igemmlt(): ...@@ -66,8 +66,10 @@ def test_linear_no_igemmlt():
assert linear_custom.state.CxB is None assert linear_custom.state.CxB is None
@pytest.mark.parametrize("has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt", @pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights"))
list(product([False, True], [False, True], [False, True], [False, True]))) @pytest.mark.parametrize("serialize_before_forward", TRUE_FALSE, ids=id_formatter("serialize_before_forward"))
@pytest.mark.parametrize("deserialize_before_cuda", TRUE_FALSE, ids=id_formatter("deserialize_before_cuda"))
@pytest.mark.parametrize("force_no_igemmlt", TRUE_FALSE, ids=id_formatter("force_no_igemmlt"))
def test_linear_serialization(has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt): def test_linear_serialization(has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt):
linear = torch.nn.Linear(32, 96) linear = torch.nn.Linear(32, 96)
x = torch.randn(3, 32, dtype=torch.half) x = torch.randn(3, 32, dtype=torch.half)
......
...@@ -6,6 +6,7 @@ import torch ...@@ -6,6 +6,7 @@ import torch
from torch import nn from torch import nn
import bitsandbytes as bnb import bitsandbytes as bnb
from tests.helpers import id_formatter
class MockArgs: class MockArgs:
...@@ -311,12 +312,7 @@ class Linear8bit(nn.Module): ...@@ -311,12 +312,7 @@ class Linear8bit(nn.Module):
return LinearFunction.apply(x, self.weight, self.bias, self.args) return LinearFunction.apply(x, self.weight, self.bias, self.args)
threshold = [0.0, 3.0] @pytest.mark.parametrize("threshold", [0.0, 3.0], ids=id_formatter("threshold"))
values = threshold
names = [f"threshold_{vals}" for vals in values]
@pytest.mark.parametrize("threshold", values, ids=names)
def test_linear8bitlt_inference(threshold): def test_linear8bitlt_inference(threshold):
l1 = bnb.nn.Linear8bitLt(32, 64, threshold=threshold).cuda().half() l1 = bnb.nn.Linear8bitLt(32, 64, threshold=threshold).cuda().half()
assert l1.weight.device.type == "cuda" assert l1.weight.device.type == "cuda"
...@@ -510,18 +506,21 @@ def test_linear_kbit_fp32_bias(module): ...@@ -510,18 +506,21 @@ def test_linear_kbit_fp32_bias(module):
o1 = l1(b1) o1 = l1(b1)
assert l1.bias is None assert l1.bias is None
modules = []
modules.append(bnb.nn.Linear8bitLt) module_dict = {
modules.append(bnb.nn.Linear4bit) "Int8Lt": bnb.nn.Linear8bitLt,
modules.append(bnb.nn.LinearFP4) "4bit": bnb.nn.Linear4bit,
modules.append(bnb.nn.LinearNF4) "FP4": bnb.nn.LinearFP4,
modules.append(lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compress_statistics=True)) "NF4": bnb.nn.LinearNF4,
modules.append(lambda d1, d2: bnb.nn.LinearNF4(d1, d2, compress_statistics=True)) "FP4+C": lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compress_statistics=True),
modules.append(lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.float32)) "NF4+C": lambda d1, d2: bnb.nn.LinearNF4(d1, d2, compress_statistics=True),
modules.append(lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.float16)) "NF4+fp32": lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.float32),
modules.append(lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.bfloat16)) "NF4+fp16": lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.float16),
names = ['Int8Lt', '4bit', 'FP4', 'NF4', 'FP4+C', 'NF4+C', 'NF4+fp32', 'NF4+fp16', 'NF4+bf16'] "NF4+bf16": lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.bfloat16),
@pytest.mark.parametrize("module", modules, ids=names) }
@pytest.mark.parametrize("module", module_dict.values(), ids=module_dict.keys())
def test_kbit_backprop(module): def test_kbit_backprop(module):
b = 17 b = 17
dim1 = 37 dim1 = 37
......
from itertools import product
import os import os
from os.path import join from os.path import join
import shutil import shutil
...@@ -11,6 +10,7 @@ import torch ...@@ -11,6 +10,7 @@ import torch
import bitsandbytes as bnb import bitsandbytes as bnb
import bitsandbytes.functional as F import bitsandbytes.functional as F
from tests.helpers import describe_dtype, id_formatter
# import apex # import apex
...@@ -101,15 +101,16 @@ str2statenames["rmsprop8bit_blockwise"] = [("square_avg", "state1", "qmap1", "ab ...@@ -101,15 +101,16 @@ str2statenames["rmsprop8bit_blockwise"] = [("square_avg", "state1", "qmap1", "ab
str2statenames["lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")] str2statenames["lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")]
str2statenames["paged_lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")] str2statenames["paged_lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")]
dim1 = [1024] optimizer_names_32bit = ["adam", "momentum", "rmsprop", 'paged_adamw', 'paged_adam', 'lion', 'paged_lion']
dim2 = [32, 1024, 4097, 1]
gtype = [torch.float32, torch.float16, torch.bfloat16]
optimizer_names = ["adam", "momentum", "rmsprop", 'paged_adamw', 'paged_adam', 'lion', 'paged_lion'] @pytest.mark.parametrize("optim_name", optimizer_names_32bit, ids=id_formatter("opt"))
values = list(product(dim1, dim2, gtype, optimizer_names)) @pytest.mark.parametrize("gtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values] @pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) @pytest.mark.parametrize("dim2", [32, 1024, 4097, 1], ids=id_formatter("dim2"))
def test_optimizer32bit(dim1, dim2, gtype, optim_name): def test_optimizer32bit(dim1, dim2, gtype, optim_name):
if gtype == torch.bfloat16 and optim_name in ['momentum', 'rmsprop']: pytest.skip() if gtype == torch.bfloat16 and optim_name in ['momentum', 'rmsprop']:
pytest.skip()
if dim1 == 1 and dim2 == 1: if dim1 == 1 and dim2 == 1:
return return
p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
...@@ -134,7 +135,6 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): ...@@ -134,7 +135,6 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
bnb_optimizer.step() bnb_optimizer.step()
torch_optimizer.step() torch_optimizer.step()
for name1, name2 in str2statenames[optim_name]: for name1, name2 in str2statenames[optim_name]:
torch.testing.assert_close( torch.testing.assert_close(
torch_optimizer.state[p1][name1], torch_optimizer.state[p1][name1],
...@@ -177,14 +177,9 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): ...@@ -177,14 +177,9 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
assert bnb_optimizer.state[p2]["unorm_vec"] > 0.0 assert bnb_optimizer.state[p2]["unorm_vec"] > 0.0
dim1 = [1024] @pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
dim2 = [32, 1024, 4097] @pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2"))
gtype = [torch.float32, torch.float16] @pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=describe_dtype)
values = list(product(dim1, dim2, gtype))
names = ["dim1_{}_dim2_{}_gtype_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, gtype", values, ids=names)
def test_global_config(dim1, dim2, gtype): def test_global_config(dim1, dim2, gtype):
if dim1 == 1 and dim2 == 1: if dim1 == 1 and dim2 == 1:
return return
...@@ -230,10 +225,7 @@ def test_global_config(dim1, dim2, gtype): ...@@ -230,10 +225,7 @@ def test_global_config(dim1, dim2, gtype):
assert adam2.state[p3]["state2"].dtype == torch.uint8 assert adam2.state[p3]["state2"].dtype == torch.uint8
dim1 = [1024] optimizer_names_8bit = [
dim2 = [32, 1024, 4097]
gtype = [torch.float32, torch.float16, torch.bfloat16]
optimizer_names = [
"adam8bit", "adam8bit",
"lion8bit", "lion8bit",
"momentum8bit", "momentum8bit",
...@@ -243,13 +235,12 @@ optimizer_names = [ ...@@ -243,13 +235,12 @@ optimizer_names = [
"momentum8bit_blockwise", "momentum8bit_blockwise",
"rmsprop8bit_blockwise", "rmsprop8bit_blockwise",
] ]
values = list(product(dim1, dim2, gtype, optimizer_names))
names = [
"dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values
]
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) @pytest.mark.parametrize("optim_name", optimizer_names_8bit, ids=id_formatter("opt"))
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
@pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
def test_optimizer8bit(dim1, dim2, gtype, optim_name): def test_optimizer8bit(dim1, dim2, gtype, optim_name):
if gtype == torch.bfloat16 and optim_name not in ['adam8bit_blockwise', 'lion8bit_blockwise']: pytest.skip() if gtype == torch.bfloat16 and optim_name not in ['adam8bit_blockwise', 'lion8bit_blockwise']: pytest.skip()
if dim1 == 1 and dim2 == 1: if dim1 == 1 and dim2 == 1:
...@@ -375,18 +366,10 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): ...@@ -375,18 +366,10 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
# print(sum(relerrors)/len(relerrors)) # print(sum(relerrors)/len(relerrors))
dim1 = [1024] @pytest.mark.parametrize("optim_bits", [32, 8], ids=id_formatter("optim_bits"))
dim2 = [32, 1024, 4097] @pytest.mark.parametrize("gtype", [torch.float32], ids=describe_dtype)
gtype = [torch.float32] @pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2"))
optim_bits = [32, 8] @pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
values = list(product(dim1, dim2, gtype, optim_bits))
names = [
"dim1_{}_dim2_{}_gtype_{}_optim_bits_{}".format(*vals)
for vals in values
]
@pytest.mark.parametrize("dim1, dim2, gtype, optim_bits", values, ids=names)
def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits): def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
if dim1 == 1 and dim2 == 1: if dim1 == 1 and dim2 == 1:
return return
...@@ -474,22 +457,19 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits): ...@@ -474,22 +457,19 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
adam2.load_state_dict(torch.load(join(path, "opt.pt"))) adam2.load_state_dict(torch.load(join(path, "opt.pt")))
dim1 = [4096] optimizer_names_benchmark = [
dim2 = [4096] "adam8bit_blockwise",
gtype = [torch.float32, torch.float16] "paged_adam8bit_blockwise",
# optimizer_names = ['adam8bit_blockwise', 'adam8bit', 'lamb8bit'] "paged_adamw8bit_blockwise",
# optimizer_names = ['adam8bit_blockwise', 'adam_apex', 'adam8bit', 'adam', 'adam_pytorch'] "paged_lion8bit_blockwise",
# optimizer_names = ['momentum_apex', 'momentum8bit', 'momentum_pytorch']
# optimizer_names = ['lamb_apex', 'lamb8bit']
# optimizer_names = ['lars_apex', 'lars8bit']
optimizer_names = ["adam8bit_blockwise", 'paged_adam8bit_blockwise', 'paged_adamw8bit_blockwise', 'paged_lion8bit_blockwise']
values = list(product(dim1, dim2, gtype, optimizer_names))
names = [
"dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values
] ]
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) @pytest.mark.parametrize("dim1", [4096], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [4096], ids=id_formatter("dim2"))
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=describe_dtype)
@pytest.mark.parametrize("optim_name", optimizer_names_benchmark, ids=id_formatter("opt"))
@pytest.mark.benchmark
def test_benchmark_blockwise(dim1, dim2, gtype, optim_name): def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
if dim1 == 1 and dim2 == 1: if dim1 == 1 and dim2 == 1:
return return
...@@ -514,15 +494,12 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name): ...@@ -514,15 +494,12 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
print(optim_name, gtype, s / params) print(optim_name, gtype, s / params)
# assert s < 3.9 # assert s < 3.9
dim1 = [2*1024]
gtype = [torch.float16] @pytest.mark.parametrize("dim1", [2 * 1024], ids=id_formatter("dim1"))
#mode = ['torch', 'bnb'] @pytest.mark.parametrize("gtype", [torch.float16], ids=describe_dtype)
mode = ['bnb'] @pytest.mark.parametrize("optim_name", ['paged_adamw'], ids=id_formatter("optim_name"))
optimizer_names = ['paged_adamw'] @pytest.mark.parametrize("mode", ['bnb'], ids=id_formatter("mode"))
#optimizer_names = ['paged_adamw8bit_blockwise'] @pytest.mark.benchmark
values = list(product(dim1,gtype, optimizer_names, mode))
names = ['dim1_{0}_gtype_{1}_optim_{2}_mode_{3}'.format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, gtype, optim_name, mode", values, ids=names)
def test_stream_optimizer_bench(dim1, gtype, optim_name, mode): def test_stream_optimizer_bench(dim1, gtype, optim_name, mode):
layers1 = torch.nn.Sequential(*torch.nn.ModuleList([torch.nn.Linear(dim1, dim1) for i in range(10)])) layers1 = torch.nn.Sequential(*torch.nn.ModuleList([torch.nn.Linear(dim1, dim1) for i in range(10)]))
layers1 = layers1.to(gtype) layers1 = layers1.to(gtype)
......
...@@ -4,11 +4,12 @@ import torch ...@@ -4,11 +4,12 @@ import torch
from bitsandbytes.nn import Linear8bitLt from bitsandbytes.nn import Linear8bitLt
from bitsandbytes.nn.triton_based_modules import SwitchBackLinear from bitsandbytes.nn.triton_based_modules import SwitchBackLinear
from bitsandbytes.triton.triton_utils import is_triton_available from bitsandbytes.triton.triton_utils import is_triton_available
from tests.helpers import TRUE_FALSE
@pytest.mark.skipif(not is_triton_available() or not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] >= 8, @pytest.mark.skipif(not is_triton_available() or not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] >= 8,
reason="This test requires triton and a GPU with compute capability 8.0 or higher.") reason="This test requires triton and a GPU with compute capability 8.0 or higher.")
@pytest.mark.parametrize("vector_wise_quantization", [False, True]) @pytest.mark.parametrize("vector_wise_quantization", TRUE_FALSE)
def test_switchback(vector_wise_quantization): def test_switchback(vector_wise_quantization):
for dim in [83]: for dim in [83]:
for batch in [13]: for batch in [13]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment