Commit 8051775c authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

use runner class instead of instance outside of main

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/312

As discussed, we decided to not use runner instance outside of `main`, previous diffs already solved the prerequisites, this diff mainly does the renaming.
- Use runner name (str) in the fblearner, ML pipeline.
- Use runner name (str) in FBL operator, MAST and binary operator.
- Use runner class as the interface of main, it can be either the name of class (str) or actual class. The main usage should be using `str`, so that the importing of class happens inside `main`. But it's also a common use case to import runner class and call `main` for things like ad-hoc scripts or tests, supporting actual class makes it easier modify code for those cases (eg. some local test class doesn't have a name, so it's not feasible to use runner name).

Reviewed By: newstzpz

Differential Revision: D37060338

fbshipit-source-id: 879852d41902b87d6db6cb9d7b3e8dc55dc4b976
parent b077a2c1
...@@ -18,6 +18,7 @@ __all__ = [ ...@@ -18,6 +18,7 @@ __all__ = [
"GeneralizedRCNNRunner", "GeneralizedRCNNRunner",
"TRAINER_HOOKS_REGISTRY", "TRAINER_HOOKS_REGISTRY",
"create_runner", "create_runner",
"import_runner",
] ]
......
...@@ -6,7 +6,7 @@ import argparse ...@@ -6,7 +6,7 @@ import argparse
import logging import logging
import os import os
import time import time
from typing import List, Optional, Type, Union from typing import List, Optional, Tuple, Type, Union
import detectron2.utils.comm as comm import detectron2.utils.comm as comm
import torch import torch
...@@ -19,13 +19,7 @@ from d2go.config import ( ...@@ -19,13 +19,7 @@ from d2go.config import (
) )
from d2go.config.utils import get_diff_cfg from d2go.config.utils import get_diff_cfg
from d2go.distributed import get_local_rank, get_num_processes_per_machine from d2go.distributed import get_local_rank, get_num_processes_per_machine
from d2go.runner import ( from d2go.runner import BaseRunner, DefaultTask, import_runner, RunnerV2Mixin
BaseRunner,
create_runner,
DefaultTask,
import_runner,
RunnerV2Mixin,
)
from d2go.utils.helper import run_once from d2go.utils.helper import run_once
from d2go.utils.launch_environment import get_launch_environment from d2go.utils.launch_environment import get_launch_environment
from detectron2.utils.collect_env import collect_env_info from detectron2.utils.collect_env import collect_env_info
...@@ -149,6 +143,7 @@ def create_cfg_from_cli( ...@@ -149,6 +143,7 @@ def create_cfg_from_cli(
print("Loaded config file {}:\n{}".format(config_file, f.read())) print("Loaded config file {}:\n{}".format(config_file, f.read()))
if isinstance(runner_class, str): if isinstance(runner_class, str):
print(f"Importing runner: {runner_class} ...")
runner_class = import_runner(runner_class) runner_class = import_runner(runner_class)
if runner_class is None or issubclass(runner_class, RunnerV2Mixin): if runner_class is None or issubclass(runner_class, RunnerV2Mixin):
# Runner-less API # Runner-less API
...@@ -164,7 +159,9 @@ def create_cfg_from_cli( ...@@ -164,7 +159,9 @@ def create_cfg_from_cli(
return cfg return cfg
def prepare_for_launch(args): def prepare_for_launch(
args,
) -> Tuple[CfgNode, str, Optional[str]]:
""" """
Load config, figure out working directory, create runner. Load config, figure out working directory, create runner.
- when args.config_file is empty, returned cfg will be the default one - when args.config_file is empty, returned cfg will be the default one
...@@ -183,9 +180,7 @@ def prepare_for_launch(args): ...@@ -183,9 +180,7 @@ def prepare_for_launch(args):
assert args.output_dir or args.config_file assert args.output_dir or args.config_file
output_dir = args.output_dir or cfg.OUTPUT_DIR output_dir = args.output_dir or cfg.OUTPUT_DIR
# TODO (T123980149): use runner_name across the board return cfg, output_dir, args.runner
runner = create_runner(args.runner)
return cfg, output_dir, runner
def maybe_override_output_dir(cfg: CfgNode, output_dir: str): def maybe_override_output_dir(cfg: CfgNode, output_dir: str):
...@@ -202,8 +197,8 @@ def maybe_override_output_dir(cfg: CfgNode, output_dir: str): ...@@ -202,8 +197,8 @@ def maybe_override_output_dir(cfg: CfgNode, output_dir: str):
def setup_after_launch( def setup_after_launch(
cfg: CfgNode, cfg: CfgNode,
output_dir: str, output_dir: str,
runner: Union[BaseRunner, Type[DefaultTask], None], runner_class: Union[None, str, Type[BaseRunner], Type[DefaultTask]],
): ) -> Union[None, BaseRunner, Type[DefaultTask]]:
""" """
Binary-level setup after entering DDP, including Binary-level setup after entering DDP, including
- creating working directory - creating working directory
...@@ -222,10 +217,22 @@ def setup_after_launch( ...@@ -222,10 +217,22 @@ def setup_after_launch(
logger.info("Running with full config:\n{}".format(cfg)) logger.info("Running with full config:\n{}".format(cfg))
dump_cfg(cfg, os.path.join(output_dir, "config.yaml")) dump_cfg(cfg, os.path.join(output_dir, "config.yaml"))
if isinstance(runner, BaseRunner): if isinstance(runner_class, str):
logger.info("Initializing runner ...") logger.info(f"Importing runner: {runner_class} ...")
runner_class = import_runner(runner_class)
if issubclass(runner_class, DefaultTask):
# TODO(T123679504): merge this with runner code path to return runner instance
logger.info(f"Importing lightning task: {runner_class} ...")
runner = runner_class
elif issubclass(runner_class, BaseRunner):
logger.info(f"Initializing runner: {runner_class} ...")
runner = runner_class()
runner = initialize_runner(runner, cfg) runner = initialize_runner(runner, cfg)
logger.info("Running with runner: {}".format(runner)) logger.info("Running with runner: {}".format(runner))
else:
assert runner_class is None, f"Unsupported runner class: {runner_class}"
runner = None
# save the diff config # save the diff config
default_cfg = ( default_cfg = (
...@@ -241,6 +248,8 @@ def setup_after_launch( ...@@ -241,6 +248,8 @@ def setup_after_launch(
# scale the config after dumping so that dumped config files keep original world size # scale the config after dumping so that dumped config files keep original world size
auto_scale_world_size(cfg, new_world_size=comm.get_world_size()) auto_scale_world_size(cfg, new_world_size=comm.get_world_size())
return runner
def setup_logger( def setup_logger(
module_name: str, module_name: str,
......
...@@ -33,7 +33,7 @@ class TestLightningTrainNet(unittest.TestCase): ...@@ -33,7 +33,7 @@ class TestLightningTrainNet(unittest.TestCase):
cfg = self._get_cfg(root_dir) cfg = self._get_cfg(root_dir)
# set distributed backend to none to avoid spawning child process, # set distributed backend to none to avoid spawning child process,
# which doesn't inherit the temporary dataset # which doesn't inherit the temporary dataset
main(cfg, root_dir) main(cfg, root_dir, GeneralizedRCNNTask)
@tempdir @tempdir
@enable_ddp_env @enable_ddp_env
...@@ -41,7 +41,7 @@ class TestLightningTrainNet(unittest.TestCase): ...@@ -41,7 +41,7 @@ class TestLightningTrainNet(unittest.TestCase):
"""tests saving and loading from checkpoint.""" """tests saving and loading from checkpoint."""
cfg = self._get_cfg(tmp_dir) cfg = self._get_cfg(tmp_dir)
out = main(cfg, tmp_dir) out = main(cfg, tmp_dir, GeneralizedRCNNTask)
ckpts = [f for f in os.listdir(tmp_dir) if f.endswith(".ckpt")] ckpts = [f for f in os.listdir(tmp_dir) if f.endswith(".ckpt")]
expected_ckpts = ("last.ckpt", FINAL_MODEL_CKPT) expected_ckpts = ("last.ckpt", FINAL_MODEL_CKPT)
for ckpt in expected_ckpts: for ckpt in expected_ckpts:
...@@ -53,7 +53,7 @@ class TestLightningTrainNet(unittest.TestCase): ...@@ -53,7 +53,7 @@ class TestLightningTrainNet(unittest.TestCase):
cfg2.MODEL.WEIGHTS = os.path.join(tmp_dir, "last.ckpt") cfg2.MODEL.WEIGHTS = os.path.join(tmp_dir, "last.ckpt")
output_dir = os.path.join(tmp_dir, "output") output_dir = os.path.join(tmp_dir, "output")
out2 = main(cfg2, output_dir, eval_only=True) out2 = main(cfg2, output_dir, GeneralizedRCNNTask, eval_only=True)
accuracy = flatten_config_dict(out.accuracy) accuracy = flatten_config_dict(out.accuracy)
accuracy2 = flatten_config_dict(out2.accuracy) accuracy2 = flatten_config_dict(out2.accuracy)
for k in accuracy: for k in accuracy:
......
...@@ -24,8 +24,7 @@ def maskrcnn_export_caffe2_vs_torchvision_opset_format_example(self): ...@@ -24,8 +24,7 @@ def maskrcnn_export_caffe2_vs_torchvision_opset_format_example(self):
(dataset_name,), (dataset_name,),
] ]
# START_WIKI_EXAMPLE_TAG # START_WIKI_EXAMPLE_TAG
runner = GeneralizedRCNNRunner() cfg = GeneralizedRCNNRunner.get_default_cfg()
cfg = runner.get_default_cfg()
cfg.merge_from_file("detectron2go://mask_rcnn_fbnetv3a_dsmask_C4.yaml") cfg.merge_from_file("detectron2go://mask_rcnn_fbnetv3a_dsmask_C4.yaml")
cfg.merge_from_list(get_quick_test_config_opts()) cfg.merge_from_list(get_quick_test_config_opts())
cfg.merge_from_list(config_list) cfg.merge_from_list(config_list)
...@@ -35,7 +34,7 @@ def maskrcnn_export_caffe2_vs_torchvision_opset_format_example(self): ...@@ -35,7 +34,7 @@ def maskrcnn_export_caffe2_vs_torchvision_opset_format_example(self):
_ = main( _ = main(
cfg, cfg,
tmp_dir, tmp_dir,
runner, GeneralizedRCNNRunner,
predictor_types=["torchscript@c2_ops", "torchscript"], predictor_types=["torchscript@c2_ops", "torchscript"],
) )
......
...@@ -6,10 +6,13 @@ Tool for benchmarking data loading ...@@ -6,10 +6,13 @@ Tool for benchmarking data loading
import logging import logging
import time import time
from typing import Type, Union
import detectron2.utils.comm as comm import detectron2.utils.comm as comm
import numpy as np import numpy as np
from d2go.config import CfgNode
from d2go.distributed import get_num_processes_per_machine, launch from d2go.distributed import get_num_processes_per_machine, launch
from d2go.runner import BaseRunner
from d2go.setup import ( from d2go.setup import (
basic_argument_parser, basic_argument_parser,
post_mortem_if_fail_for_main, post_mortem_if_fail_for_main,
...@@ -21,17 +24,16 @@ from detectron2.fb.env import get_launch_environment ...@@ -21,17 +24,16 @@ from detectron2.fb.env import get_launch_environment
from detectron2.utils.logger import log_every_n_seconds from detectron2.utils.logger import log_every_n_seconds
from fvcore.common.history_buffer import HistoryBuffer from fvcore.common.history_buffer import HistoryBuffer
logger = logging.getLogger("d2go.tools.benchmark_data") logger = logging.getLogger("d2go.tools.benchmark_data")
def main( def main(
cfg, cfg: CfgNode,
output_dir, output_dir: str,
runner=None, runner_class: Union[str, Type[BaseRunner]],
is_train=True, is_train: bool = True,
): ):
setup_after_launch(cfg, output_dir, runner) runner = setup_after_launch(cfg, output_dir, runner_class)
if is_train: if is_train:
data_loader = runner.build_detection_train_loader(cfg) data_loader = runner.build_detection_train_loader(cfg)
...@@ -133,7 +135,7 @@ def main( ...@@ -133,7 +135,7 @@ def main(
def run_with_cmdline_args(args): def run_with_cmdline_args(args):
cfg, output_dir, runner = prepare_for_launch(args) cfg, output_dir, runner_name = prepare_for_launch(args)
launch( launch(
post_mortem_if_fail_for_main(main), post_mortem_if_fail_for_main(main),
num_processes_per_machine=args.num_processes, num_processes_per_machine=args.num_processes,
...@@ -141,7 +143,7 @@ def run_with_cmdline_args(args): ...@@ -141,7 +143,7 @@ def run_with_cmdline_args(args):
machine_rank=args.machine_rank, machine_rank=args.machine_rank,
dist_url=args.dist_url, dist_url=args.dist_url,
backend=args.dist_backend, backend=args.dist_backend,
args=(cfg, output_dir, runner, args.is_train), args=(cfg, output_dir, runner_name, args.is_train),
) )
......
...@@ -8,9 +8,12 @@ torchscript, caffe2, etc.) using Detectron2Go system (dataloading, evaluation, e ...@@ -8,9 +8,12 @@ torchscript, caffe2, etc.) using Detectron2Go system (dataloading, evaluation, e
import logging import logging
import sys import sys
from typing import Optional, Type, Union
import torch import torch
from d2go.config import CfgNode
from d2go.distributed import launch from d2go.distributed import launch
from d2go.runner import BaseRunner
from d2go.setup import ( from d2go.setup import (
basic_argument_parser, basic_argument_parser,
caffe2_global_init, caffe2_global_init,
...@@ -26,19 +29,19 @@ logger = logging.getLogger("d2go.tools.caffe2_evaluator") ...@@ -26,19 +29,19 @@ logger = logging.getLogger("d2go.tools.caffe2_evaluator")
def main( def main(
cfg, cfg: CfgNode,
output_dir, output_dir: str,
runner, runner_class: Union[str, Type[BaseRunner]],
# binary specific optional arguments # binary specific optional arguments
predictor_path, predictor_path: str,
num_threads=None, num_threads: Optional[int] = None,
caffe2_engine=None, caffe2_engine: Optional[int] = None,
caffe2_logging_print_net_summary=0, caffe2_logging_print_net_summary: int = 0,
): ):
torch.backends.quantized.engine = cfg.QUANTIZATION.BACKEND torch.backends.quantized.engine = cfg.QUANTIZATION.BACKEND
print("run with quantized engine: ", torch.backends.quantized.engine) print("run with quantized engine: ", torch.backends.quantized.engine)
setup_after_launch(cfg, output_dir, runner) runner = setup_after_launch(cfg, output_dir, runner_class)
caffe2_global_init(caffe2_logging_print_net_summary, num_threads) caffe2_global_init(caffe2_logging_print_net_summary, num_threads)
predictor = create_predictor(predictor_path) predictor = create_predictor(predictor_path)
...@@ -52,7 +55,7 @@ def main( ...@@ -52,7 +55,7 @@ def main(
@post_mortem_if_fail() @post_mortem_if_fail()
def run_with_cmdline_args(args): def run_with_cmdline_args(args):
cfg, output_dir, runner = prepare_for_launch(args) cfg, output_dir, runner_name = prepare_for_launch(args)
launch( launch(
post_mortem_if_fail_for_main(main), post_mortem_if_fail_for_main(main),
args.num_processes, args.num_processes,
...@@ -64,7 +67,7 @@ def run_with_cmdline_args(args): ...@@ -64,7 +67,7 @@ def run_with_cmdline_args(args):
args=( args=(
cfg, cfg,
output_dir, output_dir,
runner, runner_name,
# binary specific optional arguments # binary specific optional arguments
args.predictor_path, args.predictor_path,
args.num_threads, args.num_threads,
......
...@@ -9,11 +9,12 @@ deployable format (such as torchscript, caffe2, ...) ...@@ -9,11 +9,12 @@ deployable format (such as torchscript, caffe2, ...)
import copy import copy
import logging import logging
import sys import sys
import typing from typing import Dict, List, Type, Union
import mobile_cv.lut.lib.pt.flops_utils as flops_utils import mobile_cv.lut.lib.pt.flops_utils as flops_utils
from d2go.config import temp_defrost from d2go.config import CfgNode, temp_defrost
from d2go.export.exporter import convert_and_export_predictor from d2go.export.exporter import convert_and_export_predictor
from d2go.runner import BaseRunner
from d2go.setup import basic_argument_parser, prepare_for_launch, setup_after_launch from d2go.setup import basic_argument_parser, prepare_for_launch, setup_after_launch
from mobile_cv.common.misc.py import post_mortem_if_fail from mobile_cv.common.misc.py import post_mortem_if_fail
...@@ -22,11 +23,11 @@ logger = logging.getLogger("d2go.tools.export") ...@@ -22,11 +23,11 @@ logger = logging.getLogger("d2go.tools.export")
def main( def main(
cfg, cfg: CfgNode,
output_dir, output_dir: str,
runner, runner_class: Union[str, Type[BaseRunner]],
# binary specific optional arguments # binary specific optional arguments
predictor_types: typing.List[str], predictor_types: List[str],
device: str = "cpu", device: str = "cpu",
compare_accuracy: bool = False, compare_accuracy: bool = False,
skip_if_fail: bool = False, skip_if_fail: bool = False,
...@@ -39,7 +40,7 @@ def main( ...@@ -39,7 +40,7 @@ def main(
# ret["accuracy_comparison"] = accuracy_comparison # ret["accuracy_comparison"] = accuracy_comparison
cfg = copy.deepcopy(cfg) cfg = copy.deepcopy(cfg)
setup_after_launch(cfg, output_dir, runner) runner = setup_after_launch(cfg, output_dir, runner_class)
with temp_defrost(cfg): with temp_defrost(cfg):
cfg.merge_from_list(["MODEL.DEVICE", device]) cfg.merge_from_list(["MODEL.DEVICE", device])
...@@ -56,7 +57,7 @@ def main( ...@@ -56,7 +57,7 @@ def main(
input_args = (first_batch,) input_args = (first_batch,)
flops_utils.print_model_flops(model, input_args) flops_utils.print_model_flops(model, input_args)
predictor_paths: typing.Dict[str, str] = {} predictor_paths: Dict[str, str] = {}
for typ in predictor_types: for typ in predictor_types:
# convert_and_export_predictor might alter the model, copy before calling it # convert_and_export_predictor might alter the model, copy before calling it
pytorch_model = copy.deepcopy(model) pytorch_model = copy.deepcopy(model)
...@@ -82,11 +83,11 @@ def main( ...@@ -82,11 +83,11 @@ def main(
@post_mortem_if_fail() @post_mortem_if_fail()
def run_with_cmdline_args(args): def run_with_cmdline_args(args):
cfg, output_dir, runner = prepare_for_launch(args) cfg, output_dir, runner_name = prepare_for_launch(args)
return main( return main(
cfg, cfg,
output_dir, output_dir,
runner, runner_name,
# binary specific optional arguments # binary specific optional arguments
predictor_types=args.predictor_types, predictor_types=args.predictor_types,
device=args.device, device=args.device,
......
...@@ -5,14 +5,14 @@ ...@@ -5,14 +5,14 @@
import logging import logging
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Type from typing import Any, Dict, List, Optional, Type, Union
import mobile_cv.torch.utils_pytorch.comm as comm import mobile_cv.torch.utils_pytorch.comm as comm
import pytorch_lightning as pl # type: ignore import pytorch_lightning as pl # type: ignore
from d2go.config import CfgNode from d2go.config import CfgNode
from d2go.runner import create_runner from d2go.runner import create_runner
from d2go.runner.callbacks.quantization import QuantizationAwareTraining from d2go.runner.callbacks.quantization import QuantizationAwareTraining
from d2go.runner.lightning_task import GeneralizedRCNNTask from d2go.runner.lightning_task import DefaultTask, GeneralizedRCNNTask
from d2go.setup import basic_argument_parser, setup_after_launch from d2go.setup import basic_argument_parser, setup_after_launch
from d2go.trainer.lightning.training_loop import _do_test, _do_train from d2go.trainer.lightning.training_loop import _do_test, _do_train
from detectron2.utils.file_io import PathManager from detectron2.utils.file_io import PathManager
...@@ -92,7 +92,7 @@ def get_trainer_params(cfg: CfgNode) -> Dict[str, Any]: ...@@ -92,7 +92,7 @@ def get_trainer_params(cfg: CfgNode) -> Dict[str, Any]:
def main( def main(
cfg: CfgNode, cfg: CfgNode,
output_dir: str, output_dir: str,
task_cls: Type[GeneralizedRCNNTask] = GeneralizedRCNNTask, runner_class: Union[str, Type[DefaultTask]],
eval_only: bool = False, eval_only: bool = False,
) -> TrainOutput: ) -> TrainOutput:
"""Main function for launching a training with lightning trainer """Main function for launching a training with lightning trainer
...@@ -102,7 +102,7 @@ def main( ...@@ -102,7 +102,7 @@ def main(
num_processes: Number of processes on each node. num_processes: Number of processes on each node.
eval_only: True if run evaluation only. eval_only: True if run evaluation only.
""" """
setup_after_launch(cfg, output_dir, task_cls) task_cls: Type[DefaultTask] = setup_after_launch(cfg, output_dir, runner_class)
task = task_cls.from_config(cfg, eval_only) task = task_cls.from_config(cfg, eval_only)
trainer_params = get_trainer_params(cfg) trainer_params = get_trainer_params(cfg)
...@@ -130,7 +130,7 @@ def main( ...@@ -130,7 +130,7 @@ def main(
def build_config( def build_config(
config_file: str, config_file: str,
task_cls: Type[GeneralizedRCNNTask], task_cls: Type[DefaultTask],
opts: Optional[List[str]] = None, opts: Optional[List[str]] = None,
) -> CfgNode: ) -> CfgNode:
"""Build config node from config file """Build config node from config file
......
...@@ -7,10 +7,12 @@ Detection Training Script. ...@@ -7,10 +7,12 @@ Detection Training Script.
import logging import logging
import sys import sys
from typing import List from typing import List, Type, Union
import detectron2.utils.comm as comm import detectron2.utils.comm as comm
from d2go.config import CfgNode
from d2go.distributed import launch from d2go.distributed import launch
from d2go.runner import BaseRunner
from d2go.setup import ( from d2go.setup import (
basic_argument_parser, basic_argument_parser,
build_basic_cli_args, build_basic_cli_args,
...@@ -30,14 +32,13 @@ logger = logging.getLogger("d2go.tools.train_net") ...@@ -30,14 +32,13 @@ logger = logging.getLogger("d2go.tools.train_net")
def main( def main(
cfg, cfg: CfgNode,
output_dir, output_dir: str,
runner=None, runner_class: Union[str, Type[BaseRunner]],
eval_only=False, eval_only: bool = False,
# NOTE: always enable resume when running on cluster resume: bool = True, # NOTE: always enable resume when running on cluster
resume=True,
): ):
setup_after_launch(cfg, output_dir, runner) runner = setup_after_launch(cfg, output_dir, runner_class)
model = runner.build_model(cfg) model = runner.build_model(cfg)
logger.info("Model:\n{}".format(model)) logger.info("Model:\n{}".format(model))
...@@ -84,7 +85,7 @@ def main( ...@@ -84,7 +85,7 @@ def main(
def run_with_cmdline_args(args): def run_with_cmdline_args(args):
cfg, output_dir, runner = prepare_for_launch(args) cfg, output_dir, runner_name = prepare_for_launch(args)
outputs = launch( outputs = launch(
post_mortem_if_fail_for_main(main), post_mortem_if_fail_for_main(main),
...@@ -93,7 +94,7 @@ def run_with_cmdline_args(args): ...@@ -93,7 +94,7 @@ def run_with_cmdline_args(args):
machine_rank=args.machine_rank, machine_rank=args.machine_rank,
dist_url=args.dist_url, dist_url=args.dist_url,
backend=args.dist_backend, backend=args.dist_backend,
args=(cfg, output_dir, runner, args.eval_only, args.resume), args=(cfg, output_dir, runner_name, args.eval_only, args.resume),
) )
if args.save_return_file is not None: if args.save_return_file is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment