Commit 5c16a4ea authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

use dataclass to annotate the output of main & operator

Summary: Pull Request resolved: https://github.com/facebookresearch/d2go/pull/340

Reviewed By: miqueljubert

Differential Revision: D37968017

fbshipit-source-id: a3953fdbb2c48ceaffcf94df081c0b3253d247d5
parent 596f3721
......@@ -6,7 +6,8 @@ Tool for benchmarking data loading
import logging
import time
from typing import Type, Union
from dataclasses import dataclass
from typing import Any, Dict, Type, Union
import detectron2.utils.comm as comm
import numpy as np
......@@ -27,12 +28,19 @@ from fvcore.common.history_buffer import HistoryBuffer
logger = logging.getLogger("d2go.tools.benchmark_data")
@dataclass
class BenchmarkDataOutput:
accuracy: Dict[str, Any]
# TODO: support arbitrary levels of dicts
metrics: Dict[str, Dict[str, Dict[str, Dict[str, float]]]]
def main(
cfg: CfgNode,
output_dir: str,
runner_class: Union[str, Type[BaseRunner]],
is_train: bool = True,
):
) -> BenchmarkDataOutput:
runner = setup_after_launch(cfg, output_dir, runner_class)
if is_train:
......@@ -128,10 +136,10 @@ def main(
metrics = {"_name_": {dataset_name: results}}
print_metrics_table(metrics)
return {
"accuracy": metrics,
"metrics": metrics,
}
return BenchmarkDataOutput(
accuracy=metrics,
metrics=metrics,
)
def run_with_cmdline_args(args):
......
......@@ -8,7 +8,8 @@ torchscript, caffe2, etc.) using Detectron2Go system (dataloading, evaluation, e
import logging
import sys
from typing import Optional, Type, Union
from dataclasses import dataclass
from typing import Any, Dict, Optional, Type, Union
import torch
from d2go.config import CfgNode
......@@ -28,6 +29,13 @@ from mobile_cv.predictor.api import create_predictor
logger = logging.getLogger("d2go.tools.caffe2_evaluator")
@dataclass
class EvaluatorOutput:
accuracy: Dict[str, Any]
# TODO: support arbitrary levels of dicts
metrics: Dict[str, Dict[str, Dict[str, Dict[str, float]]]]
def main(
cfg: CfgNode,
output_dir: str,
......@@ -37,7 +45,7 @@ def main(
num_threads: Optional[int] = None,
caffe2_engine: Optional[int] = None,
caffe2_logging_print_net_summary: int = 0,
):
) -> EvaluatorOutput:
torch.backends.quantized.engine = cfg.QUANTIZATION.BACKEND
print("run with quantized engine: ", torch.backends.quantized.engine)
......@@ -47,10 +55,10 @@ def main(
predictor = create_predictor(predictor_path)
metrics = runner.do_test(cfg, predictor)
print_metrics_table(metrics)
return {
"accuracy": metrics,
"metrics": metrics,
}
return EvaluatorOutput(
accuracy=metrics,
metrics=metrics,
)
@post_mortem_if_fail()
......
......@@ -9,7 +9,8 @@ deployable format (such as torchscript, caffe2, ...)
import copy
import logging
import sys
from typing import Dict, List, Type, Union
from dataclasses import dataclass
from typing import Any, Dict, List, Type, Union
import mobile_cv.lut.lib.pt.flops_utils as flops_utils
from d2go.config import CfgNode, temp_defrost
......@@ -22,6 +23,12 @@ from mobile_cv.common.misc.py import post_mortem_if_fail
logger = logging.getLogger("d2go.tools.export")
@dataclass
class ExporterOutput:
predictor_paths: Dict[str, str]
accuracy_comparison: Dict[str, Any]
def main(
cfg: CfgNode,
output_dir: str,
......@@ -31,7 +38,7 @@ def main(
device: str = "cpu",
compare_accuracy: bool = False,
skip_if_fail: bool = False,
):
) -> ExporterOutput:
if compare_accuracy:
raise NotImplementedError(
"compare_accuracy functionality isn't currently supported."
......@@ -76,9 +83,10 @@ def main(
if not skip_if_fail:
raise e
ret = {"predictor_paths": predictor_paths, "accuracy_comparison": {}}
return ret
return ExporterOutput(
predictor_paths=predictor_paths,
accuracy_comparison={},
)
@post_mortem_if_fail()
......
......@@ -4,8 +4,7 @@
import logging
import os
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Type, Union
from typing import Any, Dict, List, Type, Union
import mobile_cv.torch.utils_pytorch.comm as comm
import pytorch_lightning as pl # type: ignore
......@@ -13,6 +12,7 @@ from d2go.config import CfgNode
from d2go.runner.callbacks.quantization import QuantizationAwareTraining
from d2go.runner.lightning_task import DefaultTask
from d2go.setup import basic_argument_parser, prepare_for_launch, setup_after_launch
from d2go.tools.train_net import TrainNetOutput
from d2go.trainer.lightning.training_loop import _do_test, _do_train
from detectron2.utils.file_io import PathManager
from pytorch_lightning.callbacks import Callback, LearningRateMonitor, TQDMProgressBar
......@@ -28,14 +28,6 @@ logger = logging.getLogger("detectron2go.lightning.train_net")
FINAL_MODEL_CKPT = f"model_final{ModelCheckpoint.FILE_EXTENSION}"
@dataclass
class TrainOutput:
output_dir: str
accuracy: Optional[Dict[str, Any]] = None
tensorboard_log_dir: Optional[str] = None
model_configs: Optional[Dict[str, str]] = None
def _get_trainer_callbacks(cfg: CfgNode) -> List[Callback]:
"""Gets the trainer callbacks based on the given D2Go Config.
......@@ -93,7 +85,7 @@ def main(
output_dir: str,
runner_class: Union[str, Type[DefaultTask]],
eval_only: bool = False,
) -> TrainOutput:
) -> TrainNetOutput:
"""Main function for launching a training with lightning trainer
Args:
cfg: D2go config node
......@@ -119,10 +111,10 @@ def main(
else:
model_configs = _do_train(cfg, trainer, task)
return TrainOutput(
output_dir=cfg.OUTPUT_DIR,
return TrainNetOutput(
tensorboard_log_dir=trainer_params["logger"].log_dir,
accuracy=task.eval_res,
metrics=task.eval_res,
model_configs=model_configs,
)
......
......@@ -7,7 +7,8 @@ Detection Training Script.
import logging
import sys
from typing import List, Type, Union
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Type, Union
import detectron2.utils.comm as comm
from d2go.config import CfgNode
......@@ -31,13 +32,23 @@ from detectron2.engine.defaults import create_ddp_model
logger = logging.getLogger("d2go.tools.train_net")
@dataclass
class TrainNetOutput:
accuracy: Dict[str, Any]
# TODO: support arbitrary levels of dicts
metrics: Dict[str, Dict[str, Dict[str, Dict[str, Any]]]]
model_configs: Dict[str, str]
# TODO: decide if `tensorboard_log_dir` should be part of output
tensorboard_log_dir: Optional[str] = None
def main(
cfg: CfgNode,
output_dir: str,
runner_class: Union[str, Type[BaseRunner]],
eval_only: bool = False,
resume: bool = True, # NOTE: always enable resume when running on cluster
):
) -> TrainNetOutput:
runner = setup_after_launch(cfg, output_dir, runner_class)
model = runner.build_model(cfg)
......@@ -55,11 +66,11 @@ def main(
model.eval()
metrics = runner.do_test(cfg, model, train_iter=train_iter)
print_metrics_table(metrics)
return {
"accuracy": metrics,
"model_configs": {},
"metrics": metrics,
}
return TrainNetOutput(
accuracy=metrics,
model_configs={},
metrics=metrics,
)
model = create_ddp_model(
model,
......@@ -75,13 +86,13 @@ def main(
# dump config files for trained models
trained_model_configs = dump_trained_model_configs(cfg.OUTPUT_DIR, trained_cfgs)
return {
return TrainNetOutput(
# for e2e_workflow
"accuracy": metrics,
accuracy=metrics,
# for unit_workflow
"model_configs": trained_model_configs,
"metrics": metrics,
}
model_configs=trained_model_configs,
metrics=metrics,
)
def run_with_cmdline_args(args):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment