Unverified Commit 1088ec52 authored by Matthew Douglas's avatar Matthew Douglas Committed by GitHub
Browse files

Updates for device agnosticism (#1601)

* Include device support tags for transformers multi-backend compatability; add xpu() and cpu() to Params4bit

* Make test suite more device-agnostic

* Additional device agnostic tests

* Additional device agnosticism for tests

* Add BNB_TEST_DEVICE env var to manually select device for unit tests

* Include device support tags for transformers multi-backend compatability; add xpu() and cpu() to Params4bit

* Make test suite more device-agnostic

* Additional device agnostic tests

* Additional device agnosticism for tests

* Add BNB_TEST_DEVICE env var to manually select device for unit tests

* Small bugfix for int8 test

* Exclude backward() from code coverage reports

* Params4bit: don't try to quantize when moving to meta device
parent 97073cdb
...@@ -20,15 +20,15 @@ from .nn import modules ...@@ -20,15 +20,15 @@ from .nn import modules
from .optim import adam from .optim import adam
# This is a signal for integrations with transformers/diffusers. # This is a signal for integrations with transformers/diffusers.
# Eventually, we will remove this and check based on release version. # Eventually we may remove this but it is currently required for compatibility.
features = {"multi-backend"} features = {"multi-backend"}
supported_torch_devices = { supported_torch_devices = {
"cuda",
"cpu", "cpu",
# "mps", "cuda", # NVIDIA/AMD GPU
# "xpu", "xpu", # Intel GPU
# "hpu", "hpu", # Gaudi
# "npu", "npu", # Ascend NPU
"mps", # Apple Silicon
} }
if torch.cuda.is_available(): if torch.cuda.is_available():
......
...@@ -284,7 +284,7 @@ class MatMul8bitLt(torch.autograd.Function): ...@@ -284,7 +284,7 @@ class MatMul8bitLt(torch.autograd.Function):
dtype=torch.float16, dtype=torch.float16,
) )
if state.threshold > 0.0 and subA is not None: if state.threshold > 0.0 and subA is not None and subA.numel() > 0:
grad_B[:, idx] += torch.matmul(grad_output.t(), subA) grad_B[:, idx] += torch.matmul(grad_output.t(), subA)
if req_gradA: if req_gradA:
......
...@@ -341,7 +341,7 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8) ...@@ -341,7 +341,7 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
for i in range(gap): for i in range(gap):
values.append(0) values.append(0)
values.sort() values.sort()
code = torch.Tensor(values) code = torch.tensor(values)
code /= code.max() code /= code.max()
return code return code
......
...@@ -306,9 +306,15 @@ class Params4bit(torch.nn.Parameter): ...@@ -306,9 +306,15 @@ class Params4bit(torch.nn.Parameter):
self.bnb_quantized = True self.bnb_quantized = True
return self return self
def cpu(self):
return self.to(device="cpu")
def cuda(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False): def cuda(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False):
return self.to(device="cuda" if device is None else device, non_blocking=non_blocking) return self.to(device="cuda" if device is None else device, non_blocking=non_blocking)
def xpu(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False):
return self.to(device="xpu" if device is None else device, non_blocking=non_blocking)
@overload @overload
def to( def to(
self: T, self: T,
...@@ -326,7 +332,7 @@ class Params4bit(torch.nn.Parameter): ...@@ -326,7 +332,7 @@ class Params4bit(torch.nn.Parameter):
def to(self, *args, **kwargs): def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
if device is not None and device.type == "cuda" and not self.bnb_quantized: if device is not None and device.type != "meta" and not self.bnb_quantized:
return self._quantize(device) return self._quantize(device)
else: else:
if self.quant_state is not None: if self.quant_state is not None:
......
...@@ -79,6 +79,12 @@ include = ["bitsandbytes*"] ...@@ -79,6 +79,12 @@ include = ["bitsandbytes*"]
[tool.setuptools.dynamic] [tool.setuptools.dynamic]
version = {attr = "bitsandbytes.__version__"} version = {attr = "bitsandbytes.__version__"}
[tool.coverage.report]
exclude_also = [
# exclude backward() functions from coverage, as they are invoked from C++
'def backward\(ctx'
]
[tool.pytest.ini_options] [tool.pytest.ini_options]
addopts = "-rP -m 'not slow and not benchmark and not deprecated'" addopts = "-rP -m 'not slow and not benchmark and not deprecated'"
# ; --cov=bitsandbytes # ; --cov=bitsandbytes
......
import functools
from io import BytesIO from io import BytesIO
from itertools import product from itertools import product
import os
import random import random
from typing import Any from typing import Any
...@@ -13,6 +15,38 @@ BOOLEAN_TRIPLES = list(product(TRUE_FALSE, repeat=3)) # all combinations of (bo ...@@ -13,6 +15,38 @@ BOOLEAN_TRIPLES = list(product(TRUE_FALSE, repeat=3)) # all combinations of (bo
BOOLEAN_TUPLES = list(product(TRUE_FALSE, repeat=2)) # all combinations of (bool, bool) BOOLEAN_TUPLES = list(product(TRUE_FALSE, repeat=2)) # all combinations of (bool, bool)
@functools.cache
def get_available_devices():
if "BNB_TEST_DEVICE" in os.environ:
# If the environment variable is set, use it directly.
return [os.environ["BNB_TEST_DEVICE"]]
devices = ["cpu"]
if hasattr(torch, "accelerator"):
# PyTorch 2.6+ - determine accelerator using agnostic API.
if torch.accelerator.is_available():
devices += [str(torch.accelerator.current_accelerator())]
else:
if torch.cuda.is_available():
devices += ["cuda"]
if torch.backends.mps.is_available():
devices += ["mps"]
if hasattr(torch, "xpu") and torch.xpu.is_available():
devices += ["xpu"]
custom_backend_name = torch._C._get_privateuse1_backend_name()
custom_backend_module = getattr(torch, custom_backend_name, None)
custom_backend_is_available_fn = getattr(custom_backend_module, "is_available", None)
if custom_backend_is_available_fn and custom_backend_module.is_available():
devices += [custom_backend_name]
return devices
def torch_save_to_buffer(obj): def torch_save_to_buffer(obj):
buffer = BytesIO() buffer = BytesIO()
torch.save(obj, buffer) torch.save(obj, buffer)
......
...@@ -6,12 +6,14 @@ from tests.helpers import ( ...@@ -6,12 +6,14 @@ from tests.helpers import (
BOOLEAN_TRIPLES, BOOLEAN_TRIPLES,
TRUE_FALSE, TRUE_FALSE,
describe_dtype, describe_dtype,
get_available_devices,
id_formatter, id_formatter,
) )
TRANSPOSE_VALS = [(False, True), (False, False)] TRANSPOSE_VALS = [(False, True), (False, False)]
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dim1", [40], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim1", [40], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [64, 0], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim2", [64, 0], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", [32], ids=id_formatter("dim3")) @pytest.mark.parametrize("dim3", [32], ids=id_formatter("dim3"))
...@@ -27,10 +29,16 @@ TRANSPOSE_VALS = [(False, True), (False, False)] ...@@ -27,10 +29,16 @@ TRANSPOSE_VALS = [(False, True), (False, False)]
@pytest.mark.parametrize("transpose", TRANSPOSE_VALS, ids=id_formatter("transpose")) @pytest.mark.parametrize("transpose", TRANSPOSE_VALS, ids=id_formatter("transpose"))
@pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights")) @pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights"))
@pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias")) @pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias"))
def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights, has_bias): def test_matmullt(
device, dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights, has_bias
):
if device != "cuda" and funcs[1] == bnb.research.switchback_bnb:
# TODO: Deprecate/remove?
pytest.skip("switchback_bnb only works on CUDA.")
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
outlier_dim = torch.randint(0, dimA[1], size=(dimA[1] // 8,), device="cuda") outlier_dim = torch.randint(0, dimA[1], size=(dimA[1] // 8,), device=device)
if has_bias == False: if has_bias == False:
req_grad = list(req_grad) req_grad = list(req_grad)
req_grad[2] = False req_grad[2] = False
...@@ -38,21 +46,21 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec ...@@ -38,21 +46,21 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
for i in range(3): for i in range(3):
# normal multiply # normal multiply
if funcs[0] in [torch.mm, torch.matmul]: if funcs[0] in [torch.mm, torch.matmul]:
A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype) A = torch.randn(size=dimA, device=device, requires_grad=req_grad[0], dtype=dtype)
if decomp == 6.0: if decomp == 6.0:
with torch.no_grad(): with torch.no_grad():
A[:, outlier_dim] = 6.0 A[:, outlier_dim] = 6.0
B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype) B = torch.randn(size=dimB, device=device, requires_grad=req_grad[1], dtype=dtype)
target = torch.randn( target = torch.randn(
size=(dim2, dim4), size=(dim2, dim4),
device="cuda", device=device,
requires_grad=req_grad[1], requires_grad=req_grad[1],
dtype=dtype, dtype=dtype,
) )
bias = None bias = None
bias2 = None bias2 = None
if has_bias: if has_bias:
bias = torch.randn(dim4, device="cuda", dtype=dtype, requires_grad=req_grad[2]) bias = torch.randn(dim4, device=device, dtype=dtype, requires_grad=req_grad[2])
bias2 = bias.clone() bias2 = bias.clone()
torch.nn.init.xavier_uniform_(B) torch.nn.init.xavier_uniform_(B)
B2 = B.clone() B2 = B.clone()
...@@ -91,7 +99,8 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec ...@@ -91,7 +99,8 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
if has_fp16_weights: if has_fp16_weights:
if any(req_grad): if any(req_grad):
out_bnb.data.copy_(out_torch) out_bnb.data.copy_(out_torch)
torch.cuda.synchronize() if device == "cuda":
torch.cuda.synchronize()
loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean() loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
loss_bnb.backward() loss_bnb.backward()
gradA1 = A.grad gradA1 = A.grad
...@@ -135,6 +144,7 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec ...@@ -135,6 +144,7 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
torch.testing.assert_close(gradBias1, gradBias2) torch.testing.assert_close(gradBias1, gradBias2)
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dim1", [48], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim1", [48], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [64, 0], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim2", [64, 0], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", [64], ids=id_formatter("dim3")) @pytest.mark.parametrize("dim3", [64], ids=id_formatter("dim3"))
...@@ -147,6 +157,7 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec ...@@ -147,6 +157,7 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"], ids=id_formatter("quant_type")) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"], ids=id_formatter("quant_type"))
def test_matmul_4bit( def test_matmul_4bit(
device,
dim1, dim1,
dim2, dim2,
dim3, dim3,
...@@ -159,6 +170,9 @@ def test_matmul_4bit( ...@@ -159,6 +170,9 @@ def test_matmul_4bit(
compress_statistics, compress_statistics,
quant_type, quant_type,
): ):
if device == "cpu" and quant_type == "fp4":
pytest.skip("Only nf4 is supported on CPU")
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
if has_bias == False: if has_bias == False:
...@@ -168,13 +182,13 @@ def test_matmul_4bit( ...@@ -168,13 +182,13 @@ def test_matmul_4bit(
for i in range(3): for i in range(3):
# normal multiply # normal multiply
if funcs[0] in [torch.mm, torch.matmul]: if funcs[0] in [torch.mm, torch.matmul]:
A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype) A = torch.randn(size=dimA, device=device, requires_grad=req_grad[0], dtype=dtype)
B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype) B = torch.randn(size=dimB, device=device, requires_grad=req_grad[1], dtype=dtype)
target = torch.randn(size=(dim2, dim4), device="cuda", requires_grad=req_grad[1], dtype=dtype) target = torch.randn(size=(dim2, dim4), device=device, requires_grad=req_grad[1], dtype=dtype)
bias = None bias = None
bias2 = None bias2 = None
if has_bias: if has_bias:
bias = torch.randn(dim4, device="cuda", dtype=dtype, requires_grad=req_grad[2]) bias = torch.randn(dim4, device=device, dtype=dtype, requires_grad=req_grad[2])
bias2 = bias.clone() bias2 = bias.clone()
torch.nn.init.xavier_uniform_(B) torch.nn.init.xavier_uniform_(B)
...@@ -204,7 +218,8 @@ def test_matmul_4bit( ...@@ -204,7 +218,8 @@ def test_matmul_4bit(
# assert err < 0.20 # assert err < 0.20
if any(req_grad): if any(req_grad):
out_bnb.data.copy_(out_torch) out_bnb.data.copy_(out_torch)
torch.cuda.synchronize() if device == "cuda":
torch.cuda.synchronize()
loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean() loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
loss_bnb.backward() loss_bnb.backward()
gradA1 = A.grad gradA1 = A.grad
......
...@@ -13,6 +13,7 @@ from tests.helpers import ( ...@@ -13,6 +13,7 @@ from tests.helpers import (
BOOLEAN_TUPLES, BOOLEAN_TUPLES,
TRUE_FALSE, TRUE_FALSE,
describe_dtype, describe_dtype,
get_available_devices,
get_test_dims, get_test_dims,
id_formatter, id_formatter,
) )
...@@ -87,15 +88,26 @@ class Timer: ...@@ -87,15 +88,26 @@ class Timer:
class Test8BitBlockwiseQuantizeFunctional: class Test8BitBlockwiseQuantizeFunctional:
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
@pytest.mark.parametrize("nested", TRUE_FALSE, ids=id_formatter("nested")) @pytest.mark.parametrize("nested", TRUE_FALSE, ids=id_formatter("nested"))
@pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64]) @pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64])
@pytest.mark.parametrize("signed", TRUE_FALSE, ids=id_formatter("signed")) @pytest.mark.parametrize("signed", TRUE_FALSE, ids=id_formatter("signed"))
def test_dynamic_blockwise_quantization(self, dtype, nested, blocksize, signed): def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize, signed):
if device == "cpu":
# This test is slow on CPU, so avoid atypical use cases.
if nested:
pytest.skip("Not a typical use case.")
if blocksize != 256:
pytest.skip("Only blocksize 256 is the typical one supported on CPU.")
if dtype != torch.float32:
pytest.xfail(f"CPU implementation currently only supports float32, got {dtype}")
diffs = [] diffs = []
reldiffs = [] reldiffs = []
for i in range(100): for i in range(100):
A1 = torch.randn(1024, 1024, device="cuda", dtype=dtype) A1 = torch.randn(1024, 1024, device=device, dtype=dtype)
C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested) C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested)
A2 = F.dequantize_blockwise(C, S) A2 = F.dequantize_blockwise(C, S)
diff = torch.abs(A1 - A2).float() diff = torch.abs(A1 - A2).float()
...@@ -113,7 +125,7 @@ class Test8BitBlockwiseQuantizeFunctional: ...@@ -113,7 +125,7 @@ class Test8BitBlockwiseQuantizeFunctional:
diffs = [] diffs = []
code = F.create_dynamic_map(signed=signed) code = F.create_dynamic_map(signed=signed)
for i in range(100): for i in range(100):
A1 = torch.rand(1024, 1024, device="cuda", dtype=dtype) A1 = torch.rand(1024, 1024, device=device, dtype=dtype)
C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested, code=code) C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested, code=code)
A2 = F.dequantize_blockwise(C, S) A2 = F.dequantize_blockwise(C, S)
diff = torch.abs(A1 - A2).float() diff = torch.abs(A1 - A2).float()
...@@ -154,21 +166,27 @@ class Test8BitBlockwiseQuantizeFunctional: ...@@ -154,21 +166,27 @@ class Test8BitBlockwiseQuantizeFunctional:
# print(sum(diffs)/len(diffs)) # print(sum(diffs)/len(diffs))
# print(sum(reldiffs)/len(reldiffs)) # print(sum(reldiffs)/len(reldiffs))
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("bits", range(2, 9), ids=id_formatter("bits")) @pytest.mark.parametrize("bits", range(2, 9), ids=id_formatter("bits"))
@pytest.mark.parametrize("method", ["linear", "fp8", "dynamic", "quantile"]) @pytest.mark.parametrize("method", ["linear", "fp8", "dynamic", "quantile"])
def test_few_bit_quant(self, bits, method): def test_few_bit_quant(self, device, bits, method):
if device == "cpu" and bits != 8:
pytest.skip("CPU implementation only supports 8 bits")
abserrs = [] abserrs = []
relerrs = [] relerrs = []
code = None code = None
if method == "linear": if method == "linear":
code = F.create_linear_map(True, total_bits=bits).cuda() code = F.create_linear_map(True, total_bits=bits).to(device)
elif method == "fp8": elif method == "fp8":
ebits = math.ceil(bits / 2) ebits = math.ceil(bits / 2)
pbits = bits - ebits - 1 pbits = bits - ebits - 1
code = F.create_fp8_map(True, ebits, pbits, bits).cuda() code = F.create_fp8_map(True, ebits, pbits, bits).to(device)
elif method == "dynamic": elif method == "dynamic":
code = F.create_dynamic_map(True, bits - 0, bits).cuda() code = F.create_dynamic_map(True, bits - 0, bits).to(device)
elif method == "quantile": elif method == "quantile":
if device != "cuda":
pytest.xfail("Quantile map only works on CUDA")
values = torch.randn(2048, 2048, device="cuda") values = torch.randn(2048, 2048, device="cuda")
code = F.create_quantile_map(values, bits).cuda() code = F.create_quantile_map(values, bits).cuda()
# for some data types we have no zero # for some data types we have no zero
...@@ -178,7 +196,7 @@ class Test8BitBlockwiseQuantizeFunctional: ...@@ -178,7 +196,7 @@ class Test8BitBlockwiseQuantizeFunctional:
# print(method, (code==0).sum()) # print(method, (code==0).sum())
assert code.numel() == 256 assert code.numel() == 256
for i in range(10): for i in range(10):
values = torch.randn(1, 32, device="cuda") values = torch.randn(1, 32, device=device)
values /= values.abs().max() values /= values.abs().max()
# values[values.abs() < 1e-6] += 1e-5 # values[values.abs() < 1e-6] += 1e-5
...@@ -189,8 +207,8 @@ class Test8BitBlockwiseQuantizeFunctional: ...@@ -189,8 +207,8 @@ class Test8BitBlockwiseQuantizeFunctional:
q1.append(idx.item()) q1.append(idx.item())
v1.append(code[idx].item()) v1.append(code[idx].item())
q1 = torch.Tensor(q1).cuda() q1 = torch.tensor(q1, device=device)
v1 = torch.Tensor(v1).cuda() v1 = torch.tensor(v1, device=device)
q2, S2 = F.quantize_blockwise(values, code=code) q2, S2 = F.quantize_blockwise(values, code=code)
v2 = F.dequantize_blockwise(q2, S2) v2 = F.dequantize_blockwise(q2, S2)
...@@ -206,15 +224,20 @@ class Test8BitBlockwiseQuantizeFunctional: ...@@ -206,15 +224,20 @@ class Test8BitBlockwiseQuantizeFunctional:
else: else:
torch.testing.assert_close(q1, q2) torch.testing.assert_close(q1, q2)
def test_fp8_quant(self): @pytest.mark.parametrize("device", get_available_devices())
def test_fp8_quant(self, device):
# TODO
if device == "cpu":
pytest.skip("CPU implementation segfaults")
for e_bits in range(1, 7): for e_bits in range(1, 7):
p_bits = 7 - e_bits p_bits = 7 - e_bits
code = F.create_fp8_map(True, e_bits, p_bits).cuda() code = F.create_fp8_map(True, e_bits, p_bits).to(device)
abserr = [] abserr = []
relerr = [] relerr = []
for i in range(100): for i in range(100):
A1 = torch.randn(1024, 1024, device="cuda") A1 = torch.randn(1024, 1024, device=device)
C, SC = F.quantize_blockwise(A1, code=code) C, SC = F.quantize_blockwise(A1, code=code)
A2 = F.dequantize_blockwise(C, SC) A2 = F.dequantize_blockwise(C, SC)
diff = torch.abs(A1 - A2) diff = torch.abs(A1 - A2)
...@@ -228,7 +251,7 @@ class Test8BitBlockwiseQuantizeFunctional: ...@@ -228,7 +251,7 @@ class Test8BitBlockwiseQuantizeFunctional:
abserr = [] abserr = []
relerr = [] relerr = []
for i in range(100): for i in range(100):
A1 = torch.rand(1024, 1024, device="cuda") A1 = torch.rand(1024, 1024, device=device)
C, SC = F.quantize_blockwise(A1, code=code) C, SC = F.quantize_blockwise(A1, code=code)
A2 = F.dequantize_blockwise(C, SC) A2 = F.dequantize_blockwise(C, SC)
diff = torch.abs(A1 - A2) diff = torch.abs(A1 - A2)
...@@ -242,7 +265,7 @@ class Test8BitBlockwiseQuantizeFunctional: ...@@ -242,7 +265,7 @@ class Test8BitBlockwiseQuantizeFunctional:
abserr = [] abserr = []
relerr = [] relerr = []
for i in range(100): for i in range(100):
A1 = torch.randn(1024, 1024, device="cuda") A1 = torch.randn(1024, 1024, device=device)
C, SC = F.quantize_blockwise(A1) C, SC = F.quantize_blockwise(A1)
A2 = F.dequantize_blockwise(C, SC) A2 = F.dequantize_blockwise(C, SC)
diff = torch.abs(A1 - A2) diff = torch.abs(A1 - A2)
...@@ -329,6 +352,7 @@ methods = { ...@@ -329,6 +352,7 @@ methods = {
} }
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required")
class TestIGEMMFunctional: class TestIGEMMFunctional:
@pytest.mark.parametrize("dim1", [1024 * 2], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim1", [1024 * 2], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [1024 * 16], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim2", [1024 * 16], ids=id_formatter("dim2"))
...@@ -532,36 +556,38 @@ class TestIGEMMFunctional: ...@@ -532,36 +556,38 @@ class TestIGEMMFunctional:
class TestLLMInt8Functional: class TestLLMInt8Functional:
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dim1", [128], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim1", [128], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [256], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim2", [256], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", [499, 512], ids=id_formatter("dim3")) @pytest.mark.parametrize("dim3", [499, 512], ids=id_formatter("dim3"))
@pytest.mark.parametrize("dim4", [512], ids=id_formatter("dim4")) @pytest.mark.parametrize("dim4", [512], ids=id_formatter("dim4"))
@pytest.mark.parametrize("dims", (2, 3), ids=id_formatter("dims")) @pytest.mark.parametrize("dims", (2, 3), ids=id_formatter("dims"))
@pytest.mark.parametrize("ldb", (0,), ids=id_formatter("ldb")) @pytest.mark.parametrize("ldb", (0,), ids=id_formatter("ldb"))
def test_int8_linear_matmul(self, dim1, dim2, dim3, dim4, dims, ldb): def test_int8_linear_matmul(self, device, dim1, dim2, dim3, dim4, dims, ldb):
for i in range(k): for i in range(k):
if dims == 2: if dims == 2:
A = torch.randint(-128, 127, size=(dim1, dim3), device="cuda").to(torch.int8) A = torch.randint(-128, 127, size=(dim1, dim3), dtype=torch.int8, device=device)
elif dims == 3: elif dims == 3:
A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(torch.int8) A = torch.randint(-128, 127, size=(dim1, dim2, dim3), dtype=torch.int8, device=device)
B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(torch.int8) B = torch.randint(-128, 127, size=(dim4, dim3), dtype=torch.int8, device=device)
C1 = torch.matmul(A.float(), B.t().float()) C1 = torch.matmul(A.float(), B.t().float())
C2 = F.int8_linear_matmul(A, B) C2 = F.int8_linear_matmul(A, B)
torch.testing.assert_close(C1, C2.float()) torch.testing.assert_close(C1, C2.float())
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dim1", [32], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim1", [32], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [32], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim2", [32], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", [32], ids=id_formatter("dim3")) @pytest.mark.parametrize("dim3", [32], ids=id_formatter("dim3"))
@pytest.mark.parametrize("dim4", [32], ids=id_formatter("dim4")) @pytest.mark.parametrize("dim4", [32], ids=id_formatter("dim4"))
@pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims")) @pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims"))
def test_int8_linear_matmul_half(self, dim1, dim2, dim3, dim4, dims): def test_int8_linear_matmul_half(self, device, dim1, dim2, dim3, dim4, dims):
for i in range(k): for i in range(k):
if dims == 2: if dims == 2:
A = torch.normal(0, 0.5, size=(dim1, dim3), device="cuda").half() A = torch.normal(0, 0.5, size=(dim1, dim3), device=device).half()
elif dims == 3: elif dims == 3:
A = torch.normal(0, 0.5, size=(dim1, dim2, dim3), device="cuda").half() A = torch.normal(0, 0.5, size=(dim1, dim2, dim3), device=device).half()
B = torch.randn((dim4, dim3), device="cuda").half() B = torch.randn((dim4, dim3), device=device).half()
torch.nn.init.xavier_uniform_(B) torch.nn.init.xavier_uniform_(B)
C1 = torch.matmul(A, B.t()) C1 = torch.matmul(A, B.t())
...@@ -573,19 +599,20 @@ class TestLLMInt8Functional: ...@@ -573,19 +599,20 @@ class TestLLMInt8Functional:
torch.testing.assert_close(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05) torch.testing.assert_close(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05)
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dim1", (64, 256), ids=id_formatter("dim1")) @pytest.mark.parametrize("dim1", (64, 256), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim4", (64, 1024), ids=id_formatter("dim4")) @pytest.mark.parametrize("dim4", (64, 1024), ids=id_formatter("dim4"))
@pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims")) @pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims"))
@pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias")) @pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias"))
def test_dequant_mm(self, dim1, dim4, dims, has_bias): def test_dequant_mm(self, device, dim1, dim4, dims, has_bias):
inner = 128 inner = 128
bias = None bias = None
if has_bias: if has_bias:
bias = torch.randn(dim4, device="cuda", dtype=torch.float16) bias = torch.randn(dim4, device=device, dtype=torch.float16)
for i in range(1): for i in range(1):
A = torch.randn(dim1, inner, device="cuda") A = torch.randn(dim1, inner, device=device)
B = torch.randn(dim4, inner, device="cuda") B = torch.randn(dim4, inner, device=device)
C1 = torch.matmul(A.half(), B.t().half()) C1 = torch.matmul(A.half(), B.t().half())
if has_bias: if has_bias:
C1 += bias C1 += bias
...@@ -618,6 +645,7 @@ class TestLLMInt8Functional: ...@@ -618,6 +645,7 @@ class TestLLMInt8Functional:
@pytest.mark.parametrize("dim2", [1 * 1024], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim2", [1 * 1024], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims")) @pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims"))
@pytest.mark.parametrize("threshold", [0.0, 3.0], ids=id_formatter("decomp")) @pytest.mark.parametrize("threshold", [0.0, 3.0], ids=id_formatter("decomp"))
@pytest.mark.deprecated
def test_colrow_absmax(self, dim1, dim2, dims, threshold): def test_colrow_absmax(self, dim1, dim2, dims, threshold):
for i in range(k): for i in range(k):
A = torch.randn(dim1, dim2, device="cuda").half() A = torch.randn(dim1, dim2, device="cuda").half()
...@@ -654,6 +682,7 @@ class TestLLMInt8Functional: ...@@ -654,6 +682,7 @@ class TestLLMInt8Functional:
@pytest.mark.parametrize("dim1", [2048, 4096], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim1", [2048, 4096], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [512, 1024], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim2", [512, 1024], ids=id_formatter("dim2"))
@pytest.mark.deprecated
def test_int8_double_quant(self, dim1, dim2): def test_int8_double_quant(self, dim1, dim2):
for i in range(k): for i in range(k):
A = torch.randn(dim1, dim2, device="cuda").half() A = torch.randn(dim1, dim2, device="cuda").half()
...@@ -686,6 +715,7 @@ class TestLLMInt8Functional: ...@@ -686,6 +715,7 @@ class TestLLMInt8Functional:
torch.testing.assert_close(Srow.flatten().float(), statsA) torch.testing.assert_close(Srow.flatten().float(), statsA)
torch.testing.assert_close(Scol.flatten().float(), statsAt) torch.testing.assert_close(Scol.flatten().float(), statsAt)
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize( @pytest.mark.parametrize(
("dim1", "dim4", "inner"), ("dim1", "dim4", "inner"),
( (
...@@ -697,10 +727,10 @@ class TestLLMInt8Functional: ...@@ -697,10 +727,10 @@ class TestLLMInt8Functional:
) )
), ),
) )
def test_integrated_int8_linear_matmul(self, dim1, dim4, inner): def test_integrated_int8_linear_matmul(self, device, dim1, dim4, inner):
for i in range(k): for i in range(k):
A = torch.randn(dim1, inner, device="cuda").half() A = torch.randn(dim1, inner, device=device).half()
B = torch.randn(dim4, inner, device="cuda").half() B = torch.randn(dim4, inner, device=device).half()
out1 = torch.matmul(A.half(), B.t().half()) out1 = torch.matmul(A.half(), B.t().half())
...@@ -724,12 +754,13 @@ class TestLLMInt8Functional: ...@@ -724,12 +754,13 @@ class TestLLMInt8Functional:
err2 = torch.abs(out1 - out3).mean().item() err2 = torch.abs(out1 - out3).mean().item()
assert err2 <= err1 * 1.025 assert err2 <= err1 * 1.025
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [1024, 4096], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim2", [1024, 4096], ids=id_formatter("dim2"))
def test_coo_double_quant(self, dim1, dim2): def test_coo_double_quant(self, device, dim1, dim2):
threshold = 2.00 threshold = 2.00
for i in range(k): for i in range(k):
A = torch.randn(dim1, dim2, device="cuda").half() A = torch.randn(dim1, dim2, device=device).half()
idx = torch.abs(A) >= threshold idx = torch.abs(A) >= threshold
CA, statsA, outlier_cols = F.int8_vectorwise_quant(A, threshold=threshold) CA, statsA, outlier_cols = F.int8_vectorwise_quant(A, threshold=threshold)
...@@ -743,12 +774,13 @@ class TestLLMInt8Functional: ...@@ -743,12 +774,13 @@ class TestLLMInt8Functional:
A2 = (CA.float() * statsA.unsqueeze(1) / 127).half() A2 = (CA.float() * statsA.unsqueeze(1) / 127).half()
torch.testing.assert_close(A, A2, rtol=0.05, atol=1.5e-2) torch.testing.assert_close(A, A2, rtol=0.05, atol=1.5e-2)
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [1024, 4096], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim2", [1024, 4096], ids=id_formatter("dim2"))
def test_coo_int8_vectorwise_quant(self, dim1, dim2): def test_coo_int8_vectorwise_quant(self, device, dim1, dim2):
threshold = 3.00 threshold = 3.00
for i in range(k): for i in range(k):
A = torch.randn(dim1, dim2, device="cuda").half() A = torch.randn(dim1, dim2, device=device).half()
idx = torch.abs(A) >= threshold idx = torch.abs(A) >= threshold
CA, statsA, outlier_cols = F.int8_vectorwise_quant(A, threshold=threshold) CA, statsA, outlier_cols = F.int8_vectorwise_quant(A, threshold=threshold)
...@@ -759,6 +791,7 @@ class TestLLMInt8Functional: ...@@ -759,6 +791,7 @@ class TestLLMInt8Functional:
torch.testing.assert_close(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2) torch.testing.assert_close(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required")
class TestSpMMFunctional: class TestSpMMFunctional:
@pytest.mark.parametrize("dim1", [256, 1024], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim1", [256, 1024], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [128, 512], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim2", [128, 512], ids=id_formatter("dim2"))
...@@ -1025,6 +1058,7 @@ class TestSpMMFunctional: ...@@ -1025,6 +1058,7 @@ class TestSpMMFunctional:
print("partial matmul", time.time() - t0) print("partial matmul", time.time() - t0)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required")
class TestSparseTensorFunctional: class TestSparseTensorFunctional:
def test_coo2csr(self): def test_coo2csr(self):
threshold = 1 threshold = 1
...@@ -1063,11 +1097,12 @@ class TestSparseTensorFunctional: ...@@ -1063,11 +1097,12 @@ class TestSparseTensorFunctional:
class TestQuantize4BitFunctional: class TestQuantize4BitFunctional:
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512, 1024, 2048, 4096]) @pytest.mark.parametrize("blocksize", [64, 128, 256, 512, 1024, 2048, 4096])
def test_4bit_quant(self, dtype, quant_type, blocksize): def test_4bit_quant(self, device, dtype, quant_type, blocksize):
A1 = torch.randn(1024, 1024, device="cuda", dtype=dtype) A1 = torch.randn(1024, 1024, device=device, dtype=dtype)
qa, SA = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type) qa, SA = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type)
A2 = F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type) A2 = F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type)
...@@ -1095,13 +1130,14 @@ class TestQuantize4BitFunctional: ...@@ -1095,13 +1130,14 @@ class TestQuantize4BitFunctional:
# 1024 => 0.8, 2048 => 0.88, 4096 => 0.96 # 1024 => 0.8, 2048 => 0.88, 4096 => 0.96
assert err.item() < math.log2(blocksize) * 8e-2 assert err.item() < math.log2(blocksize) * 8e-2
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
@pytest.mark.parametrize("blocksize", [64, 128], ids=id_formatter("blocksize")) @pytest.mark.parametrize("blocksize", [64, 128], ids=id_formatter("blocksize"))
def test_4bit_compressed_stats(self, quant_type, blocksize): def test_4bit_compressed_stats(self, device, quant_type, blocksize):
errs1 = [] errs1 = []
errs2 = [] errs2 = []
for i in range(10): for i in range(10):
A1 = torch.randn(1024, 1024, device="cuda").half() A1 = torch.randn(1024, 1024, device=device).half()
q2, SA2 = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type) q2, SA2 = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type)
q3, SA3 = F.quantize_4bit(A1, blocksize=blocksize, compress_statistics=True, quant_type=quant_type) q3, SA3 = F.quantize_4bit(A1, blocksize=blocksize, compress_statistics=True, quant_type=quant_type)
A2 = F.dequantize_4bit(q2, SA2, quant_type=quant_type) A2 = F.dequantize_4bit(q2, SA2, quant_type=quant_type)
...@@ -1127,6 +1163,7 @@ class TestQuantize4BitFunctional: ...@@ -1127,6 +1163,7 @@ class TestQuantize4BitFunctional:
# @pytest.mark.parametrize("quant_type", ['fp4', 'nf4']) # @pytest.mark.parametrize("quant_type", ['fp4', 'nf4'])
@pytest.mark.parametrize("quant_type", ["nf4"]) @pytest.mark.parametrize("quant_type", ["nf4"])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required")
@pytest.mark.benchmark @pytest.mark.benchmark
def test_bench_4bit_dequant(self, quant_type): def test_bench_4bit_dequant(self, quant_type):
blocksize = 256 blocksize = 256
...@@ -1157,6 +1194,7 @@ class TestQuantize4BitFunctional: ...@@ -1157,6 +1194,7 @@ class TestQuantize4BitFunctional:
# torch.cuda.synchronize() # torch.cuda.synchronize()
# print((time.time()-t0)/iters*1e6) # print((time.time()-t0)/iters*1e6)
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("double_quant", TRUE_FALSE, ids=lambda double_quant: f"DQ_{double_quant}") @pytest.mark.parametrize("double_quant", TRUE_FALSE, ids=lambda double_quant: f"DQ_{double_quant}")
@pytest.mark.parametrize("storage_type", ["nf4", "fp4"]) @pytest.mark.parametrize("storage_type", ["nf4", "fp4"])
@pytest.mark.parametrize("kind", ["fc1", "fc2", "attn", "attn_packed"]) @pytest.mark.parametrize("kind", ["fc1", "fc2", "attn", "attn_packed"])
...@@ -1167,7 +1205,7 @@ class TestQuantize4BitFunctional: ...@@ -1167,7 +1205,7 @@ class TestQuantize4BitFunctional:
ids=describe_dtype, ids=describe_dtype,
) )
@pytest.mark.parametrize("dim", [128, 256, 512, 1024], ids=id_formatter("dim")) @pytest.mark.parametrize("dim", [128, 256, 512, 1024], ids=id_formatter("dim"))
def test_gemv_4bit(self, dim, dtype, storage_type, quant_storage, double_quant, kind): def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double_quant, kind):
errs1 = [] errs1 = []
errs2 = [] errs2 = []
errs3 = [] errs3 = []
...@@ -1180,17 +1218,17 @@ class TestQuantize4BitFunctional: ...@@ -1180,17 +1218,17 @@ class TestQuantize4BitFunctional:
for i in range(100): for i in range(100):
if kind == "fc1": if kind == "fc1":
A = torch.randn(1, dim, dtype=dtype, device="cuda") A = torch.randn(1, dim, dtype=dtype, device=device)
B = torch.randn(dim * 4, dim, dtype=dtype, device="cuda") / math.sqrt(dim) B = torch.randn(dim * 4, dim, dtype=dtype, device=device) / math.sqrt(dim)
elif kind == "fc2": elif kind == "fc2":
A = torch.randn(1, 4 * dim, dtype=dtype, device="cuda") A = torch.randn(1, 4 * dim, dtype=dtype, device=device)
B = torch.randn(dim, 4 * dim, dtype=dtype, device="cuda") / math.sqrt(dim) B = torch.randn(dim, 4 * dim, dtype=dtype, device=device) / math.sqrt(dim)
elif kind == "attn": elif kind == "attn":
A = torch.randn(1, dim, dtype=dtype, device="cuda") A = torch.randn(1, dim, dtype=dtype, device=device)
B = torch.randn(dim, dim, dtype=dtype, device="cuda") / math.sqrt(dim) B = torch.randn(dim, dim, dtype=dtype, device=device) / math.sqrt(dim)
elif kind == "attn_packed": elif kind == "attn_packed":
A = torch.randn(1, dim, dtype=dtype, device="cuda") A = torch.randn(1, dim, dtype=dtype, device=device)
B = torch.randn(dim * 3, dim, dtype=dtype, device="cuda") / math.sqrt(dim) B = torch.randn(dim * 3, dim, dtype=dtype, device=device) / math.sqrt(dim)
qB, state = F.quantize_4bit( qB, state = F.quantize_4bit(
B, B,
...@@ -1294,18 +1332,19 @@ class TestQuantize4BitFunctional: ...@@ -1294,18 +1332,19 @@ class TestQuantize4BitFunctional:
assert relratio < 1.04 and relratio > 0.96 assert relratio < 1.04 and relratio > 0.96
assert maxratio < 1.02 and maxratio > 0.98 assert maxratio < 1.02 and maxratio > 0.98
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"]) @pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
@pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"]) @pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"])
def test_gemv_eye_4bit(self, storage_type, dtype, double_quant): def test_gemv_eye_4bit(self, device, storage_type, dtype, double_quant):
dims = 10 dims = 10
torch.random.manual_seed(np.random.randint(0, 412424242)) torch.random.manual_seed(np.random.randint(0, 412424242))
dims = get_test_dims(0, 8192, n=dims) dims = get_test_dims(0, 8192, n=dims)
dims = [dim + (64 - (dim % 64)) for dim in dims] dims = [dim + (64 - (dim % 64)) for dim in dims]
# for dim in [576, 5120, 3520, 5184, 1280, 4992, 5312, 2048]: # for dim in [576, 5120, 3520, 5184, 1280, 4992, 5312, 2048]:
for dim in dims: for dim in dims:
A = torch.normal(0, 0.1, size=(1, 1, dim), dtype=dtype, device="cuda") A = torch.normal(0, 0.1, size=(1, 1, dim), dtype=dtype, device=device)
B = torch.eye(dim, dtype=dtype, device="cuda") B = torch.eye(dim, dtype=dtype, device=device)
qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant) qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant)
C3 = torch.matmul(A, B.t()) C3 = torch.matmul(A, B.t())
......
...@@ -7,7 +7,7 @@ import pytest ...@@ -7,7 +7,7 @@ import pytest
import torch import torch
import bitsandbytes as bnb import bitsandbytes as bnb
from tests.helpers import TRUE_FALSE, torch_load_from_buffer, torch_save_to_buffer from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter, torch_load_from_buffer, torch_save_to_buffer
storage = { storage = {
"uint8": torch.uint8, "uint8": torch.uint8,
...@@ -17,15 +17,18 @@ storage = { ...@@ -17,15 +17,18 @@ storage = {
} }
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_storage", ["uint8", "float16", "bfloat16", "float32"]) @pytest.mark.parametrize("quant_storage", ["uint8", "float16", "bfloat16", "float32"])
@pytest.mark.parametrize("bias", TRUE_FALSE) @pytest.mark.parametrize("bias", TRUE_FALSE, ids=id_formatter("bias"))
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
@pytest.mark.parametrize("save_before_forward", TRUE_FALSE) @pytest.mark.parametrize("save_before_forward", TRUE_FALSE, ids=id_formatter("save_before_forward"))
def test_linear_serialization(quant_type, compress_statistics, bias, quant_storage, save_before_forward): def test_linear_serialization(device, quant_type, compress_statistics, bias, quant_storage, save_before_forward):
if device == "cpu":
pytest.xfail("Dequantization is not yet implemented for CPU")
original_dtype = torch.float16 original_dtype = torch.float16
compute_dtype = None compute_dtype = None
device = "cuda"
layer_shape = (300, 400) layer_shape = (300, 400)
linear = torch.nn.Linear(*layer_shape, dtype=original_dtype, device="cpu") # original layer linear = torch.nn.Linear(*layer_shape, dtype=original_dtype, device="cpu") # original layer
...@@ -52,7 +55,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora ...@@ -52,7 +55,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
# restoring from state_dict: # restoring from state_dict:
bias_data2 = sd.pop("bias", None) bias_data2 = sd.pop("bias", None)
weight_data2 = sd.pop("weight") weight_data2 = sd.pop("weight")
weight2 = bnb.nn.Params4bit.from_prequantized(quantized_stats=sd, data=weight_data2) weight2 = bnb.nn.Params4bit.from_prequantized(quantized_stats=sd, data=weight_data2, device=device)
# creating new layer with same params: # creating new layer with same params:
linear_q2 = bnb.nn.Linear4bit( linear_q2 = bnb.nn.Linear4bit(
...@@ -174,18 +177,50 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora ...@@ -174,18 +177,50 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
assert size_ratio < target_compression, ratio_error_msg assert size_ratio < target_compression, ratio_error_msg
def test_copy_param(): @pytest.mark.parametrize("device", get_available_devices())
tensor = torch.tensor([1.0, 2.0, 3.0, 4.0]) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
param = bnb.nn.Params4bit(data=tensor, requires_grad=False).cuda(0) @pytest.mark.parametrize("blocksize", [64, 128])
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
def test_copy_param(device, quant_type, blocksize, compress_statistics):
if device == "cpu":
if compress_statistics:
pytest.skip("Currently segfaults on CPU")
if quant_type == "fp4":
pytest.xfail("FP4 not supported on CPU")
tensor = torch.linspace(1, blocksize, blocksize)
param = bnb.nn.Params4bit(
data=tensor,
quant_type=quant_type,
blocksize=blocksize,
compress_statistics=compress_statistics,
requires_grad=False,
).to(device)
shallow_copy_param = copy.copy(param) shallow_copy_param = copy.copy(param)
assert param.quant_state is shallow_copy_param.quant_state assert param.quant_state is shallow_copy_param.quant_state
assert param.data.data_ptr() == shallow_copy_param.data.data_ptr() assert param.data.data_ptr() == shallow_copy_param.data.data_ptr()
def test_deepcopy_param(): @pytest.mark.parametrize("device", get_available_devices())
tensor = torch.tensor([1.0, 2.0, 3.0, 4.0]) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
param = bnb.nn.Params4bit(data=tensor, requires_grad=False).cuda(0) @pytest.mark.parametrize("blocksize", [64, 128])
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
def test_deepcopy_param(device, quant_type, blocksize, compress_statistics):
if device == "cpu":
if compress_statistics:
pytest.skip("Currently segfaults on CPU")
if quant_type == "fp4":
pytest.xfail("FP4 not supported on CPU")
tensor = torch.linspace(1, blocksize, blocksize)
param = bnb.nn.Params4bit(
data=tensor,
quant_type=quant_type,
blocksize=blocksize,
compress_statistics=compress_statistics,
requires_grad=False,
).to(device)
dict_keys_before = set(param.__dict__.keys()) dict_keys_before = set(param.__dict__.keys())
copy_param = copy.deepcopy(param) copy_param = copy.deepcopy(param)
dict_keys_after = set(param.__dict__.keys()) dict_keys_after = set(param.__dict__.keys())
...@@ -199,12 +234,27 @@ def test_deepcopy_param(): ...@@ -199,12 +234,27 @@ def test_deepcopy_param():
assert dict_keys_before == dict_keys_copy assert dict_keys_before == dict_keys_copy
def test_params4bit_real_serialization(): @pytest.mark.parametrize("device", get_available_devices())
original_tensor = torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float32) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
original_param = bnb.nn.Params4bit(data=original_tensor, quant_type="fp4") @pytest.mark.parametrize("blocksize", [64, 128])
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
def test_params4bit_real_serialization(device, quant_type, blocksize, compress_statistics):
if device == "cpu":
if compress_statistics:
pytest.skip("Currently segfaults on CPU")
if quant_type == "fp4":
pytest.xfail("FP4 not supported on CPU")
original_tensor = torch.linspace(1, blocksize, blocksize, dtype=torch.float32)
original_param = bnb.nn.Params4bit(
data=original_tensor,
quant_type=quant_type,
blocksize=blocksize,
compress_statistics=compress_statistics,
)
dict_keys_before = set(original_param.__dict__.keys()) dict_keys_before = set(original_param.__dict__.keys())
original_param.cuda(0) # move to CUDA to trigger quantization original_param.to(device) # change device to trigger quantization
serialized_param = pickle.dumps(original_param) serialized_param = pickle.dumps(original_param)
deserialized_param = pickle.loads(serialized_param) deserialized_param = pickle.loads(serialized_param)
......
...@@ -11,6 +11,7 @@ import bitsandbytes as bnb ...@@ -11,6 +11,7 @@ import bitsandbytes as bnb
from bitsandbytes.nn.modules import Linear8bitLt from bitsandbytes.nn.modules import Linear8bitLt
from tests.helpers import ( from tests.helpers import (
TRUE_FALSE, TRUE_FALSE,
get_available_devices,
id_formatter, id_formatter,
torch_load_from_buffer, torch_load_from_buffer,
torch_save_to_buffer, torch_save_to_buffer,
...@@ -19,7 +20,11 @@ from tests.helpers import ( ...@@ -19,7 +20,11 @@ from tests.helpers import (
# contributed by Alex Borzunov, see: # contributed by Alex Borzunov, see:
# https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py # https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py
def test_linear_no_igemmlt(): @pytest.mark.parametrize("device", get_available_devices())
def test_linear_no_igemmlt(device):
if device == "cpu":
pytest.xfail("Not yet implemented on CPU")
linear = torch.nn.Linear(1024, 3072) linear = torch.nn.Linear(1024, 3072)
x = torch.randn(3, 1024, dtype=torch.half) x = torch.randn(3, 1024, dtype=torch.half)
linear_custom = Linear8bitLt( linear_custom = Linear8bitLt(
...@@ -29,6 +34,8 @@ def test_linear_no_igemmlt(): ...@@ -29,6 +34,8 @@ def test_linear_no_igemmlt():
has_fp16_weights=False, has_fp16_weights=False,
threshold=6.0, threshold=6.0,
) )
# TODO: Remove, this is no longer implemented
linear_custom.state.force_no_igemmlt = True linear_custom.state.force_no_igemmlt = True
linear_custom.weight = bnb.nn.Int8Params( linear_custom.weight = bnb.nn.Int8Params(
...@@ -37,11 +44,11 @@ def test_linear_no_igemmlt(): ...@@ -37,11 +44,11 @@ def test_linear_no_igemmlt():
has_fp16_weights=False, has_fp16_weights=False,
).to(linear.weight.dtype) ).to(linear.weight.dtype)
linear_custom.bias = linear.bias linear_custom.bias = linear.bias
linear_custom = linear_custom.cuda() linear_custom = linear_custom.to(device)
linear = linear.half().cuda() linear = linear.half().to(device)
x_ref = x.clone().cuda().requires_grad_(True) x_ref = x.clone().to(device).requires_grad_(True)
x_ours = x.clone().cuda().requires_grad_(True) x_ours = x.clone().to(device).requires_grad_(True)
fx_ref = linear(x_ref).float() fx_ref = linear(x_ref).float()
grad_proj = torch.randn_like(fx_ref) grad_proj = torch.randn_like(fx_ref)
(fx_ref * grad_proj).mean().backward() (fx_ref * grad_proj).mean().backward()
...@@ -58,18 +65,25 @@ def test_linear_no_igemmlt(): ...@@ -58,18 +65,25 @@ def test_linear_no_igemmlt():
torch.testing.assert_close(x_ref.grad, x_ours.grad, atol=0.01, rtol=1e-5) torch.testing.assert_close(x_ref.grad, x_ours.grad, atol=0.01, rtol=1e-5)
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights")) @pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights"))
@pytest.mark.parametrize("threshold", [0.0, 6.0], ids=id_formatter("threshold"))
@pytest.mark.parametrize("serialize_before_forward", TRUE_FALSE, ids=id_formatter("serialize_before_forward")) @pytest.mark.parametrize("serialize_before_forward", TRUE_FALSE, ids=id_formatter("serialize_before_forward"))
@pytest.mark.parametrize("deserialize_before_cuda", TRUE_FALSE, ids=id_formatter("deserialize_before_cuda")) @pytest.mark.parametrize("deserialize_before_cuda", TRUE_FALSE, ids=id_formatter("deserialize_before_cuda"))
@pytest.mark.parametrize("save_before_forward", TRUE_FALSE, ids=id_formatter("save_before_forward")) @pytest.mark.parametrize("save_before_forward", TRUE_FALSE, ids=id_formatter("save_before_forward"))
@pytest.mark.parametrize("load_before_cuda", TRUE_FALSE, ids=id_formatter("load_before_cuda")) @pytest.mark.parametrize("load_before_cuda", TRUE_FALSE, ids=id_formatter("load_before_cuda"))
def test_linear_serialization( def test_linear_serialization(
device,
has_fp16_weights, has_fp16_weights,
threshold,
serialize_before_forward, serialize_before_forward,
deserialize_before_cuda, deserialize_before_cuda,
save_before_forward, save_before_forward,
load_before_cuda, load_before_cuda,
): ):
if device == "cpu":
pytest.xfail("Not yet implemented on CPU")
linear = torch.nn.Linear(32, 96) linear = torch.nn.Linear(32, 96)
# TODO: Fallback for bad shapes # TODO: Fallback for bad shapes
x = torch.randn(4, 32, dtype=torch.half) x = torch.randn(4, 32, dtype=torch.half)
...@@ -80,7 +94,7 @@ def test_linear_serialization( ...@@ -80,7 +94,7 @@ def test_linear_serialization(
linear.out_features, linear.out_features,
linear.bias is not None, linear.bias is not None,
has_fp16_weights=has_fp16_weights, has_fp16_weights=has_fp16_weights,
threshold=6.0, threshold=threshold,
) )
linear_custom.weight = bnb.nn.Int8Params( linear_custom.weight = bnb.nn.Int8Params(
...@@ -89,7 +103,7 @@ def test_linear_serialization( ...@@ -89,7 +103,7 @@ def test_linear_serialization(
has_fp16_weights=has_fp16_weights, has_fp16_weights=has_fp16_weights,
) )
linear_custom.bias = linear.bias linear_custom.bias = linear.bias
linear_custom = linear_custom.cuda() linear_custom = linear_custom.to(device)
if serialize_before_forward: if serialize_before_forward:
state_dict_8bit = linear_custom.state_dict() state_dict_8bit = linear_custom.state_dict()
...@@ -125,7 +139,7 @@ def test_linear_serialization( ...@@ -125,7 +139,7 @@ def test_linear_serialization(
linear.out_features, linear.out_features,
linear.bias is not None, linear.bias is not None,
has_fp16_weights=has_fp16_weights, has_fp16_weights=has_fp16_weights,
threshold=6.0, threshold=threshold,
) )
if deserialize_before_cuda: if deserialize_before_cuda:
...@@ -135,7 +149,7 @@ def test_linear_serialization( ...@@ -135,7 +149,7 @@ def test_linear_serialization(
if load_before_cuda: if load_before_cuda:
new_linear_custom2 = torch_load_from_buffer(bytes_8bit) new_linear_custom2 = torch_load_from_buffer(bytes_8bit)
new_linear_custom = new_linear_custom.cuda() new_linear_custom = new_linear_custom.to(device)
if not deserialize_before_cuda: if not deserialize_before_cuda:
new_linear_custom.load_state_dict(new_state_dict, strict=True) new_linear_custom.load_state_dict(new_state_dict, strict=True)
......
import inspect import inspect
import math
import einops
import pytest import pytest
import torch import torch
from torch import nn from torch import nn
import bitsandbytes as bnb import bitsandbytes as bnb
from tests.helpers import id_formatter from tests.helpers import get_available_devices, id_formatter
class MockArgs: class MockArgs:
...@@ -54,266 +52,32 @@ def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10): ...@@ -54,266 +52,32 @@ def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10):
torch.testing.assert_close(a, b, rtol=rtol, atol=atol) torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
class LinearFunction(torch.autograd.Function): @pytest.mark.parametrize("device", get_available_devices())
@staticmethod
def get_8bit_linear_trimmed(x, stochastic=False, trim_value=3.0):
round_func = LinearFunction.round_stoachastic if stochastic else torch.round
norm = math.sqrt(math.pi) / math.sqrt(2.0)
# std = torch.abs(x).mean()*norm
std = torch.std(x)
max1 = std * trim_value
x = x / max1 * 127
x = round_func(x)
x[x > 127] = 127
x[x < -127] = -127
x = x / 127 * max1
return x
def quant(x, quant_type, dim=1):
if quant_type == "linear":
max1 = torch.abs(x).max().float()
xq = torch.round(x / max1 * 127).to(torch.int8)
return xq, max1
elif quant_type == "vector":
max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
xq = torch.round(x / max1 * 127).to(torch.int8)
return xq, max1
elif quant_type == "min-max":
maxA = torch.amax(x, dim=dim, keepdim=True).float()
minA = torch.amin(x, dim=dim, keepdim=True).float()
scale = (maxA - minA) / 2.0
xq = torch.round(127 * (x - minA - scale) / scale).to(torch.int8)
return xq, (minA.float(), scale.float())
else:
return None
def dequant(xq, S1, S2, dtype, quant_type):
if quant_type == "linear":
norm = S1 * S2 / (127 * 127)
# double cast needed to prevent overflows
return (xq.float() * norm).to(dtype)
elif quant_type == "vector":
x = xq.float()
if len(xq.shape) == 2 and len(S1.shape) == 3:
S1 = S1.squeeze(0)
if len(xq.shape) == 2 and len(S2.shape) == 3:
S2 = S2.squeeze(0)
# print(x.shape, S1.shape, S2.shape)
if len(S1.shape) == 2:
x *= S1.t() / 127
else:
x *= S1 / 127
x *= S2 / 127
return x.to(dtype)
else:
return None
def dequant_min_max(xq, A, B, SA, SB, dtype):
offset = B.float().t().sum(0) * (SA[0] + SA[1])
x = xq.float()
if len(xq.shape) == 2 and len(SB.shape) == 3:
SB = SB.squeeze(0)
if len(xq.shape) == 2 and len(SA.shape) == 3:
SA = SA.squeeze(0)
if len(SB.shape) == 2:
x *= SB.t() / 127
else:
x *= SB / 127
x *= SA[1] / 127
x += offset
return x.to(dtype)
def get_8bit_linear(x, stochastic=False):
round_func = LinearFunction.round_stoachastic if stochastic else torch.round
max1 = torch.abs(x).max()
x = x / max1 * 127
x = round_func(x) / 127 * max1
# x = torch.round(x)/128*max1
return x
@staticmethod
def get_8bit_vector_wise(x, dim, stochastic=False):
round_func = LinearFunction.round_stoachastic if stochastic else torch.round
max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
max1[max1 == 0] = 1.0
x = (x * 127) / max1
x = round_func(x) / 127 * max1
return x
@staticmethod
def round_stoachastic(x):
sign = torch.sign(x)
absx = torch.abs(x)
decimal = absx - torch.floor(absx)
rdm = torch.rand_like(decimal)
return sign * (torch.floor(absx) + (rdm < decimal).to(x.dtype))
@staticmethod
def fake_8bit_storage(w, exponent_bits):
code = bnb.functional.create_dynamic_map(n=exponent_bits).to(w.device)
absmax, C = bnb.functional.quantize_blockwise(w.data, code=code)
out = bnb.functional.dequantize_blockwise(absmax, C, code)
out = out.half()
w.copy_(out)
return out
@staticmethod
def fake_8bit_storage_quantile(w, args):
code = bnb.functional.estimate_quantiles(w.data, offset=args.offset)
# C = bnb.functional.quantize_no_absmax(code, w)
# out = bnb.functional.dequantize_no_absmax(code, C, out=w.data)
# print(out)
# out = out.half()
code /= torch.max(torch.abs(code))
absmax, C = bnb.functional.quantize_blockwise(w.data, code=code)
out = bnb.functional.dequantize_blockwise(absmax, C, code)
out = out.half()
w.copy_(out)
return out
@staticmethod
def fake_8bit_storage_stoachstic(w):
rand = torch.rand(1024, device=w.device)
absmax, C = bnb.functional.quantize_blockwise(w.data, rand=rand)
out = bnb.functional.dequantize_blockwise(absmax, C)
out = out.half()
w.copy_(out)
return out
@staticmethod
def fake_8bit_storage_with_max(w, topk=8):
blocked_w = einops.rearrange(w.flatten(), "(h b) -> h b", b=256)
max_val, idx = torch.sort(torch.abs(blocked_w), dim=1, descending=True)
idx = idx[:, :topk]
max_val = max_val[:, :topk]
mask = torch.zeros_like(blocked_w)
mask.scatter_(dim=1, index=idx, src=torch.ones_like(max_val))
mask = mask.bool()
# 1. zero out max values
# 2. quantize + dequantize
# 3. write back max values
# 4. copy matrix back to weight
values = blocked_w[mask]
blocked_w[mask] = 0
code = bnb.functional.create_dynamic_map()
code = code.to(w.device)
absmax, C = bnb.functional.quantize_blockwise(blocked_w.data)
bnb.functional.dequantize_blockwise(absmax, C, out=blocked_w)
blocked_w[mask] = values
unblocked_w = blocked_w.flatten().view(w.shape)
w.copy_(unblocked_w)
return unblocked_w
@staticmethod
def forward(ctx, x, weight, bias=None, args=None):
if args.use_8bit_training != "off":
weight8, S1 = LinearFunction.quant(weight, args.quant_type, dim=1)
x8, S2 = LinearFunction.quant(x, args.quant_type, dim=2)
outputq = bnb.functional.igemm(x8, weight8.t())
output = LinearFunction.dequant(outputq, S1, S2, x.dtype, args.quant_type)
# if torch.rand(1) < 0.01:
# output32 = torch.matmul(x, weight.t())
# err = torch.abs(output-output32).float()
# relerr = err/(torch.abs(output32).float()+1e-8)
# print(f'{err.mean().item():.4f}, {relerr.mean().item():.4f}', args.quant_type, 'forward', proxy)
else:
# output = torch.matmul(x, weight.t())
output = torch.einsum("bsi,oi->bso", x, weight)
ctx.save_for_backward(x, weight, bias)
ctx.args = args
if bias is not None:
output += bias.unsqueeze(0).expand_as(output)
return output
@staticmethod
def backward(ctx, grad_output):
x, weight, bias = ctx.saved_tensors
args = ctx.args
stochastic = False
grad_input = grad_weight = grad_bias = None
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = grad_output.sum(0)
# weight and x are already 8bit
# -> transform grad_output to 8-bit
if args.use_8bit_training == "forward+wgrad":
grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=[0, 1])
x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1])
grad_weight8 = bnb.functional.igemm(grad_output8, x8)
grad_weight = LinearFunction.dequant(grad_weight8, S1, S2, grad_output.dtype, args.quant_type)
# grad_weight32 = torch.einsum('bso,bsi->oi', grad_output, x)
grad_input = grad_output.matmul(weight)
elif args.use_8bit_training == "full":
grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=[0, 1])
x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1])
grad_weight8 = torch.zeros_like(weight, dtype=torch.int32)
bnb.functional.igemm(grad_output8, x8, out=grad_weight8)
grad_weight = LinearFunction.dequant(grad_weight8, S1, S2, grad_output.dtype, args.quant_type)
grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=2)
weight8, S3 = LinearFunction.quant(weight, args.quant_type, dim=0)
grad_input8 = bnb.functional.igemm(grad_output8, weight8)
grad_input = LinearFunction.dequant(grad_input8, S1, S3, grad_output.dtype, args.quant_type)
else:
grad_input = grad_output.matmul(weight)
grad_weight = torch.einsum("bsi,bso->oi", x, grad_output)
return grad_input, grad_weight, grad_bias, None
class Linear8bit(nn.Module):
def __init__(self, input_features, output_features, bias=True, args=None):
super().__init__()
self.input_features = input_features
self.output_features = output_features
self.args = args
self.weight = nn.Parameter(torch.empty(output_features, input_features))
if bias:
self.bias = nn.Parameter(torch.empty(output_features))
else:
self.register_parameter("bias", None)
torch.nn.init.xavier_uniform_(self.weight)
if self.bias is not None:
torch.nn.init.zeros_(self.bias)
def forward(self, x):
self.args.training = self.training
return LinearFunction.apply(x, self.weight, self.bias, self.args)
@pytest.mark.parametrize("threshold", [0.0, 3.0], ids=id_formatter("threshold")) @pytest.mark.parametrize("threshold", [0.0, 3.0], ids=id_formatter("threshold"))
def test_linear8bitlt_inference(threshold): def test_linear8bitlt_inference(device, threshold):
l1 = bnb.nn.Linear8bitLt(32, 64, threshold=threshold).cuda().half() if device == "cpu":
assert l1.weight.device.type == "cuda" pytest.xfail("Not yet implemented on CPU")
assert l1.weight.dtype == torch.float16
l1 = bnb.nn.Linear8bitLt(32, 64, threshold=threshold, has_fp16_weights=False).to(device).half()
assert l1.weight.device.type == device
assert l1.weight.dtype == torch.int8
l1.eval() l1.eval()
for i in range(100): for i in range(100):
b1 = torch.randn(16, 8, 32, device="cuda").half() b1 = torch.randn(16, 8, 32, device=device).half()
o1 = l1(b1) o1 = l1(b1)
if i == 1: if i == 1:
assert l1.state.CB is not None assert l1.state.CB is not None
def test_linear8bitlt_accumulated_gradient(): # TODO: Remove support for training int8 weights
l1 = torch.nn.Sequential(*[bnb.nn.Linear8bitLt(32, 32).cuda().half() for i in range(2)]) @pytest.mark.parametrize("device", get_available_devices())
l2 = torch.nn.Sequential(*[torch.nn.Linear(32, 32).cuda().half() for i in range(2)]) def test_linear8bitlt_accumulated_gradient(device):
if device != "cuda":
pytest.skip("Only supported on CUDA")
l1 = torch.nn.Sequential(*[bnb.nn.Linear8bitLt(32, 32).to(device).half() for i in range(2)])
l2 = torch.nn.Sequential(*[torch.nn.Linear(32, 32).to(device).half() for i in range(2)])
l1[0].weight.data.copy_(l2[0].weight.data) l1[0].weight.data.copy_(l2[0].weight.data)
l1[1].weight.data.copy_(l2[1].weight.data) l1[1].weight.data.copy_(l2[1].weight.data)
l1[0].bias.data.copy_(l2[0].bias.data) l1[0].bias.data.copy_(l2[0].bias.data)
...@@ -325,7 +89,7 @@ def test_linear8bitlt_accumulated_gradient(): ...@@ -325,7 +89,7 @@ def test_linear8bitlt_accumulated_gradient():
acc_steps = 10 acc_steps = 10
for i in range(15): for i in range(15):
b1 = torch.randn(16, 8, 32, device="cuda").half() b1 = torch.randn(16, 8, 32, device=device).half()
o1 = l1(b1) o1 = l1(b1)
o2 = l2(b1) o2 = l2(b1)
loss1 = o1.mean() loss1 = o1.mean()
...@@ -353,8 +117,12 @@ def test_linear8bitlt_accumulated_gradient(): ...@@ -353,8 +117,12 @@ def test_linear8bitlt_accumulated_gradient():
assert_all_approx_close(l1[1].weight.grad, l2[1].weight.grad, rtol=1.05, atol=0.04, count=1) assert_all_approx_close(l1[1].weight.grad, l2[1].weight.grad, rtol=1.05, atol=0.04, count=1)
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("threshold", [0.0, 2.0]) @pytest.mark.parametrize("threshold", [0.0, 2.0])
def test_linear8bitlt_no_fp16_weights(threshold): def test_linear8bitlt_no_fp16_weights(device, threshold):
if device == "cpu":
pytest.xfail("Not yet supported on CPU")
l1 = ( l1 = (
bnb.nn.Linear8bitLt( bnb.nn.Linear8bitLt(
32, 32,
...@@ -362,23 +130,23 @@ def test_linear8bitlt_no_fp16_weights(threshold): ...@@ -362,23 +130,23 @@ def test_linear8bitlt_no_fp16_weights(threshold):
threshold=threshold, threshold=threshold,
has_fp16_weights=False, has_fp16_weights=False,
) )
.cuda() .to(device)
.half() .half()
) )
assert l1.weight.dtype == torch.int8 assert l1.weight.dtype == torch.int8
l1.eval() l1.eval()
for i in range(100): for i in range(100):
b1 = torch.randn(16, 8, 32, device="cuda").half() b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16)
o1 = l1(b1) o1 = l1(b1)
assert o1.dtype == torch.float16 assert o1.dtype == torch.float16
mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda() mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).to(device)
assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc1.weight.dtype == torch.int8
assert mlp.fc2.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8
for i in range(100): for i in range(100):
b1 = torch.randn(16, 8, 32, device="cuda").half() b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16)
o1 = mlp(b1) o1 = mlp(b1)
assert o1.dtype == torch.float16 assert o1.dtype == torch.float16
if threshold > 0: if threshold > 0:
...@@ -386,12 +154,12 @@ def test_linear8bitlt_no_fp16_weights(threshold): ...@@ -386,12 +154,12 @@ def test_linear8bitlt_no_fp16_weights(threshold):
if threshold > 0: if threshold > 0:
assert mlp.fc2.state.idx is not None assert mlp.fc2.state.idx is not None
mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda().half() mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).to(device).half()
assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc1.weight.dtype == torch.int8
assert mlp.fc2.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8
for i in range(100): for i in range(100):
b1 = torch.randn(16, 8, 32, device="cuda").half() b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16)
o1 = mlp(b1) o1 = mlp(b1)
assert o1.dtype == torch.float16 assert o1.dtype == torch.float16
if threshold > 0: if threshold > 0:
...@@ -399,10 +167,10 @@ def test_linear8bitlt_no_fp16_weights(threshold): ...@@ -399,10 +167,10 @@ def test_linear8bitlt_no_fp16_weights(threshold):
if threshold > 0: if threshold > 0:
assert mlp.fc2.state.idx is not None assert mlp.fc2.state.idx is not None
mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().cuda() mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().to(device)
for i in range(100): for i in range(100):
b1 = torch.randn(16, 8, 32, device="cuda").half() b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16)
o1 = mlp(b1) o1 = mlp(b1)
assert o1.dtype == torch.float16 assert o1.dtype == torch.float16
if threshold > 0: if threshold > 0:
...@@ -420,11 +188,11 @@ def test_linear8bitlt_no_fp16_weights(threshold): ...@@ -420,11 +188,11 @@ def test_linear8bitlt_no_fp16_weights(threshold):
has_fp16_weights=False, has_fp16_weights=False,
) )
.half() .half()
.to("cuda") .to(device)
) )
for i in range(100): for i in range(100):
b1 = torch.randn(16, 8, 32, device="cuda").half() b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16)
o1 = mlp(b1) o1 = mlp(b1)
assert o1.dtype == torch.float16 assert o1.dtype == torch.float16
if threshold > 0: if threshold > 0:
...@@ -433,8 +201,8 @@ def test_linear8bitlt_no_fp16_weights(threshold): ...@@ -433,8 +201,8 @@ def test_linear8bitlt_no_fp16_weights(threshold):
assert mlp.fc2.state.idx is not None assert mlp.fc2.state.idx is not None
assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc1.weight.dtype == torch.int8
assert mlp.fc2.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8
assert mlp.fc1.weight.device.type == "cuda" assert mlp.fc1.weight.device.type == device
assert mlp.fc2.weight.device.type == "cuda" assert mlp.fc2.weight.device.type == device
mlp = MLP8bit( mlp = MLP8bit(
32, 32,
...@@ -442,11 +210,11 @@ def test_linear8bitlt_no_fp16_weights(threshold): ...@@ -442,11 +210,11 @@ def test_linear8bitlt_no_fp16_weights(threshold):
threshold=threshold, threshold=threshold,
has_fp16_weights=False, has_fp16_weights=False,
) )
w1, w2 = mlp.fc1.weight.clone().cuda(), mlp.fc2.weight.clone().cuda() # grab weights before quantization, w1, w2 = mlp.fc1.weight.clone().to(device), mlp.fc2.weight.clone().to(device) # grab weights before quantization,
mlp = mlp.cuda().half() # and this line triggers quantization mlp = mlp.cuda().half() # and this line triggers quantization
for i in range(100): for i in range(100):
b1 = torch.randn(16, 8, 32, device="cuda").half() b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16)
o1 = mlp(b1) o1 = mlp(b1)
assert o1.dtype == torch.float16 assert o1.dtype == torch.float16
if threshold > 0: if threshold > 0:
...@@ -456,10 +224,10 @@ def test_linear8bitlt_no_fp16_weights(threshold): ...@@ -456,10 +224,10 @@ def test_linear8bitlt_no_fp16_weights(threshold):
assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc1.weight.dtype == torch.int8
assert mlp.fc2.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8
assert mlp.fc1.weight.device.type == "cuda" assert mlp.fc1.weight.device.type == device
assert mlp.fc2.weight.device.type == "cuda" assert mlp.fc2.weight.device.type == device
b1 = torch.randn(16, 8, 32, device="cuda", requires_grad=True, dtype=torch.half) b1 = torch.randn(16, 8, 32, device=device, requires_grad=True, dtype=torch.half)
o1 = mlp(b1) o1 = mlp(b1)
assert o1.dtype == torch.float16 assert o1.dtype == torch.float16
assert o1.requires_grad assert o1.requires_grad
...@@ -475,33 +243,37 @@ def test_linear8bitlt_no_fp16_weights(threshold): ...@@ -475,33 +243,37 @@ def test_linear8bitlt_no_fp16_weights(threshold):
assert (idx == 0).sum().item() <= b1.numel() * 0.005 assert (idx == 0).sum().item() <= b1.numel() * 0.005
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize( @pytest.mark.parametrize(
"module", "module",
[ [
lambda n_in, n_out, bias=True: bnb.nn.Linear8bitLt(n_in, n_out, bias=bias, has_fp16_weights=False), lambda n_in, n_out, bias=True: bnb.nn.Linear8bitLt(n_in, n_out, bias=bias, has_fp16_weights=False),
bnb.nn.LinearFP4, bnb.nn.LinearNF4,
], ],
ids=["Int8Lt", "FP4"], ids=["Int8Lt", "NF4"],
) )
def test_linear_kbit_fp32_bias(module): def test_linear_kbit_fp32_bias(device, module):
if device == "cpu":
pytest.xfail("Not yet implemented on CPU")
# casts model to fp16 -> int8 automatically # casts model to fp16 -> int8 automatically
l1 = module(32, 64).cuda() l1 = module(32, 64).to(device)
assert l1.weight.dtype in [torch.int8, torch.uint8] assert l1.weight.dtype in [torch.int8, torch.uint8]
assert l1.bias.dtype == torch.float32 assert l1.bias.dtype == torch.float32
for i in range(100): for i in range(100):
b1 = torch.randn(16, 8, 32, device="cuda").half() b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16)
# casts bias to fp32 # casts bias to fp32
o1 = l1(b1) o1 = l1(b1)
assert l1.bias.dtype == torch.float16 assert l1.bias.dtype == torch.float16
# casts model to fp16 -> int8 automatically # casts model to fp16 -> int8 automatically
l1 = module(32, 64, bias=False).cuda() l1 = module(32, 64, bias=False).to(device)
assert l1.weight.dtype in [torch.int8, torch.uint8] assert l1.weight.dtype in [torch.int8, torch.uint8]
assert l1.bias is None assert l1.bias is None
for i in range(100): for i in range(100):
b1 = torch.randn(16, 8, 32, device="cuda").half() b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16)
o1 = l1(b1) o1 = l1(b1)
assert l1.bias is None assert l1.bias is None
...@@ -519,8 +291,12 @@ module_dict = { ...@@ -519,8 +291,12 @@ module_dict = {
} }
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("module", module_dict.values(), ids=module_dict.keys()) @pytest.mark.parametrize("module", module_dict.values(), ids=module_dict.keys())
def test_kbit_backprop(module): def test_kbit_backprop(device, module):
if device == "cpu":
pytest.xfail("Not yet implemented on CPU")
b = 16 b = 16
dim1 = 36 dim1 = 36
dim2 = 84 dim2 = 84
...@@ -536,16 +312,16 @@ def test_kbit_backprop(module): ...@@ -536,16 +312,16 @@ def test_kbit_backprop(module):
kbit[1].weight.detach().copy_(ref[1].weight) kbit[1].weight.detach().copy_(ref[1].weight)
kbit[0].bias.detach().copy_(ref[0].bias) kbit[0].bias.detach().copy_(ref[0].bias)
kbit[1].bias.detach().copy_(ref[1].bias) kbit[1].bias.detach().copy_(ref[1].bias)
ref = ref.half().cuda() ref = ref.half().to(device)
kbit = kbit.half().cuda() kbit = kbit.half().to(device)
kbit = kbit.half().to("cuda") kbit = kbit.half().to(device)
errs1 = [] errs1 = []
errs2 = [] errs2 = []
relerrs1 = [] relerrs1 = []
relerrs2 = [] relerrs2 = []
for i in range(100): for i in range(100):
batch = torch.randn(b, dim1).half().cuda() batch = torch.randn(b, dim1, device=device, dtype=torch.float16)
out1 = ref(batch) out1 = ref(batch)
out2 = kbit(batch) out2 = kbit(batch)
out1.mean().backward() out1.mean().backward()
...@@ -578,6 +354,7 @@ def test_kbit_backprop(module): ...@@ -578,6 +354,7 @@ def test_kbit_backprop(module):
assert kbit[0].weight.grad is None or kbit[0].bias.grad.sum().item() == 0 assert kbit[0].weight.grad is None or kbit[0].bias.grad.sum().item() == 0
@pytest.mark.deprecated
def test_fp8linear(): def test_fp8linear():
b = 10 b = 10
h = 1024 h = 1024
...@@ -608,6 +385,7 @@ def test_fp8linear(): ...@@ -608,6 +385,7 @@ def test_fp8linear():
assert bgraderr < 0.00002 assert bgraderr < 0.00002
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("embedding_dim", [64, 65]) @pytest.mark.parametrize("embedding_dim", [64, 65])
@pytest.mark.parametrize("input_shape", [(10,), (10, 10), (10, 10, 10)], ids=str) @pytest.mark.parametrize("input_shape", [(10,), (10, 10), (10, 10, 10)], ids=str)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -621,7 +399,10 @@ def test_fp8linear(): ...@@ -621,7 +399,10 @@ def test_fp8linear():
], ],
ids=lambda x: x.__name__ if inspect.isclass(x) else str(x), ids=lambda x: x.__name__ if inspect.isclass(x) else str(x),
) )
def test_embedding_lossless(embedding_class, input_shape, embedding_dim, quant_storage): def test_embedding_lossless(device, embedding_class, input_shape, embedding_dim, quant_storage):
if device == "cpu":
pytest.xfail("Not yet supported on CPU")
num_embeddings = 128 num_embeddings = 128
src_weight = (torch.randn((num_embeddings, embedding_dim), dtype=torch.float32) > 0).to( src_weight = (torch.randn((num_embeddings, embedding_dim), dtype=torch.float32) > 0).to(
...@@ -641,10 +422,10 @@ def test_embedding_lossless(embedding_class, input_shape, embedding_dim, quant_s ...@@ -641,10 +422,10 @@ def test_embedding_lossless(embedding_class, input_shape, embedding_dim, quant_s
e.load_state_dict(emb_base.state_dict()) e.load_state_dict(emb_base.state_dict())
emb_base.cuda() emb_base.to(device)
e.cuda() e.to(device)
input_tokens = torch.randint(low=0, high=num_embeddings, size=input_shape, device="cuda") input_tokens = torch.randint(low=0, high=num_embeddings, size=input_shape, device=device)
torch.testing.assert_close( torch.testing.assert_close(
actual=e(input_tokens), actual=e(input_tokens),
...@@ -652,6 +433,7 @@ def test_embedding_lossless(embedding_class, input_shape, embedding_dim, quant_s ...@@ -652,6 +433,7 @@ def test_embedding_lossless(embedding_class, input_shape, embedding_dim, quant_s
) )
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("embedding_dim", [64, 65]) @pytest.mark.parametrize("embedding_dim", [64, 65])
@pytest.mark.parametrize("input_shape", [(10,), (10, 10), (10, 10, 10)], ids=str) @pytest.mark.parametrize("input_shape", [(10,), (10, 10), (10, 10, 10)], ids=str)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -665,7 +447,10 @@ def test_embedding_lossless(embedding_class, input_shape, embedding_dim, quant_s ...@@ -665,7 +447,10 @@ def test_embedding_lossless(embedding_class, input_shape, embedding_dim, quant_s
], ],
ids=lambda x: x.__name__ if inspect.isclass(x) else str(x), ids=lambda x: x.__name__ if inspect.isclass(x) else str(x),
) )
def test_embedding_error(embedding_class, input_shape, embedding_dim, quant_storage): def test_embedding_error(device, embedding_class, input_shape, embedding_dim, quant_storage):
if device == "cpu":
pytest.xfail("Not yet supported on CPU")
is_8bit = embedding_class is bnb.nn.Embedding8bit is_8bit = embedding_class is bnb.nn.Embedding8bit
num_embeddings = 128 num_embeddings = 128
...@@ -685,10 +470,10 @@ def test_embedding_error(embedding_class, input_shape, embedding_dim, quant_stor ...@@ -685,10 +470,10 @@ def test_embedding_error(embedding_class, input_shape, embedding_dim, quant_stor
e.load_state_dict(emb_base.state_dict()) e.load_state_dict(emb_base.state_dict())
emb_base.cuda() emb_base.to(device)
e.cuda() e.to(device)
input_tokens = torch.randint(low=0, high=num_embeddings, size=input_shape, device="cuda") input_tokens = torch.randint(low=0, high=num_embeddings, size=input_shape, device=device)
torch.testing.assert_close( torch.testing.assert_close(
actual=e(input_tokens), actual=e(input_tokens),
...@@ -698,46 +483,64 @@ def test_embedding_error(embedding_class, input_shape, embedding_dim, quant_stor ...@@ -698,46 +483,64 @@ def test_embedding_error(embedding_class, input_shape, embedding_dim, quant_stor
) )
def test_4bit_linear_warnings(): @pytest.mark.parametrize("device", get_available_devices())
def test_4bit_linear_warnings(device):
if device == "cpu":
pytest.xfail("Not yet implemented on CPU")
dim1 = 64 dim1 = 64
with pytest.warns(UserWarning, match=r"inference or training"): with pytest.warns(UserWarning, match=r"inference or training"):
net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)]) net = nn.Sequential(
net = net.cuda() *[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4", compute_dtype=torch.float32) for i in range(10)]
inp = torch.rand(10, dim1).cuda().half() )
net = net.to(device)
inp = torch.rand(10, dim1, device=device, dtype=torch.float16)
net(inp) net(inp)
with pytest.warns(UserWarning, match=r"inference."): with pytest.warns(UserWarning, match=r"inference."):
net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)]) net = nn.Sequential(
net = net.cuda() *[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4", compute_dtype=torch.float32) for i in range(10)]
inp = torch.rand(1, dim1).cuda().half() )
net = net.to(device)
inp = torch.rand(1, dim1, device=device, dtype=torch.float16)
net(inp) net(inp)
with pytest.warns(UserWarning) as record: with pytest.warns(UserWarning) as record:
net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)]) net = nn.Sequential(
net = net.cuda() *[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4", compute_dtype=torch.float32) for i in range(10)]
inp = torch.rand(10, dim1).cuda().half() )
net = net.to(device)
inp = torch.rand(10, dim1, device=device, dtype=torch.float16)
net(inp) net(inp)
net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)]) net = nn.Sequential(
net = net.cuda() *[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4", compute_dtype=torch.float32) for i in range(10)]
inp = torch.rand(1, dim1).cuda().half() )
net = net.to(device)
inp = torch.rand(1, dim1, device=device, dtype=torch.float16)
net(inp) net(inp)
assert len(record) == 2 assert len(record) == 2
def test_4bit_embedding_warnings(): @pytest.mark.parametrize("device", get_available_devices())
def test_4bit_embedding_warnings(device):
if device == "cpu":
pytest.xfail("Not yet implemented on CPU")
num_embeddings = 128 num_embeddings = 128
default_block_size = 64 default_block_size = 64
with pytest.warns(UserWarning, match=r"inference."): with pytest.warns(UserWarning, match=r"inference."):
net = bnb.nn.Embedding4bit(num_embeddings=num_embeddings, embedding_dim=default_block_size + 1) net = bnb.nn.Embedding4bit(
net.cuda() num_embeddings=num_embeddings, embedding_dim=default_block_size + 1, quant_type="nf4"
inp = torch.randint(low=0, high=num_embeddings, size=(1,), device="cuda") )
net.to(device)
inp = torch.randint(low=0, high=num_embeddings, size=(1,), device=device)
net(inp) net(inp)
def test_4bit_embedding_weight_fsdp_fix(): def test_4bit_embedding_weight_fsdp_fix(requires_cuda):
num_embeddings = 64 num_embeddings = 64
embedding_dim = 32 embedding_dim = 32
...@@ -754,7 +557,7 @@ def test_4bit_embedding_weight_fsdp_fix(): ...@@ -754,7 +557,7 @@ def test_4bit_embedding_weight_fsdp_fix():
assert module.weight.quant_state is not None assert module.weight.quant_state is not None
def test_4bit_linear_weight_fsdp_fix(): def test_4bit_linear_weight_fsdp_fix(requires_cuda):
inp_size = 64 inp_size = 64
out_size = 32 out_size = 32
......
...@@ -4,11 +4,11 @@ import pytest ...@@ -4,11 +4,11 @@ import pytest
import torch import torch
import bitsandbytes import bitsandbytes
from tests.helpers import TRUE_FALSE, id_formatter from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter
class TestLLMInt8Ops: class TestLLMInt8Ops:
@pytest.mark.parametrize("device", ["cpu", "cuda"]) @pytest.mark.parametrize("device", get_available_devices())
def test_int8_linear_matmul(self, device): def test_int8_linear_matmul(self, device):
A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device) A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device)
B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device) B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device)
...@@ -20,7 +20,7 @@ class TestLLMInt8Ops: ...@@ -20,7 +20,7 @@ class TestLLMInt8Ops:
torch.library.opcheck(torch.ops.bitsandbytes.int8_linear_matmul.default, (A, B)) torch.library.opcheck(torch.ops.bitsandbytes.int8_linear_matmul.default, (A, B))
@pytest.mark.parametrize("device", ["cpu", "cuda"]) @pytest.mark.parametrize("device", get_available_devices())
def test_int8_linear_matmul_out(self, device): def test_int8_linear_matmul_out(self, device):
A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device) A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device)
B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device) B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device)
...@@ -35,7 +35,7 @@ class TestLLMInt8Ops: ...@@ -35,7 +35,7 @@ class TestLLMInt8Ops:
torch.library.opcheck(torch.ops.bitsandbytes.int8_linear_matmul.out, (A, B, out)) torch.library.opcheck(torch.ops.bitsandbytes.int8_linear_matmul.out, (A, B, out))
@pytest.mark.parametrize("threshold", [0.0, 6.0]) @pytest.mark.parametrize("threshold", [0.0, 6.0])
@pytest.mark.parametrize("device", ["cpu", "cuda"]) @pytest.mark.parametrize("device", get_available_devices())
def test_int8_vectorwise_quant(self, threshold, device): def test_int8_vectorwise_quant(self, threshold, device):
if device == "cpu": if device == "cpu":
pytest.skip("CPU implementation is not available") pytest.skip("CPU implementation is not available")
...@@ -64,7 +64,7 @@ class TestLLMInt8Ops: ...@@ -64,7 +64,7 @@ class TestLLMInt8Ops:
torch.library.opcheck(torch.ops.bitsandbytes.int8_vectorwise_quant, (A, threshold)) torch.library.opcheck(torch.ops.bitsandbytes.int8_vectorwise_quant, (A, threshold))
@pytest.mark.parametrize("device", ["cpu", "cuda"]) @pytest.mark.parametrize("device", get_available_devices())
def test_int8_mm_dequant(self, device): def test_int8_mm_dequant(self, device):
A = torch.randint(-128, 127, (256, 256), dtype=torch.int32, device=device) A = torch.randint(-128, 127, (256, 256), dtype=torch.int32, device=device)
row_stats = torch.randn(256, dtype=torch.float32, device=device) row_stats = torch.randn(256, dtype=torch.float32, device=device)
...@@ -77,7 +77,7 @@ class TestLLMInt8Ops: ...@@ -77,7 +77,7 @@ class TestLLMInt8Ops:
torch.library.opcheck(torch.ops.bitsandbytes.int8_mm_dequant, (A, row_stats, col_stats)) torch.library.opcheck(torch.ops.bitsandbytes.int8_mm_dequant, (A, row_stats, col_stats))
@pytest.mark.parametrize("device", ["cpu", "cuda"]) @pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
@pytest.mark.parametrize("has_bias", TRUE_FALSE) @pytest.mark.parametrize("has_bias", TRUE_FALSE)
def test_int8_scaled_mm(self, device, dtype, has_bias): def test_int8_scaled_mm(self, device, dtype, has_bias):
...@@ -96,7 +96,7 @@ class TestLLMInt8Ops: ...@@ -96,7 +96,7 @@ class TestLLMInt8Ops:
class TestInt8BlockwiseQuantOps: class TestInt8BlockwiseQuantOps:
@pytest.mark.parametrize("device", ["cpu", "cuda"]) @pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512]) @pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
def test_quantize_blockwise(self, device, dtype, blocksize): def test_quantize_blockwise(self, device, dtype, blocksize):
...@@ -116,7 +116,7 @@ class TestInt8BlockwiseQuantOps: ...@@ -116,7 +116,7 @@ class TestInt8BlockwiseQuantOps:
torch.library.opcheck(torch.ops.bitsandbytes.quantize_blockwise, (A, code, blocksize)) torch.library.opcheck(torch.ops.bitsandbytes.quantize_blockwise, (A, code, blocksize))
@pytest.mark.parametrize("device", ["cpu", "cuda"]) @pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512]) @pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
def test_dequantize_blockwise(self, device, dtype, blocksize): def test_dequantize_blockwise(self, device, dtype, blocksize):
...@@ -140,7 +140,7 @@ class TestInt8BlockwiseQuantOps: ...@@ -140,7 +140,7 @@ class TestInt8BlockwiseQuantOps:
class Test4bitBlockwiseQuantOps: class Test4bitBlockwiseQuantOps:
@pytest.mark.parametrize("device", ["cpu", "cuda"]) @pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
@pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype"))
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
...@@ -164,7 +164,7 @@ class Test4bitBlockwiseQuantOps: ...@@ -164,7 +164,7 @@ class Test4bitBlockwiseQuantOps:
torch.library.opcheck(torch.ops.bitsandbytes.quantize_4bit, (A, blocksize, quant_type, storage_dtype)) torch.library.opcheck(torch.ops.bitsandbytes.quantize_4bit, (A, blocksize, quant_type, storage_dtype))
@pytest.mark.parametrize("device", ["cpu", "cuda"]) @pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
@pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype"))
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
...@@ -197,7 +197,7 @@ class Test4bitBlockwiseQuantOps: ...@@ -197,7 +197,7 @@ class Test4bitBlockwiseQuantOps:
torch.ops.bitsandbytes.dequantize_4bit.default, (A, absmax, blocksize, quant_type, shape, dtype) torch.ops.bitsandbytes.dequantize_4bit.default, (A, absmax, blocksize, quant_type, shape, dtype)
) )
@pytest.mark.parametrize("device", ["cpu", "cuda"]) @pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
@pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype"))
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
......
...@@ -47,7 +47,6 @@ str2optimizers["momentum_pytorch"] = ( ...@@ -47,7 +47,6 @@ str2optimizers["momentum_pytorch"] = (
) )
str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam) str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam)
str2optimizers["adam8bit"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False))
str2optimizers["adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True)) str2optimizers["adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True))
str2optimizers["paged_adam"] = (torch.optim.Adam, bnb.optim.PagedAdam) str2optimizers["paged_adam"] = (torch.optim.Adam, bnb.optim.PagedAdam)
str2optimizers["paged_adamw"] = (torch.optim.AdamW, bnb.optim.PagedAdamW) str2optimizers["paged_adamw"] = (torch.optim.AdamW, bnb.optim.PagedAdamW)
...@@ -88,19 +87,14 @@ str2optimizers["paged_ademamix8bit_blockwise_scheduled"] = ( ...@@ -88,19 +87,14 @@ str2optimizers["paged_ademamix8bit_blockwise_scheduled"] = (
) )
str2optimizers["lion"] = (Lion, bnb.optim.Lion) str2optimizers["lion"] = (Lion, bnb.optim.Lion)
str2optimizers["lion8bit"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=False))
str2optimizers["lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=True))
str2optimizers["paged_lion"] = (Lion, bnb.optim.PagedLion) str2optimizers["paged_lion"] = (Lion, bnb.optim.PagedLion)
str2optimizers["lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=True))
str2optimizers["paged_lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.PagedLion8bit(pxx, block_wise=True)) str2optimizers["paged_lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.PagedLion8bit(pxx, block_wise=True))
str2optimizers["momentum"] = ( str2optimizers["momentum"] = (
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False), lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False),
) )
str2optimizers["momentum8bit"] = (
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False),
)
str2optimizers["momentum8bit_blockwise"] = ( str2optimizers["momentum8bit_blockwise"] = (
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True),
...@@ -110,10 +104,6 @@ str2optimizers["rmsprop"] = ( ...@@ -110,10 +104,6 @@ str2optimizers["rmsprop"] = (
lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False), lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False),
) )
str2optimizers["rmsprop8bit"] = (
lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False),
)
str2optimizers["rmsprop8bit_blockwise"] = ( str2optimizers["rmsprop8bit_blockwise"] = (
lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=True), lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=True),
...@@ -128,8 +118,7 @@ str2statenames["paged_lion"] = [("exp_avg", "state1")] ...@@ -128,8 +118,7 @@ str2statenames["paged_lion"] = [("exp_avg", "state1")]
str2statenames["momentum"] = [("momentum_buffer", "state1")] str2statenames["momentum"] = [("momentum_buffer", "state1")]
str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["rmsprop"] = [("square_avg", "state1")] str2statenames["rmsprop"] = [("square_avg", "state1")]
str2statenames["adam8bit"] = [("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2")]
str2statenames["lamb8bit"] = [("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2")]
str2statenames["adam8bit_blockwise"] = [ str2statenames["adam8bit_blockwise"] = [
("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg", "state1", "qmap1", "absmax1"),
("exp_avg_sq", "state2", "qmap2", "absmax2"), ("exp_avg_sq", "state2", "qmap2", "absmax2"),
...@@ -142,10 +131,8 @@ str2statenames["paged_adamw8bit_blockwise"] = [ ...@@ -142,10 +131,8 @@ str2statenames["paged_adamw8bit_blockwise"] = [
("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg", "state1", "qmap1", "absmax1"),
("exp_avg_sq", "state2", "qmap2", "absmax2"), ("exp_avg_sq", "state2", "qmap2", "absmax2"),
] ]
str2statenames["momentum8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")]
str2statenames["lion8bit"] = [("exp_avg", "state1", "qmap1", "max1")]
str2statenames["momentum8bit_blockwise"] = [("momentum_buffer", "state1", "qmap1", "absmax1")] str2statenames["momentum8bit_blockwise"] = [("momentum_buffer", "state1", "qmap1", "absmax1")]
str2statenames["rmsprop8bit"] = [("square_avg", "state1", "qmap1", "max1")]
str2statenames["rmsprop8bit_blockwise"] = [("square_avg", "state1", "qmap1", "absmax1")] str2statenames["rmsprop8bit_blockwise"] = [("square_avg", "state1", "qmap1", "absmax1")]
str2statenames["lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")] str2statenames["lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")]
str2statenames["paged_lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")] str2statenames["paged_lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")]
...@@ -180,7 +167,7 @@ optimizer_names_32bit = [ ...@@ -180,7 +167,7 @@ optimizer_names_32bit = [
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize("gtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [32, 1024, 4097, 1], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim2", [32, 1024, 4097, 1], ids=id_formatter("dim2"))
def test_optimizer32bit(dim1, dim2, gtype, optim_name): def test_optimizer32bit(requires_cuda, dim1, dim2, gtype, optim_name):
if gtype == torch.bfloat16 and optim_name in ["momentum", "rmsprop"]: if gtype == torch.bfloat16 and optim_name in ["momentum", "rmsprop"]:
pytest.skip() pytest.skip()
if dim1 == 1 and dim2 == 1: if dim1 == 1 and dim2 == 1:
...@@ -256,7 +243,7 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): ...@@ -256,7 +243,7 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2"))
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=describe_dtype) @pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=describe_dtype)
def test_global_config(dim1, dim2, gtype): def test_global_config(requires_cuda, dim1, dim2, gtype):
if dim1 == 1 and dim2 == 1: if dim1 == 1 and dim2 == 1:
return return
p1 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1 p1 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
...@@ -298,10 +285,11 @@ def test_global_config(dim1, dim2, gtype): ...@@ -298,10 +285,11 @@ def test_global_config(dim1, dim2, gtype):
optimizer_names_8bit = [ optimizer_names_8bit = [
"adam8bit", # Non-blockwise optimizers are deprecated.
"lion8bit", # "adam8bit",
"momentum8bit", # "lion8bit",
"rmsprop8bit", # "momentum8bit",
# "rmsprop8bit",
"adam8bit_blockwise", "adam8bit_blockwise",
"lion8bit_blockwise", "lion8bit_blockwise",
"momentum8bit_blockwise", "momentum8bit_blockwise",
...@@ -315,7 +303,7 @@ optimizer_names_8bit = [ ...@@ -315,7 +303,7 @@ optimizer_names_8bit = [
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize("gtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
@pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
def test_optimizer8bit(dim1, dim2, gtype, optim_name): def test_optimizer8bit(requires_cuda, dim1, dim2, gtype, optim_name):
torch.set_printoptions(precision=6) torch.set_printoptions(precision=6)
if gtype == torch.bfloat16 and "blockwise" not in optim_name: if gtype == torch.bfloat16 and "blockwise" not in optim_name:
...@@ -479,7 +467,8 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): ...@@ -479,7 +467,8 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
@pytest.mark.parametrize("gtype", [torch.float32], ids=describe_dtype) @pytest.mark.parametrize("gtype", [torch.float32], ids=describe_dtype)
@pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits): @pytest.mark.deprecated
def test_adam_percentile_clipping(requires_cuda, dim1, dim2, gtype, optim_bits):
if dim1 == 1 and dim2 == 1: if dim1 == 1 and dim2 == 1:
return return
p1 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1 p1 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment