"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "729178c7cb3984f1ac04dd89b1610e95abf5cb4a"
Commit 95e429a1 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

minor update of result gathering logic

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/596

`outputs = {0: result}` feels a bit hacky, technically it should be `outputs = {worker_rank: result}` in order to match the `outputs` semantic in the else-branch.

Reviewed By: frabu6

Differential Revision: D47442322

fbshipit-source-id: f4d24f7022971b4f919b4fb4a563164c7f71cd2b
parent bbfdc182
...@@ -6,7 +6,7 @@ import argparse ...@@ -6,7 +6,7 @@ import argparse
import logging import logging
import os import os
import time import time
from typing import List, Optional, Tuple, Type, Union from typing import Callable, List, Optional, Tuple, Type, TypeVar, Union
import detectron2.utils.comm as comm import detectron2.utils.comm as comm
import torch import torch
...@@ -39,6 +39,8 @@ from mobile_cv.common.misc.py import FolderLock, MultiprocessingPdb, post_mortem ...@@ -39,6 +39,8 @@ from mobile_cv.common.misc.py import FolderLock, MultiprocessingPdb, post_mortem
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_RT = TypeVar("_RT")
@run_once() @run_once()
def setup_root_logger(logging_level: int = logging.INFO) -> None: def setup_root_logger(logging_level: int = logging.INFO) -> None:
...@@ -214,7 +216,7 @@ def create_cfg_from_cli( ...@@ -214,7 +216,7 @@ def create_cfg_from_cli(
def prepare_for_launch( def prepare_for_launch(
args, args,
) -> Tuple[CfgNode, str, Optional[str]]: ) -> Tuple[CfgNode, str, str]:
""" """
Load config, figure out working directory, create runner. Load config, figure out working directory, create runner.
- when args.config_file is empty, returned cfg will be the default one - when args.config_file is empty, returned cfg will be the default one
...@@ -436,8 +438,8 @@ def caffe2_global_init(logging_print_net_summary=0, num_threads=None): ...@@ -436,8 +438,8 @@ def caffe2_global_init(logging_print_net_summary=0, num_threads=None):
logger.info("Using {} threads after GlobalInit".format(torch.get_num_threads())) logger.info("Using {} threads after GlobalInit".format(torch.get_num_threads()))
def post_mortem_if_fail_for_main(main_func): def post_mortem_if_fail_for_main(main_func: Callable[..., _RT]) -> Callable[..., _RT]:
def new_main_func(cfg, output_dir, *args, **kwargs): def new_main_func(cfg, output_dir, *args, **kwargs) -> _RT:
pdb_ = ( pdb_ = (
MultiprocessingPdb(FolderLock(output_dir)) MultiprocessingPdb(FolderLock(output_dir))
if comm.get_world_size() > 1 if comm.get_world_size() > 1
......
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import logging import logging
from typing import Callable from typing import Callable, TypeVar
from torch.distributed.elastic.multiprocessing.errors import ( from torch.distributed.elastic.multiprocessing.errors import (
_NOT_AVAILABLE, _NOT_AVAILABLE,
...@@ -11,9 +11,11 @@ from torch.distributed.elastic.multiprocessing.errors import ( ...@@ -11,9 +11,11 @@ from torch.distributed.elastic.multiprocessing.errors import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_RT = TypeVar("_RT")
def mast_error_handler(func: Callable) -> Callable:
def wrapper(*args, **kwargs): def mast_error_handler(func: Callable[..., _RT]) -> Callable[..., _RT]:
def wrapper(*args, **kwargs) -> _RT:
logger.info("Starting main") logger.info("Starting main")
error_handler = get_error_handler() error_handler = get_error_handler()
logger.debug(f"Error handler is: {type(error_handler)=}, {error_handler=}") logger.debug(f"Error handler is: {type(error_handler)=}, {error_handler=}")
...@@ -44,11 +46,11 @@ def mast_error_handler(func: Callable) -> Callable: ...@@ -44,11 +46,11 @@ def mast_error_handler(func: Callable) -> Callable:
return wrapper return wrapper
def gather_mast_errors(func: Callable) -> Callable: def gather_mast_errors(func: Callable[..., _RT]) -> Callable[..., _RT]:
def wrapper(*args, **kwargs) -> None: def wrapper(*args, **kwargs) -> _RT:
logger.info("Starting CLI application") logger.info("Starting CLI application")
try: try:
func(*args, **kwargs) return func(*args, **kwargs)
finally: finally:
logging.info("Entering final reply file generation step") logging.info("Entering final reply file generation step")
import glob import glob
......
...@@ -7,7 +7,7 @@ Detection Training Script. ...@@ -7,7 +7,7 @@ Detection Training Script.
import logging import logging
import sys import sys
from typing import Callable, List, Type, Union from typing import Callable, Dict, List, Type, Union
import detectron2.utils.comm as comm import detectron2.utils.comm as comm
from d2go.config import CfgNode from d2go.config import CfgNode
...@@ -38,13 +38,16 @@ logger = logging.getLogger("d2go.tools.train_net") ...@@ -38,13 +38,16 @@ logger = logging.getLogger("d2go.tools.train_net")
setup_root_logger() setup_root_logger()
TrainOrTestNetOutput = Union[TrainNetOutput, TestNetOutput]
def main( def main(
cfg: CfgNode, cfg: CfgNode,
output_dir: str, output_dir: str,
runner_class: Union[str, Type[BaseRunner]], runner_class: Union[str, Type[BaseRunner]],
eval_only: bool = False, eval_only: bool = False,
resume: bool = True, # NOTE: always enable resume when running on cluster resume: bool = True, # NOTE: always enable resume when running on cluster
) -> Union[TrainNetOutput, TestNetOutput]: ) -> TrainOrTestNetOutput:
logger.debug(f"Entered main for d2go, {runner_class=}") logger.debug(f"Entered main for d2go, {runner_class=}")
runner = setup_after_launch(cfg, output_dir, runner_class) runner = setup_after_launch(cfg, output_dir, runner_class)
...@@ -102,7 +105,7 @@ def main( ...@@ -102,7 +105,7 @@ def main(
) )
def wrapped_main(*args, **kwargs) -> Callable: def wrapped_main(*args, **kwargs) -> Callable[..., TrainOrTestNetOutput]:
return mast_error_handler(main)(*args, **kwargs) return mast_error_handler(main)(*args, **kwargs)
...@@ -118,7 +121,7 @@ def run_with_cmdline_args(args): ...@@ -118,7 +121,7 @@ def run_with_cmdline_args(args):
if args.run_as_worker: if args.run_as_worker:
logger.info("Running as worker") logger.info("Running as worker")
result = distributed_worker( result: TrainOrTestNetOutput = distributed_worker(
main_func, main_func,
args=(cfg, output_dir, runner_name), args=(cfg, output_dir, runner_name),
kwargs={ kwargs={
...@@ -131,9 +134,8 @@ def run_with_cmdline_args(args): ...@@ -131,9 +134,8 @@ def run_with_cmdline_args(args):
return_save_file=None, return_save_file=None,
shared_context=shared_context, shared_context=shared_context,
) )
outputs = {0: result}
else: else:
outputs = launch( outputs: Dict[int, TrainOrTestNetOutput] = launch(
main_func, main_func,
num_processes_per_machine=args.num_processes, num_processes_per_machine=args.num_processes,
num_machines=args.num_machines, num_machines=args.num_machines,
...@@ -147,12 +149,15 @@ def run_with_cmdline_args(args): ...@@ -147,12 +149,15 @@ def run_with_cmdline_args(args):
"resume": args.resume, "resume": args.resume,
}, },
) )
# The indices of outputs are global ranks of all workers on this node, here we
# use the local master result.
result: TrainOrTestNetOutput = outputs[args.machine_rank * args.num_processes]
# Only save results from global rank 0 for consistency. # Only save result from global rank 0 for consistency.
if args.save_return_file is not None and args.machine_rank == 0: if args.save_return_file is not None and args.machine_rank == 0:
logger.info(f"Operator results: {outputs[0]}") logger.info(f"Operator result: {result}")
logger.info(f"Writing results to {args.save_return_file}.") logger.info(f"Writing result to {args.save_return_file}.")
save_binary_outputs(args.save_return_file, outputs[0]) save_binary_outputs(args.save_return_file, result)
def cli(args=None): def cli(args=None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment