Commit 8051775c authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

use runner class instead of instance outside of main

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/312

As discussed, we decided to not use runner instance outside of `main`, previous diffs already solved the prerequisites, this diff mainly does the renaming.
- Use runner name (str) in the fblearner, ML pipeline.
- Use runner name (str) in FBL operator, MAST and binary operator.
- Use runner class as the interface of main, it can be either the name of class (str) or actual class. The main usage should be using `str`, so that the importing of class happens inside `main`. But it's also a common use case to import runner class and call `main` for things like ad-hoc scripts or tests, supporting actual class makes it easier modify code for those cases (eg. some local test class doesn't have a name, so it's not feasible to use runner name).

Reviewed By: newstzpz

Differential Revision: D37060338

fbshipit-source-id: 879852d41902b87d6db6cb9d7b3e8dc55dc4b976
parent b077a2c1
......@@ -18,6 +18,7 @@ __all__ = [
"GeneralizedRCNNRunner",
"TRAINER_HOOKS_REGISTRY",
"create_runner",
"import_runner",
]
......
......@@ -6,7 +6,7 @@ import argparse
import logging
import os
import time
from typing import List, Optional, Type, Union
from typing import List, Optional, Tuple, Type, Union
import detectron2.utils.comm as comm
import torch
......@@ -19,13 +19,7 @@ from d2go.config import (
)
from d2go.config.utils import get_diff_cfg
from d2go.distributed import get_local_rank, get_num_processes_per_machine
from d2go.runner import (
BaseRunner,
create_runner,
DefaultTask,
import_runner,
RunnerV2Mixin,
)
from d2go.runner import BaseRunner, DefaultTask, import_runner, RunnerV2Mixin
from d2go.utils.helper import run_once
from d2go.utils.launch_environment import get_launch_environment
from detectron2.utils.collect_env import collect_env_info
......@@ -149,6 +143,7 @@ def create_cfg_from_cli(
print("Loaded config file {}:\n{}".format(config_file, f.read()))
if isinstance(runner_class, str):
print(f"Importing runner: {runner_class} ...")
runner_class = import_runner(runner_class)
if runner_class is None or issubclass(runner_class, RunnerV2Mixin):
# Runner-less API
......@@ -164,7 +159,9 @@ def create_cfg_from_cli(
return cfg
def prepare_for_launch(args):
def prepare_for_launch(
args,
) -> Tuple[CfgNode, str, Optional[str]]:
"""
Load config, figure out working directory, create runner.
- when args.config_file is empty, returned cfg will be the default one
......@@ -183,9 +180,7 @@ def prepare_for_launch(args):
assert args.output_dir or args.config_file
output_dir = args.output_dir or cfg.OUTPUT_DIR
# TODO (T123980149): use runner_name across the board
runner = create_runner(args.runner)
return cfg, output_dir, runner
return cfg, output_dir, args.runner
def maybe_override_output_dir(cfg: CfgNode, output_dir: str):
......@@ -202,8 +197,8 @@ def maybe_override_output_dir(cfg: CfgNode, output_dir: str):
def setup_after_launch(
cfg: CfgNode,
output_dir: str,
runner: Union[BaseRunner, Type[DefaultTask], None],
):
runner_class: Union[None, str, Type[BaseRunner], Type[DefaultTask]],
) -> Union[None, BaseRunner, Type[DefaultTask]]:
"""
Binary-level setup after entering DDP, including
- creating working directory
......@@ -222,10 +217,22 @@ def setup_after_launch(
logger.info("Running with full config:\n{}".format(cfg))
dump_cfg(cfg, os.path.join(output_dir, "config.yaml"))
if isinstance(runner, BaseRunner):
logger.info("Initializing runner ...")
if isinstance(runner_class, str):
logger.info(f"Importing runner: {runner_class} ...")
runner_class = import_runner(runner_class)
if issubclass(runner_class, DefaultTask):
# TODO(T123679504): merge this with runner code path to return runner instance
logger.info(f"Importing lightning task: {runner_class} ...")
runner = runner_class
elif issubclass(runner_class, BaseRunner):
logger.info(f"Initializing runner: {runner_class} ...")
runner = runner_class()
runner = initialize_runner(runner, cfg)
logger.info("Running with runner: {}".format(runner))
else:
assert runner_class is None, f"Unsupported runner class: {runner_class}"
runner = None
# save the diff config
default_cfg = (
......@@ -241,6 +248,8 @@ def setup_after_launch(
# scale the config after dumping so that dumped config files keep original world size
auto_scale_world_size(cfg, new_world_size=comm.get_world_size())
return runner
def setup_logger(
module_name: str,
......
......@@ -33,7 +33,7 @@ class TestLightningTrainNet(unittest.TestCase):
cfg = self._get_cfg(root_dir)
# set distributed backend to none to avoid spawning child process,
# which doesn't inherit the temporary dataset
main(cfg, root_dir)
main(cfg, root_dir, GeneralizedRCNNTask)
@tempdir
@enable_ddp_env
......@@ -41,7 +41,7 @@ class TestLightningTrainNet(unittest.TestCase):
"""tests saving and loading from checkpoint."""
cfg = self._get_cfg(tmp_dir)
out = main(cfg, tmp_dir)
out = main(cfg, tmp_dir, GeneralizedRCNNTask)
ckpts = [f for f in os.listdir(tmp_dir) if f.endswith(".ckpt")]
expected_ckpts = ("last.ckpt", FINAL_MODEL_CKPT)
for ckpt in expected_ckpts:
......@@ -53,7 +53,7 @@ class TestLightningTrainNet(unittest.TestCase):
cfg2.MODEL.WEIGHTS = os.path.join(tmp_dir, "last.ckpt")
output_dir = os.path.join(tmp_dir, "output")
out2 = main(cfg2, output_dir, eval_only=True)
out2 = main(cfg2, output_dir, GeneralizedRCNNTask, eval_only=True)
accuracy = flatten_config_dict(out.accuracy)
accuracy2 = flatten_config_dict(out2.accuracy)
for k in accuracy:
......
......@@ -24,8 +24,7 @@ def maskrcnn_export_caffe2_vs_torchvision_opset_format_example(self):
(dataset_name,),
]
# START_WIKI_EXAMPLE_TAG
runner = GeneralizedRCNNRunner()
cfg = runner.get_default_cfg()
cfg = GeneralizedRCNNRunner.get_default_cfg()
cfg.merge_from_file("detectron2go://mask_rcnn_fbnetv3a_dsmask_C4.yaml")
cfg.merge_from_list(get_quick_test_config_opts())
cfg.merge_from_list(config_list)
......@@ -35,7 +34,7 @@ def maskrcnn_export_caffe2_vs_torchvision_opset_format_example(self):
_ = main(
cfg,
tmp_dir,
runner,
GeneralizedRCNNRunner,
predictor_types=["torchscript@c2_ops", "torchscript"],
)
......
......@@ -6,10 +6,13 @@ Tool for benchmarking data loading
import logging
import time
from typing import Type, Union
import detectron2.utils.comm as comm
import numpy as np
from d2go.config import CfgNode
from d2go.distributed import get_num_processes_per_machine, launch
from d2go.runner import BaseRunner
from d2go.setup import (
basic_argument_parser,
post_mortem_if_fail_for_main,
......@@ -21,17 +24,16 @@ from detectron2.fb.env import get_launch_environment
from detectron2.utils.logger import log_every_n_seconds
from fvcore.common.history_buffer import HistoryBuffer
logger = logging.getLogger("d2go.tools.benchmark_data")
def main(
cfg,
output_dir,
runner=None,
is_train=True,
cfg: CfgNode,
output_dir: str,
runner_class: Union[str, Type[BaseRunner]],
is_train: bool = True,
):
setup_after_launch(cfg, output_dir, runner)
runner = setup_after_launch(cfg, output_dir, runner_class)
if is_train:
data_loader = runner.build_detection_train_loader(cfg)
......@@ -133,7 +135,7 @@ def main(
def run_with_cmdline_args(args):
cfg, output_dir, runner = prepare_for_launch(args)
cfg, output_dir, runner_name = prepare_for_launch(args)
launch(
post_mortem_if_fail_for_main(main),
num_processes_per_machine=args.num_processes,
......@@ -141,7 +143,7 @@ def run_with_cmdline_args(args):
machine_rank=args.machine_rank,
dist_url=args.dist_url,
backend=args.dist_backend,
args=(cfg, output_dir, runner, args.is_train),
args=(cfg, output_dir, runner_name, args.is_train),
)
......
......@@ -8,9 +8,12 @@ torchscript, caffe2, etc.) using Detectron2Go system (dataloading, evaluation, e
import logging
import sys
from typing import Optional, Type, Union
import torch
from d2go.config import CfgNode
from d2go.distributed import launch
from d2go.runner import BaseRunner
from d2go.setup import (
basic_argument_parser,
caffe2_global_init,
......@@ -26,19 +29,19 @@ logger = logging.getLogger("d2go.tools.caffe2_evaluator")
def main(
cfg,
output_dir,
runner,
cfg: CfgNode,
output_dir: str,
runner_class: Union[str, Type[BaseRunner]],
# binary specific optional arguments
predictor_path,
num_threads=None,
caffe2_engine=None,
caffe2_logging_print_net_summary=0,
predictor_path: str,
num_threads: Optional[int] = None,
caffe2_engine: Optional[int] = None,
caffe2_logging_print_net_summary: int = 0,
):
torch.backends.quantized.engine = cfg.QUANTIZATION.BACKEND
print("run with quantized engine: ", torch.backends.quantized.engine)
setup_after_launch(cfg, output_dir, runner)
runner = setup_after_launch(cfg, output_dir, runner_class)
caffe2_global_init(caffe2_logging_print_net_summary, num_threads)
predictor = create_predictor(predictor_path)
......@@ -52,7 +55,7 @@ def main(
@post_mortem_if_fail()
def run_with_cmdline_args(args):
cfg, output_dir, runner = prepare_for_launch(args)
cfg, output_dir, runner_name = prepare_for_launch(args)
launch(
post_mortem_if_fail_for_main(main),
args.num_processes,
......@@ -64,7 +67,7 @@ def run_with_cmdline_args(args):
args=(
cfg,
output_dir,
runner,
runner_name,
# binary specific optional arguments
args.predictor_path,
args.num_threads,
......
......@@ -9,11 +9,12 @@ deployable format (such as torchscript, caffe2, ...)
import copy
import logging
import sys
import typing
from typing import Dict, List, Type, Union
import mobile_cv.lut.lib.pt.flops_utils as flops_utils
from d2go.config import temp_defrost
from d2go.config import CfgNode, temp_defrost
from d2go.export.exporter import convert_and_export_predictor
from d2go.runner import BaseRunner
from d2go.setup import basic_argument_parser, prepare_for_launch, setup_after_launch
from mobile_cv.common.misc.py import post_mortem_if_fail
......@@ -22,11 +23,11 @@ logger = logging.getLogger("d2go.tools.export")
def main(
cfg,
output_dir,
runner,
cfg: CfgNode,
output_dir: str,
runner_class: Union[str, Type[BaseRunner]],
# binary specific optional arguments
predictor_types: typing.List[str],
predictor_types: List[str],
device: str = "cpu",
compare_accuracy: bool = False,
skip_if_fail: bool = False,
......@@ -39,7 +40,7 @@ def main(
# ret["accuracy_comparison"] = accuracy_comparison
cfg = copy.deepcopy(cfg)
setup_after_launch(cfg, output_dir, runner)
runner = setup_after_launch(cfg, output_dir, runner_class)
with temp_defrost(cfg):
cfg.merge_from_list(["MODEL.DEVICE", device])
......@@ -56,7 +57,7 @@ def main(
input_args = (first_batch,)
flops_utils.print_model_flops(model, input_args)
predictor_paths: typing.Dict[str, str] = {}
predictor_paths: Dict[str, str] = {}
for typ in predictor_types:
# convert_and_export_predictor might alter the model, copy before calling it
pytorch_model = copy.deepcopy(model)
......@@ -82,11 +83,11 @@ def main(
@post_mortem_if_fail()
def run_with_cmdline_args(args):
cfg, output_dir, runner = prepare_for_launch(args)
cfg, output_dir, runner_name = prepare_for_launch(args)
return main(
cfg,
output_dir,
runner,
runner_name,
# binary specific optional arguments
predictor_types=args.predictor_types,
device=args.device,
......
......@@ -5,14 +5,14 @@
import logging
import os
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Type
from typing import Any, Dict, List, Optional, Type, Union
import mobile_cv.torch.utils_pytorch.comm as comm
import pytorch_lightning as pl # type: ignore
from d2go.config import CfgNode
from d2go.runner import create_runner
from d2go.runner.callbacks.quantization import QuantizationAwareTraining
from d2go.runner.lightning_task import GeneralizedRCNNTask
from d2go.runner.lightning_task import DefaultTask, GeneralizedRCNNTask
from d2go.setup import basic_argument_parser, setup_after_launch
from d2go.trainer.lightning.training_loop import _do_test, _do_train
from detectron2.utils.file_io import PathManager
......@@ -92,7 +92,7 @@ def get_trainer_params(cfg: CfgNode) -> Dict[str, Any]:
def main(
cfg: CfgNode,
output_dir: str,
task_cls: Type[GeneralizedRCNNTask] = GeneralizedRCNNTask,
runner_class: Union[str, Type[DefaultTask]],
eval_only: bool = False,
) -> TrainOutput:
"""Main function for launching a training with lightning trainer
......@@ -102,7 +102,7 @@ def main(
num_processes: Number of processes on each node.
eval_only: True if run evaluation only.
"""
setup_after_launch(cfg, output_dir, task_cls)
task_cls: Type[DefaultTask] = setup_after_launch(cfg, output_dir, runner_class)
task = task_cls.from_config(cfg, eval_only)
trainer_params = get_trainer_params(cfg)
......@@ -130,7 +130,7 @@ def main(
def build_config(
config_file: str,
task_cls: Type[GeneralizedRCNNTask],
task_cls: Type[DefaultTask],
opts: Optional[List[str]] = None,
) -> CfgNode:
"""Build config node from config file
......
......@@ -7,10 +7,12 @@ Detection Training Script.
import logging
import sys
from typing import List
from typing import List, Type, Union
import detectron2.utils.comm as comm
from d2go.config import CfgNode
from d2go.distributed import launch
from d2go.runner import BaseRunner
from d2go.setup import (
basic_argument_parser,
build_basic_cli_args,
......@@ -30,14 +32,13 @@ logger = logging.getLogger("d2go.tools.train_net")
def main(
cfg,
output_dir,
runner=None,
eval_only=False,
# NOTE: always enable resume when running on cluster
resume=True,
cfg: CfgNode,
output_dir: str,
runner_class: Union[str, Type[BaseRunner]],
eval_only: bool = False,
resume: bool = True, # NOTE: always enable resume when running on cluster
):
setup_after_launch(cfg, output_dir, runner)
runner = setup_after_launch(cfg, output_dir, runner_class)
model = runner.build_model(cfg)
logger.info("Model:\n{}".format(model))
......@@ -84,7 +85,7 @@ def main(
def run_with_cmdline_args(args):
cfg, output_dir, runner = prepare_for_launch(args)
cfg, output_dir, runner_name = prepare_for_launch(args)
outputs = launch(
post_mortem_if_fail_for_main(main),
......@@ -93,7 +94,7 @@ def run_with_cmdline_args(args):
machine_rank=args.machine_rank,
dist_url=args.dist_url,
backend=args.dist_backend,
args=(cfg, output_dir, runner, args.eval_only, args.resume),
args=(cfg, output_dir, runner_name, args.eval_only, args.resume),
)
if args.save_return_file is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment