You need to sign in or sign up before continuing.
Commit d8bdc633 authored by Tsahi Glik's avatar Tsahi Glik Committed by Facebook GitHub Bot
Browse files

Integrate AIEnv with D2Go train_net

Summary:
Add support in d2go.distributed for `env://` init method. Use env variables as specified in https://pytorch.org/docs/stable/distributed.html#environment-variable-initialization for initialized distributed params.

Also change train_net cli function signature to accept args list instead of only using `sys.argv`. To allow calling this function from AIEnv launcher.

Differential Revision: D34540275

fbshipit-source-id: 7f718aed4c010b0ac8347d43b5ca5b401210756c
parent f16cc060
...@@ -9,6 +9,7 @@ Similar to detectron2.engine.launch, may support a few more things: ...@@ -9,6 +9,7 @@ Similar to detectron2.engine.launch, may support a few more things:
""" """
import logging import logging
import os
import tempfile import tempfile
import detectron2.utils.comm as comm import detectron2.utils.comm as comm
...@@ -87,10 +88,8 @@ def launch( ...@@ -87,10 +88,8 @@ def launch(
prefix = f"detectron2go_{main_func.__module__}.{main_func.__name__}_return" prefix = f"detectron2go_{main_func.__module__}.{main_func.__name__}_return"
with tempfile.NamedTemporaryFile(prefix=prefix, suffix=".pth") as f: with tempfile.NamedTemporaryFile(prefix=prefix, suffix=".pth") as f:
return_file = f.name return_file = f.name
mp.spawn( if dist_url.startswith("env://"):
_distributed_worker, _run_with_dist_env(
nprocs=num_processes_per_machine,
args=(
main_func, main_func,
world_size, world_size,
num_processes_per_machine, num_processes_per_machine,
...@@ -99,15 +98,69 @@ def launch( ...@@ -99,15 +98,69 @@ def launch(
backend, backend,
return_file, return_file,
args, args,
), )
daemon=False, else:
) mp.spawn(
if machine_rank == 0: _distributed_worker,
nprocs=num_processes_per_machine,
args=(
main_func,
world_size,
num_processes_per_machine,
machine_rank,
dist_url,
backend,
return_file,
args,
),
daemon=False,
)
if machine_rank == 0 and get_local_rank() == 0:
return torch.load(return_file) return torch.load(return_file)
else: else:
return main_func(*args) return main_func(*args)
def _run_with_dist_env(
main_func,
world_size,
num_processes_per_machine,
machine_rank,
dist_url,
backend,
return_file,
args,
):
assert dist_url.startswith("env://")
# Read torch.distributed params from env according to the contract in
# https://pytorch.org/docs/stable/distributed.html#environment-variable-initialization
world_size = int(os.environ.get("WORLD_SIZE", 1))
num_processes_per_machine = int(os.environ.get("LOCAL_WORLD_SIZE", 1))
machine_rank = int(os.environ.get("GROUP_RANK", 0))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
num_machines = int(world_size / num_processes_per_machine)
logger.info(
"Loaded distributed params from env."
f" Run with num_processes_per_machine: {num_processes_per_machine},"
f" num_machines: {num_machines}, machine_rank: {machine_rank},"
)
_distributed_worker(
local_rank,
main_func,
world_size,
num_processes_per_machine,
machine_rank,
dist_url,
backend,
return_file,
args,
)
def _distributed_worker( def _distributed_worker(
local_rank, local_rank,
main_func, main_func,
......
...@@ -6,6 +6,7 @@ Detection Training Script. ...@@ -6,6 +6,7 @@ Detection Training Script.
""" """
import logging import logging
import sys
import detectron2.utils.comm as comm import detectron2.utils.comm as comm
from d2go.distributed import launch from d2go.distributed import launch
...@@ -89,7 +90,7 @@ def run_with_cmdline_args(args): ...@@ -89,7 +90,7 @@ def run_with_cmdline_args(args):
) )
def cli(): def cli(args):
parser = basic_argument_parser(requires_output_dir=False) parser = basic_argument_parser(requires_output_dir=False)
parser.add_argument( parser.add_argument(
"--eval-only", action="store_true", help="perform evaluation only" "--eval-only", action="store_true", help="perform evaluation only"
...@@ -99,8 +100,8 @@ def cli(): ...@@ -99,8 +100,8 @@ def cli():
action="store_true", action="store_true",
help="whether to attempt to resume from the checkpoint directory", help="whether to attempt to resume from the checkpoint directory",
) )
run_with_cmdline_args(parser.parse_args()) run_with_cmdline_args(parser.parse_args(args))
if __name__ == "__main__": if __name__ == "__main__":
cli() cli(sys.argv)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment