Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
763f2c87
Unverified
Commit
763f2c87
authored
Jan 27, 2022
by
Yuge Zhang
Committed by
GitHub
Jan 27, 2022
Browse files
Fix serializer for complex kinds of arguments (#4487)
parent
bb0a8700
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
102 additions
and
15 deletions
+102
-15
nni/common/serializer.py
nni/common/serializer.py
+56
-13
test/ut/sdk/imported/_test_serializer_py38.py
test/ut/sdk/imported/_test_serializer_py38.py
+10
-0
test/ut/sdk/test_serializer.py
test/ut/sdk/test_serializer.py
+36
-2
No files found.
nni/common/serializer.py
View file @
763f2c87
...
...
@@ -219,6 +219,7 @@ def trace(cls_or_func: T = None, *, kw_only: bool = True) -> Union[T, Traceable]
If ``kw_only`` is true, try to convert all parameters into kwargs type. This is done by inspecting the argument
list and types. This can be useful to extract semantics, but can be tricky in some corner cases.
Therefore, in some cases, some positional arguments will still be kept.
.. warning::
...
...
@@ -451,27 +452,69 @@ def _formulate_single_argument(arg):
def
_formulate_arguments
(
func
,
args
,
kwargs
,
kw_only
,
is_class_init
=
False
):
# This is to formulate the arguments and make them well-formed.
if
kw_only
:
# Match arguments with given arguments, so that we can use keyword arguments as much as possible.
# Mutators don't like positional arguments. Positional arguments might not supply enough information.
# get arguments passed to a function, and save it as a dict
argname_list
=
list
(
inspect
.
signature
(
func
).
parameters
.
keys
())
insp_parameters
=
inspect
.
signature
(
func
).
parameters
argname_list
=
list
(
insp_parameters
.
keys
())
if
is_class_init
:
argname_list
=
argname_list
[
1
:]
full_args
=
{}
# match arguments with given arguments
# args should be longer than given list, because args can be used in a kwargs way
assert
len
(
args
)
<=
len
(
argname_list
),
f
'Length of
{
args
}
is greater than length of
{
argname_list
}
.'
for
argname
,
value
in
zip
(
argname_list
,
args
):
full_args
[
argname
]
=
value
positional_args
=
[]
keyword_args
=
{}
# According to https://docs.python.org/3/library/inspect.html#inspect.Parameter, there are five kinds of parameters
# in Python. We only try to handle POSITIONAL_ONLY and POSITIONAL_OR_KEYWORD here.
# Example:
# For foo(a, b, *c, **d), a and b and c should be kept.
# For foo(a, b, /, d), a and b should be kept.
for
i
,
value
in
enumerate
(
args
):
if
i
>=
len
(
argname_list
):
raise
ValueError
(
f
'
{
func
}
receives extra argument:
{
value
}
.'
)
argname
=
argname_list
[
i
]
if
insp_parameters
[
argname
].
kind
==
inspect
.
Parameter
.
POSITIONAL_ONLY
:
# positional only. have to be kept.
positional_args
.
append
(
value
)
elif
insp_parameters
[
argname
].
kind
==
inspect
.
Parameter
.
POSITIONAL_OR_KEYWORD
:
# this should be the most common case
keyword_args
[
argname
]
=
value
elif
insp_parameters
[
argname
].
kind
==
inspect
.
Parameter
.
VAR_POSITIONAL
:
# Any previous preprocessing might be wrong. Clean them all.
# Any parameters that appear before a VAR_POSITIONAL should be kept positional.
# Otherwise, VAR_POSITIONAL might not work.
# For the cases I've tested, any parameters that appear after a VAR_POSITIONAL are considered keyword only.
# But, if args is not long enough for VAR_POSITIONAL to be encountered, they should be handled by other if-branches.
positional_args
=
args
keyword_args
=
{}
break
else
:
# kind has to be one of `KEYWORD_ONLY` and `VAR_KEYWORD`
raise
ValueError
(
f
'
{
func
}
receives positional argument:
{
value
}
, but the parameter type is found to be keyword only.'
)
# use kwargs to override
full
_args
.
update
(
kwargs
)
keyword
_args
.
update
(
kwargs
)
args
,
kwargs
=
[],
full_args
if
positional_args
:
# Raise a warning if some arguments are not convertible to keyword arguments.
warnings
.
warn
(
f
'Found positional arguments
{
positional_args
}
should processing parameters of
{
func
}
. '
'We recommend always using keyword arguments to specify parameters. '
'For example: `nn.LSTM(input_size=2, hidden_size=2)` instead of `nn.LSTM(2, 2)`.'
)
else
:
# keep them unprocessed
positional_args
,
keyword_args
=
args
,
kwargs
args
=
[
_formulate_single_argument
(
arg
)
for
arg
in
args
]
kwargs
=
{
k
:
_formulate_single_argument
(
arg
)
for
k
,
arg
in
kwargs
.
items
()}
# do some extra conversions to the arguments.
positional_args
=
[
_formulate_single_argument
(
arg
)
for
arg
in
positional_args
]
keyword_args
=
{
k
:
_formulate_single_argument
(
arg
)
for
k
,
arg
in
keyword_args
.
items
()}
return
list
(
args
)
,
k
w
args
return
positional_
args
,
k
eyword_
args
def
_is_function
(
obj
:
Any
)
->
bool
:
...
...
test/ut/sdk/imported/_test_serializer_py38.py
0 → 100644
View file @
763f2c87
import
nni
def
test_positional_only
():
def
foo
(
a
,
b
,
/
,
c
):
pass
d
=
nni
.
trace
(
foo
)(
1
,
2
,
c
=
3
)
assert
d
.
trace_args
==
[
1
,
2
]
assert
d
.
trace_kwargs
==
dict
(
c
=
3
)
test/ut/sdk/test_serializer.py
View file @
763f2c87
import
math
import
re
import
sys
from
pathlib
import
Path
...
...
@@ -16,6 +15,10 @@ if True: # prevent auto formatting
sys
.
path
.
insert
(
0
,
Path
(
__file__
).
parent
.
as_posix
())
from
imported.model
import
ImportTest
# this test cannot be directly put in this file. It will cause syntax error for python <= 3.7.
if
tuple
(
sys
.
version_info
)
>=
(
3
,
8
):
from
imported._test_serializer_py38
import
test_positional_only
@
nni
.
trace
class
SimpleClass
:
...
...
@@ -238,6 +241,36 @@ def test_generator():
print
(
optimizer
.
trace_kwargs
)
def
test_arguments_kind
():
def
foo
(
a
,
b
,
*
c
,
**
d
):
pass
d
=
nni
.
trace
(
foo
)(
1
,
2
,
3
,
4
)
assert
d
.
trace_args
==
[
1
,
2
,
3
,
4
]
assert
d
.
trace_kwargs
==
{}
d
=
nni
.
trace
(
foo
)(
a
=
1
,
b
=
2
)
assert
d
.
trace_kwargs
==
dict
(
a
=
1
,
b
=
2
)
d
=
nni
.
trace
(
foo
)(
1
,
b
=
2
)
# this is not perfect, but it's safe
assert
d
.
trace_kwargs
==
dict
(
a
=
1
,
b
=
2
)
def
foo
(
a
,
*
,
b
=
3
,
c
=
5
):
pass
d
=
nni
.
trace
(
foo
)(
1
,
b
=
2
,
c
=
3
)
assert
d
.
trace_kwargs
==
dict
(
a
=
1
,
b
=
2
,
c
=
3
)
import
torch.nn
as
nn
lstm
=
nni
.
trace
(
nn
.
LSTM
)(
2
,
2
)
assert
lstm
.
input_size
==
2
assert
lstm
.
hidden_size
==
2
assert
lstm
.
trace_args
==
[
2
,
2
]
lstm
=
nni
.
trace
(
nn
.
LSTM
)(
input_size
=
2
,
hidden_size
=
2
)
assert
lstm
.
trace_kwargs
==
{
'input_size'
:
2
,
'hidden_size'
:
2
}
if
__name__
==
'__main__'
:
# test_simple_class()
...
...
@@ -245,4 +278,5 @@ if __name__ == '__main__':
# test_nested_class()
# test_unserializable()
# test_basic_unit()
test_generator
()
# test_generator()
test_arguments_kind
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment