"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "463cdeab31c61f33a39d4adf1d88f1f4c26689d3"
Unverified Commit d9e6ceaa authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[fix] fix test_oss.py when host have 2 GPUs (#26)


Co-authored-by: default avatarMin Xu <m1n@fb.com>
parent 525e709b
...@@ -125,7 +125,7 @@ def run_test_step(rank, world_size): ...@@ -125,7 +125,7 @@ def run_test_step(rank, world_size):
@skip_if_no_cuda @skip_if_no_cuda
def test_step(): def test_step():
world_size = 2 world_size = min(2, torch.cuda.device_count())
mp.spawn(run_test_step, args=(world_size,), nprocs=world_size, join=True) mp.spawn(run_test_step, args=(world_size,), nprocs=world_size, join=True)
...@@ -169,7 +169,7 @@ def run_test_step_with_closure(rank, world_size, optimizer=None): ...@@ -169,7 +169,7 @@ def run_test_step_with_closure(rank, world_size, optimizer=None):
@skip_if_no_cuda @skip_if_no_cuda
def test_step_with_closure(): def test_step_with_closure():
world_size = 2 world_size = min(2, torch.cuda.device_count())
mp.spawn(run_test_step_with_closure, args=(world_size,), nprocs=world_size, join=True) mp.spawn(run_test_step_with_closure, args=(world_size,), nprocs=world_size, join=True)
...@@ -236,6 +236,8 @@ def run_test_collect_shards(rank, world_size, reference_rank): ...@@ -236,6 +236,8 @@ def run_test_collect_shards(rank, world_size, reference_rank):
def test_collect_shards(): def test_collect_shards():
world_size = 3 world_size = 3
if torch.cuda.is_available():
world_size = min(world_size, torch.cuda.device_count())
reference_rank = 0 reference_rank = 0
mp.spawn( mp.spawn(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment