Commit bb34a375 authored by Peizhao Zhang's avatar Peizhao Zhang Committed by Facebook GitHub Bot
Browse files

support specifying backend for testing helper.

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/513

support specifying backend for testing helper.

Reviewed By: tglik

Differential Revision: D44401470

fbshipit-source-id: 9c7962cf40d3c677f9a3c7bfa9cdf5dcecae2ba9
parent 46606a02
...@@ -41,7 +41,8 @@ def skip_if_no_gpu(func): ...@@ -41,7 +41,8 @@ def skip_if_no_gpu(func):
return wrapper return wrapper
def enable_ddp_env(func): def enable_ddp_env(backend="gloo"):
def _enable_ddp_env(func):
@wraps(func) @wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
def find_free_port() -> str: def find_free_port() -> str:
...@@ -56,7 +57,7 @@ def enable_ddp_env(func): ...@@ -56,7 +57,7 @@ def enable_ddp_env(func):
main_func=func, main_func=func,
args=args, args=args,
kwargs=kwargs, kwargs=kwargs,
backend="gloo", backend=backend,
init_method="file:///tmp/detectron2go_test_ddp_init_{}".format( init_method="file:///tmp/detectron2go_test_ddp_init_{}".format(
uuid.uuid4().hex uuid.uuid4().hex
), ),
...@@ -72,6 +73,8 @@ def enable_ddp_env(func): ...@@ -72,6 +73,8 @@ def enable_ddp_env(func):
return wrapper return wrapper
return _enable_ddp_env
def tempdir(func): def tempdir(func):
"""A decorator for creating a tempory directory that is cleaned up after function execution.""" """A decorator for creating a tempory directory that is cleaned up after function execution."""
......
...@@ -27,7 +27,7 @@ class TestLightningTrainNet(unittest.TestCase): ...@@ -27,7 +27,7 @@ class TestLightningTrainNet(unittest.TestCase):
return mah.create_detection_cfg(GeneralizedRCNNTask, tmp_dir) return mah.create_detection_cfg(GeneralizedRCNNTask, tmp_dir)
@tempdir @tempdir
@enable_ddp_env @enable_ddp_env()
def test_train_net_main(self, root_dir): def test_train_net_main(self, root_dir):
"""tests the main training entry point.""" """tests the main training entry point."""
cfg = self._get_cfg(root_dir) cfg = self._get_cfg(root_dir)
...@@ -36,7 +36,7 @@ class TestLightningTrainNet(unittest.TestCase): ...@@ -36,7 +36,7 @@ class TestLightningTrainNet(unittest.TestCase):
main(cfg, root_dir, GeneralizedRCNNTask) main(cfg, root_dir, GeneralizedRCNNTask)
@tempdir @tempdir
@enable_ddp_env @enable_ddp_env()
def test_checkpointing(self, tmp_dir): def test_checkpointing(self, tmp_dir):
"""tests saving and loading from checkpoint.""" """tests saving and loading from checkpoint."""
cfg = self._get_cfg(tmp_dir) cfg = self._get_cfg(tmp_dir)
......
...@@ -237,7 +237,7 @@ class TestOptimizer(unittest.TestCase): ...@@ -237,7 +237,7 @@ class TestOptimizer(unittest.TestCase):
self, optimizer.param_groups[2], num_params=2, lr=1.0, weight_decay=0.0 self, optimizer.param_groups[2], num_params=2, lr=1.0, weight_decay=0.0
) )
@helper.enable_ddp_env @helper.enable_ddp_env()
def test_create_optimizer_custom_ddp(self): def test_create_optimizer_custom_ddp(self):
class Model(torch.nn.Module): class Model(torch.nn.Module):
def __init__(self): def __init__(self):
......
...@@ -173,7 +173,7 @@ class TestDefaultRunner(unittest.TestCase): ...@@ -173,7 +173,7 @@ class TestDefaultRunner(unittest.TestCase):
) )
self.assertTrue(isinstance(runner, default_runner.Detectron2GoRunner)) self.assertTrue(isinstance(runner, default_runner.Detectron2GoRunner))
@helper.enable_ddp_env @helper.enable_ddp_env()
def test_d2go_runner_ema(self): def test_d2go_runner_ema(self):
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
ds_name = create_local_dataset(tmp_dir, 5, 10, 10) ds_name = create_local_dataset(tmp_dir, 5, 10, 10)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment