test_hpo_utils.py 2.59 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from nni.common.hpo_utils import validate_search_space

good = {
    'choice': { '_type': 'choice', '_value': ['a', 'b'] },
    'randint': { '_type': 'randint', '_value': [1, 10] },
    'uniform': { '_type': 'uniform', '_value': [0, 1.0] },
    'quniform': { '_type': 'quniform', '_value': [1, 10, 0.1] },
    'loguniform': { '_type': 'loguniform', '_value': [0.001, 0.1] },
    'qloguniform': { '_type': 'qloguniform', '_value': [0.001, 0.1, 0.001] },
    'normal': { '_type': 'normal', '_value': [0, 0.1] },
    'qnormal': { '_type': 'qnormal', '_value': [0.5, 0.1, 0.1] },
    'lognormal': { '_type': 'lognormal', '_value': [0.0, 1] },
    'qlognormal': { '_type': 'qlognormal', '_value': [-1, 1, 0.1] },
}
good_partial = {
    'choice': good['choice'],
    'randint': good['randint'],
}

bad_type = 'x'
bad_spec_type = { 'x': [1, 2, 3] }
bad_fields = { 'x': { 'type': 'choice', 'value': ['a', 'b'] } }
bad_type_name = { 'x': { '_type': 'choic', '_value': ['a'] } }
bad_value = { 'x': { '_type': 'choice', '_value': 'ab' } }
bad_2_args = { 'x': { '_type': 'randint', '_value': [1, 2, 3] } }
bad_3_args = { 'x': { '_type': 'quniform', '_value': [0] } }
bad_int_args = { 'x': { '_type': 'randint', '_value': [1.0, 2.0] } }
bad_float_args = { 'x': { '_type': 'uniform', '_value': ['0.1', '0.2'] } }
bad_low_high = { 'x': { '_type': 'quniform', '_value': [2, 1, 0.1] } }
bad_log = { 'x': { '_type': 'loguniform', '_value': [0, 1] } }
bad_sigma = { 'x': { '_type': 'normal', '_value': [0, 0] } }

def test_hpo_utils():
    assert validate_search_space(good, raise_exception=False)
    assert not validate_search_space(bad_type, raise_exception=False)
    assert not validate_search_space(bad_spec_type, raise_exception=False)
    assert not validate_search_space(bad_fields, raise_exception=False)
    assert not validate_search_space(bad_type_name, raise_exception=False)
    assert not validate_search_space(bad_value, raise_exception=False)
    assert not validate_search_space(bad_2_args, raise_exception=False)
    assert not validate_search_space(bad_3_args, raise_exception=False)
    assert not validate_search_space(bad_int_args, raise_exception=False)
    assert not validate_search_space(bad_float_args, raise_exception=False)
    assert not validate_search_space(bad_low_high, raise_exception=False)
    assert not validate_search_space(bad_log, raise_exception=False)
    assert not validate_search_space(bad_sigma, raise_exception=False)

    assert validate_search_space(good_partial, ['choice', 'randint'], False)
    assert not validate_search_space(good, ['choice', 'randint'], False)

if __name__ == '__main__':
    test_hpo_utils()