test_networkmorphism_tuner.py 6.89 KB
Newer Older
liuzhe-lz's avatar
liuzhe-lz committed
1
2
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
Lee's avatar
Lee committed
3
4
5
6
7
8

import json
from unittest import TestCase, main
from copy import deepcopy
import torch

9
10
from nni.algorithms.hpo.networkmorphism_tuner.graph import graph_to_json, json_to_graph
from nni.algorithms.hpo.networkmorphism_tuner.graph_transformer import (
Lee's avatar
Lee committed
11
12
13
14
    to_deeper_graph,
    to_skip_connection_graph,
    to_wider_graph,
)
15
16
17
from nni.algorithms.hpo.networkmorphism_tuner.layers import layer_description_extractor
from nni.algorithms.hpo.networkmorphism_tuner.networkmorphism_tuner import NetworkMorphismTuner
from nni.algorithms.hpo.networkmorphism_tuner.nn import CnnGenerator
Lee's avatar
Lee committed
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44


class NetworkMorphismTestCase(TestCase):
    """  unittest for NetworkMorphismTuner
    """

    def test_graph_json_transform(self):
        """ unittest for graph_json_transform function
        """

        graph_init = CnnGenerator(10, (32, 32, 3)).generate()
        graph_init = to_wider_graph(deepcopy(graph_init))
        graph_init = to_deeper_graph(deepcopy(graph_init))
        graph_init = to_skip_connection_graph(deepcopy(graph_init))
        json_out = graph_to_json(graph_init, "temp.json")

        graph_recover = json_to_graph(json_out)

        # compare all data in graph
        self.assertEqual(graph_init.input_shape, graph_recover.input_shape)
        self.assertEqual(graph_init.weighted, graph_recover.weighted)
        self.assertEqual(
            graph_init.layer_id_to_input_node_ids,
            graph_recover.layer_id_to_input_node_ids,
        )
        self.assertEqual(graph_init.adj_list, graph_recover.adj_list)
        self.assertEqual(
xuehui's avatar
xuehui committed
45
46
47
48
49
            graph_init.reverse_adj_list,
            graph_recover.reverse_adj_list)
        self.assertEqual(
            len(graph_init.operation_history), len(
                graph_recover.operation_history)
Lee's avatar
Lee committed
50
51
52
53
54
55
56
57
58
        )
        self.assertEqual(graph_init.n_dim, graph_recover.n_dim)
        self.assertEqual(graph_init.conv, graph_recover.conv)
        self.assertEqual(graph_init.batch_norm, graph_recover.batch_norm)
        self.assertEqual(graph_init.vis, graph_recover.vis)

        node_list_init = [node.shape for node in graph_init.node_list]
        node_list_recover = [node.shape for node in graph_recover.node_list]
        self.assertEqual(node_list_init, node_list_recover)
xuehui's avatar
xuehui committed
59
60
        self.assertEqual(len(graph_init.node_to_id),
                         len(graph_recover.node_to_id))
Lee's avatar
Lee committed
61
62
63
64
65
66
67
68
69
70
        layer_list_init = [
            layer_description_extractor(item, graph_init.node_to_id)
            for item in graph_init.layer_list
        ]
        layer_list_recover = [
            layer_description_extractor(item, graph_recover.node_to_id)
            for item in graph_recover.layer_list
        ]
        self.assertEqual(layer_list_init, layer_list_recover)

xuehui's avatar
xuehui committed
71
72
        node_to_id_init = [graph_init.node_to_id[node]
                           for node in graph_init.node_list]
Lee's avatar
Lee committed
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
        node_to_id_recover = [
            graph_recover.node_to_id[node] for node in graph_recover.node_list
        ]
        self.assertEqual(node_to_id_init, node_to_id_recover)

        layer_to_id_init = [
            graph_init.layer_to_id[layer] for layer in graph_init.layer_list
        ]
        layer_to_id_recover = [
            graph_recover.layer_to_id[layer] for layer in graph_recover.layer_list
        ]
        self.assertEqual(layer_to_id_init, layer_to_id_recover)

    def test_to_wider_graph(self):
        """ unittest for to_wider_graph function
        """

        graph_init = CnnGenerator(10, (32, 32, 3)).generate()
        json_out = graph_to_json(graph_init, "temp.json")
        graph_recover = json_to_graph(json_out)
        wider_graph = to_wider_graph(deepcopy(graph_recover))
        model = wider_graph.produce_torch_model()
        out = model(torch.ones(1, 3, 32, 32))
        self.assertEqual(out.shape, torch.Size([1, 10]))

    def test_to_deeper_graph(self):
        """ unittest for to_deeper_graph function
        """

        graph_init = CnnGenerator(10, (32, 32, 3)).generate()
        json_out = graph_to_json(graph_init, "temp.json")
        graph_recover = json_to_graph(json_out)
105
        deeper_graph = to_deeper_graph(deepcopy(graph_recover))
Lee's avatar
Lee committed
106
107
108
109
110
111
112
113
114
115
116
        model = deeper_graph.produce_torch_model()
        out = model(torch.ones(1, 3, 32, 32))
        self.assertEqual(out.shape, torch.Size([1, 10]))

    def test_to_skip_connection_graph(self):
        """ unittest for to_skip_connection_graph function
        """

        graph_init = CnnGenerator(10, (32, 32, 3)).generate()
        json_out = graph_to_json(graph_init, "temp.json")
        graph_recover = json_to_graph(json_out)
117
        skip_connection_graph = to_skip_connection_graph(deepcopy(graph_recover))
Lee's avatar
Lee committed
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
        model = skip_connection_graph.produce_torch_model()
        out = model(torch.ones(1, 3, 32, 32))
        self.assertEqual(out.shape, torch.Size([1, 10]))

    def test_generate_parameters(self):
        """ unittest for generate_parameters function
        """

        tuner = NetworkMorphismTuner()
        model_json = tuner.generate_parameters(0)
        model_json = json.loads(model_json)
        self.assertEqual(model_json["input_shape"], [32, 32, 3])
        self.assertEqual(tuner.total_data[0][1:], (-1, 0))

    def test_receive_trial_result(self):
        """ unittest for receive_trial_result function
        """

        tuner = NetworkMorphismTuner()
        model_json = tuner.generate_parameters(0)
        tuner.receive_trial_result(0, {}, 0.7)
        (json_out, father_id, model_id) = tuner.total_data[0]

        self.assertEqual(father_id, -1)
        self.assertEqual(model_json, json_out)

        ret = {"model_id": 0, "metric_value": 0.7}
        self.assertEqual(tuner.bo.search_tree.adj_list[model_id], [])
        self.assertEqual(tuner.history[-1], ret)

    def test_update_search_space(self):
        """ unittest for update_search_space function
        """

        tuner = NetworkMorphismTuner()
        self.assertEqual(tuner.search_space, dict())
        tuner.update_search_space("Test")
        self.assertEqual(tuner.search_space, "Test")

    def test_init_search(self):
        """ unittest for init_search function
        """

        tuner = NetworkMorphismTuner()
        self.assertEqual(tuner.history, [])
        tuner.init_search()
        self.assertEqual(tuner.model_count, 1)
        self.assertEqual(len(tuner.training_queue), 1)
        self.assertEqual(len(tuner.descriptors), 1)

    def test_add_model(self):
        """ unittest for add_model function
        """

        tuner = NetworkMorphismTuner()
        tuner.add_model(0.8, 0)
        ret = {"model_id": 0, "metric_value": 0.8}
        self.assertEqual(tuner.history[-1], ret)

    def test_get_best_model_id(self):
        """ unittest for get_best_model_id function
        """

        tuner = NetworkMorphismTuner()
xuehui's avatar
xuehui committed
182
183
        tuner.add_model(0.8, 0)
        tuner.add_model(0.9, 1)
Lee's avatar
Lee committed
184
185
186
187
188
        self.assertEqual(tuner.get_best_model_id(), 1)


if __name__ == "__main__":
    main()