Commit 8d1c1c2c authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

support specifying default config from config itself

Summary: Pull Request resolved: https://github.com/facebookresearch/d2go/pull/264

Reviewed By: tglik

Differential Revision: D36427439

fbshipit-source-id: 3502d61ccb3c3f67d7ccfc1f55777c74f6c3b970
parent 814badbc
...@@ -8,6 +8,7 @@ from .config import ( ...@@ -8,6 +8,7 @@ from .config import (
CfgNode, CfgNode,
CONFIG_CUSTOM_PARSE_REGISTRY, CONFIG_CUSTOM_PARSE_REGISTRY,
CONFIG_SCALING_METHOD_REGISTRY, CONFIG_SCALING_METHOD_REGISTRY,
load_full_config_from_file,
reroute_config_path, reroute_config_path,
temp_defrost, temp_defrost,
) )
...@@ -18,6 +19,7 @@ __all__ = [ ...@@ -18,6 +19,7 @@ __all__ = [
"CONFIG_SCALING_METHOD_REGISTRY", "CONFIG_SCALING_METHOD_REGISTRY",
"CfgNode", "CfgNode",
"auto_scale_world_size", "auto_scale_world_size",
"load_full_config_from_file",
"reroute_config_path", "reroute_config_path",
"temp_defrost", "temp_defrost",
] ]
...@@ -8,6 +8,7 @@ from typing import List ...@@ -8,6 +8,7 @@ from typing import List
import mock import mock
import yaml import yaml
from d2go.registry.builtin import CONFIG_UPDATER_REGISTRY
from detectron2.config import CfgNode as _CfgNode from detectron2.config import CfgNode as _CfgNode
from fvcore.common.registry import Registry from fvcore.common.registry import Registry
...@@ -16,6 +17,7 @@ from .utils import reroute_config_path ...@@ -16,6 +17,7 @@ from .utils import reroute_config_path
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
CONFIG_CUSTOM_PARSE_REGISTRY = Registry("CONFIG_CUSTOM_PARSE") CONFIG_CUSTOM_PARSE_REGISTRY = Registry("CONFIG_CUSTOM_PARSE")
DEFAULTS_GENERATOR_KEY = "_DEFAULTS_"
def _opts_to_dict(opts: List[str]): def _opts_to_dict(opts: List[str]):
...@@ -80,6 +82,10 @@ class CfgNode(_CfgNode): ...@@ -80,6 +82,10 @@ class CfgNode(_CfgNode):
if frozen: if frozen:
self.freeze() self.freeze()
def get_default_cfg(self):
"""Return the defaults for this instance of CfgNode"""
return _resolve_default_config(self)
@contextlib.contextmanager @contextlib.contextmanager
def temp_defrost(cfg): def temp_defrost(cfg):
...@@ -167,3 +173,34 @@ def auto_scale_world_size(cfg, new_world_size): ...@@ -167,3 +173,34 @@ def auto_scale_world_size(cfg, new_world_size):
table = get_cfg_diff_table(cfg, original_cfg) table = get_cfg_diff_table(cfg, original_cfg)
logger.info("Auto-scaled the config according to the actual world size: \n" + table) logger.info("Auto-scaled the config according to the actual world size: \n" + table)
def _resolve_default_config(cfg: CfgNode) -> CfgNode:
if DEFAULTS_GENERATOR_KEY not in cfg:
raise ValueError(
f"Can't resolved default config because `{DEFAULTS_GENERATOR_KEY}` is"
f" missing from cfg: \n{cfg}"
)
updater_names: List[str] = cfg[DEFAULTS_GENERATOR_KEY]
assert isinstance(updater_names, list), updater_names
assert [isinstance(x, str) for x in updater_names], updater_names
# starting from a empty CfgNode, sequentially apply the generator
cfg = CfgNode()
for name in updater_names:
updater = CONFIG_UPDATER_REGISTRY.get(name)
cfg = updater(cfg)
# the resolved default config should keep the same default generator
cfg[DEFAULTS_GENERATOR_KEY] = updater_names
return cfg
def load_full_config_from_file(filename: str) -> CfgNode:
loaded_cfg = CfgNode.load_yaml_with_base(filename)
loaded_cfg = CfgNode(loaded_cfg) # cast Dict to CfgNode
cfg = loaded_cfg.get_default_cfg()
cfg.merge_from_other_cfg(loaded_cfg)
return cfg
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from mobile_cv.common.misc.registry import Registry
"""
This file contains all D2Go's builtin registries with global scope.
- These registries can be treated as "static". There'll be a bootstrap process happens
at the beginning of the program to make it works like the registrations happen
at compile time (like C++). In another word, the objects are guaranteed to be
registered to those builtin registries without user importing their code.
- Since the namespace is global, the registered name has to be unique across all projects.
"""
DEMO_REGISTRY = Registry("DEMO")
# Registry for config updater
CONFIG_UPDATER_REGISTRY = Registry("CONFIG_UPDATER")
...@@ -7,7 +7,12 @@ import logging ...@@ -7,7 +7,12 @@ import logging
import os import os
import unittest import unittest
from d2go.config import auto_scale_world_size, CfgNode, reroute_config_path from d2go.config import (
auto_scale_world_size,
CfgNode,
load_full_config_from_file,
reroute_config_path,
)
from d2go.config.utils import ( from d2go.config.utils import (
config_dict_to_list_str, config_dict_to_list_str,
flatten_config_dict, flatten_config_dict,
...@@ -15,6 +20,7 @@ from d2go.config.utils import ( ...@@ -15,6 +20,7 @@ from d2go.config.utils import (
get_diff_cfg, get_diff_cfg,
get_from_flattened_config_dict, get_from_flattened_config_dict,
) )
from d2go.registry.builtin import CONFIG_UPDATER_REGISTRY
from d2go.runner import GeneralizedRCNNRunner from d2go.runner import GeneralizedRCNNRunner
from d2go.utils.testing.helper import get_resource_path from d2go.utils.testing.helper import get_resource_path
from mobile_cv.common.misc.file_utils import make_temp_directory from mobile_cv.common.misc.file_utils import make_temp_directory
...@@ -226,5 +232,31 @@ class TestAutoScaleWorldSize(unittest.TestCase): ...@@ -226,5 +232,31 @@ class TestAutoScaleWorldSize(unittest.TestCase):
self.assertEqual(cfg.SOLVER.IMS_PER_BATCH, batch_size_x8) self.assertEqual(cfg.SOLVER.IMS_PER_BATCH, batch_size_x8)
class TestConfigDefaultsGen(unittest.TestCase):
def test_case1(self):
# register in local scope
@CONFIG_UPDATER_REGISTRY.register()
def _test1(cfg):
cfg.TEST1 = CfgNode()
cfg.TEST1.X = 1
return cfg
@CONFIG_UPDATER_REGISTRY.register()
def _test2(cfg):
cfg.TEST2 = CfgNode()
cfg.TEST2.Y = 2
return cfg
filename = get_resource_path("defaults_gen_case1.yaml")
cfg = load_full_config_from_file(filename)
default_cfg = cfg.get_default_cfg()
# default value is 1
self.assertEqual(default_cfg.TEST1.X, 1)
self.assertEqual(default_cfg.TEST2.Y, 2)
# yaml file overwrites it to 3
self.assertEqual(cfg.TEST1.X, 3)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
_DEFAULTS_: ["_test1", "_test2"]
TEST1:
X: 3
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment