Commit 60b6995d authored by Francisc Bungiu's avatar Francisc Bungiu Committed by Facebook GitHub Bot
Browse files

Add MAST support for eval

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/583

Extend support to MAST for evaluator binary.

Reviewed By: miqueljubert

Differential Revision: D46762473

fbshipit-source-id: 62ac68f195c89924abf71c9b6a9715d60ffcbf9b
parent 955e53f6
...@@ -131,6 +131,11 @@ def build_basic_cli_args( ...@@ -131,6 +131,11 @@ def build_basic_cli_args(
dist_backend: Optional[str] = None, dist_backend: Optional[str] = None,
disable_post_mortem: bool = False, disable_post_mortem: bool = False,
run_as_worker: bool = False, run_as_worker: bool = False,
# Evaluator args below
predictor_path: Optional[str] = None,
num_threads: Optional[int] = None,
caffe2_engine: Optional[int] = None,
caffe2_logging_print_net_summary: Optional[int] = None,
) -> List[str]: ) -> List[str]:
""" """
Returns parameters in the form of CLI arguments for the binary using Returns parameters in the form of CLI arguments for the binary using
...@@ -161,6 +166,17 @@ def build_basic_cli_args( ...@@ -161,6 +166,17 @@ def build_basic_cli_args(
args += ["--dist-url", str(dist_url)] args += ["--dist-url", str(dist_url)]
if dist_backend is not None: if dist_backend is not None:
args += ["--dist-backend", str(dist_backend)] args += ["--dist-backend", str(dist_backend)]
if predictor_path is not None:
args += ["--predictor-path", predictor_path]
if num_threads is not None:
args += ["--num-threads", int(num_threads)]
if caffe2_engine is not None:
args += ["--caffe2-engine", int(caffe2_engine)]
if caffe2_logging_print_net_summary is not None:
args += [
"--caffe2_logging_print_net_summary",
str(caffe2_logging_print_net_summary),
]
return args return args
......
...@@ -29,6 +29,12 @@ class TestNetOutput: ...@@ -29,6 +29,12 @@ class TestNetOutput:
tensorboard_log_dir: Optional[str] = None tensorboard_log_dir: Optional[str] = None
@dataclass
class EvaluatorOutput:
accuracy: AccuracyDict[float]
metrics: MetricsDict[float]
def do_train(): def do_train():
pass pass
......
...@@ -9,7 +9,7 @@ torchscript, caffe2, etc.) using Detectron2Go system (dataloading, evaluation, e ...@@ -9,7 +9,7 @@ torchscript, caffe2, etc.) using Detectron2Go system (dataloading, evaluation, e
import logging import logging
import sys import sys
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Type, Union from typing import List, Optional, Type, Union
import torch import torch
from d2go.config import CfgNode from d2go.config import CfgNode
...@@ -19,6 +19,7 @@ from d2go.quantization.qconfig import smart_decode_backend ...@@ -19,6 +19,7 @@ from d2go.quantization.qconfig import smart_decode_backend
from d2go.runner import BaseRunner from d2go.runner import BaseRunner
from d2go.setup import ( from d2go.setup import (
basic_argument_parser, basic_argument_parser,
build_basic_cli_args,
caffe2_global_init, caffe2_global_init,
post_mortem_if_fail_for_main, post_mortem_if_fail_for_main,
prepare_for_launch, prepare_for_launch,
...@@ -26,19 +27,14 @@ from d2go.setup import ( ...@@ -26,19 +27,14 @@ from d2go.setup import (
setup_before_launch, setup_before_launch,
setup_root_logger, setup_root_logger,
) )
from d2go.trainer.api import EvaluatorOutput
from d2go.utils.misc import print_metrics_table from d2go.utils.misc import print_metrics_table, save_binary_outputs
from mobile_cv.predictor.api import create_predictor from mobile_cv.predictor.api import create_predictor
logger = logging.getLogger("d2go.tools.caffe2_evaluator") logger = logging.getLogger("d2go.tools.caffe2_evaluator")
@dataclass
class EvaluatorOutput:
accuracy: AccuracyDict[float]
metrics: MetricsDict[float]
def main( def main(
cfg: CfgNode, cfg: CfgNode,
output_dir: str, output_dir: str,
...@@ -71,7 +67,7 @@ def run_with_cmdline_args(args): ...@@ -71,7 +67,7 @@ def run_with_cmdline_args(args):
cfg, output_dir, runner_name = prepare_for_launch(args) cfg, output_dir, runner_name = prepare_for_launch(args)
shared_context = setup_before_launch(cfg, output_dir, runner_name) shared_context = setup_before_launch(cfg, output_dir, runner_name)
main_func = main if args.disable_post_mortem else post_mortem_if_fail_for_main(main) main_func = main if args.disable_post_mortem else post_mortem_if_fail_for_main(main)
launch( outputs = launch(
main_func, main_func,
args.num_processes, args.num_processes,
num_machines=args.num_machines, num_machines=args.num_machines,
...@@ -88,6 +84,25 @@ def run_with_cmdline_args(args): ...@@ -88,6 +84,25 @@ def run_with_cmdline_args(args):
"caffe2_logging_print_net_summary": args.caffe2_logging_print_net_summary, "caffe2_logging_print_net_summary": args.caffe2_logging_print_net_summary,
}, },
) )
# Only save results from global rank 0 for consistency.
if args.save_return_file is not None and args.machine_rank == 0:
save_binary_outputs(args.save_return_file, outputs[0])
def build_cli_args(
eval_only: bool = False,
resume: bool = False,
**kwargs,
) -> List[str]:
"""Returns parameters in the form of CLI arguments for evaluator binary.
For the list of non-train_net-specific parameters, see build_basic_cli_args."""
args = build_basic_cli_args(**kwargs)
if eval_only:
args += ["--eval-only"]
if resume:
args += ["--resume"]
return args
def cli(args=None): def cli(args=None):
......
...@@ -52,7 +52,7 @@ def main( ...@@ -52,7 +52,7 @@ def main(
) -> Union[TrainNetOutput, TestNetOutput]: ) -> Union[TrainNetOutput, TestNetOutput]:
logger.info("Starting main") logger.info("Starting main")
error_handler = get_error_handler() error_handler = get_error_handler()
logger.debug(f">>>>>>> Error handler is: {type(error_handler)=}, {error_handler=}") logger.debug(f"Error handler is: {type(error_handler)=}, {error_handler=}")
error_handler.initialize() error_handler.initialize()
logger.debug("Error handler has been initialized") logger.debug("Error handler has been initialized")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment