Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
dc58203d
Unverified
Commit
dc58203d
authored
Nov 03, 2021
by
Yuge Zhang
Committed by
GitHub
Nov 03, 2021
Browse files
Adopt torchmetrics (#4290)
parent
8fc555ad
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
23 additions
and
19 deletions
+23
-19
dependencies/recommended.txt
dependencies/recommended.txt
+2
-1
dependencies/recommended_legacy.txt
dependencies/recommended_legacy.txt
+1
-0
nni/retiarii/evaluator/pytorch/cgo/accelerator.py
nni/retiarii/evaluator/pytorch/cgo/accelerator.py
+12
-11
nni/retiarii/evaluator/pytorch/cgo/evaluator.py
nni/retiarii/evaluator/pytorch/cgo/evaluator.py
+4
-4
nni/retiarii/evaluator/pytorch/lightning.py
nni/retiarii/evaluator/pytorch/lightning.py
+4
-3
No files found.
dependencies/recommended.txt
View file @
dc58203d
...
...
@@ -8,7 +8,8 @@ torch == 1.9.0+cpu ; sys_platform != "darwin"
torch == 1.9.0 ; sys_platform == "darwin"
torchvision == 0.10.0+cpu ; sys_platform != "darwin"
torchvision == 0.10.0 ; sys_platform == "darwin"
pytorch-lightning >= 1.4.2
pytorch-lightning >= 1.5
torchmetrics
onnx
peewee
graphviz
...
...
dependencies/recommended_legacy.txt
View file @
dc58203d
...
...
@@ -6,6 +6,7 @@ torchvision == 0.7.0+cpu
# It will install pytorch-lightning 0.8.x and unit tests won't work.
# Latest version has conflict with tensorboard and tensorflow 1.x.
pytorch-lightning
torchmetrics
keras == 2.1.6
onnx
...
...
nni/retiarii/evaluator/pytorch/cgo/accelerator.py
View file @
dc58203d
from
typing
import
Any
,
Union
,
Optional
,
List
import
torch
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
typing
import
Any
,
List
,
Optional
,
Union
import
torch
from
pytorch_lightning.accelerators.accelerator
import
Accelerator
from
pytorch_lightning.plugins.environments
import
ClusterEnvironment
from
pytorch_lightning.plugins.training_type.training_type_plugin
import
TrainingTypePlugin
from
pytorch_lightning.trainer.connectors.accelerator_connector
import
AcceleratorConnector
from
pytorch_lightning.trainer
import
Trainer
from
pytorch_lightning.plugins
import
Plugin
from
pytorch_lightning.plugins.environments
import
ClusterEnvironment
from
pytorch_lightning.trainer.connectors.accelerator_connector
import
AcceleratorConnector
from
....serializer
import
serialize_cls
...
...
@@ -69,9 +70,8 @@ class BypassPlugin(TrainingTypePlugin):
# bypass device placement from pytorch lightning
pass
def
setup
(
self
,
model
:
torch
.
nn
.
Module
)
->
torch
.
nn
.
Module
:
self
.
model_to_device
()
return
self
.
model
def
setup
(
self
)
->
None
:
pass
@
property
def
is_global_zero
(
self
)
->
bool
:
...
...
@@ -100,8 +100,9 @@ def get_accelerator_connector(
deterministic
:
bool
=
False
,
precision
:
int
=
32
,
amp_backend
:
str
=
'native'
,
amp_level
:
str
=
'O2'
,
plugins
:
Optional
[
Union
[
List
[
Union
[
Plugin
,
ClusterEnvironment
,
str
]],
Plugin
,
ClusterEnvironment
,
str
]]
=
None
,
amp_level
:
Optional
[
str
]
=
None
,
plugins
:
Optional
[
Union
[
List
[
Union
[
TrainingTypePlugin
,
ClusterEnvironment
,
str
]],
TrainingTypePlugin
,
ClusterEnvironment
,
str
]]
=
None
,
**
other_trainier_kwargs
)
->
AcceleratorConnector
:
gpu_ids
=
Trainer
().
_parse_devices
(
gpus
,
auto_select_gpus
,
tpu_cores
)
return
AcceleratorConnector
(
...
...
nni/retiarii/evaluator/pytorch/cgo/evaluator.py
View file @
dc58203d
...
...
@@ -7,7 +7,7 @@ from typing import Dict, List, Optional, Union
import
torch.nn
as
nn
import
torch.optim
as
optim
import
py
torch
_lightning
as
pl
import
torch
metrics
from
torch.utils.data
import
DataLoader
import
nni
...
...
@@ -19,7 +19,7 @@ from ....serializer import serialize_cls
@
serialize_cls
class
_MultiModelSupervisedLearningModule
(
LightningModule
):
def
__init__
(
self
,
criterion
:
nn
.
Module
,
metrics
:
Dict
[
str
,
pl
.
metrics
.
Metric
],
def
__init__
(
self
,
criterion
:
nn
.
Module
,
metrics
:
Dict
[
str
,
torch
metrics
.
Metric
],
n_models
:
int
=
0
,
learning_rate
:
float
=
0.001
,
weight_decay
:
float
=
0.
,
...
...
@@ -119,7 +119,7 @@ class MultiModelSupervisedLearningModule(_MultiModelSupervisedLearningModule):
Class for optimizer (not an instance). default: ``Adam``
"""
def
__init__
(
self
,
criterion
:
nn
.
Module
,
metrics
:
Dict
[
str
,
pl
.
metrics
.
Metric
],
def
__init__
(
self
,
criterion
:
nn
.
Module
,
metrics
:
Dict
[
str
,
torch
metrics
.
Metric
],
learning_rate
:
float
=
0.001
,
weight_decay
:
float
=
0.
,
optimizer
:
optim
.
Optimizer
=
optim
.
Adam
):
...
...
@@ -180,7 +180,7 @@ class _RegressionModule(MultiModelSupervisedLearningModule):
learning_rate
:
float
=
0.001
,
weight_decay
:
float
=
0.
,
optimizer
:
optim
.
Optimizer
=
optim
.
Adam
):
super
().
__init__
(
criterion
,
{
'mse'
:
pl
.
metrics
.
MeanSquaredError
},
super
().
__init__
(
criterion
,
{
'mse'
:
torch
metrics
.
MeanSquaredError
},
learning_rate
=
learning_rate
,
weight_decay
=
weight_decay
,
optimizer
=
optimizer
)
...
...
nni/retiarii/evaluator/pytorch/lightning.py
View file @
dc58203d
...
...
@@ -9,6 +9,7 @@ from typing import Dict, NoReturn, Union, Optional, List, Type
import
pytorch_lightning
as
pl
import
torch.nn
as
nn
import
torch.optim
as
optim
import
torchmetrics
from
torch.utils.data
import
DataLoader
import
nni
...
...
@@ -140,7 +141,7 @@ def _check_dataloader(dataloader):
### The following are some commonly used Lightning modules ###
class
_SupervisedLearningModule
(
LightningModule
):
def
__init__
(
self
,
criterion
:
nn
.
Module
,
metrics
:
Dict
[
str
,
pl
.
metrics
.
Metric
],
def
__init__
(
self
,
criterion
:
nn
.
Module
,
metrics
:
Dict
[
str
,
torch
metrics
.
Metric
],
learning_rate
:
float
=
0.001
,
weight_decay
:
float
=
0.
,
optimizer
:
optim
.
Optimizer
=
optim
.
Adam
,
...
...
@@ -213,7 +214,7 @@ class _SupervisedLearningModule(LightningModule):
return
{
name
:
self
.
trainer
.
callback_metrics
[
'val_'
+
name
].
item
()
for
name
in
self
.
metrics
}
class
_AccuracyWithLogits
(
pl
.
metrics
.
Accuracy
):
class
_AccuracyWithLogits
(
torch
metrics
.
Accuracy
):
def
update
(
self
,
pred
,
target
):
return
super
().
update
(
nn
.
functional
.
softmax
(
pred
),
target
)
...
...
@@ -278,7 +279,7 @@ class _RegressionModule(_SupervisedLearningModule):
weight_decay
:
float
=
0.
,
optimizer
:
optim
.
Optimizer
=
optim
.
Adam
,
export_onnx
:
bool
=
True
):
super
().
__init__
(
criterion
,
{
'mse'
:
pl
.
metrics
.
MeanSquaredError
},
super
().
__init__
(
criterion
,
{
'mse'
:
torch
metrics
.
MeanSquaredError
},
learning_rate
=
learning_rate
,
weight_decay
=
weight_decay
,
optimizer
=
optimizer
,
export_onnx
=
export_onnx
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment