Unverified Commit ebca3cec authored by J-shang's avatar J-shang Committed by GitHub
Browse files

support annotation in python 3.8 (#2881)


Co-authored-by: default avatarNing Shang <nishang@microsoft.com>
parent 1d8b8e48
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import ast import ast
import astor import astor
from .utils import ast_Num, ast_Str
# pylint: disable=unidiomatic-typecheck # pylint: disable=unidiomatic-typecheck
...@@ -37,13 +38,13 @@ def parse_annotation_mutable_layers(code, lineno, nas_mode): ...@@ -37,13 +38,13 @@ def parse_annotation_mutable_layers(code, lineno, nas_mode):
for call in value.elts: for call in value.elts:
assert type(call) is ast.Call, 'Element in layer_choice should be function call' assert type(call) is ast.Call, 'Element in layer_choice should be function call'
call_name = astor.to_source(call).strip() call_name = astor.to_source(call).strip()
call_funcs_keys.append(ast.Str(s=call_name)) call_funcs_keys.append(ast_Str(s=call_name))
call_funcs_values.append(call.func) call_funcs_values.append(call.func)
assert not call.args, 'Number of args without keyword should be zero' assert not call.args, 'Number of args without keyword should be zero'
kw_args = [] kw_args = []
kw_values = [] kw_values = []
for kw in call.keywords: for kw in call.keywords:
kw_args.append(ast.Str(s=kw.arg)) kw_args.append(ast_Str(s=kw.arg))
kw_values.append(kw.value) kw_values.append(kw.value)
call_kwargs_values.append(ast.Dict(keys=kw_args, values=kw_values)) call_kwargs_values.append(ast.Dict(keys=kw_args, values=kw_values))
call_funcs = ast.Dict(keys=call_funcs_keys, values=call_funcs_values) call_funcs = ast.Dict(keys=call_funcs_keys, values=call_funcs_values)
...@@ -57,12 +58,12 @@ def parse_annotation_mutable_layers(code, lineno, nas_mode): ...@@ -57,12 +58,12 @@ def parse_annotation_mutable_layers(code, lineno, nas_mode):
elif k.id == 'optional_inputs': elif k.id == 'optional_inputs':
assert not fields['optional_inputs'], 'Duplicated field: optional_inputs' assert not fields['optional_inputs'], 'Duplicated field: optional_inputs'
assert type(value) is ast.List, 'Value of optional_inputs should be a list' assert type(value) is ast.List, 'Value of optional_inputs should be a list'
var_names = [ast.Str(s=astor.to_source(var).strip()) for var in value.elts] var_names = [ast_Str(s=astor.to_source(var).strip()) for var in value.elts]
optional_inputs = ast.Dict(keys=var_names, values=value.elts) optional_inputs = ast.Dict(keys=var_names, values=value.elts)
fields['optional_inputs'] = True fields['optional_inputs'] = True
elif k.id == 'optional_input_size': elif k.id == 'optional_input_size':
assert not fields['optional_input_size'], 'Duplicated field: optional_input_size' assert not fields['optional_input_size'], 'Duplicated field: optional_input_size'
assert type(value) is ast.Num or type(value) is ast.List, \ assert type(value) is ast_Num or type(value) is ast.List, \
'Value of optional_input_size should be a number or list' 'Value of optional_input_size should be a number or list'
optional_input_size = value optional_input_size = value
fields['optional_input_size'] = True fields['optional_input_size'] = True
...@@ -79,8 +80,8 @@ def parse_annotation_mutable_layers(code, lineno, nas_mode): ...@@ -79,8 +80,8 @@ def parse_annotation_mutable_layers(code, lineno, nas_mode):
mutable_layer_id = 'mutable_layer_' + str(mutable_layer_cnt) mutable_layer_id = 'mutable_layer_' + str(mutable_layer_cnt)
mutable_layer_cnt += 1 mutable_layer_cnt += 1
target_call_attr = ast.Attribute(value=ast.Name(id='nni', ctx=ast.Load()), attr='mutable_layer', ctx=ast.Load()) target_call_attr = ast.Attribute(value=ast.Name(id='nni', ctx=ast.Load()), attr='mutable_layer', ctx=ast.Load())
target_call_args = [ast.Str(s=mutable_id), target_call_args = [ast_Str(s=mutable_id),
ast.Str(s=mutable_layer_id), ast_Str(s=mutable_layer_id),
call_funcs, call_funcs,
call_kwargs] call_kwargs]
if fields['fixed_inputs']: if fields['fixed_inputs']:
...@@ -93,8 +94,8 @@ def parse_annotation_mutable_layers(code, lineno, nas_mode): ...@@ -93,8 +94,8 @@ def parse_annotation_mutable_layers(code, lineno, nas_mode):
target_call_args.append(optional_input_size) target_call_args.append(optional_input_size)
else: else:
target_call_args.append(ast.Dict(keys=[], values=[])) target_call_args.append(ast.Dict(keys=[], values=[]))
target_call_args.append(ast.Num(n=0)) target_call_args.append(ast_Num(n=0))
target_call_args.append(ast.Str(s=nas_mode)) target_call_args.append(ast_Str(s=nas_mode))
if nas_mode in ['enas_mode', 'oneshot_mode', 'darts_mode']: if nas_mode in ['enas_mode', 'oneshot_mode', 'darts_mode']:
target_call_args.append(ast.Name(id='tensorflow')) target_call_args.append(ast.Name(id='tensorflow'))
target_call = ast.Call(func=target_call_attr, args=target_call_args, keywords=[]) target_call = ast.Call(func=target_call_attr, args=target_call_args, keywords=[])
...@@ -151,7 +152,7 @@ def parse_nni_variable(code): ...@@ -151,7 +152,7 @@ def parse_nni_variable(code):
assert arg.func.value.id == 'nni', 'nni.variable value is not a NNI function' assert arg.func.value.id == 'nni', 'nni.variable value is not a NNI function'
name_str = astor.to_source(name).strip() name_str = astor.to_source(name).strip()
keyword_arg = ast.keyword(arg='name', value=ast.Str(s=name_str)) keyword_arg = ast.keyword(arg='name', value=ast_Str(s=name_str))
arg.keywords.append(keyword_arg) arg.keywords.append(keyword_arg)
if arg.func.attr == 'choice': if arg.func.attr == 'choice':
convert_args_to_dict(arg) convert_args_to_dict(arg)
...@@ -169,7 +170,7 @@ def parse_nni_function(code): ...@@ -169,7 +170,7 @@ def parse_nni_function(code):
convert_args_to_dict(call, with_lambda=True) convert_args_to_dict(call, with_lambda=True)
name_str = astor.to_source(name).strip() name_str = astor.to_source(name).strip()
call.keywords[0].value = ast.Str(s=name_str) call.keywords[0].value = ast_Str(s=name_str)
return call, funcs return call, funcs
...@@ -180,12 +181,12 @@ def convert_args_to_dict(call, with_lambda=False): ...@@ -180,12 +181,12 @@ def convert_args_to_dict(call, with_lambda=False):
""" """
keys, values = list(), list() keys, values = list(), list()
for arg in call.args: for arg in call.args:
if type(arg) in [ast.Str, ast.Num]: if type(arg) in [ast_Str, ast_Num]:
arg_value = arg arg_value = arg
else: else:
# if arg is not a string or a number, we use its source code as the key # if arg is not a string or a number, we use its source code as the key
arg_value = astor.to_source(arg).strip('\n"') arg_value = astor.to_source(arg).strip('\n"')
arg_value = ast.Str(str(arg_value)) arg_value = ast_Str(str(arg_value))
arg = make_lambda(arg) if with_lambda else arg arg = make_lambda(arg) if with_lambda else arg
keys.append(arg_value) keys.append(arg_value)
values.append(arg) values.append(arg)
...@@ -209,7 +210,7 @@ def test_variable_equal(node1, node2): ...@@ -209,7 +210,7 @@ def test_variable_equal(node1, node2):
return False return False
if isinstance(node1, ast.AST): if isinstance(node1, ast.AST):
for k, v in vars(node1).items(): for k, v in vars(node1).items():
if k in ('lineno', 'col_offset', 'ctx'): if k in ('lineno', 'col_offset', 'ctx', 'end_lineno', 'end_col_offset'):
continue continue
if not test_variable_equal(v, getattr(node2, k)): if not test_variable_equal(v, getattr(node2, k)):
return False return False
...@@ -282,7 +283,7 @@ class Transformer(ast.NodeTransformer): ...@@ -282,7 +283,7 @@ class Transformer(ast.NodeTransformer):
annotation = self.stack[-1] annotation = self.stack[-1]
# this is a standalone string, may be an annotation # this is a standalone string, may be an annotation
if type(node) is ast.Expr and type(node.value) is ast.Str: if type(node) is ast.Expr and type(node.value) is ast_Str:
# must not annotate an annotation string # must not annotate an annotation string
assert annotation is None, 'Annotating an annotation' assert annotation is None, 'Annotating an annotation'
return self._visit_string(node) return self._visit_string(node)
...@@ -306,7 +307,7 @@ class Transformer(ast.NodeTransformer): ...@@ -306,7 +307,7 @@ class Transformer(ast.NodeTransformer):
if string.startswith('@nni.training_update'): if string.startswith('@nni.training_update'):
expr = parse_annotation(string[1:]) expr = parse_annotation(string[1:])
call_node = expr.value call_node = expr.value
call_node.args.insert(0, ast.Str(s=self.nas_mode)) call_node.args.insert(0, ast_Str(s=self.nas_mode))
return expr return expr
if string.startswith('@nni.report_intermediate_result') \ if string.startswith('@nni.report_intermediate_result') \
......
...@@ -6,6 +6,8 @@ import numbers ...@@ -6,6 +6,8 @@ import numbers
import astor import astor
from .utils import ast_Num, ast_Str
# pylint: disable=unidiomatic-typecheck # pylint: disable=unidiomatic-typecheck
...@@ -44,7 +46,7 @@ class SearchSpaceGenerator(ast.NodeTransformer): ...@@ -44,7 +46,7 @@ class SearchSpaceGenerator(ast.NodeTransformer):
self.search_space[key]['_value'][mutable_layer] = { self.search_space[key]['_value'][mutable_layer] = {
'layer_choice': [k.s for k in args[2].keys], 'layer_choice': [k.s for k in args[2].keys],
'optional_inputs': [k.s for k in args[5].keys], 'optional_inputs': [k.s for k in args[5].keys],
'optional_input_size': args[6].n if isinstance(args[6], ast.Num) else [args[6].elts[0].n, args[6].elts[1].n] 'optional_input_size': args[6].n if isinstance(args[6], ast_Num) else [args[6].elts[0].n, args[6].elts[1].n]
} }
def visit_Call(self, node): # pylint: disable=invalid-name def visit_Call(self, node): # pylint: disable=invalid-name
...@@ -73,7 +75,7 @@ class SearchSpaceGenerator(ast.NodeTransformer): ...@@ -73,7 +75,7 @@ class SearchSpaceGenerator(ast.NodeTransformer):
# there is a `name` argument # there is a `name` argument
assert len(node.keywords) == 1, 'Smart parameter has keyword argument other than "name"' assert len(node.keywords) == 1, 'Smart parameter has keyword argument other than "name"'
assert node.keywords[0].arg == 'name', 'Smart paramater\'s keyword argument is not "name"' assert node.keywords[0].arg == 'name', 'Smart paramater\'s keyword argument is not "name"'
assert type(node.keywords[0].value) is ast.Str, 'Smart parameter\'s name must be string literal' assert type(node.keywords[0].value) is ast_Str, 'Smart parameter\'s name must be string literal'
name = node.keywords[0].value.s name = node.keywords[0].value.s
specified_name = True specified_name = True
else: else:
...@@ -86,7 +88,7 @@ class SearchSpaceGenerator(ast.NodeTransformer): ...@@ -86,7 +88,7 @@ class SearchSpaceGenerator(ast.NodeTransformer):
# we will use keys in the dict as the choices, which is generated by code_generator according to the args given by user # we will use keys in the dict as the choices, which is generated by code_generator according to the args given by user
assert len(node.args) == 1, 'Smart parameter has arguments other than dict' assert len(node.args) == 1, 'Smart parameter has arguments other than dict'
# check if it is a number or a string and get its value accordingly # check if it is a number or a string and get its value accordingly
args = [key.n if type(key) is ast.Num else key.s for key in node.args[0].keys] args = [key.n if type(key) is ast_Num else key.s for key in node.args[0].keys]
else: else:
# arguments of other functions must be literal number # arguments of other functions must be literal number
assert all(isinstance(ast.literal_eval(astor.to_source(arg)), numbers.Real) for arg in node.args), \ assert all(isinstance(ast.literal_eval(astor.to_source(arg)), numbers.Real) for arg in node.args), \
...@@ -95,7 +97,7 @@ class SearchSpaceGenerator(ast.NodeTransformer): ...@@ -95,7 +97,7 @@ class SearchSpaceGenerator(ast.NodeTransformer):
key = self.module_name + '/' + name + '/' + func key = self.module_name + '/' + name + '/' + func
# store key in ast.Call # store key in ast.Call
node.keywords.append(ast.keyword(arg='key', value=ast.Str(s=key))) node.keywords.append(ast.keyword(arg='key', value=ast_Str(s=key)))
if func == 'function_choice': if func == 'function_choice':
func = 'choice' func = 'choice'
......
...@@ -5,6 +5,8 @@ import ast ...@@ -5,6 +5,8 @@ import ast
import astor import astor
from nni_cmd.common_utils import print_warning from nni_cmd.common_utils import print_warning
from .utils import ast_Num, ast_Str
# pylint: disable=unidiomatic-typecheck # pylint: disable=unidiomatic-typecheck
para_cfg = None para_cfg = None
...@@ -134,7 +136,7 @@ def parse_nni_variable(code): ...@@ -134,7 +136,7 @@ def parse_nni_variable(code):
assert arg.func.value.id == 'nni', 'nni.variable value is not a NNI function' assert arg.func.value.id == 'nni', 'nni.variable value is not a NNI function'
name_str = astor.to_source(name).strip() name_str = astor.to_source(name).strip()
keyword_arg = ast.keyword(arg='name', value=ast.Str(s=name_str)) keyword_arg = ast.keyword(arg='name', value=ast_Str(s=name_str))
arg.keywords.append(keyword_arg) arg.keywords.append(keyword_arg)
if arg.func.attr == 'choice': if arg.func.attr == 'choice':
convert_args_to_dict(arg) convert_args_to_dict(arg)
...@@ -152,7 +154,7 @@ def parse_nni_function(code): ...@@ -152,7 +154,7 @@ def parse_nni_function(code):
convert_args_to_dict(call, with_lambda=True) convert_args_to_dict(call, with_lambda=True)
name_str = astor.to_source(name).strip() name_str = astor.to_source(name).strip()
call.keywords[0].value = ast.Str(s=name_str) call.keywords[0].value = ast_Str(s=name_str)
return call, funcs return call, funcs
...@@ -163,12 +165,12 @@ def convert_args_to_dict(call, with_lambda=False): ...@@ -163,12 +165,12 @@ def convert_args_to_dict(call, with_lambda=False):
""" """
keys, values = list(), list() keys, values = list(), list()
for arg in call.args: for arg in call.args:
if type(arg) in [ast.Str, ast.Num]: if type(arg) in [ast_Str, ast_Num]:
arg_value = arg arg_value = arg
else: else:
# if arg is not a string or a number, we use its source code as the key # if arg is not a string or a number, we use its source code as the key
arg_value = astor.to_source(arg).strip('\n"') arg_value = astor.to_source(arg).strip('\n"')
arg_value = ast.Str(str(arg_value)) arg_value = ast_Str(str(arg_value))
arg = make_lambda(arg) if with_lambda else arg arg = make_lambda(arg) if with_lambda else arg
keys.append(arg_value) keys.append(arg_value)
values.append(arg) values.append(arg)
...@@ -192,7 +194,7 @@ def test_variable_equal(node1, node2): ...@@ -192,7 +194,7 @@ def test_variable_equal(node1, node2):
return False return False
if isinstance(node1, ast.AST): if isinstance(node1, ast.AST):
for k, v in vars(node1).items(): for k, v in vars(node1).items():
if k in ('lineno', 'col_offset', 'ctx'): if k in ('lineno', 'col_offset', 'ctx', 'end_lineno', 'end_col_offset'):
continue continue
if not test_variable_equal(v, getattr(node2, k)): if not test_variable_equal(v, getattr(node2, k)):
return False return False
...@@ -264,7 +266,7 @@ class Transformer(ast.NodeTransformer): ...@@ -264,7 +266,7 @@ class Transformer(ast.NodeTransformer):
annotation = self.stack[-1] annotation = self.stack[-1]
# this is a standalone string, may be an annotation # this is a standalone string, may be an annotation
if type(node) is ast.Expr and type(node.value) is ast.Str: if type(node) is ast.Expr and type(node.value) is ast_Str:
# must not annotate an annotation string # must not annotate an annotation string
assert annotation is None, 'Annotating an annotation' assert annotation is None, 'Annotating an annotation'
return self._visit_string(node) return self._visit_string(node)
...@@ -290,23 +292,23 @@ class Transformer(ast.NodeTransformer): ...@@ -290,23 +292,23 @@ class Transformer(ast.NodeTransformer):
"Please remove this line in the trial code." "Please remove this line in the trial code."
print_warning(deprecated_message) print_warning(deprecated_message)
return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()), return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()),
args=[ast.Str(s='Get next parameter here...')], keywords=[])) args=[ast_Str(s='Get next parameter here...')], keywords=[]))
if string.startswith('@nni.training_update'): if string.startswith('@nni.training_update'):
return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()), return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()),
args=[ast.Str(s='Training update here...')], keywords=[])) args=[ast_Str(s='Training update here...')], keywords=[]))
if string.startswith('@nni.report_intermediate_result'): if string.startswith('@nni.report_intermediate_result'):
module = ast.parse(string[1:]) module = ast.parse(string[1:])
arg = module.body[0].value.args[0] arg = module.body[0].value.args[0]
return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()), return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()),
args=[ast.Str(s='nni.report_intermediate_result: '), arg], keywords=[])) args=[ast_Str(s='nni.report_intermediate_result: '), arg], keywords=[]))
if string.startswith('@nni.report_final_result'): if string.startswith('@nni.report_final_result'):
module = ast.parse(string[1:]) module = ast.parse(string[1:])
arg = module.body[0].value.args[0] arg = module.body[0].value.args[0]
return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()), return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()),
args=[ast.Str(s='nni.report_final_result: '), arg], keywords=[])) args=[ast_Str(s='nni.report_final_result: '), arg], keywords=[]))
if string.startswith('@nni.mutable_layers'): if string.startswith('@nni.mutable_layers'):
return parse_annotation_mutable_layers(string[1:], node.lineno) return parse_annotation_mutable_layers(string[1:], node.lineno)
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import ast
from sys import version_info
if version_info >= (3, 8):
ast_Num = ast_Str = ast_Bytes = ast_NameConstant = ast_Ellipsis = ast.Constant
else:
ast_Num = ast.Num
ast_Str = ast.Str
ast_Bytes = ast.Bytes
ast_NameConstant = ast.NameConstant
ast_Ellipsis = ast.Ellipsis
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment