Unverified Commit 7b127ccb authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[fix][OSS] enabling disabled tests for 1.8 (#534)

* enabling disabled tests
parent 8b59267b
...@@ -415,6 +415,7 @@ def test_sharding(): ...@@ -415,6 +415,7 @@ def test_sharding():
def run_test_collect_shards(rank, world_size, reference_rank, tempfile_name): def run_test_collect_shards(rank, world_size, reference_rank, tempfile_name):
dist_init(rank, world_size, tempfile_name) dist_init(rank, world_size, tempfile_name)
device = torch.device(rank) if torch.cuda.device_count() > 1 else DEVICE device = torch.device(rank) if torch.cuda.device_count() > 1 else DEVICE
torch.cuda.set_device(rank)
# Run a dummy step so that the optimizer state dict exists # Run a dummy step so that the optimizer state dict exists
batch, input_width, hidden, target_width = 3, 3, 3, 5 batch, input_width, hidden, target_width = 3, 3, 3, 5
...@@ -472,14 +473,10 @@ def run_test_collect_shards(rank, world_size, reference_rank, tempfile_name): ...@@ -472,14 +473,10 @@ def run_test_collect_shards(rank, world_size, reference_rank, tempfile_name):
dist.destroy_process_group() dist.destroy_process_group()
# TODO(blefaudeux) Fix for torch v1.8.0 @skip_if_single_gpu
@pytest.mark.skipif(torch.__version__.split("+")[0].split(".") == ["1", "8", "0"], reason="disabled for torch 1.8.0")
def test_collect_shards(): def test_collect_shards():
world_size = 3 world_size = 2
temp_file_name = tempfile.mkstemp()[1] temp_file_name = tempfile.mkstemp()[1]
if torch.cuda.is_available():
world_size = min(world_size, torch.cuda.device_count())
reference_rank = 0 reference_rank = 0
mp.spawn( mp.spawn(
...@@ -487,9 +484,10 @@ def test_collect_shards(): ...@@ -487,9 +484,10 @@ def test_collect_shards():
) )
def run_test_reproducibility(rank, world_size, reference_rank, tempfile_name): def run_test_reproducibility(rank, world_size, tempfile_name):
dist_init(rank, world_size, tempfile_name) dist_init(rank, world_size, tempfile_name)
device = torch.device(rank) if torch.cuda.device_count() > 1 else DEVICE device = torch.device(rank) if torch.cuda.device_count() > 1 else DEVICE
torch.cuda.set_device(rank)
# Run a dummy step so that the optimizer state dict exists # Run a dummy step so that the optimizer state dict exists
batch, input_width, hidden, target_width = 3, 3, 3, 5 batch, input_width, hidden, target_width = 3, 3, 3, 5
...@@ -535,21 +533,13 @@ def run_test_reproducibility(rank, world_size, reference_rank, tempfile_name): ...@@ -535,21 +533,13 @@ def run_test_reproducibility(rank, world_size, reference_rank, tempfile_name):
dist.destroy_process_group() dist.destroy_process_group()
# TODO(blefaudeux) Fix for torch v1.8.0
@pytest.mark.skipif(torch.__version__.split("+")[0].split(".") == ["1", "8", "0"], reason="disabled for torch 1.8.0")
@skip_if_single_gpu @skip_if_single_gpu
def test_reproducibility(): def test_reproducibility():
world_size = 2 world_size = 2
temp_file_name = tempfile.mkstemp()[1] temp_file_name = tempfile.mkstemp()[1]
if torch.cuda.is_available() and torch.cuda.device_count() < world_size:
# Bail out if not enough devices
return
reference_rank = 0
mp.spawn( mp.spawn(
run_test_reproducibility, args=(world_size, reference_rank, temp_file_name), nprocs=world_size, join=True, run_test_reproducibility, args=(world_size, temp_file_name), nprocs=world_size, join=True,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment