Unverified Commit 763f2c87 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Fix serializer for complex kinds of arguments (#4487)

parent bb0a8700
...@@ -219,6 +219,7 @@ def trace(cls_or_func: T = None, *, kw_only: bool = True) -> Union[T, Traceable] ...@@ -219,6 +219,7 @@ def trace(cls_or_func: T = None, *, kw_only: bool = True) -> Union[T, Traceable]
If ``kw_only`` is true, try to convert all parameters into kwargs type. This is done by inspecting the argument If ``kw_only`` is true, try to convert all parameters into kwargs type. This is done by inspecting the argument
list and types. This can be useful to extract semantics, but can be tricky in some corner cases. list and types. This can be useful to extract semantics, but can be tricky in some corner cases.
Therefore, in some cases, some positional arguments will still be kept.
.. warning:: .. warning::
...@@ -451,27 +452,69 @@ def _formulate_single_argument(arg): ...@@ -451,27 +452,69 @@ def _formulate_single_argument(arg):
def _formulate_arguments(func, args, kwargs, kw_only, is_class_init=False): def _formulate_arguments(func, args, kwargs, kw_only, is_class_init=False):
# This is to formulate the arguments and make them well-formed. # This is to formulate the arguments and make them well-formed.
if kw_only: if kw_only:
# Match arguments with given arguments, so that we can use keyword arguments as much as possible.
# Mutators don't like positional arguments. Positional arguments might not supply enough information.
# get arguments passed to a function, and save it as a dict # get arguments passed to a function, and save it as a dict
argname_list = list(inspect.signature(func).parameters.keys()) insp_parameters = inspect.signature(func).parameters
argname_list = list(insp_parameters.keys())
if is_class_init: if is_class_init:
argname_list = argname_list[1:] argname_list = argname_list[1:]
full_args = {} positional_args = []
keyword_args = {}
# According to https://docs.python.org/3/library/inspect.html#inspect.Parameter, there are five kinds of parameters
# in Python. We only try to handle POSITIONAL_ONLY and POSITIONAL_OR_KEYWORD here.
# Example:
# For foo(a, b, *c, **d), a and b and c should be kept.
# For foo(a, b, /, d), a and b should be kept.
for i, value in enumerate(args):
if i >= len(argname_list):
raise ValueError(f'{func} receives extra argument: {value}.')
argname = argname_list[i]
if insp_parameters[argname].kind == inspect.Parameter.POSITIONAL_ONLY:
# positional only. have to be kept.
positional_args.append(value)
elif insp_parameters[argname].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD:
# this should be the most common case
keyword_args[argname] = value
elif insp_parameters[argname].kind == inspect.Parameter.VAR_POSITIONAL:
# Any previous preprocessing might be wrong. Clean them all.
# Any parameters that appear before a VAR_POSITIONAL should be kept positional.
# Otherwise, VAR_POSITIONAL might not work.
# For the cases I've tested, any parameters that appear after a VAR_POSITIONAL are considered keyword only.
# But, if args is not long enough for VAR_POSITIONAL to be encountered, they should be handled by other if-branches.
positional_args = args
keyword_args = {}
break
# match arguments with given arguments else:
# args should be longer than given list, because args can be used in a kwargs way # kind has to be one of `KEYWORD_ONLY` and `VAR_KEYWORD`
assert len(args) <= len(argname_list), f'Length of {args} is greater than length of {argname_list}.' raise ValueError(f'{func} receives positional argument: {value}, but the parameter type is found to be keyword only.')
for argname, value in zip(argname_list, args):
full_args[argname] = value
# use kwargs to override # use kwargs to override
full_args.update(kwargs) keyword_args.update(kwargs)
args, kwargs = [], full_args if positional_args:
# Raise a warning if some arguments are not convertible to keyword arguments.
warnings.warn(f'Found positional arguments {positional_args} should processing parameters of {func}. '
'We recommend always using keyword arguments to specify parameters. '
'For example: `nn.LSTM(input_size=2, hidden_size=2)` instead of `nn.LSTM(2, 2)`.')
else:
# keep them unprocessed
positional_args, keyword_args = args, kwargs
args = [_formulate_single_argument(arg) for arg in args] # do some extra conversions to the arguments.
kwargs = {k: _formulate_single_argument(arg) for k, arg in kwargs.items()} positional_args = [_formulate_single_argument(arg) for arg in positional_args]
keyword_args = {k: _formulate_single_argument(arg) for k, arg in keyword_args.items()}
return list(args), kwargs return positional_args, keyword_args
def _is_function(obj: Any) -> bool: def _is_function(obj: Any) -> bool:
......
import nni
def test_positional_only():
def foo(a, b, /, c):
pass
d = nni.trace(foo)(1, 2, c=3)
assert d.trace_args == [1, 2]
assert d.trace_kwargs == dict(c=3)
import math import math
import re
import sys import sys
from pathlib import Path from pathlib import Path
...@@ -16,6 +15,10 @@ if True: # prevent auto formatting ...@@ -16,6 +15,10 @@ if True: # prevent auto formatting
sys.path.insert(0, Path(__file__).parent.as_posix()) sys.path.insert(0, Path(__file__).parent.as_posix())
from imported.model import ImportTest from imported.model import ImportTest
# this test cannot be directly put in this file. It will cause syntax error for python <= 3.7.
if tuple(sys.version_info) >= (3, 8):
from imported._test_serializer_py38 import test_positional_only
@nni.trace @nni.trace
class SimpleClass: class SimpleClass:
...@@ -238,6 +241,36 @@ def test_generator(): ...@@ -238,6 +241,36 @@ def test_generator():
print(optimizer.trace_kwargs) print(optimizer.trace_kwargs)
def test_arguments_kind():
def foo(a, b, *c, **d):
pass
d = nni.trace(foo)(1, 2, 3, 4)
assert d.trace_args == [1, 2, 3, 4]
assert d.trace_kwargs == {}
d = nni.trace(foo)(a=1, b=2)
assert d.trace_kwargs == dict(a=1, b=2)
d = nni.trace(foo)(1, b=2)
# this is not perfect, but it's safe
assert d.trace_kwargs == dict(a=1, b=2)
def foo(a, *, b=3, c=5):
pass
d = nni.trace(foo)(1, b=2, c=3)
assert d.trace_kwargs == dict(a=1, b=2, c=3)
import torch.nn as nn
lstm = nni.trace(nn.LSTM)(2, 2)
assert lstm.input_size == 2
assert lstm.hidden_size == 2
assert lstm.trace_args == [2, 2]
lstm = nni.trace(nn.LSTM)(input_size=2, hidden_size=2)
assert lstm.trace_kwargs == {'input_size': 2, 'hidden_size': 2}
if __name__ == '__main__': if __name__ == '__main__':
# test_simple_class() # test_simple_class()
...@@ -245,4 +278,5 @@ if __name__ == '__main__': ...@@ -245,4 +278,5 @@ if __name__ == '__main__':
# test_nested_class() # test_nested_class()
# test_unserializable() # test_unserializable()
# test_basic_unit() # test_basic_unit()
test_generator() # test_generator()
test_arguments_kind()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment