Commit 223fea51 authored by Egor Krivov's avatar Egor Krivov
Browse files

Add no_cpu for optimizers

parent 3b89a05e
......@@ -18,12 +18,12 @@ BOOLEAN_TUPLES = list(product(TRUE_FALSE, repeat=2)) # all combinations of (boo
@functools.cache
def get_available_devices():
def get_available_devices(no_cpu=False):
if "BNB_TEST_DEVICE" in os.environ:
# If the environment variable is set, use it directly.
return [os.environ["BNB_TEST_DEVICE"]]
devices = [] if HIP_ENVIRONMENT else ["cpu"]
devices = [] if HIP_ENVIRONMENT else ["cpu"] if not no_cpu else []
if hasattr(torch, "accelerator"):
# PyTorch 2.6+ - determine accelerator using agnostic API.
......
......@@ -169,7 +169,7 @@ optimizer_names_32bit = [
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [32, 1024, 4097, 1], ids=id_formatter("dim2"))
@pytest.mark.parametrize("device", get_available_devices(), ids=id_formatter("device"))
@pytest.mark.parametrize("device", get_available_devices(no_cpu=True), ids=id_formatter("device"))
def test_optimizer32bit(dim1, dim2, gtype, optim_name, device):
if optim_name.startswith("paged_") and sys.platform == "win32":
pytest.skip("Paged optimizers can have issues on Windows.")
......@@ -249,7 +249,7 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name, device):
@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2"))
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=describe_dtype)
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("device", get_available_devices(no_cpu=True))
def test_global_config(dim1, dim2, gtype, device):
if dim1 == 1 and dim2 == 1:
return
......@@ -305,7 +305,7 @@ optimizer_names_8bit = [
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
@pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("device", get_available_devices(no_cpu=True))
def test_optimizer8bit(dim1, dim2, gtype, optim_name, device):
torch.set_printoptions(precision=6)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment