Unverified Commit 7b127ccb authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[fix][OSS] enabling disabled tests for 1.8 (#534)

* enabling disabled tests
parent 8b59267b
......@@ -415,6 +415,7 @@ def test_sharding():
def run_test_collect_shards(rank, world_size, reference_rank, tempfile_name):
dist_init(rank, world_size, tempfile_name)
device = torch.device(rank) if torch.cuda.device_count() > 1 else DEVICE
torch.cuda.set_device(rank)
# Run a dummy step so that the optimizer state dict exists
batch, input_width, hidden, target_width = 3, 3, 3, 5
......@@ -472,14 +473,10 @@ def run_test_collect_shards(rank, world_size, reference_rank, tempfile_name):
dist.destroy_process_group()
# TODO(blefaudeux) Fix for torch v1.8.0
@pytest.mark.skipif(torch.__version__.split("+")[0].split(".") == ["1", "8", "0"], reason="disabled for torch 1.8.0")
@skip_if_single_gpu
def test_collect_shards():
world_size = 3
world_size = 2
temp_file_name = tempfile.mkstemp()[1]
if torch.cuda.is_available():
world_size = min(world_size, torch.cuda.device_count())
reference_rank = 0
mp.spawn(
......@@ -487,9 +484,10 @@ def test_collect_shards():
)
def run_test_reproducibility(rank, world_size, reference_rank, tempfile_name):
def run_test_reproducibility(rank, world_size, tempfile_name):
dist_init(rank, world_size, tempfile_name)
device = torch.device(rank) if torch.cuda.device_count() > 1 else DEVICE
torch.cuda.set_device(rank)
# Run a dummy step so that the optimizer state dict exists
batch, input_width, hidden, target_width = 3, 3, 3, 5
......@@ -535,21 +533,13 @@ def run_test_reproducibility(rank, world_size, reference_rank, tempfile_name):
dist.destroy_process_group()
# TODO(blefaudeux) Fix for torch v1.8.0
@pytest.mark.skipif(torch.__version__.split("+")[0].split(".") == ["1", "8", "0"], reason="disabled for torch 1.8.0")
@skip_if_single_gpu
def test_reproducibility():
world_size = 2
temp_file_name = tempfile.mkstemp()[1]
if torch.cuda.is_available() and torch.cuda.device_count() < world_size:
# Bail out if not enough devices
return
reference_rank = 0
mp.spawn(
run_test_reproducibility, args=(world_size, reference_rank, temp_file_name), nprocs=world_size, join=True,
run_test_reproducibility, args=(world_size, temp_file_name), nprocs=world_size, join=True,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment