Commit 94dc481a authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

unify DDP launcher for elastic and non-elastic (support elastic launch correctly)

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/274

X-link: https://github.com/facebookresearch/mobile-vision/pull/76

TLDR: this diff consolidate the `distributed_helper` of `mobile_cv`, it (together with `mobile_cv`'s `comm` module) should be the TOGO library for dealing with DDP. D2 (https://github.com/facebookresearch/d2go/commit/87374efb134e539090e0b5c476809dc35bf6aedb)Go's `distributed` is now built on-top of `mobile_cv`'s `distributed_helper`.

Reviewed By: newstzpz

Differential Revision: D36787336

fbshipit-source-id: 640c9dcff5eec534e7894c75cfdf0a12d21c297e
parent 1f45cf04
......@@ -3,64 +3,73 @@
"""
Similar to detectron2.engine.launch, may support a few more things:
- support for get_local_rank.
- support other backends like GLOO.
Extend the mobile_cv.torch.utils_pytorch.distributed_helper to add D2/D2Go specific
features, functions in this module share the same signatures as the ones from mobile_cv.
"""
import logging
import os
import tempfile
from typing import Any, Callable, Dict, Optional, Tuple
import detectron2.utils.comm as comm
import detectron2.utils.comm as d2_comm
import mobile_cv.torch.utils_pytorch.comm as mcv_comm
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from d2go.config import CfgNode, temp_defrost
from d2go.utils.launch_environment import get_launch_environment
from mobile_cv.torch.utils_pytorch.distributed_helper import (
DistributedParams,
enable_dist_process_groups,
launch as _launch,
save_return_deco,
)
logger = logging.getLogger(__name__)
_LOCAL_RANK = 0
_NUM_PROCESSES_PER_MACHINE = 1
def _set_local_rank(local_rank):
global _LOCAL_RANK
_LOCAL_RANK = local_rank
def _set_num_processes_per_machine(num_processes):
global _NUM_PROCESSES_PER_MACHINE
_NUM_PROCESSES_PER_MACHINE = num_processes
# BC-compatible
def get_local_rank():
return _LOCAL_RANK
return mcv_comm.get_local_rank()
# BC-compatible
def get_num_processes_per_machine():
return _NUM_PROCESSES_PER_MACHINE
return mcv_comm.get_local_size()
# Modify mobile_cv's `default_distributed_worker` to also setup D2's comm module
def distributed_worker(
main_func: Callable,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
backend: str,
dist_url: Optional[str] = None,
dist_params: Optional[DistributedParams] = None,
return_save_file: Optional[str] = None,
):
dist_params = dist_params or DistributedParams.from_environ()
with enable_dist_process_groups(backend, dist_url, dist_params):
d2_comm._LOCAL_PROCESS_GROUP = mcv_comm._LOCAL_PROCESS_GROUP
# Now the D2's comm module should be fully functional
deco = save_return_deco(return_save_file, dist_params.global_rank)
return deco(main_func)(*args, **kwargs)
# TODO: merge with d2.engine.launch
def launch(
main_func,
num_processes_per_machine,
num_machines=1,
machine_rank=0,
dist_url=None,
backend="NCCL",
always_spawn=False,
args=(),
main_func: Callable,
num_processes_per_machine: int,
num_machines: int = 1,
machine_rank: int = 0,
dist_url: Optional[str] = None,
backend: str = "NCCL",
always_spawn: bool = False,
launch_method: str = "multiprocessing",
args: Tuple[Any, ...] = (),
kwargs: Dict[str, Any] = None,
):
logger.info(
f"Launch with num_processes_per_machine: {num_processes_per_machine},"
f" num_machines: {num_machines}, machine_rank: {machine_rank},"
f" dist_url: {dist_url}, backend: {backend}."
)
"""
D2Go's specialized launch method, it does a few more things on top of mcv's launch:
- Automatically convert GPU to CPU if CUDA is not available.
- Add D2Go-specific initialziation in the _distributed_worker.
"""
if get_launch_environment() == "local" and not torch.cuda.is_available():
assert len(args) > 0, args
......@@ -74,144 +83,16 @@ def launch(
cfg.MODEL.DEVICE = "cpu"
backend = "GLOO"
if backend == "NCCL":
assert (
num_processes_per_machine <= torch.cuda.device_count()
), "num_processes_per_machine is greater than device count: {} vs {}".format(
num_processes_per_machine, torch.cuda.device_count()
)
world_size = num_machines * num_processes_per_machine
if world_size > 1 or always_spawn:
# https://github.com/pytorch/pytorch/pull/14391
# TODO prctl in spawned processes
prefix = f"detectron2go_{main_func.__module__}.{main_func.__name__}_return"
with tempfile.NamedTemporaryFile(prefix=prefix, suffix=".pth") as f:
return_file = f.name
if dist_url.startswith("env://"):
_run_with_dist_env(
main_func,
world_size,
num_processes_per_machine,
machine_rank,
dist_url,
backend,
return_file,
args,
)
else:
mp.spawn(
_distributed_worker,
nprocs=num_processes_per_machine,
args=(
main_func,
world_size,
num_processes_per_machine,
machine_rank,
dist_url,
backend,
return_file,
args,
),
daemon=False,
)
if machine_rank == 0 and get_local_rank() == 0:
return torch.load(return_file)
else:
return main_func(*args)
def _run_with_dist_env(
main_func,
world_size,
num_processes_per_machine,
machine_rank,
dist_url,
backend,
return_file,
args,
):
assert dist_url.startswith("env://")
# Read torch.distributed params from env according to the contract in
# https://pytorch.org/docs/stable/distributed.html#environment-variable-initialization
world_size = int(os.environ.get("WORLD_SIZE", 1))
num_processes_per_machine = int(os.environ.get("LOCAL_WORLD_SIZE", 1))
machine_rank = int(os.environ.get("GROUP_RANK", 0))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
num_machines = int(world_size / num_processes_per_machine)
logger.info(
"Loaded distributed params from env."
f" Run with num_processes_per_machine: {num_processes_per_machine},"
f" num_machines: {num_machines}, machine_rank: {machine_rank},"
return _launch(
main_func=main_func,
num_processes_per_machine=num_processes_per_machine,
num_machines=num_machines,
machine_rank=machine_rank,
dist_url=dist_url,
backend=backend,
always_spawn=always_spawn,
launch_method=launch_method,
args=args,
kwargs=kwargs,
_distributed_worker=distributed_worker,
)
_distributed_worker(
local_rank,
main_func,
world_size,
num_processes_per_machine,
machine_rank,
dist_url,
backend,
return_file,
args,
)
def _distributed_worker(
local_rank,
main_func,
world_size,
num_processes_per_machine,
machine_rank,
dist_url,
backend,
return_file,
args,
):
assert backend in ["NCCL", "GLOO"]
_set_local_rank(local_rank)
_set_num_processes_per_machine(num_processes_per_machine)
# NOTE: this is wrong if using different number of processes across machine
global_rank = machine_rank * num_processes_per_machine + local_rank
try:
dist.init_process_group(
backend=backend,
init_method=dist_url,
world_size=world_size,
rank=global_rank,
)
except Exception as e:
logger.error("Process group URL: {}".format(dist_url))
raise e
# Setup the local process group (which contains ranks within the same machine)
assert comm._LOCAL_PROCESS_GROUP is None
num_machines = world_size // num_processes_per_machine
for i in range(num_machines):
ranks_on_i = list(
range(i * num_processes_per_machine, (i + 1) * num_processes_per_machine)
)
pg = dist.new_group(ranks_on_i)
if i == machine_rank:
comm._LOCAL_PROCESS_GROUP = pg
if backend in ["NCCL"]:
torch.cuda.set_device(local_rank)
# synchronize is needed here to prevent a possible timeout after calling
# init_process_group
# See: https://github.com/facebookresearch/maskrcnn-benchmark/issues/172
comm.synchronize()
ret = main_func(*args)
if global_rank == 0:
logger.info(
"Save {}.{} return to: {}".format(
main_func.__module__, main_func.__name__, return_file
)
)
torch.save(ret, return_file)
......@@ -196,7 +196,6 @@ def setup_after_launch(
cfg: CfgNode,
output_dir: str,
runner: Optional[BaseRunner] = None,
_scale_world_size: bool = True, # HACK: temporarily allow lightning_train_net to by pass this.
):
"""
Binary-level setup after entering DDP, including
......@@ -233,8 +232,7 @@ def setup_after_launch(
pass
# scale the config after dumping so that dumped config files keep original world size
if _scale_world_size:
auto_scale_world_size(cfg, new_world_size=comm.get_world_size())
auto_scale_world_size(cfg, new_world_size=comm.get_world_size())
def setup_logger(
......
......@@ -11,7 +11,7 @@ from tempfile import TemporaryDirectory
from typing import Optional
import torch
import torch.distributed as dist
from d2go.distributed import distributed_worker, DistributedParams
def get_resource_path(file: Optional[str] = None):
......@@ -51,17 +51,24 @@ def enable_ddp_env(func):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = find_free_port()
dist.init_process_group(
"gloo",
rank=0,
world_size=1,
init_method="file:///tmp/detectron2go_test_ddp_init_{}".format(
return distributed_worker(
main_func=func,
args=args,
kwargs=kwargs,
backend="gloo",
dist_url="file:///tmp/detectron2go_test_ddp_init_{}".format(
uuid.uuid4().hex
),
dist_params=DistributedParams(
local_rank=0,
machine_rank=0,
global_rank=0,
num_processes_per_machine=1,
world_size=1,
),
return_save_file=None, # don't save file
)
ret = func(*args, **kwargs)
dist.destroy_process_group()
return ret
return wrapper
......
......@@ -5,13 +5,12 @@ import os
import unittest
import numpy as np
import torch.distributed as dist
from d2go.config import CfgNode
from d2go.config.utils import flatten_config_dict
from d2go.runner.lightning_task import GeneralizedRCNNTask
from d2go.tools.lightning_train_net import FINAL_MODEL_CKPT, main
from d2go.utils.testing import meta_arch_helper as mah
from d2go.utils.testing.helper import tempdir
from d2go.utils.testing.helper import enable_ddp_env, tempdir
class TestLightningTrainNet(unittest.TestCase):
......@@ -28,6 +27,7 @@ class TestLightningTrainNet(unittest.TestCase):
return mah.create_detection_cfg(GeneralizedRCNNTask, tmp_dir)
@tempdir
@enable_ddp_env
def test_train_net_main(self, root_dir):
"""tests the main training entry point."""
cfg = self._get_cfg(root_dir)
......@@ -36,6 +36,7 @@ class TestLightningTrainNet(unittest.TestCase):
main(cfg, root_dir)
@tempdir
@enable_ddp_env
def test_checkpointing(self, tmp_dir):
"""tests saving and loading from checkpoint."""
cfg = self._get_cfg(tmp_dir)
......@@ -57,7 +58,3 @@ class TestLightningTrainNet(unittest.TestCase):
accuracy2 = flatten_config_dict(out2.accuracy)
for k in accuracy:
np.testing.assert_equal(accuracy[k], accuracy2[k])
def tearDown(self):
if dist.is_initialized():
dist.destroy_process_group()
......@@ -7,8 +7,9 @@ import os
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Type
import mobile_cv.torch.utils_pytorch.comm as comm
import pytorch_lightning as pl # type: ignore
from d2go.config import auto_scale_world_size, CfgNode, temp_defrost
from d2go.config import CfgNode, temp_defrost
from d2go.runner import create_runner
from d2go.runner.callbacks.quantization import QuantizationAwareTraining
from d2go.runner.lightning_task import GeneralizedRCNNTask
......@@ -67,9 +68,7 @@ def _get_accelerator(use_cpu: bool) -> str:
return "cpu" if use_cpu else "gpu"
def get_trainer_params(
cfg: CfgNode, num_machines: int, num_processes: int
) -> Dict[str, Any]:
def get_trainer_params(cfg: CfgNode) -> Dict[str, Any]:
use_cpu = cfg.MODEL.DEVICE.lower() == "cpu"
strategy = _get_strategy(cfg)
accelerator = _get_accelerator(use_cpu)
......@@ -80,8 +79,8 @@ def get_trainer_params(
"val_check_interval": cfg.TEST.EVAL_PERIOD
if cfg.TEST.EVAL_PERIOD > 0
else cfg.SOLVER.MAX_ITER,
"num_nodes": num_machines,
"devices": num_processes,
"num_nodes": comm.get_num_nodes(),
"devices": comm.get_local_size(),
"strategy": strategy,
"accelerator": accelerator,
"callbacks": _get_trainer_callbacks(cfg),
......@@ -138,8 +137,6 @@ def main(
output_dir: str,
task_cls: Type[GeneralizedRCNNTask] = GeneralizedRCNNTask,
eval_only: bool = False,
num_machines: int = 1,
num_processes: int = 1,
) -> TrainOutput:
"""Main function for launching a training with lightning trainer
Args:
......@@ -148,12 +145,10 @@ def main(
num_processes: Number of processes on each node.
eval_only: True if run evaluation only.
"""
# FIXME: make comm.get_world_size() work properly.
setup_after_launch(cfg, output_dir, _scale_world_size=False)
auto_scale_world_size(cfg, new_world_size=num_machines * num_processes)
setup_after_launch(cfg, output_dir)
task = task_cls.from_config(cfg, eval_only)
trainer_params = get_trainer_params(cfg, num_machines, num_processes)
trainer_params = get_trainer_params(cfg)
last_checkpoint = os.path.join(cfg.OUTPUT_DIR, "last.ckpt")
if PathManager.exists(last_checkpoint):
......@@ -212,8 +207,6 @@ if __name__ == "__main__":
args.output_dir,
task_cls,
eval_only=False, # eval_only
num_machines=args.num_machines,
num_processes=args.num_processes,
)
if get_rank() == 0:
print(ret)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment