"vscode:/vscode.git/clone" did not exist on "3ee899fa0c0a443db371848a87582b2e2295852d"
Unverified Commit e6aef938 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[fix] OSS Cpu tests (#333)

parent 38ad8638
...@@ -237,7 +237,7 @@ def run_test_add_param_group(rank, world_size, tempfile_name): ...@@ -237,7 +237,7 @@ def run_test_add_param_group(rank, world_size, tempfile_name):
def test_add_param_group(): def test_add_param_group():
world_size = 4 world_size = 4
if not torch.cuda.is_available() or torch.cuda.device_count() < world_size: if torch.cuda.is_available() and torch.cuda.device_count() < world_size:
world_size = min(world_size, torch.cuda.device_count()) world_size = min(world_size, torch.cuda.device_count())
temp_file_name = tempfile.mkstemp()[1] temp_file_name = tempfile.mkstemp()[1]
...@@ -262,6 +262,9 @@ def run_test_zero_grad(rank, world_size, tempfile_name): ...@@ -262,6 +262,9 @@ def run_test_zero_grad(rank, world_size, tempfile_name):
def test_zero_grad(): def test_zero_grad():
world_size = 2 world_size = 2
if torch.cuda.is_available() and torch.cuda.device_count() < world_size:
world_size = min(world_size, torch.cuda.device_count())
temp_file_name = tempfile.mkstemp()[1] temp_file_name = tempfile.mkstemp()[1]
mp.spawn(run_test_zero_grad, args=(world_size, temp_file_name), nprocs=world_size, join=True) mp.spawn(run_test_zero_grad, args=(world_size, temp_file_name), nprocs=world_size, join=True)
...@@ -474,7 +477,11 @@ def run_test_multiple_groups(rank, world_size, tempfile_name): ...@@ -474,7 +477,11 @@ def run_test_multiple_groups(rank, world_size, tempfile_name):
dist.gather(p, receptacle, dst=0, group=process_group) dist.gather(p, receptacle, dst=0, group=process_group)
if rank == 0: if rank == 0:
for sync_p in receptacle[1:]: for sync_p in receptacle[1:]:
assert torch.all(torch.eq(receptacle[0], sync_p)), "Models differ in between ranks" assert torch.all(
torch.eq(receptacle[0], sync_p)
), "Models differ in between ranks {} - {}".format(
torch.norm(receptacle[0]), torch.norm(sync_p)
)
if rank in sub_group_ranks: if rank in sub_group_ranks:
# Model fitting in the broadcast bucket # Model fitting in the broadcast bucket
...@@ -498,7 +505,6 @@ def run_test_multiple_groups(rank, world_size, tempfile_name): ...@@ -498,7 +505,6 @@ def run_test_multiple_groups(rank, world_size, tempfile_name):
check(optimizer) check(optimizer)
dist.destroy_process_group(process_group) dist.destroy_process_group(process_group)
dist.destroy_process_group()
def test_multiple_groups(): def test_multiple_groups():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment