Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
dc58203d
Unverified
Commit
dc58203d
authored
Nov 03, 2021
by
Yuge Zhang
Committed by
GitHub
Nov 03, 2021
Browse files
Adopt torchmetrics (#4290)
parent
8fc555ad
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
23 additions
and
19 deletions
+23
-19
dependencies/recommended.txt
dependencies/recommended.txt
+2
-1
dependencies/recommended_legacy.txt
dependencies/recommended_legacy.txt
+1
-0
nni/retiarii/evaluator/pytorch/cgo/accelerator.py
nni/retiarii/evaluator/pytorch/cgo/accelerator.py
+12
-11
nni/retiarii/evaluator/pytorch/cgo/evaluator.py
nni/retiarii/evaluator/pytorch/cgo/evaluator.py
+4
-4
nni/retiarii/evaluator/pytorch/lightning.py
nni/retiarii/evaluator/pytorch/lightning.py
+4
-3
No files found.
dependencies/recommended.txt
View file @
dc58203d
...
@@ -8,7 +8,8 @@ torch == 1.9.0+cpu ; sys_platform != "darwin"
...
@@ -8,7 +8,8 @@ torch == 1.9.0+cpu ; sys_platform != "darwin"
torch == 1.9.0 ; sys_platform == "darwin"
torch == 1.9.0 ; sys_platform == "darwin"
torchvision == 0.10.0+cpu ; sys_platform != "darwin"
torchvision == 0.10.0+cpu ; sys_platform != "darwin"
torchvision == 0.10.0 ; sys_platform == "darwin"
torchvision == 0.10.0 ; sys_platform == "darwin"
pytorch-lightning >= 1.4.2
pytorch-lightning >= 1.5
torchmetrics
onnx
onnx
peewee
peewee
graphviz
graphviz
...
...
dependencies/recommended_legacy.txt
View file @
dc58203d
...
@@ -6,6 +6,7 @@ torchvision == 0.7.0+cpu
...
@@ -6,6 +6,7 @@ torchvision == 0.7.0+cpu
# It will install pytorch-lightning 0.8.x and unit tests won't work.
# It will install pytorch-lightning 0.8.x and unit tests won't work.
# Latest version has conflict with tensorboard and tensorflow 1.x.
# Latest version has conflict with tensorboard and tensorflow 1.x.
pytorch-lightning
pytorch-lightning
torchmetrics
keras == 2.1.6
keras == 2.1.6
onnx
onnx
...
...
nni/retiarii/evaluator/pytorch/cgo/accelerator.py
View file @
dc58203d
from
typing
import
Any
,
Union
,
Optional
,
List
# Copyright (c) Microsoft Corporation.
import
torch
# Licensed under the MIT license.
from
typing
import
Any
,
List
,
Optional
,
Union
import
torch
from
pytorch_lightning.accelerators.accelerator
import
Accelerator
from
pytorch_lightning.accelerators.accelerator
import
Accelerator
from
pytorch_lightning.plugins.environments
import
ClusterEnvironment
from
pytorch_lightning.plugins.training_type.training_type_plugin
import
TrainingTypePlugin
from
pytorch_lightning.plugins.training_type.training_type_plugin
import
TrainingTypePlugin
from
pytorch_lightning.trainer.connectors.accelerator_connector
import
AcceleratorConnector
from
pytorch_lightning.trainer
import
Trainer
from
pytorch_lightning.trainer
import
Trainer
from
pytorch_lightning.trainer.connectors.accelerator_connector
import
AcceleratorConnector
from
pytorch_lightning.plugins
import
Plugin
from
pytorch_lightning.plugins.environments
import
ClusterEnvironment
from
....serializer
import
serialize_cls
from
....serializer
import
serialize_cls
...
@@ -69,9 +70,8 @@ class BypassPlugin(TrainingTypePlugin):
...
@@ -69,9 +70,8 @@ class BypassPlugin(TrainingTypePlugin):
# bypass device placement from pytorch lightning
# bypass device placement from pytorch lightning
pass
pass
def
setup
(
self
,
model
:
torch
.
nn
.
Module
)
->
torch
.
nn
.
Module
:
def
setup
(
self
)
->
None
:
self
.
model_to_device
()
pass
return
self
.
model
@
property
@
property
def
is_global_zero
(
self
)
->
bool
:
def
is_global_zero
(
self
)
->
bool
:
...
@@ -100,8 +100,9 @@ def get_accelerator_connector(
...
@@ -100,8 +100,9 @@ def get_accelerator_connector(
deterministic
:
bool
=
False
,
deterministic
:
bool
=
False
,
precision
:
int
=
32
,
precision
:
int
=
32
,
amp_backend
:
str
=
'native'
,
amp_backend
:
str
=
'native'
,
amp_level
:
str
=
'O2'
,
amp_level
:
Optional
[
str
]
=
None
,
plugins
:
Optional
[
Union
[
List
[
Union
[
Plugin
,
ClusterEnvironment
,
str
]],
Plugin
,
ClusterEnvironment
,
str
]]
=
None
,
plugins
:
Optional
[
Union
[
List
[
Union
[
TrainingTypePlugin
,
ClusterEnvironment
,
str
]],
TrainingTypePlugin
,
ClusterEnvironment
,
str
]]
=
None
,
**
other_trainier_kwargs
)
->
AcceleratorConnector
:
**
other_trainier_kwargs
)
->
AcceleratorConnector
:
gpu_ids
=
Trainer
().
_parse_devices
(
gpus
,
auto_select_gpus
,
tpu_cores
)
gpu_ids
=
Trainer
().
_parse_devices
(
gpus
,
auto_select_gpus
,
tpu_cores
)
return
AcceleratorConnector
(
return
AcceleratorConnector
(
...
...
nni/retiarii/evaluator/pytorch/cgo/evaluator.py
View file @
dc58203d
...
@@ -7,7 +7,7 @@ from typing import Dict, List, Optional, Union
...
@@ -7,7 +7,7 @@ from typing import Dict, List, Optional, Union
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.optim
as
optim
import
torch.optim
as
optim
import
py
torch
_lightning
as
pl
import
torch
metrics
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
import
nni
import
nni
...
@@ -19,7 +19,7 @@ from ....serializer import serialize_cls
...
@@ -19,7 +19,7 @@ from ....serializer import serialize_cls
@
serialize_cls
@
serialize_cls
class
_MultiModelSupervisedLearningModule
(
LightningModule
):
class
_MultiModelSupervisedLearningModule
(
LightningModule
):
def
__init__
(
self
,
criterion
:
nn
.
Module
,
metrics
:
Dict
[
str
,
pl
.
metrics
.
Metric
],
def
__init__
(
self
,
criterion
:
nn
.
Module
,
metrics
:
Dict
[
str
,
torch
metrics
.
Metric
],
n_models
:
int
=
0
,
n_models
:
int
=
0
,
learning_rate
:
float
=
0.001
,
learning_rate
:
float
=
0.001
,
weight_decay
:
float
=
0.
,
weight_decay
:
float
=
0.
,
...
@@ -119,7 +119,7 @@ class MultiModelSupervisedLearningModule(_MultiModelSupervisedLearningModule):
...
@@ -119,7 +119,7 @@ class MultiModelSupervisedLearningModule(_MultiModelSupervisedLearningModule):
Class for optimizer (not an instance). default: ``Adam``
Class for optimizer (not an instance). default: ``Adam``
"""
"""
def
__init__
(
self
,
criterion
:
nn
.
Module
,
metrics
:
Dict
[
str
,
pl
.
metrics
.
Metric
],
def
__init__
(
self
,
criterion
:
nn
.
Module
,
metrics
:
Dict
[
str
,
torch
metrics
.
Metric
],
learning_rate
:
float
=
0.001
,
learning_rate
:
float
=
0.001
,
weight_decay
:
float
=
0.
,
weight_decay
:
float
=
0.
,
optimizer
:
optim
.
Optimizer
=
optim
.
Adam
):
optimizer
:
optim
.
Optimizer
=
optim
.
Adam
):
...
@@ -180,7 +180,7 @@ class _RegressionModule(MultiModelSupervisedLearningModule):
...
@@ -180,7 +180,7 @@ class _RegressionModule(MultiModelSupervisedLearningModule):
learning_rate
:
float
=
0.001
,
learning_rate
:
float
=
0.001
,
weight_decay
:
float
=
0.
,
weight_decay
:
float
=
0.
,
optimizer
:
optim
.
Optimizer
=
optim
.
Adam
):
optimizer
:
optim
.
Optimizer
=
optim
.
Adam
):
super
().
__init__
(
criterion
,
{
'mse'
:
pl
.
metrics
.
MeanSquaredError
},
super
().
__init__
(
criterion
,
{
'mse'
:
torch
metrics
.
MeanSquaredError
},
learning_rate
=
learning_rate
,
weight_decay
=
weight_decay
,
optimizer
=
optimizer
)
learning_rate
=
learning_rate
,
weight_decay
=
weight_decay
,
optimizer
=
optimizer
)
...
...
nni/retiarii/evaluator/pytorch/lightning.py
View file @
dc58203d
...
@@ -9,6 +9,7 @@ from typing import Dict, NoReturn, Union, Optional, List, Type
...
@@ -9,6 +9,7 @@ from typing import Dict, NoReturn, Union, Optional, List, Type
import
pytorch_lightning
as
pl
import
pytorch_lightning
as
pl
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.optim
as
optim
import
torch.optim
as
optim
import
torchmetrics
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
import
nni
import
nni
...
@@ -140,7 +141,7 @@ def _check_dataloader(dataloader):
...
@@ -140,7 +141,7 @@ def _check_dataloader(dataloader):
### The following are some commonly used Lightning modules ###
### The following are some commonly used Lightning modules ###
class
_SupervisedLearningModule
(
LightningModule
):
class
_SupervisedLearningModule
(
LightningModule
):
def
__init__
(
self
,
criterion
:
nn
.
Module
,
metrics
:
Dict
[
str
,
pl
.
metrics
.
Metric
],
def
__init__
(
self
,
criterion
:
nn
.
Module
,
metrics
:
Dict
[
str
,
torch
metrics
.
Metric
],
learning_rate
:
float
=
0.001
,
learning_rate
:
float
=
0.001
,
weight_decay
:
float
=
0.
,
weight_decay
:
float
=
0.
,
optimizer
:
optim
.
Optimizer
=
optim
.
Adam
,
optimizer
:
optim
.
Optimizer
=
optim
.
Adam
,
...
@@ -213,7 +214,7 @@ class _SupervisedLearningModule(LightningModule):
...
@@ -213,7 +214,7 @@ class _SupervisedLearningModule(LightningModule):
return
{
name
:
self
.
trainer
.
callback_metrics
[
'val_'
+
name
].
item
()
for
name
in
self
.
metrics
}
return
{
name
:
self
.
trainer
.
callback_metrics
[
'val_'
+
name
].
item
()
for
name
in
self
.
metrics
}
class
_AccuracyWithLogits
(
pl
.
metrics
.
Accuracy
):
class
_AccuracyWithLogits
(
torch
metrics
.
Accuracy
):
def
update
(
self
,
pred
,
target
):
def
update
(
self
,
pred
,
target
):
return
super
().
update
(
nn
.
functional
.
softmax
(
pred
),
target
)
return
super
().
update
(
nn
.
functional
.
softmax
(
pred
),
target
)
...
@@ -278,7 +279,7 @@ class _RegressionModule(_SupervisedLearningModule):
...
@@ -278,7 +279,7 @@ class _RegressionModule(_SupervisedLearningModule):
weight_decay
:
float
=
0.
,
weight_decay
:
float
=
0.
,
optimizer
:
optim
.
Optimizer
=
optim
.
Adam
,
optimizer
:
optim
.
Optimizer
=
optim
.
Adam
,
export_onnx
:
bool
=
True
):
export_onnx
:
bool
=
True
):
super
().
__init__
(
criterion
,
{
'mse'
:
pl
.
metrics
.
MeanSquaredError
},
super
().
__init__
(
criterion
,
{
'mse'
:
torch
metrics
.
MeanSquaredError
},
learning_rate
=
learning_rate
,
weight_decay
=
weight_decay
,
optimizer
=
optimizer
,
learning_rate
=
learning_rate
,
weight_decay
=
weight_decay
,
optimizer
=
optimizer
,
export_onnx
=
export_onnx
)
export_onnx
=
export_onnx
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment