Commit bb34a375 authored by Peizhao Zhang's avatar Peizhao Zhang Committed by Facebook GitHub Bot
Browse files

support specifying backend for testing helper.

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/513

support specifying backend for testing helper.

Reviewed By: tglik

Differential Revision: D44401470

fbshipit-source-id: 9c7962cf40d3c677f9a3c7bfa9cdf5dcecae2ba9
parent 46606a02
......@@ -41,36 +41,39 @@ def skip_if_no_gpu(func):
return wrapper
def enable_ddp_env(func):
@wraps(func)
def wrapper(*args, **kwargs):
def find_free_port() -> str:
s = socket.socket()
s.bind(("localhost", 0)) # Bind to a free port provided by the host.
return str(s.getsockname()[1])
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = find_free_port()
return distributed_worker(
main_func=func,
args=args,
kwargs=kwargs,
backend="gloo",
init_method="file:///tmp/detectron2go_test_ddp_init_{}".format(
uuid.uuid4().hex
),
dist_params=DistributedParams(
local_rank=0,
machine_rank=0,
global_rank=0,
num_processes_per_machine=1,
world_size=1,
),
return_save_file=None, # don't save file
)
return wrapper
def enable_ddp_env(backend="gloo"):
def _enable_ddp_env(func):
@wraps(func)
def wrapper(*args, **kwargs):
def find_free_port() -> str:
s = socket.socket()
s.bind(("localhost", 0)) # Bind to a free port provided by the host.
return str(s.getsockname()[1])
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = find_free_port()
return distributed_worker(
main_func=func,
args=args,
kwargs=kwargs,
backend=backend,
init_method="file:///tmp/detectron2go_test_ddp_init_{}".format(
uuid.uuid4().hex
),
dist_params=DistributedParams(
local_rank=0,
machine_rank=0,
global_rank=0,
num_processes_per_machine=1,
world_size=1,
),
return_save_file=None, # don't save file
)
return wrapper
return _enable_ddp_env
def tempdir(func):
......
......@@ -27,7 +27,7 @@ class TestLightningTrainNet(unittest.TestCase):
return mah.create_detection_cfg(GeneralizedRCNNTask, tmp_dir)
@tempdir
@enable_ddp_env
@enable_ddp_env()
def test_train_net_main(self, root_dir):
"""tests the main training entry point."""
cfg = self._get_cfg(root_dir)
......@@ -36,7 +36,7 @@ class TestLightningTrainNet(unittest.TestCase):
main(cfg, root_dir, GeneralizedRCNNTask)
@tempdir
@enable_ddp_env
@enable_ddp_env()
def test_checkpointing(self, tmp_dir):
"""tests saving and loading from checkpoint."""
cfg = self._get_cfg(tmp_dir)
......
......@@ -237,7 +237,7 @@ class TestOptimizer(unittest.TestCase):
self, optimizer.param_groups[2], num_params=2, lr=1.0, weight_decay=0.0
)
@helper.enable_ddp_env
@helper.enable_ddp_env()
def test_create_optimizer_custom_ddp(self):
class Model(torch.nn.Module):
def __init__(self):
......
......@@ -173,7 +173,7 @@ class TestDefaultRunner(unittest.TestCase):
)
self.assertTrue(isinstance(runner, default_runner.Detectron2GoRunner))
@helper.enable_ddp_env
@helper.enable_ddp_env()
def test_d2go_runner_ema(self):
with tempfile.TemporaryDirectory() as tmp_dir:
ds_name = create_local_dataset(tmp_dir, 5, 10, 10)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment