Commit 7bed2910 authored by Haroun Habeeb's avatar Haroun Habeeb Committed by Facebook GitHub Bot
Browse files

temp_new_allowed

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/438

Adding new fields to a config is only allowed if `new_allowed=True`. yacs `CfgNode` provides a `set_new_allowed(value: bool)` function.

We create a context manager like `temp_defrost` but for new_allowed to use it. We also implement unit test for the same

Reviewed By: yanglinfang, newstzpz, wat3rBro

Differential Revision: D41748992

fbshipit-source-id: 71d048511476001ca96e6b36dde4d177b11268d7
parent e2537c82
...@@ -12,6 +12,7 @@ from .config import ( ...@@ -12,6 +12,7 @@ from .config import (
load_full_config_from_file, load_full_config_from_file,
reroute_config_path, reroute_config_path,
temp_defrost, temp_defrost,
temp_new_allowed,
) )
...@@ -24,4 +25,5 @@ __all__ = [ ...@@ -24,4 +25,5 @@ __all__ = [
"load_full_config_from_file", "load_full_config_from_file",
"reroute_config_path", "reroute_config_path",
"temp_defrost", "temp_defrost",
"temp_new_allowed",
] ]
...@@ -96,6 +96,14 @@ def temp_defrost(cfg): ...@@ -96,6 +96,14 @@ def temp_defrost(cfg):
cfg.freeze() cfg.freeze()
@contextlib.contextmanager
def temp_new_allowed(cfg: CfgNode):
is_new_allowed = cfg.is_new_allowed()
cfg.set_new_allowed(True)
yield cfg
cfg.set_new_allowed(is_new_allowed)
@contextlib.contextmanager @contextlib.contextmanager
def reroute_load_yaml_with_base(): def reroute_load_yaml_with_base():
BASE_KEY = "_BASE_" BASE_KEY = "_BASE_"
......
...@@ -12,6 +12,7 @@ from d2go.config import ( ...@@ -12,6 +12,7 @@ from d2go.config import (
CfgNode, CfgNode,
load_full_config_from_file, load_full_config_from_file,
reroute_config_path, reroute_config_path,
temp_new_allowed,
) )
from d2go.config.utils import ( from d2go.config.utils import (
config_dict_to_list_str, config_dict_to_list_str,
...@@ -74,6 +75,18 @@ class TestConfig(unittest.TestCase): ...@@ -74,6 +75,18 @@ class TestConfig(unittest.TestCase):
self.assertEqual(cfg.MODEL.FBNET_V2.ARCH, "FBNetV3_A") # second base is loaded self.assertEqual(cfg.MODEL.FBNET_V2.ARCH, "FBNetV3_A") # second base is loaded
self.assertEqual(cfg.OUTPUT_DIR, "test") # non-base is loaded self.assertEqual(cfg.OUTPUT_DIR, "test") # non-base is loaded
def test_temp_new_allowed(self):
default_cfg = GeneralizedRCNNRunner.get_default_cfg()
def set_field(cfg):
cfg.THIS_BETTER_BE_A_NEW_CONFIG = 4
self.assertFalse("THIS_BETTER_BE_A_NEW_CONFIG" in default_cfg)
with temp_new_allowed(default_cfg):
set_field(default_cfg)
self.assertTrue("THIS_BETTER_BE_A_NEW_CONFIG" in default_cfg)
self.assertTrue(default_cfg.THIS_BETTER_BE_A_NEW_CONFIG == 4)
def test_default_cfg_dump_and_load(self): def test_default_cfg_dump_and_load(self):
default_cfg = GeneralizedRCNNRunner.get_default_cfg() default_cfg = GeneralizedRCNNRunner.get_default_cfg()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment