Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
051ed9e6
Unverified
Commit
051ed9e6
authored
May 10, 2021
by
Yuge Zhang
Committed by
GitHub
May 10, 2021
Browse files
Unpin `pytorch_lightning<1.2` (#3598)
parent
b7f374ce
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
10 additions
and
4 deletions
+10
-4
dependencies/recommended.txt
dependencies/recommended.txt
+1
-1
nni/retiarii/evaluator/pytorch/lightning.py
nni/retiarii/evaluator/pytorch/lightning.py
+6
-1
pipelines/full-test-linux.yml
pipelines/full-test-linux.yml
+1
-1
pipelines/full-test-windows.yml
pipelines/full-test-windows.yml
+1
-1
test/ut/retiarii/test_lightning_trainer.py
test/ut/retiarii/test_lightning_trainer.py
+1
-0
No files found.
dependencies/recommended.txt
View file @
051ed9e6
...
...
@@ -6,7 +6,7 @@ torch == 1.6.0+cpu ; sys_platform != "darwin"
torch == 1.6.0 ; sys_platform == "darwin"
torchvision == 0.7.0+cpu ; sys_platform != "darwin"
torchvision == 0.7.0 ; sys_platform == "darwin"
pytorch-lightning >= 1.1.1
, < 1.2
pytorch-lightning >= 1.1.1
onnx
peewee
graphviz
nni/retiarii/evaluator/pytorch/lightning.py
View file @
051ed9e6
...
...
@@ -165,13 +165,18 @@ class _SupervisedLearningModule(LightningModule):
return
{
name
:
self
.
trainer
.
callback_metrics
[
'val_'
+
name
].
item
()
for
name
in
self
.
metrics
}
class
_AccuracyWithLogits
(
pl
.
metrics
.
Accuracy
):
def
update
(
self
,
pred
,
target
):
return
super
().
update
(
nn
.
functional
.
softmax
(
pred
),
target
)
@
serialize_cls
class
_ClassificationModule
(
_SupervisedLearningModule
):
def
__init__
(
self
,
criterion
:
nn
.
Module
=
nn
.
CrossEntropyLoss
,
learning_rate
:
float
=
0.001
,
weight_decay
:
float
=
0.
,
optimizer
:
optim
.
Optimizer
=
optim
.
Adam
):
super
().
__init__
(
criterion
,
{
'acc'
:
pl
.
metrics
.
Accuracy
},
super
().
__init__
(
criterion
,
{
'acc'
:
_AccuracyWithLogits
},
learning_rate
=
learning_rate
,
weight_decay
=
weight_decay
,
optimizer
=
optimizer
)
...
...
pipelines/full-test-linux.yml
View file @
051ed9e6
...
...
@@ -31,7 +31,7 @@ jobs:
python3 -m pip install scikit-learn==0.24.1
python3 -m pip install torchvision==0.7.0
python3 -m pip install torch==1.6.0
python3 -m pip install 'pytorch-lightning>=1.1.1
,<1.2
'
python3 -m pip install 'pytorch-lightning>=1.1.1'
python3 -m pip install keras==2.1.6
python3 -m pip install tensorflow==2.3.1 tensorflow-estimator==2.3.0
python3 -m pip install thop
...
...
pipelines/full-test-windows.yml
View file @
051ed9e6
...
...
@@ -28,7 +28,7 @@ jobs:
python -m pip install scikit-learn==0.24.1
python -m pip install keras==2.1.6
python -m pip install torch==1.6.0 torchvision==0.7.0 -f https://download.pytorch.org/whl/torch_stable.html
python -m pip install 'pytorch-lightning>=1.1.1
,<1.2
'
python -m pip install 'pytorch-lightning>=1.1.1'
python -m pip install tensorflow==2.3.1 tensorflow-estimator==2.3.0
displayName
:
Install extra dependencies
...
...
test/ut/retiarii/test_lightning_trainer.py
View file @
051ed9e6
...
...
@@ -3,6 +3,7 @@ import pytest
import
nni
import
nni.retiarii.evaluator.pytorch.lightning
as
pl
import
nni.runtime.platform.test
import
pytorch_lightning
import
torch
import
torch.nn
as
nn
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment