"git@developer.sourcefind.cn:OpenDAS/pytorch3d.git" did not exist on "e0753f0b0d7fbbc07556b3e3d2bf7116b784153d"
Unverified Commit e6aef938 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[fix] OSS Cpu tests (#333)

parent 38ad8638
......@@ -237,7 +237,7 @@ def run_test_add_param_group(rank, world_size, tempfile_name):
def test_add_param_group():
world_size = 4
if not torch.cuda.is_available() or torch.cuda.device_count() < world_size:
if torch.cuda.is_available() and torch.cuda.device_count() < world_size:
world_size = min(world_size, torch.cuda.device_count())
temp_file_name = tempfile.mkstemp()[1]
......@@ -262,6 +262,9 @@ def run_test_zero_grad(rank, world_size, tempfile_name):
def test_zero_grad():
world_size = 2
if torch.cuda.is_available() and torch.cuda.device_count() < world_size:
world_size = min(world_size, torch.cuda.device_count())
temp_file_name = tempfile.mkstemp()[1]
mp.spawn(run_test_zero_grad, args=(world_size, temp_file_name), nprocs=world_size, join=True)
......@@ -474,7 +477,11 @@ def run_test_multiple_groups(rank, world_size, tempfile_name):
dist.gather(p, receptacle, dst=0, group=process_group)
if rank == 0:
for sync_p in receptacle[1:]:
assert torch.all(torch.eq(receptacle[0], sync_p)), "Models differ in between ranks"
assert torch.all(
torch.eq(receptacle[0], sync_p)
), "Models differ in between ranks {} - {}".format(
torch.norm(receptacle[0]), torch.norm(sync_p)
)
if rank in sub_group_ranks:
# Model fitting in the broadcast bucket
......@@ -498,7 +505,6 @@ def run_test_multiple_groups(rank, world_size, tempfile_name):
check(optimizer)
dist.destroy_process_group(process_group)
dist.destroy_process_group()
def test_multiple_groups():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment