Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
ebca3cec
Unverified
Commit
ebca3cec
authored
Sep 14, 2020
by
J-shang
Committed by
GitHub
Sep 14, 2020
Browse files
support annotation in python 3.8 (#2881)
Co-authored-by:
Ning Shang
<
nishang@microsoft.com
>
parent
1d8b8e48
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
49 additions
and
29 deletions
+49
-29
tools/nni_annotation/code_generator.py
tools/nni_annotation/code_generator.py
+16
-15
tools/nni_annotation/search_space_generator.py
tools/nni_annotation/search_space_generator.py
+6
-4
tools/nni_annotation/specific_code_generator.py
tools/nni_annotation/specific_code_generator.py
+12
-10
tools/nni_annotation/utils.py
tools/nni_annotation/utils.py
+15
-0
No files found.
tools/nni_annotation/code_generator.py
View file @
ebca3cec
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
import
ast
import
ast
import
astor
import
astor
from
.utils
import
ast_Num
,
ast_Str
# pylint: disable=unidiomatic-typecheck
# pylint: disable=unidiomatic-typecheck
...
@@ -37,13 +38,13 @@ def parse_annotation_mutable_layers(code, lineno, nas_mode):
...
@@ -37,13 +38,13 @@ def parse_annotation_mutable_layers(code, lineno, nas_mode):
for
call
in
value
.
elts
:
for
call
in
value
.
elts
:
assert
type
(
call
)
is
ast
.
Call
,
'Element in layer_choice should be function call'
assert
type
(
call
)
is
ast
.
Call
,
'Element in layer_choice should be function call'
call_name
=
astor
.
to_source
(
call
).
strip
()
call_name
=
astor
.
to_source
(
call
).
strip
()
call_funcs_keys
.
append
(
ast
.
Str
(
s
=
call_name
))
call_funcs_keys
.
append
(
ast
_
Str
(
s
=
call_name
))
call_funcs_values
.
append
(
call
.
func
)
call_funcs_values
.
append
(
call
.
func
)
assert
not
call
.
args
,
'Number of args without keyword should be zero'
assert
not
call
.
args
,
'Number of args without keyword should be zero'
kw_args
=
[]
kw_args
=
[]
kw_values
=
[]
kw_values
=
[]
for
kw
in
call
.
keywords
:
for
kw
in
call
.
keywords
:
kw_args
.
append
(
ast
.
Str
(
s
=
kw
.
arg
))
kw_args
.
append
(
ast
_
Str
(
s
=
kw
.
arg
))
kw_values
.
append
(
kw
.
value
)
kw_values
.
append
(
kw
.
value
)
call_kwargs_values
.
append
(
ast
.
Dict
(
keys
=
kw_args
,
values
=
kw_values
))
call_kwargs_values
.
append
(
ast
.
Dict
(
keys
=
kw_args
,
values
=
kw_values
))
call_funcs
=
ast
.
Dict
(
keys
=
call_funcs_keys
,
values
=
call_funcs_values
)
call_funcs
=
ast
.
Dict
(
keys
=
call_funcs_keys
,
values
=
call_funcs_values
)
...
@@ -57,12 +58,12 @@ def parse_annotation_mutable_layers(code, lineno, nas_mode):
...
@@ -57,12 +58,12 @@ def parse_annotation_mutable_layers(code, lineno, nas_mode):
elif
k
.
id
==
'optional_inputs'
:
elif
k
.
id
==
'optional_inputs'
:
assert
not
fields
[
'optional_inputs'
],
'Duplicated field: optional_inputs'
assert
not
fields
[
'optional_inputs'
],
'Duplicated field: optional_inputs'
assert
type
(
value
)
is
ast
.
List
,
'Value of optional_inputs should be a list'
assert
type
(
value
)
is
ast
.
List
,
'Value of optional_inputs should be a list'
var_names
=
[
ast
.
Str
(
s
=
astor
.
to_source
(
var
).
strip
())
for
var
in
value
.
elts
]
var_names
=
[
ast
_
Str
(
s
=
astor
.
to_source
(
var
).
strip
())
for
var
in
value
.
elts
]
optional_inputs
=
ast
.
Dict
(
keys
=
var_names
,
values
=
value
.
elts
)
optional_inputs
=
ast
.
Dict
(
keys
=
var_names
,
values
=
value
.
elts
)
fields
[
'optional_inputs'
]
=
True
fields
[
'optional_inputs'
]
=
True
elif
k
.
id
==
'optional_input_size'
:
elif
k
.
id
==
'optional_input_size'
:
assert
not
fields
[
'optional_input_size'
],
'Duplicated field: optional_input_size'
assert
not
fields
[
'optional_input_size'
],
'Duplicated field: optional_input_size'
assert
type
(
value
)
is
ast
.
Num
or
type
(
value
)
is
ast
.
List
,
\
assert
type
(
value
)
is
ast
_
Num
or
type
(
value
)
is
ast
.
List
,
\
'Value of optional_input_size should be a number or list'
'Value of optional_input_size should be a number or list'
optional_input_size
=
value
optional_input_size
=
value
fields
[
'optional_input_size'
]
=
True
fields
[
'optional_input_size'
]
=
True
...
@@ -79,8 +80,8 @@ def parse_annotation_mutable_layers(code, lineno, nas_mode):
...
@@ -79,8 +80,8 @@ def parse_annotation_mutable_layers(code, lineno, nas_mode):
mutable_layer_id
=
'mutable_layer_'
+
str
(
mutable_layer_cnt
)
mutable_layer_id
=
'mutable_layer_'
+
str
(
mutable_layer_cnt
)
mutable_layer_cnt
+=
1
mutable_layer_cnt
+=
1
target_call_attr
=
ast
.
Attribute
(
value
=
ast
.
Name
(
id
=
'nni'
,
ctx
=
ast
.
Load
()),
attr
=
'mutable_layer'
,
ctx
=
ast
.
Load
())
target_call_attr
=
ast
.
Attribute
(
value
=
ast
.
Name
(
id
=
'nni'
,
ctx
=
ast
.
Load
()),
attr
=
'mutable_layer'
,
ctx
=
ast
.
Load
())
target_call_args
=
[
ast
.
Str
(
s
=
mutable_id
),
target_call_args
=
[
ast
_
Str
(
s
=
mutable_id
),
ast
.
Str
(
s
=
mutable_layer_id
),
ast
_
Str
(
s
=
mutable_layer_id
),
call_funcs
,
call_funcs
,
call_kwargs
]
call_kwargs
]
if
fields
[
'fixed_inputs'
]:
if
fields
[
'fixed_inputs'
]:
...
@@ -93,8 +94,8 @@ def parse_annotation_mutable_layers(code, lineno, nas_mode):
...
@@ -93,8 +94,8 @@ def parse_annotation_mutable_layers(code, lineno, nas_mode):
target_call_args
.
append
(
optional_input_size
)
target_call_args
.
append
(
optional_input_size
)
else
:
else
:
target_call_args
.
append
(
ast
.
Dict
(
keys
=
[],
values
=
[]))
target_call_args
.
append
(
ast
.
Dict
(
keys
=
[],
values
=
[]))
target_call_args
.
append
(
ast
.
Num
(
n
=
0
))
target_call_args
.
append
(
ast
_
Num
(
n
=
0
))
target_call_args
.
append
(
ast
.
Str
(
s
=
nas_mode
))
target_call_args
.
append
(
ast
_
Str
(
s
=
nas_mode
))
if
nas_mode
in
[
'enas_mode'
,
'oneshot_mode'
,
'darts_mode'
]:
if
nas_mode
in
[
'enas_mode'
,
'oneshot_mode'
,
'darts_mode'
]:
target_call_args
.
append
(
ast
.
Name
(
id
=
'tensorflow'
))
target_call_args
.
append
(
ast
.
Name
(
id
=
'tensorflow'
))
target_call
=
ast
.
Call
(
func
=
target_call_attr
,
args
=
target_call_args
,
keywords
=
[])
target_call
=
ast
.
Call
(
func
=
target_call_attr
,
args
=
target_call_args
,
keywords
=
[])
...
@@ -151,7 +152,7 @@ def parse_nni_variable(code):
...
@@ -151,7 +152,7 @@ def parse_nni_variable(code):
assert
arg
.
func
.
value
.
id
==
'nni'
,
'nni.variable value is not a NNI function'
assert
arg
.
func
.
value
.
id
==
'nni'
,
'nni.variable value is not a NNI function'
name_str
=
astor
.
to_source
(
name
).
strip
()
name_str
=
astor
.
to_source
(
name
).
strip
()
keyword_arg
=
ast
.
keyword
(
arg
=
'name'
,
value
=
ast
.
Str
(
s
=
name_str
))
keyword_arg
=
ast
.
keyword
(
arg
=
'name'
,
value
=
ast
_
Str
(
s
=
name_str
))
arg
.
keywords
.
append
(
keyword_arg
)
arg
.
keywords
.
append
(
keyword_arg
)
if
arg
.
func
.
attr
==
'choice'
:
if
arg
.
func
.
attr
==
'choice'
:
convert_args_to_dict
(
arg
)
convert_args_to_dict
(
arg
)
...
@@ -169,7 +170,7 @@ def parse_nni_function(code):
...
@@ -169,7 +170,7 @@ def parse_nni_function(code):
convert_args_to_dict
(
call
,
with_lambda
=
True
)
convert_args_to_dict
(
call
,
with_lambda
=
True
)
name_str
=
astor
.
to_source
(
name
).
strip
()
name_str
=
astor
.
to_source
(
name
).
strip
()
call
.
keywords
[
0
].
value
=
ast
.
Str
(
s
=
name_str
)
call
.
keywords
[
0
].
value
=
ast
_
Str
(
s
=
name_str
)
return
call
,
funcs
return
call
,
funcs
...
@@ -180,12 +181,12 @@ def convert_args_to_dict(call, with_lambda=False):
...
@@ -180,12 +181,12 @@ def convert_args_to_dict(call, with_lambda=False):
"""
"""
keys
,
values
=
list
(),
list
()
keys
,
values
=
list
(),
list
()
for
arg
in
call
.
args
:
for
arg
in
call
.
args
:
if
type
(
arg
)
in
[
ast
.
Str
,
ast
.
Num
]:
if
type
(
arg
)
in
[
ast
_
Str
,
ast
_
Num
]:
arg_value
=
arg
arg_value
=
arg
else
:
else
:
# if arg is not a string or a number, we use its source code as the key
# if arg is not a string or a number, we use its source code as the key
arg_value
=
astor
.
to_source
(
arg
).
strip
(
'
\n
"'
)
arg_value
=
astor
.
to_source
(
arg
).
strip
(
'
\n
"'
)
arg_value
=
ast
.
Str
(
str
(
arg_value
))
arg_value
=
ast
_
Str
(
str
(
arg_value
))
arg
=
make_lambda
(
arg
)
if
with_lambda
else
arg
arg
=
make_lambda
(
arg
)
if
with_lambda
else
arg
keys
.
append
(
arg_value
)
keys
.
append
(
arg_value
)
values
.
append
(
arg
)
values
.
append
(
arg
)
...
@@ -209,7 +210,7 @@ def test_variable_equal(node1, node2):
...
@@ -209,7 +210,7 @@ def test_variable_equal(node1, node2):
return
False
return
False
if
isinstance
(
node1
,
ast
.
AST
):
if
isinstance
(
node1
,
ast
.
AST
):
for
k
,
v
in
vars
(
node1
).
items
():
for
k
,
v
in
vars
(
node1
).
items
():
if
k
in
(
'lineno'
,
'col_offset'
,
'ctx'
):
if
k
in
(
'lineno'
,
'col_offset'
,
'ctx'
,
'end_lineno'
,
'end_col_offset'
):
continue
continue
if
not
test_variable_equal
(
v
,
getattr
(
node2
,
k
)):
if
not
test_variable_equal
(
v
,
getattr
(
node2
,
k
)):
return
False
return
False
...
@@ -282,7 +283,7 @@ class Transformer(ast.NodeTransformer):
...
@@ -282,7 +283,7 @@ class Transformer(ast.NodeTransformer):
annotation
=
self
.
stack
[
-
1
]
annotation
=
self
.
stack
[
-
1
]
# this is a standalone string, may be an annotation
# this is a standalone string, may be an annotation
if
type
(
node
)
is
ast
.
Expr
and
type
(
node
.
value
)
is
ast
.
Str
:
if
type
(
node
)
is
ast
.
Expr
and
type
(
node
.
value
)
is
ast
_
Str
:
# must not annotate an annotation string
# must not annotate an annotation string
assert
annotation
is
None
,
'Annotating an annotation'
assert
annotation
is
None
,
'Annotating an annotation'
return
self
.
_visit_string
(
node
)
return
self
.
_visit_string
(
node
)
...
@@ -306,7 +307,7 @@ class Transformer(ast.NodeTransformer):
...
@@ -306,7 +307,7 @@ class Transformer(ast.NodeTransformer):
if
string
.
startswith
(
'@nni.training_update'
):
if
string
.
startswith
(
'@nni.training_update'
):
expr
=
parse_annotation
(
string
[
1
:])
expr
=
parse_annotation
(
string
[
1
:])
call_node
=
expr
.
value
call_node
=
expr
.
value
call_node
.
args
.
insert
(
0
,
ast
.
Str
(
s
=
self
.
nas_mode
))
call_node
.
args
.
insert
(
0
,
ast
_
Str
(
s
=
self
.
nas_mode
))
return
expr
return
expr
if
string
.
startswith
(
'@nni.report_intermediate_result'
)
\
if
string
.
startswith
(
'@nni.report_intermediate_result'
)
\
...
...
tools/nni_annotation/search_space_generator.py
View file @
ebca3cec
...
@@ -6,6 +6,8 @@ import numbers
...
@@ -6,6 +6,8 @@ import numbers
import
astor
import
astor
from
.utils
import
ast_Num
,
ast_Str
# pylint: disable=unidiomatic-typecheck
# pylint: disable=unidiomatic-typecheck
...
@@ -44,7 +46,7 @@ class SearchSpaceGenerator(ast.NodeTransformer):
...
@@ -44,7 +46,7 @@ class SearchSpaceGenerator(ast.NodeTransformer):
self
.
search_space
[
key
][
'_value'
][
mutable_layer
]
=
{
self
.
search_space
[
key
][
'_value'
][
mutable_layer
]
=
{
'layer_choice'
:
[
k
.
s
for
k
in
args
[
2
].
keys
],
'layer_choice'
:
[
k
.
s
for
k
in
args
[
2
].
keys
],
'optional_inputs'
:
[
k
.
s
for
k
in
args
[
5
].
keys
],
'optional_inputs'
:
[
k
.
s
for
k
in
args
[
5
].
keys
],
'optional_input_size'
:
args
[
6
].
n
if
isinstance
(
args
[
6
],
ast
.
Num
)
else
[
args
[
6
].
elts
[
0
].
n
,
args
[
6
].
elts
[
1
].
n
]
'optional_input_size'
:
args
[
6
].
n
if
isinstance
(
args
[
6
],
ast
_
Num
)
else
[
args
[
6
].
elts
[
0
].
n
,
args
[
6
].
elts
[
1
].
n
]
}
}
def
visit_Call
(
self
,
node
):
# pylint: disable=invalid-name
def
visit_Call
(
self
,
node
):
# pylint: disable=invalid-name
...
@@ -73,7 +75,7 @@ class SearchSpaceGenerator(ast.NodeTransformer):
...
@@ -73,7 +75,7 @@ class SearchSpaceGenerator(ast.NodeTransformer):
# there is a `name` argument
# there is a `name` argument
assert
len
(
node
.
keywords
)
==
1
,
'Smart parameter has keyword argument other than "name"'
assert
len
(
node
.
keywords
)
==
1
,
'Smart parameter has keyword argument other than "name"'
assert
node
.
keywords
[
0
].
arg
==
'name'
,
'Smart paramater
\'
s keyword argument is not "name"'
assert
node
.
keywords
[
0
].
arg
==
'name'
,
'Smart paramater
\'
s keyword argument is not "name"'
assert
type
(
node
.
keywords
[
0
].
value
)
is
ast
.
Str
,
'Smart parameter
\'
s name must be string literal'
assert
type
(
node
.
keywords
[
0
].
value
)
is
ast
_
Str
,
'Smart parameter
\'
s name must be string literal'
name
=
node
.
keywords
[
0
].
value
.
s
name
=
node
.
keywords
[
0
].
value
.
s
specified_name
=
True
specified_name
=
True
else
:
else
:
...
@@ -86,7 +88,7 @@ class SearchSpaceGenerator(ast.NodeTransformer):
...
@@ -86,7 +88,7 @@ class SearchSpaceGenerator(ast.NodeTransformer):
# we will use keys in the dict as the choices, which is generated by code_generator according to the args given by user
# we will use keys in the dict as the choices, which is generated by code_generator according to the args given by user
assert
len
(
node
.
args
)
==
1
,
'Smart parameter has arguments other than dict'
assert
len
(
node
.
args
)
==
1
,
'Smart parameter has arguments other than dict'
# check if it is a number or a string and get its value accordingly
# check if it is a number or a string and get its value accordingly
args
=
[
key
.
n
if
type
(
key
)
is
ast
.
Num
else
key
.
s
for
key
in
node
.
args
[
0
].
keys
]
args
=
[
key
.
n
if
type
(
key
)
is
ast
_
Num
else
key
.
s
for
key
in
node
.
args
[
0
].
keys
]
else
:
else
:
# arguments of other functions must be literal number
# arguments of other functions must be literal number
assert
all
(
isinstance
(
ast
.
literal_eval
(
astor
.
to_source
(
arg
)),
numbers
.
Real
)
for
arg
in
node
.
args
),
\
assert
all
(
isinstance
(
ast
.
literal_eval
(
astor
.
to_source
(
arg
)),
numbers
.
Real
)
for
arg
in
node
.
args
),
\
...
@@ -95,7 +97,7 @@ class SearchSpaceGenerator(ast.NodeTransformer):
...
@@ -95,7 +97,7 @@ class SearchSpaceGenerator(ast.NodeTransformer):
key
=
self
.
module_name
+
'/'
+
name
+
'/'
+
func
key
=
self
.
module_name
+
'/'
+
name
+
'/'
+
func
# store key in ast.Call
# store key in ast.Call
node
.
keywords
.
append
(
ast
.
keyword
(
arg
=
'key'
,
value
=
ast
.
Str
(
s
=
key
)))
node
.
keywords
.
append
(
ast
.
keyword
(
arg
=
'key'
,
value
=
ast
_
Str
(
s
=
key
)))
if
func
==
'function_choice'
:
if
func
==
'function_choice'
:
func
=
'choice'
func
=
'choice'
...
...
tools/nni_annotation/specific_code_generator.py
View file @
ebca3cec
...
@@ -5,6 +5,8 @@ import ast
...
@@ -5,6 +5,8 @@ import ast
import
astor
import
astor
from
nni_cmd.common_utils
import
print_warning
from
nni_cmd.common_utils
import
print_warning
from
.utils
import
ast_Num
,
ast_Str
# pylint: disable=unidiomatic-typecheck
# pylint: disable=unidiomatic-typecheck
para_cfg
=
None
para_cfg
=
None
...
@@ -134,7 +136,7 @@ def parse_nni_variable(code):
...
@@ -134,7 +136,7 @@ def parse_nni_variable(code):
assert
arg
.
func
.
value
.
id
==
'nni'
,
'nni.variable value is not a NNI function'
assert
arg
.
func
.
value
.
id
==
'nni'
,
'nni.variable value is not a NNI function'
name_str
=
astor
.
to_source
(
name
).
strip
()
name_str
=
astor
.
to_source
(
name
).
strip
()
keyword_arg
=
ast
.
keyword
(
arg
=
'name'
,
value
=
ast
.
Str
(
s
=
name_str
))
keyword_arg
=
ast
.
keyword
(
arg
=
'name'
,
value
=
ast
_
Str
(
s
=
name_str
))
arg
.
keywords
.
append
(
keyword_arg
)
arg
.
keywords
.
append
(
keyword_arg
)
if
arg
.
func
.
attr
==
'choice'
:
if
arg
.
func
.
attr
==
'choice'
:
convert_args_to_dict
(
arg
)
convert_args_to_dict
(
arg
)
...
@@ -152,7 +154,7 @@ def parse_nni_function(code):
...
@@ -152,7 +154,7 @@ def parse_nni_function(code):
convert_args_to_dict
(
call
,
with_lambda
=
True
)
convert_args_to_dict
(
call
,
with_lambda
=
True
)
name_str
=
astor
.
to_source
(
name
).
strip
()
name_str
=
astor
.
to_source
(
name
).
strip
()
call
.
keywords
[
0
].
value
=
ast
.
Str
(
s
=
name_str
)
call
.
keywords
[
0
].
value
=
ast
_
Str
(
s
=
name_str
)
return
call
,
funcs
return
call
,
funcs
...
@@ -163,12 +165,12 @@ def convert_args_to_dict(call, with_lambda=False):
...
@@ -163,12 +165,12 @@ def convert_args_to_dict(call, with_lambda=False):
"""
"""
keys
,
values
=
list
(),
list
()
keys
,
values
=
list
(),
list
()
for
arg
in
call
.
args
:
for
arg
in
call
.
args
:
if
type
(
arg
)
in
[
ast
.
Str
,
ast
.
Num
]:
if
type
(
arg
)
in
[
ast
_
Str
,
ast
_
Num
]:
arg_value
=
arg
arg_value
=
arg
else
:
else
:
# if arg is not a string or a number, we use its source code as the key
# if arg is not a string or a number, we use its source code as the key
arg_value
=
astor
.
to_source
(
arg
).
strip
(
'
\n
"'
)
arg_value
=
astor
.
to_source
(
arg
).
strip
(
'
\n
"'
)
arg_value
=
ast
.
Str
(
str
(
arg_value
))
arg_value
=
ast
_
Str
(
str
(
arg_value
))
arg
=
make_lambda
(
arg
)
if
with_lambda
else
arg
arg
=
make_lambda
(
arg
)
if
with_lambda
else
arg
keys
.
append
(
arg_value
)
keys
.
append
(
arg_value
)
values
.
append
(
arg
)
values
.
append
(
arg
)
...
@@ -192,7 +194,7 @@ def test_variable_equal(node1, node2):
...
@@ -192,7 +194,7 @@ def test_variable_equal(node1, node2):
return
False
return
False
if
isinstance
(
node1
,
ast
.
AST
):
if
isinstance
(
node1
,
ast
.
AST
):
for
k
,
v
in
vars
(
node1
).
items
():
for
k
,
v
in
vars
(
node1
).
items
():
if
k
in
(
'lineno'
,
'col_offset'
,
'ctx'
):
if
k
in
(
'lineno'
,
'col_offset'
,
'ctx'
,
'end_lineno'
,
'end_col_offset'
):
continue
continue
if
not
test_variable_equal
(
v
,
getattr
(
node2
,
k
)):
if
not
test_variable_equal
(
v
,
getattr
(
node2
,
k
)):
return
False
return
False
...
@@ -264,7 +266,7 @@ class Transformer(ast.NodeTransformer):
...
@@ -264,7 +266,7 @@ class Transformer(ast.NodeTransformer):
annotation
=
self
.
stack
[
-
1
]
annotation
=
self
.
stack
[
-
1
]
# this is a standalone string, may be an annotation
# this is a standalone string, may be an annotation
if
type
(
node
)
is
ast
.
Expr
and
type
(
node
.
value
)
is
ast
.
Str
:
if
type
(
node
)
is
ast
.
Expr
and
type
(
node
.
value
)
is
ast
_
Str
:
# must not annotate an annotation string
# must not annotate an annotation string
assert
annotation
is
None
,
'Annotating an annotation'
assert
annotation
is
None
,
'Annotating an annotation'
return
self
.
_visit_string
(
node
)
return
self
.
_visit_string
(
node
)
...
@@ -290,23 +292,23 @@ class Transformer(ast.NodeTransformer):
...
@@ -290,23 +292,23 @@ class Transformer(ast.NodeTransformer):
"Please remove this line in the trial code."
"Please remove this line in the trial code."
print_warning
(
deprecated_message
)
print_warning
(
deprecated_message
)
return
ast
.
Expr
(
value
=
ast
.
Call
(
func
=
ast
.
Name
(
id
=
'print'
,
ctx
=
ast
.
Load
()),
return
ast
.
Expr
(
value
=
ast
.
Call
(
func
=
ast
.
Name
(
id
=
'print'
,
ctx
=
ast
.
Load
()),
args
=
[
ast
.
Str
(
s
=
'Get next parameter here...'
)],
keywords
=
[]))
args
=
[
ast
_
Str
(
s
=
'Get next parameter here...'
)],
keywords
=
[]))
if
string
.
startswith
(
'@nni.training_update'
):
if
string
.
startswith
(
'@nni.training_update'
):
return
ast
.
Expr
(
value
=
ast
.
Call
(
func
=
ast
.
Name
(
id
=
'print'
,
ctx
=
ast
.
Load
()),
return
ast
.
Expr
(
value
=
ast
.
Call
(
func
=
ast
.
Name
(
id
=
'print'
,
ctx
=
ast
.
Load
()),
args
=
[
ast
.
Str
(
s
=
'Training update here...'
)],
keywords
=
[]))
args
=
[
ast
_
Str
(
s
=
'Training update here...'
)],
keywords
=
[]))
if
string
.
startswith
(
'@nni.report_intermediate_result'
):
if
string
.
startswith
(
'@nni.report_intermediate_result'
):
module
=
ast
.
parse
(
string
[
1
:])
module
=
ast
.
parse
(
string
[
1
:])
arg
=
module
.
body
[
0
].
value
.
args
[
0
]
arg
=
module
.
body
[
0
].
value
.
args
[
0
]
return
ast
.
Expr
(
value
=
ast
.
Call
(
func
=
ast
.
Name
(
id
=
'print'
,
ctx
=
ast
.
Load
()),
return
ast
.
Expr
(
value
=
ast
.
Call
(
func
=
ast
.
Name
(
id
=
'print'
,
ctx
=
ast
.
Load
()),
args
=
[
ast
.
Str
(
s
=
'nni.report_intermediate_result: '
),
arg
],
keywords
=
[]))
args
=
[
ast
_
Str
(
s
=
'nni.report_intermediate_result: '
),
arg
],
keywords
=
[]))
if
string
.
startswith
(
'@nni.report_final_result'
):
if
string
.
startswith
(
'@nni.report_final_result'
):
module
=
ast
.
parse
(
string
[
1
:])
module
=
ast
.
parse
(
string
[
1
:])
arg
=
module
.
body
[
0
].
value
.
args
[
0
]
arg
=
module
.
body
[
0
].
value
.
args
[
0
]
return
ast
.
Expr
(
value
=
ast
.
Call
(
func
=
ast
.
Name
(
id
=
'print'
,
ctx
=
ast
.
Load
()),
return
ast
.
Expr
(
value
=
ast
.
Call
(
func
=
ast
.
Name
(
id
=
'print'
,
ctx
=
ast
.
Load
()),
args
=
[
ast
.
Str
(
s
=
'nni.report_final_result: '
),
arg
],
keywords
=
[]))
args
=
[
ast
_
Str
(
s
=
'nni.report_final_result: '
),
arg
],
keywords
=
[]))
if
string
.
startswith
(
'@nni.mutable_layers'
):
if
string
.
startswith
(
'@nni.mutable_layers'
):
return
parse_annotation_mutable_layers
(
string
[
1
:],
node
.
lineno
)
return
parse_annotation_mutable_layers
(
string
[
1
:],
node
.
lineno
)
...
...
tools/nni_annotation/utils.py
0 → 100644
View file @
ebca3cec
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
ast
from
sys
import
version_info
if
version_info
>=
(
3
,
8
):
ast_Num
=
ast_Str
=
ast_Bytes
=
ast_NameConstant
=
ast_Ellipsis
=
ast
.
Constant
else
:
ast_Num
=
ast
.
Num
ast_Str
=
ast
.
Str
ast_Bytes
=
ast
.
Bytes
ast_NameConstant
=
ast
.
NameConstant
ast_Ellipsis
=
ast
.
Ellipsis
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment