Unverified Commit 1088ec52 authored by Matthew Douglas's avatar Matthew Douglas Committed by GitHub
Browse files

Updates for device agnosticism (#1601)

* Include device support tags for transformers multi-backend compatability; add xpu() and cpu() to Params4bit

* Make test suite more device-agnostic

* Additional device agnostic tests

* Additional device agnosticism for tests

* Add BNB_TEST_DEVICE env var to manually select device for unit tests

* Include device support tags for transformers multi-backend compatability; add xpu() and cpu() to Params4bit

* Make test suite more device-agnostic

* Additional device agnostic tests

* Additional device agnosticism for tests

* Add BNB_TEST_DEVICE env var to manually select device for unit tests

* Small bugfix for int8 test

* Exclude backward() from code coverage reports

* Params4bit: don't try to quantize when moving to meta device
parent 97073cdb
......@@ -20,15 +20,15 @@ from .nn import modules
from .optim import adam
# This is a signal for integrations with transformers/diffusers.
# Eventually, we will remove this and check based on release version.
# Eventually we may remove this but it is currently required for compatibility.
features = {"multi-backend"}
supported_torch_devices = {
"cuda",
"cpu",
# "mps",
# "xpu",
# "hpu",
# "npu",
"cuda", # NVIDIA/AMD GPU
"xpu", # Intel GPU
"hpu", # Gaudi
"npu", # Ascend NPU
"mps", # Apple Silicon
}
if torch.cuda.is_available():
......
......@@ -284,7 +284,7 @@ class MatMul8bitLt(torch.autograd.Function):
dtype=torch.float16,
)
if state.threshold > 0.0 and subA is not None:
if state.threshold > 0.0 and subA is not None and subA.numel() > 0:
grad_B[:, idx] += torch.matmul(grad_output.t(), subA)
if req_gradA:
......
......@@ -341,7 +341,7 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
for i in range(gap):
values.append(0)
values.sort()
code = torch.Tensor(values)
code = torch.tensor(values)
code /= code.max()
return code
......
......@@ -306,9 +306,15 @@ class Params4bit(torch.nn.Parameter):
self.bnb_quantized = True
return self
def cpu(self):
return self.to(device="cpu")
def cuda(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False):
return self.to(device="cuda" if device is None else device, non_blocking=non_blocking)
def xpu(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False):
return self.to(device="xpu" if device is None else device, non_blocking=non_blocking)
@overload
def to(
self: T,
......@@ -326,7 +332,7 @@ class Params4bit(torch.nn.Parameter):
def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
if device is not None and device.type == "cuda" and not self.bnb_quantized:
if device is not None and device.type != "meta" and not self.bnb_quantized:
return self._quantize(device)
else:
if self.quant_state is not None:
......
......@@ -79,6 +79,12 @@ include = ["bitsandbytes*"]
[tool.setuptools.dynamic]
version = {attr = "bitsandbytes.__version__"}
[tool.coverage.report]
exclude_also = [
# exclude backward() functions from coverage, as they are invoked from C++
'def backward\(ctx'
]
[tool.pytest.ini_options]
addopts = "-rP -m 'not slow and not benchmark and not deprecated'"
# ; --cov=bitsandbytes
......
import functools
from io import BytesIO
from itertools import product
import os
import random
from typing import Any
......@@ -13,6 +15,38 @@ BOOLEAN_TRIPLES = list(product(TRUE_FALSE, repeat=3)) # all combinations of (bo
BOOLEAN_TUPLES = list(product(TRUE_FALSE, repeat=2)) # all combinations of (bool, bool)
@functools.cache
def get_available_devices():
if "BNB_TEST_DEVICE" in os.environ:
# If the environment variable is set, use it directly.
return [os.environ["BNB_TEST_DEVICE"]]
devices = ["cpu"]
if hasattr(torch, "accelerator"):
# PyTorch 2.6+ - determine accelerator using agnostic API.
if torch.accelerator.is_available():
devices += [str(torch.accelerator.current_accelerator())]
else:
if torch.cuda.is_available():
devices += ["cuda"]
if torch.backends.mps.is_available():
devices += ["mps"]
if hasattr(torch, "xpu") and torch.xpu.is_available():
devices += ["xpu"]
custom_backend_name = torch._C._get_privateuse1_backend_name()
custom_backend_module = getattr(torch, custom_backend_name, None)
custom_backend_is_available_fn = getattr(custom_backend_module, "is_available", None)
if custom_backend_is_available_fn and custom_backend_module.is_available():
devices += [custom_backend_name]
return devices
def torch_save_to_buffer(obj):
buffer = BytesIO()
torch.save(obj, buffer)
......
......@@ -6,12 +6,14 @@ from tests.helpers import (
BOOLEAN_TRIPLES,
TRUE_FALSE,
describe_dtype,
get_available_devices,
id_formatter,
)
TRANSPOSE_VALS = [(False, True), (False, False)]
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dim1", [40], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [64, 0], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", [32], ids=id_formatter("dim3"))
......@@ -27,10 +29,16 @@ TRANSPOSE_VALS = [(False, True), (False, False)]
@pytest.mark.parametrize("transpose", TRANSPOSE_VALS, ids=id_formatter("transpose"))
@pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights"))
@pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias"))
def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights, has_bias):
def test_matmullt(
device, dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights, has_bias
):
if device != "cuda" and funcs[1] == bnb.research.switchback_bnb:
# TODO: Deprecate/remove?
pytest.skip("switchback_bnb only works on CUDA.")
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
outlier_dim = torch.randint(0, dimA[1], size=(dimA[1] // 8,), device="cuda")
outlier_dim = torch.randint(0, dimA[1], size=(dimA[1] // 8,), device=device)
if has_bias == False:
req_grad = list(req_grad)
req_grad[2] = False
......@@ -38,21 +46,21 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
for i in range(3):
# normal multiply
if funcs[0] in [torch.mm, torch.matmul]:
A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype)
A = torch.randn(size=dimA, device=device, requires_grad=req_grad[0], dtype=dtype)
if decomp == 6.0:
with torch.no_grad():
A[:, outlier_dim] = 6.0
B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype)
B = torch.randn(size=dimB, device=device, requires_grad=req_grad[1], dtype=dtype)
target = torch.randn(
size=(dim2, dim4),
device="cuda",
device=device,
requires_grad=req_grad[1],
dtype=dtype,
)
bias = None
bias2 = None
if has_bias:
bias = torch.randn(dim4, device="cuda", dtype=dtype, requires_grad=req_grad[2])
bias = torch.randn(dim4, device=device, dtype=dtype, requires_grad=req_grad[2])
bias2 = bias.clone()
torch.nn.init.xavier_uniform_(B)
B2 = B.clone()
......@@ -91,7 +99,8 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
if has_fp16_weights:
if any(req_grad):
out_bnb.data.copy_(out_torch)
torch.cuda.synchronize()
if device == "cuda":
torch.cuda.synchronize()
loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
loss_bnb.backward()
gradA1 = A.grad
......@@ -135,6 +144,7 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
torch.testing.assert_close(gradBias1, gradBias2)
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dim1", [48], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [64, 0], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", [64], ids=id_formatter("dim3"))
......@@ -147,6 +157,7 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"], ids=id_formatter("quant_type"))
def test_matmul_4bit(
device,
dim1,
dim2,
dim3,
......@@ -159,6 +170,9 @@ def test_matmul_4bit(
compress_statistics,
quant_type,
):
if device == "cpu" and quant_type == "fp4":
pytest.skip("Only nf4 is supported on CPU")
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
if has_bias == False:
......@@ -168,13 +182,13 @@ def test_matmul_4bit(
for i in range(3):
# normal multiply
if funcs[0] in [torch.mm, torch.matmul]:
A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype)
B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype)
target = torch.randn(size=(dim2, dim4), device="cuda", requires_grad=req_grad[1], dtype=dtype)
A = torch.randn(size=dimA, device=device, requires_grad=req_grad[0], dtype=dtype)
B = torch.randn(size=dimB, device=device, requires_grad=req_grad[1], dtype=dtype)
target = torch.randn(size=(dim2, dim4), device=device, requires_grad=req_grad[1], dtype=dtype)
bias = None
bias2 = None
if has_bias:
bias = torch.randn(dim4, device="cuda", dtype=dtype, requires_grad=req_grad[2])
bias = torch.randn(dim4, device=device, dtype=dtype, requires_grad=req_grad[2])
bias2 = bias.clone()
torch.nn.init.xavier_uniform_(B)
......@@ -204,7 +218,8 @@ def test_matmul_4bit(
# assert err < 0.20
if any(req_grad):
out_bnb.data.copy_(out_torch)
torch.cuda.synchronize()
if device == "cuda":
torch.cuda.synchronize()
loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
loss_bnb.backward()
gradA1 = A.grad
......
This diff is collapsed.
......@@ -7,7 +7,7 @@ import pytest
import torch
import bitsandbytes as bnb
from tests.helpers import TRUE_FALSE, torch_load_from_buffer, torch_save_to_buffer
from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter, torch_load_from_buffer, torch_save_to_buffer
storage = {
"uint8": torch.uint8,
......@@ -17,15 +17,18 @@ storage = {
}
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_storage", ["uint8", "float16", "bfloat16", "float32"])
@pytest.mark.parametrize("bias", TRUE_FALSE)
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE)
@pytest.mark.parametrize("bias", TRUE_FALSE, ids=id_formatter("bias"))
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
@pytest.mark.parametrize("save_before_forward", TRUE_FALSE)
def test_linear_serialization(quant_type, compress_statistics, bias, quant_storage, save_before_forward):
@pytest.mark.parametrize("save_before_forward", TRUE_FALSE, ids=id_formatter("save_before_forward"))
def test_linear_serialization(device, quant_type, compress_statistics, bias, quant_storage, save_before_forward):
if device == "cpu":
pytest.xfail("Dequantization is not yet implemented for CPU")
original_dtype = torch.float16
compute_dtype = None
device = "cuda"
layer_shape = (300, 400)
linear = torch.nn.Linear(*layer_shape, dtype=original_dtype, device="cpu") # original layer
......@@ -52,7 +55,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
# restoring from state_dict:
bias_data2 = sd.pop("bias", None)
weight_data2 = sd.pop("weight")
weight2 = bnb.nn.Params4bit.from_prequantized(quantized_stats=sd, data=weight_data2)
weight2 = bnb.nn.Params4bit.from_prequantized(quantized_stats=sd, data=weight_data2, device=device)
# creating new layer with same params:
linear_q2 = bnb.nn.Linear4bit(
......@@ -174,18 +177,50 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
assert size_ratio < target_compression, ratio_error_msg
def test_copy_param():
tensor = torch.tensor([1.0, 2.0, 3.0, 4.0])
param = bnb.nn.Params4bit(data=tensor, requires_grad=False).cuda(0)
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
@pytest.mark.parametrize("blocksize", [64, 128])
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
def test_copy_param(device, quant_type, blocksize, compress_statistics):
if device == "cpu":
if compress_statistics:
pytest.skip("Currently segfaults on CPU")
if quant_type == "fp4":
pytest.xfail("FP4 not supported on CPU")
tensor = torch.linspace(1, blocksize, blocksize)
param = bnb.nn.Params4bit(
data=tensor,
quant_type=quant_type,
blocksize=blocksize,
compress_statistics=compress_statistics,
requires_grad=False,
).to(device)
shallow_copy_param = copy.copy(param)
assert param.quant_state is shallow_copy_param.quant_state
assert param.data.data_ptr() == shallow_copy_param.data.data_ptr()
def test_deepcopy_param():
tensor = torch.tensor([1.0, 2.0, 3.0, 4.0])
param = bnb.nn.Params4bit(data=tensor, requires_grad=False).cuda(0)
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
@pytest.mark.parametrize("blocksize", [64, 128])
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
def test_deepcopy_param(device, quant_type, blocksize, compress_statistics):
if device == "cpu":
if compress_statistics:
pytest.skip("Currently segfaults on CPU")
if quant_type == "fp4":
pytest.xfail("FP4 not supported on CPU")
tensor = torch.linspace(1, blocksize, blocksize)
param = bnb.nn.Params4bit(
data=tensor,
quant_type=quant_type,
blocksize=blocksize,
compress_statistics=compress_statistics,
requires_grad=False,
).to(device)
dict_keys_before = set(param.__dict__.keys())
copy_param = copy.deepcopy(param)
dict_keys_after = set(param.__dict__.keys())
......@@ -199,12 +234,27 @@ def test_deepcopy_param():
assert dict_keys_before == dict_keys_copy
def test_params4bit_real_serialization():
original_tensor = torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float32)
original_param = bnb.nn.Params4bit(data=original_tensor, quant_type="fp4")
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
@pytest.mark.parametrize("blocksize", [64, 128])
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
def test_params4bit_real_serialization(device, quant_type, blocksize, compress_statistics):
if device == "cpu":
if compress_statistics:
pytest.skip("Currently segfaults on CPU")
if quant_type == "fp4":
pytest.xfail("FP4 not supported on CPU")
original_tensor = torch.linspace(1, blocksize, blocksize, dtype=torch.float32)
original_param = bnb.nn.Params4bit(
data=original_tensor,
quant_type=quant_type,
blocksize=blocksize,
compress_statistics=compress_statistics,
)
dict_keys_before = set(original_param.__dict__.keys())
original_param.cuda(0) # move to CUDA to trigger quantization
original_param.to(device) # change device to trigger quantization
serialized_param = pickle.dumps(original_param)
deserialized_param = pickle.loads(serialized_param)
......
......@@ -11,6 +11,7 @@ import bitsandbytes as bnb
from bitsandbytes.nn.modules import Linear8bitLt
from tests.helpers import (
TRUE_FALSE,
get_available_devices,
id_formatter,
torch_load_from_buffer,
torch_save_to_buffer,
......@@ -19,7 +20,11 @@ from tests.helpers import (
# contributed by Alex Borzunov, see:
# https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py
def test_linear_no_igemmlt():
@pytest.mark.parametrize("device", get_available_devices())
def test_linear_no_igemmlt(device):
if device == "cpu":
pytest.xfail("Not yet implemented on CPU")
linear = torch.nn.Linear(1024, 3072)
x = torch.randn(3, 1024, dtype=torch.half)
linear_custom = Linear8bitLt(
......@@ -29,6 +34,8 @@ def test_linear_no_igemmlt():
has_fp16_weights=False,
threshold=6.0,
)
# TODO: Remove, this is no longer implemented
linear_custom.state.force_no_igemmlt = True
linear_custom.weight = bnb.nn.Int8Params(
......@@ -37,11 +44,11 @@ def test_linear_no_igemmlt():
has_fp16_weights=False,
).to(linear.weight.dtype)
linear_custom.bias = linear.bias
linear_custom = linear_custom.cuda()
linear = linear.half().cuda()
linear_custom = linear_custom.to(device)
linear = linear.half().to(device)
x_ref = x.clone().cuda().requires_grad_(True)
x_ours = x.clone().cuda().requires_grad_(True)
x_ref = x.clone().to(device).requires_grad_(True)
x_ours = x.clone().to(device).requires_grad_(True)
fx_ref = linear(x_ref).float()
grad_proj = torch.randn_like(fx_ref)
(fx_ref * grad_proj).mean().backward()
......@@ -58,18 +65,25 @@ def test_linear_no_igemmlt():
torch.testing.assert_close(x_ref.grad, x_ours.grad, atol=0.01, rtol=1e-5)
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights"))
@pytest.mark.parametrize("threshold", [0.0, 6.0], ids=id_formatter("threshold"))
@pytest.mark.parametrize("serialize_before_forward", TRUE_FALSE, ids=id_formatter("serialize_before_forward"))
@pytest.mark.parametrize("deserialize_before_cuda", TRUE_FALSE, ids=id_formatter("deserialize_before_cuda"))
@pytest.mark.parametrize("save_before_forward", TRUE_FALSE, ids=id_formatter("save_before_forward"))
@pytest.mark.parametrize("load_before_cuda", TRUE_FALSE, ids=id_formatter("load_before_cuda"))
def test_linear_serialization(
device,
has_fp16_weights,
threshold,
serialize_before_forward,
deserialize_before_cuda,
save_before_forward,
load_before_cuda,
):
if device == "cpu":
pytest.xfail("Not yet implemented on CPU")
linear = torch.nn.Linear(32, 96)
# TODO: Fallback for bad shapes
x = torch.randn(4, 32, dtype=torch.half)
......@@ -80,7 +94,7 @@ def test_linear_serialization(
linear.out_features,
linear.bias is not None,
has_fp16_weights=has_fp16_weights,
threshold=6.0,
threshold=threshold,
)
linear_custom.weight = bnb.nn.Int8Params(
......@@ -89,7 +103,7 @@ def test_linear_serialization(
has_fp16_weights=has_fp16_weights,
)
linear_custom.bias = linear.bias
linear_custom = linear_custom.cuda()
linear_custom = linear_custom.to(device)
if serialize_before_forward:
state_dict_8bit = linear_custom.state_dict()
......@@ -125,7 +139,7 @@ def test_linear_serialization(
linear.out_features,
linear.bias is not None,
has_fp16_weights=has_fp16_weights,
threshold=6.0,
threshold=threshold,
)
if deserialize_before_cuda:
......@@ -135,7 +149,7 @@ def test_linear_serialization(
if load_before_cuda:
new_linear_custom2 = torch_load_from_buffer(bytes_8bit)
new_linear_custom = new_linear_custom.cuda()
new_linear_custom = new_linear_custom.to(device)
if not deserialize_before_cuda:
new_linear_custom.load_state_dict(new_state_dict, strict=True)
......
This diff is collapsed.
......@@ -4,11 +4,11 @@ import pytest
import torch
import bitsandbytes
from tests.helpers import TRUE_FALSE, id_formatter
from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter
class TestLLMInt8Ops:
@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("device", get_available_devices())
def test_int8_linear_matmul(self, device):
A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device)
B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device)
......@@ -20,7 +20,7 @@ class TestLLMInt8Ops:
torch.library.opcheck(torch.ops.bitsandbytes.int8_linear_matmul.default, (A, B))
@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("device", get_available_devices())
def test_int8_linear_matmul_out(self, device):
A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device)
B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device)
......@@ -35,7 +35,7 @@ class TestLLMInt8Ops:
torch.library.opcheck(torch.ops.bitsandbytes.int8_linear_matmul.out, (A, B, out))
@pytest.mark.parametrize("threshold", [0.0, 6.0])
@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("device", get_available_devices())
def test_int8_vectorwise_quant(self, threshold, device):
if device == "cpu":
pytest.skip("CPU implementation is not available")
......@@ -64,7 +64,7 @@ class TestLLMInt8Ops:
torch.library.opcheck(torch.ops.bitsandbytes.int8_vectorwise_quant, (A, threshold))
@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("device", get_available_devices())
def test_int8_mm_dequant(self, device):
A = torch.randint(-128, 127, (256, 256), dtype=torch.int32, device=device)
row_stats = torch.randn(256, dtype=torch.float32, device=device)
......@@ -77,7 +77,7 @@ class TestLLMInt8Ops:
torch.library.opcheck(torch.ops.bitsandbytes.int8_mm_dequant, (A, row_stats, col_stats))
@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
@pytest.mark.parametrize("has_bias", TRUE_FALSE)
def test_int8_scaled_mm(self, device, dtype, has_bias):
......@@ -96,7 +96,7 @@ class TestLLMInt8Ops:
class TestInt8BlockwiseQuantOps:
@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
def test_quantize_blockwise(self, device, dtype, blocksize):
......@@ -116,7 +116,7 @@ class TestInt8BlockwiseQuantOps:
torch.library.opcheck(torch.ops.bitsandbytes.quantize_blockwise, (A, code, blocksize))
@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
def test_dequantize_blockwise(self, device, dtype, blocksize):
......@@ -140,7 +140,7 @@ class TestInt8BlockwiseQuantOps:
class Test4bitBlockwiseQuantOps:
@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
@pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype"))
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
......@@ -164,7 +164,7 @@ class Test4bitBlockwiseQuantOps:
torch.library.opcheck(torch.ops.bitsandbytes.quantize_4bit, (A, blocksize, quant_type, storage_dtype))
@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
@pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype"))
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
......@@ -197,7 +197,7 @@ class Test4bitBlockwiseQuantOps:
torch.ops.bitsandbytes.dequantize_4bit.default, (A, absmax, blocksize, quant_type, shape, dtype)
)
@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
@pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype"))
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
......
......@@ -47,7 +47,6 @@ str2optimizers["momentum_pytorch"] = (
)
str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam)
str2optimizers["adam8bit"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False))
str2optimizers["adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True))
str2optimizers["paged_adam"] = (torch.optim.Adam, bnb.optim.PagedAdam)
str2optimizers["paged_adamw"] = (torch.optim.AdamW, bnb.optim.PagedAdamW)
......@@ -88,19 +87,14 @@ str2optimizers["paged_ademamix8bit_blockwise_scheduled"] = (
)
str2optimizers["lion"] = (Lion, bnb.optim.Lion)
str2optimizers["lion8bit"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=False))
str2optimizers["lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=True))
str2optimizers["paged_lion"] = (Lion, bnb.optim.PagedLion)
str2optimizers["lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=True))
str2optimizers["paged_lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.PagedLion8bit(pxx, block_wise=True))
str2optimizers["momentum"] = (
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False),
)
str2optimizers["momentum8bit"] = (
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False),
)
str2optimizers["momentum8bit_blockwise"] = (
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True),
......@@ -110,10 +104,6 @@ str2optimizers["rmsprop"] = (
lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False),
)
str2optimizers["rmsprop8bit"] = (
lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False),
)
str2optimizers["rmsprop8bit_blockwise"] = (
lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=True),
......@@ -128,8 +118,7 @@ str2statenames["paged_lion"] = [("exp_avg", "state1")]
str2statenames["momentum"] = [("momentum_buffer", "state1")]
str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["rmsprop"] = [("square_avg", "state1")]
str2statenames["adam8bit"] = [("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2")]
str2statenames["lamb8bit"] = [("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2")]
str2statenames["adam8bit_blockwise"] = [
("exp_avg", "state1", "qmap1", "absmax1"),
("exp_avg_sq", "state2", "qmap2", "absmax2"),
......@@ -142,10 +131,8 @@ str2statenames["paged_adamw8bit_blockwise"] = [
("exp_avg", "state1", "qmap1", "absmax1"),
("exp_avg_sq", "state2", "qmap2", "absmax2"),
]
str2statenames["momentum8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")]
str2statenames["lion8bit"] = [("exp_avg", "state1", "qmap1", "max1")]
str2statenames["momentum8bit_blockwise"] = [("momentum_buffer", "state1", "qmap1", "absmax1")]
str2statenames["rmsprop8bit"] = [("square_avg", "state1", "qmap1", "max1")]
str2statenames["rmsprop8bit_blockwise"] = [("square_avg", "state1", "qmap1", "absmax1")]
str2statenames["lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")]
str2statenames["paged_lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")]
......@@ -180,7 +167,7 @@ optimizer_names_32bit = [
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [32, 1024, 4097, 1], ids=id_formatter("dim2"))
def test_optimizer32bit(dim1, dim2, gtype, optim_name):
def test_optimizer32bit(requires_cuda, dim1, dim2, gtype, optim_name):
if gtype == torch.bfloat16 and optim_name in ["momentum", "rmsprop"]:
pytest.skip()
if dim1 == 1 and dim2 == 1:
......@@ -256,7 +243,7 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2"))
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=describe_dtype)
def test_global_config(dim1, dim2, gtype):
def test_global_config(requires_cuda, dim1, dim2, gtype):
if dim1 == 1 and dim2 == 1:
return
p1 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
......@@ -298,10 +285,11 @@ def test_global_config(dim1, dim2, gtype):
optimizer_names_8bit = [
"adam8bit",
"lion8bit",
"momentum8bit",
"rmsprop8bit",
# Non-blockwise optimizers are deprecated.
# "adam8bit",
# "lion8bit",
# "momentum8bit",
# "rmsprop8bit",
"adam8bit_blockwise",
"lion8bit_blockwise",
"momentum8bit_blockwise",
......@@ -315,7 +303,7 @@ optimizer_names_8bit = [
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
@pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
def test_optimizer8bit(dim1, dim2, gtype, optim_name):
def test_optimizer8bit(requires_cuda, dim1, dim2, gtype, optim_name):
torch.set_printoptions(precision=6)
if gtype == torch.bfloat16 and "blockwise" not in optim_name:
......@@ -479,7 +467,8 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
@pytest.mark.parametrize("gtype", [torch.float32], ids=describe_dtype)
@pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
@pytest.mark.deprecated
def test_adam_percentile_clipping(requires_cuda, dim1, dim2, gtype, optim_bits):
if dim1 == 1 and dim2 == 1:
return
p1 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment