Commit 94dc481a authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

unify DDP launcher for elastic and non-elastic (support elastic launch correctly)

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/274

X-link: https://github.com/facebookresearch/mobile-vision/pull/76

TLDR: this diff consolidate the `distributed_helper` of `mobile_cv`, it (together with `mobile_cv`'s `comm` module) should be the TOGO library for dealing with DDP. D2 (https://github.com/facebookresearch/d2go/commit/87374efb134e539090e0b5c476809dc35bf6aedb)Go's `distributed` is now built on-top of `mobile_cv`'s `distributed_helper`.

Reviewed By: newstzpz

Differential Revision: D36787336

fbshipit-source-id: 640c9dcff5eec534e7894c75cfdf0a12d21c297e
parent 1f45cf04
...@@ -3,64 +3,73 @@ ...@@ -3,64 +3,73 @@
""" """
Similar to detectron2.engine.launch, may support a few more things: Extend the mobile_cv.torch.utils_pytorch.distributed_helper to add D2/D2Go specific
- support for get_local_rank. features, functions in this module share the same signatures as the ones from mobile_cv.
- support other backends like GLOO.
""" """
import logging import logging
import os from typing import Any, Callable, Dict, Optional, Tuple
import tempfile
import detectron2.utils.comm as comm import detectron2.utils.comm as d2_comm
import mobile_cv.torch.utils_pytorch.comm as mcv_comm
import torch import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from d2go.config import CfgNode, temp_defrost from d2go.config import CfgNode, temp_defrost
from d2go.utils.launch_environment import get_launch_environment from d2go.utils.launch_environment import get_launch_environment
from mobile_cv.torch.utils_pytorch.distributed_helper import (
DistributedParams,
enable_dist_process_groups,
launch as _launch,
save_return_deco,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_LOCAL_RANK = 0 # BC-compatible
_NUM_PROCESSES_PER_MACHINE = 1
def _set_local_rank(local_rank):
global _LOCAL_RANK
_LOCAL_RANK = local_rank
def _set_num_processes_per_machine(num_processes):
global _NUM_PROCESSES_PER_MACHINE
_NUM_PROCESSES_PER_MACHINE = num_processes
def get_local_rank(): def get_local_rank():
return _LOCAL_RANK return mcv_comm.get_local_rank()
# BC-compatible
def get_num_processes_per_machine(): def get_num_processes_per_machine():
return _NUM_PROCESSES_PER_MACHINE return mcv_comm.get_local_size()
# Modify mobile_cv's `default_distributed_worker` to also setup D2's comm module
def distributed_worker(
main_func: Callable,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
backend: str,
dist_url: Optional[str] = None,
dist_params: Optional[DistributedParams] = None,
return_save_file: Optional[str] = None,
):
dist_params = dist_params or DistributedParams.from_environ()
with enable_dist_process_groups(backend, dist_url, dist_params):
d2_comm._LOCAL_PROCESS_GROUP = mcv_comm._LOCAL_PROCESS_GROUP
# Now the D2's comm module should be fully functional
deco = save_return_deco(return_save_file, dist_params.global_rank)
return deco(main_func)(*args, **kwargs)
# TODO: merge with d2.engine.launch
def launch( def launch(
main_func, main_func: Callable,
num_processes_per_machine, num_processes_per_machine: int,
num_machines=1, num_machines: int = 1,
machine_rank=0, machine_rank: int = 0,
dist_url=None, dist_url: Optional[str] = None,
backend="NCCL", backend: str = "NCCL",
always_spawn=False, always_spawn: bool = False,
args=(), launch_method: str = "multiprocessing",
args: Tuple[Any, ...] = (),
kwargs: Dict[str, Any] = None,
): ):
logger.info( """
f"Launch with num_processes_per_machine: {num_processes_per_machine}," D2Go's specialized launch method, it does a few more things on top of mcv's launch:
f" num_machines: {num_machines}, machine_rank: {machine_rank}," - Automatically convert GPU to CPU if CUDA is not available.
f" dist_url: {dist_url}, backend: {backend}." - Add D2Go-specific initialziation in the _distributed_worker.
) """
if get_launch_environment() == "local" and not torch.cuda.is_available(): if get_launch_environment() == "local" and not torch.cuda.is_available():
assert len(args) > 0, args assert len(args) > 0, args
...@@ -74,144 +83,16 @@ def launch( ...@@ -74,144 +83,16 @@ def launch(
cfg.MODEL.DEVICE = "cpu" cfg.MODEL.DEVICE = "cpu"
backend = "GLOO" backend = "GLOO"
if backend == "NCCL": return _launch(
assert ( main_func=main_func,
num_processes_per_machine <= torch.cuda.device_count() num_processes_per_machine=num_processes_per_machine,
), "num_processes_per_machine is greater than device count: {} vs {}".format( num_machines=num_machines,
num_processes_per_machine, torch.cuda.device_count() machine_rank=machine_rank,
) dist_url=dist_url,
backend=backend,
world_size = num_machines * num_processes_per_machine always_spawn=always_spawn,
if world_size > 1 or always_spawn: launch_method=launch_method,
# https://github.com/pytorch/pytorch/pull/14391 args=args,
# TODO prctl in spawned processes kwargs=kwargs,
prefix = f"detectron2go_{main_func.__module__}.{main_func.__name__}_return" _distributed_worker=distributed_worker,
with tempfile.NamedTemporaryFile(prefix=prefix, suffix=".pth") as f:
return_file = f.name
if dist_url.startswith("env://"):
_run_with_dist_env(
main_func,
world_size,
num_processes_per_machine,
machine_rank,
dist_url,
backend,
return_file,
args,
)
else:
mp.spawn(
_distributed_worker,
nprocs=num_processes_per_machine,
args=(
main_func,
world_size,
num_processes_per_machine,
machine_rank,
dist_url,
backend,
return_file,
args,
),
daemon=False,
)
if machine_rank == 0 and get_local_rank() == 0:
return torch.load(return_file)
else:
return main_func(*args)
def _run_with_dist_env(
main_func,
world_size,
num_processes_per_machine,
machine_rank,
dist_url,
backend,
return_file,
args,
):
assert dist_url.startswith("env://")
# Read torch.distributed params from env according to the contract in
# https://pytorch.org/docs/stable/distributed.html#environment-variable-initialization
world_size = int(os.environ.get("WORLD_SIZE", 1))
num_processes_per_machine = int(os.environ.get("LOCAL_WORLD_SIZE", 1))
machine_rank = int(os.environ.get("GROUP_RANK", 0))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
num_machines = int(world_size / num_processes_per_machine)
logger.info(
"Loaded distributed params from env."
f" Run with num_processes_per_machine: {num_processes_per_machine},"
f" num_machines: {num_machines}, machine_rank: {machine_rank},"
) )
_distributed_worker(
local_rank,
main_func,
world_size,
num_processes_per_machine,
machine_rank,
dist_url,
backend,
return_file,
args,
)
def _distributed_worker(
local_rank,
main_func,
world_size,
num_processes_per_machine,
machine_rank,
dist_url,
backend,
return_file,
args,
):
assert backend in ["NCCL", "GLOO"]
_set_local_rank(local_rank)
_set_num_processes_per_machine(num_processes_per_machine)
# NOTE: this is wrong if using different number of processes across machine
global_rank = machine_rank * num_processes_per_machine + local_rank
try:
dist.init_process_group(
backend=backend,
init_method=dist_url,
world_size=world_size,
rank=global_rank,
)
except Exception as e:
logger.error("Process group URL: {}".format(dist_url))
raise e
# Setup the local process group (which contains ranks within the same machine)
assert comm._LOCAL_PROCESS_GROUP is None
num_machines = world_size // num_processes_per_machine
for i in range(num_machines):
ranks_on_i = list(
range(i * num_processes_per_machine, (i + 1) * num_processes_per_machine)
)
pg = dist.new_group(ranks_on_i)
if i == machine_rank:
comm._LOCAL_PROCESS_GROUP = pg
if backend in ["NCCL"]:
torch.cuda.set_device(local_rank)
# synchronize is needed here to prevent a possible timeout after calling
# init_process_group
# See: https://github.com/facebookresearch/maskrcnn-benchmark/issues/172
comm.synchronize()
ret = main_func(*args)
if global_rank == 0:
logger.info(
"Save {}.{} return to: {}".format(
main_func.__module__, main_func.__name__, return_file
)
)
torch.save(ret, return_file)
...@@ -196,7 +196,6 @@ def setup_after_launch( ...@@ -196,7 +196,6 @@ def setup_after_launch(
cfg: CfgNode, cfg: CfgNode,
output_dir: str, output_dir: str,
runner: Optional[BaseRunner] = None, runner: Optional[BaseRunner] = None,
_scale_world_size: bool = True, # HACK: temporarily allow lightning_train_net to by pass this.
): ):
""" """
Binary-level setup after entering DDP, including Binary-level setup after entering DDP, including
...@@ -233,8 +232,7 @@ def setup_after_launch( ...@@ -233,8 +232,7 @@ def setup_after_launch(
pass pass
# scale the config after dumping so that dumped config files keep original world size # scale the config after dumping so that dumped config files keep original world size
if _scale_world_size: auto_scale_world_size(cfg, new_world_size=comm.get_world_size())
auto_scale_world_size(cfg, new_world_size=comm.get_world_size())
def setup_logger( def setup_logger(
......
...@@ -11,7 +11,7 @@ from tempfile import TemporaryDirectory ...@@ -11,7 +11,7 @@ from tempfile import TemporaryDirectory
from typing import Optional from typing import Optional
import torch import torch
import torch.distributed as dist from d2go.distributed import distributed_worker, DistributedParams
def get_resource_path(file: Optional[str] = None): def get_resource_path(file: Optional[str] = None):
...@@ -51,17 +51,24 @@ def enable_ddp_env(func): ...@@ -51,17 +51,24 @@ def enable_ddp_env(func):
os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = find_free_port() os.environ["MASTER_PORT"] = find_free_port()
dist.init_process_group(
"gloo", return distributed_worker(
rank=0, main_func=func,
world_size=1, args=args,
init_method="file:///tmp/detectron2go_test_ddp_init_{}".format( kwargs=kwargs,
backend="gloo",
dist_url="file:///tmp/detectron2go_test_ddp_init_{}".format(
uuid.uuid4().hex uuid.uuid4().hex
), ),
dist_params=DistributedParams(
local_rank=0,
machine_rank=0,
global_rank=0,
num_processes_per_machine=1,
world_size=1,
),
return_save_file=None, # don't save file
) )
ret = func(*args, **kwargs)
dist.destroy_process_group()
return ret
return wrapper return wrapper
......
...@@ -5,13 +5,12 @@ import os ...@@ -5,13 +5,12 @@ import os
import unittest import unittest
import numpy as np import numpy as np
import torch.distributed as dist
from d2go.config import CfgNode from d2go.config import CfgNode
from d2go.config.utils import flatten_config_dict from d2go.config.utils import flatten_config_dict
from d2go.runner.lightning_task import GeneralizedRCNNTask from d2go.runner.lightning_task import GeneralizedRCNNTask
from d2go.tools.lightning_train_net import FINAL_MODEL_CKPT, main from d2go.tools.lightning_train_net import FINAL_MODEL_CKPT, main
from d2go.utils.testing import meta_arch_helper as mah from d2go.utils.testing import meta_arch_helper as mah
from d2go.utils.testing.helper import tempdir from d2go.utils.testing.helper import enable_ddp_env, tempdir
class TestLightningTrainNet(unittest.TestCase): class TestLightningTrainNet(unittest.TestCase):
...@@ -28,6 +27,7 @@ class TestLightningTrainNet(unittest.TestCase): ...@@ -28,6 +27,7 @@ class TestLightningTrainNet(unittest.TestCase):
return mah.create_detection_cfg(GeneralizedRCNNTask, tmp_dir) return mah.create_detection_cfg(GeneralizedRCNNTask, tmp_dir)
@tempdir @tempdir
@enable_ddp_env
def test_train_net_main(self, root_dir): def test_train_net_main(self, root_dir):
"""tests the main training entry point.""" """tests the main training entry point."""
cfg = self._get_cfg(root_dir) cfg = self._get_cfg(root_dir)
...@@ -36,6 +36,7 @@ class TestLightningTrainNet(unittest.TestCase): ...@@ -36,6 +36,7 @@ class TestLightningTrainNet(unittest.TestCase):
main(cfg, root_dir) main(cfg, root_dir)
@tempdir @tempdir
@enable_ddp_env
def test_checkpointing(self, tmp_dir): def test_checkpointing(self, tmp_dir):
"""tests saving and loading from checkpoint.""" """tests saving and loading from checkpoint."""
cfg = self._get_cfg(tmp_dir) cfg = self._get_cfg(tmp_dir)
...@@ -57,7 +58,3 @@ class TestLightningTrainNet(unittest.TestCase): ...@@ -57,7 +58,3 @@ class TestLightningTrainNet(unittest.TestCase):
accuracy2 = flatten_config_dict(out2.accuracy) accuracy2 = flatten_config_dict(out2.accuracy)
for k in accuracy: for k in accuracy:
np.testing.assert_equal(accuracy[k], accuracy2[k]) np.testing.assert_equal(accuracy[k], accuracy2[k])
def tearDown(self):
if dist.is_initialized():
dist.destroy_process_group()
...@@ -7,8 +7,9 @@ import os ...@@ -7,8 +7,9 @@ import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Type from typing import Any, Dict, List, Optional, Type
import mobile_cv.torch.utils_pytorch.comm as comm
import pytorch_lightning as pl # type: ignore import pytorch_lightning as pl # type: ignore
from d2go.config import auto_scale_world_size, CfgNode, temp_defrost from d2go.config import CfgNode, temp_defrost
from d2go.runner import create_runner from d2go.runner import create_runner
from d2go.runner.callbacks.quantization import QuantizationAwareTraining from d2go.runner.callbacks.quantization import QuantizationAwareTraining
from d2go.runner.lightning_task import GeneralizedRCNNTask from d2go.runner.lightning_task import GeneralizedRCNNTask
...@@ -67,9 +68,7 @@ def _get_accelerator(use_cpu: bool) -> str: ...@@ -67,9 +68,7 @@ def _get_accelerator(use_cpu: bool) -> str:
return "cpu" if use_cpu else "gpu" return "cpu" if use_cpu else "gpu"
def get_trainer_params( def get_trainer_params(cfg: CfgNode) -> Dict[str, Any]:
cfg: CfgNode, num_machines: int, num_processes: int
) -> Dict[str, Any]:
use_cpu = cfg.MODEL.DEVICE.lower() == "cpu" use_cpu = cfg.MODEL.DEVICE.lower() == "cpu"
strategy = _get_strategy(cfg) strategy = _get_strategy(cfg)
accelerator = _get_accelerator(use_cpu) accelerator = _get_accelerator(use_cpu)
...@@ -80,8 +79,8 @@ def get_trainer_params( ...@@ -80,8 +79,8 @@ def get_trainer_params(
"val_check_interval": cfg.TEST.EVAL_PERIOD "val_check_interval": cfg.TEST.EVAL_PERIOD
if cfg.TEST.EVAL_PERIOD > 0 if cfg.TEST.EVAL_PERIOD > 0
else cfg.SOLVER.MAX_ITER, else cfg.SOLVER.MAX_ITER,
"num_nodes": num_machines, "num_nodes": comm.get_num_nodes(),
"devices": num_processes, "devices": comm.get_local_size(),
"strategy": strategy, "strategy": strategy,
"accelerator": accelerator, "accelerator": accelerator,
"callbacks": _get_trainer_callbacks(cfg), "callbacks": _get_trainer_callbacks(cfg),
...@@ -138,8 +137,6 @@ def main( ...@@ -138,8 +137,6 @@ def main(
output_dir: str, output_dir: str,
task_cls: Type[GeneralizedRCNNTask] = GeneralizedRCNNTask, task_cls: Type[GeneralizedRCNNTask] = GeneralizedRCNNTask,
eval_only: bool = False, eval_only: bool = False,
num_machines: int = 1,
num_processes: int = 1,
) -> TrainOutput: ) -> TrainOutput:
"""Main function for launching a training with lightning trainer """Main function for launching a training with lightning trainer
Args: Args:
...@@ -148,12 +145,10 @@ def main( ...@@ -148,12 +145,10 @@ def main(
num_processes: Number of processes on each node. num_processes: Number of processes on each node.
eval_only: True if run evaluation only. eval_only: True if run evaluation only.
""" """
# FIXME: make comm.get_world_size() work properly. setup_after_launch(cfg, output_dir)
setup_after_launch(cfg, output_dir, _scale_world_size=False)
auto_scale_world_size(cfg, new_world_size=num_machines * num_processes)
task = task_cls.from_config(cfg, eval_only) task = task_cls.from_config(cfg, eval_only)
trainer_params = get_trainer_params(cfg, num_machines, num_processes) trainer_params = get_trainer_params(cfg)
last_checkpoint = os.path.join(cfg.OUTPUT_DIR, "last.ckpt") last_checkpoint = os.path.join(cfg.OUTPUT_DIR, "last.ckpt")
if PathManager.exists(last_checkpoint): if PathManager.exists(last_checkpoint):
...@@ -212,8 +207,6 @@ if __name__ == "__main__": ...@@ -212,8 +207,6 @@ if __name__ == "__main__":
args.output_dir, args.output_dir,
task_cls, task_cls,
eval_only=False, # eval_only eval_only=False, # eval_only
num_machines=args.num_machines,
num_processes=args.num_processes,
) )
if get_rank() == 0: if get_rank() == 0:
print(ret) print(ret)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment