test_annotation.py 1.78 KB
Newer Older
liuzhe-lz's avatar
liuzhe-lz committed
1
2
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
3
4
5
6
7
8
9
10
11

# pylint: skip-file

from .__init__ import *

import ast
import json
import os
import shutil
fishyds's avatar
fishyds committed
12
import tempfile
13
14
15
16
17
18
19
20
21
22
23
from unittest import TestCase, main


class AnnotationTestCase(TestCase):
    @classmethod
    def setUpClass(cls):
        os.chdir('nni_annotation')
        if os.path.isdir('_generated'):
            shutil.rmtree('_generated')

    def test_search_space_generator(self):
24
25
        shutil.copytree('testcase/annotated', '_generated/annotated')
        search_space = generate_search_space('_generated/annotated')
26
27
28
29
        with open('testcase/searchspace.json') as f:
            self.assertEqual(search_space, json.load(f))

    def test_code_generator(self):
30
31
32
33
34
35
        code_dir = expand_annotations('testcase/usercode', '_generated/usercode', nas_mode='classic_mode')
        self.assertEqual(code_dir, '_generated/usercode')
        self._assert_source_equal('testcase/annotated/nas.py', '_generated/usercode/nas.py')
        self._assert_source_equal('testcase/annotated/mnist.py', '_generated/usercode/mnist.py')
        self._assert_source_equal('testcase/annotated/dir/simple.py', '_generated/usercode/dir/simple.py')
        with open('testcase/usercode/nonpy.txt') as src, open('_generated/usercode/nonpy.txt') as dst:
36
37
            assert src.read() == dst.read()

fishyds's avatar
fishyds committed
38
39
40
41
42
    def test_annotation_detecting(self):
        dir_ = 'testcase/usercode/non_annotation'
        code_dir = expand_annotations(dir_, tempfile.mkdtemp())
        self.assertEqual(code_dir, dir_)

43
44
45
46
47
48
49
50
51
    def _assert_source_equal(self, src1, src2):
        with open(src1) as f1, open(src2) as f2:
            ast1 = ast.dump(ast.parse(f1.read()))
            ast2 = ast.dump(ast.parse(f2.read()))
        self.assertEqual(ast1, ast2)


if __name__ == '__main__':
    main()