Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
dc58203d
Unverified
Commit
dc58203d
authored
Nov 03, 2021
by
Yuge Zhang
Committed by
GitHub
Nov 03, 2021
Browse files
Adopt torchmetrics (#4290)
parent
8fc555ad
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
23 additions
and
19 deletions
+23
-19
dependencies/recommended.txt
dependencies/recommended.txt
+2
-1
dependencies/recommended_legacy.txt
dependencies/recommended_legacy.txt
+1
-0
nni/retiarii/evaluator/pytorch/cgo/accelerator.py
nni/retiarii/evaluator/pytorch/cgo/accelerator.py
+12
-11
nni/retiarii/evaluator/pytorch/cgo/evaluator.py
nni/retiarii/evaluator/pytorch/cgo/evaluator.py
+4
-4
nni/retiarii/evaluator/pytorch/lightning.py
nni/retiarii/evaluator/pytorch/lightning.py
+4
-3
No files found.
dependencies/recommended.txt
View file @
dc58203d
...
...
@@ -8,7 +8,8 @@ torch == 1.9.0+cpu ; sys_platform != "darwin"
torch == 1.9.0 ; sys_platform == "darwin"
torchvision == 0.10.0+cpu ; sys_platform != "darwin"
torchvision == 0.10.0 ; sys_platform == "darwin"
pytorch-lightning >= 1.4.2
pytorch-lightning >= 1.5
torchmetrics
onnx
peewee
graphviz
...
...
dependencies/recommended_legacy.txt
View file @
dc58203d
...
...
@@ -6,6 +6,7 @@ torchvision == 0.7.0+cpu
# It will install pytorch-lightning 0.8.x and unit tests won't work.
# Latest version has conflict with tensorboard and tensorflow 1.x.
pytorch-lightning
torchmetrics
keras == 2.1.6
onnx
...
...
nni/retiarii/evaluator/pytorch/cgo/accelerator.py
View file @
dc58203d
from
typing
import
Any
,
Union
,
Optional
,
List
import
torch
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
typing
import
Any
,
List
,
Optional
,
Union
import
torch
from
pytorch_lightning.accelerators.accelerator
import
Accelerator
from
pytorch_lightning.plugins.environments
import
ClusterEnvironment
from
pytorch_lightning.plugins.training_type.training_type_plugin
import
TrainingTypePlugin
from
pytorch_lightning.trainer.connectors.accelerator_connector
import
AcceleratorConnector
from
pytorch_lightning.trainer
import
Trainer
from
pytorch_lightning.plugins
import
Plugin
from
pytorch_lightning.plugins.environments
import
ClusterEnvironment
from
pytorch_lightning.trainer.connectors.accelerator_connector
import
AcceleratorConnector
from
....serializer
import
serialize_cls
...
...
@@ -69,9 +70,8 @@ class BypassPlugin(TrainingTypePlugin):
# bypass device placement from pytorch lightning
pass
def
setup
(
self
,
model
:
torch
.
nn
.
Module
)
->
torch
.
nn
.
Module
:
self
.
model_to_device
()
return
self
.
model
def
setup
(
self
)
->
None
:
pass
@
property
def
is_global_zero
(
self
)
->
bool
:
...
...
@@ -100,8 +100,9 @@ def get_accelerator_connector(
deterministic
:
bool
=
False
,
precision
:
int
=
32
,
amp_backend
:
str
=
'native'
,
amp_level
:
str
=
'O2'
,
plugins
:
Optional
[
Union
[
List
[
Union
[
Plugin
,
ClusterEnvironment
,
str
]],
Plugin
,
ClusterEnvironment
,
str
]]
=
None
,
amp_level
:
Optional
[
str
]
=
None
,
plugins
:
Optional
[
Union
[
List
[
Union
[
TrainingTypePlugin
,
ClusterEnvironment
,
str
]],
TrainingTypePlugin
,
ClusterEnvironment
,
str
]]
=
None
,
**
other_trainier_kwargs
)
->
AcceleratorConnector
:
gpu_ids
=
Trainer
().
_parse_devices
(
gpus
,
auto_select_gpus
,
tpu_cores
)
return
AcceleratorConnector
(
...
...
nni/retiarii/evaluator/pytorch/cgo/evaluator.py
View file @
dc58203d
...
...
@@ -7,7 +7,7 @@ from typing import Dict, List, Optional, Union
import
torch.nn
as
nn
import
torch.optim
as
optim
import
py
torch
_lightning
as
pl
import
torch
metrics
from
torch.utils.data
import
DataLoader
import
nni
...
...
@@ -19,7 +19,7 @@ from ....serializer import serialize_cls
@
serialize_cls
class
_MultiModelSupervisedLearningModule
(
LightningModule
):
def
__init__
(
self
,
criterion
:
nn
.
Module
,
metrics
:
Dict
[
str
,
pl
.
metrics
.
Metric
],
def
__init__
(
self
,
criterion
:
nn
.
Module
,
metrics
:
Dict
[
str
,
torch
metrics
.
Metric
],
n_models
:
int
=
0
,
learning_rate
:
float
=
0.001
,
weight_decay
:
float
=
0.
,
...
...
@@ -119,7 +119,7 @@ class MultiModelSupervisedLearningModule(_MultiModelSupervisedLearningModule):
Class for optimizer (not an instance). default: ``Adam``
"""
def
__init__
(
self
,
criterion
:
nn
.
Module
,
metrics
:
Dict
[
str
,
pl
.
metrics
.
Metric
],
def
__init__
(
self
,
criterion
:
nn
.
Module
,
metrics
:
Dict
[
str
,
torch
metrics
.
Metric
],
learning_rate
:
float
=
0.001
,
weight_decay
:
float
=
0.
,
optimizer
:
optim
.
Optimizer
=
optim
.
Adam
):
...
...
@@ -180,7 +180,7 @@ class _RegressionModule(MultiModelSupervisedLearningModule):
learning_rate
:
float
=
0.001
,
weight_decay
:
float
=
0.
,
optimizer
:
optim
.
Optimizer
=
optim
.
Adam
):
super
().
__init__
(
criterion
,
{
'mse'
:
pl
.
metrics
.
MeanSquaredError
},
super
().
__init__
(
criterion
,
{
'mse'
:
torch
metrics
.
MeanSquaredError
},
learning_rate
=
learning_rate
,
weight_decay
=
weight_decay
,
optimizer
=
optimizer
)
...
...
nni/retiarii/evaluator/pytorch/lightning.py
View file @
dc58203d
...
...
@@ -9,6 +9,7 @@ from typing import Dict, NoReturn, Union, Optional, List, Type
import
pytorch_lightning
as
pl
import
torch.nn
as
nn
import
torch.optim
as
optim
import
torchmetrics
from
torch.utils.data
import
DataLoader
import
nni
...
...
@@ -140,7 +141,7 @@ def _check_dataloader(dataloader):
### The following are some commonly used Lightning modules ###
class
_SupervisedLearningModule
(
LightningModule
):
def
__init__
(
self
,
criterion
:
nn
.
Module
,
metrics
:
Dict
[
str
,
pl
.
metrics
.
Metric
],
def
__init__
(
self
,
criterion
:
nn
.
Module
,
metrics
:
Dict
[
str
,
torch
metrics
.
Metric
],
learning_rate
:
float
=
0.001
,
weight_decay
:
float
=
0.
,
optimizer
:
optim
.
Optimizer
=
optim
.
Adam
,
...
...
@@ -213,7 +214,7 @@ class _SupervisedLearningModule(LightningModule):
return
{
name
:
self
.
trainer
.
callback_metrics
[
'val_'
+
name
].
item
()
for
name
in
self
.
metrics
}
class
_AccuracyWithLogits
(
pl
.
metrics
.
Accuracy
):
class
_AccuracyWithLogits
(
torch
metrics
.
Accuracy
):
def
update
(
self
,
pred
,
target
):
return
super
().
update
(
nn
.
functional
.
softmax
(
pred
),
target
)
...
...
@@ -278,7 +279,7 @@ class _RegressionModule(_SupervisedLearningModule):
weight_decay
:
float
=
0.
,
optimizer
:
optim
.
Optimizer
=
optim
.
Adam
,
export_onnx
:
bool
=
True
):
super
().
__init__
(
criterion
,
{
'mse'
:
pl
.
metrics
.
MeanSquaredError
},
super
().
__init__
(
criterion
,
{
'mse'
:
torch
metrics
.
MeanSquaredError
},
learning_rate
=
learning_rate
,
weight_decay
=
weight_decay
,
optimizer
=
optimizer
,
export_onnx
=
export_onnx
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment