"...gpu/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "5384c7d7716860214f0d4bb3906da66fd00151a7"
Unverified Commit 21abc280 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Fix #4434: support pickle in serializer (#4552)

parent c447249c
......@@ -10,7 +10,7 @@ from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from nni.common.serializer import _trace_cls
from nni.common.serializer import Traceable
from nni.common.serializer import Traceable, is_traceable
__all__ = ['OptimizerConstructHelper', 'LRSchedulerConstructHelper']
......@@ -80,14 +80,14 @@ class OptimizerConstructHelper(ConstructHelper):
@staticmethod
def from_trace(model: Module, optimizer_trace: Traceable):
assert isinstance(optimizer_trace, Traceable), \
assert is_traceable(optimizer_trace), \
'Please use nni.trace to wrap the optimizer class before initialize the optimizer.'
assert isinstance(optimizer_trace, Optimizer), \
'It is not an instance of torch.nn.Optimizer.'
return OptimizerConstructHelper(model,
optimizer_trace._get_nni_attr('symbol'),
*optimizer_trace._get_nni_attr('args'),
**optimizer_trace._get_nni_attr('kwargs'))
optimizer_trace.trace_symbol,
*optimizer_trace.trace_args,
**optimizer_trace.trace_kwargs)
class LRSchedulerConstructHelper(ConstructHelper):
......@@ -112,7 +112,7 @@ class LRSchedulerConstructHelper(ConstructHelper):
@staticmethod
def from_trace(lr_scheduler_trace: Traceable):
assert isinstance(lr_scheduler_trace, Traceable), \
assert is_traceable(lr_scheduler_trace), \
'Please use nni.trace to wrap the lr scheduler class before initialize the scheduler.'
assert isinstance(lr_scheduler_trace, _LRScheduler), \
'It is not an instance of torch.nn.lr_scheduler._LRScheduler.'
......
......@@ -5,6 +5,7 @@ import copy
import functools
import inspect
import numbers
import sys
import types
import warnings
from io import IOBase
......@@ -13,7 +14,7 @@ from typing import Any, Dict, List, Optional, TypeVar, Union
import cloudpickle # use cloudpickle as backend for unserializable types and instances
import json_tricks # use json_tricks as serializer backend
__all__ = ['trace', 'dump', 'load', 'PayloadTooLarge', 'Translatable', 'Traceable', 'is_traceable']
__all__ = ['trace', 'dump', 'load', 'PayloadTooLarge', 'Translatable', 'Traceable', 'is_traceable', 'is_wrapped_with_trace']
T = TypeVar('T')
......@@ -23,46 +24,43 @@ class PayloadTooLarge(Exception):
pass
class Traceable(abc.ABC):
class Traceable:
"""
A traceable object have copy and dict. Copy and mutate are used to copy the object for further mutations.
Dict returns a TraceDictType to enable serialization.
"""
@abc.abstractmethod
def trace_copy(self) -> 'Traceable':
"""
Perform a shallow copy.
NOTE: NONE of the attributes will be preserved.
This is the one that should be used when you want to "mutate" a serializable object.
"""
...
raise NotImplementedError()
@property
@abc.abstractmethod
def trace_symbol(self) -> Any:
"""
Symbol object. Could be a class or a function.
``get_hybrid_cls_or_func_name`` and ``import_cls_or_func_from_hybrid_name`` is a pair to
convert the symbol into a string and convert the string back to symbol.
"""
...
raise NotImplementedError()
@property
@abc.abstractmethod
def trace_args(self) -> List[Any]:
"""
List of positional arguments passed to symbol. Usually empty if ``kw_only`` is true,
in which case all the positional arguments are converted into keyword arguments.
"""
...
raise NotImplementedError()
@property
@abc.abstractmethod
def trace_kwargs(self) -> Dict[str, Any]:
"""
Dict of keyword arguments.
"""
...
raise NotImplementedError()
class Translatable(abc.ABC):
......@@ -84,13 +82,27 @@ class Translatable(abc.ABC):
def is_traceable(obj: Any) -> bool:
"""
Check whether an object is a traceable instance (not type).
Check whether an object is a traceable instance or type.
Note that an object is traceable only means that it implements the "Traceable" interface,
and the properties have been implemented. It doesn't necessary mean that its type is wrapped with trace,
because the properties could be added **after** the instance has been created.
"""
return hasattr(obj, 'trace_copy') and \
hasattr(obj, 'trace_symbol') and \
hasattr(obj, 'trace_args') and \
hasattr(obj, 'trace_kwargs') and \
not inspect.isclass(obj)
hasattr(obj, 'trace_kwargs')
def is_wrapped_with_trace(cls_or_func: Any) -> bool:
"""
Check whether a function or class is already wrapped with ``@nni.trace``.
If a class or function is already wrapped with trace, then the created object must be "traceable".
"""
return getattr(cls_or_func, '_traced', False) and (
not hasattr(cls_or_func, '__dict__') or # in case it's a function
'_traced' in cls_or_func.__dict__ # must be in this class, super-class traced doesn't count
)
class SerializableObject(Traceable):
......@@ -160,6 +172,15 @@ class SerializableObject(Traceable):
def inject_trace_info(obj: Any, symbol: T, args: List[Any], kwargs: Dict[str, Any]) -> Any:
# If an object is already created, this can be a fix so that the necessary info are re-injected into the object.
# Make obj complying with the interface of traceable, though we cannot change its base class.
obj.__dict__.update(_nni_symbol=symbol, _nni_args=args, _nni_kwargs=kwargs)
return obj
def _make_class_traceable(cls: T, create_wrapper: bool = False) -> T:
# Make an already exist class traceable, without creating a new class.
# Should be used together with `inject_trace_info`.
def getter_factory(x):
return lambda self: self.__dict__['_nni_' + x]
......@@ -184,20 +205,18 @@ def inject_trace_info(obj: Any, symbol: T, args: List[Any], kwargs: Dict[str, An
'trace_copy': trace_copy
}
if hasattr(obj, '__class__') and hasattr(obj, '__dict__'):
if not create_wrapper:
for name, method in attributes.items():
setattr(obj.__class__, name, method)
setattr(cls, name, method)
return cls
else:
wrapper = type('wrapper', (Traceable, type(obj)), attributes)
obj = wrapper(obj) # pylint: disable=abstract-class-instantiated
# make obj complying with the interface of traceable, though we cannot change its base class
obj.__dict__.update(_nni_symbol=symbol, _nni_args=args, _nni_kwargs=kwargs)
return obj
# sometimes create_wrapper is mandatory, e.g., for built-in types like list/int.
# but I don't want to check here because it's unreliable.
wrapper = type('wrapper', (Traceable, cls), attributes)
return wrapper
def trace(cls_or_func: T = None, *, kw_only: bool = True) -> Union[T, Traceable]:
def trace(cls_or_func: T = None, *, kw_only: bool = True, inheritable: bool = False) -> Union[T, Traceable]:
"""
Annotate a function or a class if you want to preserve where it comes from.
This is usually used in the following scenarios:
......@@ -221,6 +240,9 @@ def trace(cls_or_func: T = None, *, kw_only: bool = True) -> Union[T, Traceable]
list and types. This can be useful to extract semantics, but can be tricky in some corner cases.
Therefore, in some cases, some positional arguments will still be kept.
If ``inheritable`` is true, the trace information from superclass will also be available in subclass.
This however, will make the subclass un-trace-able. Note that this argument has no effect when tracing functions.
.. warning::
Generators will be first expanded into a list, and the resulting list will be further passed into the wrapped function/class.
......@@ -237,10 +259,10 @@ def trace(cls_or_func: T = None, *, kw_only: bool = True) -> Union[T, Traceable]
def wrap(cls_or_func):
# already annotated, do nothing
if getattr(cls_or_func, '_traced', False):
if is_wrapped_with_trace(cls_or_func):
return cls_or_func
if isinstance(cls_or_func, type):
cls_or_func = _trace_cls(cls_or_func, kw_only)
cls_or_func = _trace_cls(cls_or_func, kw_only, inheritable=inheritable)
elif _is_function(cls_or_func):
cls_or_func = _trace_func(cls_or_func, kw_only)
else:
......@@ -353,11 +375,60 @@ def load(string: Optional[str] = None, *, fp: Optional[Any] = None, ignore_comme
return json_tricks.load(fp, obj_pairs_hooks=hooks, **json_tricks_kwargs)
def _trace_cls(base, kw_only, call_super=True):
def _trace_cls(base, kw_only, call_super=True, inheritable=False):
# the implementation to trace a class is to store a copy of init arguments
# this won't support class that defines a customized new but should work for most cases
class wrapper(SerializableObject, base):
if sys.platform != 'linux':
if not call_super:
raise ValueError("'call_super' is mandatory to be set true on non-linux platform")
try:
# In non-linux envs, dynamically creating new classes doesn't work with pickle.
# We have to replace the ``__init__`` with a new ``__init__``.
# This, however, causes side-effects where the replacement is not intended.
# This also doesn't work built-in types (e.g., OrderedDict), and the replacement
# won't be effective any more if ``nni.trace`` is called in-place (e.g., ``nni.trace(nn.Conv2d)(...)``).
original_init = base.__init__
# Makes the new init have the exact same signature as the old one,
# so as to make pytorch-lightning happy.
# https://github.com/PyTorchLightning/pytorch-lightning/blob/4cc05b2cf98e49168a5f5dc265647d75d1d3aae9/pytorch_lightning/utilities/parsing.py#L143
@functools.wraps(original_init)
def new_init(self, *args, **kwargs):
args, kwargs = _formulate_arguments(original_init, args, kwargs, kw_only, is_class_init=True)
original_init(
self,
*[_argument_processor(arg) for arg in args],
**{kw: _argument_processor(arg) for kw, arg in kwargs.items()}
)
inject_trace_info(self, base, args, kwargs)
base.__init__ = new_init
base = _make_class_traceable(base)
return base
except TypeError:
warnings.warn("In-place __init__ replacement failed in `@nni.trace`, probably because the type is a built-in/extension type, "
"and it's __init__ can't be replaced. `@nni.trace` is now falling back to the 'inheritance' approach. "
"However, this could cause issues when using pickle. See https://github.com/microsoft/nni/issues/4434",
RuntimeWarning)
# This is trying to solve the case where superclass and subclass are both decorated with @nni.trace.
# We use a metaclass to "unwrap" the superclass.
# However, this doesn't work if:
# 1. Base class already has a customized metaclass. We will raise error in that class.
# 2. SerializableObject in ancester (instead of parent). I think this case is rare and I didn't handle this case yet. FIXME
if type(base) is type and not inheritable:
metaclass = _unwrap_metaclass
else:
metaclass = type
if SerializableObject in inspect.getmro(base):
raise TypeError(f"{base} has a superclass already decorated with trace, and it's using a customized metaclass {type(base)}. "
"Please either use the default metaclass, or remove trace from the super-class.")
class wrapper(SerializableObject, base, metaclass=metaclass):
def __init__(self, *args, **kwargs):
# store a copy of initial parameters
args, kwargs = _formulate_arguments(base.__init__, args, kwargs, kw_only, is_class_init=True)
......@@ -365,6 +436,32 @@ def _trace_cls(base, kw_only, call_super=True):
# calling serializable object init to initialize the full object
super().__init__(symbol=base, args=args, kwargs=kwargs, call_super=call_super)
def __reduce__(self):
# The issue that decorator and pickler doesn't play well together is well known.
# The workaround solution is to use a fool class (_pickling_object) which pretends to be the pickled object.
# We then put the original type, as well as args and kwargs in its `__new__` argument.
# I suspect that their could still be problems when things get complex,
# e.g., the wrapped class has a custom pickling (`__reduce__``) or `__new__`.
# But it can't be worse because the previous pickle doesn't work at all.
#
# Linked issue: https://github.com/microsoft/nni/issues/4434
# SO: https://stackoverflow.com/questions/52185507/pickle-and-decorated-classes-picklingerror-not-the-same-object
# Store the inner class. The wrapped class couldn't be properly pickled.
type_ = cloudpickle.dumps(type(self).__wrapped__)
# in case they have customized ``__getstate__``.
if hasattr(self, '__getstate__'):
obj_ = self.__getstate__()
else:
obj_ = self.__dict__
# Pickle can't handle type objects.
if '_nni_symbol' in obj_:
obj_['_nni_symbol'] = cloudpickle.dumps(obj_['_nni_symbol'])
return _pickling_object, (type_, kw_only, obj_)
_copy_class_wrapper_attributes(base, wrapper)
return wrapper
......@@ -391,6 +488,8 @@ def _trace_func(func, kw_only):
elif hasattr(res, '__class__') and hasattr(res, '__dict__'):
# is a class, inject interface directly
# need to be done before primitive types because there could be inheritance here.
if not getattr(type(res), '_traced', False):
_make_class_traceable(type(res), False) # in-place
res = inject_trace_info(res, func, args, kwargs)
elif isinstance(res, (collections.abc.Callable, types.ModuleType, IOBase)):
raise TypeError(f'Try to add trace info to {res}, but functions and modules are not supported.')
......@@ -400,6 +499,8 @@ def _trace_func(func, kw_only):
# will be directly captured by python json encoder
# and thus not possible to restore the trace parameters after dump and reload.
# this is a known limitation.
new_type = _make_class_traceable(type(res), True)
res = new_type(res) # re-creating the object
res = inject_trace_info(res, func, args, kwargs)
else:
raise TypeError(f'Try to add trace info to {res}, but the type "{type(res)}" is unknown. '
......@@ -425,6 +526,48 @@ def _copy_class_wrapper_attributes(base, wrapper):
wrapper.__wrapped__ = base
class _unwrap_metaclass(type):
# When a subclass is created, it detects whether the super-class is already annotated with @nni.trace.
# If yes, it gets the ``__wrapped__`` inner class, so that it doesn't inherit SerializableObject twice.
# Note that this doesn't work when metaclass is already defined (such as ABCMeta). We give up in that case.
def __new__(cls, name, bases, dct):
bases = tuple([getattr(base, '__wrapped__', base) for base in bases])
return super().__new__(cls, name, bases, dct)
# Using a customized "bases" breaks default isinstance and issubclass.
# We recover this by overriding the subclass and isinstance behavior, which conerns wrapped class only.
def __subclasscheck__(cls, subclass):
inner_cls = getattr(cls, '__wrapped__', cls)
return inner_cls in inspect.getmro(subclass)
def __instancecheck__(cls, instance):
inner_cls = getattr(cls, '__wrapped__', cls)
return inner_cls in inspect.getmro(type(instance))
class _pickling_object:
# Need `cloudpickle.load` on the callable because the callable is pickled with cloudpickle.
# Used in `_trace_cls`.
def __new__(cls, type_, kw_only, data):
type_ = cloudpickle.loads(type_)
# Restore the trace type
type_ = _trace_cls(type_, kw_only)
# restore type
if '_nni_symbol' in data:
data['_nni_symbol'] = cloudpickle.loads(data['_nni_symbol'])
# https://docs.python.org/3/library/pickle.html#pickling-class-instances
obj = type_.__new__(type_)
if hasattr(obj, '__setstate__'):
obj.__setstate__(data)
else:
obj.__dict__.update(data)
return obj
def _argument_processor(arg):
# 1) translate
# handle cases like ValueChoice
......@@ -533,7 +676,9 @@ def _import_cls_or_func_from_name(target: str) -> Any:
def _strip_trace_type(traceable: Any) -> Any:
if getattr(traceable, '_traced', False):
return traceable.__wrapped__
# sometimes, ``__wrapped__`` could be unavailable (e.g., with `inject_trace_info`)
# need to have a default value
return getattr(traceable, '__wrapped__', traceable)
return traceable
......@@ -598,7 +743,7 @@ def _json_tricks_serializable_object_encode(obj: Any, primitives: bool = False,
# Encodes a serializable object instance to json.
# do nothing to instance that is not a serializable object and do not use trace
if not use_trace or not is_traceable(obj):
if not (use_trace and hasattr(obj, '__class__') and is_traceable(type(obj))):
return obj
if isinstance(obj.trace_symbol, property):
......
......@@ -101,6 +101,7 @@ class _MultiModelSupervisedLearningModule(LightningModule):
return {name: self.trainer.callback_metrics['val_' + name].item() for name in self.metrics}
@nni.trace
class MultiModelSupervisedLearningModule(_MultiModelSupervisedLearningModule):
"""
Lightning Module of SupervisedLearning for Cross-Graph Optimization.
......
......@@ -5,7 +5,7 @@ import inspect
import warnings
from typing import Any, TypeVar, Union
from nni.common.serializer import Traceable, is_traceable, trace, _copy_class_wrapper_attributes
from nni.common.serializer import Traceable, is_traceable, is_wrapped_with_trace, trace, _copy_class_wrapper_attributes
from .utils import ModelNamespace
__all__ = ['get_init_parameters_or_fail', 'serialize', 'serialize_cls', 'basic_unit', 'model_wrapper',
......@@ -64,7 +64,8 @@ def basic_unit(cls: T, basic_unit_tag: bool = True) -> Union[T, Traceable]:
class PrimitiveOp(nn.Module):
...
"""
_check_wrapped(cls)
if _check_wrapped(cls, 'basic_unit'):
return cls
import torch.nn as nn
assert issubclass(cls, nn.Module), 'When using @basic_unit, the class must be a subclass of nn.Module.'
......@@ -72,15 +73,7 @@ def basic_unit(cls: T, basic_unit_tag: bool = True) -> Union[T, Traceable]:
cls = trace(cls)
cls._nni_basic_unit = basic_unit_tag
# HACK: for torch script
# https://github.com/pytorch/pytorch/pull/45261
# https://github.com/pytorch/pytorch/issues/54688
# I'm not sure whether there will be potential issues
import torch
cls._get_nni_attr = torch.jit.ignore(cls._get_nni_attr)
cls.trace_symbol = torch.jit.unused(cls.trace_symbol)
cls.trace_args = torch.jit.unused(cls.trace_args)
cls.trace_kwargs = torch.jit.unused(cls.trace_kwargs)
_torchscript_patch(cls)
return cls
......@@ -103,12 +96,14 @@ def model_wrapper(cls: T) -> Union[T, Traceable]:
Currently, NNI might not complain in simple cases where ``@model_wrapper`` is actually not needed.
But in future, we might enforce ``@model_wrapper`` to be required for base model.
"""
_check_wrapped(cls)
if _check_wrapped(cls, 'model_wrapper'):
return cls
import torch.nn as nn
assert issubclass(cls, nn.Module)
wrapper = trace(cls)
# subclass can still use trace info
wrapper = trace(cls, inheritable=True)
class reset_wrapper(wrapper):
def __init__(self, *args, **kwargs):
......@@ -116,8 +111,12 @@ def model_wrapper(cls: T) -> Union[T, Traceable]:
super().__init__(*args, **kwargs)
_copy_class_wrapper_attributes(wrapper, reset_wrapper)
reset_wrapper.__wrapped__ = wrapper.__wrapped__
reset_wrapper.__wrapped__ = getattr(wrapper, '__wrapped__', wrapper)
reset_wrapper._nni_model_wrapper = True
reset_wrapper._traced = True
_torchscript_patch(cls)
return reset_wrapper
......@@ -133,6 +132,32 @@ def is_model_wrapped(cls_or_instance) -> bool:
return getattr(cls_or_instance, '_nni_model_wrapper', False)
def _check_wrapped(cls: T) -> bool:
if getattr(cls, '_traced', False) or getattr(cls, '_nni_model_wrapper', False):
raise TypeError(f'{cls} is already wrapped with trace wrapper (basic_unit / model_wrapper / trace). Cannot wrap again.')
def _check_wrapped(cls: T, rewrap: str) -> bool:
wrapped = None
if is_model_wrapped(cls):
wrapped = 'model_wrapper'
elif is_basic_unit(cls):
wrapped = 'basic_unit'
elif is_wrapped_with_trace(cls):
wrapped = 'nni.trace'
if wrapped:
if wrapped != rewrap:
raise TypeError(f'{cls} is already wrapped with {wrapped}. Cannot rewrap with {rewrap}.')
return True
return False
def _torchscript_patch(cls) -> None:
# HACK: for torch script
# https://github.com/pytorch/pytorch/pull/45261
# https://github.com/pytorch/pytorch/issues/54688
# I'm not sure whether there will be potential issues
import torch
if hasattr(cls, '_get_nni_attr'): # could not exist on non-linux
cls._get_nni_attr = torch.jit.ignore(cls._get_nni_attr)
if hasattr(cls, 'trace_symbol'):
# these must all exist or all non-exist
cls.trace_symbol = torch.jit.unused(cls.trace_symbol)
cls.trace_args = torch.jit.unused(cls.trace_args)
cls.trace_kwargs = torch.jit.unused(cls.trace_kwargs)
cls.trace_copy = torch.jit.ignore(cls.trace_copy)
......@@ -9,6 +9,7 @@ from pytorch_lightning.utilities.seed import seed_everything
from pathlib import Path
import nni
import nni.runtime.platform.test
try:
from nni.common.device import GPUDevice
......
import math
import pickle
import sys
from pathlib import Path
......@@ -27,6 +28,11 @@ class SimpleClass:
self._b = b
@nni.trace
class EmptyClass:
pass
class UnserializableSimpleClass:
def __init__(self):
self._a = 1
......@@ -124,7 +130,8 @@ def test_custom_class():
module = nni.trace(Foo)(Foo(1), 5)
dumped_module = nni.dump(module)
assert len(dumped_module) > 200 # should not be too longer if the serialization is correct
module = nni.load(dumped_module)
assert module.bb[0] == module.bb[999] == 6
module = nni.trace(Foo)(nni.trace(Foo)(1), 5)
dumped_module = nni.dump(module)
......@@ -193,6 +200,20 @@ def test_dataset():
assert y.size() == torch.Size([10])
def test_pickle():
pickle.dumps(EmptyClass())
obj = SimpleClass(1)
obj = pickle.loads(pickle.dumps(obj))
assert obj._a == 1
assert obj._b == 1
obj = SimpleClass(1)
obj.xxx = 3
obj = pickle.loads(pickle.dumps(obj))
assert obj.xxx == 3
@pytest.mark.skipif(sys.platform != 'linux', reason='https://github.com/microsoft/nni/issues/4434')
def test_multiprocessing_dataloader():
# check whether multi-processing works
......@@ -208,6 +229,28 @@ def test_multiprocessing_dataloader():
assert y.size() == torch.Size([10])
def _test_multiprocessing_dataset_worker(dataset):
if sys.platform == 'linux':
# on non-linux, the loaded object will become non-traceable
# due to an implementation limitation
assert is_traceable(dataset)
else:
from torch.utils.data import Dataset
assert isinstance(dataset, Dataset)
def test_multiprocessing_dataset():
from torch.utils.data import Dataset
dataset = nni.trace(Dataset)()
import multiprocessing
process = multiprocessing.Process(target=_test_multiprocessing_dataset_worker, args=(dataset, ))
process.start()
process.join()
assert process.exitcode == 0
def test_type():
assert nni.dump(torch.optim.Adam) == '{"__nni_type__": "path:torch.optim.adam.Adam"}'
assert nni.load('{"__nni_type__": "path:torch.optim.adam.Adam"}') == torch.optim.Adam
......@@ -220,10 +263,20 @@ def test_lightning_earlystop():
import nni.retiarii.evaluator.pytorch.lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
trainer = pl.Trainer(callbacks=[nni.trace(EarlyStopping)(monitor="val_loss")])
trainer = nni.load(nni.dump(trainer))
pickle_size_limit = 4096 if sys.platform == 'linux' else 32768
trainer = nni.load(nni.dump(trainer, pickle_size_limit=pickle_size_limit))
assert any(isinstance(callback, EarlyStopping) for callback in trainer.callbacks)
def test_pickle_trainer():
import nni.retiarii.evaluator.pytorch.lightning as pl
from pytorch_lightning import Trainer
trainer = pl.Trainer(max_epochs=1)
data = pickle.dumps(trainer)
trainer = pickle.loads(data)
assert isinstance(trainer, Trainer)
def test_generator():
import torch.nn as nn
import torch.optim as optim
......@@ -272,11 +325,31 @@ def test_arguments_kind():
assert lstm.trace_kwargs == {'input_size': 2, 'hidden_size': 2}
if __name__ == '__main__':
# test_simple_class()
# test_external_class()
# test_nested_class()
# test_unserializable()
# test_basic_unit()
# test_generator()
test_arguments_kind()
def test_subclass():
@nni.trace
class Super:
def __init__(self, a, b):
self._a = a
self._b = b
class Sub1(Super):
def __init__(self, c, d):
super().__init__(3, 4)
self._c = c
self._d = d
@nni.trace
class Sub2(Super):
def __init__(self, c, d):
super().__init__(3, 4)
self._c = c
self._d = d
obj = Sub1(1, 2)
# There could be trace_kwargs for obj. Behavior is undefined.
assert obj._a == 3 and obj._c == 1
assert isinstance(obj, Super)
obj = Sub2(1, 2)
assert obj.trace_kwargs == {'c': 1, 'd': 2}
assert issubclass(type(obj), Super)
assert isinstance(obj, Super)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment