search_space_generator.py 4.71 KB
Newer Older
liuzhe-lz's avatar
liuzhe-lz committed
1
2
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
3
4

import ast
5
import numbers
6

liuzhe-lz's avatar
liuzhe-lz committed
7
8
import astor

9
10
from .utils import ast_Num, ast_Str

11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# pylint: disable=unidiomatic-typecheck


# list of functions related to search space generating
_ss_funcs = [
    'choice',
    'randint',
    'uniform',
    'quniform',
    'loguniform',
    'qloguniform',
    'normal',
    'qnormal',
    'lognormal',
    'qlognormal',
26
27
    'function_choice',
    'mutable_layer'
28
29
30
]


Zejun Lin's avatar
Zejun Lin committed
31
class SearchSpaceGenerator(ast.NodeTransformer):
32
33
34
35
36
37
38
    """Generate search space from smart parater APIs"""

    def __init__(self, module_name):
        self.module_name = module_name
        self.search_space = {}
        self.last_line = 0  # last parsed line, useful for error reporting

39
40
41
    def generate_mutable_layer_search_space(self, args):
        mutable_block = args[0].s
        mutable_layer = args[1].s
42
43
44
        key = self.module_name + '/' + mutable_block
        args[0].s = key
        if key not in self.search_space:
Zejun Lin's avatar
Zejun Lin committed
45
46
            self.search_space[key] = {'_type': 'mutable_layer', '_value': {}}
        self.search_space[key]['_value'][mutable_layer] = {
47
48
            'layer_choice': [k.s for k in args[2].keys],
            'optional_inputs': [k.s for k in args[5].keys],
49
            'optional_input_size': args[6].n if isinstance(args[6], ast_Num) else [args[6].elts[0].n, args[6].elts[1].n]
50
51
        }

52
53
54
55
56
    def visit_Call(self, node):  # pylint: disable=invalid-name
        self.generic_visit(node)

        # ignore if the function is not 'nni.*'
        if type(node.func) is not ast.Attribute:
Zejun Lin's avatar
Zejun Lin committed
57
            return node
58
        if type(node.func.value) is not ast.Name:
Zejun Lin's avatar
Zejun Lin committed
59
            return node
60
        if node.func.value.id != 'nni':
Zejun Lin's avatar
Zejun Lin committed
61
            return node
62
63
64
65

        # ignore if its not a search space function (e.g. `report_final_result`)
        func = node.func.attr
        if func not in _ss_funcs:
Zejun Lin's avatar
Zejun Lin committed
66
            return node
67
68
69

        self.last_line = node.lineno

70
71
72
73
        if func == 'mutable_layer':
            self.generate_mutable_layer_search_space(node.args)
            return node

74
75
76
77
        if node.keywords:
            # there is a `name` argument
            assert len(node.keywords) == 1, 'Smart parameter has keyword argument other than "name"'
            assert node.keywords[0].arg == 'name', 'Smart paramater\'s keyword argument is not "name"'
78
            assert type(node.keywords[0].value) is ast_Str, 'Smart parameter\'s name must be string literal'
79
80
81
82
            name = node.keywords[0].value.s
            specified_name = True
        else:
            # generate the missing name automatically
83
            name = '__line' + str(str(node.args[-1].lineno))
84
            specified_name = False
Zejun Lin's avatar
Zejun Lin committed
85
            node.keywords = list()
86
87

        if func in ('choice', 'function_choice'):
88
89
90
            # we will use keys in the dict as the choices, which is generated by code_generator according to the args given by user
            assert len(node.args) == 1, 'Smart parameter has arguments other than dict'
            # check if it is a number or a string and get its value accordingly
91
            args = [key.n if type(key) is ast_Num else key.s for key in node.args[0].keys]
92
93
        else:
            # arguments of other functions must be literal number
94
            assert all(isinstance(ast.literal_eval(astor.to_source(arg)), numbers.Real) for arg in node.args), \
95
                'Smart parameter\'s arguments must be number literals'
96
            args = [ast.literal_eval(astor.to_source(arg)) for arg in node.args]
97
98

        key = self.module_name + '/' + name + '/' + func
Zejun Lin's avatar
Zejun Lin committed
99
        # store key in ast.Call
100
        node.keywords.append(ast.keyword(arg='key', value=ast_Str(s=key)))
Zejun Lin's avatar
Zejun Lin committed
101

102
103
104
105
106
107
108
109
110
111
112
113
114
115
        if func == 'function_choice':
            func = 'choice'
        value = {'_type': func, '_value': args}

        if specified_name:
            # multiple functions with same name must have identical arguments
            old = self.search_space.get(key)
            assert old is None or old == value, 'Different smart parameters have same name'
        else:
            # generated name must not duplicate
            assert key not in self.search_space, 'Only one smart parameter is allowed in a line'

        self.search_space[key] = value

Zejun Lin's avatar
Zejun Lin committed
116
117
        return node

118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134

def generate(module_name, code):
    """Generate search space.
    Return a serializable search space object.
    module_name: name of the module (str)
    code: user code (str)
    """
    try:
        ast_tree = ast.parse(code)
    except Exception:
        raise RuntimeError('Bad Python code')

    visitor = SearchSpaceGenerator(module_name)
    try:
        visitor.visit(ast_tree)
    except AssertionError as exc:
        raise RuntimeError('%d: %s' % (visitor.last_line, exc.args[0]))
Zejun Lin's avatar
Zejun Lin committed
135
    return visitor.search_space, astor.to_source(ast_tree)