Unverified Commit c02f4922 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Fix model type in lightning (#4451)

parent 2772751d
......@@ -4,7 +4,7 @@
import os
import warnings
from pathlib import Path
from typing import Dict, NoReturn, Union, Optional, List, Type
from typing import Dict, Union, Optional, List, Type
import pytorch_lightning as pl
import torch.nn as nn
......@@ -33,11 +33,11 @@ class LightningModule(pl.LightningModule):
Lightning modules used in NNI should inherit this class.
"""
def set_model(self, model: Union[Type[nn.Module], nn.Module]) -> NoReturn:
if isinstance(model, type):
self.model = model()
else:
def set_model(self, model: Union[Type[nn.Module], nn.Module]) -> None:
if isinstance(model, nn.Module):
self.model = model
else:
self.model = model()
Trainer = nni.trace(pl.Trainer)
......
......@@ -8,7 +8,6 @@ import pytorch_lightning
import torch
import torch.nn as nn
import torch.nn.functional as F
from nni.retiarii import serialize_cls, serialize
from nni.retiarii.evaluator import FunctionalEvaluator
from sklearn.datasets import load_diabetes
from torch.utils.data import Dataset
......@@ -92,8 +91,8 @@ def _reset():
def test_mnist():
_reset()
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = serialize(MNIST, root='data/mnist', train=True, download=True, transform=transform)
test_dataset = serialize(MNIST, root='data/mnist', train=False, download=True, transform=transform)
train_dataset = nni.trace(MNIST)(root='data/mnist', train=True, download=True, transform=transform)
test_dataset = nni.trace(MNIST)(root='data/mnist', train=False, download=True, transform=transform)
lightning = pl.Classification(train_dataloader=pl.DataLoader(train_dataset, batch_size=100),
val_dataloaders=pl.DataLoader(test_dataset, batch_size=100),
max_epochs=2, limit_train_batches=0.25, # for faster training
......@@ -125,7 +124,24 @@ def test_functional():
FunctionalEvaluator(_foo)._execute(MNISTModel)
@pytest.mark.skipif(pytorch_lightning.__version__ < '1.0', reason='Incompatible APIs.')
def test_fit_api():
_reset()
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = nni.trace(MNIST)(root='data/mnist', train=True, download=True, transform=transform)
test_dataset = nni.trace(MNIST)(root='data/mnist', train=False, download=True, transform=transform)
lightning = pl.Classification(train_dataloader=pl.DataLoader(train_dataset, batch_size=100),
val_dataloaders=pl.DataLoader(test_dataset, batch_size=100),
max_epochs=1, limit_train_batches=0.1, # for faster training
progress_bar_refresh_rate=progress_bar_refresh_rate)
lightning.fit(lambda: MNISTModel())
lightning.fit(MNISTModel)
lightning.fit(MNISTModel())
_reset()
if __name__ == '__main__':
test_mnist()
test_diabetes()
test_functional()
test_fit_api()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment