test_highlevel_apis.py 17.6 KB
Newer Older
1
2
import random
import unittest
3
from collections import Counter
4
5
6
7

import nni.retiarii.nn.pytorch as nn
import torch
import torch.nn.functional as F
8
from nni.retiarii import Sampler, basic_unit
9
10
from nni.retiarii.converter import convert_to_graph
from nni.retiarii.codegen import model_to_pytorch_script
11
12
13
14
from nni.retiarii.execution.python import _unpack_if_only_one
from nni.retiarii.nn.pytorch.mutator import process_inline_mutation, extract_mutation_from_pt_module
from nni.retiarii.serializer import model_wrapper
from nni.retiarii.utils import ContextStack
15
16


17
class EnumerateSampler(Sampler):
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
    def __init__(self):
        self.index = 0

    def choice(self, candidates, *args, **kwargs):
        choice = candidates[self.index % len(candidates)]
        self.index += 1
        return choice


class RandomSampler(Sampler):
    def __init__(self):
        self.counter = 0

    def choice(self, candidates, *args, **kwargs):
        self.counter += 1
        return random.choice(candidates)


36
@basic_unit
37
38
39
40
41
42
43
44
45
46
47
48
49
class MutableConv(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 3, kernel_size=1)
        self.conv2 = nn.Conv2d(3, 5, kernel_size=1)

    def forward(self, x: torch.Tensor, index: int):
        if index == 0:
            return self.conv1(x)
        else:
            return self.conv2(x)


50
class GraphIR(unittest.TestCase):
51
52
53
54
55
56
57
58
59
60
61

    def _convert_to_ir(self, model):
        script_module = torch.jit.script(model)
        return convert_to_graph(script_module, model)

    def _get_converted_pytorch_model(self, model_ir):
        model_code = model_to_pytorch_script(model_ir)
        exec_vars = {}
        exec(model_code + '\n\nconverted_model = _model()', exec_vars)
        return exec_vars['converted_model']

62
63
64
65
66
67
68
69
70
71
72
    def _get_model_with_mutators(self, pytorch_model):
        model = self._convert_to_ir(pytorch_model)
        mutators = process_inline_mutation(model)
        return model, mutators

    def get_serializer(self):
        def dummy(cls):
            return cls

        return dummy

73
    def test_layer_choice(self):
74
        @self.get_serializer()
75
76
77
78
79
80
81
82
83
84
85
        class Net(nn.Module):
            def __init__(self):
                super().__init__()
                self.module = nn.LayerChoice([
                    nn.Conv2d(3, 3, kernel_size=1),
                    nn.Conv2d(3, 5, kernel_size=1)
                ])

            def forward(self, x):
                return self.module(x)

86
        model, mutators = self._get_model_with_mutators(Net())
87
        self.assertEqual(len(mutators), 1)
88
        mutator = mutators[0].bind_sampler(EnumerateSampler())
89
90
91
92
93
94
95
96
        model1 = mutator.apply(model)
        model2 = mutator.apply(model)
        self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3)).size(),
                         torch.Size([1, 3, 3, 3]))
        self.assertEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 3, 3)).size(),
                         torch.Size([1, 5, 3, 3]))

    def test_input_choice(self):
97
        @self.get_serializer()
98
99
100
101
102
103
104
105
106
107
108
109
        class Net(nn.Module):
            def __init__(self):
                super().__init__()
                self.conv1 = nn.Conv2d(3, 3, kernel_size=1)
                self.conv2 = nn.Conv2d(3, 5, kernel_size=1)
                self.input = nn.InputChoice(2)

            def forward(self, x):
                x1 = self.conv1(x)
                x2 = self.conv2(x)
                return self.input([x1, x2])

110
        model, mutators = self._get_model_with_mutators(Net())
111
        self.assertEqual(len(mutators), 1)
112
        mutator = mutators[0].bind_sampler(EnumerateSampler())
113
114
115
116
117
118
119
120
        model1 = mutator.apply(model)
        model2 = mutator.apply(model)
        self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3)).size(),
                         torch.Size([1, 3, 3, 3]))
        self.assertEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 3, 3)).size(),
                         torch.Size([1, 5, 3, 3]))

    def test_chosen_inputs(self):
121
        @self.get_serializer()
122
123
124
125
126
127
128
129
130
131
132
133
134
        class Net(nn.Module):
            def __init__(self, reduction):
                super().__init__()
                self.conv1 = nn.Conv2d(3, 3, kernel_size=1)
                self.conv2 = nn.Conv2d(3, 3, kernel_size=1)
                self.input = nn.InputChoice(2, n_chosen=2, reduction=reduction)

            def forward(self, x):
                x1 = self.conv1(x)
                x2 = self.conv2(x)
                return self.input([x1, x2])

        for reduction in ['none', 'sum', 'mean', 'concat']:
135
            model, mutators = self._get_model_with_mutators(Net(reduction))
136
            self.assertEqual(len(mutators), 1)
137
            mutator = mutators[0].bind_sampler(EnumerateSampler())
138
139
140
141
142
143
144
145
146
147
148
149
            model = mutator.apply(model)
            result = self._get_converted_pytorch_model(model)(torch.randn(1, 3, 3, 3))
            if reduction == 'none':
                self.assertEqual(len(result), 2)
                self.assertEqual(result[0].size(), torch.Size([1, 3, 3, 3]))
                self.assertEqual(result[1].size(), torch.Size([1, 3, 3, 3]))
            elif reduction == 'concat':
                self.assertEqual(result.size(), torch.Size([1, 6, 3, 3]))
            else:
                self.assertEqual(result.size(), torch.Size([1, 3, 3, 3]))

    def test_value_choice(self):
150
        @self.get_serializer()
151
152
153
154
155
156
157
158
159
        class Net(nn.Module):
            def __init__(self):
                super().__init__()
                self.index = nn.ValueChoice([0, 1])
                self.conv = MutableConv()

            def forward(self, x):
                return self.conv(x, self.index())

160
        model, mutators = self._get_model_with_mutators(Net())
161
        self.assertEqual(len(mutators), 1)
162
        mutator = mutators[0].bind_sampler(EnumerateSampler())
163
164
165
166
167
168
169
        model1 = mutator.apply(model)
        model2 = mutator.apply(model)
        self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3)).size(),
                         torch.Size([1, 3, 3, 3]))
        self.assertEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 3, 3)).size(),
                         torch.Size([1, 5, 3, 3]))

170
    def test_value_choice_as_parameter(self):
171
        @self.get_serializer()
172
173
174
175
176
177
178
179
        class Net(nn.Module):
            def __init__(self):
                super().__init__()
                self.conv = nn.Conv2d(3, 5, kernel_size=nn.ValueChoice([3, 5]))

            def forward(self, x):
                return self.conv(x)

180
        model, mutators = self._get_model_with_mutators(Net())
181
182
183
184
185
186
187
188
189
190
        self.assertEqual(len(mutators), 1)
        mutator = mutators[0].bind_sampler(EnumerateSampler())
        model1 = mutator.apply(model)
        model2 = mutator.apply(model)
        self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 5, 5)).size(),
                         torch.Size([1, 5, 3, 3]))
        self.assertEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 5, 5)).size(),
                         torch.Size([1, 5, 1, 1]))

    def test_value_choice_as_parameter(self):
191
        @self.get_serializer()
192
193
194
195
196
197
198
199
        class Net(nn.Module):
            def __init__(self):
                super().__init__()
                self.conv = nn.Conv2d(3, 5, kernel_size=nn.ValueChoice([3, 5]))

            def forward(self, x):
                return self.conv(x)

200
        model, mutators = self._get_model_with_mutators(Net())
201
202
203
204
205
206
207
208
209
210
        self.assertEqual(len(mutators), 1)
        mutator = mutators[0].bind_sampler(EnumerateSampler())
        model1 = mutator.apply(model)
        model2 = mutator.apply(model)
        self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 5, 5)).size(),
                         torch.Size([1, 5, 3, 3]))
        self.assertEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 5, 5)).size(),
                         torch.Size([1, 5, 1, 1]))

    def test_value_choice_as_parameter(self):
211
        @self.get_serializer()
212
213
214
215
216
217
218
219
        class Net(nn.Module):
            def __init__(self):
                super().__init__()
                self.conv = nn.Conv2d(3, nn.ValueChoice([6, 8]), kernel_size=nn.ValueChoice([3, 5]))

            def forward(self, x):
                return self.conv(x)

220
        model, mutators = self._get_model_with_mutators(Net())
221
222
223
224
225
226
227
228
229
230
        self.assertEqual(len(mutators), 2)
        mutators[0].bind_sampler(EnumerateSampler())
        mutators[1].bind_sampler(EnumerateSampler())
        input = torch.randn(1, 3, 5, 5)
        self.assertEqual(self._get_converted_pytorch_model(mutators[1].apply(mutators[0].apply(model)))(input).size(),
                         torch.Size([1, 6, 3, 3]))
        self.assertEqual(self._get_converted_pytorch_model(mutators[1].apply(mutators[0].apply(model)))(input).size(),
                         torch.Size([1, 8, 1, 1]))

    def test_value_choice_as_parameter_shared(self):
231
        @self.get_serializer()
232
233
234
235
236
237
238
239
240
        class Net(nn.Module):
            def __init__(self):
                super().__init__()
                self.conv1 = nn.Conv2d(3, nn.ValueChoice([6, 8], label='shared'), 1)
                self.conv2 = nn.Conv2d(3, nn.ValueChoice([6, 8], label='shared'), 1)

            def forward(self, x):
                return self.conv1(x) + self.conv2(x)

241
        model, mutators = self._get_model_with_mutators(Net())
242
243
244
245
246
247
248
249
250
        self.assertEqual(len(mutators), 1)
        mutator = mutators[0].bind_sampler(EnumerateSampler())
        model1 = mutator.apply(model)
        model2 = mutator.apply(model)
        self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 5, 5)).size(),
                         torch.Size([1, 6, 5, 5]))
        self.assertEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 5, 5)).size(),
                         torch.Size([1, 8, 5, 5]))

251
    def test_value_choice_in_functional(self):
252
        @self.get_serializer()
253
254
255
256
257
258
259
260
        class Net(nn.Module):
            def __init__(self):
                super().__init__()
                self.dropout_rate = nn.ValueChoice([0., 1.])

            def forward(self, x):
                return F.dropout(x, self.dropout_rate())

261
        model, mutators = self._get_model_with_mutators(Net())
262
        self.assertEqual(len(mutators), 1)
263
        mutator = mutators[0].bind_sampler(EnumerateSampler())
264
265
        model1 = mutator.apply(model)
        model2 = mutator.apply(model)
266
        self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3))
267
268
269
        self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3)).size(), torch.Size([1, 3, 3, 3]))
        self.assertAlmostEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 3, 3)).abs().sum().item(), 0)

270
    def test_value_choice_in_layer_choice(self):
271
        @self.get_serializer()
272
273
274
275
276
277
278
279
280
281
282
        class Net(nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = nn.LayerChoice([
                    nn.Linear(3, nn.ValueChoice([10, 20])),
                    nn.Linear(3, nn.ValueChoice([30, 40]))
                ])

            def forward(self, x):
                return self.linear(x)

283
        model, mutators = self._get_model_with_mutators(Net())
284
285
286
287
288
289
290
291
292
293
        self.assertEqual(len(mutators), 3)
        sz_counter = Counter()
        sampler = RandomSampler()
        for i in range(100):
            model_new = model
            for mutator in mutators:
                model_new = mutator.bind_sampler(sampler).apply(model_new)
            sz_counter[self._get_converted_pytorch_model(model_new)(torch.randn(1, 3)).size(1)] += 1
        self.assertEqual(len(sz_counter), 4)

294
    def test_shared(self):
295
        @self.get_serializer()
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
        class Net(nn.Module):
            def __init__(self, shared=True):
                super().__init__()
                labels = ['x', 'x'] if shared else [None, None]
                self.module1 = nn.LayerChoice([
                    nn.Conv2d(3, 3, kernel_size=1),
                    nn.Conv2d(3, 5, kernel_size=1)
                ], label=labels[0])
                self.module2 = nn.LayerChoice([
                    nn.Conv2d(3, 3, kernel_size=1),
                    nn.Conv2d(3, 5, kernel_size=1)
                ], label=labels[1])

            def forward(self, x):
                return self.module1(x) + self.module2(x)

312
        model, mutators = self._get_model_with_mutators(Net())
313
314
315
316
317
318
        self.assertEqual(len(mutators), 1)
        sampler = RandomSampler()
        mutator = mutators[0].bind_sampler(sampler)
        self.assertEqual(self._get_converted_pytorch_model(mutator.apply(model))(torch.randn(1, 3, 3, 3)).size(0), 1)
        self.assertEqual(sampler.counter, 1)

319
        model, mutators = self._get_model_with_mutators(Net(shared=False))
320
321
322
323
324
        self.assertEqual(len(mutators), 2)
        sampler = RandomSampler()
        # repeat test. Expectation: sometimes succeeds, sometimes fails.
        failed_count = 0
        for i in range(30):
325
            model_new = model
326
            for mutator in mutators:
327
                model_new = mutator.bind_sampler(sampler).apply(model_new)
328
329
            self.assertEqual(sampler.counter, 2 * (i + 1))
            try:
330
                self._get_converted_pytorch_model(model_new)(torch.randn(1, 3, 3, 3))
331
332
333
334
            except RuntimeError:
                failed_count += 1
        self.assertGreater(failed_count, 0)
        self.assertLess(failed_count, 30)
335
336

    def test_valuechoice_access(self):
337
        @self.get_serializer()
338
339
340
341
342
343
344
345
346
        class Net(nn.Module):
            def __init__(self):
                super().__init__()
                vc = nn.ValueChoice([(6, 3), (8, 5)])
                self.conv = nn.Conv2d(3, vc[0], kernel_size=vc[1])

            def forward(self, x):
                return self.conv(x)

347
        model, mutators = self._get_model_with_mutators(Net())
348
349
350
351
352
353
354
355
        self.assertEqual(len(mutators), 1)
        mutators[0].bind_sampler(EnumerateSampler())
        input = torch.randn(1, 3, 5, 5)
        self.assertEqual(self._get_converted_pytorch_model(mutators[0].apply(model))(input).size(),
                         torch.Size([1, 6, 3, 3]))
        self.assertEqual(self._get_converted_pytorch_model(mutators[0].apply(model))(input).size(),
                         torch.Size([1, 8, 1, 1]))

356
        @self.get_serializer()
357
358
359
360
361
362
363
364
365
366
367
368
369
370
        class Net2(nn.Module):
            def __init__(self):
                super().__init__()
                choices = [
                    {'b': [3], 'bp': [6]},
                    {'b': [6], 'bp': [12]}
                ]
                self.conv = nn.Conv2d(3, nn.ValueChoice(choices, label='a')['b'][0], 1)
                self.conv1 = nn.Conv2d(nn.ValueChoice(choices, label='a')['bp'][0], 3, 1)

            def forward(self, x):
                x = self.conv(x)
                return self.conv1(torch.cat((x, x), 1))

371
        model, mutators = self._get_model_with_mutators(Net2())
372
373
374
375
376
377
        self.assertEqual(len(mutators), 1)
        mutators[0].bind_sampler(EnumerateSampler())
        input = torch.randn(1, 3, 5, 5)
        self._get_converted_pytorch_model(mutators[0].apply(model))(input)

    def test_valuechoice_access_functional(self):
378
        @self.get_serializer()
379
380
381
382
383
384
385
386
        class Net(nn.Module):
            def __init__(self):
                super().__init__()
                self.dropout_rate = nn.ValueChoice([[0.,], [1.,]])

            def forward(self, x):
                return F.dropout(x, self.dropout_rate()[0])

387
        model, mutators = self._get_model_with_mutators(Net())
388
389
390
391
392
393
394
395
396
        self.assertEqual(len(mutators), 1)
        mutator = mutators[0].bind_sampler(EnumerateSampler())
        model1 = mutator.apply(model)
        model2 = mutator.apply(model)
        self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3))
        self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3)).size(), torch.Size([1, 3, 3, 3]))
        self.assertAlmostEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 3, 3)).abs().sum().item(), 0)

    def test_valuechoice_access_functional_expression(self):
397
        @self.get_serializer()
398
399
400
401
402
403
404
405
406
407
        class Net(nn.Module):
            def __init__(self):
                super().__init__()
                self.dropout_rate = nn.ValueChoice([[1.05,], [1.1,]])

            def forward(self, x):
                # if expression failed, the exception would be:
                # ValueError: dropout probability has to be between 0 and 1, but got 1.05
                return F.dropout(x, self.dropout_rate()[0] - .1)

408
        model, mutators = self._get_model_with_mutators(Net())
409
410
411
412
413
414
415
        self.assertEqual(len(mutators), 1)
        mutator = mutators[0].bind_sampler(EnumerateSampler())
        model1 = mutator.apply(model)
        model2 = mutator.apply(model)
        self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3))
        self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3)).size(), torch.Size([1, 3, 3, 3]))
        self.assertAlmostEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 3, 3)).abs().sum().item(), 0)
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441


class Python(GraphIR):
    def _get_converted_pytorch_model(self, model_ir):
        mutation = {mut.mutator.label: _unpack_if_only_one(mut.samples) for mut in model_ir.history}
        with ContextStack('fixed', mutation):
            model = model_ir.python_class(**model_ir.python_init_params)
            return model

    def _get_model_with_mutators(self, pytorch_model):
        return extract_mutation_from_pt_module(pytorch_model)

    def get_serializer(self):
        return model_wrapper

    @unittest.skip
    def test_value_choice(self): ...

    @unittest.skip
    def test_value_choice_in_functional(self): ...

    @unittest.skip
    def test_valuechoice_access_functional(self): ...

    @unittest.skip
    def test_valuechoice_access_functional_expression(self): ...