search_space_generator.py 5.04 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================


import ast
Zejun Lin's avatar
Zejun Lin committed
23
import astor
24
import numbers
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44

# pylint: disable=unidiomatic-typecheck


# list of functions related to search space generating
_ss_funcs = [
    'choice',
    'randint',
    'uniform',
    'quniform',
    'loguniform',
    'qloguniform',
    'normal',
    'qnormal',
    'lognormal',
    'qlognormal',
    'function_choice'
]


Zejun Lin's avatar
Zejun Lin committed
45
class SearchSpaceGenerator(ast.NodeTransformer):
46
47
48
49
50
51
52
53
54
55
56
57
    """Generate search space from smart parater APIs"""

    def __init__(self, module_name):
        self.module_name = module_name
        self.search_space = {}
        self.last_line = 0  # last parsed line, useful for error reporting

    def visit_Call(self, node):  # pylint: disable=invalid-name
        self.generic_visit(node)

        # ignore if the function is not 'nni.*'
        if type(node.func) is not ast.Attribute:
Zejun Lin's avatar
Zejun Lin committed
58
            return node
59
        if type(node.func.value) is not ast.Name:
Zejun Lin's avatar
Zejun Lin committed
60
            return node
61
        if node.func.value.id != 'nni':
Zejun Lin's avatar
Zejun Lin committed
62
            return node
63
64
65
66

        # ignore if its not a search space function (e.g. `report_final_result`)
        func = node.func.attr
        if func not in _ss_funcs:
Zejun Lin's avatar
Zejun Lin committed
67
            return node
68
69
70
71
72
73
74
75
76
77
78
79

        self.last_line = node.lineno

        if node.keywords:
            # there is a `name` argument
            assert len(node.keywords) == 1, 'Smart parameter has keyword argument other than "name"'
            assert node.keywords[0].arg == 'name', 'Smart paramater\'s keyword argument is not "name"'
            assert type(node.keywords[0].value) is ast.Str, 'Smart parameter\'s name must be string literal'
            name = node.keywords[0].value.s
            specified_name = True
        else:
            # generate the missing name automatically
80
            name = '__line' + str(str(node.args[-1].lineno))
81
            specified_name = False
Zejun Lin's avatar
Zejun Lin committed
82
            node.keywords = list()
83
84

        if func in ('choice', 'function_choice'):
85
86
87
88
            # we will use keys in the dict as the choices, which is generated by code_generator according to the args given by user
            assert len(node.args) == 1, 'Smart parameter has arguments other than dict'
            # check if it is a number or a string and get its value accordingly
            args = [key.n if type(key) is ast.Num else key.s for key in node.args[0].keys]
89
90
        else:
            # arguments of other functions must be literal number
91
92
93
            assert all(isinstance(ast.literal_eval(astor.to_source(arg)), numbers.Real) for arg in node.args), \
            'Smart parameter\'s arguments must be number literals'
            args = [ast.literal_eval(astor.to_source(arg)) for arg in node.args]
94
95

        key = self.module_name + '/' + name + '/' + func
Zejun Lin's avatar
Zejun Lin committed
96
97
98
        # store key in ast.Call
        node.keywords.append(ast.keyword(arg='key', value=ast.Str(s=key)))

99
100
101
102
103
104
105
106
107
108
109
110
111
112
        if func == 'function_choice':
            func = 'choice'
        value = {'_type': func, '_value': args}

        if specified_name:
            # multiple functions with same name must have identical arguments
            old = self.search_space.get(key)
            assert old is None or old == value, 'Different smart parameters have same name'
        else:
            # generated name must not duplicate
            assert key not in self.search_space, 'Only one smart parameter is allowed in a line'

        self.search_space[key] = value

Zejun Lin's avatar
Zejun Lin committed
113
114
        return node

115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131

def generate(module_name, code):
    """Generate search space.
    Return a serializable search space object.
    module_name: name of the module (str)
    code: user code (str)
    """
    try:
        ast_tree = ast.parse(code)
    except Exception:
        raise RuntimeError('Bad Python code')

    visitor = SearchSpaceGenerator(module_name)
    try:
        visitor.visit(ast_tree)
    except AssertionError as exc:
        raise RuntimeError('%d: %s' % (visitor.last_line, exc.args[0]))
Zejun Lin's avatar
Zejun Lin committed
132
    return visitor.search_space, astor.to_source(ast_tree)