Commit 5461fe77 authored by Zejun Lin's avatar Zejun Lin Committed by QuanluZhang
Browse files

Update annotation to support displaying real choice in searchspace (#471)

* fix-annotation

* fix annotation for error printing

* update annotation

* update annotation

* update annotation

* update annotation

* update unittest for annotation

* update unit test

* update annotation

* update annotation

* update annotation

* update ut

* update ut
parent ff834cea
......@@ -32,6 +32,9 @@ _last_metric = None
def get_next_parameter():
return _params
def get_sequence_id():
return 0
def send_metric(string):
global _last_metric
_last_metric = string
......
......@@ -82,7 +82,7 @@ if env_args.platform is None:
else:
def choice(*options, name=None):
def choice(options, name=None):
return options[_get_param('choice', name)]
def randint(upper, name=None):
......@@ -112,7 +112,7 @@ else:
def qlognormal(mu, sigma, q, name=None):
return _get_param('qlognormal', name)
def function_choice(*funcs, name=None):
def function_choice(funcs, name=None):
return funcs[_get_param('function_choice', name)]()
def _get_param(func, name):
......
......@@ -18,6 +18,9 @@
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
import os
os.environ['NNI_PLATFORM'] = 'unittest'
import nni
import nni.platform.test as test_platform
......@@ -26,37 +29,51 @@ import nni.trial
from unittest import TestCase, main
lineno1 = 48
lineno2 = 58
lineno1 = 61
lineno2 = 75
class SmartParamTestCase(TestCase):
def setUp(self):
params = {
'test_smartparam/choice1/choice': 2,
'test_smartparam/choice1/choice': 'a',
'test_smartparam/choice2/choice': '3*2+1',
'test_smartparam/choice3/choice': '[1, 2]',
'test_smartparam/choice4/choice': '{"a", 2}',
'test_smartparam/__line{:d}/uniform'.format(lineno1): '5',
'test_smartparam/func/function_choice': 1,
'test_smartparam/__line{:d}/function_choice'.format(lineno2): 0
'test_smartparam/func/function_choice': 'bar',
'test_smartparam/lambda_func/function_choice': "lambda: 2*3",
'test_smartparam/__line{:d}/function_choice'.format(lineno2): 'max(1, 2, 3)'
}
nni.trial._params = { 'parameter_id': 'test_trial', 'parameters': params }
def test_specified_name(self):
val = nni.choice('a', 'b', 'c', name = 'choice1')
self.assertEqual(val, 'c')
val = nni.choice({'a': 'a', '3*2+1': 3*2+1, '[1, 2]': [1, 2], '{"a", 2}': {"a", 2}}, name = 'choice1')
self.assertEqual(val, 'a')
val = nni.choice({'a': 'a', '3*2+1': 3*2+1, '[1, 2]': [1, 2], '{"a", 2}': {"a", 2}}, name = 'choice2')
self.assertEqual(val, 7)
val = nni.choice({'a': 'a', '3*2+1': 3*2+1, '[1, 2]': [1, 2], '{"a", 2}': {"a", 2}}, name = 'choice3')
self.assertEqual(val, [1, 2])
val = nni.choice({'a': 'a', '3*2+1': 3*2+1, '[1, 2]': [1, 2], '{"a", 2}': {"a", 2}}, name = 'choice4')
self.assertEqual(val, {"a", 2})
def test_default_name(self):
val = nni.uniform(1, 10) # NOTE: assign this line number to lineno1
self.assertEqual(val, '5')
def test_specified_name_func(self):
val = nni.function_choice(foo, bar, name = 'func')
val = nni.function_choice({'foo': foo, 'bar': bar}, name = 'func')
self.assertEqual(val, 'bar')
def test_lambda_func(self):
val = nni.function_choice({"lambda: 2*3": lambda: 2*3, "lambda: 3*4": lambda: 3*4}, name = 'lambda_func')
self.assertEqual(val, 6)
def test_default_name_func(self):
val = nni.function_choice(
lambda: max(1, 2, 3),
lambda: 2 * 2 # NOTE: assign this line number to lineno2
)
val = nni.function_choice({
'max(1, 2, 3)': lambda: max(1, 2, 3),
'min(1, 2)': lambda: min(1, 2) # NOTE: assign this line number to lineno2
})
self.assertEqual(val, 3)
......
......@@ -110,6 +110,6 @@ def _expand_file_annotations(src_path, dst_path):
except Exception as exc: # pylint: disable=broad-except
if exc.args:
raise RuntimeError(src_path + ' ' + '\n'.join(exc.args))
raise RuntimeError(src_path + ' ' + '\n'.join(str(arg) for arg in exc.args))
else:
raise RuntimeError('Failed to expand annotations for %s: %r' % (src_path, exc))
......@@ -76,6 +76,8 @@ def parse_nni_variable(code):
name_str = astor.to_source(name).strip()
keyword_arg = ast.keyword(arg='name', value=ast.Str(s=name_str))
arg.keywords.append(keyword_arg)
if arg.func.attr == 'choice':
convert_args_to_dict(arg)
return name, arg
......@@ -87,7 +89,7 @@ def parse_nni_function(code):
"""
name, call = parse_annotation_function(code, 'function_choice')
funcs = [ast.dump(func, False) for func in call.args]
call.args = [make_lambda(arg) for arg in call.args]
convert_args_to_dict(call, with_lambda=True)
name_str = astor.to_source(name).strip()
call.keywords[0].value = ast.Str(s=name_str)
......@@ -95,11 +97,31 @@ def parse_nni_function(code):
return call, funcs
def convert_args_to_dict(call, with_lambda=False):
"""Convert all args to a dict such that every key and value in the dict is the same as the value of the arg.
Return the AST Call node with only one arg that is the dictionary
"""
keys, values = list(), list()
for arg in call.args:
if type(arg) in [ast.Str, ast.Num]:
arg_value = arg
else:
# if arg is not a string or a number, we use its source code as the key
arg_value = astor.to_source(arg).strip('\n"')
arg_value = ast.Str(str(arg_value))
arg = make_lambda(arg) if with_lambda else arg
keys.append(arg_value)
values.append(arg)
del call.args[:]
call.args.append(ast.Dict(keys=keys, values=values))
return call
def make_lambda(call):
"""Wrap an AST Call node to lambda expression node.
call: ast.Call node
"""
assert type(call) is ast.Call, 'Argument of nni.function_choice is not function call'
empty_args = ast.arguments(args=[], vararg=None, kwarg=None, defaults=[])
return ast.Lambda(args=empty_args, body=call)
......
......@@ -75,14 +75,14 @@ class SearchSpaceGenerator(ast.NodeVisitor):
specified_name = True
else:
# generate the missing name automatically
assert len(node.args) > 0, 'Smart parameter expression has no argument'
name = '__line' + str(node.args[-1].lineno)
name = '__line' + str(str(node.args[-1].lineno))
specified_name = False
if func in ('choice', 'function_choice'):
# arguments of `choice` may contain complex expression,
# so use indices instead of arguments
args = list(range(len(node.args)))
# we will use keys in the dict as the choices, which is generated by code_generator according to the args given by user
assert len(node.args) == 1, 'Smart parameter has arguments other than dict'
# check if it is a number or a string and get its value accordingly
args = [key.n if type(key) is ast.Num else key.s for key in node.args[0].keys]
else:
# arguments of other functions must be literal number
assert all(type(arg) is ast.Num for arg in node.args), 'Smart parameter\'s arguments must be number literals'
......
import nni
def max_pool(k):
pass
h_conv1 = 1
conv_size = nni.choice(2, 3, 5, 7, name='conv_size')
h_pool1 = nni.function_choice(lambda : max_pool(h_conv1), lambda : avg_pool
(h_conv2, h_conv3), name='max_pool')
conv_size = nni.choice({2: 2, 3: 3, 5: 5, 7: 7}, name='conv_size')
abc = nni.choice({'2': '2', 3: 3, '(5 * 6)': 5 * 6, "{(1): 2, '3': 4}": {(1
): 2, '3': 4}, '[1, 2, 3]': [1, 2, 3]}, name='abc')
h_pool1 = nni.function_choice({'max_pool(h_conv1)': lambda : max_pool(
h_conv1), 'avg_pool(h_conv2, h_conv3)': lambda : avg_pool(h_conv2,
h_conv3)}, name='max_pool')
h_pool2 = nni.function_choice({'max_poo(h_conv1)': lambda : max_poo(h_conv1
), '(2 * 3 + 4)': lambda : 2 * 3 + 4, '(lambda x: 1 + x)': lambda : lambda
x: 1 + x}, name='max_poo')
test_acc = 1
nni.report_intermediate_result(test_acc)
test_acc = 2
......
import nni
def max_pool(k):
pass
h_conv1 = 1
conv_size = nni.choice(2, 3, 5, 7, name='conv_size')
h_pool1 = nni.function_choice(lambda : max_pool(h_conv1),
lambda : h_conv1,
lambda : avg_pool
(h_conv2, h_conv3)
nni.choice({'foo': foo, 'bar': bar})(1)
conv_size = nni.choice({2: 2, 3: 3, 5: 5, 7: 7}, name='conv_size')
abc = nni.choice({'2': '2', 3: 3, '(5 * 6)': 5 * 6, 7: 7}, name='abc')
h_pool1 = nni.function_choice({'max_pool': lambda : max_pool(h_conv1),
'h_conv1': lambda : h_conv1,
'avg_pool': lambda : avg_pool(h_conv2, h_conv3)}
)
h_pool1 = nni.function_choice({'max_pool(h_conv1)': lambda : max_pool(
h_conv1), 'avg_pool(h_conv2, h_conv3)': lambda : avg_pool(h_conv2,
h_conv3)}, name='max_pool')
h_pool2 = nni.function_choice({'max_poo(h_conv1)': lambda : max_poo(h_conv1
), '(2 * 3 + 4)': lambda : 2 * 3 + 4, '(lambda x: 1 + x)': lambda : lambda
x: 1 + x}, name='max_poo')
tmp = nni.qlognormal(1.2, 3, 4.5)
test_acc = 1
nni.report_intermediate_result(test_acc)
test_acc = 2
nni.report_final_result(test_acc)
nni.choice(foo, bar)(1)
......@@ -3,15 +3,12 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import nni
import logging
import math
import tempfile
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
logger = logging.getLogger('mnist')
FLAGS = None
......@@ -23,8 +20,10 @@ class MnistNetwork(object):
y_dim=10):
self.channel_1_num = channel_1_num
self.channel_2_num = channel_2_num
self.conv_size = nni.choice(2, 3, 5, 7, name='self.conv_size')
self.hidden_size = nni.choice(124, 512, 1024, name='self.hidden_size')
self.conv_size = nni.choice({2: 2, 3: 3, 5: 5, 7: 7}, name=
'self.conv_size')
self.hidden_size = nni.choice({124: 124, 512: 512, 1024: 1024},
name='self.hidden_size')
self.pool_size = pool_size
self.learning_rate = nni.randint(2, 3, 5, name='self.learning_rate')
self.x_dim = x_dim
......@@ -47,14 +46,20 @@ class MnistNetwork(object):
W_conv1 = weight_variable([self.conv_size, self.conv_size, 1,
self.channel_1_num])
b_conv1 = bias_variable([self.channel_1_num])
h_conv1 = nni.function_choice(lambda : tf.nn.relu(conv2d(
x_image, W_conv1) + b_conv1), lambda : tf.nn.sigmoid(conv2d
(x_image, W_conv1) + b_conv1), lambda : tf.nn.tanh(conv2d(
x_image, W_conv1) + b_conv1), name='tf.nn.relu')
h_conv1 = nni.function_choice({
'tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)': lambda :
tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1),
'tf.nn.sigmoid(conv2d(x_image, W_conv1) + b_conv1)': lambda :
tf.nn.sigmoid(conv2d(x_image, W_conv1) + b_conv1),
'tf.nn.tanh(conv2d(x_image, W_conv1) + b_conv1)': lambda :
tf.nn.tanh(conv2d(x_image, W_conv1) + b_conv1)}, name=
'tf.nn.relu')
with tf.name_scope('pool1'):
h_pool1 = nni.function_choice(lambda : max_pool(h_conv1, self.
pool_size), lambda : avg_pool(h_conv1, self.pool_size),
name='max_pool')
h_pool1 = nni.function_choice({
'max_pool(h_conv1, self.pool_size)': lambda : max_pool(
h_conv1, self.pool_size),
'avg_pool(h_conv1, self.pool_size)': lambda : avg_pool(
h_conv1, self.pool_size)}, name='max_pool')
with tf.name_scope('conv2'):
W_conv2 = weight_variable([self.conv_size, self.conv_size, self
.channel_1_num, self.channel_2_num])
......@@ -135,9 +140,10 @@ def main():
sess.run(tf.global_variables_initializer())
batch_num = 200
for i in range(batch_num):
batch_size = nni.choice(50, 250, 500, name='batch_size')
batch_size = nni.choice({50: 50, 250: 250, 500: 500}, name=
'batch_size')
batch = mnist.train.next_batch(batch_size)
dropout_rate = nni.choice(1, 5, name='dropout_rate')
dropout_rate = nni.choice({1: 1, 5: 5}, name='dropout_rate')
mnist_network.train_step.run(feed_dict={mnist_network.x: batch[
0], mnist_network.y: batch[1], mnist_network.keep_prob:
dropout_rate})
......
{
"handwrite/__line6/choice": {
"_type": "choice",
"_value": [
"foo",
"bar"
]
},
"handwrite/conv_size/choice": {
"_type": "choice",
"_value": [ 0, 1, 2, 3 ]
"_value": [
2,
3,
5,
7
]
},
"handwrite/__line5/function_choice": {
"handwrite/abc/choice": {
"_type": "choice",
"_value": [ 0, 1, 2 ]
"_value": [
"2",
3,
"(5 * 6)",
7
]
},
"handwrite/__line8/qlognormal": {
"_type": "qlognormal",
"_value": [ 1.2, 3, 4.5 ]
"handwrite/__line9/function_choice": {
"_type": "choice",
"_value": [
"max_pool",
"h_conv1",
"avg_pool"
]
},
"handwrite/max_pool/function_choice": {
"_type": "choice",
"_value": [
"max_pool(h_conv1)",
"avg_pool(h_conv2, h_conv3)"
]
},
"handwrite/__line13/choice": {
"handwrite/max_poo/function_choice": {
"_type": "choice",
"_value": [ 0, 1 ]
"_value": [
"max_poo(h_conv1)",
"(2 * 3 + 4)",
"(lambda x: 1 + x)"
]
},
"handwrite/__line19/qlognormal": {
"_type": "qlognormal",
"_value": [
1.2,
3,
4.5
]
},
"mnist/self.conv_size/choice": {
"_type": "choice",
"_value": [ 0, 1, 2, 3 ]
"_value": [
2,
3,
5,
7
]
},
"mnist/self.hidden_size/choice": {
"_type": "choice",
"_value": [ 0, 1, 2 ]
"_value": [
124,
512,
1024
]
},
"mnist/self.learning_rate/randint": {
"_type": "randint",
"_value": [ 2, 3, 5 ]
"_value": [
2,
3,
5
]
},
"mnist/tf.nn.relu/function_choice": {
"_type": "choice",
"_value": [ 0, 1, 2 ]
"_value": [
"tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)",
"tf.nn.sigmoid(conv2d(x_image, W_conv1) + b_conv1)",
"tf.nn.tanh(conv2d(x_image, W_conv1) + b_conv1)"
]
},
"mnist/max_pool/function_choice": {
"_type": "choice",
"_value": [ 0, 1 ]
"_value": [
"max_pool(h_conv1, self.pool_size)",
"avg_pool(h_conv1, self.pool_size)"
]
},
"mnist/batch_size/choice": {
"_type": "choice",
"_value": [ 0, 1, 2 ]
"_value": [
50,
250,
500
]
},
"mnist/dropout_rate/choice": {
"_type": "choice",
"_value": [ 0, 1 ]
"_value": [
1,
5
]
},
"dir.simple/conv_size/choice": {
"_type": "choice",
"_value": [ 0, 1, 2, 3 ]
"_value": [
2,
3,
5,
7
]
},
"dir.simple/abc/choice": {
"_type": "choice",
"_value": [
"2",
3,
"(5 * 6)",
"{(1): 2, '3': 4}",
"[1, 2, 3]"
]
},
"dir.simple/max_pool/function_choice": {
"_type": "choice",
"_value": [ 0, 1 ]
"_value": [
"max_pool(h_conv1)",
"avg_pool(h_conv2, h_conv3)"
]
},
"dir.simple/max_poo/function_choice": {
"_type": "choice",
"_value": [
"max_poo(h_conv1)",
"(2 * 3 + 4)",
"(lambda x: 1 + x)"
]
}
}
\ No newline at end of file
......@@ -3,8 +3,12 @@ def max_pool(k):
h_conv1=1
"""@nni.variable(nni.choice(2,3,5,7),name=conv_size)"""
conv_size = 5
"""@nni.function_choice(max_pool(h_conv1),avg_pool(h_conv2,h_conv3),name=max_pool)"""
"""@nni.variable(nni.choice('2',3,5*6,{1:2, '3':4},[1,2,3]),name=abc)"""
abc = 5
"""@nni.function_choice(max_pool(h_conv1), avg_pool(h_conv2,h_conv3), name=max_pool)"""
h_pool1 = max_pool(h_conv1)
"""@nni.function_choice(max_poo(h_conv1), 2 * 3 + 4, lambda x: 1+x, name=max_poo)"""
h_pool2 = max_poo(h_conv1)
test_acc=1
'''@nni.report_intermediate_result(test_acc)'''
test_acc=2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment