Commit 7d35bae7 authored by Jessica Zhong's avatar Jessica Zhong Committed by Facebook GitHub Bot
Browse files

Enable Torch Elastic Launch on Mast in D2go

Summary: Pull Request resolved: https://github.com/facebookresearch/d2go/pull/566

Reviewed By: wat3rBro

Differential Revision: D45829249

fbshipit-source-id: 4e70bed0e85179b49b4e2358be3d937cfbf474d4
parent 87956d50
...@@ -73,6 +73,19 @@ def distributed_worker( ...@@ -73,6 +73,19 @@ def distributed_worker(
shared_context shared_context
) # set the global shared context from the args passed in by mp spawn ) # set the global shared context from the args passed in by mp spawn
dist_params = dist_params or DistributedParams.from_environ() dist_params = dist_params or DistributedParams.from_environ()
if get_launch_environment() == "local" and not torch.cuda.is_available():
assert len(args) > 0, args
cfg = args[0]
if isinstance(cfg, CfgNode) and cfg.MODEL.DEVICE == "cuda":
logger.warning(
"Detected that CUDA is not available on this machine, set MODEL.DEVICE"
" to cpu and backend to GLOO"
)
with temp_defrost(cfg):
cfg.MODEL.DEVICE = "cpu"
args.backend = "GLOO"
with enable_dist_process_groups(backend, init_method, dist_params, timeout): with enable_dist_process_groups(backend, init_method, dist_params, timeout):
d2_comm._LOCAL_PROCESS_GROUP = mcv_comm._LOCAL_PROCESS_GROUP d2_comm._LOCAL_PROCESS_GROUP = mcv_comm._LOCAL_PROCESS_GROUP
# Now the D2's comm module should be fully functional # Now the D2's comm module should be fully functional
...@@ -106,19 +119,6 @@ def launch( ...@@ -106,19 +119,6 @@ def launch(
- Automatically convert GPU to CPU if CUDA is not available. - Automatically convert GPU to CPU if CUDA is not available.
- Add D2Go-specific initialziation in the _distributed_worker. - Add D2Go-specific initialziation in the _distributed_worker.
""" """
if get_launch_environment() == "local" and not torch.cuda.is_available():
assert len(args) > 0, args
cfg = args[0]
if isinstance(cfg, CfgNode) and cfg.MODEL.DEVICE == "cuda":
logger.warning(
"Detected that CUDA is not available on this machine, set MODEL.DEVICE"
" to cpu and backend to GLOO"
)
with temp_defrost(cfg):
cfg.MODEL.DEVICE = "cpu"
backend = "GLOO"
return _launch( return _launch(
main_func=main_func, main_func=main_func,
num_processes_per_machine=num_processes_per_machine, num_processes_per_machine=num_processes_per_machine,
......
...@@ -104,6 +104,7 @@ def basic_argument_parser( ...@@ -104,6 +104,7 @@ def basic_argument_parser(
"--num-processes", type=int, default=1, help="number of gpus per machine" "--num-processes", type=int, default=1, help="number of gpus per machine"
) )
parser.add_argument("--num-machines", type=int, default=1) parser.add_argument("--num-machines", type=int, default=1)
parser.add_argument("--run-as-worker", type=bool, default=False)
parser.add_argument( parser.add_argument(
"--machine-rank", "--machine-rank",
type=int, type=int,
...@@ -129,6 +130,7 @@ def build_basic_cli_args( ...@@ -129,6 +130,7 @@ def build_basic_cli_args(
dist_url: Optional[str] = None, dist_url: Optional[str] = None,
dist_backend: Optional[str] = None, dist_backend: Optional[str] = None,
disable_post_mortem: bool = False, disable_post_mortem: bool = False,
run_as_worker: bool = False,
) -> List[str]: ) -> List[str]:
""" """
Returns parameters in the form of CLI arguments for the binary using Returns parameters in the form of CLI arguments for the binary using
...@@ -147,6 +149,8 @@ def build_basic_cli_args( ...@@ -147,6 +149,8 @@ def build_basic_cli_args(
args += ["--save-return-file", str(save_return_file)] args += ["--save-return-file", str(save_return_file)]
if disable_post_mortem: if disable_post_mortem:
args += ["--disable-post-mortem"] args += ["--disable-post-mortem"]
if run_as_worker:
args += ["--run-as-worker", str(run_as_worker)]
if num_processes is not None: if num_processes is not None:
args += ["--num-processes", str(num_processes)] args += ["--num-processes", str(num_processes)]
if num_machines is not None: if num_machines is not None:
......
...@@ -11,7 +11,7 @@ from typing import List, Type, Union ...@@ -11,7 +11,7 @@ from typing import List, Type, Union
import detectron2.utils.comm as comm import detectron2.utils.comm as comm
from d2go.config import CfgNode from d2go.config import CfgNode
from d2go.distributed import launch from d2go.distributed import distributed_worker, launch
from d2go.runner import BaseRunner from d2go.runner import BaseRunner
from d2go.setup import ( from d2go.setup import (
basic_argument_parser, basic_argument_parser,
...@@ -141,24 +141,42 @@ def run_with_cmdline_args(args): ...@@ -141,24 +141,42 @@ def run_with_cmdline_args(args):
shared_context = setup_before_launch(cfg, output_dir, runner_name) shared_context = setup_before_launch(cfg, output_dir, runner_name)
main_func = main if args.disable_post_mortem else post_mortem_if_fail_for_main(main) main_func = main if args.disable_post_mortem else post_mortem_if_fail_for_main(main)
outputs = launch(
main_func, if args.run_as_worker:
num_processes_per_machine=args.num_processes, logger.info("Running as worker")
num_machines=args.num_machines, result = distributed_worker(
machine_rank=args.machine_rank, main_func,
dist_url=args.dist_url, args=(cfg, output_dir, runner_name),
backend=args.dist_backend, kwargs={
shared_context=shared_context, "eval_only": args.eval_only,
args=(cfg, output_dir, runner_name), "resume": args.resume,
kwargs={ },
"eval_only": args.eval_only, backend=args.dist_backend,
"resume": args.resume, init_method=None, # init_method is env by default
}, dist_params=None,
) return_save_file=None,
shared_context=shared_context,
)
else:
outputs = launch(
main_func,
num_processes_per_machine=args.num_processes,
num_machines=args.num_machines,
machine_rank=args.machine_rank,
dist_url=args.dist_url,
backend=args.dist_backend,
shared_context=shared_context,
args=(cfg, output_dir, runner_name),
kwargs={
"eval_only": args.eval_only,
"resume": args.resume,
},
)
result = outputs[0]
# Only save results from global rank 0 for consistency. # Only save results from global rank 0 for consistency.
if args.save_return_file is not None and args.machine_rank == 0: if args.save_return_file is not None and args.machine_rank == 0:
save_binary_outputs(args.save_return_file, outputs[0]) save_binary_outputs(args.save_return_file, result)
def cli(args=None): def cli(args=None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment