Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
051ed9e6
Unverified
Commit
051ed9e6
authored
May 10, 2021
by
Yuge Zhang
Committed by
GitHub
May 10, 2021
Browse files
Unpin `pytorch_lightning<1.2` (#3598)
parent
b7f374ce
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
10 additions
and
4 deletions
+10
-4
dependencies/recommended.txt
dependencies/recommended.txt
+1
-1
nni/retiarii/evaluator/pytorch/lightning.py
nni/retiarii/evaluator/pytorch/lightning.py
+6
-1
pipelines/full-test-linux.yml
pipelines/full-test-linux.yml
+1
-1
pipelines/full-test-windows.yml
pipelines/full-test-windows.yml
+1
-1
test/ut/retiarii/test_lightning_trainer.py
test/ut/retiarii/test_lightning_trainer.py
+1
-0
No files found.
dependencies/recommended.txt
View file @
051ed9e6
...
...
@@ -6,7 +6,7 @@ torch == 1.6.0+cpu ; sys_platform != "darwin"
torch == 1.6.0 ; sys_platform == "darwin"
torchvision == 0.7.0+cpu ; sys_platform != "darwin"
torchvision == 0.7.0 ; sys_platform == "darwin"
pytorch-lightning >= 1.1.1
, < 1.2
pytorch-lightning >= 1.1.1
onnx
peewee
graphviz
nni/retiarii/evaluator/pytorch/lightning.py
View file @
051ed9e6
...
...
@@ -165,13 +165,18 @@ class _SupervisedLearningModule(LightningModule):
return
{
name
:
self
.
trainer
.
callback_metrics
[
'val_'
+
name
].
item
()
for
name
in
self
.
metrics
}
class
_AccuracyWithLogits
(
pl
.
metrics
.
Accuracy
):
def
update
(
self
,
pred
,
target
):
return
super
().
update
(
nn
.
functional
.
softmax
(
pred
),
target
)
@
serialize_cls
class
_ClassificationModule
(
_SupervisedLearningModule
):
def
__init__
(
self
,
criterion
:
nn
.
Module
=
nn
.
CrossEntropyLoss
,
learning_rate
:
float
=
0.001
,
weight_decay
:
float
=
0.
,
optimizer
:
optim
.
Optimizer
=
optim
.
Adam
):
super
().
__init__
(
criterion
,
{
'acc'
:
pl
.
metrics
.
Accuracy
},
super
().
__init__
(
criterion
,
{
'acc'
:
_AccuracyWithLogits
},
learning_rate
=
learning_rate
,
weight_decay
=
weight_decay
,
optimizer
=
optimizer
)
...
...
pipelines/full-test-linux.yml
View file @
051ed9e6
...
...
@@ -31,7 +31,7 @@ jobs:
python3 -m pip install scikit-learn==0.24.1
python3 -m pip install torchvision==0.7.0
python3 -m pip install torch==1.6.0
python3 -m pip install 'pytorch-lightning>=1.1.1
,<1.2
'
python3 -m pip install 'pytorch-lightning>=1.1.1'
python3 -m pip install keras==2.1.6
python3 -m pip install tensorflow==2.3.1 tensorflow-estimator==2.3.0
python3 -m pip install thop
...
...
pipelines/full-test-windows.yml
View file @
051ed9e6
...
...
@@ -28,7 +28,7 @@ jobs:
python -m pip install scikit-learn==0.24.1
python -m pip install keras==2.1.6
python -m pip install torch==1.6.0 torchvision==0.7.0 -f https://download.pytorch.org/whl/torch_stable.html
python -m pip install 'pytorch-lightning>=1.1.1
,<1.2
'
python -m pip install 'pytorch-lightning>=1.1.1'
python -m pip install tensorflow==2.3.1 tensorflow-estimator==2.3.0
displayName
:
Install extra dependencies
...
...
test/ut/retiarii/test_lightning_trainer.py
View file @
051ed9e6
...
...
@@ -3,6 +3,7 @@ import pytest
import
nni
import
nni.retiarii.evaluator.pytorch.lightning
as
pl
import
nni.runtime.platform.test
import
pytorch_lightning
import
torch
import
torch.nn
as
nn
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment