Unverified Commit 90f96ef5 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Add support for generator in serializer (#4465)

parent 253dbfd8
......@@ -220,6 +220,11 @@ def trace(cls_or_func: T = None, *, kw_only: bool = True) -> Union[T, Traceable]
If ``kw_only`` is true, try to convert all parameters into kwargs type. This is done by inspecting the argument
list and types. This can be useful to extract semantics, but can be tricky in some corner cases.
.. warning::
Generators will be first expanded into a list, and the resulting list will be further passed into the wrapped function/class.
This might hang when generators produce an infinite sequence. We might introduce an API to control this behavior in future.
Example:
.. code-block:: python
......@@ -431,6 +436,18 @@ def _argument_processor(arg):
return arg
def _formulate_single_argument(arg):
# this is different from argument processor
# it directly apply the transformation on the stored arguments
# expand generator into list
# Note that some types that are generator (such as range(10)) may not be identified as generator here.
if isinstance(arg, types.GeneratorType):
arg = list(arg)
return arg
def _formulate_arguments(func, args, kwargs, kw_only, is_class_init=False):
# This is to formulate the arguments and make them well-formed.
if kw_only:
......@@ -451,6 +468,9 @@ def _formulate_arguments(func, args, kwargs, kw_only, is_class_init=False):
args, kwargs = [], full_args
args = [_formulate_single_argument(arg) for arg in args]
kwargs = {k: _formulate_single_argument(arg) for k, arg in kwargs.items()}
return list(args), kwargs
......
......@@ -221,10 +221,28 @@ def test_lightning_earlystop():
assert any(isinstance(callback, EarlyStopping) for callback in trainer.callbacks)
def test_generator():
import torch.nn as nn
import torch.optim as optim
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 10, 1)
def forward(self, x):
return self.conv(x)
model = Net()
optimizer = nni.trace(optim.Adam)(model.parameters())
print(optimizer.trace_kwargs)
if __name__ == '__main__':
# test_simple_class()
# test_external_class()
# test_nested_class()
# test_unserializable()
# test_basic_unit()
test_multiprocessing_dataloader()
test_generator()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment