Commit 14147f6f authored by Matthew Douglas's avatar Matthew Douglas
Browse files

Test fix

parent 941681da
...@@ -21,7 +21,8 @@ BOOLEAN_TUPLES = list(product(TRUE_FALSE, repeat=2)) # all combinations of (boo ...@@ -21,7 +21,8 @@ BOOLEAN_TUPLES = list(product(TRUE_FALSE, repeat=2)) # all combinations of (boo
def get_available_devices(no_cpu=False): def get_available_devices(no_cpu=False):
if "BNB_TEST_DEVICE" in os.environ: if "BNB_TEST_DEVICE" in os.environ:
# If the environment variable is set, use it directly. # If the environment variable is set, use it directly.
return [d for d in os.environ["BNB_TEST_DEVICE"] if d.lower() != "cpu"] device = os.environ["BNB_TEST_DEVICE"]
return [] if no_cpu and device == "cpu" else [device]
devices = [] if HIP_ENVIRONMENT else ["cpu"] if not no_cpu else [] devices = [] if HIP_ENVIRONMENT else ["cpu"] if not no_cpu else []
......
...@@ -170,6 +170,7 @@ optimizer_names_32bit = [ ...@@ -170,6 +170,7 @@ optimizer_names_32bit = [
@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [32, 1024, 4097, 1], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim2", [32, 1024, 4097, 1], ids=id_formatter("dim2"))
@pytest.mark.parametrize("device", get_available_devices(no_cpu=True), ids=id_formatter("device")) @pytest.mark.parametrize("device", get_available_devices(no_cpu=True), ids=id_formatter("device"))
@pytest.mark.skipif(not get_available_devices(no_cpu=True), reason="No device")
def test_optimizer32bit(dim1, dim2, gtype, optim_name, device): def test_optimizer32bit(dim1, dim2, gtype, optim_name, device):
if optim_name.startswith("paged_") and sys.platform == "win32": if optim_name.startswith("paged_") and sys.platform == "win32":
pytest.skip("Paged optimizers can have issues on Windows.") pytest.skip("Paged optimizers can have issues on Windows.")
...@@ -250,6 +251,7 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name, device): ...@@ -250,6 +251,7 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name, device):
@pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2"))
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=describe_dtype) @pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=describe_dtype)
@pytest.mark.parametrize("device", get_available_devices(no_cpu=True)) @pytest.mark.parametrize("device", get_available_devices(no_cpu=True))
@pytest.mark.skipif(not get_available_devices(no_cpu=True), reason="No device")
def test_global_config(dim1, dim2, gtype, device): def test_global_config(dim1, dim2, gtype, device):
if dim1 == 1 and dim2 == 1: if dim1 == 1 and dim2 == 1:
return return
...@@ -306,6 +308,7 @@ optimizer_names_8bit = [ ...@@ -306,6 +308,7 @@ optimizer_names_8bit = [
@pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
@pytest.mark.parametrize("device", get_available_devices(no_cpu=True)) @pytest.mark.parametrize("device", get_available_devices(no_cpu=True))
@pytest.mark.skipif(not get_available_devices(no_cpu=True), reason="No device")
def test_optimizer8bit(dim1, dim2, gtype, optim_name, device): def test_optimizer8bit(dim1, dim2, gtype, optim_name, device):
torch.set_printoptions(precision=6) torch.set_printoptions(precision=6)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment