Commit 8d1c1c2c authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

support specifying default config from config itself

Summary: Pull Request resolved: https://github.com/facebookresearch/d2go/pull/264

Reviewed By: tglik

Differential Revision: D36427439

fbshipit-source-id: 3502d61ccb3c3f67d7ccfc1f55777c74f6c3b970
parent 814badbc
......@@ -8,6 +8,7 @@ from .config import (
CfgNode,
CONFIG_CUSTOM_PARSE_REGISTRY,
CONFIG_SCALING_METHOD_REGISTRY,
load_full_config_from_file,
reroute_config_path,
temp_defrost,
)
......@@ -18,6 +19,7 @@ __all__ = [
"CONFIG_SCALING_METHOD_REGISTRY",
"CfgNode",
"auto_scale_world_size",
"load_full_config_from_file",
"reroute_config_path",
"temp_defrost",
]
......@@ -8,6 +8,7 @@ from typing import List
import mock
import yaml
from d2go.registry.builtin import CONFIG_UPDATER_REGISTRY
from detectron2.config import CfgNode as _CfgNode
from fvcore.common.registry import Registry
......@@ -16,6 +17,7 @@ from .utils import reroute_config_path
logger = logging.getLogger(__name__)
CONFIG_CUSTOM_PARSE_REGISTRY = Registry("CONFIG_CUSTOM_PARSE")
DEFAULTS_GENERATOR_KEY = "_DEFAULTS_"
def _opts_to_dict(opts: List[str]):
......@@ -80,6 +82,10 @@ class CfgNode(_CfgNode):
if frozen:
self.freeze()
def get_default_cfg(self):
"""Return the defaults for this instance of CfgNode"""
return _resolve_default_config(self)
@contextlib.contextmanager
def temp_defrost(cfg):
......@@ -167,3 +173,34 @@ def auto_scale_world_size(cfg, new_world_size):
table = get_cfg_diff_table(cfg, original_cfg)
logger.info("Auto-scaled the config according to the actual world size: \n" + table)
def _resolve_default_config(cfg: CfgNode) -> CfgNode:
if DEFAULTS_GENERATOR_KEY not in cfg:
raise ValueError(
f"Can't resolved default config because `{DEFAULTS_GENERATOR_KEY}` is"
f" missing from cfg: \n{cfg}"
)
updater_names: List[str] = cfg[DEFAULTS_GENERATOR_KEY]
assert isinstance(updater_names, list), updater_names
assert [isinstance(x, str) for x in updater_names], updater_names
# starting from a empty CfgNode, sequentially apply the generator
cfg = CfgNode()
for name in updater_names:
updater = CONFIG_UPDATER_REGISTRY.get(name)
cfg = updater(cfg)
# the resolved default config should keep the same default generator
cfg[DEFAULTS_GENERATOR_KEY] = updater_names
return cfg
def load_full_config_from_file(filename: str) -> CfgNode:
loaded_cfg = CfgNode.load_yaml_with_base(filename)
loaded_cfg = CfgNode(loaded_cfg) # cast Dict to CfgNode
cfg = loaded_cfg.get_default_cfg()
cfg.merge_from_other_cfg(loaded_cfg)
return cfg
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from mobile_cv.common.misc.registry import Registry
"""
This file contains all D2Go's builtin registries with global scope.
- These registries can be treated as "static". There'll be a bootstrap process happens
at the beginning of the program to make it works like the registrations happen
at compile time (like C++). In another word, the objects are guaranteed to be
registered to those builtin registries without user importing their code.
- Since the namespace is global, the registered name has to be unique across all projects.
"""
DEMO_REGISTRY = Registry("DEMO")
# Registry for config updater
CONFIG_UPDATER_REGISTRY = Registry("CONFIG_UPDATER")
......@@ -7,7 +7,12 @@ import logging
import os
import unittest
from d2go.config import auto_scale_world_size, CfgNode, reroute_config_path
from d2go.config import (
auto_scale_world_size,
CfgNode,
load_full_config_from_file,
reroute_config_path,
)
from d2go.config.utils import (
config_dict_to_list_str,
flatten_config_dict,
......@@ -15,6 +20,7 @@ from d2go.config.utils import (
get_diff_cfg,
get_from_flattened_config_dict,
)
from d2go.registry.builtin import CONFIG_UPDATER_REGISTRY
from d2go.runner import GeneralizedRCNNRunner
from d2go.utils.testing.helper import get_resource_path
from mobile_cv.common.misc.file_utils import make_temp_directory
......@@ -226,5 +232,31 @@ class TestAutoScaleWorldSize(unittest.TestCase):
self.assertEqual(cfg.SOLVER.IMS_PER_BATCH, batch_size_x8)
class TestConfigDefaultsGen(unittest.TestCase):
def test_case1(self):
# register in local scope
@CONFIG_UPDATER_REGISTRY.register()
def _test1(cfg):
cfg.TEST1 = CfgNode()
cfg.TEST1.X = 1
return cfg
@CONFIG_UPDATER_REGISTRY.register()
def _test2(cfg):
cfg.TEST2 = CfgNode()
cfg.TEST2.Y = 2
return cfg
filename = get_resource_path("defaults_gen_case1.yaml")
cfg = load_full_config_from_file(filename)
default_cfg = cfg.get_default_cfg()
# default value is 1
self.assertEqual(default_cfg.TEST1.X, 1)
self.assertEqual(default_cfg.TEST2.Y, 2)
# yaml file overwrites it to 3
self.assertEqual(cfg.TEST1.X, 3)
if __name__ == "__main__":
unittest.main()
_DEFAULTS_: ["_test1", "_test2"]
TEST1:
X: 3
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment