"vscode:/vscode.git/clone" did not exist on "fa7ccb3316dccdf0326913222c337da20b436251"
test_config.py 10.2 KB
Newer Older
facebook-github-bot's avatar
facebook-github-bot committed
1
2
3
4
5
6
7
8
9
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved


import glob
import logging
import os
import unittest

10
11
12
13
14
15
from d2go.config import (
    auto_scale_world_size,
    CfgNode,
    load_full_config_from_file,
    reroute_config_path,
)
Yanghan Wang's avatar
Yanghan Wang committed
16
17
18
19
from d2go.config.utils import (
    config_dict_to_list_str,
    flatten_config_dict,
    get_cfg_diff_table,
20
    get_diff_cfg,
Yanghan Wang's avatar
Yanghan Wang committed
21
22
    get_from_flattened_config_dict,
)
23
from d2go.registry.builtin import CONFIG_UPDATER_REGISTRY
facebook-github-bot's avatar
facebook-github-bot committed
24
from d2go.runner import GeneralizedRCNNRunner
Sam Tsai's avatar
Sam Tsai committed
25
from d2go.utils.testing.helper import get_resource_path
facebook-github-bot's avatar
facebook-github-bot committed
26
27
28
29
30
31
from mobile_cv.common.misc.file_utils import make_temp_directory


logger = logging.getLogger(__name__)


Yanghan Wang's avatar
Yanghan Wang committed
32
33
34
class TestConfig(unittest.TestCase):
    def test_load_configs(self):
        """Make sure configs are loadable"""
facebook-github-bot's avatar
facebook-github-bot committed
35
36
37

        for location in ["detectron2", "detectron2go"]:
            root_dir = os.path.abspath(reroute_config_path(f"{location}://."))
Yanghan Wang's avatar
Yanghan Wang committed
38
            files = glob.glob(os.path.join(root_dir, "**/*.yaml"), recursive=True)
39
            files = [f for f in files if "fbnas" not in f]
facebook-github-bot's avatar
facebook-github-bot committed
40
41
42
            self.assertGreater(len(files), 0)
            for fn in sorted(files):
                logger.info("Loading {}...".format(fn))
43
                GeneralizedRCNNRunner.get_default_cfg().merge_from_file(fn)
facebook-github-bot's avatar
facebook-github-bot committed
44

Yanghan Wang's avatar
Yanghan Wang committed
45
46
    def test_load_arch_defs(self):
        """Test arch def str-to-dict conversion compatible with merging"""
47
        default_cfg = GeneralizedRCNNRunner.get_default_cfg()
facebook-github-bot's avatar
facebook-github-bot committed
48
        cfg = default_cfg.clone()
Sam Tsai's avatar
Sam Tsai committed
49
        cfg.merge_from_file(get_resource_path("arch_def_merging.yaml"))
facebook-github-bot's avatar
facebook-github-bot committed
50
51
52
53
54
55
56
57
58
59
60

        with make_temp_directory("detectron2go_tmp") as tmp_dir:
            # Dump out config with arch def
            file_name = os.path.join(tmp_dir, "test_archdef_config.yaml")
            with open(file_name, "w") as f:
                f.write(cfg.dump())

            # Attempt to reload the config
            another_cfg = default_cfg.clone()
            another_cfg.merge_from_file(file_name)

61
    def test_base_reroute(self):
62
        default_cfg = GeneralizedRCNNRunner.get_default_cfg()
63
64
65
66
67
68
69
70
71
72
73
74
75
76

        # use rerouted file as base
        cfg = default_cfg.clone()
        cfg.merge_from_file(get_resource_path("rerouted_base.yaml"))
        self.assertEqual(cfg.MODEL.MASK_ON, True)  # base is loaded
        self.assertEqual(cfg.MODEL.FBNET_V2.ARCH, "test")  # non-base is loaded

        # use multiple files as base
        cfg = default_cfg.clone()
        cfg.merge_from_file(get_resource_path("rerouted_multi_base.yaml"))
        self.assertEqual(cfg.MODEL.MASK_ON, True)  # base is loaded
        self.assertEqual(cfg.MODEL.FBNET_V2.ARCH, "FBNetV3_A")  # second base is loaded
        self.assertEqual(cfg.OUTPUT_DIR, "test")  # non-base is loaded

facebook-github-bot's avatar
facebook-github-bot committed
77
    def test_default_cfg_dump_and_load(self):
78
        default_cfg = GeneralizedRCNNRunner.get_default_cfg()
facebook-github-bot's avatar
facebook-github-bot committed
79
80
81
82
83
84
85
86
87
88
89
90

        cfg = default_cfg.clone()
        with make_temp_directory("detectron2go_tmp") as tmp_dir:
            file_name = os.path.join(tmp_dir, "config.yaml")
            # this is same as the one in fblearner_launch_utils_detectron2go.py
            with open(file_name, "w") as f:
                f.write(cfg.dump(default_flow_style=False))

            # check if the dumped config file can be merged
            cfg.merge_from_file(file_name)

    def test_default_cfg_deprecated_keys(self):
91
        default_cfg = GeneralizedRCNNRunner.get_default_cfg()
facebook-github-bot's avatar
facebook-github-bot committed
92
93
94
95
96
97
98
99
100
101

        # a warning will be printed for deprecated keys
        default_cfg.merge_from_list(["QUANTIZATION.QAT.LOAD_PRETRAINED", True])
        # exception will raise for renamed keys
        self.assertRaises(
            KeyError,
            default_cfg.merge_from_list,
            ["QUANTIZATION.QAT.BACKEND", "fbgemm"],
        )

102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
    def test_merge_from_list_with_new_allowed(self):
        """
        YACS's merge_from_list doesn't take new_allowed into account, D2Go override its behavior, and this test covers it.
        """
        # new_allowed is not set
        cfg = CfgNode()
        cfg.A = CfgNode()
        cfg.A.X = 1
        self.assertRaises(Exception, cfg.merge_from_list, ["A.Y", "2"])

        # new_allowed is set for sub key
        cfg = CfgNode()
        cfg.A = CfgNode(new_allowed=True)
        cfg.A.X = 1
        cfg.merge_from_list(["A.Y", "2"])
        self.assertEqual(cfg.A.Y, 2)  # note that the string will be converted to number
        # however new_allowed is not set for root key
        self.assertRaises(Exception, cfg.merge_from_list, ["B", "3"])

facebook-github-bot's avatar
facebook-github-bot committed
121

Yanghan Wang's avatar
Yanghan Wang committed
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
class TestConfigUtils(unittest.TestCase):
    """Test util functions in config/utils.py"""

    def test_flatten_config_dict(self):
        """Check flatten config dict to single layer dict"""
        d = {"c0": {"c1": {"c2": 3}}, "b0": {"b1": "b2"}, "a0": "a1"}

        # reorder=True
        fdict = flatten_config_dict(d, reorder=True)
        gt = {"a0": "a1", "b0.b1": "b2", "c0.c1.c2": 3}
        self.assertEqual(fdict, gt)
        self.assertEqual(list(fdict.keys()), list(gt.keys()))

        # reorder=False
        fdict = flatten_config_dict(d, reorder=False)
        gt = {"c0.c1.c2": 3, "b0.b1": "b2", "a0": "a1"}
        self.assertEqual(fdict, gt)
        self.assertEqual(list(fdict.keys()), list(gt.keys()))

    def test_config_dict_to_list_str(self):
        """Check convert config dict to str list"""
        d = {"a0": "a1", "b0": {"b1": "b2"}, "c0": {"c1": {"c2": 3}}}
        str_list = config_dict_to_list_str(d)
        gt = ["a0", "a1", "b0.b1", "b2", "c0.c1.c2", "3"]
        self.assertEqual(str_list, gt)

    def test_get_from_flattened_config_dict(self):
        d = {"MODEL": {"MIN_DIM_SIZE": 360}}
        self.assertEqual(
            get_from_flattened_config_dict(d, "MODEL.MIN_DIM_SIZE"), 360
        )  # exist
        self.assertEqual(
            get_from_flattened_config_dict(d, "MODEL.MODEL.INPUT_SIZE"), None
        )  # non-exist

Jonathan Zeltser's avatar
Jonathan Zeltser committed
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
    def test_get_diff_cfg(self):
        """check config that is diff from default config, no new keys"""
        # create base config
        cfg1 = CfgNode()
        cfg1.A = CfgNode()
        cfg1.A.Y = 2
        # case 1: new allowed not set, new config has only old keys
        cfg2 = cfg1.clone()
        cfg2.set_new_allowed(False)
        cfg2.A.Y = 3
        gt = CfgNode()
        gt.A = CfgNode()
        gt.A.Y = 3
        self.assertEqual(gt, get_diff_cfg(cfg1, cfg2))

    def test_diff_cfg_no_new_allowed(self):
        """check that if new_allowed is False, new keys cause key error"""
        # create base config
        cfg1 = CfgNode()
        cfg1.A = CfgNode()
        cfg1.A.set_new_allowed(False)
        cfg1.A.Y = 2
        # case 2: new allowed not set, new config has new keys
        cfg2 = cfg1.clone()
        cfg2.A.X = 2
        self.assertRaises(KeyError, get_diff_cfg, cfg1, cfg2)

    def test_diff_cfg_with_new_allowed(self):
        """diff config with new keys and new_allowed set to True"""
        # create base config
        cfg1 = CfgNode()
        cfg1.A = CfgNode()
        cfg1.A.set_new_allowed(True)
        cfg1.A.Y = 2
        # case 3: new allowed set, new config has new keys
        cfg2 = cfg1.clone()
        cfg2.A.X = 2
        gt = CfgNode()
        gt.A = CfgNode()
        gt.A.X = 2
        self.assertEqual(gt, get_diff_cfg(cfg1, cfg2))

