Commit f1210a9c authored by Yuge Zhang's avatar Yuge Zhang Committed by QuanluZhang
Browse files

Add support for @nni.training_update in codegen (#1564)

* add support for training_update in codegen and prettify code in nni annotation
parent 4b5b6876
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
import ast import ast
import astor import astor
# pylint: disable=unidiomatic-typecheck # pylint: disable=unidiomatic-typecheck
def parse_annotation_mutable_layers(code, lineno, nas_mode): def parse_annotation_mutable_layers(code, lineno, nas_mode):
...@@ -79,7 +80,8 @@ def parse_annotation_mutable_layers(code, lineno, nas_mode): ...@@ -79,7 +80,8 @@ def parse_annotation_mutable_layers(code, lineno, nas_mode):
fields['optional_inputs'] = True fields['optional_inputs'] = True
elif k.id == 'optional_input_size': elif k.id == 'optional_input_size':
assert not fields['optional_input_size'], 'Duplicated field: optional_input_size' assert not fields['optional_input_size'], 'Duplicated field: optional_input_size'
assert type(value) is ast.Num or type(value) is ast.List, 'Value of optional_input_size should be a number or list' assert type(value) is ast.Num or type(value) is ast.List, \
'Value of optional_input_size should be a number or list'
optional_input_size = value optional_input_size = value
fields['optional_input_size'] = True fields['optional_input_size'] = True
elif k.id == 'layer_output': elif k.id == 'layer_output':
...@@ -118,6 +120,7 @@ def parse_annotation_mutable_layers(code, lineno, nas_mode): ...@@ -118,6 +120,7 @@ def parse_annotation_mutable_layers(code, lineno, nas_mode):
nodes.append(node) nodes.append(node)
return nodes return nodes
def parse_annotation(code): def parse_annotation(code):
"""Parse an annotation string. """Parse an annotation string.
Return an AST Expr node. Return an AST Expr node.
...@@ -198,7 +201,7 @@ def convert_args_to_dict(call, with_lambda=False): ...@@ -198,7 +201,7 @@ def convert_args_to_dict(call, with_lambda=False):
if type(arg) in [ast.Str, ast.Num]: if type(arg) in [ast.Str, ast.Num]:
arg_value = arg arg_value = arg
else: else:
# if arg is not a string or a number, we use its source code as the key # if arg is not a string or a number, we use its source code as the key
arg_value = astor.to_source(arg).strip('\n"') arg_value = astor.to_source(arg).strip('\n"')
arg_value = ast.Str(str(arg_value)) arg_value = ast.Str(str(arg_value))
arg = make_lambda(arg) if with_lambda else arg arg = make_lambda(arg) if with_lambda else arg
...@@ -311,7 +314,6 @@ class Transformer(ast.NodeTransformer): ...@@ -311,7 +314,6 @@ class Transformer(ast.NodeTransformer):
return self._visit_children(node) return self._visit_children(node)
def _visit_string(self, node): def _visit_string(self, node):
string = node.value.s string = node.value.s
if string.startswith('@nni.'): if string.startswith('@nni.'):
...@@ -325,7 +327,7 @@ class Transformer(ast.NodeTransformer): ...@@ -325,7 +327,7 @@ class Transformer(ast.NodeTransformer):
call_node.args.insert(0, ast.Str(s=self.nas_mode)) call_node.args.insert(0, ast.Str(s=self.nas_mode))
return expr return expr
if string.startswith('@nni.report_intermediate_result') \ if string.startswith('@nni.report_intermediate_result') \
or string.startswith('@nni.report_final_result') \ or string.startswith('@nni.report_final_result') \
or string.startswith('@nni.get_next_parameter'): or string.startswith('@nni.get_next_parameter'):
return parse_annotation(string[1:]) # expand annotation string to code return parse_annotation(string[1:]) # expand annotation string to code
...@@ -341,7 +343,6 @@ class Transformer(ast.NodeTransformer): ...@@ -341,7 +343,6 @@ class Transformer(ast.NodeTransformer):
raise AssertionError('Unexpected annotation function') raise AssertionError('Unexpected annotation function')
def _visit_children(self, node): def _visit_children(self, node):
self.stack.append(None) self.stack.append(None)
self.generic_visit(node) self.generic_visit(node)
......
...@@ -64,7 +64,6 @@ class SearchSpaceGenerator(ast.NodeTransformer): ...@@ -64,7 +64,6 @@ class SearchSpaceGenerator(ast.NodeTransformer):
'optional_input_size': args[6].n if isinstance(args[6], ast.Num) else [args[6].elts[0].n, args[6].elts[1].n] 'optional_input_size': args[6].n if isinstance(args[6], ast.Num) else [args[6].elts[0].n, args[6].elts[1].n]
} }
def visit_Call(self, node): # pylint: disable=invalid-name def visit_Call(self, node): # pylint: disable=invalid-name
self.generic_visit(node) self.generic_visit(node)
...@@ -108,7 +107,7 @@ class SearchSpaceGenerator(ast.NodeTransformer): ...@@ -108,7 +107,7 @@ class SearchSpaceGenerator(ast.NodeTransformer):
else: else:
# arguments of other functions must be literal number # arguments of other functions must be literal number
assert all(isinstance(ast.literal_eval(astor.to_source(arg)), numbers.Real) for arg in node.args), \ assert all(isinstance(ast.literal_eval(astor.to_source(arg)), numbers.Real) for arg in node.args), \
'Smart parameter\'s arguments must be number literals' 'Smart parameter\'s arguments must be number literals'
args = [ast.literal_eval(astor.to_source(arg)) for arg in node.args] args = [ast.literal_eval(astor.to_source(arg)) for arg in node.args]
key = self.module_name + '/' + name + '/' + func key = self.module_name + '/' + name + '/' + func
......
...@@ -28,6 +28,7 @@ from nni_cmd.common_utils import print_warning ...@@ -28,6 +28,7 @@ from nni_cmd.common_utils import print_warning
para_cfg = None para_cfg = None
prefix_name = None prefix_name = None
def parse_annotation_mutable_layers(code, lineno): def parse_annotation_mutable_layers(code, lineno):
"""Parse the string of mutable layers in annotation. """Parse the string of mutable layers in annotation.
Return a list of AST Expr nodes Return a list of AST Expr nodes
...@@ -102,6 +103,7 @@ def parse_annotation_mutable_layers(code, lineno): ...@@ -102,6 +103,7 @@ def parse_annotation_mutable_layers(code, lineno):
nodes.append(node) nodes.append(node)
return nodes return nodes
def parse_annotation(code): def parse_annotation(code):
"""Parse an annotation string. """Parse an annotation string.
Return an AST Expr node. Return an AST Expr node.
...@@ -182,7 +184,7 @@ def convert_args_to_dict(call, with_lambda=False): ...@@ -182,7 +184,7 @@ def convert_args_to_dict(call, with_lambda=False):
if type(arg) in [ast.Str, ast.Num]: if type(arg) in [ast.Str, ast.Num]:
arg_value = arg arg_value = arg
else: else:
# if arg is not a string or a number, we use its source code as the key # if arg is not a string or a number, we use its source code as the key
arg_value = astor.to_source(arg).strip('\n"') arg_value = astor.to_source(arg).strip('\n"')
arg_value = ast.Str(str(arg_value)) arg_value = ast.Str(str(arg_value))
arg = make_lambda(arg) if with_lambda else arg arg = make_lambda(arg) if with_lambda else arg
...@@ -217,7 +219,7 @@ def test_variable_equal(node1, node2): ...@@ -217,7 +219,7 @@ def test_variable_equal(node1, node2):
if len(node1) != len(node2): if len(node1) != len(node2):
return False return False
return all(test_variable_equal(n1, n2) for n1, n2 in zip(node1, node2)) return all(test_variable_equal(n1, n2) for n1, n2 in zip(node1, node2))
return node1 == node2 return node1 == node2
...@@ -294,7 +296,6 @@ class Transformer(ast.NodeTransformer): ...@@ -294,7 +296,6 @@ class Transformer(ast.NodeTransformer):
return self._visit_children(node) return self._visit_children(node)
def _visit_string(self, node): def _visit_string(self, node):
string = node.value.s string = node.value.s
if string.startswith('@nni.'): if string.startswith('@nni.'):
...@@ -303,19 +304,27 @@ class Transformer(ast.NodeTransformer): ...@@ -303,19 +304,27 @@ class Transformer(ast.NodeTransformer):
return node # not an annotation, ignore it return node # not an annotation, ignore it
if string.startswith('@nni.get_next_parameter'): if string.startswith('@nni.get_next_parameter'):
deprecated_message = "'@nni.get_next_parameter' is deprecated in annotation due to inconvenience. Please remove this line in the trial code." deprecated_message = "'@nni.get_next_parameter' is deprecated in annotation due to inconvenience. " \
"Please remove this line in the trial code."
print_warning(deprecated_message) print_warning(deprecated_message)
return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()), args=[ast.Str(s='Get next parameter here...')], keywords=[])) return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()),
args=[ast.Str(s='Get next parameter here...')], keywords=[]))
if string.startswith('@nni.training_update'):
return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()),
args=[ast.Str(s='Training update here...')], keywords=[]))
if string.startswith('@nni.report_intermediate_result'): if string.startswith('@nni.report_intermediate_result'):
module = ast.parse(string[1:]) module = ast.parse(string[1:])
arg = module.body[0].value.args[0] arg = module.body[0].value.args[0]
return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()), args=[ast.Str(s='nni.report_intermediate_result: '), arg], keywords=[])) return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()),
args=[ast.Str(s='nni.report_intermediate_result: '), arg], keywords=[]))
if string.startswith('@nni.report_final_result'): if string.startswith('@nni.report_final_result'):
module = ast.parse(string[1:]) module = ast.parse(string[1:])
arg = module.body[0].value.args[0] arg = module.body[0].value.args[0]
return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()), args=[ast.Str(s='nni.report_final_result: '), arg], keywords=[])) return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()),
args=[ast.Str(s='nni.report_final_result: '), arg], keywords=[]))
if string.startswith('@nni.mutable_layers'): if string.startswith('@nni.mutable_layers'):
return parse_annotation_mutable_layers(string[1:], node.lineno) return parse_annotation_mutable_layers(string[1:], node.lineno)
...@@ -327,7 +336,6 @@ class Transformer(ast.NodeTransformer): ...@@ -327,7 +336,6 @@ class Transformer(ast.NodeTransformer):
raise AssertionError('Unexpected annotation function') raise AssertionError('Unexpected annotation function')
def _visit_children(self, node): def _visit_children(self, node):
self.stack.append(None) self.stack.append(None)
self.generic_visit(node) self.generic_visit(node)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment