Unverified Commit 1088ec52 authored by Matthew Douglas's avatar Matthew Douglas Committed by GitHub
Browse files

Updates for device agnosticism (#1601)

* Include device support tags for transformers multi-backend compatability; add xpu() and cpu() to Params4bit

* Make test suite more device-agnostic

* Additional device agnostic tests

* Additional device agnosticism for tests

* Add BNB_TEST_DEVICE env var to manually select device for unit tests

* Include device support tags for transformers multi-backend compatability; add xpu() and cpu() to Params4bit

* Make test suite more device-agnostic

* Additional device agnostic tests

* Additional device agnosticism for tests

* Add BNB_TEST_DEVICE env var to manually select device for unit tests

* Small bugfix for int8 test

* Exclude backward() from code coverage reports

* Params4bit: don't try to quantize when moving to meta device
parent 97073cdb
......@@ -20,15 +20,15 @@ from .nn import modules
from .optim import adam
# This is a signal for integrations with transformers/diffusers.
# Eventually, we will remove this and check based on release version.
# Eventually we may remove this but it is currently required for compatibility.
features = {"multi-backend"}
supported_torch_devices = {
"cuda",
"cpu",
# "mps",
# "xpu",
# "hpu",
# "npu",
"cuda", # NVIDIA/AMD GPU
"xpu", # Intel GPU
"hpu", # Gaudi
"npu", # Ascend NPU
"mps", # Apple Silicon
}
if torch.cuda.is_available():
......
......@@ -284,7 +284,7 @@ class MatMul8bitLt(torch.autograd.Function):
dtype=torch.float16,
)
if state.threshold > 0.0 and subA is not None:
if state.threshold > 0.0 and subA is not None and subA.numel() > 0:
grad_B[:, idx] += torch.matmul(grad_output.t(), subA)
if req_gradA:
......
......@@ -341,7 +341,7 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
for i in range(gap):
values.append(0)
values.sort()
code = torch.Tensor(values)
code = torch.tensor(values)
code /= code.max()
return code
......
......@@ -306,9 +306,15 @@ class Params4bit(torch.nn.Parameter):
self.bnb_quantized = True
return self
def cpu(self):
return self.to(device="cpu")
def cuda(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False):
return self.to(device="cuda" if device is None else device, non_blocking=non_blocking)
def xpu(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False):
return self.to(device="xpu" if device is None else device, non_blocking=non_blocking)
@overload
def to(
self: T,
......@@ -326,7 +332,7 @@ class Params4bit(torch.nn.Parameter):
def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
if device is not None and device.type == "cuda" and not self.bnb_quantized:
if device is not None and device.type != "meta" and not self.bnb_quantized:
return self._quantize(device)
else:
if self.quant_state is not None:
......
......@@ -79,6 +79,12 @@ include = ["bitsandbytes*"]
[tool.setuptools.dynamic]
version = {attr = "bitsandbytes.__version__"}
[tool.coverage.report]
exclude_also = [
# exclude backward() functions from coverage, as they are invoked from C++
'def backward\(ctx'
]
[tool.pytest.ini_options]
addopts = "-rP -m 'not slow and not benchmark and not deprecated'"
# ; --cov=bitsandbytes
......
import functools
from io import BytesIO
from itertools import product
import os
import random
from typing import Any
......@@ -13,6 +15,38 @@ BOOLEAN_TRIPLES = list(product(TRUE_FALSE, repeat=3)) # all combinations of (bo
BOOLEAN_TUPLES = list(product(TRUE_FALSE, repeat=2)) # all combinations of (bool, bool)
@functools.cache
def get_available_devices():
if "BNB_TEST_DEVICE" in os.environ:
# If the environment variable is set, use it directly.
return [os.environ["BNB_TEST_DEVICE"]]
devices = ["cpu"]
if hasattr(torch, "accelerator"):
# PyTorch 2.6+ - determine accelerator using agnostic API.
if torch.accelerator.is_available():
devices += [str(torch.accelerator.current_accelerator())]
else:
if torch.cuda.is_available():
devices += ["cuda"]
if torch.backends.mps.is_available():
devices += ["mps"]
if hasattr(torch, "xpu") and torch.xpu.is_available():
devices += ["xpu"]
custom_backend_name = torch._C._get_privateuse1_backend_name()
custom_backend_module = getattr(torch, custom_backend_name, None)
custom_backend_is_available_fn = getattr(custom_backend_module, "is_available", None)
if custom_backend_is_available_fn and custom_backend_module.is_available():
devices += [custom_backend_name]
return devices
def torch_save_to_buffer(obj):
buffer = BytesIO()
torch.save(obj, buffer)
......
......@@ -6,12 +6,14 @@ from tests.helpers import (
BOOLEAN_TRIPLES,
TRUE_FALSE,
describe_dtype,
get_available_devices,
id_formatter,
)
TRANSPOSE_VALS = [(False, True), (False, False)]
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dim1", [40], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [64, 0], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", [32], ids=id_formatter("dim3"))
......@@ -27,10 +29,16 @@ TRANSPOSE_VALS = [(False, True), (False, False)]
@pytest.mark.parametrize("transpose", TRANSPOSE_VALS, ids=id_formatter("transpose"))
@pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights"))
@pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias"))
def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights, has_bias):
def test_matmullt(
device, dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights, has_bias
):
if device != "cuda" and funcs[1] == bnb.research.switchback_bnb:
# TODO: Deprecate/remove?
pytest.skip("switchback_bnb only works on CUDA.")
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
outlier_dim = torch.randint(0, dimA[1], size=(dimA[1] // 8,), device="cuda")
outlier_dim = torch.randint(0, dimA[1], size=(dimA[1] // 8,), device=device)
if has_bias == False:
req_grad = list(req_grad)
req_grad[2] = False
......@@ -38,21 +46,21 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
for i in range(3):
# normal multiply
if funcs[0] in [torch.mm, torch.matmul]:
A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype)
A = torch.randn(size=dimA, device=device, requires_grad=req_grad[0], dtype=dtype)
if decomp == 6.0:
with torch.no_grad():
A[:, outlier_dim] = 6.0
B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype)
B = torch.randn(size=dimB, device=device, requires_grad=req_grad[1], dtype=dtype)
target = torch.randn(
size=(dim2, dim4),
device="cuda",
device=device,
requires_grad=req_grad[1],
dtype=dtype,
)
bias = None
bias2 = None
if has_bias:
bias = torch.randn(dim4, device="cuda", dtype=dtype, requires_grad=req_grad[2])
bias = torch.randn(dim4, device=device, dtype=dtype, requires_grad=req_grad[2])
bias2 = bias.clone()
torch.nn.init.xavier_uniform_(B)
B2 = B.clone()
......@@ -91,6 +99,7 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
if has_fp16_weights:
if any(req_grad):
out_bnb.data.copy_(out_torch)
if device == "cuda":
torch.cuda.synchronize()
loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
loss_bnb.backward()
......@@ -135,6 +144,7 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
torch.testing.assert_close(gradBias1, gradBias2)
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dim1", [48], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [64, 0], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", [64], ids=id_formatter("dim3"))
......@@ -147,6 +157,7 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"], ids=id_formatter("quant_type"))
def test_matmul_4bit(
device,
dim1,
dim2,
dim3,
......@@ -159,6 +170,9 @@ def test_matmul_4bit(
compress_statistics,
quant_type,
):
if device == "cpu" and quant_type == "fp4":
pytest.skip("Only nf4 is supported on CPU")
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
if has_bias == False:
......@@ -168,13 +182,13 @@ def test_matmul_4bit(
for i in range(3):
# normal multiply
if funcs[0] in [torch.mm, torch.matmul]:
A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype)
B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype)
target = torch.randn(size=(dim2, dim4), device="cuda", requires_grad=req_grad[1], dtype=dtype)
A = torch.randn(size=dimA, device=device, requires_grad=req_grad[0], dtype=dtype)
B = torch.randn(size=dimB, device=device, requires_grad=req_grad[1], dtype=dtype)
target = torch.randn(size=(dim2, dim4), device=device, requires_grad=req_grad[1], dtype=dtype)
bias = None
bias2 = None
if has_bias:
bias = torch.randn(dim4, device="cuda", dtype=dtype, requires_grad=req_grad[2])
bias = torch.randn(dim4, device=device, dtype=dtype, requires_grad=req_grad[2])
bias2 = bias.clone()
torch.nn.init.xavier_uniform_(B)
......@@ -204,6 +218,7 @@ def test_matmul_4bit(
# assert err < 0.20
if any(req_grad):
out_bnb.data.copy_(out_torch)
if device == "cuda":
torch.cuda.synchronize()
loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
loss_bnb.backward()
......
......@@ -13,6 +13,7 @@ from tests.helpers import (
BOOLEAN_TUPLES,
TRUE_FALSE,
describe_dtype,
get_available_devices,
get_test_dims,
id_formatter,
)
......@@ -87,15 +88,26 @@ class Timer:
class Test8BitBlockwiseQuantizeFunctional:
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
@pytest.mark.parametrize("nested", TRUE_FALSE, ids=id_formatter("nested"))
@pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64])
@pytest.mark.parametrize("signed", TRUE_FALSE, ids=id_formatter("signed"))
def test_dynamic_blockwise_quantization(self, dtype, nested, blocksize, signed):
def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize, signed):
if device == "cpu":
# This test is slow on CPU, so avoid atypical use cases.
if nested:
pytest.skip("Not a typical use case.")
if blocksize != 256:
pytest.skip("Only blocksize 256 is the typical one supported on CPU.")
if dtype != torch.float32:
pytest.xfail(f"CPU implementation currently only supports float32, got {dtype}")
diffs = []
reldiffs = []
for i in range(100):
A1 = torch.randn(1024, 1024, device="cuda", dtype=dtype)
A1 = torch.randn(1024, 1024, device=device, dtype=dtype)
C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested)
A2 = F.dequantize_blockwise(C, S)
diff = torch.abs(A1 - A2).float()
......@@ -113,7 +125,7 @@ class Test8BitBlockwiseQuantizeFunctional:
diffs = []
code = F.create_dynamic_map(signed=signed)
for i in range(100):
A1 = torch.rand(1024, 1024, device="cuda", dtype=dtype)
A1 = torch.rand(1024, 1024, device=device, dtype=dtype)
C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested, code=code)
A2 = F.dequantize_blockwise(C, S)
diff = torch.abs(A1 - A2).float()
......@@ -154,21 +166,27 @@ class Test8BitBlockwiseQuantizeFunctional:
# print(sum(diffs)/len(diffs))
# print(sum(reldiffs)/len(reldiffs))
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("bits", range(2, 9), ids=id_formatter("bits"))
@pytest.mark.parametrize("method", ["linear", "fp8", "dynamic", "quantile"])
def test_few_bit_quant(self, bits, method):
def test_few_bit_quant(self, device, bits, method):
if device == "cpu" and bits != 8:
pytest.skip("CPU implementation only supports 8 bits")
abserrs = []
relerrs = []
code = None
if method == "linear":
code = F.create_linear_map(True, total_bits=bits).cuda()
code = F.create_linear_map(True, total_bits=bits).to(device)
elif method == "fp8":
ebits = math.ceil(bits / 2)
pbits = bits - ebits - 1
code = F.create_fp8_map(True, ebits, pbits, bits).cuda()
code = F.create_fp8_map(True, ebits, pbits, bits).to(device)
elif method == "dynamic":
code = F.create_dynamic_map(True, bits - 0, bits).cuda()
code = F.create_dynamic_map(True, bits - 0, bits).to(device)
elif method == "quantile":
if device != "cuda":
pytest.xfail("Quantile map only works on CUDA")
values = torch.randn(2048, 2048, device="cuda")
code = F.create_quantile_map(values, bits).cuda()
# for some data types we have no zero
......@@ -178,7 +196,7 @@ class Test8BitBlockwiseQuantizeFunctional:
# print(method, (code==0).sum())
assert code.numel() == 256
for i in range(10):
values = torch.randn(1, 32, device="cuda")
values = torch.randn(1, 32, device=device)
values /= values.abs().max()
# values[values.abs() < 1e-6] += 1e-5
......@@ -189,8 +207,8 @@ class Test8BitBlockwiseQuantizeFunctional:
q1.append(idx.item())
v1.append(code[idx].item())
q1 = torch.Tensor(q1).cuda()
v1 = torch.Tensor(v1).cuda()
q1 = torch.tensor(q1, device=device)
v1 = torch.tensor(v1, device=device)
q2, S2 = F.quantize_blockwise(values, code=code)
v2 = F.dequantize_blockwise(q2, S2)
......@@ -206,15 +224,20 @@ class Test8BitBlockwiseQuantizeFunctional:
else:
torch.testing.assert_close(q1, q2)
def test_fp8_quant(self):
@pytest.mark.parametrize("device", get_available_devices())
def test_fp8_quant(self, device):
# TODO
if device == "cpu":
pytest.skip("CPU implementation segfaults")
for e_bits in range(1, 7):
p_bits = 7 - e_bits
code = F.create_fp8_map(True, e_bits, p_bits).cuda()
code = F.create_fp8_map(True, e_bits, p_bits).to(device)
abserr = []
relerr = []
for i in range(100):
A1 = torch.randn(1024, 1024, device="cuda")
A1 = torch.randn(1024, 1024, device=device)
C, SC = F.quantize_blockwise(A1, code=code)
A2 = F.dequantize_blockwise(C, SC)
diff = torch.abs(A1 - A2)
......@@ -228,7 +251,7 @@ class Test8BitBlockwiseQuantizeFunctional:
abserr = []
relerr = []
for i in range(100):
A1 = torch.rand(1024, 1024, device="cuda")
A1 = torch.rand(1024, 1024, device=device)
C, SC = F.quantize_blockwise(A1, code=code)
A2 = F.dequantize_blockwise(C, SC)
diff = torch.abs(A1 - A2)
......@@ -242,7 +265,7 @@ class Test8BitBlockwiseQuantizeFunctional:
abserr = []
relerr = []
for i in range(100):
A1 = torch.randn(1024, 1024, device="cuda")
A1 = torch.randn(1024, 1024, device=device)
C, SC = F.quantize_blockwise(A1)
A2 = F.dequantize_blockwise(C, SC)
diff = torch.abs(A1 - A2)
......@@ -329,6 +352,7 @@ methods = {
}
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required")
class TestIGEMMFunctional:
@pytest.mark.parametrize("dim1", [1024 * 2], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [1024 * 16], ids=id_formatter("dim2"))
......@@ -532,36 +556,38 @@ class TestIGEMMFunctional:
class TestLLMInt8Functional:
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dim1", [128], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [256], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", [499, 512], ids=id_formatter("dim3"))
@pytest.mark.parametrize("dim4", [512], ids=id_formatter("dim4"))
@pytest.mark.parametrize("dims", (2, 3), ids=id_formatter("dims"))
@pytest.mark.parametrize("ldb", (0,), ids=id_formatter("ldb"))
def test_int8_linear_matmul(self, dim1, dim2, dim3, dim4, dims, ldb):
def test_int8_linear_matmul(self, device, dim1, dim2, dim3, dim4, dims, ldb):
for i in range(k):
if dims == 2:
A = torch.randint(-128, 127, size=(dim1, dim3), device="cuda").to(torch.int8)
A = torch.randint(-128, 127, size=(dim1, dim3), dtype=torch.int8, device=device)
elif dims == 3:
A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(torch.int8)
B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(torch.int8)
A = torch.randint(-128, 127, size=(dim1, dim2, dim3), dtype=torch.int8, device=device)
B = torch.randint(-128, 127, size=(dim4, dim3), dtype=torch.int8, device=device)
C1 = torch.matmul(A.float(), B.t().float())
C2 = F.int8_linear_matmul(A, B)
torch.testing.assert_close(C1, C2.float())
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dim1", [32], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [32], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", [32], ids=id_formatter("dim3"))
@pytest.mark.parametrize("dim4", [32], ids=id_formatter("dim4"))
@pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims"))
def test_int8_linear_matmul_half(self, dim1, dim2, dim3, dim4, dims):
def test_int8_linear_matmul_half(self, device, dim1, dim2, dim3, dim4, dims):
for i in range(k):
if dims == 2:
A = torch.normal(0, 0.5, size=(dim1, dim3), device="cuda").half()
A = torch.normal(0, 0.5, size=(dim1, dim3), device=device).half()
elif dims == 3:
A = torch.normal(0, 0.5, size=(dim1, dim2, dim3), device="cuda").half()
B = torch.randn((dim4, dim3), device="cuda").half()
A = torch.normal(0, 0.5, size=(dim1, dim2, dim3), device=device).half()
B = torch.randn((dim4, dim3), device=device).half()
torch.nn.init.xavier_uniform_(B)
C1 = torch.matmul(A, B.t())
......@@ -573,19 +599,20 @@ class TestLLMInt8Functional:
torch.testing.assert_close(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05)
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dim1", (64, 256), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim4", (64, 1024), ids=id_formatter("dim4"))
@pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims"))
@pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias"))
def test_dequant_mm(self, dim1, dim4, dims, has_bias):
def test_dequant_mm(self, device, dim1, dim4, dims, has_bias):
inner = 128
bias = None
if has_bias:
bias = torch.randn(dim4, device="cuda", dtype=torch.float16)
bias = torch.randn(dim4, device=device, dtype=torch.float16)
for i in range(1):
A = torch.randn(dim1, inner, device="cuda")
B = torch.randn(dim4, inner, device="cuda")
A = torch.randn(dim1, inner, device=device)
B = torch.randn(dim4, inner, device=device)
C1 = torch.matmul(A.half(), B.t().half())
if has_bias:
C1 += bias
......@@ -618,6 +645,7 @@ class TestLLMInt8Functional:
@pytest.mark.parametrize("dim2", [1 * 1024], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims"))
@pytest.mark.parametrize("threshold", [0.0, 3.0], ids=id_formatter("decomp"))
@pytest.mark.deprecated
def test_colrow_absmax(self, dim1, dim2, dims, threshold):
for i in range(k):
A = torch.randn(dim1, dim2, device="cuda").half()
......@@ -654,6 +682,7 @@ class TestLLMInt8Functional:
@pytest.mark.parametrize("dim1", [2048, 4096], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [512, 1024], ids=id_formatter("dim2"))
@pytest.mark.deprecated
def test_int8_double_quant(self, dim1, dim2):
for i in range(k):
A = torch.randn(dim1, dim2, device="cuda").half()
......@@ -686,6 +715,7 @@ class TestLLMInt8Functional:
torch.testing.assert_close(Srow.flatten().float(), statsA)
torch.testing.assert_close(Scol.flatten().float(), statsAt)
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize(
("dim1", "dim4", "inner"),
(
......@@ -697,10 +727,10 @@ class TestLLMInt8Functional:
)
),
)
def test_integrated_int8_linear_matmul(self, dim1, dim4, inner):
def test_integrated_int8_linear_matmul(self, device, dim1, dim4, inner):
for i in range(k):
A = torch.randn(dim1, inner, device="cuda").half()
B = torch.randn(dim4, inner, device="cuda").half()
A = torch.randn(dim1, inner, device=device).half()
B = torch.randn(dim4, inner, device=device).half()
out1 = torch.matmul(A.half(), B.t().half())
......@@ -724,12 +754,13 @@ class TestLLMInt8Functional:
err2 = torch.abs(out1 - out3).mean().item()
assert err2 <= err1 * 1.025
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [1024, 4096], ids=id_formatter("dim2"))
def test_coo_double_quant(self, dim1, dim2):
def test_coo_double_quant(self, device, dim1, dim2):
threshold = 2.00
for i in range(k):
A = torch.randn(dim1, dim2, device="cuda").half()
A = torch.randn(dim1, dim2, device=device).half()
idx = torch.abs(A) >= threshold
CA, statsA, outlier_cols = F.int8_vectorwise_quant(A, threshold=threshold)
......@@ -743,12 +774,13 @@ class TestLLMInt8Functional:
A2 = (CA.float() * statsA.unsqueeze(1) / 127).half()
torch.testing.assert_close(A, A2, rtol=0.05, atol=1.5e-2)
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [1024, 4096], ids=id_formatter("dim2"))
def test_coo_int8_vectorwise_quant(self, dim1, dim2):
def test_coo_int8_vectorwise_quant(self, device, dim1, dim2):
threshold = 3.00
for i in range(k):
A = torch.randn(dim1, dim2, device="cuda").half()
A = torch.randn(dim1, dim2, device=device).half()
idx = torch.abs(A) >= threshold
CA, statsA, outlier_cols = F.int8_vectorwise_quant(A, threshold=threshold)
......@@ -759,6 +791,7 @@ class TestLLMInt8Functional:
torch.testing.assert_close(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required")
class TestSpMMFunctional:
@pytest.mark.parametrize("dim1", [256, 1024], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [128, 512], ids=id_formatter("dim2"))
......@@ -1025,6 +1058,7 @@ class TestSpMMFunctional:
print("partial matmul", time.time() - t0)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required")
class TestSparseTensorFunctional:
def test_coo2csr(self):
threshold = 1
......@@ -1063,11 +1097,12 @@ class TestSparseTensorFunctional:
class TestQuantize4BitFunctional:
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512, 1024, 2048, 4096])
def test_4bit_quant(self, dtype, quant_type, blocksize):
A1 = torch.randn(1024, 1024, device="cuda", dtype=dtype)
def test_4bit_quant(self, device, dtype, quant_type, blocksize):
A1 = torch.randn(1024, 1024, device=device, dtype=dtype)
qa, SA = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type)
A2 = F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type)
......@@ -1095,13 +1130,14 @@ class TestQuantize4BitFunctional:
# 1024 => 0.8, 2048 => 0.88, 4096 => 0.96
assert err.item() < math.log2(blocksize) * 8e-2
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
@pytest.mark.parametrize("blocksize", [64, 128], ids=id_formatter("blocksize"))
def test_4bit_compressed_stats(self, quant_type, blocksize):
def test_4bit_compressed_stats(self, device, quant_type, blocksize):
errs1 = []
errs2 = []
for i in range(10):
A1 = torch.randn(1024, 1024, device="cuda").half()
A1 = torch.randn(1024, 1024, device=device).half()
q2, SA2 = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type)
q3, SA3 = F.quantize_4bit(A1, blocksize=blocksize, compress_statistics=True, quant_type=quant_type)
A2 = F.dequantize_4bit(q2, SA2, quant_type=quant_type)
......@@ -1127,6 +1163,7 @@ class TestQuantize4BitFunctional:
# @pytest.mark.parametrize("quant_type", ['fp4', 'nf4'])
@pytest.mark.parametrize("quant_type", ["nf4"])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required")
@pytest.mark.benchmark
def test_bench_4bit_dequant(self, quant_type):
blocksize = 256
......@@ -1157,6 +1194,7 @@ class TestQuantize4BitFunctional:
# torch.cuda.synchronize()
# print((time.time()-t0)/iters*1e6)
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("double_quant", TRUE_FALSE, ids=lambda double_quant: f"DQ_{double_quant}")
@pytest.mark.parametrize("storage_type", ["nf4", "fp4"])
@pytest.mark.parametrize("kind", ["fc1", "fc2", "attn", "attn_packed"])
......@@ -1167,7 +1205,7 @@ class TestQuantize4BitFunctional:
ids=describe_dtype,
)
@pytest.mark.parametrize("dim", [128, 256, 512, 1024], ids=id_formatter("dim"))
def test_gemv_4bit(self, dim, dtype, storage_type, quant_storage, double_quant, kind):
def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double_quant, kind):
errs1 = []
errs2 = []
errs3 = []
......@@ -1180,17 +1218,17 @@ class TestQuantize4BitFunctional:
for i in range(100):
if kind == "fc1":
A = torch.randn(1, dim, dtype=dtype, device="cuda")
B = torch.randn(dim * 4, dim, dtype=dtype, device="cuda") / math.sqrt(dim)
A = torch.randn(1, dim, dtype=dtype, device=device)
B = torch.randn(dim * 4, dim, dtype=dtype, device=device) / math.sqrt(dim)
elif kind == "fc2":
A = torch.randn(1, 4 * dim, dtype=dtype, device="cuda")
B = torch.randn(dim, 4 * dim, dtype=dtype, device="cuda") / math.sqrt(dim)
A = torch.randn(1, 4 * dim, dtype=dtype, device=device)
B = torch.randn(dim, 4 * dim, dtype=dtype, device=device) / math.sqrt(dim)
elif kind == "attn":
A = torch.randn(1, dim, dtype=dtype, device="cuda")
B = torch.randn(dim, dim, dtype=dtype, device="cuda") / math.sqrt(dim)
A = torch.randn(1, dim, dtype=dtype, device=device)
B = torch.randn(dim, dim, dtype=dtype, device=device) / math.sqrt(dim)
elif kind == "attn_packed":
A = torch.randn(1, dim, dtype=dtype, device="cuda")
B = torch.randn(dim * 3, dim, dtype=dtype, device="cuda") / math.sqrt(dim)
A = torch.randn(1, dim, dtype=dtype, device=device)
B = torch.randn(dim * 3, dim, dtype=dtype, device=device) / math.sqrt(dim)
qB, state = F.quantize_4bit(
B,
......@@ -1294,18 +1332,19 @@ class TestQuantize4BitFunctional:
assert relratio < 1.04 and relratio > 0.96
assert maxratio < 1.02 and maxratio > 0.98
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
@pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"])
def test_gemv_eye_4bit(self, storage_type, dtype, double_quant):
def test_gemv_eye_4bit(self, device, storage_type, dtype, double_quant):
dims = 10
torch.random.manual_seed(np.random.randint(0, 412424242))
dims = get_test_dims(0, 8192, n=dims)
dims = [dim + (64 - (dim % 64)) for dim in dims]
# for dim in [576, 5120, 3520, 5184, 1280, 4992, 5312, 2048]:
for dim in dims:
A = torch.normal(0, 0.1, size=(1, 1, dim), dtype=dtype, device="cuda")
B = torch.eye(dim, dtype=dtype, device="cuda")
A = torch.normal(0, 0.1, size=(1, 1, dim), dtype=dtype, device=device)
B = torch.eye(dim, dtype=dtype, device=device)
qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant)
C3 = torch.matmul(A, B.t())
......
......@@ -7,7 +7,7 @@ import pytest
import torch
import bitsandbytes as bnb
from tests.helpers import TRUE_FALSE, torch_load_from_buffer, torch_save_to_buffer
from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter, torch_load_from_buffer, torch_save_to_buffer
storage = {
"uint8": torch.uint8,
......@@ -17,15 +17,18 @@ storage = {
}
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_storage", ["uint8", "float16", "bfloat16", "float32"])
@pytest.mark.parametrize("bias", TRUE_FALSE)
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE)
@pytest.mark.parametrize("bias", TRUE_FALSE, ids=id_formatter("bias"))
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
@pytest.mark.parametrize("save_before_forward", TRUE_FALSE)
def test_linear_serialization(quant_type, compress_statistics, bias, quant_storage, save_before_forward):
@pytest.mark.parametrize("save_before_forward", TRUE_FALSE, ids=id_formatter("save_before_forward"))
def test_linear_serialization(device, quant_type, compress_statistics, bias, quant_storage, save_before_forward):
if device == "cpu":
pytest.xfail("Dequantization is not yet implemented for CPU")
original_dtype = torch.float16
compute_dtype = None
device = "cuda"
layer_shape = (300, 400)
linear = torch.nn.Linear(*layer_shape, dtype=original_dtype, device="cpu") # original layer
......@@ -52,7 +55,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
# restoring from state_dict:
bias_data2 = sd.pop("bias", None)
weight_data2 = sd.pop("weight")
weight2 = bnb.nn.Params4bit.from_prequantized(quantized_stats=sd, data=weight_data2)
weight2 = bnb.nn.Params4bit.from_prequantized(quantized_stats=sd, data=weight_data2, device=device)
# creating new layer with same params:
linear_q2 = bnb.nn.Linear4bit(
......@@ -174,18 +177,50 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
assert size_ratio < target_compression, ratio_error_msg
def test_copy_param():
tensor = torch.tensor([1.0, 2.0, 3.0, 4.0])
param = bnb.nn.Params4bit(data=tensor, requires_grad=False).cuda(0)
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
@pytest.mark.parametrize("blocksize", [64, 128])
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
def test_copy_param(device, quant_type, blocksize, compress_statistics):
if device == "cpu":
if compress_statistics:
pytest.skip("Currently segfaults on CPU")
if quant_type == "fp4":
pytest.xfail("FP4 not supported on CPU")
tensor = torch.linspace(1, blocksize, blocksize)
param = bnb.nn.Params4bit(
data=tensor,
quant_type=quant_type,
blocksize=blocksize,
compress_statistics=compress_statistics,
requires_grad=False,
).to(device)
shallow_copy_param = copy.copy(param)
assert param.quant_state is shallow_copy_param.quant_state
assert param.data.data_ptr() == shallow_copy_param.data.data_ptr()
def test_deepcopy_param():
tensor = torch.tensor([1.0, 2.0, 3.0, 4.0])
param = bnb.nn.Params4bit(data=tensor, requires_grad=False).cuda(0)
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
@pytest.mark.parametrize("blocksize", [64, 128])
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
def test_deepcopy_param(device, quant_type, blocksize, compress_statistics):
if device == "cpu":
if compress_statistics:
pytest.skip("Currently segfaults on CPU")
if quant_type == "fp4":
pytest.xfail("FP4 not supported on CPU")
tensor = torch.linspace(1, blocksize, blocksize)
param = bnb.nn.Params4bit(
data=tensor,
quant_type=quant_type,
blocksize=blocksize,
compress_statistics=compress_statistics,
requires_grad=False,
).to(device)
dict_keys_before = set(param.__dict__.keys())
copy_param = copy.deepcopy(param)
dict_keys_after = set(param.__dict__.keys())
......@@ -199,12 +234,27 @@ def test_deepcopy_param():
assert dict_keys_before == dict_keys_copy
def test_params4bit_real_serialization():
original_tensor = torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float32)
original_param = bnb.nn.Params4bit(data=original_tensor, quant_type="fp4")
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
@pytest.mark.parametrize("blocksize", [64, 128])
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
def test_params4bit_real_serialization(device, quant_type, blocksize, compress_statistics):
if device == "cpu":
if compress_statistics:
pytest.skip("Currently segfaults on CPU")
if quant_type == "fp4":
pytest.xfail("FP4 not supported on CPU")
original_tensor = torch.linspace(1, blocksize, blocksize, dtype=torch.float32)
original_param = bnb.nn.Params4bit(
data=original_tensor,
quant_type=quant_type,
blocksize=blocksize,
compress_statistics=compress_statistics,
)
dict_keys_before = set(original_param.__dict__.keys())
original_param.cuda(0) # move to CUDA to trigger quantization
original_param.to(device) # change device to trigger quantization
serialized_param = pickle.dumps(original_param)
deserialized_param = pickle.loads(serialized_param)
......
......@@ -11,6 +11,7 @@ import bitsandbytes as bnb
from bitsandbytes.nn.modules import Linear8bitLt
from tests.helpers import (
TRUE_FALSE,
get_available_devices,
id_formatter,
torch_load_from_buffer,
torch_save_to_buffer,
......@@ -19,7 +20,11 @@ from tests.helpers import (
# contributed by Alex Borzunov, see:
# https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py
def test_linear_no_igemmlt():
@pytest.mark.parametrize("device", get_available_devices())
def test_linear_no_igemmlt(device):
if device == "cpu":
pytest.xfail("Not yet implemented on CPU")
linear = torch.nn.Linear(1024, 3072)
x = torch.randn(3, 1024, dtype=torch.half)
linear_custom = Linear8bitLt(
......@@ -29,6 +34,8 @@ def test_linear_no_igemmlt():
has_fp16_weights=False,
threshold=6.0,
)
# TODO: Remove, this is no longer implemented
linear_custom.state.force_no_igemmlt = True
linear_custom.weight = bnb.nn.Int8Params(
......@@ -37,11 +44,11 @@ def test_linear_no_igemmlt():
has_fp16_weights=False,
).to(linear.weight.dtype)
linear_custom.bias = linear.bias
linear_custom = linear_custom.cuda()
linear = linear.half().cuda()
linear_custom = linear_custom.to(device)
linear = linear.half().to(device)
x_ref = x.clone().cuda().requires_grad_(True)
x_ours = x.clone().cuda().requires_grad_(True)
x_ref = x.clone().to(device).requires_grad_(True)
x_ours = x.clone().to(device).requires_grad_(True)
fx_ref = linear(x_ref).float()
grad_proj = torch.randn_like(fx_ref)
(fx_ref * grad_proj).mean().backward()
......@@ -58,18 +65,25 @@ def test_linear_no_igemmlt():
torch.testing.assert_close(x_ref.grad, x_ours.grad, atol=0.01, rtol=1e-5)
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights"))
@pytest.mark.parametrize("threshold", [0.0, 6.0], ids=id_formatter("threshold"))
@pytest.mark.parametrize("serialize_before_forward", TRUE_FALSE, ids=id_formatter("serialize_before_forward"))
@pytest.mark.parametrize("deserialize_before_cuda", TRUE_FALSE, ids=id_formatter("deserialize_before_cuda"))
@pytest.mark.parametrize("save_before_forward", TRUE_FALSE, ids=id_formatter("save_before_forward"))
@pytest.mark.parametrize("load_before_cuda", TRUE_FALSE, ids=id_formatter("load_before_cuda"))
def test_linear_serialization(
device,
has_fp16_weights,
threshold,
serialize_before_forward,
deserialize_before_cuda,
save_before_forward,
load_before_cuda,
):
if device == "cpu":
pytest.xfail("Not yet implemented on CPU")
linear = torch.nn.Linear(32, 96)
# TODO: Fallback for bad shapes
x = torch.randn(4, 32, dtype=torch.half)
......@@ -80,7 +94,7 @@ def test_linear_serialization(
linear.out_features,
linear.bias is not None,
has_fp16_weights=has_fp16_weights,
threshold=6.0,
threshold=threshold,
)
linear_custom.weight = bnb.nn.Int8Params(
......@@ -89,7 +103,7 @@ def test_linear_serialization(
has_fp16_weights=has_fp16_weights,
)
linear_custom.bias = linear.bias
linear_custom = linear_custom.cuda()
linear_custom = linear_custom.to(device)
if serialize_before_forward:
state_dict_8bit = linear_custom.state_dict()
......@@ -125,7 +139,7 @@ def test_linear_serialization(
linear.out_features,
linear.bias is not None,
has_fp16_weights=has_fp16_weights,
threshold=6.0,
threshold=threshold,
)
if deserialize_before_cuda:
......@@ -135,7 +149,7 @@ def test_linear_serialization(
if load_before_cuda:
new_linear_custom2 = torch_load_from_buffer(bytes_8bit)
new_linear_custom = new_linear_custom.cuda()
new_linear_custom = new_linear_custom.to(device)
if not deserialize_before_cuda:
new_linear_custom.load_state_dict(new_state_dict, strict=True)
......
import inspect
import math
import einops
import pytest
import torch
from torch import nn
import bitsandbytes as bnb
from tests.helpers import id_formatter
from tests.helpers import get_available_devices, id_formatter
class MockArgs:
......@@ -54,266 +52,32 @@ def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10):
torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
class LinearFunction(torch.autograd.Function):
@staticmethod
def get_8bit_linear_trimmed(x, stochastic=False, trim_value=3.0):
round_func = LinearFunction.round_stoachastic if stochastic else torch.round
norm = math.sqrt(math.pi) / math.sqrt(2.0)
# std = torch.abs(x).mean()*norm
std = torch.std(x)
max1 = std * trim_value
x = x / max1 * 127
x = round_func(x)
x[x > 127] = 127
x[x < -127] = -127
x = x / 127 * max1
return x
def quant(x, quant_type, dim=1):
if quant_type == "linear":
max1 = torch.abs(x).max().float()
xq = torch.round(x / max1 * 127).to(torch.int8)
return xq, max1
elif quant_type == "vector":
max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
xq = torch.round(x / max1 * 127).to(torch.int8)
return xq, max1
elif quant_type == "min-max":
maxA = torch.amax(x, dim=dim, keepdim=True).float()
minA = torch.amin(x, dim=dim, keepdim=True).float()
scale = (maxA - minA) / 2.0
xq = torch.round(127 * (x - minA - scale) / scale).to(torch.int8)
return xq, (minA.float(), scale.float())
else:
return None
def dequant(xq, S1, S2, dtype, quant_type):
if quant_type == "linear":
norm = S1 * S2 / (127 * 127)
# double cast needed to prevent overflows
return (xq.float() * norm).to(dtype)
elif quant_type == "vector":
x = xq.float()
if len(xq.shape) == 2 and len(S1.shape) == 3:
S1 = S1.squeeze(0)
if len(xq.shape) == 2 and len(S2.shape) == 3:
S2 = S2.squeeze(0)
# print(x.shape, S1.shape, S2.shape)
if len(S1.shape) == 2:
x *= S1.t() / 127
else:
x *= S1 / 127
x *= S2 / 127
return x.to(dtype)
else:
return None
def dequant_min_max(xq, A, B, SA, SB, dtype):
offset = B.float().t().sum(0) * (SA[0] + SA[1])
x = xq.float()
if len(xq.shape) == 2 and len(SB.shape) == 3:
SB = SB.squeeze(0)
if len(xq.shape) == 2 and len(SA.shape) == 3:
SA = SA.squeeze(0)
if len(SB.shape) == 2:
x *= SB.t() / 127
else:
x *= SB / 127
x *= SA[1] / 127
x += offset
return x.to(dtype)
def get_8bit_linear(x, stochastic=False):
round_func = LinearFunction.round_stoachastic if stochastic else torch.round
max1 = torch.abs(x).max()
x = x / max1 * 127
x = round_func(x) / 127 * max1
# x = torch.round(x)/128*max1
return x
@staticmethod
def get_8bit_vector_wise(x, dim, stochastic=False):
round_func = LinearFunction.round_stoachastic if stochastic else torch.round
max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
max1[max1 == 0] = 1.0
x = (x * 127) / max1
x = round_func(x) / 127 * max1
return x
@staticmethod
def round_stoachastic(x):
sign = torch.sign(x)
absx = torch.abs(x)
decimal = absx - torch.floor(absx)
rdm = torch.rand_like(decimal)
return sign * (torch.floor(absx) + (rdm < decimal).to(x.dtype))
@staticmethod
def fake_8bit_storage(w, exponent_bits):
code = bnb.functional.create_dynamic_map(n=exponent_bits).to(w.device)
absmax, C = bnb.functional.quantize_blockwise(w.data, code=code)
out = bnb.functional.dequantize_blockwise(absmax, C, code)
out = out.half()
w.copy_(out)
return out
@staticmethod
def fake_8bit_storage_quantile(w, args):
code = bnb.functional.estimate_quantiles(w.data, offset=args.offset)
# C = bnb.functional.quantize_no_absmax(code, w)
# out = bnb.functional.dequantize_no_absmax(code, C, out=w.data)
# print(out)
# out = out.half()
code /= torch.max(torch.abs(code))
absmax, C = bnb.functional.quantize_blockwise(w.data, code=code)
out = bnb.functional.dequantize_blockwise(absmax, C, code)
out = out.half()
w.copy_(out)
return out
@staticmethod
def fake_8bit_storage_stoachstic(w):
rand = torch.rand(1024, device=w.device)
absmax, C = bnb.functional.quantize_blockwise(w.data, rand=rand)
out = bnb.functional.dequantize_blockwise(absmax, C)
out = out.half()
w.copy_(out)
return out
@staticmethod
def fake_8bit_storage_with_max(w, topk=8):
blocked_w = einops.rearrange(w.flatten(), "(h b) -> h b", b=256)
max_val, idx = torch.sort(torch.abs(blocked_w), dim=1, descending=True)
idx = idx[:, :topk]
max_val = max_val[:, :topk]
mask = torch.zeros_like(blocked_w)
mask.scatter_(dim=1, index=idx, src=torch.ones_like(max_val))
mask = mask.bool()
# 1. zero out max values
# 2. quantize + dequantize
# 3. write back max values
# 4. copy matrix back to weight
values = blocked_w[mask]
blocked_w[mask] = 0
code = bnb.functional.create_dynamic_map()
code = code.to(w.device)
absmax, C = bnb.functional.quantize_blockwise(blocked_w.data)
bnb.functional.dequantize_blockwise(absmax, C, out=blocked_w)
blocked_w[mask] = values
unblocked_w = blocked_w.flatten().view(w.shape)
w.copy_(unblocked_w)
return unblocked_w
@staticmethod
def forward(ctx, x, weight, bias=None, args=None):
if args.use_8bit_training != "off":
weight8, S1 = LinearFunction.quant(weight, args.quant_type, dim=1)
x8, S2 = LinearFunction.quant(x, args.quant_type, dim=2)
outputq = bnb.functional.igemm(x8, weight8.t())
output = LinearFunction.dequant(outputq, S1, S2, x.dtype, args.quant_type)
# if torch.rand(1) < 0.01:
# output32 = torch.matmul(x, weight.t())
# err = torch.abs(output-output32).float()
# relerr = err/(torch.abs(output32).float()+1e-8)
# print(f'{err.mean().item():.4f}, {relerr.mean().item():.4f}', args.quant_type, 'forward', proxy)
else:
# output = torch.matmul(x, weight.t())
output = torch.einsum("bsi,oi->bso", x, weight)
ctx.save_for_backward(x, weight, bias)
ctx.args = args
if bias is not None:
output += bias.unsqueeze(0).expand_as(output)
return output
@staticmethod
def backward(ctx, grad_output):
x, weight, bias = ctx.saved_tensors
args = ctx.args
stochastic = False
grad_input = grad_weight = grad_bias = None
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = grad_output.sum(0)
# weight and x are already 8bit
# -> transform grad_output to 8-bit
if args.use_8bit_training == "forward+wgrad":
grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=[0, 1])
x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1])
grad_weight8 = bnb.functional.igemm(grad_output8, x8)
grad_weight = LinearFunction.dequant(grad_weight8, S1, S2, grad_output.dtype, args.quant_type)
# grad_weight32 = torch.einsum('bso,bsi->oi', grad_output, x)
grad_input = grad_output.matmul(weight)
elif args.use_8bit_training == "full":
grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=[0, 1])
x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1])
grad_weight8 = torch.zeros_like(weight, dtype=torch.int32)
bnb.functional.igemm(grad_output8, x8, out=grad_weight8)
grad_weight = LinearFunction.dequant(grad_weight8, S1, S2, grad_output.dtype, args.quant_type)
grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=2)
weight8, S3 = LinearFunction.quant(weight, args.quant_type, dim=0)
grad_input8 = bnb.functional.igemm(grad_output8, weight8)
grad_input = LinearFunction.dequant(grad_input8, S1, S3, grad_output.dtype, args.quant_type)
else:
grad_input = grad_output.matmul(weight)
grad_weight = torch.einsum("bsi,bso->oi", x, grad_output)
return grad_input, grad_weight, grad_bias, None
class Linear8bit(nn.Module):
def __init__(self, input_features, output_features, bias=True, args=None):
super().__init__()
self.input_features = input_features
self.output_features = output_features
self.args = args
self.weight = nn.Parameter(torch.empty(output_features, input_features))
if bias:
self.bias = nn.Parameter(torch.empty(output_features))
else:
self.register_parameter("bias", None)
torch.nn.init.xavier_uniform_(self.weight)
if self.bias is not None:
torch.nn.init.zeros_(self.bias)
def forward(self, x):
self.args.training = self.training
return LinearFunction.apply(x, self.weight, self.bias, self.args)
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("threshold", [0.0, 3.0], ids=id_formatter("threshold"))
def test_linear8bitlt_inference(threshold):
l1 = bnb.nn.Linear8bitLt(32, 64, threshold=threshold).cuda().half()
assert l1.weight.device.type == "cuda"
assert l1.weight.dtype == torch.float16
def test_linear8bitlt_inference(device, threshold):
if device == "cpu":
pytest.xfail("Not yet implemented on CPU")
l1 = bnb.nn.Linear8bitLt(32, 64, threshold=threshold, has_fp16_weights=False).to(device).half()
assert l1.weight.device.type == device
assert l1.weight.dtype == torch.int8
l1.eval()
for i in range(100):
b1 = torch.randn(16, 8, 32, device="cuda").half()
b1 = torch.randn(16, 8, 32, device=device).half()
o1 = l1(b1)
if i == 1:
assert l1.state.CB is not None
def test_linear8bitlt_accumulated_gradient():
l1 = torch.nn.Sequential(*[bnb.nn.Linear8bitLt(32, 32).cuda().half() for i in range(2)])
l2 = torch.nn.Sequential(*[torch.nn.Linear(32, 32).cuda().half() for i in range(2)])
# TODO: Remove support for training int8 weights
@pytest.mark.parametrize("device", get_available_devices())
def test_linear8bitlt_accumulated_gradient(device):
if device != "cuda":
pytest.skip("Only supported on CUDA")
l1 = torch.nn.Sequential(*[bnb.nn.Linear8bitLt(32, 32).to(device).half() for i in range(2)])
l2 = torch.nn.Sequential(*[torch.nn.Linear(32, 32).to(device).half() for i in range(2)])
l1[0].weight.data.copy_(l2[0].weight.data)
l1[1].weight.data.copy_(l2[1].weight.data)
l1[0].bias.data.copy_(l2[0].bias.data)
......@@ -325,7 +89,7 @@ def test_linear8bitlt_accumulated_gradient():
acc_steps = 10
for i in range(15):
b1 = torch.randn(16, 8, 32, device="cuda").half()
b1 = torch.randn(16, 8, 32, device=device).half()
o1 = l1(b1)
o2 = l2(b1)
loss1 = o1.mean()
......@@ -353,8 +117,12 @@ def test_linear8bitlt_accumulated_gradient():
assert_all_approx_close(l1[1].weight.grad, l2[1].weight.grad, rtol=1.05, atol=0.04, count=1)
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("threshold", [0.0, 2.0])
def test_linear8bitlt_no_fp16_weights(threshold):
def test_linear8bitlt_no_fp16_weights(device, threshold):
if device == "cpu":
pytest.xfail("Not yet supported on CPU")
l1 = (
bnb.nn.Linear8bitLt(
32,
......@@ -362,23 +130,23 @@ def test_linear8bitlt_no_fp16_weights(threshold):
threshold=threshold,
has_fp16_weights=False,
)
.cuda()
.to(device)
.half()
)
assert l1.weight.dtype == torch.int8
l1.eval()
for i in range(100):
b1 = torch.randn(16, 8, 32, device="cuda").half()
b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16)
o1 = l1(b1)
assert o1.dtype == torch.float16
mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda()
mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).to(device)
assert mlp.fc1.weight.dtype == torch.int8
assert mlp.fc2.weight.dtype == torch.int8
for i in range(100):
b1 = torch.randn(16, 8, 32, device="cuda").half()
b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16)
o1 = mlp(b1)
assert o1.dtype == torch.float16
if threshold > 0:
......@@ -386,12 +154,12 @@ def test_linear8bitlt_no_fp16_weights(threshold):
if threshold > 0:
assert mlp.fc2.state.idx is not None
mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda().half()
mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).to(device).half()
assert mlp.fc1.weight.dtype == torch.int8
assert mlp.fc2.weight.dtype == torch.int8
for i in range(100):
b1 = torch.randn(16, 8, 32, device="cuda").half()
b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16)
o1 = mlp(b1)
assert o1.dtype == torch.float16
if threshold > 0:
......@@ -399,10 +167,10 @@ def test_linear8bitlt_no_fp16_weights(threshold):
if threshold > 0:
assert mlp.fc2.state.idx is not None
mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().cuda()
mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().to(device)
for i in range(100):
b1 = torch.randn(16, 8, 32, device="cuda").half()
b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16)
o1 = mlp(b1)
assert o1.dtype == torch.float16
if threshold > 0:
......@@ -420,11 +188,11 @@ def test_linear8bitlt_no_fp16_weights(threshold):
has_fp16_weights=False,
)
.half()
.to("cuda")
.to(device)
)
for i in range(100):
b1 = torch.randn(16, 8, 32, device="cuda").half()
b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16)
o1 = mlp(b1)
assert o1.dtype == torch.float16
if threshold > 0:
......@@ -433,8 +201,8 @@ def test_linear8bitlt_no_fp16_weights(threshold):
assert mlp.fc2.state.idx is not None
assert mlp.fc1.weight.dtype == torch.int8
assert mlp.fc2.weight.dtype == torch.int8
assert mlp.fc1.weight.device.type == "cuda"
assert mlp.fc2.weight.device.type == "cuda"
assert mlp.fc1.weight.device.type == device
assert mlp.fc2.weight.device.type == device
mlp = MLP8bit(
32,
......@@ -442,11 +210,11 @@ def test_linear8bitlt_no_fp16_weights(threshold):
threshold=threshold,
has_fp16_weights=False,
)
w1, w2 = mlp.fc1.weight.clone().cuda(), mlp.fc2.weight.clone().cuda() # grab weights before quantization,
w1, w2 = mlp.fc1.weight.clone().to(device), mlp.fc2.weight.clone().to(device) # grab weights before quantization,
mlp = mlp.cuda().half() # and this line triggers quantization
for i in range(100):
b1 = torch.randn(16, 8, 32, device="cuda").half()
b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16)
o1 = mlp(b1)
assert o1.dtype == torch.float16
if threshold > 0:
......@@ -456,10 +224,10 @@ def test_linear8bitlt_no_fp16_weights(threshold):
assert mlp.fc1.weight.dtype == torch.int8
assert mlp.fc2.weight.dtype == torch.int8
assert mlp.fc1.weight.device.type == "cuda"
assert mlp.fc2.weight.device.type == "cuda"
assert mlp.fc1.weight.device.type == device
assert mlp.fc2.weight.device.type == device
b1 = torch.randn(16, 8, 32, device="cuda", requires_grad=True, dtype=torch.half)
b1 = torch.randn(16, 8, 32, device=device, requires_grad=True, dtype=torch.half)
o1 = mlp(b1)
assert o1.dtype == torch.float16
assert o1.requires_grad
......@@ -475,33 +243,37 @@ def test_linear8bitlt_no_fp16_weights(threshold):
assert (idx == 0).sum().item() <= b1.numel() * 0.005
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize(
"module",
[
lambda n_in, n_out, bias=True: bnb.nn.Linear8bitLt(n_in, n_out, bias=bias, has_fp16_weights=False),
bnb.nn.LinearFP4,
bnb.nn.LinearNF4,
],
ids=["Int8Lt", "FP4"],
ids=["Int8Lt", "NF4"],
)
def test_linear_kbit_fp32_bias(module):
def test_linear_kbit_fp32_bias(device, module):
if device == "cpu":
pytest.xfail("Not yet implemented on CPU")
# casts model to fp16 -> int8 automatically
l1 = module(32, 64).cuda()
l1 = module(32, 64).to(device)
assert l1.weight.dtype in [torch.int8, torch.uint8]
assert l1.bias.dtype == torch.float32
for i in range(100):
b1 = torch.randn(16, 8, 32, device="cuda").half()
b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16)
# casts bias to fp32
o1 = l1(b1)
assert l1.bias.dtype == torch.float16
# casts model to fp16 -> int8 automatically
l1 = module(32, 64, bias=False).cuda()
l1 = module(32, 64, bias=False).to(device)
assert l1.weight.dtype in [torch.int8, torch.uint8]
assert l1.bias is None
for i in range(100):
b1 = torch.randn(16, 8, 32, device="cuda").half()
b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16)
o1 = l1(b1)
assert l1.bias is None
......@@ -519,8 +291,12 @@ module_dict = {
}
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("module", module_dict.values(), ids=module_dict.keys())
def test_kbit_backprop(module):
def test_kbit_backprop(device, module):
if device == "cpu":
pytest.xfail("Not yet implemented on CPU")
b = 16
dim1 = 36
dim2 = 84
......@@ -536,16 +312,16 @@ def test_kbit_backprop(module):
kbit[1].weight.detach().copy_(ref[1].weight)
kbit[0].bias.detach().copy_(ref[0].bias)
kbit[1].bias.detach().copy_(ref[1].bias)
ref = ref.half().cuda()
kbit = kbit.half().cuda()
kbit = kbit.half().to("cuda")
ref = ref.half().to(device)
kbit = kbit.half().to(device)
kbit = kbit.half().to(device)
errs1 = []
errs2 = []
relerrs1 = []
relerrs2 = []
for i in range(100):
batch = torch.randn(b, dim1).half().cuda()
batch = torch.randn(b, dim1, device=device, dtype=torch.float16)
out1 = ref(batch)
out2 = kbit(batch)
out1.mean().backward()
......@@ -578,6 +354,7 @@ def test_kbit_backprop(module):
assert kbit[0].weight.grad is None or kbit[0].bias.grad.sum().item() == 0
@pytest.mark.deprecated
def test_fp8linear():
b = 10
h = 1024
......@@ -608,6 +385,7 @@ def test_fp8linear():
assert bgraderr < 0.00002
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("embedding_dim", [64, 65])
@pytest.mark.parametrize("input_shape", [(10,), (10, 10), (10, 10, 10)], ids=str)
@pytest.mark.parametrize(
......@@ -621,7 +399,10 @@ def test_fp8linear():
],
ids=lambda x: x.__name__ if inspect.isclass(x) else str(x),
)
def test_embedding_lossless(embedding_class, input_shape, embedding_dim, quant_storage):
def test_embedding_lossless(device, embedding_class, input_shape, embedding_dim, quant_storage):
if device == "cpu":
pytest.xfail("Not yet supported on CPU")
num_embeddings = 128
src_weight = (torch.randn((num_embeddings, embedding_dim), dtype=torch.float32) > 0).to(
......@@ -641,10 +422,10 @@ def test_embedding_lossless(embedding_class, input_shape, embedding_dim, quant_s
e.load_state_dict(emb_base.state_dict())
emb_base.cuda()
e.cuda()
emb_base.to(device)
e.to(device)
input_tokens = torch.randint(low=0, high=num_embeddings, size=input_shape, device="cuda")
input_tokens = torch.randint(low=0, high=num_embeddings, size=input_shape, device=device)
torch.testing.assert_close(
actual=e(input_tokens),
......@@ -652,6 +433,7 @@ def test_embedding_lossless(embedding_class, input_shape, embedding_dim, quant_s
)
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("embedding_dim", [64, 65])
@pytest.mark.parametrize("input_shape", [(10,), (10, 10), (10, 10, 10)], ids=str)
@pytest.mark.parametrize(
......@@ -665,7 +447,10 @@ def test_embedding_lossless(embedding_class, input_shape, embedding_dim, quant_s
],
ids=lambda x: x.__name__ if inspect.isclass(x) else str(x),
)
def test_embedding_error(embedding_class, input_shape, embedding_dim, quant_storage):
def test_embedding_error(device, embedding_class, input_shape, embedding_dim, quant_storage):
if device == "cpu":
pytest.xfail("Not yet supported on CPU")
is_8bit = embedding_class is bnb.nn.Embedding8bit
num_embeddings = 128
......@@ -685,10 +470,10 @@ def test_embedding_error(embedding_class, input_shape, embedding_dim, quant_stor
e.load_state_dict(emb_base.state_dict())
emb_base.cuda()
e.cuda()
emb_base.to(device)
e.to(device)
input_tokens = torch.randint(low=0, high=num_embeddings, size=input_shape, device="cuda")
input_tokens = torch.randint(low=0, high=num_embeddings, size=input_shape, device=device)
torch.testing.assert_close(
actual=e(input_tokens),
......@@ -698,46 +483,64 @@ def test_embedding_error(embedding_class, input_shape, embedding_dim, quant_stor
)
def test_4bit_linear_warnings():
@pytest.mark.parametrize("device", get_available_devices())
def test_4bit_linear_warnings(device):
if device == "cpu":
pytest.xfail("Not yet implemented on CPU")
dim1 = 64
with pytest.warns(UserWarning, match=r"inference or training"):
net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)])
net = net.cuda()
inp = torch.rand(10, dim1).cuda().half()
net = nn.Sequential(
*[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4", compute_dtype=torch.float32) for i in range(10)]
)
net = net.to(device)
inp = torch.rand(10, dim1, device=device, dtype=torch.float16)
net(inp)
with pytest.warns(UserWarning, match=r"inference."):
net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)])
net = net.cuda()
inp = torch.rand(1, dim1).cuda().half()
net = nn.Sequential(
*[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4", compute_dtype=torch.float32) for i in range(10)]
)
net = net.to(device)
inp = torch.rand(1, dim1, device=device, dtype=torch.float16)
net(inp)
with pytest.warns(UserWarning) as record:
net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)])
net = net.cuda()
inp = torch.rand(10, dim1).cuda().half()
net = nn.Sequential(
*[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4", compute_dtype=torch.float32) for i in range(10)]
)
net = net.to(device)
inp = torch.rand(10, dim1, device=device, dtype=torch.float16)
net(inp)
net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)])
net = net.cuda()
inp = torch.rand(1, dim1).cuda().half()
net = nn.Sequential(
*[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4", compute_dtype=torch.float32) for i in range(10)]
)
net = net.to(device)
inp = torch.rand(1, dim1, device=device, dtype=torch.float16)
net(inp)
assert len(record) == 2
def test_4bit_embedding_warnings():
@pytest.mark.parametrize("device", get_available_devices())
def test_4bit_embedding_warnings(device):
if device == "cpu":
pytest.xfail("Not yet implemented on CPU")
num_embeddings = 128
default_block_size = 64
with pytest.warns(UserWarning, match=r"inference."):
net = bnb.nn.Embedding4bit(num_embeddings=num_embeddings, embedding_dim=default_block_size + 1)
net.cuda()
inp = torch.randint(low=0, high=num_embeddings, size=(1,), device="cuda")
net = bnb.nn.Embedding4bit(
num_embeddings=num_embeddings, embedding_dim=default_block_size + 1, quant_type="nf4"
)
net.to(device)
inp = torch.randint(low=0, high=num_embeddings, size=(1,), device=device)
net(inp)
def test_4bit_embedding_weight_fsdp_fix():
def test_4bit_embedding_weight_fsdp_fix(requires_cuda):
num_embeddings = 64
embedding_dim = 32
......@@ -754,7 +557,7 @@ def test_4bit_embedding_weight_fsdp_fix():
assert module.weight.quant_state is not None
def test_4bit_linear_weight_fsdp_fix():
def test_4bit_linear_weight_fsdp_fix(requires_cuda):
inp_size = 64
out_size = 32
......
......@@ -4,11 +4,11 @@ import pytest
import torch
import bitsandbytes
from tests.helpers import TRUE_FALSE, id_formatter
from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter
class TestLLMInt8Ops:
@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("device", get_available_devices())
def test_int8_linear_matmul(self, device):
A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device)
B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device)
......@@ -20,7 +20,7 @@ class TestLLMInt8Ops:
torch.library.opcheck(torch.ops.bitsandbytes.int8_linear_matmul.default, (A, B))
@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("device", get_available_devices())
def test_int8_linear_matmul_out(self, device):
A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device)
B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device)
......@@ -35,7 +35,7 @@ class TestLLMInt8Ops:
torch.library.opcheck(torch.ops.bitsandbytes.int8_linear_matmul.out, (A, B, out))
@pytest.mark.parametrize("threshold", [0.0, 6.0])
@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("device", get_available_devices())
def test_int8_vectorwise_quant(self, threshold, device):
if device == "cpu":
pytest.skip("CPU implementation is not available")
......@@ -64,7 +64,7 @@ class TestLLMInt8Ops:
torch.library.opcheck(torch.ops.bitsandbytes.int8_vectorwise_quant, (A, threshold))
@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("device", get_available_devices())
def test_int8_mm_dequant(self, device):
A = torch.randint(-128, 127, (256, 256), dtype=torch.int32, device=device)
row_stats = torch.randn(256, dtype=torch.float32, device=device)
......@@ -77,7 +77,7 @@ class TestLLMInt8Ops:
torch.library.opcheck(torch.ops.bitsandbytes.int8_mm_dequant, (A, row_stats, col_stats))
@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
@pytest.mark.parametrize("has_bias", TRUE_FALSE)
def test_int8_scaled_mm(self, device, dtype, has_bias):
......@@ -96,7 +96,7 @@ class TestLLMInt8Ops:
class TestInt8BlockwiseQuantOps:
@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
def test_quantize_blockwise(self, device, dtype, blocksize):
......@@ -116,7 +116,7 @@ class TestInt8BlockwiseQuantOps:
torch.library.opcheck(torch.ops.bitsandbytes.quantize_blockwise, (A, code, blocksize))
@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
def test_dequantize_blockwise(self, device, dtype, blocksize):
......@@ -140,7 +140,7 @@ class TestInt8BlockwiseQuantOps:
class Test4bitBlockwiseQuantOps:
@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
@pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype"))
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
......@@ -164,7 +164,7 @@ class Test4bitBlockwiseQuantOps:
torch.library.opcheck(torch.ops.bitsandbytes.quantize_4bit, (A, blocksize, quant_type, storage_dtype))
@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
@pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype"))
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
......@@ -197,7 +197,7 @@ class Test4bitBlockwiseQuantOps:
torch.ops.bitsandbytes.dequantize_4bit.default, (A, absmax, blocksize, quant_type, shape, dtype)
)
@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
@pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype"))
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
......
......@@ -47,7 +47,6 @@ str2optimizers["momentum_pytorch"] = (
)
str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam)
str2optimizers["adam8bit"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False))
str2optimizers["adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True))
str2optimizers["paged_adam"] = (torch.optim.Adam, bnb.optim.PagedAdam)
str2optimizers["paged_adamw"] = (torch.optim.AdamW, bnb.optim.PagedAdamW)
......@@ -88,19 +87,14 @@ str2optimizers["paged_ademamix8bit_blockwise_scheduled"] = (
)
str2optimizers["lion"] = (Lion, bnb.optim.Lion)
str2optimizers["lion8bit"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=False))
str2optimizers["lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=True))
str2optimizers["paged_lion"] = (Lion, bnb.optim.PagedLion)
str2optimizers["lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=True))
str2optimizers["paged_lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.PagedLion8bit(pxx, block_wise=True))
str2optimizers["momentum"] = (
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False),
)
str2optimizers["momentum8bit"] = (
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False),
)
str2optimizers["momentum8bit_blockwise"] = (
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True),
......@@ -110,10 +104,6 @@ str2optimizers["rmsprop"] = (
lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False),
)
str2optimizers["rmsprop8bit"] = (
lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False),
)
str2optimizers["rmsprop8bit_blockwise"] = (
lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=True),
......@@ -128,8 +118,7 @@ str2statenames["paged_lion"] = [("exp_avg", "state1")]
str2statenames["momentum"] = [("momentum_buffer", "state1")]
str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["rmsprop"] = [("square_avg", "state1")]
str2statenames["adam8bit"] = [("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2")]
str2statenames["lamb8bit"] = [("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2")]
str2statenames["adam8bit_blockwise"] = [
("exp_avg", "state1", "qmap1", "absmax1"),
("exp_avg_sq", "state2", "qmap2", "absmax2"),
......@@ -142,10 +131,8 @@ str2statenames["paged_adamw8bit_blockwise"] = [
("exp_avg", "state1", "qmap1", "absmax1"),
("exp_avg_sq", "state2", "qmap2", "absmax2"),
]
str2statenames["momentum8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")]
str2statenames["lion8bit"] = [("exp_avg", "state1", "qmap1", "max1")]
str2statenames["momentum8bit_blockwise"] = [("momentum_buffer", "state1", "qmap1", "absmax1")]
str2statenames["rmsprop8bit"] = [("square_avg", "state1", "qmap1", "max1")]
str2statenames["rmsprop8bit_blockwise"] = [("square_avg", "state1", "qmap1", "absmax1")]
str2statenames["lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")]
str2statenames["paged_lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")]
......@@ -180,7 +167,7 @@ optimizer_names_32bit = [
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [32, 1024, 4097, 1], ids=id_formatter("dim2"))
def test_optimizer32bit(dim1, dim2, gtype, optim_name):
def test_optimizer32bit(requires_cuda, dim1, dim2, gtype, optim_name):
if gtype == torch.bfloat16 and optim_name in ["momentum", "rmsprop"]:
pytest.skip()
if dim1 == 1 and dim2 == 1:
......@@ -256,7 +243,7 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2"))
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=describe_dtype)
def test_global_config(dim1, dim2, gtype):
def test_global_config(requires_cuda, dim1, dim2, gtype):
if dim1 == 1 and dim2 == 1:
return
p1 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
......@@ -298,10 +285,11 @@ def test_global_config(dim1, dim2, gtype):
optimizer_names_8bit = [
"adam8bit",
"lion8bit",
"momentum8bit",
"rmsprop8bit",
# Non-blockwise optimizers are deprecated.
# "adam8bit",
# "lion8bit",
# "momentum8bit",
# "rmsprop8bit",
"adam8bit_blockwise",
"lion8bit_blockwise",
"momentum8bit_blockwise",
......@@ -315,7 +303,7 @@ optimizer_names_8bit = [
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
@pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
def test_optimizer8bit(dim1, dim2, gtype, optim_name):
def test_optimizer8bit(requires_cuda, dim1, dim2, gtype, optim_name):
torch.set_printoptions(precision=6)
if gtype == torch.bfloat16 and "blockwise" not in optim_name:
......@@ -479,7 +467,8 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
@pytest.mark.parametrize("gtype", [torch.float32], ids=describe_dtype)
@pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
@pytest.mark.deprecated
def test_adam_percentile_clipping(requires_cuda, dim1, dim2, gtype, optim_bits):
if dim1 == 1 and dim2 == 1:
return
p1 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment