Commit 7108466c authored by Zejun Lin's avatar Zejun Lin Committed by QuanluZhang
Browse files

fix annotation key-error (#806)

* fix annotation, resolve annotation's key err bug, refactor the design
parent f10c3311
......@@ -82,52 +82,40 @@ if env_args.platform is None:
else:
def choice(options, name=None):
return options[_get_param('choice', name)]
def choice(options, name=None, key=None):
return options[_get_param(key)]
def randint(upper, name=None):
return _get_param('randint', name)
def randint(upper, name=None, key=None):
return _get_param(key)
def uniform(low, high, name=None):
return _get_param('uniform', name)
def uniform(low, high, name=None, key=None):
return _get_param(key)
def quniform(low, high, q, name=None):
return _get_param('quniform', name)
def quniform(low, high, q, name=None, key=None):
return _get_param(key)
def loguniform(low, high, name=None):
return _get_param('loguniform', name)
def loguniform(low, high, name=None, key=None):
return _get_param(key)
def qloguniform(low, high, q, name=None):
return _get_param('qloguniform', name)
def qloguniform(low, high, q, name=None, key=None):
return _get_param(key)
def normal(mu, sigma, name=None):
return _get_param('normal', name)
def normal(mu, sigma, name=None, key=None):
return _get_param(key)
def qnormal(mu, sigma, q, name=None):
return _get_param('qnormal', name)
def qnormal(mu, sigma, q, name=None, key=None):
return _get_param(key)
def lognormal(mu, sigma, name=None):
return _get_param('lognormal', name)
def lognormal(mu, sigma, name=None, key=None):
return _get_param(key)
def qlognormal(mu, sigma, q, name=None):
return _get_param('qlognormal', name)
def function_choice(funcs, name=None):
return funcs[_get_param('function_choice', name)]()
def _get_param(func, name):
# frames:
# layer 0: this function
# layer 1: the API function (caller of this function)
# layer 2: caller of the API function
frame = inspect.stack(0)[2]
filename = frame.filename
lineno = frame.lineno # NOTE: this is the lineno of caller's last argument
del frame # see official doc
module = inspect.getmodulename(filename)
if name is None:
name = '__line{:d}'.format(lineno)
key = '{}/{}/{}'.format(module, name, func)
def qlognormal(mu, sigma, q, name=None, key=None):
return _get_param(key)
def function_choice(funcs, name=None, key=None):
return funcs[_get_param(key)]()
def _get_param(key):
if trial._params is None:
trial.get_next_parameter()
return trial.get_current_parameter(key)
......@@ -29,8 +29,6 @@ import nni.trial
from unittest import TestCase, main
lineno1 = 61
lineno2 = 75
class SmartParamTestCase(TestCase):
def setUp(self):
......@@ -39,43 +37,30 @@ class SmartParamTestCase(TestCase):
'test_smartparam/choice2/choice': '3*2+1',
'test_smartparam/choice3/choice': '[1, 2]',
'test_smartparam/choice4/choice': '{"a", 2}',
'test_smartparam/__line{:d}/uniform'.format(lineno1): '5',
'test_smartparam/func/function_choice': 'bar',
'test_smartparam/lambda_func/function_choice': "lambda: 2*3",
'test_smartparam/__line{:d}/function_choice'.format(lineno2): 'max(1, 2, 3)'
'test_smartparam/lambda_func/function_choice': "lambda: 2*3"
}
nni.trial._params = { 'parameter_id': 'test_trial', 'parameters': params }
def test_specified_name(self):
val = nni.choice({'a': 'a', '3*2+1': 3*2+1, '[1, 2]': [1, 2], '{"a", 2}': {"a", 2}}, name = 'choice1')
val = nni.choice({'a': 'a', '3*2+1': 3*2+1, '[1, 2]': [1, 2], '{"a", 2}': {"a", 2}}, name = 'choice1', key='test_smartparam/choice1/choice')
self.assertEqual(val, 'a')
val = nni.choice({'a': 'a', '3*2+1': 3*2+1, '[1, 2]': [1, 2], '{"a", 2}': {"a", 2}}, name = 'choice2')
val = nni.choice({'a': 'a', '3*2+1': 3*2+1, '[1, 2]': [1, 2], '{"a", 2}': {"a", 2}}, name = 'choice2', key='test_smartparam/choice2/choice')
self.assertEqual(val, 7)
val = nni.choice({'a': 'a', '3*2+1': 3*2+1, '[1, 2]': [1, 2], '{"a", 2}': {"a", 2}}, name = 'choice3')
val = nni.choice({'a': 'a', '3*2+1': 3*2+1, '[1, 2]': [1, 2], '{"a", 2}': {"a", 2}}, name = 'choice3', key='test_smartparam/choice3/choice')
self.assertEqual(val, [1, 2])
val = nni.choice({'a': 'a', '3*2+1': 3*2+1, '[1, 2]': [1, 2], '{"a", 2}': {"a", 2}}, name = 'choice4')
val = nni.choice({'a': 'a', '3*2+1': 3*2+1, '[1, 2]': [1, 2], '{"a", 2}': {"a", 2}}, name = 'choice4', key='test_smartparam/choice4/choice')
self.assertEqual(val, {"a", 2})
def test_default_name(self):
val = nni.uniform(1, 10) # NOTE: assign this line number to lineno1
self.assertEqual(val, '5')
def test_specified_name_func(self):
val = nni.function_choice({'foo': foo, 'bar': bar}, name = 'func')
def test_func(self):
val = nni.function_choice({'foo': foo, 'bar': bar}, name='func', key='test_smartparam/func/function_choice')
self.assertEqual(val, 'bar')
def test_lambda_func(self):
val = nni.function_choice({"lambda: 2*3": lambda: 2*3, "lambda: 3*4": lambda: 3*4}, name = 'lambda_func')
val = nni.function_choice({"lambda: 2*3": lambda: 2*3, "lambda: 3*4": lambda: 3*4}, name = 'lambda_func', key='test_smartparam/lambda_func/function_choice')
self.assertEqual(val, 6)
def test_default_name_func(self):
val = nni.function_choice({
'max(1, 2, 3)': lambda: max(1, 2, 3),
'min(1, 2)': lambda: min(1, 2) # NOTE: assign this line number to lineno2
})
self.assertEqual(val, 3)
def foo():
return 'foo'
......
......@@ -59,12 +59,15 @@ def generate_search_space(code_dir):
def _generate_file_search_space(path, module):
with open(path) as src:
try:
return search_space_generator.generate(module, src.read())
search_space, code = search_space_generator.generate(module, src.read())
except Exception as exc: # pylint: disable=broad-except
if exc.args:
raise RuntimeError(path + ' ' + '\n'.join(exc.args))
else:
raise RuntimeError('Failed to generate search space for %s: %r' % (path, exc))
with open(path, 'w') as dst:
dst.write(code)
return search_space
def expand_annotations(src_dir, dst_dir):
......
......@@ -20,6 +20,7 @@
import ast
import astor
# pylint: disable=unidiomatic-typecheck
......@@ -40,7 +41,7 @@ _ss_funcs = [
]
class SearchSpaceGenerator(ast.NodeVisitor):
class SearchSpaceGenerator(ast.NodeTransformer):
"""Generate search space from smart parater APIs"""
def __init__(self, module_name):
......@@ -53,16 +54,16 @@ class SearchSpaceGenerator(ast.NodeVisitor):
# ignore if the function is not 'nni.*'
if type(node.func) is not ast.Attribute:
return
return node
if type(node.func.value) is not ast.Name:
return
return node
if node.func.value.id != 'nni':
return
return node
# ignore if its not a search space function (e.g. `report_final_result`)
func = node.func.attr
if func not in _ss_funcs:
return
return node
self.last_line = node.lineno
......@@ -77,6 +78,7 @@ class SearchSpaceGenerator(ast.NodeVisitor):
# generate the missing name automatically
name = '__line' + str(str(node.args[-1].lineno))
specified_name = False
node.keywords = list()
if func in ('choice', 'function_choice'):
# we will use keys in the dict as the choices, which is generated by code_generator according to the args given by user
......@@ -89,6 +91,9 @@ class SearchSpaceGenerator(ast.NodeVisitor):
args = [arg.n for arg in node.args]
key = self.module_name + '/' + name + '/' + func
# store key in ast.Call
node.keywords.append(ast.keyword(arg='key', value=ast.Str(s=key)))
if func == 'function_choice':
func = 'choice'
value = {'_type': func, '_value': args}
......@@ -103,6 +108,8 @@ class SearchSpaceGenerator(ast.NodeVisitor):
self.search_space[key] = value
return node
def generate(module_name, code):
"""Generate search space.
......@@ -120,4 +127,4 @@ def generate(module_name, code):
visitor.visit(ast_tree)
except AssertionError as exc:
raise RuntimeError('%d: %s' % (visitor.last_line, exc.args[0]))
return visitor.search_space
return visitor.search_space, astor.to_source(ast_tree)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment