Commit f1210a9c authored by Yuge Zhang's avatar Yuge Zhang Committed by QuanluZhang
Browse files

Add support for @nni.training_update in codegen (#1564)

* add support for training_update in codegen and prettify code in nni annotation
parent 4b5b6876
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
import ast import ast
import astor import astor
# pylint: disable=unidiomatic-typecheck # pylint: disable=unidiomatic-typecheck
def parse_annotation_mutable_layers(code, lineno, nas_mode): def parse_annotation_mutable_layers(code, lineno, nas_mode):
...@@ -79,7 +80,8 @@ def parse_annotation_mutable_layers(code, lineno, nas_mode): ...@@ -79,7 +80,8 @@ def parse_annotation_mutable_layers(code, lineno, nas_mode):
fields['optional_inputs'] = True fields['optional_inputs'] = True
elif k.id == 'optional_input_size': elif k.id == 'optional_input_size':
assert not fields['optional_input_size'], 'Duplicated field: optional_input_size' assert not fields['optional_input_size'], 'Duplicated field: optional_input_size'
assert type(value) is ast.Num or type(value) is ast.List, 'Value of optional_input_size should be a number or list' assert type(value) is ast.Num or type(value) is ast.List, \
'Value of optional_input_size should be a number or list'
optional_input_size = value optional_input_size = value
fields['optional_input_size'] = True fields['optional_input_size'] = True
elif k.id == 'layer_output': elif k.id == 'layer_output':
...@@ -118,6 +120,7 @@ def parse_annotation_mutable_layers(code, lineno, nas_mode): ...@@ -118,6 +120,7 @@ def parse_annotation_mutable_layers(code, lineno, nas_mode):
nodes.append(node) nodes.append(node)
return nodes return nodes
def parse_annotation(code): def parse_annotation(code):
"""Parse an annotation string. """Parse an annotation string.
Return an AST Expr node. Return an AST Expr node.
...@@ -311,7 +314,6 @@ class Transformer(ast.NodeTransformer): ...@@ -311,7 +314,6 @@ class Transformer(ast.NodeTransformer):
return self._visit_children(node) return self._visit_children(node)
def _visit_string(self, node): def _visit_string(self, node):
string = node.value.s string = node.value.s
if string.startswith('@nni.'): if string.startswith('@nni.'):
...@@ -341,7 +343,6 @@ class Transformer(ast.NodeTransformer): ...@@ -341,7 +343,6 @@ class Transformer(ast.NodeTransformer):
raise AssertionError('Unexpected annotation function') raise AssertionError('Unexpected annotation function')
def _visit_children(self, node): def _visit_children(self, node):
self.stack.append(None) self.stack.append(None)
self.generic_visit(node) self.generic_visit(node)
......
...@@ -64,7 +64,6 @@ class SearchSpaceGenerator(ast.NodeTransformer): ...@@ -64,7 +64,6 @@ class SearchSpaceGenerator(ast.NodeTransformer):
'optional_input_size': args[6].n if isinstance(args[6], ast.Num) else [args[6].elts[0].n, args[6].elts[1].n] 'optional_input_size': args[6].n if isinstance(args[6], ast.Num) else [args[6].elts[0].n, args[6].elts[1].n]
} }
def visit_Call(self, node): # pylint: disable=invalid-name def visit_Call(self, node): # pylint: disable=invalid-name
self.generic_visit(node) self.generic_visit(node)
......
...@@ -28,6 +28,7 @@ from nni_cmd.common_utils import print_warning ...@@ -28,6 +28,7 @@ from nni_cmd.common_utils import print_warning
para_cfg = None para_cfg = None
prefix_name = None prefix_name = None
def parse_annotation_mutable_layers(code, lineno): def parse_annotation_mutable_layers(code, lineno):
"""Parse the string of mutable layers in annotation. """Parse the string of mutable layers in annotation.
Return a list of AST Expr nodes Return a list of AST Expr nodes
...@@ -102,6 +103,7 @@ def parse_annotation_mutable_layers(code, lineno): ...@@ -102,6 +103,7 @@ def parse_annotation_mutable_layers(code, lineno):
nodes.append(node) nodes.append(node)
return nodes return nodes
def parse_annotation(code): def parse_annotation(code):
"""Parse an annotation string. """Parse an annotation string.
Return an AST Expr node. Return an AST Expr node.
...@@ -294,7 +296,6 @@ class Transformer(ast.NodeTransformer): ...@@ -294,7 +296,6 @@ class Transformer(ast.NodeTransformer):
return self._visit_children(node) return self._visit_children(node)
def _visit_string(self, node): def _visit_string(self, node):
string = node.value.s string = node.value.s
if string.startswith('@nni.'): if string.startswith('@nni.'):
...@@ -303,19 +304,27 @@ class Transformer(ast.NodeTransformer): ...@@ -303,19 +304,27 @@ class Transformer(ast.NodeTransformer):
return node # not an annotation, ignore it return node # not an annotation, ignore it
if string.startswith('@nni.get_next_parameter'): if string.startswith('@nni.get_next_parameter'):
deprecated_message = "'@nni.get_next_parameter' is deprecated in annotation due to inconvenience. Please remove this line in the trial code." deprecated_message = "'@nni.get_next_parameter' is deprecated in annotation due to inconvenience. " \
"Please remove this line in the trial code."
print_warning(deprecated_message) print_warning(deprecated_message)
return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()), args=[ast.Str(s='Get next parameter here...')], keywords=[])) return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()),
args=[ast.Str(s='Get next parameter here...')], keywords=[]))
if string.startswith('@nni.training_update'):
return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()),
args=[ast.Str(s='Training update here...')], keywords=[]))
if string.startswith('@nni.report_intermediate_result'): if string.startswith('@nni.report_intermediate_result'):
module = ast.parse(string[1:]) module = ast.parse(string[1:])
arg = module.body[0].value.args[0] arg = module.body[0].value.args[0]
return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()), args=[ast.Str(s='nni.report_intermediate_result: '), arg], keywords=[])) return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()),
args=[ast.Str(s='nni.report_intermediate_result: '), arg], keywords=[]))
if string.startswith('@nni.report_final_result'): if string.startswith('@nni.report_final_result'):
module = ast.parse(string[1:]) module = ast.parse(string[1:])
arg = module.body[0].value.args[0] arg = module.body[0].value.args[0]
return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()), args=[ast.Str(s='nni.report_final_result: '), arg], keywords=[])) return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()),
args=[ast.Str(s='nni.report_final_result: '), arg], keywords=[]))
if string.startswith('@nni.mutable_layers'): if string.startswith('@nni.mutable_layers'):
return parse_annotation_mutable_layers(string[1:], node.lineno) return parse_annotation_mutable_layers(string[1:], node.lineno)
...@@ -327,7 +336,6 @@ class Transformer(ast.NodeTransformer): ...@@ -327,7 +336,6 @@ class Transformer(ast.NodeTransformer):
raise AssertionError('Unexpected annotation function') raise AssertionError('Unexpected annotation function')
def _visit_children(self, node): def _visit_children(self, node):
self.stack.append(None) self.stack.append(None)
self.generic_visit(node) self.generic_visit(node)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment