Unverified Commit dc58203d authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Adopt torchmetrics (#4290)

parent 8fc555ad
...@@ -8,7 +8,8 @@ torch == 1.9.0+cpu ; sys_platform != "darwin" ...@@ -8,7 +8,8 @@ torch == 1.9.0+cpu ; sys_platform != "darwin"
torch == 1.9.0 ; sys_platform == "darwin" torch == 1.9.0 ; sys_platform == "darwin"
torchvision == 0.10.0+cpu ; sys_platform != "darwin" torchvision == 0.10.0+cpu ; sys_platform != "darwin"
torchvision == 0.10.0 ; sys_platform == "darwin" torchvision == 0.10.0 ; sys_platform == "darwin"
pytorch-lightning >= 1.4.2 pytorch-lightning >= 1.5
torchmetrics
onnx onnx
peewee peewee
graphviz graphviz
......
...@@ -6,6 +6,7 @@ torchvision == 0.7.0+cpu ...@@ -6,6 +6,7 @@ torchvision == 0.7.0+cpu
# It will install pytorch-lightning 0.8.x and unit tests won't work. # It will install pytorch-lightning 0.8.x and unit tests won't work.
# Latest version has conflict with tensorboard and tensorflow 1.x. # Latest version has conflict with tensorboard and tensorflow 1.x.
pytorch-lightning pytorch-lightning
torchmetrics
keras == 2.1.6 keras == 2.1.6
onnx onnx
......
from typing import Any, Union, Optional, List # Copyright (c) Microsoft Corporation.
import torch # Licensed under the MIT license.
from typing import Any, List, Optional, Union
import torch
from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.plugins.environments import ClusterEnvironment
from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
from pytorch_lightning.trainer import Trainer from pytorch_lightning.trainer import Trainer
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
from pytorch_lightning.plugins import Plugin
from pytorch_lightning.plugins.environments import ClusterEnvironment
from ....serializer import serialize_cls from ....serializer import serialize_cls
...@@ -69,9 +70,8 @@ class BypassPlugin(TrainingTypePlugin): ...@@ -69,9 +70,8 @@ class BypassPlugin(TrainingTypePlugin):
# bypass device placement from pytorch lightning # bypass device placement from pytorch lightning
pass pass
def setup(self, model: torch.nn.Module) -> torch.nn.Module: def setup(self) -> None:
self.model_to_device() pass
return self.model
@property @property
def is_global_zero(self) -> bool: def is_global_zero(self) -> bool:
...@@ -100,8 +100,9 @@ def get_accelerator_connector( ...@@ -100,8 +100,9 @@ def get_accelerator_connector(
deterministic: bool = False, deterministic: bool = False,
precision: int = 32, precision: int = 32,
amp_backend: str = 'native', amp_backend: str = 'native',
amp_level: str = 'O2', amp_level: Optional[str] = None,
plugins: Optional[Union[List[Union[Plugin, ClusterEnvironment, str]], Plugin, ClusterEnvironment, str]] = None, plugins: Optional[Union[List[Union[TrainingTypePlugin, ClusterEnvironment, str]],
TrainingTypePlugin, ClusterEnvironment, str]] = None,
**other_trainier_kwargs) -> AcceleratorConnector: **other_trainier_kwargs) -> AcceleratorConnector:
gpu_ids = Trainer()._parse_devices(gpus, auto_select_gpus, tpu_cores) gpu_ids = Trainer()._parse_devices(gpus, auto_select_gpus, tpu_cores)
return AcceleratorConnector( return AcceleratorConnector(
......
...@@ -7,7 +7,7 @@ from typing import Dict, List, Optional, Union ...@@ -7,7 +7,7 @@ from typing import Dict, List, Optional, Union
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
import pytorch_lightning as pl import torchmetrics
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import nni import nni
...@@ -19,7 +19,7 @@ from ....serializer import serialize_cls ...@@ -19,7 +19,7 @@ from ....serializer import serialize_cls
@serialize_cls @serialize_cls
class _MultiModelSupervisedLearningModule(LightningModule): class _MultiModelSupervisedLearningModule(LightningModule):
def __init__(self, criterion: nn.Module, metrics: Dict[str, pl.metrics.Metric], def __init__(self, criterion: nn.Module, metrics: Dict[str, torchmetrics.Metric],
n_models: int = 0, n_models: int = 0,
learning_rate: float = 0.001, learning_rate: float = 0.001,
weight_decay: float = 0., weight_decay: float = 0.,
...@@ -119,7 +119,7 @@ class MultiModelSupervisedLearningModule(_MultiModelSupervisedLearningModule): ...@@ -119,7 +119,7 @@ class MultiModelSupervisedLearningModule(_MultiModelSupervisedLearningModule):
Class for optimizer (not an instance). default: ``Adam`` Class for optimizer (not an instance). default: ``Adam``
""" """
def __init__(self, criterion: nn.Module, metrics: Dict[str, pl.metrics.Metric], def __init__(self, criterion: nn.Module, metrics: Dict[str, torchmetrics.Metric],
learning_rate: float = 0.001, learning_rate: float = 0.001,
weight_decay: float = 0., weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam): optimizer: optim.Optimizer = optim.Adam):
...@@ -180,7 +180,7 @@ class _RegressionModule(MultiModelSupervisedLearningModule): ...@@ -180,7 +180,7 @@ class _RegressionModule(MultiModelSupervisedLearningModule):
learning_rate: float = 0.001, learning_rate: float = 0.001,
weight_decay: float = 0., weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam): optimizer: optim.Optimizer = optim.Adam):
super().__init__(criterion, {'mse': pl.metrics.MeanSquaredError}, super().__init__(criterion, {'mse': torchmetrics.MeanSquaredError},
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer) learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer)
......
...@@ -9,6 +9,7 @@ from typing import Dict, NoReturn, Union, Optional, List, Type ...@@ -9,6 +9,7 @@ from typing import Dict, NoReturn, Union, Optional, List, Type
import pytorch_lightning as pl import pytorch_lightning as pl
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
import torchmetrics
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import nni import nni
...@@ -140,7 +141,7 @@ def _check_dataloader(dataloader): ...@@ -140,7 +141,7 @@ def _check_dataloader(dataloader):
### The following are some commonly used Lightning modules ### ### The following are some commonly used Lightning modules ###
class _SupervisedLearningModule(LightningModule): class _SupervisedLearningModule(LightningModule):
def __init__(self, criterion: nn.Module, metrics: Dict[str, pl.metrics.Metric], def __init__(self, criterion: nn.Module, metrics: Dict[str, torchmetrics.Metric],
learning_rate: float = 0.001, learning_rate: float = 0.001,
weight_decay: float = 0., weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam, optimizer: optim.Optimizer = optim.Adam,
...@@ -213,7 +214,7 @@ class _SupervisedLearningModule(LightningModule): ...@@ -213,7 +214,7 @@ class _SupervisedLearningModule(LightningModule):
return {name: self.trainer.callback_metrics['val_' + name].item() for name in self.metrics} return {name: self.trainer.callback_metrics['val_' + name].item() for name in self.metrics}
class _AccuracyWithLogits(pl.metrics.Accuracy): class _AccuracyWithLogits(torchmetrics.Accuracy):
def update(self, pred, target): def update(self, pred, target):
return super().update(nn.functional.softmax(pred), target) return super().update(nn.functional.softmax(pred), target)
...@@ -278,7 +279,7 @@ class _RegressionModule(_SupervisedLearningModule): ...@@ -278,7 +279,7 @@ class _RegressionModule(_SupervisedLearningModule):
weight_decay: float = 0., weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam, optimizer: optim.Optimizer = optim.Adam,
export_onnx: bool = True): export_onnx: bool = True):
super().__init__(criterion, {'mse': pl.metrics.MeanSquaredError}, super().__init__(criterion, {'mse': torchmetrics.MeanSquaredError},
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer, learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer,
export_onnx=export_onnx) export_onnx=export_onnx)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment