Commit b57fde40 authored by Tsahi Glik's avatar Tsahi Glik Committed by Facebook GitHub Bot
Browse files

Support saving results in d2go tools

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/297

X-link: https://github.com/facebookresearch/mobile-vision/pull/84

Add command line arg to specify whether and where to save results.
This is useful where binaries are being launched from another process, or remotely on another machine.

Reviewed By: wat3rBro

Differential Revision: D37157955

fbshipit-source-id: 2a48cf967f6cf928049f2be41952834e1dd2a04d
parent da04300b
...@@ -62,6 +62,12 @@ def basic_argument_parser( ...@@ -62,6 +62,12 @@ def basic_argument_parser(
default=None, default=None,
nargs=argparse.REMAINDER, nargs=argparse.REMAINDER,
) )
parser.add_argument(
"--save-return-file",
help="When given, the main function outputs will be serialized and saved to this file",
default=None,
type=str,
)
if distributed: if distributed:
parser.add_argument( parser.add_argument(
...@@ -86,6 +92,7 @@ def build_basic_cli_args( ...@@ -86,6 +92,7 @@ def build_basic_cli_args(
config_path: Optional[str] = None, config_path: Optional[str] = None,
output_dir: Optional[str] = None, output_dir: Optional[str] = None,
runner_name: Optional[str] = None, runner_name: Optional[str] = None,
save_return_file: Optional[str] = None,
num_processes: Optional[Union[int, str]] = None, num_processes: Optional[Union[int, str]] = None,
num_machines: Optional[Union[int, str]] = None, num_machines: Optional[Union[int, str]] = None,
machine_rank: Optional[Union[int, str]] = None, machine_rank: Optional[Union[int, str]] = None,
...@@ -105,6 +112,8 @@ def build_basic_cli_args( ...@@ -105,6 +112,8 @@ def build_basic_cli_args(
args += ["--output-dir", output_dir] args += ["--output-dir", output_dir]
if runner_name is not None: if runner_name is not None:
args += ["--runner", runner_name] args += ["--runner", runner_name]
if save_return_file is not None:
args += ["--save-return-file", str(save_return_file)]
if num_processes is not None: if num_processes is not None:
args += ["--num-processes", str(num_processes)] args += ["--num-processes", str(num_processes)]
if num_machines is not None: if num_machines is not None:
......
...@@ -6,7 +6,7 @@ import logging ...@@ -6,7 +6,7 @@ import logging
import os import os
import warnings import warnings
from contextlib import contextmanager from contextlib import contextmanager
from typing import Dict, Iterator from typing import Any, Dict, Iterator
# @manual=//vision/fair/detectron2/detectron2:detectron2 # @manual=//vision/fair/detectron2/detectron2:detectron2
import detectron2.utils.comm as comm import detectron2.utils.comm as comm
...@@ -118,6 +118,18 @@ def read_trained_model_configs(output_dir: str) -> Dict[str, str]: ...@@ -118,6 +118,18 @@ def read_trained_model_configs(output_dir: str) -> Dict[str, str]:
} }
def save_binary_outputs(filename: str, outputs: Any) -> None:
"""Helper function to serialize and save function outputs in binary format."""
with PathManager.open(filename, "wb") as f:
torch.save(outputs, f)
def load_binary_outputs(filename: str) -> Any:
"""Helper function to load and deserialize function outputs saved in binary format."""
with PathManager.open(filename, "rb") as f:
return torch.load(f)
@contextmanager @contextmanager
def mode(net: torch.nn.Module, training: bool) -> Iterator[torch.nn.Module]: def mode(net: torch.nn.Module, training: bool) -> Iterator[torch.nn.Module]:
"""Temporarily switch to training/evaluation mode.""" """Temporarily switch to training/evaluation mode."""
......
...@@ -18,7 +18,11 @@ from d2go.setup import ( ...@@ -18,7 +18,11 @@ from d2go.setup import (
prepare_for_launch, prepare_for_launch,
setup_after_launch, setup_after_launch,
) )
from d2go.utils.misc import dump_trained_model_configs, print_metrics_table from d2go.utils.misc import (
dump_trained_model_configs,
print_metrics_table,
save_binary_outputs,
)
from detectron2.engine.defaults import create_ddp_model from detectron2.engine.defaults import create_ddp_model
...@@ -81,7 +85,8 @@ def main( ...@@ -81,7 +85,8 @@ def main(
def run_with_cmdline_args(args): def run_with_cmdline_args(args):
cfg, output_dir, runner = prepare_for_launch(args) cfg, output_dir, runner = prepare_for_launch(args)
launch(
outputs = launch(
post_mortem_if_fail_for_main(main), post_mortem_if_fail_for_main(main),
num_processes_per_machine=args.num_processes, num_processes_per_machine=args.num_processes,
num_machines=args.num_machines, num_machines=args.num_machines,
...@@ -91,6 +96,11 @@ def run_with_cmdline_args(args): ...@@ -91,6 +96,11 @@ def run_with_cmdline_args(args):
args=(cfg, output_dir, runner, args.eval_only, args.resume), args=(cfg, output_dir, runner, args.eval_only, args.resume),
) )
if args.save_return_file is not None:
save_binary_outputs(args.save_return_file, outputs)
return outputs
def cli(args=None): def cli(args=None):
parser = basic_argument_parser(requires_output_dir=False) parser = basic_argument_parser(requires_output_dir=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment