Commit d8bdc633 authored by Tsahi Glik's avatar Tsahi Glik Committed by Facebook GitHub Bot
Browse files

Integrate AIEnv with D2Go train_net

Summary:
Add support in d2go.distributed for `env://` init method. Use env variables as specified in https://pytorch.org/docs/stable/distributed.html#environment-variable-initialization for initialized distributed params.

Also change train_net cli function signature to accept args list instead of only using `sys.argv`. To allow calling this function from AIEnv launcher.

Differential Revision: D34540275

fbshipit-source-id: 7f718aed4c010b0ac8347d43b5ca5b401210756c
parent f16cc060
......@@ -9,6 +9,7 @@ Similar to detectron2.engine.launch, may support a few more things:
"""
import logging
import os
import tempfile
import detectron2.utils.comm as comm
......@@ -87,10 +88,8 @@ def launch(
prefix = f"detectron2go_{main_func.__module__}.{main_func.__name__}_return"
with tempfile.NamedTemporaryFile(prefix=prefix, suffix=".pth") as f:
return_file = f.name
mp.spawn(
_distributed_worker,
nprocs=num_processes_per_machine,
args=(
if dist_url.startswith("env://"):
_run_with_dist_env(
main_func,
world_size,
num_processes_per_machine,
......@@ -99,15 +98,69 @@ def launch(
backend,
return_file,
args,
),
daemon=False,
)
if machine_rank == 0:
)
else:
mp.spawn(
_distributed_worker,
nprocs=num_processes_per_machine,
args=(
main_func,
world_size,
num_processes_per_machine,
machine_rank,
dist_url,
backend,
return_file,
args,
),
daemon=False,
)
if machine_rank == 0 and get_local_rank() == 0:
return torch.load(return_file)
else:
return main_func(*args)
def _run_with_dist_env(
main_func,
world_size,
num_processes_per_machine,
machine_rank,
dist_url,
backend,
return_file,
args,
):
assert dist_url.startswith("env://")
# Read torch.distributed params from env according to the contract in
# https://pytorch.org/docs/stable/distributed.html#environment-variable-initialization
world_size = int(os.environ.get("WORLD_SIZE", 1))
num_processes_per_machine = int(os.environ.get("LOCAL_WORLD_SIZE", 1))
machine_rank = int(os.environ.get("GROUP_RANK", 0))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
num_machines = int(world_size / num_processes_per_machine)
logger.info(
"Loaded distributed params from env."
f" Run with num_processes_per_machine: {num_processes_per_machine},"
f" num_machines: {num_machines}, machine_rank: {machine_rank},"
)
_distributed_worker(
local_rank,
main_func,
world_size,
num_processes_per_machine,
machine_rank,
dist_url,
backend,
return_file,
args,
)
def _distributed_worker(
local_rank,
main_func,
......
......@@ -6,6 +6,7 @@ Detection Training Script.
"""
import logging
import sys
import detectron2.utils.comm as comm
from d2go.distributed import launch
......@@ -89,7 +90,7 @@ def run_with_cmdline_args(args):
)
def cli():
def cli(args):
parser = basic_argument_parser(requires_output_dir=False)
parser.add_argument(
"--eval-only", action="store_true", help="perform evaluation only"
......@@ -99,8 +100,8 @@ def cli():
action="store_true",
help="whether to attempt to resume from the checkpoint directory",
)
run_with_cmdline_args(parser.parse_args())
run_with_cmdline_args(parser.parse_args(args))
if __name__ == "__main__":
cli()
cli(sys.argv)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment