Commit 69bf820c authored by Tsahi Glik's avatar Tsahi Glik Committed by Facebook GitHub Bot
Browse files

Add shared workers context API

Summary:
X-link: https://github.com/facebookresearch/mobile-vision/pull/116

Pull Request resolved: https://github.com/facebookresearch/d2go/pull/398

D2 (https://github.com/facebookresearch/d2go/commit/87374efb134e539090e0b5c476809dc35bf6aedb)Go doesn't have per node initialization api, but only per worker initialization that happens per subprocess.
Some projects (like IOBT) need to way to do shared initialization before spawning all the workers in subprocess and pass this initialized shared context to the workers.
This diff adds API to create a shared context object before launching workers and then use this shared context by the runners inside the workers after launch.

Reviewed By: wat3rBro

Differential Revision: D40001329

fbshipit-source-id: 231a4e7e4da7b5db50849176c58b104c4565306a
parent e5fece78
......@@ -8,6 +8,7 @@ features, functions in this module share the same signatures as the ones from mo
"""
import logging
from dataclasses import dataclass
from datetime import timedelta
from typing import Any, Callable, Dict, Optional, Tuple
......@@ -16,6 +17,11 @@ import mobile_cv.torch.utils_pytorch.comm as mcv_comm
import torch
from d2go.config import CfgNode, temp_defrost
from d2go.utils.launch_environment import get_launch_environment
from mobile_cv.torch.utils_pytorch.comm import ( # noqa
BaseSharedContext,
get_shared_context,
set_shared_context,
)
from mobile_cv.torch.utils_pytorch.distributed_helper import (
DEFAULT_TIMEOUT,
DistributedParams,
......@@ -24,9 +30,20 @@ from mobile_cv.torch.utils_pytorch.distributed_helper import (
save_return_deco,
)
logger = logging.getLogger(__name__)
@dataclass
class D2GoSharedContext(BaseSharedContext):
"""
Shared context that can be initialied before launching the workers
passed to all workers.
"""
runner_shared_context: Any
# BC-compatible
def get_local_rank():
return mcv_comm.get_local_rank()
......@@ -47,7 +64,12 @@ def distributed_worker(
dist_params: Optional[DistributedParams] = None,
return_save_file: Optional[str] = None,
timeout: timedelta = DEFAULT_TIMEOUT,
shared_context: Optional[BaseSharedContext] = None,
):
if shared_context:
set_shared_context(
shared_context
) # set the global shared context from the args passed in by mp spawn
dist_params = dist_params or DistributedParams.from_environ()
with enable_dist_process_groups(backend, init_method, dist_params, timeout):
d2_comm._LOCAL_PROCESS_GROUP = mcv_comm._LOCAL_PROCESS_GROUP
......@@ -65,6 +87,7 @@ def launch(
backend: str = "NCCL",
always_spawn: bool = False,
launch_method: str = "multiprocessing",
shared_context: Optional[D2GoSharedContext] = None,
timeout: timedelta = DEFAULT_TIMEOUT,
args: Tuple[Any, ...] = (),
kwargs: Dict[str, Any] = None,
......@@ -96,6 +119,7 @@ def launch(
backend=backend,
always_spawn=always_spawn,
launch_method=launch_method,
shared_context=shared_context,
timeout=timeout,
args=args,
kwargs=kwargs,
......
......@@ -22,6 +22,7 @@ from d2go.data.utils import (
maybe_subsample_n_images,
update_cfg_if_using_adhoc_dataset,
)
from d2go.distributed import D2GoSharedContext
from d2go.evaluation.evaluator import inference_on_dataset
from d2go.modeling import kmeans_anchors, model_ema
from d2go.modeling.api import build_d2go_model
......@@ -60,7 +61,6 @@ from mobile_cv.common.misc.oss_utils import fb_overwritable
from mobile_cv.predictor.api import PredictorWrapper
from torch import nn
logger = logging.getLogger(__name__)
......@@ -159,6 +159,13 @@ class BaseRunner(object):
"""
pass
@classmethod
def create_shared_context(cls, cfg) -> D2GoSharedContext:
"""
Override `create_shared_context` in order to run customized code to create distributed shared context that can be accessed by all workers
"""
pass
@classmethod
def get_default_cfg(cls):
return get_base_runner_default_cfg(CfgNode())
......
......@@ -18,7 +18,11 @@ from d2go.config import (
temp_defrost,
)
from d2go.config.utils import get_diff_cfg
from d2go.distributed import get_local_rank, get_num_processes_per_machine
from d2go.distributed import (
D2GoSharedContext,
get_local_rank,
get_num_processes_per_machine,
)
from d2go.runner import BaseRunner, DefaultTask, import_runner, RunnerV2Mixin
from d2go.utils.helper import run_once
from d2go.utils.launch_environment import get_launch_environment
......@@ -202,6 +206,24 @@ def maybe_override_output_dir(cfg: CfgNode, output_dir: str):
cfg.OUTPUT_DIR = output_dir
def setup_before_launch(
cfg: CfgNode,
output_dir: str,
runner_class: Union[None, str, Type[BaseRunner], Type[DefaultTask]],
) -> Union[None, D2GoSharedContext]:
"""
Setup logic before spawning workers. Including:
- Shared context initilization to be passed to all workers
"""
if isinstance(runner_class, str):
logger.info(f"Importing runner: {runner_class} ...")
runner_class = import_runner(runner_class)
if hasattr(runner_class, "create_shared_context"):
return runner_class.create_shared_context(cfg)
return None
def setup_after_launch(
cfg: CfgNode,
output_dir: str,
......
......@@ -23,6 +23,7 @@ from d2go.setup import (
post_mortem_if_fail_for_main,
prepare_for_launch,
setup_after_launch,
setup_before_launch,
)
from d2go.utils.misc import print_metrics_table
from mobile_cv.predictor.api import create_predictor
......@@ -66,6 +67,7 @@ def main(
def run_with_cmdline_args(args):
cfg, output_dir, runner_name = prepare_for_launch(args)
shared_context = setup_before_launch(cfg, output_dir, runner_name)
main_func = main if args.disable_post_mortem else post_mortem_if_fail_for_main(main)
launch(
main_func,
......@@ -75,6 +77,7 @@ def run_with_cmdline_args(args):
dist_url=args.dist_url,
backend="GLOO",
always_spawn=False,
shared_context=shared_context,
args=(cfg, output_dir, runner_name),
kwargs={
"predictor_path": args.predictor_path,
......
......@@ -19,6 +19,7 @@ from d2go.setup import (
post_mortem_if_fail_for_main,
prepare_for_launch,
setup_after_launch,
setup_before_launch,
)
from d2go.trainer.api import TrainNetOutput
from d2go.utils.misc import (
......@@ -93,6 +94,7 @@ def main(
def run_with_cmdline_args(args):
cfg, output_dir, runner_name = prepare_for_launch(args)
shared_context = setup_before_launch(cfg, output_dir, runner_name)
main_func = main if args.disable_post_mortem else post_mortem_if_fail_for_main(main)
outputs = launch(
......@@ -102,6 +104,7 @@ def run_with_cmdline_args(args):
machine_rank=args.machine_rank,
dist_url=args.dist_url,
backend=args.dist_backend,
shared_context=shared_context,
args=(cfg, output_dir, runner_name),
kwargs={
"eval_only": args.eval_only,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment