"tests/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "0703ce88008b2765ef6636c6e5cb013d227c42ca"
Commit b077a2c1 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

create cfg using runner class instead of instance

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/310

One step towards using runner class across board. This is also need for supporting LazyConfig (which is runner-less)

Reviewed By: mcimpoi

Differential Revision: D37294926

fbshipit-source-id: f6dfc0a1103bac328ac7b337ce3aaefd5d8d85b4
parent bcdcc341
...@@ -21,6 +21,7 @@ __all__ = [ ...@@ -21,6 +21,7 @@ __all__ = [
] ]
# TODO: remove this function
def create_runner( def create_runner(
class_full_name: Optional[str], *args, **kwargs class_full_name: Optional[str], *args, **kwargs
) -> Union[BaseRunner, Type[DefaultTask]]: ) -> Union[BaseRunner, Type[DefaultTask]]:
...@@ -30,11 +31,16 @@ def create_runner( ...@@ -30,11 +31,16 @@ def create_runner(
if class_full_name is None: if class_full_name is None:
runner_class = GeneralizedRCNNRunner runner_class = GeneralizedRCNNRunner
else: else:
runner_module_name, runner_class_name = class_full_name.rsplit(".", 1) runner_class = import_runner(class_full_name)
runner_module = importlib.import_module(runner_module_name)
runner_class = getattr(runner_module, runner_class_name)
if issubclass(runner_class, DefaultTask): if issubclass(runner_class, DefaultTask):
# Return runner class for Lightning module since it requires config # Return runner class for Lightning module since it requires config
# to construct # to construct
return runner_class return runner_class
return runner_class(*args, **kwargs) return runner_class(*args, **kwargs)
def import_runner(class_full_name: str) -> Type[Union[BaseRunner, DefaultTask]]:
runner_module_name, runner_class_name = class_full_name.rsplit(".", 1)
runner_module = importlib.import_module(runner_module_name)
runner_class = getattr(runner_module, runner_class_name)
return runner_class
...@@ -19,7 +19,13 @@ from d2go.config import ( ...@@ -19,7 +19,13 @@ from d2go.config import (
) )
from d2go.config.utils import get_diff_cfg from d2go.config.utils import get_diff_cfg
from d2go.distributed import get_local_rank, get_num_processes_per_machine from d2go.distributed import get_local_rank, get_num_processes_per_machine
from d2go.runner import BaseRunner, create_runner, DefaultTask, RunnerV2Mixin from d2go.runner import (
BaseRunner,
create_runner,
DefaultTask,
import_runner,
RunnerV2Mixin,
)
from d2go.utils.helper import run_once from d2go.utils.helper import run_once
from d2go.utils.launch_environment import get_launch_environment from d2go.utils.launch_environment import get_launch_environment
from detectron2.utils.collect_env import collect_env_info from detectron2.utils.collect_env import collect_env_info
...@@ -127,6 +133,37 @@ def build_basic_cli_args( ...@@ -127,6 +133,37 @@ def build_basic_cli_args(
return args return args
def create_cfg_from_cli(
config_file: str,
overwrites: Optional[List[str]],
runner_class: Union[None, str, Type[BaseRunner], Type[DefaultTask]],
) -> CfgNode:
"""
Centralized function to load config object from config file. It currently supports:
- YACS based config (return yacs's CfgNode)
"""
config_file = reroute_config_path(config_file)
with PathManager.open(config_file, "r") as f:
# TODO: switch to logger, note that we need to initilaize logger outside of main
# for running locally.
print("Loaded config file {}:\n{}".format(config_file, f.read()))
if isinstance(runner_class, str):
runner_class = import_runner(runner_class)
if runner_class is None or issubclass(runner_class, RunnerV2Mixin):
# Runner-less API
cfg = load_full_config_from_file(config_file)
else:
# backward compatible for old API
cfg = runner_class.get_default_cfg()
cfg.merge_from_file(config_file)
cfg.merge_from_list(overwrites or [])
cfg.freeze()
return cfg
def prepare_for_launch(args): def prepare_for_launch(args):
""" """
Load config, figure out working directory, create runner. Load config, figure out working directory, create runner.
...@@ -135,22 +172,19 @@ def prepare_for_launch(args): ...@@ -135,22 +172,19 @@ def prepare_for_launch(args):
priority than cfg.OUTPUT_DIR. priority than cfg.OUTPUT_DIR.
""" """
logger.info(args) logger.info(args)
runner = create_runner(args.runner)
with PathManager.open(reroute_config_path(args.config_file), "r") as f:
print("Loaded config file {}:\n{}".format(args.config_file, f.read()))
if isinstance(runner, RunnerV2Mixin):
cfg = load_full_config_from_file(args.config_file)
else:
cfg = runner.get_default_cfg()
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts) cfg = create_cfg_from_cli(
cfg.freeze() config_file=args.config_file,
overwrites=args.opts,
runner_class=args.runner,
)
# overwrite the output_dir based on config if output is not set via cli
assert args.output_dir or args.config_file assert args.output_dir or args.config_file
output_dir = args.output_dir or cfg.OUTPUT_DIR output_dir = args.output_dir or cfg.OUTPUT_DIR
# TODO (T123980149): use runner_name across the board
runner = create_runner(args.runner)
return cfg, output_dir, runner return cfg, output_dir, runner
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment