__init__.py 4.96 KB
Newer Older
liuzhe-lz's avatar
liuzhe-lz committed
1
2
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
3
4

import os
5
import sys
6
import shutil
7
import json
8
9
10

from . import code_generator
from . import search_space_generator
11
from . import specific_code_generator
12
13
14
15


__all__ = ['generate_search_space', 'expand_annotations']

16
17
slash = '/'
if sys.platform == "win32":
18
    slash = '\\'
19
20
21
22
23
24
25

def generate_search_space(code_dir):
    """Generate search space from Python source code.
    Return a serializable search space object.
    code_dir: directory path of source files (str)
    """
    search_space = {}
26

27
    if code_dir.endswith(slash):
28
29
30
31
32
33
34
        code_dir = code_dir[:-1]

    for subdir, _, files in os.walk(code_dir):
        # generate module name from path
        if subdir == code_dir:
            package = ''
        else:
35
            assert subdir.startswith(code_dir + slash), subdir
36
            prefix_len = len(code_dir) + 1
37
            package = subdir[prefix_len:].replace(slash, '.') + '.'
38
39
40
41
42
43
44
45
46
47
48
49

        for file_name in files:
            if file_name.endswith('.py'):
                path = os.path.join(subdir, file_name)
                module = package + file_name[:-3]
                search_space.update(_generate_file_search_space(path, module))

    return search_space

def _generate_file_search_space(path, module):
    with open(path) as src:
        try:
Zejun Lin's avatar
Zejun Lin committed
50
            search_space, code = search_space_generator.generate(module, src.read())
51
52
53
54
55
        except Exception as exc:  # pylint: disable=broad-except
            if exc.args:
                raise RuntimeError(path + ' ' + '\n'.join(exc.args))
            else:
                raise RuntimeError('Failed to generate search space for %s: %r' % (path, exc))
Zejun Lin's avatar
Zejun Lin committed
56
57
58
    with open(path, 'w') as dst:
        dst.write(code)
    return search_space
59
60


61
def expand_annotations(src_dir, dst_dir, exp_id='', trial_id='', nas_mode=None):
62
    """Expand annotations in user code.
fishyds's avatar
fishyds committed
63
    Return dst_dir if annotation detected; return src_dir if not.
64
65
    src_dir: directory path of user code (str)
    dst_dir: directory to place generated files (str)
66
    nas_mode: the mode of NAS given that NAS interface is used
67
    """
68
    if src_dir[-1] == slash:
69
        src_dir = src_dir[:-1]
70

71
    if dst_dir[-1] == slash:
72
73
        dst_dir = dst_dir[:-1]

fishyds's avatar
fishyds committed
74
75
    annotated = False

76
77
78
79
80
    for src_subdir, dirs, files in os.walk(src_dir):
        assert src_subdir.startswith(src_dir)
        dst_subdir = src_subdir.replace(src_dir, dst_dir, 1)
        os.makedirs(dst_subdir, exist_ok=True)

81
82
83
84
85
86
87
88
        # generate module name from path
        if src_subdir == src_dir:
            package = ''
        else:
            assert src_subdir.startswith(src_dir + slash), src_subdir
            prefix_len = len(src_dir) + 1
            package = src_subdir[prefix_len:].replace(slash, '.') + '.'

89
90
91
92
        for file_name in files:
            src_path = os.path.join(src_subdir, file_name)
            dst_path = os.path.join(dst_subdir, file_name)
            if file_name.endswith('.py'):
93
                if trial_id == '':
94
                    annotated |= _expand_file_annotations(src_path, dst_path, nas_mode)
95
96
97
                else:
                    module = package + file_name[:-3]
                    annotated |= _generate_specific_file(src_path, dst_path, exp_id, trial_id, module)
98
99
100
101
102
103
            else:
                shutil.copyfile(src_path, dst_path)

        for dir_name in dirs:
            os.makedirs(os.path.join(dst_subdir, dir_name), exist_ok=True)

fishyds's avatar
fishyds committed
104
105
    return dst_dir if annotated else src_dir

106
def _expand_file_annotations(src_path, dst_path, nas_mode):
107
108
    with open(src_path) as src, open(dst_path, 'w') as dst:
        try:
109
            annotated_code = code_generator.parse(src.read(), nas_mode)
fishyds's avatar
fishyds committed
110
111
112
113
114
115
            if annotated_code is None:
                shutil.copyfile(src_path, dst_path)
                return False
            dst.write(annotated_code)
            return True

116
117
        except Exception as exc:  # pylint: disable=broad-except
            if exc.args:
118
                raise RuntimeError(src_path + ' ' + '\n'.join(str(arg) for arg in exc.args))
119
120
            else:
                raise RuntimeError('Failed to expand annotations for %s: %r' % (src_path, exc))
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138

def _generate_specific_file(src_path, dst_path, exp_id, trial_id, module):
    with open(src_path) as src, open(dst_path, 'w') as dst:
        try:
            with open(os.path.expanduser('~/nni/experiments/%s/trials/%s/parameter.cfg'%(exp_id, trial_id))) as fd:
                para_cfg = json.load(fd)
            annotated_code = specific_code_generator.parse(src.read(), para_cfg["parameters"], module)
            if annotated_code is None:
                shutil.copyfile(src_path, dst_path)
                return False
            dst.write(annotated_code)
            return True

        except Exception as exc:  # pylint: disable=broad-except
            if exc.args:
                raise RuntimeError(src_path + ' ' + '\n'.join(str(arg) for arg in exc.args))
            else:
                raise RuntimeError('Failed to expand annotations for %s: %r' % (src_path, exc))