"...git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "1b2d0899b4ed09e48bcef6aebfc3b815900a86e5"
Commit 5c16a4ea authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

use dataclass to annotate the output of main & operator

Summary: Pull Request resolved: https://github.com/facebookresearch/d2go/pull/340

Reviewed By: miqueljubert

Differential Revision: D37968017

fbshipit-source-id: a3953fdbb2c48ceaffcf94df081c0b3253d247d5
parent 596f3721
...@@ -6,7 +6,8 @@ Tool for benchmarking data loading ...@@ -6,7 +6,8 @@ Tool for benchmarking data loading
import logging import logging
import time import time
from typing import Type, Union from dataclasses import dataclass
from typing import Any, Dict, Type, Union
import detectron2.utils.comm as comm import detectron2.utils.comm as comm
import numpy as np import numpy as np
...@@ -27,12 +28,19 @@ from fvcore.common.history_buffer import HistoryBuffer ...@@ -27,12 +28,19 @@ from fvcore.common.history_buffer import HistoryBuffer
logger = logging.getLogger("d2go.tools.benchmark_data") logger = logging.getLogger("d2go.tools.benchmark_data")
@dataclass
class BenchmarkDataOutput:
accuracy: Dict[str, Any]
# TODO: support arbitrary levels of dicts
metrics: Dict[str, Dict[str, Dict[str, Dict[str, float]]]]
def main( def main(
cfg: CfgNode, cfg: CfgNode,
output_dir: str, output_dir: str,
runner_class: Union[str, Type[BaseRunner]], runner_class: Union[str, Type[BaseRunner]],
is_train: bool = True, is_train: bool = True,
): ) -> BenchmarkDataOutput:
runner = setup_after_launch(cfg, output_dir, runner_class) runner = setup_after_launch(cfg, output_dir, runner_class)
if is_train: if is_train:
...@@ -128,10 +136,10 @@ def main( ...@@ -128,10 +136,10 @@ def main(
metrics = {"_name_": {dataset_name: results}} metrics = {"_name_": {dataset_name: results}}
print_metrics_table(metrics) print_metrics_table(metrics)
return { return BenchmarkDataOutput(
"accuracy": metrics, accuracy=metrics,
"metrics": metrics, metrics=metrics,
} )
def run_with_cmdline_args(args): def run_with_cmdline_args(args):
......
...@@ -8,7 +8,8 @@ torchscript, caffe2, etc.) using Detectron2Go system (dataloading, evaluation, e ...@@ -8,7 +8,8 @@ torchscript, caffe2, etc.) using Detectron2Go system (dataloading, evaluation, e
import logging import logging
import sys import sys
from typing import Optional, Type, Union from dataclasses import dataclass
from typing import Any, Dict, Optional, Type, Union
import torch import torch
from d2go.config import CfgNode from d2go.config import CfgNode
...@@ -28,6 +29,13 @@ from mobile_cv.predictor.api import create_predictor ...@@ -28,6 +29,13 @@ from mobile_cv.predictor.api import create_predictor
logger = logging.getLogger("d2go.tools.caffe2_evaluator") logger = logging.getLogger("d2go.tools.caffe2_evaluator")
@dataclass
class EvaluatorOutput:
accuracy: Dict[str, Any]
# TODO: support arbitrary levels of dicts
metrics: Dict[str, Dict[str, Dict[str, Dict[str, float]]]]
def main( def main(
cfg: CfgNode, cfg: CfgNode,
output_dir: str, output_dir: str,
...@@ -37,7 +45,7 @@ def main( ...@@ -37,7 +45,7 @@ def main(
num_threads: Optional[int] = None, num_threads: Optional[int] = None,
caffe2_engine: Optional[int] = None, caffe2_engine: Optional[int] = None,
caffe2_logging_print_net_summary: int = 0, caffe2_logging_print_net_summary: int = 0,
): ) -> EvaluatorOutput:
torch.backends.quantized.engine = cfg.QUANTIZATION.BACKEND torch.backends.quantized.engine = cfg.QUANTIZATION.BACKEND
print("run with quantized engine: ", torch.backends.quantized.engine) print("run with quantized engine: ", torch.backends.quantized.engine)
...@@ -47,10 +55,10 @@ def main( ...@@ -47,10 +55,10 @@ def main(
predictor = create_predictor(predictor_path) predictor = create_predictor(predictor_path)
metrics = runner.do_test(cfg, predictor) metrics = runner.do_test(cfg, predictor)
print_metrics_table(metrics) print_metrics_table(metrics)
return { return EvaluatorOutput(
"accuracy": metrics, accuracy=metrics,
"metrics": metrics, metrics=metrics,
} )
@post_mortem_if_fail() @post_mortem_if_fail()
......
...@@ -9,7 +9,8 @@ deployable format (such as torchscript, caffe2, ...) ...@@ -9,7 +9,8 @@ deployable format (such as torchscript, caffe2, ...)
import copy import copy
import logging import logging
import sys import sys
from typing import Dict, List, Type, Union from dataclasses import dataclass
from typing import Any, Dict, List, Type, Union
import mobile_cv.lut.lib.pt.flops_utils as flops_utils import mobile_cv.lut.lib.pt.flops_utils as flops_utils
from d2go.config import CfgNode, temp_defrost from d2go.config import CfgNode, temp_defrost
...@@ -22,6 +23,12 @@ from mobile_cv.common.misc.py import post_mortem_if_fail ...@@ -22,6 +23,12 @@ from mobile_cv.common.misc.py import post_mortem_if_fail
logger = logging.getLogger("d2go.tools.export") logger = logging.getLogger("d2go.tools.export")
@dataclass
class ExporterOutput:
predictor_paths: Dict[str, str]
accuracy_comparison: Dict[str, Any]
def main( def main(
cfg: CfgNode, cfg: CfgNode,
output_dir: str, output_dir: str,
...@@ -31,7 +38,7 @@ def main( ...@@ -31,7 +38,7 @@ def main(
device: str = "cpu", device: str = "cpu",
compare_accuracy: bool = False, compare_accuracy: bool = False,
skip_if_fail: bool = False, skip_if_fail: bool = False,
): ) -> ExporterOutput:
if compare_accuracy: if compare_accuracy:
raise NotImplementedError( raise NotImplementedError(
"compare_accuracy functionality isn't currently supported." "compare_accuracy functionality isn't currently supported."
...@@ -76,9 +83,10 @@ def main( ...@@ -76,9 +83,10 @@ def main(
if not skip_if_fail: if not skip_if_fail:
raise e raise e
ret = {"predictor_paths": predictor_paths, "accuracy_comparison": {}} return ExporterOutput(
predictor_paths=predictor_paths,
return ret accuracy_comparison={},
)
@post_mortem_if_fail() @post_mortem_if_fail()
......
...@@ -4,8 +4,7 @@ ...@@ -4,8 +4,7 @@
import logging import logging
import os import os
from dataclasses import dataclass from typing import Any, Dict, List, Type, Union
from typing import Any, Dict, List, Optional, Type, Union
import mobile_cv.torch.utils_pytorch.comm as comm import mobile_cv.torch.utils_pytorch.comm as comm
import pytorch_lightning as pl # type: ignore import pytorch_lightning as pl # type: ignore
...@@ -13,6 +12,7 @@ from d2go.config import CfgNode ...@@ -13,6 +12,7 @@ from d2go.config import CfgNode
from d2go.runner.callbacks.quantization import QuantizationAwareTraining from d2go.runner.callbacks.quantization import QuantizationAwareTraining
from d2go.runner.lightning_task import DefaultTask from d2go.runner.lightning_task import DefaultTask
from d2go.setup import basic_argument_parser, prepare_for_launch, setup_after_launch from d2go.setup import basic_argument_parser, prepare_for_launch, setup_after_launch
from d2go.tools.train_net import TrainNetOutput
from d2go.trainer.lightning.training_loop import _do_test, _do_train from d2go.trainer.lightning.training_loop import _do_test, _do_train
from detectron2.utils.file_io import PathManager from detectron2.utils.file_io import PathManager
from pytorch_lightning.callbacks import Callback, LearningRateMonitor, TQDMProgressBar from pytorch_lightning.callbacks import Callback, LearningRateMonitor, TQDMProgressBar
...@@ -28,14 +28,6 @@ logger = logging.getLogger("detectron2go.lightning.train_net") ...@@ -28,14 +28,6 @@ logger = logging.getLogger("detectron2go.lightning.train_net")
FINAL_MODEL_CKPT = f"model_final{ModelCheckpoint.FILE_EXTENSION}" FINAL_MODEL_CKPT = f"model_final{ModelCheckpoint.FILE_EXTENSION}"
@dataclass
class TrainOutput:
output_dir: str
accuracy: Optional[Dict[str, Any]] = None
tensorboard_log_dir: Optional[str] = None
model_configs: Optional[Dict[str, str]] = None
def _get_trainer_callbacks(cfg: CfgNode) -> List[Callback]: def _get_trainer_callbacks(cfg: CfgNode) -> List[Callback]:
"""Gets the trainer callbacks based on the given D2Go Config. """Gets the trainer callbacks based on the given D2Go Config.
...@@ -93,7 +85,7 @@ def main( ...@@ -93,7 +85,7 @@ def main(
output_dir: str, output_dir: str,
runner_class: Union[str, Type[DefaultTask]], runner_class: Union[str, Type[DefaultTask]],
eval_only: bool = False, eval_only: bool = False,
) -> TrainOutput: ) -> TrainNetOutput:
"""Main function for launching a training with lightning trainer """Main function for launching a training with lightning trainer
Args: Args:
cfg: D2go config node cfg: D2go config node
...@@ -119,10 +111,10 @@ def main( ...@@ -119,10 +111,10 @@ def main(
else: else:
model_configs = _do_train(cfg, trainer, task) model_configs = _do_train(cfg, trainer, task)
return TrainOutput( return TrainNetOutput(
output_dir=cfg.OUTPUT_DIR,
tensorboard_log_dir=trainer_params["logger"].log_dir, tensorboard_log_dir=trainer_params["logger"].log_dir,
accuracy=task.eval_res, accuracy=task.eval_res,
metrics=task.eval_res,
model_configs=model_configs, model_configs=model_configs,
) )
......
...@@ -7,7 +7,8 @@ Detection Training Script. ...@@ -7,7 +7,8 @@ Detection Training Script.
import logging import logging
import sys import sys
from typing import List, Type, Union from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Type, Union
import detectron2.utils.comm as comm import detectron2.utils.comm as comm
from d2go.config import CfgNode from d2go.config import CfgNode
...@@ -31,13 +32,23 @@ from detectron2.engine.defaults import create_ddp_model ...@@ -31,13 +32,23 @@ from detectron2.engine.defaults import create_ddp_model
logger = logging.getLogger("d2go.tools.train_net") logger = logging.getLogger("d2go.tools.train_net")
@dataclass
class TrainNetOutput:
accuracy: Dict[str, Any]
# TODO: support arbitrary levels of dicts
metrics: Dict[str, Dict[str, Dict[str, Dict[str, Any]]]]
model_configs: Dict[str, str]
# TODO: decide if `tensorboard_log_dir` should be part of output
tensorboard_log_dir: Optional[str] = None
def main( def main(
cfg: CfgNode, cfg: CfgNode,
output_dir: str, output_dir: str,
runner_class: Union[str, Type[BaseRunner]], runner_class: Union[str, Type[BaseRunner]],
eval_only: bool = False, eval_only: bool = False,
resume: bool = True, # NOTE: always enable resume when running on cluster resume: bool = True, # NOTE: always enable resume when running on cluster
): ) -> TrainNetOutput:
runner = setup_after_launch(cfg, output_dir, runner_class) runner = setup_after_launch(cfg, output_dir, runner_class)
model = runner.build_model(cfg) model = runner.build_model(cfg)
...@@ -55,11 +66,11 @@ def main( ...@@ -55,11 +66,11 @@ def main(
model.eval() model.eval()
metrics = runner.do_test(cfg, model, train_iter=train_iter) metrics = runner.do_test(cfg, model, train_iter=train_iter)
print_metrics_table(metrics) print_metrics_table(metrics)
return { return TrainNetOutput(
"accuracy": metrics, accuracy=metrics,
"model_configs": {}, model_configs={},
"metrics": metrics, metrics=metrics,
} )
model = create_ddp_model( model = create_ddp_model(
model, model,
...@@ -75,13 +86,13 @@ def main( ...@@ -75,13 +86,13 @@ def main(
# dump config files for trained models # dump config files for trained models
trained_model_configs = dump_trained_model_configs(cfg.OUTPUT_DIR, trained_cfgs) trained_model_configs = dump_trained_model_configs(cfg.OUTPUT_DIR, trained_cfgs)
return { return TrainNetOutput(
# for e2e_workflow # for e2e_workflow
"accuracy": metrics, accuracy=metrics,
# for unit_workflow # for unit_workflow
"model_configs": trained_model_configs, model_configs=trained_model_configs,
"metrics": metrics, metrics=metrics,
} )
def run_with_cmdline_args(args): def run_with_cmdline_args(args):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment