Commit 60b6995d authored by Francisc Bungiu's avatar Francisc Bungiu Committed by Facebook GitHub Bot
Browse files

Add MAST support for eval

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/583

Extend support to MAST for evaluator binary.

Reviewed By: miqueljubert

Differential Revision: D46762473

fbshipit-source-id: 62ac68f195c89924abf71c9b6a9715d60ffcbf9b
parent 955e53f6
......@@ -131,6 +131,11 @@ def build_basic_cli_args(
dist_backend: Optional[str] = None,
disable_post_mortem: bool = False,
run_as_worker: bool = False,
# Evaluator args below
predictor_path: Optional[str] = None,
num_threads: Optional[int] = None,
caffe2_engine: Optional[int] = None,
caffe2_logging_print_net_summary: Optional[int] = None,
) -> List[str]:
"""
Returns parameters in the form of CLI arguments for the binary using
......@@ -161,6 +166,17 @@ def build_basic_cli_args(
args += ["--dist-url", str(dist_url)]
if dist_backend is not None:
args += ["--dist-backend", str(dist_backend)]
if predictor_path is not None:
args += ["--predictor-path", predictor_path]
if num_threads is not None:
args += ["--num-threads", int(num_threads)]
if caffe2_engine is not None:
args += ["--caffe2-engine", int(caffe2_engine)]
if caffe2_logging_print_net_summary is not None:
args += [
"--caffe2_logging_print_net_summary",
str(caffe2_logging_print_net_summary),
]
return args
......
......@@ -29,6 +29,12 @@ class TestNetOutput:
tensorboard_log_dir: Optional[str] = None
@dataclass
class EvaluatorOutput:
accuracy: AccuracyDict[float]
metrics: MetricsDict[float]
def do_train():
pass
......
......@@ -9,7 +9,7 @@ torchscript, caffe2, etc.) using Detectron2Go system (dataloading, evaluation, e
import logging
import sys
from dataclasses import dataclass
from typing import Optional, Type, Union
from typing import List, Optional, Type, Union
import torch
from d2go.config import CfgNode
......@@ -19,6 +19,7 @@ from d2go.quantization.qconfig import smart_decode_backend
from d2go.runner import BaseRunner
from d2go.setup import (
basic_argument_parser,
build_basic_cli_args,
caffe2_global_init,
post_mortem_if_fail_for_main,
prepare_for_launch,
......@@ -26,19 +27,14 @@ from d2go.setup import (
setup_before_launch,
setup_root_logger,
)
from d2go.trainer.api import EvaluatorOutput
from d2go.utils.misc import print_metrics_table
from d2go.utils.misc import print_metrics_table, save_binary_outputs
from mobile_cv.predictor.api import create_predictor
logger = logging.getLogger("d2go.tools.caffe2_evaluator")
@dataclass
class EvaluatorOutput:
accuracy: AccuracyDict[float]
metrics: MetricsDict[float]
def main(
cfg: CfgNode,
output_dir: str,
......@@ -71,7 +67,7 @@ def run_with_cmdline_args(args):
cfg, output_dir, runner_name = prepare_for_launch(args)
shared_context = setup_before_launch(cfg, output_dir, runner_name)
main_func = main if args.disable_post_mortem else post_mortem_if_fail_for_main(main)
launch(
outputs = launch(
main_func,
args.num_processes,
num_machines=args.num_machines,
......@@ -88,6 +84,25 @@ def run_with_cmdline_args(args):
"caffe2_logging_print_net_summary": args.caffe2_logging_print_net_summary,
},
)
# Only save results from global rank 0 for consistency.
if args.save_return_file is not None and args.machine_rank == 0:
save_binary_outputs(args.save_return_file, outputs[0])
def build_cli_args(
eval_only: bool = False,
resume: bool = False,
**kwargs,
) -> List[str]:
"""Returns parameters in the form of CLI arguments for evaluator binary.
For the list of non-train_net-specific parameters, see build_basic_cli_args."""
args = build_basic_cli_args(**kwargs)
if eval_only:
args += ["--eval-only"]
if resume:
args += ["--resume"]
return args
def cli(args=None):
......
......@@ -52,7 +52,7 @@ def main(
) -> Union[TrainNetOutput, TestNetOutput]:
logger.info("Starting main")
error_handler = get_error_handler()
logger.debug(f">>>>>>> Error handler is: {type(error_handler)=}, {error_handler=}")
logger.debug(f"Error handler is: {type(error_handler)=}, {error_handler=}")
error_handler.initialize()
logger.debug("Error handler has been initialized")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment