Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
c02f4922
Unverified
Commit
c02f4922
authored
Jan 12, 2022
by
Yuge Zhang
Committed by
GitHub
Jan 12, 2022
Browse files
Fix model type in lightning (#4451)
parent
2772751d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
24 additions
and
8 deletions
+24
-8
nni/retiarii/evaluator/pytorch/lightning.py
nni/retiarii/evaluator/pytorch/lightning.py
+5
-5
test/ut/retiarii/test_lightning_trainer.py
test/ut/retiarii/test_lightning_trainer.py
+19
-3
No files found.
nni/retiarii/evaluator/pytorch/lightning.py
View file @
c02f4922
...
...
@@ -4,7 +4,7 @@
import
os
import
warnings
from
pathlib
import
Path
from
typing
import
Dict
,
NoReturn
,
Union
,
Optional
,
List
,
Type
from
typing
import
Dict
,
Union
,
Optional
,
List
,
Type
import
pytorch_lightning
as
pl
import
torch.nn
as
nn
...
...
@@ -33,11 +33,11 @@ class LightningModule(pl.LightningModule):
Lightning modules used in NNI should inherit this class.
"""
def
set_model
(
self
,
model
:
Union
[
Type
[
nn
.
Module
],
nn
.
Module
])
->
NoReturn
:
if
isinstance
(
model
,
type
):
self
.
model
=
model
()
else
:
def
set_model
(
self
,
model
:
Union
[
Type
[
nn
.
Module
],
nn
.
Module
])
->
None
:
if
isinstance
(
model
,
nn
.
Module
):
self
.
model
=
model
else
:
self
.
model
=
model
()
Trainer
=
nni
.
trace
(
pl
.
Trainer
)
...
...
test/ut/retiarii/test_lightning_trainer.py
View file @
c02f4922
...
...
@@ -8,7 +8,6 @@ import pytorch_lightning
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
nni.retiarii
import
serialize_cls
,
serialize
from
nni.retiarii.evaluator
import
FunctionalEvaluator
from
sklearn.datasets
import
load_diabetes
from
torch.utils.data
import
Dataset
...
...
@@ -92,8 +91,8 @@ def _reset():
def
test_mnist
():
_reset
()
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.1307
,),
(
0.3081
,))])
train_dataset
=
serializ
e
(
MNIST
,
root
=
'data/mnist'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
test_dataset
=
serializ
e
(
MNIST
,
root
=
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
train_dataset
=
nni
.
trac
e
(
MNIST
)(
root
=
'data/mnist'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
test_dataset
=
nni
.
trac
e
(
MNIST
)(
root
=
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
lightning
=
pl
.
Classification
(
train_dataloader
=
pl
.
DataLoader
(
train_dataset
,
batch_size
=
100
),
val_dataloaders
=
pl
.
DataLoader
(
test_dataset
,
batch_size
=
100
),
max_epochs
=
2
,
limit_train_batches
=
0.25
,
# for faster training
...
...
@@ -125,7 +124,24 @@ def test_functional():
FunctionalEvaluator
(
_foo
).
_execute
(
MNISTModel
)
@
pytest
.
mark
.
skipif
(
pytorch_lightning
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs.'
)
def
test_fit_api
():
_reset
()
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.1307
,),
(
0.3081
,))])
train_dataset
=
nni
.
trace
(
MNIST
)(
root
=
'data/mnist'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
test_dataset
=
nni
.
trace
(
MNIST
)(
root
=
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
lightning
=
pl
.
Classification
(
train_dataloader
=
pl
.
DataLoader
(
train_dataset
,
batch_size
=
100
),
val_dataloaders
=
pl
.
DataLoader
(
test_dataset
,
batch_size
=
100
),
max_epochs
=
1
,
limit_train_batches
=
0.1
,
# for faster training
progress_bar_refresh_rate
=
progress_bar_refresh_rate
)
lightning
.
fit
(
lambda
:
MNISTModel
())
lightning
.
fit
(
MNISTModel
)
lightning
.
fit
(
MNISTModel
())
_reset
()
if
__name__
==
'__main__'
:
test_mnist
()
test_diabetes
()
test_functional
()
test_fit_api
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment