search_space_generator.py 4.67 KB
Newer Older
liuzhe-lz's avatar
liuzhe-lz committed
1
2
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
3
4

import ast
5
import numbers
6

liuzhe-lz's avatar
liuzhe-lz committed
7
8
import astor

9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# pylint: disable=unidiomatic-typecheck


# list of functions related to search space generating
_ss_funcs = [
    'choice',
    'randint',
    'uniform',
    'quniform',
    'loguniform',
    'qloguniform',
    'normal',
    'qnormal',
    'lognormal',
    'qlognormal',
24
25
    'function_choice',
    'mutable_layer'
26
27
28
]


Zejun Lin's avatar
Zejun Lin committed
29
class SearchSpaceGenerator(ast.NodeTransformer):
30
31
32
33
34
35
36
    """Generate search space from smart parater APIs"""

    def __init__(self, module_name):
        self.module_name = module_name
        self.search_space = {}
        self.last_line = 0  # last parsed line, useful for error reporting

37
38
39
    def generate_mutable_layer_search_space(self, args):
        mutable_block = args[0].s
        mutable_layer = args[1].s
40
41
42
        key = self.module_name + '/' + mutable_block
        args[0].s = key
        if key not in self.search_space:
Zejun Lin's avatar
Zejun Lin committed
43
44
            self.search_space[key] = {'_type': 'mutable_layer', '_value': {}}
        self.search_space[key]['_value'][mutable_layer] = {
45
46
47
            'layer_choice': [k.s for k in args[2].keys],
            'optional_inputs': [k.s for k in args[5].keys],
            'optional_input_size': args[6].n if isinstance(args[6], ast.Num) else [args[6].elts[0].n, args[6].elts[1].n]
48
49
        }

50
51
52
53
54
    def visit_Call(self, node):  # pylint: disable=invalid-name
        self.generic_visit(node)

        # ignore if the function is not 'nni.*'
        if type(node.func) is not ast.Attribute:
Zejun Lin's avatar
Zejun Lin committed
55
            return node
56
        if type(node.func.value) is not ast.Name:
Zejun Lin's avatar
Zejun Lin committed
57
            return node
58
        if node.func.value.id != 'nni':
Zejun Lin's avatar
Zejun Lin committed
59
            return node
60
61
62
63

        # ignore if its not a search space function (e.g. `report_final_result`)
        func = node.func.attr
        if func not in _ss_funcs:
Zejun Lin's avatar
Zejun Lin committed
64
            return node
65
66
67

        self.last_line = node.lineno

68
69
70
71
        if func == 'mutable_layer':
            self.generate_mutable_layer_search_space(node.args)
            return node

72
73
74
75
76
77
78
79
80
        if node.keywords:
            # there is a `name` argument
            assert len(node.keywords) == 1, 'Smart parameter has keyword argument other than "name"'
            assert node.keywords[0].arg == 'name', 'Smart paramater\'s keyword argument is not "name"'
            assert type(node.keywords[0].value) is ast.Str, 'Smart parameter\'s name must be string literal'
            name = node.keywords[0].value.s
            specified_name = True
        else:
            # generate the missing name automatically
81
            name = '__line' + str(str(node.args[-1].lineno))
82
            specified_name = False
Zejun Lin's avatar
Zejun Lin committed
83
            node.keywords = list()
84
85

        if func in ('choice', 'function_choice'):
86
87
88
89
            # we will use keys in the dict as the choices, which is generated by code_generator according to the args given by user
            assert len(node.args) == 1, 'Smart parameter has arguments other than dict'
            # check if it is a number or a string and get its value accordingly
            args = [key.n if type(key) is ast.Num else key.s for key in node.args[0].keys]
90
91
        else:
            # arguments of other functions must be literal number
92
            assert all(isinstance(ast.literal_eval(astor.to_source(arg)), numbers.Real) for arg in node.args), \
93
                'Smart parameter\'s arguments must be number literals'
94
            args = [ast.literal_eval(astor.to_source(arg)) for arg in node.args]
95
96

        key = self.module_name + '/' + name + '/' + func
Zejun Lin's avatar
Zejun Lin committed
97
98
99
        # store key in ast.Call
        node.keywords.append(ast.keyword(arg='key', value=ast.Str(s=key)))

100
101
102
103
104
105
106
107
108
109
110
111
112
113
        if func == 'function_choice':
            func = 'choice'
        value = {'_type': func, '_value': args}

        if specified_name:
            # multiple functions with same name must have identical arguments
            old = self.search_space.get(key)
            assert old is None or old == value, 'Different smart parameters have same name'
        else:
            # generated name must not duplicate
            assert key not in self.search_space, 'Only one smart parameter is allowed in a line'

        self.search_space[key] = value

Zejun Lin's avatar
Zejun Lin committed
114
115
        return node

116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132

def generate(module_name, code):
    """Generate search space.
    Return a serializable search space object.
    module_name: name of the module (str)
    code: user code (str)
    """
    try:
        ast_tree = ast.parse(code)
    except Exception:
        raise RuntimeError('Bad Python code')

    visitor = SearchSpaceGenerator(module_name)
    try:
        visitor.visit(ast_tree)
    except AssertionError as exc:
        raise RuntimeError('%d: %s' % (visitor.last_line, exc.args[0]))
Zejun Lin's avatar
Zejun Lin committed
133
    return visitor.search_space, astor.to_source(ast_tree)