search_space_generator.py 5.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================


import ast
Zejun Lin's avatar
Zejun Lin committed
23
import astor
24
import numbers
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40

# pylint: disable=unidiomatic-typecheck


# list of functions related to search space generating
_ss_funcs = [
    'choice',
    'randint',
    'uniform',
    'quniform',
    'loguniform',
    'qloguniform',
    'normal',
    'qnormal',
    'lognormal',
    'qlognormal',
41
42
    'function_choice',
    'mutable_layer'
43
44
45
]


Zejun Lin's avatar
Zejun Lin committed
46
class SearchSpaceGenerator(ast.NodeTransformer):
47
48
49
50
51
52
53
    """Generate search space from smart parater APIs"""

    def __init__(self, module_name):
        self.module_name = module_name
        self.search_space = {}
        self.last_line = 0  # last parsed line, useful for error reporting

54
55
56
    def generate_mutable_layer_search_space(self, args):
        mutable_block = args[0].s
        mutable_layer = args[1].s
57
58
59
        key = self.module_name + '/' + mutable_block
        args[0].s = key
        if key not in self.search_space:
Zejun Lin's avatar
Zejun Lin committed
60
61
            self.search_space[key] = {'_type': 'mutable_layer', '_value': {}}
        self.search_space[key]['_value'][mutable_layer] = {
62
63
64
            'layer_choice': [k.s for k in args[2].keys],
            'optional_inputs': [k.s for k in args[5].keys],
            'optional_input_size': args[6].n if isinstance(args[6], ast.Num) else [args[6].elts[0].n, args[6].elts[1].n]
65
66
67
        }


68
69
70
71
72
    def visit_Call(self, node):  # pylint: disable=invalid-name
        self.generic_visit(node)

        # ignore if the function is not 'nni.*'
        if type(node.func) is not ast.Attribute:
Zejun Lin's avatar
Zejun Lin committed
73
            return node
74
        if type(node.func.value) is not ast.Name:
Zejun Lin's avatar
Zejun Lin committed
75
            return node
76
        if node.func.value.id != 'nni':
Zejun Lin's avatar
Zejun Lin committed
77
            return node
78
79
80
81

        # ignore if its not a search space function (e.g. `report_final_result`)
        func = node.func.attr
        if func not in _ss_funcs:
Zejun Lin's avatar
Zejun Lin committed
82
            return node
83
84
85

        self.last_line = node.lineno

86
87
88
89
        if func == 'mutable_layer':
            self.generate_mutable_layer_search_space(node.args)
            return node

90
91
92
93
94
95
96
97
98
        if node.keywords:
            # there is a `name` argument
            assert len(node.keywords) == 1, 'Smart parameter has keyword argument other than "name"'
            assert node.keywords[0].arg == 'name', 'Smart paramater\'s keyword argument is not "name"'
            assert type(node.keywords[0].value) is ast.Str, 'Smart parameter\'s name must be string literal'
            name = node.keywords[0].value.s
            specified_name = True
        else:
            # generate the missing name automatically
99
            name = '__line' + str(str(node.args[-1].lineno))
100
            specified_name = False
Zejun Lin's avatar
Zejun Lin committed
101
            node.keywords = list()
102
103

        if func in ('choice', 'function_choice'):
104
105
106
107
            # we will use keys in the dict as the choices, which is generated by code_generator according to the args given by user
            assert len(node.args) == 1, 'Smart parameter has arguments other than dict'
            # check if it is a number or a string and get its value accordingly
            args = [key.n if type(key) is ast.Num else key.s for key in node.args[0].keys]
108
109
        else:
            # arguments of other functions must be literal number
110
111
112
            assert all(isinstance(ast.literal_eval(astor.to_source(arg)), numbers.Real) for arg in node.args), \
            'Smart parameter\'s arguments must be number literals'
            args = [ast.literal_eval(astor.to_source(arg)) for arg in node.args]
113
114

        key = self.module_name + '/' + name + '/' + func
Zejun Lin's avatar
Zejun Lin committed
115
116
117
        # store key in ast.Call
        node.keywords.append(ast.keyword(arg='key', value=ast.Str(s=key)))

118
119
120
121
122
123
124
125
126
127
128
129
130
131
        if func == 'function_choice':
            func = 'choice'
        value = {'_type': func, '_value': args}

        if specified_name:
            # multiple functions with same name must have identical arguments
            old = self.search_space.get(key)
            assert old is None or old == value, 'Different smart parameters have same name'
        else:
            # generated name must not duplicate
            assert key not in self.search_space, 'Only one smart parameter is allowed in a line'

        self.search_space[key] = value

Zejun Lin's avatar
Zejun Lin committed
132
133
        return node

134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150

def generate(module_name, code):
    """Generate search space.
    Return a serializable search space object.
    module_name: name of the module (str)
    code: user code (str)
    """
    try:
        ast_tree = ast.parse(code)
    except Exception:
        raise RuntimeError('Bad Python code')

    visitor = SearchSpaceGenerator(module_name)
    try:
        visitor.visit(ast_tree)
    except AssertionError as exc:
        raise RuntimeError('%d: %s' % (visitor.last_line, exc.args[0]))
Zejun Lin's avatar
Zejun Lin committed
151
    return visitor.search_space, astor.to_source(ast_tree)