Unverified Commit 445e7e0b authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

[Retiarii] Rewrite trainer with PyTorch Lightning (#3359)

parent 137830df
......@@ -8,7 +8,7 @@ from pathlib import Path
from torchvision import transforms
from torchvision.datasets import CIFAR10
from nni.retiarii.experiment import RetiariiExperiment, RetiariiExeConfig
from nni.retiarii.experiment.pytorch import RetiariiExperiment, RetiariiExeConfig
from nni.retiarii.strategies import TPEStrategy
from nni.retiarii.trainer.pytorch import DartsTrainer
......
......@@ -27,14 +27,12 @@ class BlockMutator(Mutator):
n_filter = self.choice(related_info['n_filter_options'])
if related_info['in_ch'] is not None:
_logger.info('zql debug X ...')
in_ch = related_info['in_ch']
else:
assert len(node.predecessors) == 1
the_node = node.predecessors[0]
_logger.info('zql debug ...')
_logger.info(the_node.operation.parameters)
_logger.info(the_node.__repr__())
_logger.debug(repr(the_node.operation.parameters))
_logger.debug(the_node.__repr__())
in_ch = the_node.operation.parameters['out_ch']
# update the placeholder to be a new operation
......
......@@ -5,10 +5,14 @@ from pathlib import Path
from nni.retiarii.trainer.pytorch import PyTorchImageClassificationTrainer
import nni.retiarii.trainer.pytorch.lightning as pl
from nni.retiarii import blackbox_module as bm
from base_mnasnet import MNASNet
from nni.retiarii.experiment import RetiariiExperiment, RetiariiExeConfig
from nni.retiarii.experiment.pytorch import RetiariiExperiment, RetiariiExeConfig
from nni.retiarii.strategies import TPEStrategy
from torchvision import transforms
from torchvision.datasets import CIFAR10
from mutator import BlockMutator
if __name__ == '__main__':
......@@ -20,16 +24,27 @@ if __name__ == '__main__':
base_model = MNASNet(0.5, _DEFAULT_DEPTHS, _DEFAULT_CONVOPS, _DEFAULT_KERNEL_SIZES,
_DEFAULT_NUM_LAYERS, _DEFAULT_SKIPS)
trainer = PyTorchImageClassificationTrainer(base_model, dataset_cls="CIFAR10",
dataset_kwargs={"root": "data/cifar10", "download": True},
dataloader_kwargs={"batch_size": 32},
optimizer_kwargs={"lr": 1e-3},
trainer_kwargs={"max_epochs": 1})
# new interface
applied_mutators = []
applied_mutators.append(BlockMutator('mutable_0'))
applied_mutators.append(BlockMutator('mutable_1'))
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
valid_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
train_dataset = bm(CIFAR10)(root='data/cifar10', train=True, download=True, transform=train_transform)
test_dataset = bm(CIFAR10)(root='data/cifar10', train=False, download=True, transform=valid_transform)
trainer = pl.Classification(train_dataloader=pl.DataLoader(train_dataset, batch_size=100),
val_dataloaders=pl.DataLoader(test_dataset, batch_size=100),
max_epochs=1, limit_train_batches=0.2)
applied_mutators = [
BlockMutator('mutable_0'),
BlockMutator('mutable_1')
]
simple_startegy = TPEStrategy()
......
import random
import nni.retiarii.nn.pytorch as nn
import nni.retiarii.trainer.pytorch.lightning as pl
import torch.nn.functional as F
from nni.retiarii.experiment import RetiariiExeConfig, RetiariiExperiment
from nni.retiarii import blackbox_module as bm
from nni.retiarii.experiment.pytorch import RetiariiExeConfig, RetiariiExperiment
from nni.retiarii.strategies import RandomStrategy
from nni.retiarii.trainer.pytorch import PyTorchImageClassificationTrainer
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
class Net(nn.Module):
......@@ -31,11 +35,12 @@ class Net(nn.Module):
if __name__ == '__main__':
base_model = Net(128)
trainer = PyTorchImageClassificationTrainer(base_model, dataset_cls="MNIST",
dataset_kwargs={"root": "data/mnist", "download": True},
dataloader_kwargs={"batch_size": 32},
optimizer_kwargs={"lr": 1e-3},
trainer_kwargs={"max_epochs": 1})
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = bm(MNIST)(root='data/mnist', train=True, download=True, transform=transform)
test_dataset = bm(MNIST)(root='data/mnist', train=False, download=True, transform=transform)
trainer = pl.Classification(train_dataloader=pl.DataLoader(train_dataset, batch_size=100),
val_dataloaders=pl.DataLoader(test_dataset, batch_size=100),
max_epochs=2)
simple_startegy = RandomStrategy()
......
import json
import logging
import random
import os
from nni.retiarii import Model, submit_models, wait_models
from nni.retiarii.strategy import BaseStrategy
from nni.retiarii import Sampler
_logger = logging.getLogger(__name__)
class RandomSampler(Sampler):
def choice(self, candidates, mutator, model, index):
return random.choice(candidates)
class SimpleStrategy(BaseStrategy):
def __init__(self):
self.name = ''
def run(self, base_model, applied_mutators, trainer):
try:
_logger.info('stargety start...')
while True:
model = base_model
_logger.info('apply mutators...')
_logger.info('mutators: {}'.format(applied_mutators))
random_sampler = RandomSampler()
for mutator in applied_mutators:
_logger.info('mutate model...')
mutator.bind_sampler(random_sampler)
model = mutator.apply(model)
# get and apply training approach
_logger.info('apply training approach...')
model.apply_trainer(trainer['modulename'], trainer['args'])
# run models
submit_models(model)
wait_models(model)
_logger.info('Strategy says:', model.metric)
except Exception as e:
_logger.error(logging.exception('message'))
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import blackbox_module
@blackbox_module
class ImportTest(nn.Module):
def __init__(self, foo, bar):
super().__init__()
self.foo = nn.Linear(foo, 3)
self.bar = nn.Dropout(bar)
def __eq__(self, other):
return self.foo.in_features == other.foo.in_features and self.bar.p == other.bar.p
......@@ -39,7 +39,6 @@
},
"_training_config": {
"module": "_debug_no_trainer",
"kwargs": {}
"__type__": "_debug_no_trainer"
}
}
import json
import pytest
import nni
import nni.retiarii.trainer.pytorch.lightning as pl
import pytorch_lightning
import torch
import torch.nn as nn
import torch.nn.functional as F
from nni.retiarii import blackbox_module as bm
from nni.retiarii.trainer import FunctionalTrainer
from sklearn.datasets import load_diabetes
from torch.utils.data import Dataset
from torchvision import transforms
from torchvision.datasets import MNIST
debug = False
progress_bar_refresh_rate = 0
if debug:
progress_bar_refresh_rate = 1
class MNISTModel(nn.Module):
def __init__(self):
super().__init__()
self.layer_1 = nn.Linear(28 * 28, 128)
self.layer_2 = nn.Linear(128, 10)
def forward(self, x):
x = x.view(x.size(0), -1)
x = self.layer_1(x)
x = F.relu(x)
x = self.layer_2(x)
return x
class FCNet(nn.Module):
def __init__(self, input_size, output_size):
super().__init__()
self.l1 = nn.Linear(input_size, 5)
self.relu = nn.ReLU()
self.l2 = nn.Linear(5, output_size)
def forward(self, x):
output = self.l1(x)
output = self.relu(output)
output = self.l2(output)
return output.view(-1)
@bm
class DiabetesDataset(Dataset):
def __init__(self, train=True):
data = load_diabetes()
self.x = torch.tensor(data['data'], dtype=torch.float32)
self.y = torch.tensor(data['target'], dtype=torch.float32)
self.length = self.x.shape[0]
split = int(self.length * 0.8)
if train:
self.x = self.x[:split]
self.y = self.y[:split]
else:
self.x = self.x[split:]
self.y = self.y[split:]
self.length = len(self.y)
def __getitem__(self, idx):
return self.x[idx], self.y[idx]
def __len__(self):
return self.length
def _get_final_result():
return float(json.loads(nni.runtime.platform.test._last_metric)['value'])
def _foo(model_cls):
assert model_cls == MNISTModel
def _reset():
# this is to not affect other tests in sdk
nni.trial._intermediate_seq = 0
nni.trial._params = {'foo': 'bar', 'parameter_id': 0}
nni.runtime.platform.test._last_metric = None
@pytest.mark.skipif(pytorch_lightning.__version__ < '1.0', reason='Incompatible APIs.')
def test_mnist():
_reset()
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = bm(MNIST)(root='data/mnist', train=True, download=True, transform=transform)
test_dataset = bm(MNIST)(root='data/mnist', train=False, download=True, transform=transform)
lightning = pl.Classification(train_dataloader=pl.DataLoader(train_dataset, batch_size=100),
val_dataloaders=pl.DataLoader(test_dataset, batch_size=100),
max_epochs=2, limit_train_batches=0.25, # for faster training
progress_bar_refresh_rate=progress_bar_refresh_rate)
lightning._execute(MNISTModel)
assert _get_final_result() > 0.7
_reset()
@pytest.mark.skipif(pytorch_lightning.__version__ < '1.0', reason='Incompatible APIs.')
def test_diabetes():
_reset()
nni.trial._params = {'foo': 'bar', 'parameter_id': 0}
nni.runtime.platform.test._last_metric = None
train_dataset = DiabetesDataset(train=True)
test_dataset = DiabetesDataset(train=False)
lightning = pl.Regression(optimizer=torch.optim.SGD,
train_dataloader=pl.DataLoader(train_dataset, batch_size=20),
val_dataloaders=pl.DataLoader(test_dataset, batch_size=20),
max_epochs=100,
progress_bar_refresh_rate=progress_bar_refresh_rate)
lightning._execute(FCNet(train_dataset.x.shape[1], 1))
assert _get_final_result() < 2e4
_reset()
@pytest.mark.skipif(pytorch_lightning.__version__ < '1.0', reason='Incompatible APIs.')
def test_functional():
FunctionalTrainer(_foo)._execute(MNISTModel)
if __name__ == '__main__':
test_mnist()
test_diabetes()
test_functional()
import json
from pathlib import Path
import re
import sys
import torch
from nni.retiarii import json_dumps, json_loads, blackbox
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
sys.path.insert(0, Path(__file__).parent.as_posix())
from imported.model import ImportTest
class Foo:
def __init__(self, a, b=1):
self.aa = a
self.bb = [b + 1 for _ in range(1000)]
def __eq__(self, other):
return self.aa == other.aa and self.bb == other.bb
def test_blackbox():
module = blackbox(Foo, 3)
assert json_loads(json_dumps(module)) == module
module = blackbox(Foo, b=2, a=1)
assert json_loads(json_dumps(module)) == module
module = blackbox(Foo, Foo(1), 5)
dumped_module = json_dumps(module)
assert len(dumped_module) > 200 # should not be too longer if the serialization is correct
module = blackbox(Foo, blackbox(Foo, 1), 5)
dumped_module = json_dumps(module)
assert len(dumped_module) < 200 # should not be too longer if the serialization is correct
assert json_loads(dumped_module) == module
def test_blackbox_module():
module = ImportTest(3, 0.5)
assert json_loads(json_dumps(module)) == module
def test_dataset():
dataset = blackbox(MNIST, root='data/mnist', train=False, download=True)
dataloader = blackbox(DataLoader, dataset, batch_size=10)
dumped_ans = {
"__type__": "torch.utils.data.dataloader.DataLoader",
"arguments": {
"batch_size": 10,
"dataset": {
"__type__": "torchvision.datasets.mnist.MNIST",
"arguments": {"root": "data/mnist", "train": False, "download": True}
}
}
}
assert json_dumps(dataloader) == json_dumps(dumped_ans)
dataloader = json_loads(json_dumps(dumped_ans))
assert isinstance(dataloader, DataLoader)
dataset = blackbox(MNIST, root='data/mnist', train=False, download=True,
transform=blackbox(
transforms.Compose,
[blackbox(transforms.ToTensor), blackbox(transforms.Normalize, (0.1307,), (0.3081,))]
))
dataloader = blackbox(DataLoader, dataset, batch_size=10)
x, y = next(iter(json_loads(json_dumps(dataloader))))
assert x.size() == torch.Size([10, 1, 28, 28])
assert y.size() == torch.Size([10])
dataset = blackbox(MNIST, root='data/mnist', train=False, download=True,
transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]))
dataloader = blackbox(DataLoader, dataset, batch_size=10)
x, y = next(iter(json_loads(json_dumps(dataloader))))
assert x.size() == torch.Size([10, 1, 28, 28])
assert y.size() == torch.Size([10])
def test_type():
assert json_dumps(torch.optim.Adam) == '{"__typename__": "torch.optim.adam.Adam"}'
assert json_loads('{"__typename__": "torch.optim.adam.Adam"}') == torch.optim.Adam
assert re.match(r'{"__typename__": "(.*)test_serializer.Foo"}', json_dumps(Foo))
if __name__ == '__main__':
test_blackbox()
test_blackbox_module()
test_dataset()
test_type()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment