test_config.py 5.79 KB
Newer Older
facebook-github-bot's avatar
facebook-github-bot committed
1
2
3
4
5
6
7
8
9
10
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved


import glob
import logging
import os
import unittest

from d2go.config import auto_scale_world_size, reroute_config_path
Yanghan Wang's avatar
Yanghan Wang committed
11
12
13
14
15
16
from d2go.config.utils import (
    config_dict_to_list_str,
    flatten_config_dict,
    get_cfg_diff_table,
    get_from_flattened_config_dict,
)
facebook-github-bot's avatar
facebook-github-bot committed
17
from d2go.runner import GeneralizedRCNNRunner
Sam Tsai's avatar
Sam Tsai committed
18
from d2go.utils.testing.helper import get_resource_path
facebook-github-bot's avatar
facebook-github-bot committed
19
20
21
22
23
24
from mobile_cv.common.misc.file_utils import make_temp_directory


logger = logging.getLogger(__name__)


Yanghan Wang's avatar
Yanghan Wang committed
25
26
27
class TestConfig(unittest.TestCase):
    def test_load_configs(self):
        """Make sure configs are loadable"""
facebook-github-bot's avatar
facebook-github-bot committed
28
29
30

        for location in ["detectron2", "detectron2go"]:
            root_dir = os.path.abspath(reroute_config_path(f"{location}://."))
Yanghan Wang's avatar
Yanghan Wang committed
31
            files = glob.glob(os.path.join(root_dir, "**/*.yaml"), recursive=True)
32
            files = [f for f in files if "fbnas" not in f]
facebook-github-bot's avatar
facebook-github-bot committed
33
34
35
36
37
            self.assertGreater(len(files), 0)
            for fn in sorted(files):
                logger.info("Loading {}...".format(fn))
                GeneralizedRCNNRunner().get_default_cfg().merge_from_file(fn)

Yanghan Wang's avatar
Yanghan Wang committed
38
39
    def test_load_arch_defs(self):
        """Test arch def str-to-dict conversion compatible with merging"""
facebook-github-bot's avatar
facebook-github-bot committed
40
41
        default_cfg = GeneralizedRCNNRunner().get_default_cfg()
        cfg = default_cfg.clone()
Sam Tsai's avatar
Sam Tsai committed
42
        cfg.merge_from_file(get_resource_path("arch_def_merging.yaml"))
facebook-github-bot's avatar
facebook-github-bot committed
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79

        with make_temp_directory("detectron2go_tmp") as tmp_dir:
            # Dump out config with arch def
            file_name = os.path.join(tmp_dir, "test_archdef_config.yaml")
            with open(file_name, "w") as f:
                f.write(cfg.dump())

            # Attempt to reload the config
            another_cfg = default_cfg.clone()
            another_cfg.merge_from_file(file_name)

    def test_default_cfg_dump_and_load(self):
        default_cfg = GeneralizedRCNNRunner().get_default_cfg()

        cfg = default_cfg.clone()
        with make_temp_directory("detectron2go_tmp") as tmp_dir:
            file_name = os.path.join(tmp_dir, "config.yaml")
            # this is same as the one in fblearner_launch_utils_detectron2go.py
            with open(file_name, "w") as f:
                f.write(cfg.dump(default_flow_style=False))

            # check if the dumped config file can be merged
            cfg.merge_from_file(file_name)

    def test_default_cfg_deprecated_keys(self):
        default_cfg = GeneralizedRCNNRunner().get_default_cfg()

        # a warning will be printed for deprecated keys
        default_cfg.merge_from_list(["QUANTIZATION.QAT.LOAD_PRETRAINED", True])
        # exception will raise for renamed keys
        self.assertRaises(
            KeyError,
            default_cfg.merge_from_list,
            ["QUANTIZATION.QAT.BACKEND", "fbgemm"],
        )


Yanghan Wang's avatar
Yanghan Wang committed
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
class TestConfigUtils(unittest.TestCase):
    """Test util functions in config/utils.py"""

    def test_flatten_config_dict(self):
        """Check flatten config dict to single layer dict"""
        d = {"c0": {"c1": {"c2": 3}}, "b0": {"b1": "b2"}, "a0": "a1"}

        # reorder=True
        fdict = flatten_config_dict(d, reorder=True)
        gt = {"a0": "a1", "b0.b1": "b2", "c0.c1.c2": 3}
        self.assertEqual(fdict, gt)
        self.assertEqual(list(fdict.keys()), list(gt.keys()))

        # reorder=False
        fdict = flatten_config_dict(d, reorder=False)
        gt = {"c0.c1.c2": 3, "b0.b1": "b2", "a0": "a1"}
        self.assertEqual(fdict, gt)
        self.assertEqual(list(fdict.keys()), list(gt.keys()))

    def test_config_dict_to_list_str(self):
        """Check convert config dict to str list"""
        d = {"a0": "a1", "b0": {"b1": "b2"}, "c0": {"c1": {"c2": 3}}}
        str_list = config_dict_to_list_str(d)
        gt = ["a0", "a1", "b0.b1", "b2", "c0.c1.c2", "3"]
        self.assertEqual(str_list, gt)

    def test_get_from_flattened_config_dict(self):
        d = {"MODEL": {"MIN_DIM_SIZE": 360}}
        self.assertEqual(
            get_from_flattened_config_dict(d, "MODEL.MIN_DIM_SIZE"), 360
        )  # exist
        self.assertEqual(
            get_from_flattened_config_dict(d, "MODEL.MODEL.INPUT_SIZE"), None
        )  # non-exist

    def test_get_cfg_diff_table(self):
        """Check compare two dicts"""
        d1 = {"a0": "a1", "b0": {"b1": "b2"}, "c0": {"c1": {"c2": 3}}}
        d2 = {"a0": "a1", "b0": {"b1": "b3"}, "c0": {"c1": {"c2": 4}}}
        table = get_cfg_diff_table(d1, d2)
        self.assertTrue("a0" not in table)  # a0 are the same
        self.assertTrue("b0.b1" in table)  # b0.b1 are different
        self.assertTrue("c0.c1.c2" in table)  # c0.c1.c2 are different


facebook-github-bot's avatar
facebook-github-bot committed
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
class TestAutoScaleWorldSize(unittest.TestCase):
    def test_8gpu_to_1gpu(self):
        """
        when scaling a 8-gpu config to 1-gpu one, the batch size will be reduced by 8x
        """
        cfg = GeneralizedRCNNRunner().get_default_cfg()
        self.assertEqual(cfg.SOLVER.REFERENCE_WORLD_SIZE, 8)
        batch_size_x8 = cfg.SOLVER.IMS_PER_BATCH
        assert batch_size_x8 % 8 == 0, "default batch size is not multiple of 8"
        auto_scale_world_size(cfg, new_world_size=1)
        self.assertEqual(cfg.SOLVER.REFERENCE_WORLD_SIZE, 1)
        self.assertEqual(cfg.SOLVER.IMS_PER_BATCH * 8, batch_size_x8)

    def test_not_scale_for_zero_world_size(self):
        """
        when reference world size is 0, no scaling should happen
        """
        cfg = GeneralizedRCNNRunner().get_default_cfg()
        self.assertEqual(cfg.SOLVER.REFERENCE_WORLD_SIZE, 8)
        cfg.SOLVER.REFERENCE_WORLD_SIZE = 0
        batch_size_x8 = cfg.SOLVER.IMS_PER_BATCH
        auto_scale_world_size(cfg, new_world_size=1)
        self.assertEqual(cfg.SOLVER.REFERENCE_WORLD_SIZE, 0)
        self.assertEqual(cfg.SOLVER.IMS_PER_BATCH, batch_size_x8)


if __name__ == "__main__":
    unittest.main()