test_annotation.py 1.9 KB
Newer Older
liuzhe-lz's avatar
liuzhe-lz committed
1
2
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
3
4
5
6
7

# pylint: skip-file

from .__init__ import *

chicm-ms's avatar
chicm-ms committed
8
import sys
9
10
11
12
import ast
import json
import os
import shutil
fishyds's avatar
fishyds committed
13
import tempfile
chicm-ms's avatar
chicm-ms committed
14
from unittest import TestCase, main, skipIf
15
16
17
18
19
20
21
22
23
24


class AnnotationTestCase(TestCase):
    @classmethod
    def setUpClass(cls):
        os.chdir('nni_annotation')
        if os.path.isdir('_generated'):
            shutil.rmtree('_generated')

    def test_search_space_generator(self):
25
26
        shutil.copytree('testcase/annotated', '_generated/annotated')
        search_space = generate_search_space('_generated/annotated')
27
28
29
        with open('testcase/searchspace.json') as f:
            self.assertEqual(search_space, json.load(f))

chicm-ms's avatar
chicm-ms committed
30
    @skipIf(sys.version_info.major == 3 and sys.version_info.minor > 7, "skip for python3.8 temporarily")
31
    def test_code_generator(self):
32
33
34
35
36
37
        code_dir = expand_annotations('testcase/usercode', '_generated/usercode', nas_mode='classic_mode')
        self.assertEqual(code_dir, '_generated/usercode')
        self._assert_source_equal('testcase/annotated/nas.py', '_generated/usercode/nas.py')
        self._assert_source_equal('testcase/annotated/mnist.py', '_generated/usercode/mnist.py')
        self._assert_source_equal('testcase/annotated/dir/simple.py', '_generated/usercode/dir/simple.py')
        with open('testcase/usercode/nonpy.txt') as src, open('_generated/usercode/nonpy.txt') as dst:
38
39
            assert src.read() == dst.read()

fishyds's avatar
fishyds committed
40
41
42
43
44
    def test_annotation_detecting(self):
        dir_ = 'testcase/usercode/non_annotation'
        code_dir = expand_annotations(dir_, tempfile.mkdtemp())
        self.assertEqual(code_dir, dir_)

45
46
47
48
49
50
51
52
53
    def _assert_source_equal(self, src1, src2):
        with open(src1) as f1, open(src2) as f2:
            ast1 = ast.dump(ast.parse(f1.read()))
            ast2 = ast.dump(ast.parse(f2.read()))
        self.assertEqual(ast1, ast2)


if __name__ == '__main__':
    main()