Commit 313eb946 authored by Peizhao Zhang's avatar Peizhao Zhang Committed by Facebook GitHub Bot
Browse files

Support cfg_diff for mismatched keys.

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/277

Support cfg_diff for mismatched keys.

Reviewed By: wat3rBro

Differential Revision: D36737254

fbshipit-source-id: a1c189c92a24f3c109d9a427f135e53876b91624
parent 6f036248
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import logging
import os import os
from enum import Enum from enum import Enum
from typing import Any, Dict, List from typing import Any, Dict, List
...@@ -121,9 +122,29 @@ def get_cfg_diff_table(cfg, original_cfg): ...@@ -121,9 +122,29 @@ def get_cfg_diff_table(cfg, original_cfg):
all_old_keys = list(flatten_config_dict(original_cfg, reorder=True).keys()) all_old_keys = list(flatten_config_dict(original_cfg, reorder=True).keys())
all_new_keys = list(flatten_config_dict(cfg, reorder=True).keys()) all_new_keys = list(flatten_config_dict(cfg, reorder=True).keys())
assert all_old_keys == all_new_keys
diff_table = [] diff_table = []
if all_old_keys != all_new_keys:
logger = logging.getLogger(__name__)
mismatched_old_keys = set(all_old_keys) - set(all_new_keys)
mismatched_new_keys = set(all_new_keys) - set(all_old_keys)
logger.warning(
"Config key mismatched.\n"
f"Mismatched old keys: {mismatched_old_keys}\n"
f"Mismatched new keys: {mismatched_new_keys}"
)
for old_key in mismatched_old_keys:
old_value = get_from_flattened_config_dict(original_cfg, old_key)
diff_table.append([old_key, old_value, "Key not exists"])
for new_key in mismatched_new_keys:
new_value = get_from_flattened_config_dict(cfg, new_key)
diff_table.append([new_key, "Key not exists", new_value])
# filter out mis-matched keys
all_old_keys = [x for x in all_old_keys if x not in mismatched_old_keys]
all_new_keys = [x for x in all_new_keys if x not in mismatched_new_keys]
for full_key in all_new_keys: for full_key in all_new_keys:
old_value = get_from_flattened_config_dict(original_cfg, full_key) old_value = get_from_flattened_config_dict(original_cfg, full_key)
new_value = get_from_flattened_config_dict(cfg, full_key) new_value = get_from_flattened_config_dict(cfg, full_key)
......
...@@ -205,6 +205,17 @@ class TestConfigUtils(unittest.TestCase): ...@@ -205,6 +205,17 @@ class TestConfigUtils(unittest.TestCase):
self.assertTrue("b0.b1" in table) # b0.b1 are different self.assertTrue("b0.b1" in table) # b0.b1 are different
self.assertTrue("c0.c1.c2" in table) # c0.c1.c2 are different self.assertTrue("c0.c1.c2" in table) # c0.c1.c2 are different
def test_get_cfg_diff_table_mismatched_keys(self):
"""Check compare two dicts, the keys are mismatched"""
d_orig = {"a0": "a1", "b0": {"b1": "b2"}, "c0": {"c1": {"c2": 3}}}
d_new = {"a0": "a1", "b0": {"b1": "b3"}, "c0": {"c4": {"c2": 4}}}
table = get_cfg_diff_table(d_new, d_orig)
self.assertTrue("a0" not in table) # a0 are the same
self.assertTrue("b0.b1" in table) # b0.b1 are different
self.assertTrue("c0.c1.c2" in table) # c0.c1.c2 key mismatched
self.assertTrue("c0.c4.c2" in table) # c0.c4.c2 key mismatched
self.assertTrue("Key not exists" in table) # has mismatched key
class TestAutoScaleWorldSize(unittest.TestCase): class TestAutoScaleWorldSize(unittest.TestCase):
def test_8gpu_to_1gpu(self): def test_8gpu_to_1gpu(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment