Commit 521b3cad authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

support for using config file with _DEFAULTS_ via cli

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/271

- set `get_default_cfg = None` to opt-in this new feature.
- support config with `_DEFAULTS_` when loading config file. Note that we don't check explicitly `"_DEFAULTS_" in config file` but check `runner.get_default_cfg == None`, this is because runner having `get_default_cfg` and config having `_DEFAULTS_` should be mutually exclusive, and `load_full_config_from_file` can raise proper error if `_DEFAULTS_` is missing.
- we also need to save `_DEFAULTS_` in the diff config.

Reviewed By: tglik

Differential Revision: D36868581

fbshipit-source-id: e0e19309c3df5a85383ce1454b321a68d0868dc4
parent 318a3d79
......@@ -8,16 +8,14 @@ from typing import List
import mock
import yaml
from d2go.registry.builtin import CONFIG_UPDATER_REGISTRY
from detectron2.config import CfgNode as _CfgNode
from fvcore.common.registry import Registry
from .utils import reroute_config_path
from .utils import reroute_config_path, resolve_default_config
logger = logging.getLogger(__name__)
CONFIG_CUSTOM_PARSE_REGISTRY = Registry("CONFIG_CUSTOM_PARSE")
DEFAULTS_GENERATOR_KEY = "_DEFAULTS_"
def _opts_to_dict(opts: List[str]):
......@@ -65,6 +63,7 @@ class CfgNode(_CfgNode):
@staticmethod
def load_yaml_with_base(filename: str, *args, **kwargs):
filename = reroute_config_path(filename)
with reroute_load_yaml_with_base():
return _CfgNode.load_yaml_with_base(filename, *args, **kwargs)
......@@ -84,7 +83,7 @@ class CfgNode(_CfgNode):
def get_default_cfg(self):
"""Return the defaults for this instance of CfgNode"""
return _resolve_default_config(self)
return resolve_default_config(self)
@contextlib.contextmanager
......@@ -178,29 +177,6 @@ def auto_scale_world_size(cfg, new_world_size):
logger.info("Auto-scaled the config according to the actual world size: \n" + table)
def _resolve_default_config(cfg: CfgNode) -> CfgNode:
if DEFAULTS_GENERATOR_KEY not in cfg:
raise ValueError(
f"Can't resolved default config because `{DEFAULTS_GENERATOR_KEY}` is"
f" missing from cfg: \n{cfg}"
)
updater_names: List[str] = cfg[DEFAULTS_GENERATOR_KEY]
assert isinstance(updater_names, list), updater_names
assert [isinstance(x, str) for x in updater_names], updater_names
# starting from a empty CfgNode, sequentially apply the generator
cfg = CfgNode()
for name in updater_names:
updater = CONFIG_UPDATER_REGISTRY.get(name)
cfg = updater(cfg)
# the resolved default config should keep the same default generator
cfg[DEFAULTS_GENERATOR_KEY] = updater_names
return cfg
def load_full_config_from_file(filename: str) -> CfgNode:
loaded_cfg = CfgNode.load_yaml_with_base(filename)
loaded_cfg = CfgNode(loaded_cfg) # cast Dict to CfgNode
......
......@@ -7,9 +7,13 @@ from enum import Enum
from typing import Any, Dict, List
import pkg_resources
from d2go.registry.builtin import CONFIG_UPDATER_REGISTRY
from mobile_cv.common.misc.oss_utils import fb_overwritable
logger = logging.getLogger(__name__)
DEFAULTS_KEY = "_DEFAULTS_"
def reroute_config_path(path: str) -> str:
"""
......@@ -210,7 +214,18 @@ def get_diff_cfg(old_cfg, new_cfg):
return out
out = new_cfg.__class__()
return get_diff_cfg_rec(old_cfg, new_cfg, out)
diff_cfg = get_diff_cfg_rec(old_cfg, new_cfg, out)
# Keep the `_DEFAULTS_` even though they should be the same
old_defaults = old_cfg.get(DEFAULTS_KEY, None)
new_defaults = new_cfg.get(DEFAULTS_KEY, None)
assert (
old_defaults == new_defaults
), f"{DEFAULTS_KEY} doesn't match! old ({old_defaults}) vs new ({new_defaults})"
if new_defaults is not None:
diff_cfg[DEFAULTS_KEY] = new_defaults
return diff_cfg
def namedtuple_to_dict(obj: Any):
......@@ -223,3 +238,27 @@ def namedtuple_to_dict(obj: Any):
else:
res[k] = v
return res
def resolve_default_config(cfg):
if DEFAULTS_KEY not in cfg:
raise ValueError(
f"Can't resolved default config because `{DEFAULTS_KEY}` is"
f" missing from cfg: \n{cfg}"
)
updater_names: List[str] = cfg[DEFAULTS_KEY]
assert isinstance(updater_names, list), updater_names
assert [isinstance(x, str) for x in updater_names], updater_names
logger.info(f"Resolving default config by applying updaters: {updater_names} ...")
# starting from a empty CfgNode, sequentially apply the generator
cfg = type(cfg)()
for name in updater_names:
updater = CONFIG_UPDATER_REGISTRY.get(name)
cfg = updater(cfg)
# the resolved default config should keep the same default generator
cfg[DEFAULTS_KEY] = updater_names
return cfg
......@@ -5,12 +5,14 @@
import importlib
from typing import Optional, Type, Union
from .api import RunnerV2Mixin
from .default_runner import BaseRunner, Detectron2GoRunner, GeneralizedRCNNRunner
from .lightning_task import DefaultTask
from .training_hooks import TRAINER_HOOKS_REGISTRY
__all__ = [
"RunnerV2Mixin",
"BaseRunner",
"Detectron2GoRunner",
"GeneralizedRCNNRunner",
......
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import final
from d2go.config import CfgNode
class RunnerV2Mixin(object):
"""
Interface for (V2) Runner:
- `get_default_cfg` is not a runner method anymore.
"""
@classmethod
@final
def get_default_cfg(cls) -> CfgNode:
raise NotImplementedError("")
......@@ -13,12 +13,13 @@ import torch
from d2go.config import (
auto_scale_world_size,
CfgNode,
load_full_config_from_file,
reroute_config_path,
temp_defrost,
)
from d2go.config.utils import get_diff_cfg
from d2go.distributed import get_local_rank, get_num_processes_per_machine
from d2go.runner import BaseRunner, create_runner, DefaultTask
from d2go.runner import BaseRunner, create_runner, DefaultTask, RunnerV2Mixin
from d2go.utils.helper import run_once
from d2go.utils.launch_environment import get_launch_environment
from detectron2.utils.collect_env import collect_env_info
......@@ -127,11 +128,15 @@ def prepare_for_launch(args):
logger.info(args)
runner = create_runner(args.runner)
cfg = runner.get_default_cfg()
with PathManager.open(reroute_config_path(args.config_file), "r") as f:
print("Loaded config file {}:\n{}".format(args.config_file, f.read()))
if isinstance(runner, RunnerV2Mixin):
cfg = load_full_config_from_file(args.config_file)
else:
cfg = runner.get_default_cfg()
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()
......@@ -180,15 +185,15 @@ def setup_after_launch(
logger.info("Running with runner: {}".format(runner))
# save the diff config
if runner is not None:
default_cfg = runner.get_default_cfg()
default_cfg = (
runner.get_default_cfg()
if runner and not isinstance(runner, RunnerV2Mixin)
else cfg.get_default_cfg()
)
dump_cfg(
get_diff_cfg(default_cfg, cfg),
os.path.join(output_dir, "diff_config.yaml"),
)
else:
# TODO: support getting default_cfg without runner.
pass
# scale the config after dumping so that dumped config files keep original world size
auto_scale_world_size(cfg, new_world_size=comm.get_world_size())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment