You need to sign in or sign up before continuing.
Commit 223fea51 authored by Egor Krivov's avatar Egor Krivov
Browse files

Add no_cpu for optimizers

parent 3b89a05e
...@@ -18,12 +18,12 @@ BOOLEAN_TUPLES = list(product(TRUE_FALSE, repeat=2)) # all combinations of (boo ...@@ -18,12 +18,12 @@ BOOLEAN_TUPLES = list(product(TRUE_FALSE, repeat=2)) # all combinations of (boo
@functools.cache @functools.cache
def get_available_devices(): def get_available_devices(no_cpu=False):
if "BNB_TEST_DEVICE" in os.environ: if "BNB_TEST_DEVICE" in os.environ:
# If the environment variable is set, use it directly. # If the environment variable is set, use it directly.
return [os.environ["BNB_TEST_DEVICE"]] return [os.environ["BNB_TEST_DEVICE"]]
devices = [] if HIP_ENVIRONMENT else ["cpu"] devices = [] if HIP_ENVIRONMENT else ["cpu"] if not no_cpu else []
if hasattr(torch, "accelerator"): if hasattr(torch, "accelerator"):
# PyTorch 2.6+ - determine accelerator using agnostic API. # PyTorch 2.6+ - determine accelerator using agnostic API.
......
...@@ -169,7 +169,7 @@ optimizer_names_32bit = [ ...@@ -169,7 +169,7 @@ optimizer_names_32bit = [
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize("gtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [32, 1024, 4097, 1], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim2", [32, 1024, 4097, 1], ids=id_formatter("dim2"))
@pytest.mark.parametrize("device", get_available_devices(), ids=id_formatter("device")) @pytest.mark.parametrize("device", get_available_devices(no_cpu=True), ids=id_formatter("device"))
def test_optimizer32bit(dim1, dim2, gtype, optim_name, device): def test_optimizer32bit(dim1, dim2, gtype, optim_name, device):
if optim_name.startswith("paged_") and sys.platform == "win32": if optim_name.startswith("paged_") and sys.platform == "win32":
pytest.skip("Paged optimizers can have issues on Windows.") pytest.skip("Paged optimizers can have issues on Windows.")
...@@ -249,7 +249,7 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name, device): ...@@ -249,7 +249,7 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name, device):
@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2"))
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=describe_dtype) @pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=describe_dtype)
@pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("device", get_available_devices(no_cpu=True))
def test_global_config(dim1, dim2, gtype, device): def test_global_config(dim1, dim2, gtype, device):
if dim1 == 1 and dim2 == 1: if dim1 == 1 and dim2 == 1:
return return
...@@ -305,7 +305,7 @@ optimizer_names_8bit = [ ...@@ -305,7 +305,7 @@ optimizer_names_8bit = [
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize("gtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
@pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
@pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("device", get_available_devices(no_cpu=True))
def test_optimizer8bit(dim1, dim2, gtype, optim_name, device): def test_optimizer8bit(dim1, dim2, gtype, optim_name, device):
torch.set_printoptions(precision=6) torch.set_printoptions(precision=6)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment