__init__.py 5.98 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================


import os
23
import sys
24
import shutil
25
import json
26
27
28

from . import code_generator
from . import search_space_generator
29
from . import specific_code_generator
30
31
32
33


__all__ = ['generate_search_space', 'expand_annotations']

34
35
slash = '/'
if sys.platform == "win32":
36
    slash = '\\'
37
38
39
40
41
42
43

def generate_search_space(code_dir):
    """Generate search space from Python source code.
    Return a serializable search space object.
    code_dir: directory path of source files (str)
    """
    search_space = {}
44

45
    if code_dir.endswith(slash):
46
47
48
49
50
51
52
        code_dir = code_dir[:-1]

    for subdir, _, files in os.walk(code_dir):
        # generate module name from path
        if subdir == code_dir:
            package = ''
        else:
53
            assert subdir.startswith(code_dir + slash), subdir
54
            prefix_len = len(code_dir) + 1
55
            package = subdir[prefix_len:].replace(slash, '.') + '.'
56
57
58
59
60
61
62
63
64
65
66
67

        for file_name in files:
            if file_name.endswith('.py'):
                path = os.path.join(subdir, file_name)
                module = package + file_name[:-3]
                search_space.update(_generate_file_search_space(path, module))

    return search_space

def _generate_file_search_space(path, module):
    with open(path) as src:
        try:
Zejun Lin's avatar
Zejun Lin committed
68
            search_space, code = search_space_generator.generate(module, src.read())
69
70
71
72
73
        except Exception as exc:  # pylint: disable=broad-except
            if exc.args:
                raise RuntimeError(path + ' ' + '\n'.join(exc.args))
            else:
                raise RuntimeError('Failed to generate search space for %s: %r' % (path, exc))
Zejun Lin's avatar
Zejun Lin committed
74
75
76
    with open(path, 'w') as dst:
        dst.write(code)
    return search_space
77
78


79
def expand_annotations(src_dir, dst_dir, exp_id='', trial_id=''):
80
    """Expand annotations in user code.
fishyds's avatar
fishyds committed
81
    Return dst_dir if annotation detected; return src_dir if not.
82
83
84
    src_dir: directory path of user code (str)
    dst_dir: directory to place generated files (str)
    """
85
    if src_dir[-1] == slash:
86
        src_dir = src_dir[:-1]
87

88
    if dst_dir[-1] == slash:
89
90
        dst_dir = dst_dir[:-1]

fishyds's avatar
fishyds committed
91
92
    annotated = False

93
94
95
96
97
    for src_subdir, dirs, files in os.walk(src_dir):
        assert src_subdir.startswith(src_dir)
        dst_subdir = src_subdir.replace(src_dir, dst_dir, 1)
        os.makedirs(dst_subdir, exist_ok=True)

98
99
100
101
102
103
104
105
        # generate module name from path
        if src_subdir == src_dir:
            package = ''
        else:
            assert src_subdir.startswith(src_dir + slash), src_subdir
            prefix_len = len(src_dir) + 1
            package = src_subdir[prefix_len:].replace(slash, '.') + '.'

106
107
108
109
        for file_name in files:
            src_path = os.path.join(src_subdir, file_name)
            dst_path = os.path.join(dst_subdir, file_name)
            if file_name.endswith('.py'):
110
111
112
113
114
                if trial_id == '':
                    annotated |= _expand_file_annotations(src_path, dst_path)
                else:
                    module = package + file_name[:-3]
                    annotated |= _generate_specific_file(src_path, dst_path, exp_id, trial_id, module)
115
116
117
118
119
120
            else:
                shutil.copyfile(src_path, dst_path)

        for dir_name in dirs:
            os.makedirs(os.path.join(dst_subdir, dir_name), exist_ok=True)

fishyds's avatar
fishyds committed
121
122
    return dst_dir if annotated else src_dir

123
124
125
def _expand_file_annotations(src_path, dst_path):
    with open(src_path) as src, open(dst_path, 'w') as dst:
        try:
fishyds's avatar
fishyds committed
126
127
128
129
130
131
132
            annotated_code = code_generator.parse(src.read())
            if annotated_code is None:
                shutil.copyfile(src_path, dst_path)
                return False
            dst.write(annotated_code)
            return True

133
134
        except Exception as exc:  # pylint: disable=broad-except
            if exc.args:
135
                raise RuntimeError(src_path + ' ' + '\n'.join(str(arg) for arg in exc.args))
136
137
            else:
                raise RuntimeError('Failed to expand annotations for %s: %r' % (src_path, exc))
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156

def _generate_specific_file(src_path, dst_path, exp_id, trial_id, module):
    with open(src_path) as src, open(dst_path, 'w') as dst:
        try:
            with open(os.path.expanduser('~/nni/experiments/%s/trials/%s/parameter.cfg'%(exp_id, trial_id))) as fd:
                para_cfg = json.load(fd)
            annotated_code = specific_code_generator.parse(src.read(), para_cfg["parameters"], module)
            if annotated_code is None:
                shutil.copyfile(src_path, dst_path)
                return False
            dst.write(annotated_code)
            return True

        except Exception as exc:  # pylint: disable=broad-except
            if exc.args:
                raise RuntimeError(src_path + ' ' + '\n'.join(str(arg) for arg in exc.args))
            else:
                raise RuntimeError('Failed to expand annotations for %s: %r' % (src_path, exc))