Commit bb34a375 authored by Peizhao Zhang's avatar Peizhao Zhang Committed by Facebook GitHub Bot
Browse files

support specifying backend for testing helper.

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/513

support specifying backend for testing helper.

Reviewed By: tglik

Differential Revision: D44401470

fbshipit-source-id: 9c7962cf40d3c677f9a3c7bfa9cdf5dcecae2ba9
parent 46606a02
...@@ -41,36 +41,39 @@ def skip_if_no_gpu(func): ...@@ -41,36 +41,39 @@ def skip_if_no_gpu(func):
return wrapper return wrapper
def enable_ddp_env(func): def enable_ddp_env(backend="gloo"):
@wraps(func) def _enable_ddp_env(func):
def wrapper(*args, **kwargs): @wraps(func)
def find_free_port() -> str: def wrapper(*args, **kwargs):
s = socket.socket() def find_free_port() -> str:
s.bind(("localhost", 0)) # Bind to a free port provided by the host. s = socket.socket()
return str(s.getsockname()[1]) s.bind(("localhost", 0)) # Bind to a free port provided by the host.
return str(s.getsockname()[1])
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = find_free_port() os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = find_free_port()
return distributed_worker(
main_func=func, return distributed_worker(
args=args, main_func=func,
kwargs=kwargs, args=args,
backend="gloo", kwargs=kwargs,
init_method="file:///tmp/detectron2go_test_ddp_init_{}".format( backend=backend,
uuid.uuid4().hex init_method="file:///tmp/detectron2go_test_ddp_init_{}".format(
), uuid.uuid4().hex
dist_params=DistributedParams( ),
local_rank=0, dist_params=DistributedParams(
machine_rank=0, local_rank=0,
global_rank=0, machine_rank=0,
num_processes_per_machine=1, global_rank=0,
world_size=1, num_processes_per_machine=1,
), world_size=1,
return_save_file=None, # don't save file ),
) return_save_file=None, # don't save file
)
return wrapper
return wrapper
return _enable_ddp_env
def tempdir(func): def tempdir(func):
......
...@@ -27,7 +27,7 @@ class TestLightningTrainNet(unittest.TestCase): ...@@ -27,7 +27,7 @@ class TestLightningTrainNet(unittest.TestCase):
return mah.create_detection_cfg(GeneralizedRCNNTask, tmp_dir) return mah.create_detection_cfg(GeneralizedRCNNTask, tmp_dir)
@tempdir @tempdir
@enable_ddp_env @enable_ddp_env()
def test_train_net_main(self, root_dir): def test_train_net_main(self, root_dir):
"""tests the main training entry point.""" """tests the main training entry point."""
cfg = self._get_cfg(root_dir) cfg = self._get_cfg(root_dir)
...@@ -36,7 +36,7 @@ class TestLightningTrainNet(unittest.TestCase): ...@@ -36,7 +36,7 @@ class TestLightningTrainNet(unittest.TestCase):
main(cfg, root_dir, GeneralizedRCNNTask) main(cfg, root_dir, GeneralizedRCNNTask)
@tempdir @tempdir
@enable_ddp_env @enable_ddp_env()
def test_checkpointing(self, tmp_dir): def test_checkpointing(self, tmp_dir):
"""tests saving and loading from checkpoint.""" """tests saving and loading from checkpoint."""
cfg = self._get_cfg(tmp_dir) cfg = self._get_cfg(tmp_dir)
......
...@@ -237,7 +237,7 @@ class TestOptimizer(unittest.TestCase): ...@@ -237,7 +237,7 @@ class TestOptimizer(unittest.TestCase):
self, optimizer.param_groups[2], num_params=2, lr=1.0, weight_decay=0.0 self, optimizer.param_groups[2], num_params=2, lr=1.0, weight_decay=0.0
) )
@helper.enable_ddp_env @helper.enable_ddp_env()
def test_create_optimizer_custom_ddp(self): def test_create_optimizer_custom_ddp(self):
class Model(torch.nn.Module): class Model(torch.nn.Module):
def __init__(self): def __init__(self):
......
...@@ -173,7 +173,7 @@ class TestDefaultRunner(unittest.TestCase): ...@@ -173,7 +173,7 @@ class TestDefaultRunner(unittest.TestCase):
) )
self.assertTrue(isinstance(runner, default_runner.Detectron2GoRunner)) self.assertTrue(isinstance(runner, default_runner.Detectron2GoRunner))
@helper.enable_ddp_env @helper.enable_ddp_env()
def test_d2go_runner_ema(self): def test_d2go_runner_ema(self):
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
ds_name = create_local_dataset(tmp_dir, 5, 10, 10) ds_name = create_local_dataset(tmp_dir, 5, 10, 10)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment