Commit 81ab967f authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

misc update to config utils

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/77

- Reimplement `get_cfg_diff_table` by reusing other utils
- Adding `reorder` option for `flatten_config_dict`
- Remove the legacy BC support for `ARCH_DEF`, including `str_wrap_fbnet_arch_def` and customized `merge_from_other_cfg`.
- Move `temp_defrost` from `utils.py` to `config.py`, this way there's no more namespace forwarding for `utils.py`
- Merge `test_config_utils.py` and `test_configs.py`

Reviewed By: zhanghang1989

Differential Revision: D28734493

fbshipit-source-id: 925f5944cf0e9019e4c54462e851ea16a5c94b8c
parent ad9f35c7
...@@ -2,11 +2,20 @@ ...@@ -2,11 +2,20 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from .config import ( # noqa, forward namespace # forward the namespace to avoid `d2go.config.config`
from .config import (
CONFIG_SCALING_METHOD_REGISTRY, CONFIG_SCALING_METHOD_REGISTRY,
CfgNode, CfgNode,
auto_scale_world_size, auto_scale_world_size,
reroute_config_path, reroute_config_path,
get_cfg_diff_table, temp_defrost,
) )
from .utils import temp_defrost # noqa, forward namespace
__all__ = [
"CONFIG_SCALING_METHOD_REGISTRY",
"CfgNode",
"auto_scale_world_size",
"reroute_config_path",
"temp_defrost",
]
...@@ -14,48 +14,6 @@ from fvcore.common.registry import Registry ...@@ -14,48 +14,6 @@ from fvcore.common.registry import Registry
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_cfg_diff_table(cfg, original_cfg):
# print out the differences
from d2go.config import CfgNode
def _find_all_keys(obj):
assert isinstance(obj, CfgNode)
ret = []
for key in sorted(obj.keys()):
value = obj[key]
if isinstance(value, CfgNode):
for sub_key in _find_all_keys(value):
ret.append("{}.{}".format(key, sub_key))
else:
ret.append(key)
return ret
def _get_value(obj, full_key):
for k in full_key.split("."):
obj = obj[k]
return obj
all_old_keys = _find_all_keys(original_cfg)
all_new_keys = _find_all_keys(cfg)
assert all_old_keys == all_new_keys
diff_table = []
for full_key in all_new_keys:
old_value = _get_value(original_cfg, full_key)
new_value = _get_value(cfg, full_key)
if old_value != new_value:
diff_table.append([full_key, old_value, new_value])
from tabulate import tabulate
table = tabulate(
diff_table,
tablefmt="pipe",
headers=["config key", "old value", "new value"],
)
return table
class CfgNode(_CfgNode): class CfgNode(_CfgNode):
@classmethod @classmethod
def cast_from_other_class(cls, other_cfg): def cast_from_other_class(cls, other_cfg):
...@@ -76,25 +34,21 @@ class CfgNode(_CfgNode): ...@@ -76,25 +34,21 @@ class CfgNode(_CfgNode):
with reroute_load_yaml_with_base(): with reroute_load_yaml_with_base():
return _CfgNode.load_yaml_with_base(filename, *args, **kwargs) return _CfgNode.load_yaml_with_base(filename, *args, **kwargs)
def merge_from_other_cfg(self, cfg_other):
# NOTE: D24397488 changes default MODEL.FBNET_V2.ARCH_DEF from "" to [], change
# the value to be able to load old full configs.
# TODO: remove this by end of 2020.
if cfg_other.get("MODEL", {}).get("FBNET_V2", {}).get("ARCH_DEF", None) == "":
import logging
logger = logging.getLogger(__name__)
logger.warning(
"Default value for MODEL.FBNET_V2.ARCH_DEF has changed to []"
)
cfg_other.MODEL.FBNET_V2.ARCH_DEF = []
return super().merge_from_other_cfg(cfg_other)
def __hash__(self): def __hash__(self):
# dump follows alphabetical order, thus good for hash use # dump follows alphabetical order, thus good for hash use
return hash(self.dump()) return hash(self.dump())
@contextlib.contextmanager
def temp_defrost(cfg):
is_frozen = cfg.is_frozen()
if is_frozen:
cfg.defrost()
yield cfg
if is_frozen:
cfg.freeze()
@contextlib.contextmanager @contextlib.contextmanager
def reroute_load_yaml_with_base(): def reroute_load_yaml_with_base():
BASE_KEY = "_BASE_" BASE_KEY = "_BASE_"
...@@ -161,5 +115,7 @@ def auto_scale_world_size(cfg, new_world_size): ...@@ -161,5 +115,7 @@ def auto_scale_world_size(cfg, new_world_size):
if frozen: if frozen:
cfg.freeze() cfg.freeze()
from d2go.config.utils import get_cfg_diff_table
table = get_cfg_diff_table(cfg, original_cfg) table = get_cfg_diff_table(cfg, original_cfg)
logger.info("Auto-scaled the config according to the actual world size: \n" + table) logger.info("Auto-scaled the config according to the actual world size: \n" + table)
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import contextlib
import copy
import json
import shlex
from typing import Dict, List from typing import Dict, List
@contextlib.contextmanager def _flatten_config_dict(x, reorder, prefix):
def temp_defrost(cfg): if not isinstance(x, dict):
is_frozen = cfg.is_frozen() return {prefix: x}
if is_frozen:
cfg.defrost()
yield cfg
if is_frozen:
cfg.freeze()
def str_wrap_fbnet_arch_def(d: Dict, inplace=False) -> Dict:
"""Replaces MODEL.FBNET_V2.ARCH_DEF with wrapped json string
Searches the input dict to see if it contains MODEL.FBNET_V2.ARCH_DEF
and replaces the value with a wrapped json string
The json string is created because FBNet builder runs json.loads on the
archdef. The json string needs to be wrapped in another string because
CfgNode runs literal_eval in order to check whether it should continue
to create CfgNodes if the value is a dict.
arch_def = {...} # {...}
arch_def = json.dumps(arch_def) # '{...}'
arch_def = strwrap(arch_def) # '"{...}"'
CfgNode(arch_def) => literal_eval(arch_def) # '{...}'
FBNetBuilder(arch_def) => json.loads(arch_def) # {...}
Example:
config = {"MODEL": {"FBNET_V2": {"ARCH_DEF": [1, 1, 1]}}}
str_wrap_fbnet_arch_def(config)
=> {"MODEL": {"FBNET_V2": {"ARCH_DEF": '''"[1, 1, 1]"'''}}}
"""
if not inplace:
d = copy.deepcopy(d)
try:
archdef = d["MODEL"]["FBNET_V2"]["ARCH_DEF"]
# MODEL.FBNET_V2.ARCH_DEF needs to be json str
archdef = json.dumps(archdef)
# CfgNode runs literal_eval when merging so wrap around str
archdef = shlex.quote(archdef)
d["MODEL"]["FBNET_V2"]["ARCH_DEF"] = archdef
except KeyError:
pass
d = {}
for k in sorted(x.keys()) if reorder else x.keys():
v = x[k]
new_key = f"{prefix}.{k}" if prefix else k
d.update(_flatten_config_dict(v, reorder, new_key))
return d return d
def flatten_config_dict(x, prefix=""): def flatten_config_dict(dic, reorder=True):
"""Flattens config dict into single layer dict """
Flattens nested dict into single layer dict, for example:
Example:
flatten_config_dict({ flatten_config_dict({
MODEL: { "MODEL": {
FBNET_V2: { "FBNET_V2": {
ARCH_DEF: "val0" "ARCH_DEF": "val0",
} "ARCH": "val1:,
},
} }
}) })
=> {"MODEL.FBNET_V2.ARCH_DEF": "val0"} => {"MODEL.FBNET_V2.ARCH_DEF": "val0", "MODEL.FBNET_V2.ARCH": "val1"}
"""
if not isinstance(x, dict):
return {prefix: x}
d = {} Args:
for k, v in x.items(): dic (dict or CfgNode): a nested dict whose keys are strings.
new_key = f"{prefix}.{k}" if prefix else k reorder (bool): if True, the returned dict will be sorted according to the keys;
d.update(flatten_config_dict(v, new_key)) otherwise original order will be preserved.
return d
Returns:
dic: a single-layer dict
"""
return _flatten_config_dict(dic, reorder=reorder, prefix="")
def config_dict_to_list_str(config_dict: Dict) -> List[str]: def config_dict_to_list_str(config_dict: Dict) -> List[str]:
...@@ -93,3 +52,49 @@ def config_dict_to_list_str(config_dict: Dict) -> List[str]: ...@@ -93,3 +52,49 @@ def config_dict_to_list_str(config_dict: Dict) -> List[str]:
str_list.append(k) str_list.append(k)
str_list.append(str(v)) str_list.append(str(v))
return str_list return str_list
def get_from_flattened_config_dict(dic, flattened_key, default=None):
"""
Reads out a value from the nested config dict using flattened config key (i.e. all
keys from each level put together with "." separator), the default value is returned
if the flattened key doesn't exist.
e.g. if the config dict is
MODEL:
TEST:
SCORE_THRESHOLD: 0.7
Then to access the value of SCORE_THRESHOLD, this API should be called
>> score_threshold = get_from_flattened_config_dict(cfg, "MODEL.TEST.SCORE_THRESHOLD")
"""
for k in flattened_key.split("."):
if k not in dic:
return default
dic = dic[k]
return dic
def get_cfg_diff_table(cfg, original_cfg):
"""
Print the different of two config dicts side-by-side in a table
"""
all_old_keys = list(flatten_config_dict(original_cfg, reorder=True).keys())
all_new_keys = list(flatten_config_dict(cfg, reorder=True).keys())
assert all_old_keys == all_new_keys
diff_table = []
for full_key in all_new_keys:
old_value = get_from_flattened_config_dict(original_cfg, full_key)
new_value = get_from_flattened_config_dict(cfg, full_key)
if old_value != new_value:
diff_table.append([full_key, old_value, new_value])
from tabulate import tabulate
table = tabulate(
diff_table,
tablefmt="pipe",
headers=["config key", "old value", "new value"],
)
return table
...@@ -18,8 +18,8 @@ from d2go.config import ( ...@@ -18,8 +18,8 @@ from d2go.config import (
CfgNode, CfgNode,
CONFIG_SCALING_METHOD_REGISTRY, CONFIG_SCALING_METHOD_REGISTRY,
temp_defrost, temp_defrost,
get_cfg_diff_table,
) )
from d2go.config.utils import get_cfg_diff_table
from d2go.data.build import build_d2go_train_loader from d2go.data.build import build_d2go_train_loader
from d2go.data.dataset_mappers import build_dataset_mapper from d2go.data.dataset_mappers import build_dataset_mapper
from d2go.data.datasets import inject_coco_datasets, register_dynamic_datasets from d2go.data.datasets import inject_coco_datasets, register_dynamic_datasets
......
...@@ -8,6 +8,12 @@ import os ...@@ -8,6 +8,12 @@ import os
import unittest import unittest
from d2go.config import auto_scale_world_size, reroute_config_path from d2go.config import auto_scale_world_size, reroute_config_path
from d2go.config.utils import (
config_dict_to_list_str,
flatten_config_dict,
get_cfg_diff_table,
get_from_flattened_config_dict,
)
from d2go.runner import GeneralizedRCNNRunner from d2go.runner import GeneralizedRCNNRunner
from d2go.utils.testing.helper import get_resource_path from d2go.utils.testing.helper import get_resource_path
from mobile_cv.common.misc.file_utils import make_temp_directory from mobile_cv.common.misc.file_utils import make_temp_directory
...@@ -16,9 +22,9 @@ from mobile_cv.common.misc.file_utils import make_temp_directory ...@@ -16,9 +22,9 @@ from mobile_cv.common.misc.file_utils import make_temp_directory
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class TestConfigs(unittest.TestCase): class TestConfig(unittest.TestCase):
def test_configs_load(self): def test_load_configs(self):
""" Make sure configs are loadable """ """Make sure configs are loadable"""
for location in ["detectron2", "detectron2go"]: for location in ["detectron2", "detectron2go"]:
root_dir = os.path.abspath(reroute_config_path(f"{location}://.")) root_dir = os.path.abspath(reroute_config_path(f"{location}://."))
...@@ -28,8 +34,8 @@ class TestConfigs(unittest.TestCase): ...@@ -28,8 +34,8 @@ class TestConfigs(unittest.TestCase):
logger.info("Loading {}...".format(fn)) logger.info("Loading {}...".format(fn))
GeneralizedRCNNRunner().get_default_cfg().merge_from_file(fn) GeneralizedRCNNRunner().get_default_cfg().merge_from_file(fn)
def test_arch_def_loads(self): def test_load_arch_defs(self):
""" Test arch def str-to-dict conversion compatible with merging """ """Test arch def str-to-dict conversion compatible with merging"""
default_cfg = GeneralizedRCNNRunner().get_default_cfg() default_cfg = GeneralizedRCNNRunner().get_default_cfg()
cfg = default_cfg.clone() cfg = default_cfg.clone()
cfg.merge_from_file(get_resource_path("arch_def_merging.yaml")) cfg.merge_from_file(get_resource_path("arch_def_merging.yaml"))
...@@ -70,6 +76,51 @@ class TestConfigs(unittest.TestCase): ...@@ -70,6 +76,51 @@ class TestConfigs(unittest.TestCase):
) )
class TestConfigUtils(unittest.TestCase):
"""Test util functions in config/utils.py"""
def test_flatten_config_dict(self):
"""Check flatten config dict to single layer dict"""
d = {"c0": {"c1": {"c2": 3}}, "b0": {"b1": "b2"}, "a0": "a1"}
# reorder=True
fdict = flatten_config_dict(d, reorder=True)
gt = {"a0": "a1", "b0.b1": "b2", "c0.c1.c2": 3}
self.assertEqual(fdict, gt)
self.assertEqual(list(fdict.keys()), list(gt.keys()))
# reorder=False
fdict = flatten_config_dict(d, reorder=False)
gt = {"c0.c1.c2": 3, "b0.b1": "b2", "a0": "a1"}
self.assertEqual(fdict, gt)
self.assertEqual(list(fdict.keys()), list(gt.keys()))
def test_config_dict_to_list_str(self):
"""Check convert config dict to str list"""
d = {"a0": "a1", "b0": {"b1": "b2"}, "c0": {"c1": {"c2": 3}}}
str_list = config_dict_to_list_str(d)
gt = ["a0", "a1", "b0.b1", "b2", "c0.c1.c2", "3"]
self.assertEqual(str_list, gt)
def test_get_from_flattened_config_dict(self):
d = {"MODEL": {"MIN_DIM_SIZE": 360}}
self.assertEqual(
get_from_flattened_config_dict(d, "MODEL.MIN_DIM_SIZE"), 360
) # exist
self.assertEqual(
get_from_flattened_config_dict(d, "MODEL.MODEL.INPUT_SIZE"), None
) # non-exist
def test_get_cfg_diff_table(self):
"""Check compare two dicts"""
d1 = {"a0": "a1", "b0": {"b1": "b2"}, "c0": {"c1": {"c2": 3}}}
d2 = {"a0": "a1", "b0": {"b1": "b3"}, "c0": {"c1": {"c2": 4}}}
table = get_cfg_diff_table(d1, d2)
self.assertTrue("a0" not in table) # a0 are the same
self.assertTrue("b0.b1" in table) # b0.b1 are different
self.assertTrue("c0.c1.c2" in table) # c0.c1.c2 are different
class TestAutoScaleWorldSize(unittest.TestCase): class TestAutoScaleWorldSize(unittest.TestCase):
def test_8gpu_to_1gpu(self): def test_8gpu_to_1gpu(self):
""" """
......
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import copy
import unittest
from d2go.config.utils import (
config_dict_to_list_str,
flatten_config_dict,
str_wrap_fbnet_arch_def,
)
class TestConfigUtils(unittest.TestCase):
def test_str_wrap_fbnet_arch_def(self):
"""Check that fbnet modeldef converted to str"""
d = {"MODEL": {"FBNET_V2": {"ARCH_DEF": {"key0": "val0"}}}}
new_dict = str_wrap_fbnet_arch_def(d)
gt = {"MODEL": {"FBNET_V2": {"ARCH_DEF": """'{"key0": "val0"}'"""}}}
self.assertEqual(new_dict, gt)
self.assertNotEqual(d, new_dict)
# check only fbnet arch is changed
d = {"a0": "a1", "b0": {"b1": "b2"}}
gt = copy.deepcopy(d)
new_dict = str_wrap_fbnet_arch_def(d)
self.assertEqual(new_dict, gt)
def test_flatten_config_dict(self):
"""Check flatten config dict to single layer dict"""
d = {"a0": "a1", "b0": {"b1": "b2"}, "c0": {"c1": {"c2": 3}}}
fdict = flatten_config_dict(d)
gt = {"a0": "a1", "b0.b1": "b2", "c0.c1.c2": 3}
self.assertEqual(fdict, gt)
def test_config_dict_to_list_str(self):
"""Check convert config dict to str list"""
d = {"a0": "a1", "b0": {"b1": "b2"}, "c0": {"c1": {"c2": 3}}}
str_list = config_dict_to_list_str(d)
gt = ["a0", "a1", "b0.b1", "b2", "c0.c1.c2", "3"]
self.assertEqual(str_list, gt)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment