"vscode:/vscode.git/clone" did not exist on "d50e3217459558cc2979f38818f1835751d4fc97"
Commit 95e429a1 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

minor update of result gathering logic

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/596

`outputs = {0: result}` feels a bit hacky, technically it should be `outputs = {worker_rank: result}` in order to match the `outputs` semantic in the else-branch.

Reviewed By: frabu6

Differential Revision: D47442322

fbshipit-source-id: f4d24f7022971b4f919b4fb4a563164c7f71cd2b
parent bbfdc182
......@@ -6,7 +6,7 @@ import argparse
import logging
import os
import time
from typing import List, Optional, Tuple, Type, Union
from typing import Callable, List, Optional, Tuple, Type, TypeVar, Union
import detectron2.utils.comm as comm
import torch
......@@ -39,6 +39,8 @@ from mobile_cv.common.misc.py import FolderLock, MultiprocessingPdb, post_mortem
logger = logging.getLogger(__name__)
_RT = TypeVar("_RT")
@run_once()
def setup_root_logger(logging_level: int = logging.INFO) -> None:
......@@ -214,7 +216,7 @@ def create_cfg_from_cli(
def prepare_for_launch(
args,
) -> Tuple[CfgNode, str, Optional[str]]:
) -> Tuple[CfgNode, str, str]:
"""
Load config, figure out working directory, create runner.
- when args.config_file is empty, returned cfg will be the default one
......@@ -436,8 +438,8 @@ def caffe2_global_init(logging_print_net_summary=0, num_threads=None):
logger.info("Using {} threads after GlobalInit".format(torch.get_num_threads()))
def post_mortem_if_fail_for_main(main_func):
def new_main_func(cfg, output_dir, *args, **kwargs):
def post_mortem_if_fail_for_main(main_func: Callable[..., _RT]) -> Callable[..., _RT]:
def new_main_func(cfg, output_dir, *args, **kwargs) -> _RT:
pdb_ = (
MultiprocessingPdb(FolderLock(output_dir))
if comm.get_world_size() > 1
......
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import logging
from typing import Callable
from typing import Callable, TypeVar
from torch.distributed.elastic.multiprocessing.errors import (
_NOT_AVAILABLE,
......@@ -11,9 +11,11 @@ from torch.distributed.elastic.multiprocessing.errors import (
logger = logging.getLogger(__name__)
_RT = TypeVar("_RT")
def mast_error_handler(func: Callable) -> Callable:
def wrapper(*args, **kwargs):
def mast_error_handler(func: Callable[..., _RT]) -> Callable[..., _RT]:
def wrapper(*args, **kwargs) -> _RT:
logger.info("Starting main")
error_handler = get_error_handler()
logger.debug(f"Error handler is: {type(error_handler)=}, {error_handler=}")
......@@ -44,11 +46,11 @@ def mast_error_handler(func: Callable) -> Callable:
return wrapper
def gather_mast_errors(func: Callable) -> Callable:
def wrapper(*args, **kwargs) -> None:
def gather_mast_errors(func: Callable[..., _RT]) -> Callable[..., _RT]:
def wrapper(*args, **kwargs) -> _RT:
logger.info("Starting CLI application")
try:
func(*args, **kwargs)
return func(*args, **kwargs)
finally:
logging.info("Entering final reply file generation step")
import glob
......
......@@ -7,7 +7,7 @@ Detection Training Script.
import logging
import sys
from typing import Callable, List, Type, Union
from typing import Callable, Dict, List, Type, Union
import detectron2.utils.comm as comm
from d2go.config import CfgNode
......@@ -38,13 +38,16 @@ logger = logging.getLogger("d2go.tools.train_net")
setup_root_logger()
TrainOrTestNetOutput = Union[TrainNetOutput, TestNetOutput]
def main(
cfg: CfgNode,
output_dir: str,
runner_class: Union[str, Type[BaseRunner]],
eval_only: bool = False,
resume: bool = True, # NOTE: always enable resume when running on cluster
) -> Union[TrainNetOutput, TestNetOutput]:
) -> TrainOrTestNetOutput:
logger.debug(f"Entered main for d2go, {runner_class=}")
runner = setup_after_launch(cfg, output_dir, runner_class)
......@@ -102,7 +105,7 @@ def main(
)
def wrapped_main(*args, **kwargs) -> Callable:
def wrapped_main(*args, **kwargs) -> Callable[..., TrainOrTestNetOutput]:
return mast_error_handler(main)(*args, **kwargs)
......@@ -118,7 +121,7 @@ def run_with_cmdline_args(args):
if args.run_as_worker:
logger.info("Running as worker")
result = distributed_worker(
result: TrainOrTestNetOutput = distributed_worker(
main_func,
args=(cfg, output_dir, runner_name),
kwargs={
......@@ -131,9 +134,8 @@ def run_with_cmdline_args(args):
return_save_file=None,
shared_context=shared_context,
)
outputs = {0: result}
else:
outputs = launch(
outputs: Dict[int, TrainOrTestNetOutput] = launch(
main_func,
num_processes_per_machine=args.num_processes,
num_machines=args.num_machines,
......@@ -147,12 +149,15 @@ def run_with_cmdline_args(args):
"resume": args.resume,
},
)
# The indices of outputs are global ranks of all workers on this node, here we
# use the local master result.
result: TrainOrTestNetOutput = outputs[args.machine_rank * args.num_processes]
# Only save results from global rank 0 for consistency.
# Only save result from global rank 0 for consistency.
if args.save_return_file is not None and args.machine_rank == 0:
logger.info(f"Operator results: {outputs[0]}")
logger.info(f"Writing results to {args.save_return_file}.")
save_binary_outputs(args.save_return_file, outputs[0])
logger.info(f"Operator result: {result}")
logger.info(f"Writing result to {args.save_return_file}.")
save_binary_outputs(args.save_return_file, result)
def cli(args=None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment