Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
c02f4922
"vscode:/vscode.git/clone" did not exist on "2c0cce90d0d2bf9e361e514134fa5689f9f46db4"
Unverified
Commit
c02f4922
authored
Jan 12, 2022
by
Yuge Zhang
Committed by
GitHub
Jan 12, 2022
Browse files
Fix model type in lightning (#4451)
parent
2772751d
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
24 additions
and
8 deletions
+24
-8
nni/retiarii/evaluator/pytorch/lightning.py
nni/retiarii/evaluator/pytorch/lightning.py
+5
-5
test/ut/retiarii/test_lightning_trainer.py
test/ut/retiarii/test_lightning_trainer.py
+19
-3
No files found.
nni/retiarii/evaluator/pytorch/lightning.py
View file @
c02f4922
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
import
os
import
os
import
warnings
import
warnings
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Dict
,
NoReturn
,
Union
,
Optional
,
List
,
Type
from
typing
import
Dict
,
Union
,
Optional
,
List
,
Type
import
pytorch_lightning
as
pl
import
pytorch_lightning
as
pl
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -33,11 +33,11 @@ class LightningModule(pl.LightningModule):
...
@@ -33,11 +33,11 @@ class LightningModule(pl.LightningModule):
Lightning modules used in NNI should inherit this class.
Lightning modules used in NNI should inherit this class.
"""
"""
def
set_model
(
self
,
model
:
Union
[
Type
[
nn
.
Module
],
nn
.
Module
])
->
NoReturn
:
def
set_model
(
self
,
model
:
Union
[
Type
[
nn
.
Module
],
nn
.
Module
])
->
None
:
if
isinstance
(
model
,
type
):
if
isinstance
(
model
,
nn
.
Module
):
self
.
model
=
model
()
else
:
self
.
model
=
model
self
.
model
=
model
else
:
self
.
model
=
model
()
Trainer
=
nni
.
trace
(
pl
.
Trainer
)
Trainer
=
nni
.
trace
(
pl
.
Trainer
)
...
...
test/ut/retiarii/test_lightning_trainer.py
View file @
c02f4922
...
@@ -8,7 +8,6 @@ import pytorch_lightning
...
@@ -8,7 +8,6 @@ import pytorch_lightning
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
nni.retiarii
import
serialize_cls
,
serialize
from
nni.retiarii.evaluator
import
FunctionalEvaluator
from
nni.retiarii.evaluator
import
FunctionalEvaluator
from
sklearn.datasets
import
load_diabetes
from
sklearn.datasets
import
load_diabetes
from
torch.utils.data
import
Dataset
from
torch.utils.data
import
Dataset
...
@@ -92,8 +91,8 @@ def _reset():
...
@@ -92,8 +91,8 @@ def _reset():
def
test_mnist
():
def
test_mnist
():
_reset
()
_reset
()
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.1307
,),
(
0.3081
,))])
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.1307
,),
(
0.3081
,))])
train_dataset
=
serializ
e
(
MNIST
,
root
=
'data/mnist'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
train_dataset
=
nni
.
trac
e
(
MNIST
)(
root
=
'data/mnist'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
test_dataset
=
serializ
e
(
MNIST
,
root
=
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
test_dataset
=
nni
.
trac
e
(
MNIST
)(
root
=
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
lightning
=
pl
.
Classification
(
train_dataloader
=
pl
.
DataLoader
(
train_dataset
,
batch_size
=
100
),
lightning
=
pl
.
Classification
(
train_dataloader
=
pl
.
DataLoader
(
train_dataset
,
batch_size
=
100
),
val_dataloaders
=
pl
.
DataLoader
(
test_dataset
,
batch_size
=
100
),
val_dataloaders
=
pl
.
DataLoader
(
test_dataset
,
batch_size
=
100
),
max_epochs
=
2
,
limit_train_batches
=
0.25
,
# for faster training
max_epochs
=
2
,
limit_train_batches
=
0.25
,
# for faster training
...
@@ -125,7 +124,24 @@ def test_functional():
...
@@ -125,7 +124,24 @@ def test_functional():
FunctionalEvaluator
(
_foo
).
_execute
(
MNISTModel
)
FunctionalEvaluator
(
_foo
).
_execute
(
MNISTModel
)
@
pytest
.
mark
.
skipif
(
pytorch_lightning
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs.'
)
def
test_fit_api
():
_reset
()
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.1307
,),
(
0.3081
,))])
train_dataset
=
nni
.
trace
(
MNIST
)(
root
=
'data/mnist'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
test_dataset
=
nni
.
trace
(
MNIST
)(
root
=
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
lightning
=
pl
.
Classification
(
train_dataloader
=
pl
.
DataLoader
(
train_dataset
,
batch_size
=
100
),
val_dataloaders
=
pl
.
DataLoader
(
test_dataset
,
batch_size
=
100
),
max_epochs
=
1
,
limit_train_batches
=
0.1
,
# for faster training
progress_bar_refresh_rate
=
progress_bar_refresh_rate
)
lightning
.
fit
(
lambda
:
MNISTModel
())
lightning
.
fit
(
MNISTModel
)
lightning
.
fit
(
MNISTModel
())
_reset
()
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_mnist
()
test_mnist
()
test_diabetes
()
test_diabetes
()
test_functional
()
test_functional
()
test_fit_api
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment