test_cgo_engine.py 13.8 KB
Newer Older
1
2
3
4
import os
import threading
import unittest
import time
QuanluZhang's avatar
QuanluZhang committed
5
import torch
6
import torch.nn as nn
7
from pytorch_lightning.utilities.seed import seed_everything
QuanluZhang's avatar
QuanluZhang committed
8
9

from pathlib import Path
10

11
import nni
12
from nni.experiment.config import RemoteConfig, RemoteMachineConfig
13
import nni.runtime.platform.test
14
from nni.runtime.tuner_command_channel import legacy as protocol
15
import json
16

17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
try:
    from nni.common.device import GPUDevice
    from nni.retiarii.execution.cgo_engine import CGOExecutionEngine
    from nni.retiarii import Model
    from nni.retiarii.graph import Node

    from nni.retiarii import Model, submit_models
    from nni.retiarii.integration import RetiariiAdvisor
    from nni.retiarii.execution import set_execution_engine
    from nni.retiarii.execution.logical_optimizer.opt_dedup_input import DedupInputOptimizer
    from nni.retiarii.execution.logical_optimizer.logical_plan import LogicalPlan
    from nni.retiarii.utils import import_

    from nni.retiarii import serialize
    import nni.retiarii.evaluator.pytorch.lightning as pl
    from nni.retiarii.evaluator.pytorch.cgo.evaluator import MultiModelSupervisedLearningModule, _MultiModelSupervisedLearningModule
    import nni.retiarii.evaluator.pytorch.cgo.trainer as cgo_trainer

35
36
    import nni.retiarii.integration_api

37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
    module_import_failed = False
except ImportError:
    module_import_failed = True

import pytest
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import Dataset
from sklearn.datasets import load_diabetes


class _model_cpu(nn.Module):
    def __init__(self):
        super().__init__()
        self.M_1_stem = M_1_stem()
        self.M_2_stem = M_2_stem()
        self.M_1_flatten = torch.nn.Flatten()
        self.M_2_flatten = torch.nn.Flatten()
        self.M_1_fc1 = torch.nn.Linear(out_features=256, in_features=1024)
        self.M_2_fc1 = torch.nn.Linear(out_features=256, in_features=1024)
        self.M_1_fc2 = torch.nn.Linear(out_features=10, in_features=256)
        self.M_2_fc2 = torch.nn.Linear(out_features=10, in_features=256)
        self.M_1_softmax = torch.nn.Softmax()
        self.M_2_softmax = torch.nn.Softmax()

    def forward(self, *_inputs):
        M_1__inputs_to_M_2_stem = _inputs[0]
        M_1_stem = self.M_1_stem(_inputs[0])
        M_2_stem = self.M_2_stem(M_1__inputs_to_M_2_stem)
        M_1_flatten = self.M_1_flatten(M_1_stem)
        M_2_flatten = self.M_2_flatten(M_2_stem)
        M_1_fc1 = self.M_1_fc1(M_1_flatten)
        M_2_fc1 = self.M_2_fc1(M_2_flatten)
        M_1_fc2 = self.M_1_fc2(M_1_fc1)
        M_2_fc2 = self.M_2_fc2(M_2_fc1)
        M_1_softmax = self.M_1_softmax(M_1_fc2)
        M_2_softmax = self.M_2_softmax(M_2_fc2)
        return M_1_softmax, M_2_softmax


class _model_gpu(nn.Module):
    def __init__(self):
        super().__init__()
        self.M_1_stem = M_1_stem().to('cuda:0')
        self.M_2_stem = M_2_stem().to('cuda:1')
        self.M_1_flatten = torch.nn.Flatten().to('cuda:0')
        self.M_2_flatten = torch.nn.Flatten().to('cuda:1')
        self.M_1_fc1 = torch.nn.Linear(out_features=256, in_features=1024).to('cuda:0')
        self.M_2_fc1 = torch.nn.Linear(out_features=256, in_features=1024).to('cuda:1')
        self.M_1_fc2 = torch.nn.Linear(out_features=10, in_features=256).to('cuda:0')
        self.M_2_fc2 = torch.nn.Linear(out_features=10, in_features=256).to('cuda:1')
        self.M_1_softmax = torch.nn.Softmax().to('cuda:0')
        self.M_2_softmax = torch.nn.Softmax().to('cuda:1')

    def forward(self, *_inputs):
        M_1__inputs_to_M_1_stem = _inputs[0].to("cuda:0")
        M_1__inputs_to_M_2_stem = _inputs[0].to("cuda:1")
        M_1_stem = self.M_1_stem(M_1__inputs_to_M_1_stem)
        M_2_stem = self.M_2_stem(M_1__inputs_to_M_2_stem)
        M_1_flatten = self.M_1_flatten(M_1_stem)
        M_2_flatten = self.M_2_flatten(M_2_stem)
        M_1_fc1 = self.M_1_fc1(M_1_flatten)
        M_2_fc1 = self.M_2_fc1(M_2_flatten)
        M_1_fc2 = self.M_1_fc2(M_1_fc1)
        M_2_fc2 = self.M_2_fc2(M_2_fc1)
        M_1_softmax = self.M_1_softmax(M_1_fc2)
        M_2_softmax = self.M_2_softmax(M_2_fc2)
        return M_1_softmax, M_2_softmax


class M_1_stem(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(out_channels=32, in_channels=1, kernel_size=5)
        self.pool1 = torch.nn.MaxPool2d(kernel_size=2)
        self.conv2 = torch.nn.Conv2d(out_channels=64, in_channels=32, kernel_size=5)
        self.pool2 = torch.nn.MaxPool2d(kernel_size=2)

    def forward(self, *_inputs):
        conv1 = self.conv1(_inputs[0])
        pool1 = self.pool1(conv1)
        conv2 = self.conv2(pool1)
        pool2 = self.pool2(conv2)
        return pool2


class M_2_stem(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(out_channels=32, in_channels=1, kernel_size=5)
        self.pool1 = torch.nn.MaxPool2d(kernel_size=2)
        self.conv2 = torch.nn.Conv2d(out_channels=64, in_channels=32, kernel_size=5)
        self.pool2 = torch.nn.MaxPool2d(kernel_size=2)

    def forward(self, *_inputs):
        conv1 = self.conv1(_inputs[0])
        pool1 = self.pool1(conv1)
        conv2 = self.conv2(pool1)
        pool2 = self.pool2(conv2)
        return pool2


def _reset():
    # this is to not affect other tests in sdk
    nni.trial._intermediate_seq = 0
    nni.trial._params = {'foo': 'bar', 'parameter_id': 0}
    nni.runtime.platform.test._last_metric = None
    nni.retiarii.integration_api._advisor = None
    nni.retiarii.execution.api._execution_engine = None
146
147
    
    seed_everything(42)
148
149
150
151
152
153
154


def _new_trainer():
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
    train_dataset = serialize(MNIST, root='data/mnist', train=True, download=True, transform=transform)
    test_dataset = serialize(MNIST, root='data/mnist', train=False, download=True, transform=transform)

155
    multi_module = _MultiModelSupervisedLearningModule(nn.CrossEntropyLoss, {'acc': pl._AccuracyWithLogits})
156
157
158
159

    lightning = pl.Lightning(multi_module, cgo_trainer.Trainer(use_cgo=True,
                                                               max_epochs=1,
                                                               limit_train_batches=0.25,
160
                                                               enable_progress_bar=False),
161
162
163
                             train_dataloader=pl.DataLoader(train_dataset, batch_size=100),
                             val_dataloaders=pl.DataLoader(test_dataset, batch_size=100))
    return lightning
164
165
166


def _load_mnist(n_models: int = 1):
Yuge Zhang's avatar
Yuge Zhang committed
167
    path = Path('ut/nas/mnist_pytorch.json')
QuanluZhang's avatar
QuanluZhang committed
168
    with open(path) as f:
169
        mnist_model = Model._load(nni.load(fp=f))
170
171
        mnist_model.evaluator = _new_trainer()

172
173
174
175
    if n_models == 1:
        return mnist_model
    else:
        models = [mnist_model]
176
177
178
179
        for i in range(n_models - 1):
            forked_model = mnist_model.fork()
            forked_model.evaluator = _new_trainer()
            models.append(forked_model)
180
        return models
181
182


183
def _get_final_result():
184
    result = nni.load(nni.runtime.platform.test._last_metric)['value']
185
186
187
188
    if isinstance(result, list):
        return [float(_) for _ in result]
    else:
        if isinstance(result, str) and '[' in result:
189
            return nni.load(result)
190
191
192
        return [float(result)]


193
class CGOEngineTest(unittest.TestCase):
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
    def setUp(self):
        if module_import_failed:
            self.skipTest('test skip due to failed import of nni.retiarii.evaluator.pytorch.lightning')

    def test_multi_model_trainer_cpu(self):
        _reset()
        transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
        train_dataset = serialize(MNIST, root='data/mnist', train=True, download=True, transform=transform)
        test_dataset = serialize(MNIST, root='data/mnist', train=False, download=True, transform=transform)

        multi_module = _MultiModelSupervisedLearningModule(nn.CrossEntropyLoss, {'acc': pl._AccuracyWithLogits}, n_models=2)

        lightning = pl.Lightning(multi_module, cgo_trainer.Trainer(use_cgo=True,
                                                                   max_epochs=1,
                                                                   limit_train_batches=0.25),
                                 train_dataloader=pl.DataLoader(train_dataset, batch_size=100),
                                 val_dataloaders=pl.DataLoader(test_dataset, batch_size=100))

        lightning._execute(_model_cpu)

        result = _get_final_result()
        assert len(result) == 2

        for _ in result:
            assert _ > 0.8

    def test_multi_model_trainer_gpu(self):
        _reset()
        if not (torch.cuda.is_available() and torch.cuda.device_count() >= 2):
            pytest.skip('test requires GPU and torch+cuda')
        transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
        train_dataset = serialize(MNIST, root='data/mnist', train=True, download=True, transform=transform)
        test_dataset = serialize(MNIST, root='data/mnist', train=False, download=True, transform=transform)

        multi_module = _MultiModelSupervisedLearningModule(nn.CrossEntropyLoss, {'acc': pl._AccuracyWithLogits}, n_models=2)

        lightning = pl.Lightning(multi_module, cgo_trainer.Trainer(use_cgo=True,
                                                                   max_epochs=1,
                                                                   limit_train_batches=0.25),
                                 train_dataloader=pl.DataLoader(train_dataset, batch_size=100),
                                 val_dataloaders=pl.DataLoader(test_dataset, batch_size=100))

        lightning._execute(_model_gpu)

        result = _get_final_result()
        assert len(result) == 2

        for _ in result:
            assert _ > 0.8

    def _build_logical_with_mnist(self, n_models: int):
        lp = LogicalPlan()
        models = _load_mnist(n_models=n_models)
        for m in models:
            lp.add_model(m)
        return lp, models

    def test_add_model(self):
        _reset()

        lp, models = self._build_logical_with_mnist(3)

        for node in lp.logical_graph.hidden_nodes:
            old_nodes = [m.root_graph.get_node_by_id(node.id) for m in models]

            self.assertTrue(any([old_nodes[0].__repr__() == Node.__repr__(x) for x in old_nodes]))

    def test_dedup_input_four_devices(self):
        _reset()

        lp, models = self._build_logical_with_mnist(3)

        opt = DedupInputOptimizer()
        opt.convert(lp)

269
        advisor = RetiariiAdvisor('ws://_unittest_placeholder_')
270
271
272
273
        advisor._channel = protocol.LegacyCommandChannel()
        advisor.default_worker.start()
        advisor.assessor_worker.start()

274
275
276
        remote = RemoteConfig(machine_list=[])
        remote.machine_list.append(RemoteMachineConfig(host='test', gpu_indices=[0,1,2,3]))
        cgo = CGOExecutionEngine(training_service=remote, batch_waiting_time=0)
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292

        phy_models = cgo._assemble(lp)
        self.assertTrue(len(phy_models) == 1)
        advisor.stopping = True
        advisor.default_worker.join()
        advisor.assessor_worker.join()
        cgo.join()

    def test_dedup_input_two_devices(self):
        _reset()

        lp, models = self._build_logical_with_mnist(3)

        opt = DedupInputOptimizer()
        opt.convert(lp)

293
        advisor = RetiariiAdvisor('ws://_unittest_placeholder_')
294
295
296
297
        advisor._channel = protocol.LegacyCommandChannel()
        advisor.default_worker.start()
        advisor.assessor_worker.start()

298
299
300
        remote = RemoteConfig(machine_list=[])
        remote.machine_list.append(RemoteMachineConfig(host='test', gpu_indices=[0,1]))
        cgo = CGOExecutionEngine(training_service=remote, batch_waiting_time=0)
301
302
303
304
305
306
307

        phy_models = cgo._assemble(lp)
        self.assertTrue(len(phy_models) == 2)
        advisor.stopping = True
        advisor.default_worker.join()
        advisor.assessor_worker.join()
        cgo.join()
308

309
    def test_submit_models(self):
310
        _reset()
311
        os.makedirs('generated', exist_ok=True)
QuanluZhang's avatar
QuanluZhang committed
312
        import nni.runtime.platform.test as tt
313
314
        protocol._set_out_file(open('generated/debug_protocol_out_file.py', 'wb'))
        protocol._set_in_file(open('generated/debug_protocol_out_file.py', 'rb'))
315
316

        models = _load_mnist(2)
317

318
        advisor = RetiariiAdvisor('ws://_unittest_placeholder_')
319
320
321
322
        advisor._channel = protocol.LegacyCommandChannel()
        advisor.default_worker.start()
        advisor.assessor_worker.start()

323
324
325
        remote = RemoteConfig(machine_list=[])
        remote.machine_list.append(RemoteMachineConfig(host='test', gpu_indices=[0,1,2,3]))
        cgo_engine = CGOExecutionEngine(training_service=remote, batch_waiting_time=0)
326
        set_execution_engine(cgo_engine)
327
        submit_models(*models)
328
        time.sleep(3)
329
330
331

        if torch.cuda.is_available() and torch.cuda.device_count() >= 2:
            cmd, data = protocol.receive()
332
            params = nni.load(data)
333

QuanluZhang's avatar
QuanluZhang committed
334
            tt.init_params(params)
335

336
            trial_thread = threading.Thread(target=CGOExecutionEngine.trial_execute_graph)
337
338
339
340
            trial_thread.start()
            last_metric = None
            while True:
                time.sleep(1)
QuanluZhang's avatar
QuanluZhang committed
341
342
                if tt._last_metric:
                    metric = tt.get_last_metric()
343
344
                    if metric == last_metric:
                        continue
345
346
                    if 'value' in metric:
                        metric['value'] = json.dumps(metric['value'])
347
348
349
                    advisor.handle_report_metric_data(metric)
                    last_metric = metric
                if not trial_thread.is_alive():
350
                    trial_thread.join()
351
352
353
                    break

            trial_thread.join()
354

355
356
357
        advisor.stopping = True
        advisor.default_worker.join()
        advisor.assessor_worker.join()
358
        cgo_engine.join()
359
360
361


if __name__ == '__main__':
362
    unittest.main()