test_config_base.py 3.57 KB
Newer Older
liuzhe-lz's avatar
liuzhe-lz committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
from copy import deepcopy
from dataclasses import dataclass
from typing import Dict, List, Optional, Union

from nni.experiment.config.base import ConfigBase

# config classes

@dataclass(init=False)
class NestedChild(ConfigBase):
    msg: str
    int_field: int = 1

    def _canonicalize(self, parents):
        if '/' not in self.msg:
            self.msg = parents[0].msg + '/' + self.msg
        super()._canonicalize(parents)

    def _validate_canonical(self):
        super()._validate_canonical()
        if not self.msg.endswith('[2]'):
            raise ValueError('not end with [2]')

@dataclass(init=False)
class Child(ConfigBase):
    msg: str
    children: List[NestedChild]

    def _canonicalize(self, parents):
        if '/' not in self.msg:
            self.msg = parents[0].msg + '/' + self.msg
        super()._canonicalize(parents)

    def _validate_canonical(self):
        super()._validate_canonical()
        if not self.msg.endswith('[1]'):
            raise ValueError('not end with "[1]"')

@dataclass(init=False)
class TestConfig(ConfigBase):
    msg: str
    required_field: Optional[int]
    optional_field: Optional[int] = None
    multi_type_field: Union[int, List[int]]
    child: Optional[Child] = None

    def _canonicalize(self, parents):
        if isinstance(self.multi_type_field, int):
            self.multi_type_field = [self.multi_type_field]
        super()._canonicalize(parents)

# sample inputs

good = {
    'msg': 'a',
    'required_field': 10,
    'multi_type_field': 20,
    'child': {
        'msg': 'b[1]',
        'children': [{
            'msg': 'c[2]',
            'int_field': 30,
        }, {
            'msg': 'd[2]',
        }],
    },
}

missing = deepcopy(good)
missing.pop('required_field')

wrong_type = deepcopy(good)
wrong_type['optional_field'] = 0.5

nested_wrong_type = deepcopy(good)
nested_wrong_type['child']['children'][1]['int_field'] = 'str'

bad_value = deepcopy(good)
bad_value['child']['msg'] = 'b'

extra_field = deepcopy(good)
extra_field['hello'] = 'world'

bads = {
    'missing': missing,
    'wrong_type': wrong_type,
    'nested_wrong_type': nested_wrong_type,
    'bad_value': bad_value,
    'extra_field': extra_field,
}

# ground truth

_nested_child_1 = NestedChild()
_nested_child_1.msg = 'c[2]'
_nested_child_1.int_field = 30

_nested_child_2 = NestedChild()
_nested_child_2.msg = 'd[2]'
_nested_child_2.int_field = 1

_child = Child()
_child.msg = 'b[1]'
_child.children = [_nested_child_1, _nested_child_2]

good_config = TestConfig()
good_config.msg = 'a'
good_config.required_field = 10
good_config.optional_field = None
good_config.multi_type_field = 20
good_config.child = _child

_nested_child_1 = NestedChild()
_nested_child_1.msg = 'a/b[1]/c[2]'
_nested_child_1.int_field = 30

_nested_child_2 = NestedChild()
_nested_child_2.msg = 'a/b[1]/d[2]'
_nested_child_2.int_field = 1

_child = Child()
_child.msg = 'a/b[1]'
_child.children = [_nested_child_1, _nested_child_2]

good_canon_config = TestConfig()
good_canon_config.msg = 'a'
good_canon_config.required_field = 10
good_canon_config.optional_field = None
good_canon_config.multi_type_field = [20]
good_canon_config.child = _child

# test function

def test_good():
    config = TestConfig(**good)
    assert config == good_config
    config.validate()
    assert config.json() == good_canon_config.json()

def test_bad():
    for tag, bad in bads.items():
        exc = None
        try:
            config = TestConfig(**bad)
            config.validate()
        except Exception as e:
            exc = e
148
        assert exc is not None
liuzhe-lz's avatar
liuzhe-lz committed
149
150
151
152

if __name__ == '__main__':
    test_good()
    test_bad()