Commit 251eaed2 authored by Jonathan Zeltser's avatar Jonathan Zeltser Committed by Facebook GitHub Bot
Browse files

save and print diff config

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/202

This diff print the diff between the default config and the full config at the start of the run

Reviewed By: wat3rBro

Differential Revision: D35346096

fbshipit-source-id: 1ce9b58a8d613d1dd572358ce1e51462c90cb337
parent 848f9944
......@@ -135,3 +135,55 @@ def get_cfg_diff_table(cfg, original_cfg):
headers=["config key", "old value", "new value"],
)
return table
def get_diff_cfg(old_cfg, new_cfg):
"""
outputs a CfgNode containing keys, values appearing in new_cfg and not in old_cfg.
If `new_allowed` is not set, then new keys will throw a KeyError
old_cfg: CfgNode, the original config, usually the dafulat
new_cfg: CfgNode, the full config being passed by the user
if new allowed is not set on new_cfg, key error is raised
returns: CfgNode, a config containing only key, value changes between old_cfg and new_cfg
example:
Cfg1:
SYSTEM:
NUM_GPUS: 2
TRAIN:
SCALES: (1, 2)
DATASETS:
train_2017:
17: 1
18: 1
Cfg2:
SYSTEM:
NUM_GPUS: 2
TRAIN:
SCALES: (4, 5, 8)
DATASETS:
train_2017:
17: 1
18: 1
get_diff_cfg(Cfg1, Cfg2) gives:
TRAIN:
SCALES: (8, 16, 32)
"""
def get_diff_cfg_rec(old_cfg, new_cfg, out):
for key in new_cfg.keys():
if key not in old_cfg.keys() and old_cfg.is_new_allowed():
out[key] = new_cfg[key]
elif old_cfg[key] != new_cfg[key]:
if type(new_cfg[key]) is type(out):
out[key] = out.__class__()
out[key] = get_diff_cfg_rec(old_cfg[key], new_cfg[key], out[key])
else:
out[key] = new_cfg[key]
return out
out = new_cfg.__class__()
return get_diff_cfg_rec(old_cfg, new_cfg, out)
......@@ -15,6 +15,7 @@ from d2go.config import (
reroute_config_path,
temp_defrost,
)
from d2go.config.utils import get_diff_cfg
from d2go.distributed import get_local_rank, get_num_processes_per_machine
from d2go.runner import create_runner, GeneralizedRCNNRunner
from d2go.utils.helper import run_once
......@@ -25,7 +26,6 @@ from detectron2.utils.logger import setup_logger
from detectron2.utils.serialize import PicklableWrapper
from mobile_cv.common.misc.py import FolderLock, MultiprocessingPdb, post_mortem_if_fail
logger = logging.getLogger(__name__)
......@@ -209,7 +209,10 @@ def setup_after_launch(cfg: CN, output_dir: str, runner):
runner = initialize_runner(runner, cfg)
log_info(cfg, runner)
dump_cfg(
get_diff_cfg(runner.get_default_cfg(), cfg),
os.path.join(output_dir, "diff_config.yaml"),
)
auto_scale_world_size(cfg, new_world_size=comm.get_world_size())
......
......@@ -10,6 +10,7 @@ import unittest
from d2go.config import CfgNode
from d2go.config import auto_scale_world_size, reroute_config_path
from d2go.config.utils import (
get_diff_cfg,
config_dict_to_list_str,
flatten_config_dict,
get_cfg_diff_table,
......@@ -148,6 +149,48 @@ class TestConfigUtils(unittest.TestCase):
get_from_flattened_config_dict(d, "MODEL.MODEL.INPUT_SIZE"), None
) # non-exist
def test_get_diff_cfg(self):
"""check config that is diff from default config, no new keys"""
# create base config
cfg1 = CfgNode()
cfg1.A = CfgNode()
cfg1.A.Y = 2
# case 1: new allowed not set, new config has only old keys
cfg2 = cfg1.clone()
cfg2.set_new_allowed(False)
cfg2.A.Y = 3
gt = CfgNode()
gt.A = CfgNode()
gt.A.Y = 3
self.assertEqual(gt, get_diff_cfg(cfg1, cfg2))
def test_diff_cfg_no_new_allowed(self):
"""check that if new_allowed is False, new keys cause key error"""
# create base config
cfg1 = CfgNode()
cfg1.A = CfgNode()
cfg1.A.set_new_allowed(False)
cfg1.A.Y = 2
# case 2: new allowed not set, new config has new keys
cfg2 = cfg1.clone()
cfg2.A.X = 2
self.assertRaises(KeyError, get_diff_cfg, cfg1, cfg2)
def test_diff_cfg_with_new_allowed(self):
"""diff config with new keys and new_allowed set to True"""
# create base config
cfg1 = CfgNode()
cfg1.A = CfgNode()
cfg1.A.set_new_allowed(True)
cfg1.A.Y = 2
# case 3: new allowed set, new config has new keys
cfg2 = cfg1.clone()
cfg2.A.X = 2
gt = CfgNode()
gt.A = CfgNode()
gt.A.X = 2
self.assertEqual(gt, get_diff_cfg(cfg1, cfg2))
def test_get_cfg_diff_table(self):
"""Check compare two dicts"""
d1 = {"a0": "a1", "b0": {"b1": "b2"}, "c0": {"c1": {"c2": 3}}}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment