import math import sys from pathlib import Path import pytest import nni import torch from torch.utils.data import DataLoader from torchvision import transforms from torchvision.datasets import MNIST from nni.common.serializer import is_traceable if True: # prevent auto formatting sys.path.insert(0, Path(__file__).parent.as_posix()) from imported.model import ImportTest # this test cannot be directly put in this file. It will cause syntax error for python <= 3.7. if tuple(sys.version_info) >= (3, 8): from imported._test_serializer_py38 import test_positional_only @nni.trace class SimpleClass: def __init__(self, a, b=1): self._a = a self._b = b class UnserializableSimpleClass: def __init__(self): self._a = 1 def test_simple_class(): instance = SimpleClass(1, 2) assert instance._a == 1 assert instance._b == 2 dump_str = nni.dump(instance) assert '"__kwargs__": {"a": 1, "b": 2}' in dump_str assert '"__symbol__"' in dump_str instance = nni.load(dump_str) assert instance._a == 1 assert instance._b == 2 def test_external_class(): from collections import OrderedDict d = nni.trace(kw_only=False)(OrderedDict)([('a', 1), ('b', 2)]) assert d['a'] == 1 assert d['b'] == 2 dump_str = nni.dump(d) assert dump_str == '{"a": 1, "b": 2}' conv = nni.trace(torch.nn.Conv2d)(3, 16, 3) assert conv.in_channels == 3 assert conv.out_channels == 16 assert conv.kernel_size == (3, 3) assert nni.dump(conv) == \ r'{"__symbol__": "path:torch.nn.modules.conv.Conv2d", ' \ r'"__kwargs__": {"in_channels": 3, "out_channels": 16, "kernel_size": 3}}' conv = nni.load(nni.dump(conv)) assert conv.kernel_size == (3, 3) def test_nested_class(): a = SimpleClass(1, 2) b = SimpleClass(a) assert b._a._a == 1 dump_str = nni.dump(b) b = nni.load(dump_str) assert 'SimpleClass object at' in repr(b) assert b._a._a == 1 def test_unserializable(): a = UnserializableSimpleClass() dump_str = nni.dump(a) a = nni.load(dump_str) assert a._a == 1 def test_function(): t = nni.trace(math.sqrt, kw_only=False)(3) assert 1 < t < 2 assert t.trace_symbol == math.sqrt assert t.trace_args == [3] t = nni.load(nni.dump(t)) assert 1 < t < 2 assert not is_traceable(t) # trace not recovered, expected, limitation def simple_class_factory(bb=3.): return SimpleClass(1, bb) t = nni.trace(simple_class_factory)(4) ts = nni.dump(t) assert '__kwargs__' in ts t = nni.load(ts) assert t._a == 1 assert is_traceable(t) t = t.trace_copy() assert is_traceable(t) assert t.trace_symbol(10)._b == 10 assert t.trace_kwargs['bb'] == 4 assert is_traceable(t.trace_copy()) class Foo: def __init__(self, a, b=1): self.aa = a self.bb = [b + 1 for _ in range(1000)] def __eq__(self, other): return self.aa == other.aa and self.bb == other.bb def test_custom_class(): module = nni.trace(Foo)(3) assert nni.load(nni.dump(module)) == module module = nni.trace(Foo)(b=2, a=1) assert nni.load(nni.dump(module)) == module module = nni.trace(Foo)(Foo(1), 5) dumped_module = nni.dump(module) assert len(dumped_module) > 200 # should not be too longer if the serialization is correct module = nni.trace(Foo)(nni.trace(Foo)(1), 5) dumped_module = nni.dump(module) assert nni.load(dumped_module) == module class Foo: def __init__(self, a, b=1): self.aa = a self.bb = [b + 1 for _ in range(1000)] def __eq__(self, other): return self.aa == other.aa and self.bb == other.bb def test_basic_unit_and_custom_import(): module = ImportTest(3, 0.5) ss = nni.dump(module) assert ss == r'{"__symbol__": "path:imported.model.ImportTest", "__kwargs__": {"foo": 3, "bar": 0.5}}' assert nni.load(nni.dump(module)) == module import nni.retiarii.nn.pytorch as nn module = nn.Conv2d(3, 10, 3, bias=False) ss = nni.dump(module) assert ss == r'{"__symbol__": "path:torch.nn.modules.conv.Conv2d", "__kwargs__": {"in_channels": 3, "out_channels": 10, "kernel_size": 3, "bias": false}}' assert nni.load(ss).bias is None def test_dataset(): dataset = nni.trace(MNIST)(root='data/mnist', train=False, download=True) dataloader = nni.trace(DataLoader)(dataset, batch_size=10) dumped_ans = { "__symbol__": "path:torch.utils.data.dataloader.DataLoader", "__kwargs__": { "dataset": { "__symbol__": "path:torchvision.datasets.mnist.MNIST", "__kwargs__": {"root": "data/mnist", "train": False, "download": True} }, "batch_size": 10 } } print(nni.dump(dataloader)) print(nni.dump(dumped_ans)) assert nni.dump(dataloader) == nni.dump(dumped_ans) dataloader = nni.load(nni.dump(dumped_ans)) assert isinstance(dataloader, DataLoader) dataset = nni.trace(MNIST)(root='data/mnist', train=False, download=True, transform=nni.trace(transforms.Compose)([ nni.trace(transforms.ToTensor)(), nni.trace(transforms.Normalize)((0.1307,), (0.3081,)) ])) dataloader = nni.trace(DataLoader)(dataset, batch_size=10) x, y = next(iter(nni.load(nni.dump(dataloader)))) assert x.size() == torch.Size([10, 1, 28, 28]) assert y.size() == torch.Size([10]) dataset = nni.trace(MNIST)(root='data/mnist', train=False, download=True, transform=nni.trace(transforms.Compose)( [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] )) dataloader = nni.trace(DataLoader)(dataset, batch_size=10) x, y = next(iter(nni.load(nni.dump(dataloader)))) assert x.size() == torch.Size([10, 1, 28, 28]) assert y.size() == torch.Size([10]) @pytest.mark.skipif(sys.platform != 'linux', reason='https://github.com/microsoft/nni/issues/4434') def test_multiprocessing_dataloader(): # check whether multi-processing works # it's possible to have pickle errors dataset = nni.trace(MNIST)(root='data/mnist', train=False, download=True, transform=nni.trace(transforms.Compose)( [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] )) import nni.retiarii.evaluator.pytorch.lightning as pl dataloader = pl.DataLoader(dataset, batch_size=10, num_workers=2) x, y = next(iter(dataloader)) assert x.size() == torch.Size([10, 1, 28, 28]) assert y.size() == torch.Size([10]) def test_type(): assert nni.dump(torch.optim.Adam) == '{"__nni_type__": "path:torch.optim.adam.Adam"}' assert nni.load('{"__nni_type__": "path:torch.optim.adam.Adam"}') == torch.optim.Adam assert Foo == nni.load(nni.dump(Foo)) assert nni.dump(math.floor) == '{"__nni_type__": "path:math.floor"}' assert nni.load('{"__nni_type__": "path:math.floor"}') == math.floor def test_lightning_earlystop(): import nni.retiarii.evaluator.pytorch.lightning as pl from pytorch_lightning.callbacks.early_stopping import EarlyStopping trainer = pl.Trainer(callbacks=[nni.trace(EarlyStopping)(monitor="val_loss")]) trainer = nni.load(nni.dump(trainer)) assert any(isinstance(callback, EarlyStopping) for callback in trainer.callbacks) def test_generator(): import torch.nn as nn import torch.optim as optim class Net(nn.Module): def __init__(self): super().__init__() self.conv = nn.Conv2d(3, 10, 1) def forward(self, x): return self.conv(x) model = Net() optimizer = nni.trace(optim.Adam)(model.parameters()) print(optimizer.trace_kwargs) def test_arguments_kind(): def foo(a, b, *c, **d): pass d = nni.trace(foo)(1, 2, 3, 4) assert d.trace_args == [1, 2, 3, 4] assert d.trace_kwargs == {} d = nni.trace(foo)(a=1, b=2) assert d.trace_kwargs == dict(a=1, b=2) d = nni.trace(foo)(1, b=2) # this is not perfect, but it's safe assert d.trace_kwargs == dict(a=1, b=2) def foo(a, *, b=3, c=5): pass d = nni.trace(foo)(1, b=2, c=3) assert d.trace_kwargs == dict(a=1, b=2, c=3) import torch.nn as nn lstm = nni.trace(nn.LSTM)(2, 2) assert lstm.input_size == 2 assert lstm.hidden_size == 2 assert lstm.trace_args == [2, 2] lstm = nni.trace(nn.LSTM)(input_size=2, hidden_size=2) assert lstm.trace_kwargs == {'input_size': 2, 'hidden_size': 2} if __name__ == '__main__': # test_simple_class() # test_external_class() # test_nested_class() # test_unserializable() # test_basic_unit() # test_generator() test_arguments_kind()