test_config.py 7.3 KB
Newer Older
facebook-github-bot's avatar
facebook-github-bot committed
1
2
3
4
5
6
7
8
9
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved


import glob
import logging
import os
import unittest

10
from d2go.config import CfgNode
facebook-github-bot's avatar
facebook-github-bot committed
11
from d2go.config import auto_scale_world_size, reroute_config_path
Yanghan Wang's avatar
Yanghan Wang committed
12
13
14
15
16
17
from d2go.config.utils import (
    config_dict_to_list_str,
    flatten_config_dict,
    get_cfg_diff_table,
    get_from_flattened_config_dict,
)
facebook-github-bot's avatar
facebook-github-bot committed
18
from d2go.runner import GeneralizedRCNNRunner
Sam Tsai's avatar
Sam Tsai committed
19
from d2go.utils.testing.helper import get_resource_path
facebook-github-bot's avatar
facebook-github-bot committed
20
21
22
23
24
25
from mobile_cv.common.misc.file_utils import make_temp_directory


logger = logging.getLogger(__name__)


Yanghan Wang's avatar
Yanghan Wang committed
26
27
28
class TestConfig(unittest.TestCase):
    def test_load_configs(self):
        """Make sure configs are loadable"""
facebook-github-bot's avatar
facebook-github-bot committed
29
30
31

        for location in ["detectron2", "detectron2go"]:
            root_dir = os.path.abspath(reroute_config_path(f"{location}://."))
Yanghan Wang's avatar
Yanghan Wang committed
32
            files = glob.glob(os.path.join(root_dir, "**/*.yaml"), recursive=True)
33
            files = [f for f in files if "fbnas" not in f]
facebook-github-bot's avatar
facebook-github-bot committed
34
35
36
37
38
            self.assertGreater(len(files), 0)
            for fn in sorted(files):
                logger.info("Loading {}...".format(fn))
                GeneralizedRCNNRunner().get_default_cfg().merge_from_file(fn)

Yanghan Wang's avatar
Yanghan Wang committed
39
40
    def test_load_arch_defs(self):
        """Test arch def str-to-dict conversion compatible with merging"""
facebook-github-bot's avatar
facebook-github-bot committed
41
42
        default_cfg = GeneralizedRCNNRunner().get_default_cfg()
        cfg = default_cfg.clone()
Sam Tsai's avatar
Sam Tsai committed
43
        cfg.merge_from_file(get_resource_path("arch_def_merging.yaml"))
facebook-github-bot's avatar
facebook-github-bot committed
44
45
46
47
48
49
50
51
52
53
54

        with make_temp_directory("detectron2go_tmp") as tmp_dir:
            # Dump out config with arch def
            file_name = os.path.join(tmp_dir, "test_archdef_config.yaml")
            with open(file_name, "w") as f:
                f.write(cfg.dump())

            # Attempt to reload the config
            another_cfg = default_cfg.clone()
            another_cfg.merge_from_file(file_name)

55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
    def test_base_reroute(self):
        default_cfg = GeneralizedRCNNRunner().get_default_cfg()

        # use rerouted file as base
        cfg = default_cfg.clone()
        cfg.merge_from_file(get_resource_path("rerouted_base.yaml"))
        self.assertEqual(cfg.MODEL.MASK_ON, True)  # base is loaded
        self.assertEqual(cfg.MODEL.FBNET_V2.ARCH, "test")  # non-base is loaded

        # use multiple files as base
        cfg = default_cfg.clone()
        cfg.merge_from_file(get_resource_path("rerouted_multi_base.yaml"))
        self.assertEqual(cfg.MODEL.MASK_ON, True)  # base is loaded
        self.assertEqual(cfg.MODEL.FBNET_V2.ARCH, "FBNetV3_A")  # second base is loaded
        self.assertEqual(cfg.OUTPUT_DIR, "test")  # non-base is loaded

facebook-github-bot's avatar
facebook-github-bot committed
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
    def test_default_cfg_dump_and_load(self):
        default_cfg = GeneralizedRCNNRunner().get_default_cfg()

        cfg = default_cfg.clone()
        with make_temp_directory("detectron2go_tmp") as tmp_dir:
            file_name = os.path.join(tmp_dir, "config.yaml")
            # this is same as the one in fblearner_launch_utils_detectron2go.py
            with open(file_name, "w") as f:
                f.write(cfg.dump(default_flow_style=False))

            # check if the dumped config file can be merged
            cfg.merge_from_file(file_name)

    def test_default_cfg_deprecated_keys(self):
        default_cfg = GeneralizedRCNNRunner().get_default_cfg()

        # a warning will be printed for deprecated keys
        default_cfg.merge_from_list(["QUANTIZATION.QAT.LOAD_PRETRAINED", True])
        # exception will raise for renamed keys
        self.assertRaises(
            KeyError,
            default_cfg.merge_from_list,
            ["QUANTIZATION.QAT.BACKEND", "fbgemm"],
        )

96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
    def test_merge_from_list_with_new_allowed(self):
        """
        YACS's merge_from_list doesn't take new_allowed into account, D2Go override its behavior, and this test covers it.
        """
        # new_allowed is not set
        cfg = CfgNode()
        cfg.A = CfgNode()
        cfg.A.X = 1
        self.assertRaises(Exception, cfg.merge_from_list, ["A.Y", "2"])

        # new_allowed is set for sub key
        cfg = CfgNode()
        cfg.A = CfgNode(new_allowed=True)
        cfg.A.X = 1
        cfg.merge_from_list(["A.Y", "2"])
        self.assertEqual(cfg.A.Y, 2)  # note that the string will be converted to number
        # however new_allowed is not set for root key
        self.assertRaises(Exception, cfg.merge_from_list, ["B", "3"])

facebook-github-bot's avatar
facebook-github-bot committed
115

Yanghan Wang's avatar
Yanghan Wang committed
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
class TestConfigUtils(unittest.TestCase):
    """Test util functions in config/utils.py"""

    def test_flatten_config_dict(self):
        """Check flatten config dict to single layer dict"""
        d = {"c0": {"c1": {"c2": 3}}, "b0": {"b1": "b2"}, "a0": "a1"}

        # reorder=True
        fdict = flatten_config_dict(d, reorder=True)
        gt = {"a0": "a1", "b0.b1": "b2", "c0.c1.c2": 3}
        self.assertEqual(fdict, gt)
        self.assertEqual(list(fdict.keys()), list(gt.keys()))

        # reorder=False
        fdict = flatten_config_dict(d, reorder=False)
        gt = {"c0.c1.c2": 3, "b0.b1": "b2", "a0": "a1"}
        self.assertEqual(fdict, gt)
        self.assertEqual(list(fdict.keys()), list(gt.keys()))

    def test_config_dict_to_list_str(self):
        """Check convert config dict to str list"""
        d = {"a0": "a1", "b0": {"b1": "b2"}, "c0": {"c1": {"c2": 3}}}
        str_list = config_dict_to_list_str(d)
        gt = ["a0", "a1", "b0.b1", "b2", "c0.c1.c2", "3"]
        self.assertEqual(str_list, gt)

    def test_get_from_flattened_config_dict(self):
        d = {"MODEL": {"MIN_DIM_SIZE": 360}}
        self.assertEqual(
            get_from_flattened_config_dict(d, "MODEL.MIN_DIM_SIZE"), 360
        )  # exist
        self.assertEqual(
            get_from_flattened_config_dict(d, "MODEL.MODEL.INPUT_SIZE"), None
        )  # non-exist

    def test_get_cfg_diff_table(self):
        """Check compare two dicts"""
        d1 = {"a0": "a1", "b0": {"b1": "b2"}, "c0": {"c1": {"c2": 3}}}
        d2 = {"a0": "a1", "b0": {"b1": "b3"}, "c0": {"c1": {"c2": 4}}}
        table = get_cfg_diff_table(d1, d2)
        self.assertTrue("a0" not in table)  # a0 are the same
        self.assertTrue("b0.b1" in table)  # b0.b1 are different
        self.assertTrue("c0.c1.c2" in table)  # c0.c1.c2 are different


facebook-github-bot's avatar
facebook-github-bot committed
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
class TestAutoScaleWorldSize(unittest.TestCase):
    def test_8gpu_to_1gpu(self):
        """
        when scaling a 8-gpu config to 1-gpu one, the batch size will be reduced by 8x
        """
        cfg = GeneralizedRCNNRunner().get_default_cfg()
        self.assertEqual(cfg.SOLVER.REFERENCE_WORLD_SIZE, 8)
        batch_size_x8 = cfg.SOLVER.IMS_PER_BATCH
        assert batch_size_x8 % 8 == 0, "default batch size is not multiple of 8"
        auto_scale_world_size(cfg, new_world_size=1)
        self.assertEqual(cfg.SOLVER.REFERENCE_WORLD_SIZE, 1)
        self.assertEqual(cfg.SOLVER.IMS_PER_BATCH * 8, batch_size_x8)

    def test_not_scale_for_zero_world_size(self):
        """
        when reference world size is 0, no scaling should happen
        """
        cfg = GeneralizedRCNNRunner().get_default_cfg()
        self.assertEqual(cfg.SOLVER.REFERENCE_WORLD_SIZE, 8)
        cfg.SOLVER.REFERENCE_WORLD_SIZE = 0
        batch_size_x8 = cfg.SOLVER.IMS_PER_BATCH
        auto_scale_world_size(cfg, new_world_size=1)
        self.assertEqual(cfg.SOLVER.REFERENCE_WORLD_SIZE, 0)
        self.assertEqual(cfg.SOLVER.IMS_PER_BATCH, batch_size_x8)


if __name__ == "__main__":
    unittest.main()