test_annotation.py 2.41 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================


# pylint: skip-file

from .__init__ import *

import ast
import json
import os
import shutil
from unittest import TestCase, main


class AnnotationTestCase(TestCase):
    @classmethod
    def setUpClass(cls):
        os.chdir('nni_annotation')
        if os.path.isdir('_generated'):
            shutil.rmtree('_generated')

    def test_search_space_generator(self):
        search_space = generate_search_space('testcase/annotated')
        with open('testcase/searchspace.json') as f:
            self.assertEqual(search_space, json.load(f))

    def test_code_generator(self):
        expand_annotations('testcase/usercode', '_generated')
        self._assert_source_equal('testcase/annotated/mnist.py', '_generated/mnist.py')
        self._assert_source_equal('testcase/annotated/dir/simple.py', '_generated/dir/simple.py')
        with open('testcase/usercode/nonpy.txt') as src, open('_generated/nonpy.txt') as dst:
            assert src.read() == dst.read()

    def _assert_source_equal(self, src1, src2):
        with open(src1) as f1, open(src2) as f2:
            ast1 = ast.dump(ast.parse(f1.read()))
            ast2 = ast.dump(ast.parse(f2.read()))
        self.assertEqual(ast1, ast2)


if __name__ == '__main__':
    main()