Yanghan Wang's avatar
Yanghan Wang committed
199
200
201
202
203
204
205
206
207
    def test_get_cfg_diff_table(self):
        """Check compare two dicts"""
        d1 = {"a0": "a1", "b0": {"b1": "b2"}, "c0": {"c1": {"c2": 3}}}
        d2 = {"a0": "a1", "b0": {"b1": "b3"}, "c0": {"c1": {"c2": 4}}}
        table = get_cfg_diff_table(d1, d2)
        self.assertTrue("a0" not in table)  # a0 are the same
        self.assertTrue("b0.b1" in table)  # b0.b1 are different
        self.assertTrue("c0.c1.c2" in table)  # c0.c1.c2 are different

208
209
210
211
212
213
214
215
216
217
218
    def test_get_cfg_diff_table_mismatched_keys(self):
        """Check compare two dicts, the keys are mismatched"""
        d_orig = {"a0": "a1", "b0": {"b1": "b2"}, "c0": {"c1": {"c2": 3}}}
        d_new = {"a0": "a1", "b0": {"b1": "b3"}, "c0": {"c4": {"c2": 4}}}
        table = get_cfg_diff_table(d_new, d_orig)
        self.assertTrue("a0" not in table)  # a0 are the same
        self.assertTrue("b0.b1" in table)  # b0.b1 are different
        self.assertTrue("c0.c1.c2" in table)  # c0.c1.c2 key mismatched
        self.assertTrue("c0.c4.c2" in table)  # c0.c4.c2 key mismatched
        self.assertTrue("Key not exists" in table)  # has mismatched key

Yanghan Wang's avatar
Yanghan Wang committed
219

facebook-github-bot's avatar
facebook-github-bot committed
220
221
222
223
224
class TestAutoScaleWorldSize(unittest.TestCase):
    def test_8gpu_to_1gpu(self):
        """
        when scaling a 8-gpu config to 1-gpu one, the batch size will be reduced by 8x
        """
225
        cfg = GeneralizedRCNNRunner.get_default_cfg()
facebook-github-bot's avatar
facebook-github-bot committed
226
227
228
229
230
231
232
233
234
235
236
        self.assertEqual(cfg.SOLVER.REFERENCE_WORLD_SIZE, 8)
        batch_size_x8 = cfg.SOLVER.IMS_PER_BATCH
        assert batch_size_x8 % 8 == 0, "default batch size is not multiple of 8"
        auto_scale_world_size(cfg, new_world_size=1)
        self.assertEqual(cfg.SOLVER.REFERENCE_WORLD_SIZE, 1)
        self.assertEqual(cfg.SOLVER.IMS_PER_BATCH * 8, batch_size_x8)

    def test_not_scale_for_zero_world_size(self):
        """
        when reference world size is 0, no scaling should happen
        """
237
        cfg = GeneralizedRCNNRunner.get_default_cfg()
facebook-github-bot's avatar
facebook-github-bot committed
238
239
240
241
242
243
244
245
        self.assertEqual(cfg.SOLVER.REFERENCE_WORLD_SIZE, 8)
        cfg.SOLVER.REFERENCE_WORLD_SIZE = 0
        batch_size_x8 = cfg.SOLVER.IMS_PER_BATCH
        auto_scale_world_size(cfg, new_world_size=1)
        self.assertEqual(cfg.SOLVER.REFERENCE_WORLD_SIZE, 0)
        self.assertEqual(cfg.SOLVER.IMS_PER_BATCH, batch_size_x8)


246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
class TestConfigDefaultsGen(unittest.TestCase):
    def test_case1(self):

        # register in local scope
        @CONFIG_UPDATER_REGISTRY.register()
        def _test1(cfg):
            cfg.TEST1 = CfgNode()
            cfg.TEST1.X = 1
            return cfg

        @CONFIG_UPDATER_REGISTRY.register()
        def _test2(cfg):
            cfg.TEST2 = CfgNode()
            cfg.TEST2.Y = 2
            return cfg

        filename = get_resource_path("defaults_gen_case1.yaml")
        cfg = load_full_config_from_file(filename)
        default_cfg = cfg.get_default_cfg()
        # default value is 1
        self.assertEqual(default_cfg.TEST1.X, 1)
        self.assertEqual(default_cfg.TEST2.Y, 2)
        # yaml file overwrites it to 3
        self.assertEqual(cfg.TEST1.X, 3)


facebook-github-bot's avatar
facebook-github-bot committed
272
273
if __name__ == "__main__":
    unittest.main()