# Copyright (c) Microsoft Corporation. # Licensed under the MIT license. import ast import numbers import astor # pylint: disable=unidiomatic-typecheck # list of functions related to search space generating _ss_funcs = [ 'choice', 'randint', 'uniform', 'quniform', 'loguniform', 'qloguniform', 'normal', 'qnormal', 'lognormal', 'qlognormal', 'function_choice', 'mutable_layer' ] class SearchSpaceGenerator(ast.NodeTransformer): """Generate search space from smart parater APIs""" def __init__(self, module_name): self.module_name = module_name self.search_space = {} self.last_line = 0 # last parsed line, useful for error reporting def generate_mutable_layer_search_space(self, args): mutable_block = args[0].s mutable_layer = args[1].s key = self.module_name + '/' + mutable_block args[0].s = key if key not in self.search_space: self.search_space[key] = {'_type': 'mutable_layer', '_value': {}} self.search_space[key]['_value'][mutable_layer] = { 'layer_choice': [k.s for k in args[2].keys], 'optional_inputs': [k.s for k in args[5].keys], 'optional_input_size': args[6].n if isinstance(args[6], ast.Num) else [args[6].elts[0].n, args[6].elts[1].n] } def visit_Call(self, node): # pylint: disable=invalid-name self.generic_visit(node) # ignore if the function is not 'nni.*' if type(node.func) is not ast.Attribute: return node if type(node.func.value) is not ast.Name: return node if node.func.value.id != 'nni': return node # ignore if its not a search space function (e.g. `report_final_result`) func = node.func.attr if func not in _ss_funcs: return node self.last_line = node.lineno if func == 'mutable_layer': self.generate_mutable_layer_search_space(node.args) return node if node.keywords: # there is a `name` argument assert len(node.keywords) == 1, 'Smart parameter has keyword argument other than "name"' assert node.keywords[0].arg == 'name', 'Smart paramater\'s keyword argument is not "name"' assert type(node.keywords[0].value) is ast.Str, 'Smart parameter\'s name must be string literal' name = node.keywords[0].value.s specified_name = True else: # generate the missing name automatically name = '__line' + str(str(node.args[-1].lineno)) specified_name = False node.keywords = list() if func in ('choice', 'function_choice'): # we will use keys in the dict as the choices, which is generated by code_generator according to the args given by user assert len(node.args) == 1, 'Smart parameter has arguments other than dict' # check if it is a number or a string and get its value accordingly args = [key.n if type(key) is ast.Num else key.s for key in node.args[0].keys] else: # arguments of other functions must be literal number assert all(isinstance(ast.literal_eval(astor.to_source(arg)), numbers.Real) for arg in node.args), \ 'Smart parameter\'s arguments must be number literals' args = [ast.literal_eval(astor.to_source(arg)) for arg in node.args] key = self.module_name + '/' + name + '/' + func # store key in ast.Call node.keywords.append(ast.keyword(arg='key', value=ast.Str(s=key))) if func == 'function_choice': func = 'choice' value = {'_type': func, '_value': args} if specified_name: # multiple functions with same name must have identical arguments old = self.search_space.get(key) assert old is None or old == value, 'Different smart parameters have same name' else: # generated name must not duplicate assert key not in self.search_space, 'Only one smart parameter is allowed in a line' self.search_space[key] = value return node def generate(module_name, code): """Generate search space. Return a serializable search space object. module_name: name of the module (str) code: user code (str) """ try: ast_tree = ast.parse(code) except Exception: raise RuntimeError('Bad Python code') visitor = SearchSpaceGenerator(module_name) try: visitor.visit(ast_tree) except AssertionError as exc: raise RuntimeError('%d: %s' % (visitor.last_line, exc.args[0])) return visitor.search_space, astor.to_source(ast_tree)