Unverified Commit c02f4922 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Fix model type in lightning (#4451)

parent 2772751d
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import os import os
import warnings import warnings
from pathlib import Path from pathlib import Path
from typing import Dict, NoReturn, Union, Optional, List, Type from typing import Dict, Union, Optional, List, Type
import pytorch_lightning as pl import pytorch_lightning as pl
import torch.nn as nn import torch.nn as nn
...@@ -33,11 +33,11 @@ class LightningModule(pl.LightningModule): ...@@ -33,11 +33,11 @@ class LightningModule(pl.LightningModule):
Lightning modules used in NNI should inherit this class. Lightning modules used in NNI should inherit this class.
""" """
def set_model(self, model: Union[Type[nn.Module], nn.Module]) -> NoReturn: def set_model(self, model: Union[Type[nn.Module], nn.Module]) -> None:
if isinstance(model, type): if isinstance(model, nn.Module):
self.model = model()
else:
self.model = model self.model = model
else:
self.model = model()
Trainer = nni.trace(pl.Trainer) Trainer = nni.trace(pl.Trainer)
......
...@@ -8,7 +8,6 @@ import pytorch_lightning ...@@ -8,7 +8,6 @@ import pytorch_lightning
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from nni.retiarii import serialize_cls, serialize
from nni.retiarii.evaluator import FunctionalEvaluator from nni.retiarii.evaluator import FunctionalEvaluator
from sklearn.datasets import load_diabetes from sklearn.datasets import load_diabetes
from torch.utils.data import Dataset from torch.utils.data import Dataset
...@@ -92,8 +91,8 @@ def _reset(): ...@@ -92,8 +91,8 @@ def _reset():
def test_mnist(): def test_mnist():
_reset() _reset()
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = serialize(MNIST, root='data/mnist', train=True, download=True, transform=transform) train_dataset = nni.trace(MNIST)(root='data/mnist', train=True, download=True, transform=transform)
test_dataset = serialize(MNIST, root='data/mnist', train=False, download=True, transform=transform) test_dataset = nni.trace(MNIST)(root='data/mnist', train=False, download=True, transform=transform)
lightning = pl.Classification(train_dataloader=pl.DataLoader(train_dataset, batch_size=100), lightning = pl.Classification(train_dataloader=pl.DataLoader(train_dataset, batch_size=100),
val_dataloaders=pl.DataLoader(test_dataset, batch_size=100), val_dataloaders=pl.DataLoader(test_dataset, batch_size=100),
max_epochs=2, limit_train_batches=0.25, # for faster training max_epochs=2, limit_train_batches=0.25, # for faster training
...@@ -125,7 +124,24 @@ def test_functional(): ...@@ -125,7 +124,24 @@ def test_functional():
FunctionalEvaluator(_foo)._execute(MNISTModel) FunctionalEvaluator(_foo)._execute(MNISTModel)
@pytest.mark.skipif(pytorch_lightning.__version__ < '1.0', reason='Incompatible APIs.')
def test_fit_api():
_reset()
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = nni.trace(MNIST)(root='data/mnist', train=True, download=True, transform=transform)
test_dataset = nni.trace(MNIST)(root='data/mnist', train=False, download=True, transform=transform)
lightning = pl.Classification(train_dataloader=pl.DataLoader(train_dataset, batch_size=100),
val_dataloaders=pl.DataLoader(test_dataset, batch_size=100),
max_epochs=1, limit_train_batches=0.1, # for faster training
progress_bar_refresh_rate=progress_bar_refresh_rate)
lightning.fit(lambda: MNISTModel())
lightning.fit(MNISTModel)
lightning.fit(MNISTModel())
_reset()
if __name__ == '__main__': if __name__ == '__main__':
test_mnist() test_mnist()
test_diabetes() test_diabetes()
test_functional() test_functional()
test_fit_api()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment