Commit 521b3cad authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

support for using config file with _DEFAULTS_ via cli

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/271

- set `get_default_cfg = None` to opt-in this new feature.
- support config with `_DEFAULTS_` when loading config file. Note that we don't check explicitly `"_DEFAULTS_" in config file` but check `runner.get_default_cfg == None`, this is because runner having `get_default_cfg` and config having `_DEFAULTS_` should be mutually exclusive, and `load_full_config_from_file` can raise proper error if `_DEFAULTS_` is missing.
- we also need to save `_DEFAULTS_` in the diff config.

Reviewed By: tglik

Differential Revision: D36868581

fbshipit-source-id: e0e19309c3df5a85383ce1454b321a68d0868dc4
parent 318a3d79
...@@ -8,16 +8,14 @@ from typing import List ...@@ -8,16 +8,14 @@ from typing import List
import mock import mock
import yaml import yaml
from d2go.registry.builtin import CONFIG_UPDATER_REGISTRY
from detectron2.config import CfgNode as _CfgNode from detectron2.config import CfgNode as _CfgNode
from fvcore.common.registry import Registry from fvcore.common.registry import Registry
from .utils import reroute_config_path from .utils import reroute_config_path, resolve_default_config
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
CONFIG_CUSTOM_PARSE_REGISTRY = Registry("CONFIG_CUSTOM_PARSE") CONFIG_CUSTOM_PARSE_REGISTRY = Registry("CONFIG_CUSTOM_PARSE")
DEFAULTS_GENERATOR_KEY = "_DEFAULTS_"
def _opts_to_dict(opts: List[str]): def _opts_to_dict(opts: List[str]):
...@@ -65,6 +63,7 @@ class CfgNode(_CfgNode): ...@@ -65,6 +63,7 @@ class CfgNode(_CfgNode):
@staticmethod @staticmethod
def load_yaml_with_base(filename: str, *args, **kwargs): def load_yaml_with_base(filename: str, *args, **kwargs):
filename = reroute_config_path(filename)
with reroute_load_yaml_with_base(): with reroute_load_yaml_with_base():
return _CfgNode.load_yaml_with_base(filename, *args, **kwargs) return _CfgNode.load_yaml_with_base(filename, *args, **kwargs)
...@@ -84,7 +83,7 @@ class CfgNode(_CfgNode): ...@@ -84,7 +83,7 @@ class CfgNode(_CfgNode):
def get_default_cfg(self): def get_default_cfg(self):
"""Return the defaults for this instance of CfgNode""" """Return the defaults for this instance of CfgNode"""
return _resolve_default_config(self) return resolve_default_config(self)
@contextlib.contextmanager @contextlib.contextmanager
...@@ -178,29 +177,6 @@ def auto_scale_world_size(cfg, new_world_size): ...@@ -178,29 +177,6 @@ def auto_scale_world_size(cfg, new_world_size):
logger.info("Auto-scaled the config according to the actual world size: \n" + table) logger.info("Auto-scaled the config according to the actual world size: \n" + table)
def _resolve_default_config(cfg: CfgNode) -> CfgNode:
if DEFAULTS_GENERATOR_KEY not in cfg:
raise ValueError(
f"Can't resolved default config because `{DEFAULTS_GENERATOR_KEY}` is"
f" missing from cfg: \n{cfg}"
)
updater_names: List[str] = cfg[DEFAULTS_GENERATOR_KEY]
assert isinstance(updater_names, list), updater_names
assert [isinstance(x, str) for x in updater_names], updater_names
# starting from a empty CfgNode, sequentially apply the generator
cfg = CfgNode()
for name in updater_names:
updater = CONFIG_UPDATER_REGISTRY.get(name)
cfg = updater(cfg)
# the resolved default config should keep the same default generator
cfg[DEFAULTS_GENERATOR_KEY] = updater_names
return cfg
def load_full_config_from_file(filename: str) -> CfgNode: def load_full_config_from_file(filename: str) -> CfgNode:
loaded_cfg = CfgNode.load_yaml_with_base(filename) loaded_cfg = CfgNode.load_yaml_with_base(filename)
loaded_cfg = CfgNode(loaded_cfg) # cast Dict to CfgNode loaded_cfg = CfgNode(loaded_cfg) # cast Dict to CfgNode
......
...@@ -7,9 +7,13 @@ from enum import Enum ...@@ -7,9 +7,13 @@ from enum import Enum
from typing import Any, Dict, List from typing import Any, Dict, List
import pkg_resources import pkg_resources
from d2go.registry.builtin import CONFIG_UPDATER_REGISTRY
from mobile_cv.common.misc.oss_utils import fb_overwritable from mobile_cv.common.misc.oss_utils import fb_overwritable
logger = logging.getLogger(__name__)
DEFAULTS_KEY = "_DEFAULTS_"
def reroute_config_path(path: str) -> str: def reroute_config_path(path: str) -> str:
""" """
...@@ -210,7 +214,18 @@ def get_diff_cfg(old_cfg, new_cfg): ...@@ -210,7 +214,18 @@ def get_diff_cfg(old_cfg, new_cfg):
return out return out
out = new_cfg.__class__() out = new_cfg.__class__()
return get_diff_cfg_rec(old_cfg, new_cfg, out) diff_cfg = get_diff_cfg_rec(old_cfg, new_cfg, out)
# Keep the `_DEFAULTS_` even though they should be the same
old_defaults = old_cfg.get(DEFAULTS_KEY, None)
new_defaults = new_cfg.get(DEFAULTS_KEY, None)
assert (
old_defaults == new_defaults
), f"{DEFAULTS_KEY} doesn't match! old ({old_defaults}) vs new ({new_defaults})"
if new_defaults is not None:
diff_cfg[DEFAULTS_KEY] = new_defaults
return diff_cfg
def namedtuple_to_dict(obj: Any): def namedtuple_to_dict(obj: Any):
...@@ -223,3 +238,27 @@ def namedtuple_to_dict(obj: Any): ...@@ -223,3 +238,27 @@ def namedtuple_to_dict(obj: Any):
else: else:
res[k] = v res[k] = v
return res return res
def resolve_default_config(cfg):
if DEFAULTS_KEY not in cfg:
raise ValueError(
f"Can't resolved default config because `{DEFAULTS_KEY}` is"
f" missing from cfg: \n{cfg}"
)
updater_names: List[str] = cfg[DEFAULTS_KEY]
assert isinstance(updater_names, list), updater_names
assert [isinstance(x, str) for x in updater_names], updater_names
logger.info(f"Resolving default config by applying updaters: {updater_names} ...")
# starting from a empty CfgNode, sequentially apply the generator
cfg = type(cfg)()
for name in updater_names:
updater = CONFIG_UPDATER_REGISTRY.get(name)
cfg = updater(cfg)
# the resolved default config should keep the same default generator
cfg[DEFAULTS_KEY] = updater_names
return cfg
...@@ -5,12 +5,14 @@ ...@@ -5,12 +5,14 @@
import importlib import importlib
from typing import Optional, Type, Union from typing import Optional, Type, Union
from .api import RunnerV2Mixin
from .default_runner import BaseRunner, Detectron2GoRunner, GeneralizedRCNNRunner from .default_runner import BaseRunner, Detectron2GoRunner, GeneralizedRCNNRunner
from .lightning_task import DefaultTask from .lightning_task import DefaultTask
from .training_hooks import TRAINER_HOOKS_REGISTRY from .training_hooks import TRAINER_HOOKS_REGISTRY
__all__ = [ __all__ = [
"RunnerV2Mixin",
"BaseRunner", "BaseRunner",
"Detectron2GoRunner", "Detectron2GoRunner",
"GeneralizedRCNNRunner", "GeneralizedRCNNRunner",
......
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import final
from d2go.config import CfgNode
class RunnerV2Mixin(object):
"""
Interface for (V2) Runner:
- `get_default_cfg` is not a runner method anymore.
"""
@classmethod
@final
def get_default_cfg(cls) -> CfgNode:
raise NotImplementedError("")
...@@ -13,12 +13,13 @@ import torch ...@@ -13,12 +13,13 @@ import torch
from d2go.config import ( from d2go.config import (
auto_scale_world_size, auto_scale_world_size,
CfgNode, CfgNode,
load_full_config_from_file,
reroute_config_path, reroute_config_path,
temp_defrost, temp_defrost,
) )
from d2go.config.utils import get_diff_cfg from d2go.config.utils import get_diff_cfg
from d2go.distributed import get_local_rank, get_num_processes_per_machine from d2go.distributed import get_local_rank, get_num_processes_per_machine
from d2go.runner import BaseRunner, create_runner, DefaultTask from d2go.runner import BaseRunner, create_runner, DefaultTask, RunnerV2Mixin
from d2go.utils.helper import run_once from d2go.utils.helper import run_once
from d2go.utils.launch_environment import get_launch_environment from d2go.utils.launch_environment import get_launch_environment
from detectron2.utils.collect_env import collect_env_info from detectron2.utils.collect_env import collect_env_info
...@@ -127,11 +128,15 @@ def prepare_for_launch(args): ...@@ -127,11 +128,15 @@ def prepare_for_launch(args):
logger.info(args) logger.info(args)
runner = create_runner(args.runner) runner = create_runner(args.runner)
cfg = runner.get_default_cfg()
with PathManager.open(reroute_config_path(args.config_file), "r") as f: with PathManager.open(reroute_config_path(args.config_file), "r") as f:
print("Loaded config file {}:\n{}".format(args.config_file, f.read())) print("Loaded config file {}:\n{}".format(args.config_file, f.read()))
cfg.merge_from_file(args.config_file)
if isinstance(runner, RunnerV2Mixin):
cfg = load_full_config_from_file(args.config_file)
else:
cfg = runner.get_default_cfg()
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts) cfg.merge_from_list(args.opts)
cfg.freeze() cfg.freeze()
...@@ -180,15 +185,15 @@ def setup_after_launch( ...@@ -180,15 +185,15 @@ def setup_after_launch(
logger.info("Running with runner: {}".format(runner)) logger.info("Running with runner: {}".format(runner))
# save the diff config # save the diff config
if runner is not None: default_cfg = (
default_cfg = runner.get_default_cfg() runner.get_default_cfg()
dump_cfg( if runner and not isinstance(runner, RunnerV2Mixin)
get_diff_cfg(default_cfg, cfg), else cfg.get_default_cfg()
os.path.join(output_dir, "diff_config.yaml"), )
) dump_cfg(
else: get_diff_cfg(default_cfg, cfg),
# TODO: support getting default_cfg without runner. os.path.join(output_dir, "diff_config.yaml"),
pass )
# scale the config after dumping so that dumped config files keep original world size # scale the config after dumping so that dumped config files keep original world size
auto_scale_world_size(cfg, new_world_size=comm.get_world_size()) auto_scale_world_size(cfg, new_world_size=comm.get_world_size())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment