test_tuner.py 9.31 KB
Newer Older
Deshui Yu's avatar
Deshui Yu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
20
import glob
Deshui Yu's avatar
Deshui Yu committed
21
import json
22
23
24
25
import logging
import os
import shutil
import sys
Deshui Yu's avatar
Deshui Yu committed
26
27
from unittest import TestCase, main

28
29
30
31
32
33
34
35
36
37
38
from nni.batch_tuner.batch_tuner import BatchTuner
from nni.evolution_tuner.evolution_tuner import EvolutionTuner
from nni.gp_tuner.gp_tuner import GPTuner
from nni.gridsearch_tuner.gridsearch_tuner import GridSearchTuner
from nni.hyperopt_tuner.hyperopt_tuner import HyperoptTuner
from nni.metis_tuner.metis_tuner import MetisTuner
try:
    from nni.smac_tuner.smac_tuner import SMACTuner
except ImportError:
    assert sys.platform == "win32"
from nni.tuner import Tuner
Deshui Yu's avatar
Deshui Yu committed
39

40
41
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger('test_tuner')
Deshui Yu's avatar
Deshui Yu committed
42
43
44


class TunerTestCase(TestCase):
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
    """
    Targeted at testing functions of built-in tuners, including
        - [ ] load_checkpoint
        - [ ] save_checkpoint
        - [X] update_search_space
        - [X] generate_multiple_parameters
        - [ ] import_data
        - [ ] trial_end
        - [ ] receive_trial_result
    """

    def search_space_test_one(self, tuner_factory, search_space):
        tuner = tuner_factory()
        self.assertIsInstance(tuner, Tuner)
        tuner.update_search_space(search_space)

        parameters = tuner.generate_multiple_parameters(list(range(0, 50)))
        logger.info(parameters)
        self.check_range(parameters, search_space)
        if not parameters:  # TODO: not strict
            raise ValueError("No parameters generated")
        return parameters

    def check_range(self, generated_params, search_space):
        EPS = 1E-6
        for param in generated_params:
            if self._testMethodName == "test_batch":
                param = {list(search_space.keys())[0]: param}
            for k, v in param.items():
                if k.startswith("_mutable_layer"):
                    _, block, layer, choice = k.split("/")
                    cand = search_space[block]["_value"][layer].get(choice)
                    # cand could be None, e.g., optional_inputs_chosen_state
                    if choice == "layer_choice":
                        self.assertIn(v, cand)
                    if choice == "optional_input_size":
                        if isinstance(cand, int):
                            self.assertEqual(v, cand)
                        else:
                            self.assertGreaterEqual(v, cand[0])
                            self.assertLessEqual(v, cand[1])
                    if choice == "optional_inputs":
                        pass  # ignore for now
                    continue
                item = search_space[k]
                if item["_type"] == "choice":
                    self.assertIn(v, item["_value"])
                if item["_type"] == "randint":
                    self.assertIsInstance(v, int)
                if item["_type"] == "uniform":
                    self.assertIsInstance(v, float)
                if item["_type"] in ("randint", "uniform", "quniform", "loguniform", "qloguniform"):
                    self.assertGreaterEqual(v, item["_value"][0])
                    self.assertLessEqual(v, item["_value"][1])
                if item["_type"].startswith("q"):
                    multiple = v / item["_value"][2]
                    print(k, v, multiple, item)
                    if item["_value"][0] + EPS < v < item["_value"][1] - EPS:
                        self.assertAlmostEqual(int(round(multiple)), multiple)
                if item["_type"] in ("qlognormal", "lognormal"):
                    self.assertGreaterEqual(v, 0)
                if item["_type"] == "mutable_layer":
                    for layer_name in item["_value"].keys():
                        self.assertIn(v[layer_name]["chosen_layer"], item["layer_choice"])

    def search_space_test_all(self, tuner_factory, supported_types=None, ignore_types=None):
        # NOTE(yuge): ignore types
        # Supported types are listed in the table. They are meant to be supported and should be correct.
        # Other than those, all the rest are "unsupported", which are expected to produce ridiculous results
        # or throw some exceptions. However, there are certain types I can't check. For example, generate
        # "normal" using GP Tuner returns successfully and results are fine if we check the range (-inf to +inf),
        # but they make no sense: it's not a normal distribution. So they are ignored in tests for now.
        with open(os.path.join(os.path.dirname(__file__), "assets/search_space.json"), "r") as fp:
            search_space_all = json.load(fp)
        if supported_types is None:
            supported_types = ["choice", "randint", "uniform", "quniform", "loguniform", "qloguniform",
                               "normal", "qnormal", "lognormal", "qlognormal"]
        full_supported_search_space = dict()
        for single in search_space_all:
            single_keyword = single.split("_")
            space = search_space_all[single]
            expected_fail = not any([t in single_keyword for t in supported_types]) or "fail" in single_keyword
            if ignore_types is not None and any([t in ignore_types for t in single_keyword]):
                continue
            if "fail" in space:
                if self._testMethodName.split("_", 1)[1] in space.pop("fail"):
                    expected_fail = True
            single_search_space = {single: space}
            if not expected_fail:
                # supports this key
                self.search_space_test_one(tuner_factory, single_search_space)
                full_supported_search_space.update(single_search_space)
            else:
                # unsupported key
                with self.assertRaises(Exception, msg="Testing {}".format(single)) as cm:
                    self.search_space_test_one(tuner_factory, single_search_space)
                logger.info("%s %s %s", tuner_factory, single, cm.exception)
        if not any(t in self._testMethodName for t in ["batch", "grid_search"]):
            # grid search fails for too many combinations
            logger.info("Full supported search space: %s", full_supported_search_space)
            self.search_space_test_one(tuner_factory, full_supported_search_space)

    def test_grid_search(self):
        self.search_space_test_all(lambda: GridSearchTuner(),
                                   supported_types=["choice", "randint", "quniform"])

    def test_tpe(self):
        self.search_space_test_all(lambda: HyperoptTuner("tpe"))

    def test_random_search(self):
        self.search_space_test_all(lambda: HyperoptTuner("random_search"))

    def test_anneal(self):
        self.search_space_test_all(lambda: HyperoptTuner("anneal"))

    def test_smac(self):
        if sys.platform == "win32":
            return  # smac doesn't work on windows
        self.search_space_test_all(lambda: SMACTuner(),
                                   supported_types=["choice", "randint", "uniform", "quniform", "loguniform"])

    def test_batch(self):
        self.search_space_test_all(lambda: BatchTuner(),
                                   supported_types=["choice"])

    def test_evolution(self):
        # Needs enough population size, otherwise it will throw a runtime error
        self.search_space_test_all(lambda: EvolutionTuner(population_size=100))

    def test_gp(self):
        self.search_space_test_all(lambda: GPTuner(),
                                   supported_types=["choice", "randint", "uniform", "quniform", "loguniform",
                                                    "qloguniform"],
                                   ignore_types=["normal", "lognormal", "qnormal", "qlognormal"])

    def test_metis(self):
        self.search_space_test_all(lambda: MetisTuner(),
                                   supported_types=["choice", "randint", "uniform", "quniform"])

    def test_networkmorphism(self):
        pass

    def test_ppo(self):
        pass

    def tearDown(self):
        file_list = glob.glob("smac3*") + ["param_config_space.pcs", "scenario.txt", "model_path"]
        for file in file_list:
            if os.path.exists(file):
                if os.path.isdir(file):
                    shutil.rmtree(file)
                else:
                    os.remove(file)
Deshui Yu's avatar
Deshui Yu committed
198
199
200
201


if __name__ == '__main__':
    main()