Commit f1210a9c authored by Yuge Zhang's avatar Yuge Zhang Committed by QuanluZhang
Browse files

Add support for @nni.training_update in codegen (#1564)

* add support for training_update in codegen and prettify code in nni annotation
parent 4b5b6876
......@@ -22,6 +22,7 @@
import ast
import astor
# pylint: disable=unidiomatic-typecheck
def parse_annotation_mutable_layers(code, lineno, nas_mode):
......@@ -79,7 +80,8 @@ def parse_annotation_mutable_layers(code, lineno, nas_mode):
fields['optional_inputs'] = True
elif k.id == 'optional_input_size':
assert not fields['optional_input_size'], 'Duplicated field: optional_input_size'
assert type(value) is ast.Num or type(value) is ast.List, 'Value of optional_input_size should be a number or list'
assert type(value) is ast.Num or type(value) is ast.List, \
'Value of optional_input_size should be a number or list'
optional_input_size = value
fields['optional_input_size'] = True
elif k.id == 'layer_output':
......@@ -118,6 +120,7 @@ def parse_annotation_mutable_layers(code, lineno, nas_mode):
nodes.append(node)
return nodes
def parse_annotation(code):
"""Parse an annotation string.
Return an AST Expr node.
......@@ -311,7 +314,6 @@ class Transformer(ast.NodeTransformer):
return self._visit_children(node)
def _visit_string(self, node):
string = node.value.s
if string.startswith('@nni.'):
......@@ -341,7 +343,6 @@ class Transformer(ast.NodeTransformer):
raise AssertionError('Unexpected annotation function')
def _visit_children(self, node):
self.stack.append(None)
self.generic_visit(node)
......
......@@ -64,7 +64,6 @@ class SearchSpaceGenerator(ast.NodeTransformer):
'optional_input_size': args[6].n if isinstance(args[6], ast.Num) else [args[6].elts[0].n, args[6].elts[1].n]
}
def visit_Call(self, node): # pylint: disable=invalid-name
self.generic_visit(node)
......
......@@ -28,6 +28,7 @@ from nni_cmd.common_utils import print_warning
para_cfg = None
prefix_name = None
def parse_annotation_mutable_layers(code, lineno):
"""Parse the string of mutable layers in annotation.
Return a list of AST Expr nodes
......@@ -102,6 +103,7 @@ def parse_annotation_mutable_layers(code, lineno):
nodes.append(node)
return nodes
def parse_annotation(code):
"""Parse an annotation string.
Return an AST Expr node.
......@@ -294,7 +296,6 @@ class Transformer(ast.NodeTransformer):
return self._visit_children(node)
def _visit_string(self, node):
string = node.value.s
if string.startswith('@nni.'):
......@@ -303,19 +304,27 @@ class Transformer(ast.NodeTransformer):
return node # not an annotation, ignore it
if string.startswith('@nni.get_next_parameter'):
deprecated_message = "'@nni.get_next_parameter' is deprecated in annotation due to inconvenience. Please remove this line in the trial code."
deprecated_message = "'@nni.get_next_parameter' is deprecated in annotation due to inconvenience. " \
"Please remove this line in the trial code."
print_warning(deprecated_message)
return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()), args=[ast.Str(s='Get next parameter here...')], keywords=[]))
return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()),
args=[ast.Str(s='Get next parameter here...')], keywords=[]))
if string.startswith('@nni.training_update'):
return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()),
args=[ast.Str(s='Training update here...')], keywords=[]))
if string.startswith('@nni.report_intermediate_result'):
module = ast.parse(string[1:])
arg = module.body[0].value.args[0]
return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()), args=[ast.Str(s='nni.report_intermediate_result: '), arg], keywords=[]))
return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()),
args=[ast.Str(s='nni.report_intermediate_result: '), arg], keywords=[]))
if string.startswith('@nni.report_final_result'):
module = ast.parse(string[1:])
arg = module.body[0].value.args[0]
return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()), args=[ast.Str(s='nni.report_final_result: '), arg], keywords=[]))
return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()),
args=[ast.Str(s='nni.report_final_result: '), arg], keywords=[]))
if string.startswith('@nni.mutable_layers'):
return parse_annotation_mutable_layers(string[1:], node.lineno)
......@@ -327,7 +336,6 @@ class Transformer(ast.NodeTransformer):
raise AssertionError('Unexpected annotation function')
def _visit_children(self, node):
self.stack.append(None)
self.generic_visit(node)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment