test_config.py 1.68 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import unittest

from d2go.config import CfgNode


class TestConfigNode(unittest.TestCase):
    @staticmethod
    def _get_default_config():
        cfg = CfgNode()
        cfg.INPUT = CfgNode()
        cfg.INPUT.CROP = CfgNode()
        cfg.INPUT.CROP.ENABLED = False
        cfg.INPUT.CROP.SIZE = (0.9, 0.9)
        cfg.INPUT.CROP.TYPE = "relative_range"
        cfg.MODEL = CfgNode()
        cfg.MODEL.MIN_DIM_SIZE = 360
        cfg.INFERENCE_SDK = CfgNode()
        cfg.INFERENCE_SDK.MODEL = CfgNode()
        cfg.INFERENCE_SDK.MODEL.SCORE_THRESHOLD = 0.8
        cfg.INFERENCE_SDK.IOU_TRACKER = CfgNode()
        cfg.INFERENCE_SDK.IOU_TRACKER.IOU_THRESHOLD = 0.15
        cfg.INFERENCE_SDK.ENABLE_ID_TRACKING = True
        return cfg

    def test_get_field_or_none(self):
        cfg = self._get_default_config()
        self.assertEqual(cfg.get_field_or_none("MODEL.MIN_DIM_SIZE"), 360)
        self.assertEqual(
            cfg.get_field_or_none("INFERENCE_SDK.ENABLE_ID_TRACKING"), True
        )
        self.assertEqual(cfg.get_field_or_none("MODEL.INPUT_SIZE"), None)
        self.assertEqual(cfg.get_field_or_none("MODEL.INPUT_SIZE.HEIGHT"), None)

    def test_as_flattened_dict(self):
        cfg = self._get_default_config()
        cfg_dict = cfg.as_flattened_dict()
        target_cfg_dict = {
            "INPUT.CROP.ENABLED": False,
            "INPUT.CROP.SIZE": (0.9, 0.9),
            "INPUT.CROP.TYPE": "relative_range",
            "MODEL.MIN_DIM_SIZE": 360,
            "INFERENCE_SDK.MODEL.SCORE_THRESHOLD": 0.8,
            "INFERENCE_SDK.IOU_TRACKER.IOU_THRESHOLD": 0.15,
            "INFERENCE_SDK.ENABLE_ID_TRACKING": True,
        }
        self.assertEqual(target_cfg_dict, cfg_dict)