Unverified Commit cae4308f authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

[Retiarii] Rename APIs and refine documentation (#3404)

parent d047d6f4
......@@ -3,10 +3,8 @@ import sys
import torch
from pathlib import Path
from nni.retiarii.trainer.pytorch import PyTorchImageClassificationTrainer
import nni.retiarii.trainer.pytorch.lightning as pl
from nni.retiarii import blackbox_module as bm
import nni.retiarii.evaluator.pytorch.lightning as pl
from nni.retiarii import serialize
from base_mnasnet import MNASNet
from nni.retiarii.experiment.pytorch import RetiariiExperiment, RetiariiExeConfig
from nni.retiarii.strategy import TPEStrategy
......@@ -35,8 +33,8 @@ if __name__ == '__main__':
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
train_dataset = bm(CIFAR10)(root='data/cifar10', train=True, download=True, transform=train_transform)
test_dataset = bm(CIFAR10)(root='data/cifar10', train=False, download=True, transform=valid_transform)
train_dataset = serialize(CIFAR10, root='data/cifar10', train=True, download=True, transform=train_transform)
test_dataset = serialize(CIFAR10, root='data/cifar10', train=False, download=True, transform=valid_transform)
trainer = pl.Classification(train_dataloader=pl.DataLoader(train_dataset, batch_size=100),
val_dataloaders=pl.DataLoader(test_dataset, batch_size=100),
max_epochs=1, limit_train_batches=0.2)
......@@ -56,4 +54,4 @@ if __name__ == '__main__':
exp_config.max_trial_number = 10
exp_config.training_service.use_active_gpu = False
exp.run(exp_config, 8081)
exp.run(exp_config, 8097)
......@@ -2,9 +2,9 @@ import random
import nni.retiarii.nn.pytorch as nn
import nni.retiarii.strategy as strategy
import nni.retiarii.trainer.pytorch.lightning as pl
import nni.retiarii.evaluator.pytorch.lightning as pl
import torch.nn.functional as F
from nni.retiarii import blackbox_module as bm
from nni.retiarii import serialize
from nni.retiarii.experiment.pytorch import RetiariiExeConfig, RetiariiExperiment
from torch.utils.data import DataLoader
from torchvision import transforms
......@@ -36,8 +36,8 @@ class Net(nn.Module):
if __name__ == '__main__':
base_model = Net(128)
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = bm(MNIST)(root='data/mnist', train=True, download=True, transform=transform)
test_dataset = bm(MNIST)(root='data/mnist', train=False, download=True, transform=transform)
train_dataset = serialize(MNIST, root='data/mnist', train=True, download=True, transform=transform)
test_dataset = serialize(MNIST, root='data/mnist', train=False, download=True, transform=transform)
trainer = pl.Classification(train_dataloader=pl.DataLoader(train_dataset, batch_size=100),
val_dataloaders=pl.DataLoader(test_dataset, batch_size=100),
max_epochs=2)
......
......@@ -340,7 +340,7 @@
}
]
},
"_training_config": {
"_evaluator": {
"module": "nni.retiarii.trainer.PyTorchImageClassificationTrainer",
"kwargs": {
"dataset_cls": "MNIST",
......
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import blackbox_module
from nni.retiarii import basic_unit
@blackbox_module
@basic_unit
class ImportTest(nn.Module):
def __init__(self, foo, bar):
super().__init__()
......
......@@ -4,7 +4,7 @@ import logging
import torch
import torch.nn as nn
from nni.retiarii.utils import add_record, del_record, version_larger_equal
from nni.retiarii.utils import version_larger_equal
_logger = logging.getLogger(__name__)
......@@ -13,39 +13,23 @@ def wrap_module(original_class):
argname_list = list(inspect.signature(original_class).parameters.keys())
# Make copy of original __init__, so we can call it without recursion
original_class.bak_init_for_inject = orig_init
if hasattr(original_class, '__del__'):
orig_del = original_class.__del__
original_class.bak_del_for_inject = orig_del
else:
orig_del = None
original_class.bak_del_for_inject = None
def __init__(self, *args, **kws):
full_args = {}
full_args.update(kws)
for i, arg in enumerate(args):
full_args[argname_list[i]] = arg
add_record(id(self), full_args)
self._init_parameters = full_args
orig_init(self, *args, **kws) # Call the original __init__
def __del__(self):
del_record(id(self))
if orig_del is not None:
orig_del(self)
original_class.__init__ = __init__ # Set the class' __init__ to the new one
original_class.__del__ = __del__
return original_class
def unwrap_module(wrapped_class):
if hasattr(wrapped_class, 'bak_init_for_inject'):
wrapped_class.__init__ = wrapped_class.bak_init_for_inject
delattr(wrapped_class, 'bak_init_for_inject')
if hasattr(wrapped_class, 'bak_del_for_inject'):
if wrapped_class.bak_del_for_inject is not None:
wrapped_class.__del__ = wrapped_class.bak_del_for_inject
delattr(wrapped_class, 'bak_del_for_inject')
return None
def remove_inject_pytorch_nn():
......
......@@ -38,7 +38,7 @@
]
},
"_training_config": {
"_evaluator": {
"__type__": "_debug_no_trainer"
}
}
......@@ -38,7 +38,7 @@
]
},
"_training_config": {
"_evaluator": {
"module": "nni.retiarii.trainer.PyTorchImageClassificationTrainer",
"kwargs": {
"dataset_cls": "MNIST",
......
......@@ -18,7 +18,7 @@ from nni.retiarii import Model, Node
from nni.retiarii import Model, submit_models
from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.integration import RetiariiAdvisor
from nni.retiarii.trainer.pytorch import PyTorchImageClassificationTrainer, PyTorchMultiModelTrainer
from nni.retiarii.evaluator.pytorch import PyTorchImageClassificationTrainer, PyTorchMultiModelTrainer
from nni.retiarii.utils import import_
......
......@@ -12,10 +12,9 @@ import torch.nn.functional as F
import torchvision
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import blackbox_module
from nni.retiarii import basic_unit
from nni.retiarii.converter import convert_to_graph
from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.utils import get_records
class MnistNet(nn.Module):
def __init__(self):
......@@ -35,8 +34,8 @@ class MnistNet(nn.Module):
x = self.fc2(x)
return F.log_softmax(x, dim=1)
# NOTE: blackbox module cannot be placed within class or function
@blackbox_module
# NOTE: serialize module cannot be placed within class or function
@basic_unit
class Linear(nn.Module):
def __init__(self, d_embed, d_proj):
super().__init__()
......@@ -66,9 +65,6 @@ class TestConvert(unittest.TestCase):
model_ir = convert_to_graph(script_module, model)
model_code = model_to_pytorch_script(model_ir)
from .inject_nn import remove_inject_pytorch_nn
remove_inject_pytorch_nn()
exec_vars = {}
exec(model_code + '\n\nconverted_model = _model()', exec_vars)
converted_model = exec_vars['converted_model']
......@@ -458,9 +454,12 @@ class TestConvert(unittest.TestCase):
self.checkExportImport(VAE().eval(), (torch.rand(128, 1, 28, 28),))
def test_torchvision_resnet18(self):
from .inject_nn import inject_pytorch_nn
inject_pytorch_nn()
self.checkExportImport(torchvision.models.resnet18().eval(), (torch.ones(1, 3, 224, 224),))
from .inject_nn import inject_pytorch_nn, remove_inject_pytorch_nn
try:
inject_pytorch_nn()
self.checkExportImport(torchvision.models.resnet18().eval(), (torch.ones(1, 3, 224, 224),))
finally:
remove_inject_pytorch_nn()
def test_resnet(self):
def conv1x1(in_planes, out_planes, stride=1):
......@@ -572,8 +571,11 @@ class TestConvert(unittest.TestCase):
self.checkExportImport(resnet18, (torch.randn(1, 3, 224, 224),))
def test_alexnet(self):
from .inject_nn import inject_pytorch_nn
inject_pytorch_nn()
x = torch.ones(1, 3, 224, 224)
model = torchvision.models.AlexNet()
self.checkExportImport(model, (x,))
from .inject_nn import inject_pytorch_nn, remove_inject_pytorch_nn
try:
inject_pytorch_nn()
x = torch.ones(1, 3, 224, 224)
model = torchvision.models.AlexNet()
self.checkExportImport(model, (x,))
finally:
remove_inject_pytorch_nn()
......@@ -8,10 +8,9 @@ import torch.nn.functional as F
import torchvision
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import blackbox_module
from nni.retiarii import basic_unit
from nni.retiarii.converter import convert_to_graph
from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.utils import get_records
# following pytorch v1.7.1
......
......@@ -15,10 +15,8 @@ import torch.nn.functional as F
import torchvision
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import blackbox_module
from nni.retiarii.converter import convert_to_graph
from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.utils import get_records
# following pytorch v1.7.1
......
......@@ -14,10 +14,9 @@ import torch.nn.functional as F
import torchvision
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import blackbox_module
from nni.retiarii import serialize
from nni.retiarii.converter import convert_to_graph
from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.utils import get_records
class TestPytorch(unittest.TestCase):
......
......@@ -17,7 +17,6 @@ from nni.retiarii import Model, Node
from nni.retiarii import Model, submit_models
from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.integration import RetiariiAdvisor
from nni.retiarii.trainer.pytorch import PyTorchImageClassificationTrainer, PyTorchMultiModelTrainer
from nni.retiarii.utils import import_
......@@ -74,7 +73,7 @@ class DedupInputTest(unittest.TestCase):
# sys.path.insert(0, 'generated')
# multi_model = import_('debug_dedup_input.logical_0')
# trainer = PyTorchMultiModelTrainer(
# multi_model(), phy_models[0][0].training_config.kwargs
# multi_model(), phy_models[0][0].evaluator.kwargs
# )
# trainer.fit()
......
......@@ -9,7 +9,7 @@ import nni
from nni.retiarii import Model, submit_models
from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.integration import RetiariiAdvisor, register_advisor
from nni.retiarii.trainer.pytorch import PyTorchImageClassificationTrainer
from nni.retiarii.evaluator.pytorch import PyTorchImageClassificationTrainer
from nni.retiarii.utils import import_
......
......@@ -23,7 +23,7 @@ def _test_file(json_path):
# add default values to JSON, so we can compare with `==`
for graph_name, graph in orig_ir.items():
if graph_name == '_training_config':
if graph_name == '_evaluator':
continue
if 'inputs' not in graph:
graph['inputs'] = None
......
......@@ -4,7 +4,7 @@ import unittest
import nni.retiarii.nn.pytorch as nn
import torch
import torch.nn.functional as F
from nni.retiarii import Sampler, blackbox_module
from nni.retiarii import Sampler, basic_unit
from nni.retiarii.converter import convert_to_graph
from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.nn.pytorch.mutator import process_inline_mutation
......@@ -29,7 +29,7 @@ class RandomSampler(Sampler):
return random.choice(candidates)
@blackbox_module
@basic_unit
class MutableConv(nn.Module):
def __init__(self):
super().__init__()
......
......@@ -2,13 +2,13 @@ import json
import pytest
import nni
import nni.retiarii.trainer.pytorch.lightning as pl
import nni.retiarii.evaluator.pytorch.lightning as pl
import pytorch_lightning
import torch
import torch.nn as nn
import torch.nn.functional as F
from nni.retiarii import blackbox_module as bm
from nni.retiarii.trainer import FunctionalTrainer
from nni.retiarii import serialize_cls, serialize
from nni.retiarii.evaluator import FunctionalEvaluator
from sklearn.datasets import load_diabetes
from torch.utils.data import Dataset
from torchvision import transforms
......@@ -49,7 +49,7 @@ class FCNet(nn.Module):
return output.view(-1)
@bm
@serialize_cls
class DiabetesDataset(Dataset):
def __init__(self, train=True):
data = load_diabetes()
......@@ -91,8 +91,8 @@ def _reset():
def test_mnist():
_reset()
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = bm(MNIST)(root='data/mnist', train=True, download=True, transform=transform)
test_dataset = bm(MNIST)(root='data/mnist', train=False, download=True, transform=transform)
train_dataset = serialize(MNIST, root='data/mnist', train=True, download=True, transform=transform)
test_dataset = serialize(MNIST, root='data/mnist', train=False, download=True, transform=transform)
lightning = pl.Classification(train_dataloader=pl.DataLoader(train_dataset, batch_size=100),
val_dataloaders=pl.DataLoader(test_dataset, batch_size=100),
max_epochs=2, limit_train_batches=0.25, # for faster training
......@@ -121,7 +121,7 @@ def test_diabetes():
@pytest.mark.skipif(pytorch_lightning.__version__ < '1.0', reason='Incompatible APIs.')
def test_functional():
FunctionalTrainer(_foo)._execute(MNISTModel)
FunctionalEvaluator(_foo)._execute(MNISTModel)
if __name__ == '__main__':
......
......@@ -4,7 +4,7 @@ import re
import sys
import torch
from nni.retiarii import json_dumps, json_loads, blackbox
from nni.retiarii import json_dumps, json_loads, serialize
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
......@@ -23,30 +23,30 @@ class Foo:
return self.aa == other.aa and self.bb == other.bb
def test_blackbox():
module = blackbox(Foo, 3)
def test_serialize():
module = serialize(Foo, 3)
assert json_loads(json_dumps(module)) == module
module = blackbox(Foo, b=2, a=1)
module = serialize(Foo, b=2, a=1)
assert json_loads(json_dumps(module)) == module
module = blackbox(Foo, Foo(1), 5)
module = serialize(Foo, Foo(1), 5)
dumped_module = json_dumps(module)
assert len(dumped_module) > 200 # should not be too longer if the serialization is correct
module = blackbox(Foo, blackbox(Foo, 1), 5)
module = serialize(Foo, serialize(Foo, 1), 5)
dumped_module = json_dumps(module)
assert len(dumped_module) < 200 # should not be too longer if the serialization is correct
assert json_loads(dumped_module) == module
def test_blackbox_module():
def test_basic_unit():
module = ImportTest(3, 0.5)
assert json_loads(json_dumps(module)) == module
def test_dataset():
dataset = blackbox(MNIST, root='data/mnist', train=False, download=True)
dataloader = blackbox(DataLoader, dataset, batch_size=10)
dataset = serialize(MNIST, root='data/mnist', train=False, download=True)
dataloader = serialize(DataLoader, dataset, batch_size=10)
dumped_ans = {
"__type__": "torch.utils.data.dataloader.DataLoader",
......@@ -62,19 +62,19 @@ def test_dataset():
dataloader = json_loads(json_dumps(dumped_ans))
assert isinstance(dataloader, DataLoader)
dataset = blackbox(MNIST, root='data/mnist', train=False, download=True,
transform=blackbox(
dataset = serialize(MNIST, root='data/mnist', train=False, download=True,
transform=serialize(
transforms.Compose,
[blackbox(transforms.ToTensor), blackbox(transforms.Normalize, (0.1307,), (0.3081,))]
[serialize(transforms.ToTensor), serialize(transforms.Normalize, (0.1307,), (0.3081,))]
))
dataloader = blackbox(DataLoader, dataset, batch_size=10)
dataloader = serialize(DataLoader, dataset, batch_size=10)
x, y = next(iter(json_loads(json_dumps(dataloader))))
assert x.size() == torch.Size([10, 1, 28, 28])
assert y.size() == torch.Size([10])
dataset = blackbox(MNIST, root='data/mnist', train=False, download=True,
dataset = serialize(MNIST, root='data/mnist', train=False, download=True,
transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]))
dataloader = blackbox(DataLoader, dataset, batch_size=10)
dataloader = serialize(DataLoader, dataset, batch_size=10)
x, y = next(iter(json_loads(json_dumps(dataloader))))
assert x.size() == torch.Size([10, 1, 28, 28])
assert y.size() == torch.Size([10])
......@@ -87,7 +87,7 @@ def test_type():
if __name__ == '__main__':
test_blackbox()
test_blackbox_module()
test_serialize()
test_basic_unit()
test_dataset()
test_type()
......@@ -12,7 +12,7 @@ from nni.retiarii import Model
from nni.retiarii.converter import convert_to_graph
from nni.retiarii.execution import wait_models
from nni.retiarii.execution.interface import AbstractExecutionEngine, WorkerInfo, MetricData, AbstractGraphListener
from nni.retiarii.graph import DebugTraining, ModelStatus
from nni.retiarii.graph import DebugEvaluator, ModelStatus
from nni.retiarii.nn.pytorch.mutator import process_inline_mutation
......@@ -80,7 +80,7 @@ def _get_model_and_mutators():
base_model = Net()
script_module = torch.jit.script(base_model)
base_model_ir = convert_to_graph(script_module, base_model)
base_model_ir.training_config = DebugTraining()
base_model_ir.evaluator = DebugEvaluator()
mutators = process_inline_mutation(base_model_ir)
return base_model_ir, mutators
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment