"...composable_kernel_rocm.git" did not exist on "0c7b35c4f893b97aa70088c194b57a7a41790fff"
Commit 7108466c authored by Zejun Lin's avatar Zejun Lin Committed by QuanluZhang
Browse files

fix annotation key-error (#806)

* fix annotation, resolve annotation's key err bug, refactor the design
parent f10c3311
...@@ -82,52 +82,40 @@ if env_args.platform is None: ...@@ -82,52 +82,40 @@ if env_args.platform is None:
else: else:
def choice(options, name=None): def choice(options, name=None, key=None):
return options[_get_param('choice', name)] return options[_get_param(key)]
def randint(upper, name=None): def randint(upper, name=None, key=None):
return _get_param('randint', name) return _get_param(key)
def uniform(low, high, name=None): def uniform(low, high, name=None, key=None):
return _get_param('uniform', name) return _get_param(key)
def quniform(low, high, q, name=None): def quniform(low, high, q, name=None, key=None):
return _get_param('quniform', name) return _get_param(key)
def loguniform(low, high, name=None): def loguniform(low, high, name=None, key=None):
return _get_param('loguniform', name) return _get_param(key)
def qloguniform(low, high, q, name=None): def qloguniform(low, high, q, name=None, key=None):
return _get_param('qloguniform', name) return _get_param(key)
def normal(mu, sigma, name=None): def normal(mu, sigma, name=None, key=None):
return _get_param('normal', name) return _get_param(key)
def qnormal(mu, sigma, q, name=None): def qnormal(mu, sigma, q, name=None, key=None):
return _get_param('qnormal', name) return _get_param(key)
def lognormal(mu, sigma, name=None): def lognormal(mu, sigma, name=None, key=None):
return _get_param('lognormal', name) return _get_param(key)
def qlognormal(mu, sigma, q, name=None): def qlognormal(mu, sigma, q, name=None, key=None):
return _get_param('qlognormal', name) return _get_param(key)
def function_choice(funcs, name=None): def function_choice(funcs, name=None, key=None):
return funcs[_get_param('function_choice', name)]() return funcs[_get_param(key)]()
def _get_param(func, name): def _get_param(key):
# frames:
# layer 0: this function
# layer 1: the API function (caller of this function)
# layer 2: caller of the API function
frame = inspect.stack(0)[2]
filename = frame.filename
lineno = frame.lineno # NOTE: this is the lineno of caller's last argument
del frame # see official doc
module = inspect.getmodulename(filename)
if name is None:
name = '__line{:d}'.format(lineno)
key = '{}/{}/{}'.format(module, name, func)
if trial._params is None: if trial._params is None:
trial.get_next_parameter() trial.get_next_parameter()
return trial.get_current_parameter(key) return trial.get_current_parameter(key)
...@@ -29,8 +29,6 @@ import nni.trial ...@@ -29,8 +29,6 @@ import nni.trial
from unittest import TestCase, main from unittest import TestCase, main
lineno1 = 61
lineno2 = 75
class SmartParamTestCase(TestCase): class SmartParamTestCase(TestCase):
def setUp(self): def setUp(self):
...@@ -39,43 +37,30 @@ class SmartParamTestCase(TestCase): ...@@ -39,43 +37,30 @@ class SmartParamTestCase(TestCase):
'test_smartparam/choice2/choice': '3*2+1', 'test_smartparam/choice2/choice': '3*2+1',
'test_smartparam/choice3/choice': '[1, 2]', 'test_smartparam/choice3/choice': '[1, 2]',
'test_smartparam/choice4/choice': '{"a", 2}', 'test_smartparam/choice4/choice': '{"a", 2}',
'test_smartparam/__line{:d}/uniform'.format(lineno1): '5',
'test_smartparam/func/function_choice': 'bar', 'test_smartparam/func/function_choice': 'bar',
'test_smartparam/lambda_func/function_choice': "lambda: 2*3", 'test_smartparam/lambda_func/function_choice': "lambda: 2*3"
'test_smartparam/__line{:d}/function_choice'.format(lineno2): 'max(1, 2, 3)'
} }
nni.trial._params = { 'parameter_id': 'test_trial', 'parameters': params } nni.trial._params = { 'parameter_id': 'test_trial', 'parameters': params }
def test_specified_name(self): def test_specified_name(self):
val = nni.choice({'a': 'a', '3*2+1': 3*2+1, '[1, 2]': [1, 2], '{"a", 2}': {"a", 2}}, name = 'choice1') val = nni.choice({'a': 'a', '3*2+1': 3*2+1, '[1, 2]': [1, 2], '{"a", 2}': {"a", 2}}, name = 'choice1', key='test_smartparam/choice1/choice')
self.assertEqual(val, 'a') self.assertEqual(val, 'a')
val = nni.choice({'a': 'a', '3*2+1': 3*2+1, '[1, 2]': [1, 2], '{"a", 2}': {"a", 2}}, name = 'choice2') val = nni.choice({'a': 'a', '3*2+1': 3*2+1, '[1, 2]': [1, 2], '{"a", 2}': {"a", 2}}, name = 'choice2', key='test_smartparam/choice2/choice')
self.assertEqual(val, 7) self.assertEqual(val, 7)
val = nni.choice({'a': 'a', '3*2+1': 3*2+1, '[1, 2]': [1, 2], '{"a", 2}': {"a", 2}}, name = 'choice3') val = nni.choice({'a': 'a', '3*2+1': 3*2+1, '[1, 2]': [1, 2], '{"a", 2}': {"a", 2}}, name = 'choice3', key='test_smartparam/choice3/choice')
self.assertEqual(val, [1, 2]) self.assertEqual(val, [1, 2])
val = nni.choice({'a': 'a', '3*2+1': 3*2+1, '[1, 2]': [1, 2], '{"a", 2}': {"a", 2}}, name = 'choice4') val = nni.choice({'a': 'a', '3*2+1': 3*2+1, '[1, 2]': [1, 2], '{"a", 2}': {"a", 2}}, name = 'choice4', key='test_smartparam/choice4/choice')
self.assertEqual(val, {"a", 2}) self.assertEqual(val, {"a", 2})
def test_default_name(self): def test_func(self):
val = nni.uniform(1, 10) # NOTE: assign this line number to lineno1 val = nni.function_choice({'foo': foo, 'bar': bar}, name='func', key='test_smartparam/func/function_choice')
self.assertEqual(val, '5')
def test_specified_name_func(self):
val = nni.function_choice({'foo': foo, 'bar': bar}, name = 'func')
self.assertEqual(val, 'bar') self.assertEqual(val, 'bar')
def test_lambda_func(self): def test_lambda_func(self):
val = nni.function_choice({"lambda: 2*3": lambda: 2*3, "lambda: 3*4": lambda: 3*4}, name = 'lambda_func') val = nni.function_choice({"lambda: 2*3": lambda: 2*3, "lambda: 3*4": lambda: 3*4}, name = 'lambda_func', key='test_smartparam/lambda_func/function_choice')
self.assertEqual(val, 6) self.assertEqual(val, 6)
def test_default_name_func(self):
val = nni.function_choice({
'max(1, 2, 3)': lambda: max(1, 2, 3),
'min(1, 2)': lambda: min(1, 2) # NOTE: assign this line number to lineno2
})
self.assertEqual(val, 3)
def foo(): def foo():
return 'foo' return 'foo'
......
...@@ -59,12 +59,15 @@ def generate_search_space(code_dir): ...@@ -59,12 +59,15 @@ def generate_search_space(code_dir):
def _generate_file_search_space(path, module): def _generate_file_search_space(path, module):
with open(path) as src: with open(path) as src:
try: try:
return search_space_generator.generate(module, src.read()) search_space, code = search_space_generator.generate(module, src.read())
except Exception as exc: # pylint: disable=broad-except except Exception as exc: # pylint: disable=broad-except
if exc.args: if exc.args:
raise RuntimeError(path + ' ' + '\n'.join(exc.args)) raise RuntimeError(path + ' ' + '\n'.join(exc.args))
else: else:
raise RuntimeError('Failed to generate search space for %s: %r' % (path, exc)) raise RuntimeError('Failed to generate search space for %s: %r' % (path, exc))
with open(path, 'w') as dst:
dst.write(code)
return search_space
def expand_annotations(src_dir, dst_dir): def expand_annotations(src_dir, dst_dir):
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
import ast import ast
import astor
# pylint: disable=unidiomatic-typecheck # pylint: disable=unidiomatic-typecheck
...@@ -40,7 +41,7 @@ _ss_funcs = [ ...@@ -40,7 +41,7 @@ _ss_funcs = [
] ]
class SearchSpaceGenerator(ast.NodeVisitor): class SearchSpaceGenerator(ast.NodeTransformer):
"""Generate search space from smart parater APIs""" """Generate search space from smart parater APIs"""
def __init__(self, module_name): def __init__(self, module_name):
...@@ -53,16 +54,16 @@ class SearchSpaceGenerator(ast.NodeVisitor): ...@@ -53,16 +54,16 @@ class SearchSpaceGenerator(ast.NodeVisitor):
# ignore if the function is not 'nni.*' # ignore if the function is not 'nni.*'
if type(node.func) is not ast.Attribute: if type(node.func) is not ast.Attribute:
return return node
if type(node.func.value) is not ast.Name: if type(node.func.value) is not ast.Name:
return return node
if node.func.value.id != 'nni': if node.func.value.id != 'nni':
return return node
# ignore if its not a search space function (e.g. `report_final_result`) # ignore if its not a search space function (e.g. `report_final_result`)
func = node.func.attr func = node.func.attr
if func not in _ss_funcs: if func not in _ss_funcs:
return return node
self.last_line = node.lineno self.last_line = node.lineno
...@@ -77,6 +78,7 @@ class SearchSpaceGenerator(ast.NodeVisitor): ...@@ -77,6 +78,7 @@ class SearchSpaceGenerator(ast.NodeVisitor):
# generate the missing name automatically # generate the missing name automatically
name = '__line' + str(str(node.args[-1].lineno)) name = '__line' + str(str(node.args[-1].lineno))
specified_name = False specified_name = False
node.keywords = list()
if func in ('choice', 'function_choice'): if func in ('choice', 'function_choice'):
# we will use keys in the dict as the choices, which is generated by code_generator according to the args given by user # we will use keys in the dict as the choices, which is generated by code_generator according to the args given by user
...@@ -89,6 +91,9 @@ class SearchSpaceGenerator(ast.NodeVisitor): ...@@ -89,6 +91,9 @@ class SearchSpaceGenerator(ast.NodeVisitor):
args = [arg.n for arg in node.args] args = [arg.n for arg in node.args]
key = self.module_name + '/' + name + '/' + func key = self.module_name + '/' + name + '/' + func
# store key in ast.Call
node.keywords.append(ast.keyword(arg='key', value=ast.Str(s=key)))
if func == 'function_choice': if func == 'function_choice':
func = 'choice' func = 'choice'
value = {'_type': func, '_value': args} value = {'_type': func, '_value': args}
...@@ -103,6 +108,8 @@ class SearchSpaceGenerator(ast.NodeVisitor): ...@@ -103,6 +108,8 @@ class SearchSpaceGenerator(ast.NodeVisitor):
self.search_space[key] = value self.search_space[key] = value
return node
def generate(module_name, code): def generate(module_name, code):
"""Generate search space. """Generate search space.
...@@ -120,4 +127,4 @@ def generate(module_name, code): ...@@ -120,4 +127,4 @@ def generate(module_name, code):
visitor.visit(ast_tree) visitor.visit(ast_tree)
except AssertionError as exc: except AssertionError as exc:
raise RuntimeError('%d: %s' % (visitor.last_line, exc.args[0])) raise RuntimeError('%d: %s' % (visitor.last_line, exc.args[0]))
return visitor.search_space return visitor.search_space, astor.to_source(ast_tree)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment