test_nas.py 3.92 KB
Newer Older
QuanluZhang's avatar
QuanluZhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import importlib
import os
import sys
from unittest import TestCase, main

import torch
import torch.nn as nn
from nni.nas.pytorch.classic_nas import get_and_apply_next_architecture
from nni.nas.pytorch.darts import DartsMutator
from nni.nas.pytorch.enas import EnasMutator
from nni.nas.pytorch.fixed import apply_fixed_architecture
from nni.nas.pytorch.random import RandomMutator
from nni.nas.pytorch.utils import _reset_global_mutable_counting


class NasTestCase(TestCase):

    def setUp(self):
        self.default_input_size = [3, 32, 32]
        self.model_path = os.path.join(os.path.dirname(__file__), "models")
        sys.path.append(self.model_path)
        self.model_module = importlib.import_module("pytorch_models")
        self.default_cls = [self.model_module.NaiveSearchSpace, self.model_module.SpaceWithMutableScope]
        self.cuda_test = [0]
        if torch.cuda.is_available():
            self.cuda_test.append(1)
        if torch.cuda.device_count() > 1:
            self.cuda_test.append(torch.cuda.device_count())

    def tearDown(self):
        sys.path.remove(self.model_path)

    def iterative_sample_and_forward(self, model, mutator=None, input_size=None, n_iters=20, test_backward=True,
                                     use_cuda=False):
        if input_size is None:
            input_size = self.default_input_size
        # support pytorch only
        input_size = [8 if use_cuda else 2] + input_size  # at least 2 samples to enable batch norm
        for _ in range(n_iters):
            for param in model.parameters():
                param.grad = None
            if mutator is not None:
                mutator.reset()
            x = torch.randn(input_size)
            if use_cuda:
                x = x.cuda()
            y = torch.sum(model(x))
            if test_backward:
                y.backward()

    def default_mutator_test_pipeline(self, mutator_cls):
        for model_cls in self.default_cls:
            for cuda_test in self.cuda_test:
                _reset_global_mutable_counting()
                model = model_cls(self)
                mutator = mutator_cls(model)
                if cuda_test:
                    model.cuda()
                    mutator.cuda()
                    if cuda_test > 1:
                        model = nn.DataParallel(model)
                self.iterative_sample_and_forward(model, mutator, use_cuda=cuda_test)
                _reset_global_mutable_counting()
                model_fixed = model_cls(self)
                if cuda_test:
                    model_fixed.cuda()
                    if cuda_test > 1:
                        model_fixed = nn.DataParallel(model_fixed)
                with torch.no_grad():
                    arc = mutator.export()
                apply_fixed_architecture(model_fixed, arc)
                self.iterative_sample_and_forward(model_fixed, n_iters=1, use_cuda=cuda_test)

    def test_random_mutator(self):
        self.default_mutator_test_pipeline(RandomMutator)

    def test_enas_mutator(self):
        self.default_mutator_test_pipeline(EnasMutator)

    def test_darts_mutator(self):
        # DARTS doesn't support DataParallel. To be fixed.
        self.cuda_test = [t for t in self.cuda_test if t <= 1]
        self.default_mutator_test_pipeline(DartsMutator)

    def test_apply_twice(self):
        model = self.model_module.NaiveSearchSpace(self)
        with self.assertRaises(RuntimeError):
            for _ in range(2):
                RandomMutator(model)

    def test_nested_space(self):
        model = self.model_module.NestedSpace(self)
        with self.assertRaises(RuntimeError):
            RandomMutator(model)

    def test_classic_nas(self):
        for model_cls in self.default_cls:
            model = model_cls(self)
            get_and_apply_next_architecture(model)
            self.iterative_sample_and_forward(model)


if __name__ == '__main__':
    main()