Commit 39054767 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

support multi-base for config re-route

Summary: as title

Reviewed By: Cysu

Differential Revision: D31901433

fbshipit-source-id: 1749527c04c392c830e1a49bca8313ddf903d7b1
parent 421960b3
...@@ -56,16 +56,22 @@ def reroute_load_yaml_with_base(): ...@@ -56,16 +56,22 @@ def reroute_load_yaml_with_base():
_safe_load = yaml.safe_load _safe_load = yaml.safe_load
_unsafe_load = yaml.unsafe_load _unsafe_load = yaml.unsafe_load
def _reroute_base(cfg):
if BASE_KEY in cfg:
if isinstance(cfg[BASE_KEY], list):
cfg[BASE_KEY] = [reroute_config_path(x) for x in cfg[BASE_KEY]]
else:
cfg[BASE_KEY] = reroute_config_path(cfg[BASE_KEY])
return cfg
def mock_safe_load(f): def mock_safe_load(f):
cfg = _safe_load(f) cfg = _safe_load(f)
if BASE_KEY in cfg: cfg = _reroute_base(cfg)
cfg[BASE_KEY] = reroute_config_path(cfg[BASE_KEY])
return cfg return cfg
def mock_unsafe_load(f): def mock_unsafe_load(f):
cfg = _unsafe_load(f) cfg = _unsafe_load(f)
if BASE_KEY in cfg: cfg = _reroute_base(cfg)
cfg[BASE_KEY] = reroute_config_path(cfg[BASE_KEY])
return cfg return cfg
with mock.patch("yaml.safe_load", side_effect=mock_safe_load): with mock.patch("yaml.safe_load", side_effect=mock_safe_load):
......
...@@ -19,6 +19,7 @@ def reroute_config_path(path: str) -> str: ...@@ -19,6 +19,7 @@ def reroute_config_path(path: str) -> str:
Those config are considered as code, so they'll reflect your current checkout, Those config are considered as code, so they'll reflect your current checkout,
try using canary if you have local changes. try using canary if you have local changes.
""" """
assert isinstance(path, str), path
if path.startswith("d2go://"): if path.startswith("d2go://"):
rel_path = path[len("d2go://") :] rel_path = path[len("d2go://") :]
......
...@@ -51,6 +51,22 @@ class TestConfig(unittest.TestCase): ...@@ -51,6 +51,22 @@ class TestConfig(unittest.TestCase):
another_cfg = default_cfg.clone() another_cfg = default_cfg.clone()
another_cfg.merge_from_file(file_name) another_cfg.merge_from_file(file_name)
def test_base_reroute(self):
default_cfg = GeneralizedRCNNRunner().get_default_cfg()
# use rerouted file as base
cfg = default_cfg.clone()
cfg.merge_from_file(get_resource_path("rerouted_base.yaml"))
self.assertEqual(cfg.MODEL.MASK_ON, True) # base is loaded
self.assertEqual(cfg.MODEL.FBNET_V2.ARCH, "test") # non-base is loaded
# use multiple files as base
cfg = default_cfg.clone()
cfg.merge_from_file(get_resource_path("rerouted_multi_base.yaml"))
self.assertEqual(cfg.MODEL.MASK_ON, True) # base is loaded
self.assertEqual(cfg.MODEL.FBNET_V2.ARCH, "FBNetV3_A") # second base is loaded
self.assertEqual(cfg.OUTPUT_DIR, "test") # non-base is loaded
def test_default_cfg_dump_and_load(self): def test_default_cfg_dump_and_load(self):
default_cfg = GeneralizedRCNNRunner().get_default_cfg() default_cfg = GeneralizedRCNNRunner().get_default_cfg()
......
_BASE_: "detectron2go://mask_rcnn_fbnetv3a_C4.yaml"
MODEL:
FBNET_V2:
ARCH: "test"
_BASE_:
- "rerouted_base.yaml"
- "detectron2go://mask_rcnn_fbnetv3a_C4.yaml"
OUTPUT_DIR: "test"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment