"docs/source/vscode:/vscode.git/clone" did not exist on "b22f192c8c9e81eb0e8a5c1108c2a55f9a51ea46"
Unverified Commit cae4308f authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

[Retiarii] Rename APIs and refine documentation (#3404)

parent d047d6f4
...@@ -3,10 +3,8 @@ import sys ...@@ -3,10 +3,8 @@ import sys
import torch import torch
from pathlib import Path from pathlib import Path
from nni.retiarii.trainer.pytorch import PyTorchImageClassificationTrainer import nni.retiarii.evaluator.pytorch.lightning as pl
from nni.retiarii import serialize
import nni.retiarii.trainer.pytorch.lightning as pl
from nni.retiarii import blackbox_module as bm
from base_mnasnet import MNASNet from base_mnasnet import MNASNet
from nni.retiarii.experiment.pytorch import RetiariiExperiment, RetiariiExeConfig from nni.retiarii.experiment.pytorch import RetiariiExperiment, RetiariiExeConfig
from nni.retiarii.strategy import TPEStrategy from nni.retiarii.strategy import TPEStrategy
...@@ -35,8 +33,8 @@ if __name__ == '__main__': ...@@ -35,8 +33,8 @@ if __name__ == '__main__':
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]) ])
train_dataset = bm(CIFAR10)(root='data/cifar10', train=True, download=True, transform=train_transform) train_dataset = serialize(CIFAR10, root='data/cifar10', train=True, download=True, transform=train_transform)
test_dataset = bm(CIFAR10)(root='data/cifar10', train=False, download=True, transform=valid_transform) test_dataset = serialize(CIFAR10, root='data/cifar10', train=False, download=True, transform=valid_transform)
trainer = pl.Classification(train_dataloader=pl.DataLoader(train_dataset, batch_size=100), trainer = pl.Classification(train_dataloader=pl.DataLoader(train_dataset, batch_size=100),
val_dataloaders=pl.DataLoader(test_dataset, batch_size=100), val_dataloaders=pl.DataLoader(test_dataset, batch_size=100),
max_epochs=1, limit_train_batches=0.2) max_epochs=1, limit_train_batches=0.2)
...@@ -56,4 +54,4 @@ if __name__ == '__main__': ...@@ -56,4 +54,4 @@ if __name__ == '__main__':
exp_config.max_trial_number = 10 exp_config.max_trial_number = 10
exp_config.training_service.use_active_gpu = False exp_config.training_service.use_active_gpu = False
exp.run(exp_config, 8081) exp.run(exp_config, 8097)
...@@ -2,9 +2,9 @@ import random ...@@ -2,9 +2,9 @@ import random
import nni.retiarii.nn.pytorch as nn import nni.retiarii.nn.pytorch as nn
import nni.retiarii.strategy as strategy import nni.retiarii.strategy as strategy
import nni.retiarii.trainer.pytorch.lightning as pl import nni.retiarii.evaluator.pytorch.lightning as pl
import torch.nn.functional as F import torch.nn.functional as F
from nni.retiarii import blackbox_module as bm from nni.retiarii import serialize
from nni.retiarii.experiment.pytorch import RetiariiExeConfig, RetiariiExperiment from nni.retiarii.experiment.pytorch import RetiariiExeConfig, RetiariiExperiment
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torchvision import transforms from torchvision import transforms
...@@ -36,8 +36,8 @@ class Net(nn.Module): ...@@ -36,8 +36,8 @@ class Net(nn.Module):
if __name__ == '__main__': if __name__ == '__main__':
base_model = Net(128) base_model = Net(128)
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = bm(MNIST)(root='data/mnist', train=True, download=True, transform=transform) train_dataset = serialize(MNIST, root='data/mnist', train=True, download=True, transform=transform)
test_dataset = bm(MNIST)(root='data/mnist', train=False, download=True, transform=transform) test_dataset = serialize(MNIST, root='data/mnist', train=False, download=True, transform=transform)
trainer = pl.Classification(train_dataloader=pl.DataLoader(train_dataset, batch_size=100), trainer = pl.Classification(train_dataloader=pl.DataLoader(train_dataset, batch_size=100),
val_dataloaders=pl.DataLoader(test_dataset, batch_size=100), val_dataloaders=pl.DataLoader(test_dataset, batch_size=100),
max_epochs=2) max_epochs=2)
......
...@@ -340,7 +340,7 @@ ...@@ -340,7 +340,7 @@
} }
] ]
}, },
"_training_config": { "_evaluator": {
"module": "nni.retiarii.trainer.PyTorchImageClassificationTrainer", "module": "nni.retiarii.trainer.PyTorchImageClassificationTrainer",
"kwargs": { "kwargs": {
"dataset_cls": "MNIST", "dataset_cls": "MNIST",
......
import nni.retiarii.nn.pytorch as nn import nni.retiarii.nn.pytorch as nn
from nni.retiarii import blackbox_module from nni.retiarii import basic_unit
@blackbox_module @basic_unit
class ImportTest(nn.Module): class ImportTest(nn.Module):
def __init__(self, foo, bar): def __init__(self, foo, bar):
super().__init__() super().__init__()
......
...@@ -4,7 +4,7 @@ import logging ...@@ -4,7 +4,7 @@ import logging
import torch import torch
import torch.nn as nn import torch.nn as nn
from nni.retiarii.utils import add_record, del_record, version_larger_equal from nni.retiarii.utils import version_larger_equal
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -13,39 +13,23 @@ def wrap_module(original_class): ...@@ -13,39 +13,23 @@ def wrap_module(original_class):
argname_list = list(inspect.signature(original_class).parameters.keys()) argname_list = list(inspect.signature(original_class).parameters.keys())
# Make copy of original __init__, so we can call it without recursion # Make copy of original __init__, so we can call it without recursion
original_class.bak_init_for_inject = orig_init original_class.bak_init_for_inject = orig_init
if hasattr(original_class, '__del__'):
orig_del = original_class.__del__
original_class.bak_del_for_inject = orig_del
else:
orig_del = None
original_class.bak_del_for_inject = None
def __init__(self, *args, **kws): def __init__(self, *args, **kws):
full_args = {} full_args = {}
full_args.update(kws) full_args.update(kws)
for i, arg in enumerate(args): for i, arg in enumerate(args):
full_args[argname_list[i]] = arg full_args[argname_list[i]] = arg
add_record(id(self), full_args) self._init_parameters = full_args
orig_init(self, *args, **kws) # Call the original __init__ orig_init(self, *args, **kws) # Call the original __init__
def __del__(self):
del_record(id(self))
if orig_del is not None:
orig_del(self)
original_class.__init__ = __init__ # Set the class' __init__ to the new one original_class.__init__ = __init__ # Set the class' __init__ to the new one
original_class.__del__ = __del__
return original_class return original_class
def unwrap_module(wrapped_class): def unwrap_module(wrapped_class):
if hasattr(wrapped_class, 'bak_init_for_inject'): if hasattr(wrapped_class, 'bak_init_for_inject'):
wrapped_class.__init__ = wrapped_class.bak_init_for_inject wrapped_class.__init__ = wrapped_class.bak_init_for_inject
delattr(wrapped_class, 'bak_init_for_inject') delattr(wrapped_class, 'bak_init_for_inject')
if hasattr(wrapped_class, 'bak_del_for_inject'):
if wrapped_class.bak_del_for_inject is not None:
wrapped_class.__del__ = wrapped_class.bak_del_for_inject
delattr(wrapped_class, 'bak_del_for_inject')
return None return None
def remove_inject_pytorch_nn(): def remove_inject_pytorch_nn():
......
...@@ -38,7 +38,7 @@ ...@@ -38,7 +38,7 @@
] ]
}, },
"_training_config": { "_evaluator": {
"__type__": "_debug_no_trainer" "__type__": "_debug_no_trainer"
} }
} }
...@@ -38,7 +38,7 @@ ...@@ -38,7 +38,7 @@
] ]
}, },
"_training_config": { "_evaluator": {
"module": "nni.retiarii.trainer.PyTorchImageClassificationTrainer", "module": "nni.retiarii.trainer.PyTorchImageClassificationTrainer",
"kwargs": { "kwargs": {
"dataset_cls": "MNIST", "dataset_cls": "MNIST",
......
...@@ -18,7 +18,7 @@ from nni.retiarii import Model, Node ...@@ -18,7 +18,7 @@ from nni.retiarii import Model, Node
from nni.retiarii import Model, submit_models from nni.retiarii import Model, submit_models
from nni.retiarii.codegen import model_to_pytorch_script from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.integration import RetiariiAdvisor from nni.retiarii.integration import RetiariiAdvisor
from nni.retiarii.trainer.pytorch import PyTorchImageClassificationTrainer, PyTorchMultiModelTrainer from nni.retiarii.evaluator.pytorch import PyTorchImageClassificationTrainer, PyTorchMultiModelTrainer
from nni.retiarii.utils import import_ from nni.retiarii.utils import import_
......
...@@ -12,10 +12,9 @@ import torch.nn.functional as F ...@@ -12,10 +12,9 @@ import torch.nn.functional as F
import torchvision import torchvision
import nni.retiarii.nn.pytorch as nn import nni.retiarii.nn.pytorch as nn
from nni.retiarii import blackbox_module from nni.retiarii import basic_unit
from nni.retiarii.converter import convert_to_graph from nni.retiarii.converter import convert_to_graph
from nni.retiarii.codegen import model_to_pytorch_script from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.utils import get_records
class MnistNet(nn.Module): class MnistNet(nn.Module):
def __init__(self): def __init__(self):
...@@ -35,8 +34,8 @@ class MnistNet(nn.Module): ...@@ -35,8 +34,8 @@ class MnistNet(nn.Module):
x = self.fc2(x) x = self.fc2(x)
return F.log_softmax(x, dim=1) return F.log_softmax(x, dim=1)
# NOTE: blackbox module cannot be placed within class or function # NOTE: serialize module cannot be placed within class or function
@blackbox_module @basic_unit
class Linear(nn.Module): class Linear(nn.Module):
def __init__(self, d_embed, d_proj): def __init__(self, d_embed, d_proj):
super().__init__() super().__init__()
...@@ -66,9 +65,6 @@ class TestConvert(unittest.TestCase): ...@@ -66,9 +65,6 @@ class TestConvert(unittest.TestCase):
model_ir = convert_to_graph(script_module, model) model_ir = convert_to_graph(script_module, model)
model_code = model_to_pytorch_script(model_ir) model_code = model_to_pytorch_script(model_ir)
from .inject_nn import remove_inject_pytorch_nn
remove_inject_pytorch_nn()
exec_vars = {} exec_vars = {}
exec(model_code + '\n\nconverted_model = _model()', exec_vars) exec(model_code + '\n\nconverted_model = _model()', exec_vars)
converted_model = exec_vars['converted_model'] converted_model = exec_vars['converted_model']
...@@ -458,9 +454,12 @@ class TestConvert(unittest.TestCase): ...@@ -458,9 +454,12 @@ class TestConvert(unittest.TestCase):
self.checkExportImport(VAE().eval(), (torch.rand(128, 1, 28, 28),)) self.checkExportImport(VAE().eval(), (torch.rand(128, 1, 28, 28),))
def test_torchvision_resnet18(self): def test_torchvision_resnet18(self):
from .inject_nn import inject_pytorch_nn from .inject_nn import inject_pytorch_nn, remove_inject_pytorch_nn
inject_pytorch_nn() try:
self.checkExportImport(torchvision.models.resnet18().eval(), (torch.ones(1, 3, 224, 224),)) inject_pytorch_nn()
self.checkExportImport(torchvision.models.resnet18().eval(), (torch.ones(1, 3, 224, 224),))
finally:
remove_inject_pytorch_nn()
def test_resnet(self): def test_resnet(self):
def conv1x1(in_planes, out_planes, stride=1): def conv1x1(in_planes, out_planes, stride=1):
...@@ -572,8 +571,11 @@ class TestConvert(unittest.TestCase): ...@@ -572,8 +571,11 @@ class TestConvert(unittest.TestCase):
self.checkExportImport(resnet18, (torch.randn(1, 3, 224, 224),)) self.checkExportImport(resnet18, (torch.randn(1, 3, 224, 224),))
def test_alexnet(self): def test_alexnet(self):
from .inject_nn import inject_pytorch_nn from .inject_nn import inject_pytorch_nn, remove_inject_pytorch_nn
inject_pytorch_nn() try:
x = torch.ones(1, 3, 224, 224) inject_pytorch_nn()
model = torchvision.models.AlexNet() x = torch.ones(1, 3, 224, 224)
self.checkExportImport(model, (x,)) model = torchvision.models.AlexNet()
self.checkExportImport(model, (x,))
finally:
remove_inject_pytorch_nn()
...@@ -8,10 +8,9 @@ import torch.nn.functional as F ...@@ -8,10 +8,9 @@ import torch.nn.functional as F
import torchvision import torchvision
import nni.retiarii.nn.pytorch as nn import nni.retiarii.nn.pytorch as nn
from nni.retiarii import blackbox_module from nni.retiarii import basic_unit
from nni.retiarii.converter import convert_to_graph from nni.retiarii.converter import convert_to_graph
from nni.retiarii.codegen import model_to_pytorch_script from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.utils import get_records
# following pytorch v1.7.1 # following pytorch v1.7.1
......
...@@ -15,10 +15,8 @@ import torch.nn.functional as F ...@@ -15,10 +15,8 @@ import torch.nn.functional as F
import torchvision import torchvision
import nni.retiarii.nn.pytorch as nn import nni.retiarii.nn.pytorch as nn
from nni.retiarii import blackbox_module
from nni.retiarii.converter import convert_to_graph from nni.retiarii.converter import convert_to_graph
from nni.retiarii.codegen import model_to_pytorch_script from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.utils import get_records
# following pytorch v1.7.1 # following pytorch v1.7.1
......
...@@ -14,10 +14,9 @@ import torch.nn.functional as F ...@@ -14,10 +14,9 @@ import torch.nn.functional as F
import torchvision import torchvision
import nni.retiarii.nn.pytorch as nn import nni.retiarii.nn.pytorch as nn
from nni.retiarii import blackbox_module from nni.retiarii import serialize
from nni.retiarii.converter import convert_to_graph from nni.retiarii.converter import convert_to_graph
from nni.retiarii.codegen import model_to_pytorch_script from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.utils import get_records
class TestPytorch(unittest.TestCase): class TestPytorch(unittest.TestCase):
......
...@@ -17,7 +17,6 @@ from nni.retiarii import Model, Node ...@@ -17,7 +17,6 @@ from nni.retiarii import Model, Node
from nni.retiarii import Model, submit_models from nni.retiarii import Model, submit_models
from nni.retiarii.codegen import model_to_pytorch_script from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.integration import RetiariiAdvisor from nni.retiarii.integration import RetiariiAdvisor
from nni.retiarii.trainer.pytorch import PyTorchImageClassificationTrainer, PyTorchMultiModelTrainer
from nni.retiarii.utils import import_ from nni.retiarii.utils import import_
...@@ -74,7 +73,7 @@ class DedupInputTest(unittest.TestCase): ...@@ -74,7 +73,7 @@ class DedupInputTest(unittest.TestCase):
# sys.path.insert(0, 'generated') # sys.path.insert(0, 'generated')
# multi_model = import_('debug_dedup_input.logical_0') # multi_model = import_('debug_dedup_input.logical_0')
# trainer = PyTorchMultiModelTrainer( # trainer = PyTorchMultiModelTrainer(
# multi_model(), phy_models[0][0].training_config.kwargs # multi_model(), phy_models[0][0].evaluator.kwargs
# ) # )
# trainer.fit() # trainer.fit()
......
...@@ -9,7 +9,7 @@ import nni ...@@ -9,7 +9,7 @@ import nni
from nni.retiarii import Model, submit_models from nni.retiarii import Model, submit_models
from nni.retiarii.codegen import model_to_pytorch_script from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.integration import RetiariiAdvisor, register_advisor from nni.retiarii.integration import RetiariiAdvisor, register_advisor
from nni.retiarii.trainer.pytorch import PyTorchImageClassificationTrainer from nni.retiarii.evaluator.pytorch import PyTorchImageClassificationTrainer
from nni.retiarii.utils import import_ from nni.retiarii.utils import import_
......
...@@ -23,7 +23,7 @@ def _test_file(json_path): ...@@ -23,7 +23,7 @@ def _test_file(json_path):
# add default values to JSON, so we can compare with `==` # add default values to JSON, so we can compare with `==`
for graph_name, graph in orig_ir.items(): for graph_name, graph in orig_ir.items():
if graph_name == '_training_config': if graph_name == '_evaluator':
continue continue
if 'inputs' not in graph: if 'inputs' not in graph:
graph['inputs'] = None graph['inputs'] = None
......
...@@ -4,7 +4,7 @@ import unittest ...@@ -4,7 +4,7 @@ import unittest
import nni.retiarii.nn.pytorch as nn import nni.retiarii.nn.pytorch as nn
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from nni.retiarii import Sampler, blackbox_module from nni.retiarii import Sampler, basic_unit
from nni.retiarii.converter import convert_to_graph from nni.retiarii.converter import convert_to_graph
from nni.retiarii.codegen import model_to_pytorch_script from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.nn.pytorch.mutator import process_inline_mutation from nni.retiarii.nn.pytorch.mutator import process_inline_mutation
...@@ -29,7 +29,7 @@ class RandomSampler(Sampler): ...@@ -29,7 +29,7 @@ class RandomSampler(Sampler):
return random.choice(candidates) return random.choice(candidates)
@blackbox_module @basic_unit
class MutableConv(nn.Module): class MutableConv(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
......
...@@ -2,13 +2,13 @@ import json ...@@ -2,13 +2,13 @@ import json
import pytest import pytest
import nni import nni
import nni.retiarii.trainer.pytorch.lightning as pl import nni.retiarii.evaluator.pytorch.lightning as pl
import pytorch_lightning import pytorch_lightning
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from nni.retiarii import blackbox_module as bm from nni.retiarii import serialize_cls, serialize
from nni.retiarii.trainer import FunctionalTrainer from nni.retiarii.evaluator import FunctionalEvaluator
from sklearn.datasets import load_diabetes from sklearn.datasets import load_diabetes
from torch.utils.data import Dataset from torch.utils.data import Dataset
from torchvision import transforms from torchvision import transforms
...@@ -49,7 +49,7 @@ class FCNet(nn.Module): ...@@ -49,7 +49,7 @@ class FCNet(nn.Module):
return output.view(-1) return output.view(-1)
@bm @serialize_cls
class DiabetesDataset(Dataset): class DiabetesDataset(Dataset):
def __init__(self, train=True): def __init__(self, train=True):
data = load_diabetes() data = load_diabetes()
...@@ -91,8 +91,8 @@ def _reset(): ...@@ -91,8 +91,8 @@ def _reset():
def test_mnist(): def test_mnist():
_reset() _reset()
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = bm(MNIST)(root='data/mnist', train=True, download=True, transform=transform) train_dataset = serialize(MNIST, root='data/mnist', train=True, download=True, transform=transform)
test_dataset = bm(MNIST)(root='data/mnist', train=False, download=True, transform=transform) test_dataset = serialize(MNIST, root='data/mnist', train=False, download=True, transform=transform)
lightning = pl.Classification(train_dataloader=pl.DataLoader(train_dataset, batch_size=100), lightning = pl.Classification(train_dataloader=pl.DataLoader(train_dataset, batch_size=100),
val_dataloaders=pl.DataLoader(test_dataset, batch_size=100), val_dataloaders=pl.DataLoader(test_dataset, batch_size=100),
max_epochs=2, limit_train_batches=0.25, # for faster training max_epochs=2, limit_train_batches=0.25, # for faster training
...@@ -121,7 +121,7 @@ def test_diabetes(): ...@@ -121,7 +121,7 @@ def test_diabetes():
@pytest.mark.skipif(pytorch_lightning.__version__ < '1.0', reason='Incompatible APIs.') @pytest.mark.skipif(pytorch_lightning.__version__ < '1.0', reason='Incompatible APIs.')
def test_functional(): def test_functional():
FunctionalTrainer(_foo)._execute(MNISTModel) FunctionalEvaluator(_foo)._execute(MNISTModel)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -4,7 +4,7 @@ import re ...@@ -4,7 +4,7 @@ import re
import sys import sys
import torch import torch
from nni.retiarii import json_dumps, json_loads, blackbox from nni.retiarii import json_dumps, json_loads, serialize
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torchvision import transforms from torchvision import transforms
from torchvision.datasets import MNIST from torchvision.datasets import MNIST
...@@ -23,30 +23,30 @@ class Foo: ...@@ -23,30 +23,30 @@ class Foo:
return self.aa == other.aa and self.bb == other.bb return self.aa == other.aa and self.bb == other.bb
def test_blackbox(): def test_serialize():
module = blackbox(Foo, 3) module = serialize(Foo, 3)
assert json_loads(json_dumps(module)) == module assert json_loads(json_dumps(module)) == module
module = blackbox(Foo, b=2, a=1) module = serialize(Foo, b=2, a=1)
assert json_loads(json_dumps(module)) == module assert json_loads(json_dumps(module)) == module
module = blackbox(Foo, Foo(1), 5) module = serialize(Foo, Foo(1), 5)
dumped_module = json_dumps(module) dumped_module = json_dumps(module)
assert len(dumped_module) > 200 # should not be too longer if the serialization is correct assert len(dumped_module) > 200 # should not be too longer if the serialization is correct
module = blackbox(Foo, blackbox(Foo, 1), 5) module = serialize(Foo, serialize(Foo, 1), 5)
dumped_module = json_dumps(module) dumped_module = json_dumps(module)
assert len(dumped_module) < 200 # should not be too longer if the serialization is correct assert len(dumped_module) < 200 # should not be too longer if the serialization is correct
assert json_loads(dumped_module) == module assert json_loads(dumped_module) == module
def test_blackbox_module(): def test_basic_unit():
module = ImportTest(3, 0.5) module = ImportTest(3, 0.5)
assert json_loads(json_dumps(module)) == module assert json_loads(json_dumps(module)) == module
def test_dataset(): def test_dataset():
dataset = blackbox(MNIST, root='data/mnist', train=False, download=True) dataset = serialize(MNIST, root='data/mnist', train=False, download=True)
dataloader = blackbox(DataLoader, dataset, batch_size=10) dataloader = serialize(DataLoader, dataset, batch_size=10)
dumped_ans = { dumped_ans = {
"__type__": "torch.utils.data.dataloader.DataLoader", "__type__": "torch.utils.data.dataloader.DataLoader",
...@@ -62,19 +62,19 @@ def test_dataset(): ...@@ -62,19 +62,19 @@ def test_dataset():
dataloader = json_loads(json_dumps(dumped_ans)) dataloader = json_loads(json_dumps(dumped_ans))
assert isinstance(dataloader, DataLoader) assert isinstance(dataloader, DataLoader)
dataset = blackbox(MNIST, root='data/mnist', train=False, download=True, dataset = serialize(MNIST, root='data/mnist', train=False, download=True,
transform=blackbox( transform=serialize(
transforms.Compose, transforms.Compose,
[blackbox(transforms.ToTensor), blackbox(transforms.Normalize, (0.1307,), (0.3081,))] [serialize(transforms.ToTensor), serialize(transforms.Normalize, (0.1307,), (0.3081,))]
)) ))
dataloader = blackbox(DataLoader, dataset, batch_size=10) dataloader = serialize(DataLoader, dataset, batch_size=10)
x, y = next(iter(json_loads(json_dumps(dataloader)))) x, y = next(iter(json_loads(json_dumps(dataloader))))
assert x.size() == torch.Size([10, 1, 28, 28]) assert x.size() == torch.Size([10, 1, 28, 28])
assert y.size() == torch.Size([10]) assert y.size() == torch.Size([10])
dataset = blackbox(MNIST, root='data/mnist', train=False, download=True, dataset = serialize(MNIST, root='data/mnist', train=False, download=True,
transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])) transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]))
dataloader = blackbox(DataLoader, dataset, batch_size=10) dataloader = serialize(DataLoader, dataset, batch_size=10)
x, y = next(iter(json_loads(json_dumps(dataloader)))) x, y = next(iter(json_loads(json_dumps(dataloader))))
assert x.size() == torch.Size([10, 1, 28, 28]) assert x.size() == torch.Size([10, 1, 28, 28])
assert y.size() == torch.Size([10]) assert y.size() == torch.Size([10])
...@@ -87,7 +87,7 @@ def test_type(): ...@@ -87,7 +87,7 @@ def test_type():
if __name__ == '__main__': if __name__ == '__main__':
test_blackbox() test_serialize()
test_blackbox_module() test_basic_unit()
test_dataset() test_dataset()
test_type() test_type()
...@@ -12,7 +12,7 @@ from nni.retiarii import Model ...@@ -12,7 +12,7 @@ from nni.retiarii import Model
from nni.retiarii.converter import convert_to_graph from nni.retiarii.converter import convert_to_graph
from nni.retiarii.execution import wait_models from nni.retiarii.execution import wait_models
from nni.retiarii.execution.interface import AbstractExecutionEngine, WorkerInfo, MetricData, AbstractGraphListener from nni.retiarii.execution.interface import AbstractExecutionEngine, WorkerInfo, MetricData, AbstractGraphListener
from nni.retiarii.graph import DebugTraining, ModelStatus from nni.retiarii.graph import DebugEvaluator, ModelStatus
from nni.retiarii.nn.pytorch.mutator import process_inline_mutation from nni.retiarii.nn.pytorch.mutator import process_inline_mutation
...@@ -80,7 +80,7 @@ def _get_model_and_mutators(): ...@@ -80,7 +80,7 @@ def _get_model_and_mutators():
base_model = Net() base_model = Net()
script_module = torch.jit.script(base_model) script_module = torch.jit.script(base_model)
base_model_ir = convert_to_graph(script_module, base_model) base_model_ir = convert_to_graph(script_module, base_model)
base_model_ir.training_config = DebugTraining() base_model_ir.evaluator = DebugEvaluator()
mutators = process_inline_mutation(base_model_ir) mutators = process_inline_mutation(base_model_ir)
return base_model_ir, mutators return base_model_ir, mutators
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment