Unverified Commit 5c9797a2 authored by Zhenhua Han's avatar Zhenhua Han Committed by GitHub
Browse files

Fix interface of CGO's accelerator to support pytorch-lightning 1.4.2 (#4075)

parent cae1e6d4
...@@ -8,7 +8,7 @@ torch == 1.9.0+cpu ; sys_platform != "darwin" ...@@ -8,7 +8,7 @@ torch == 1.9.0+cpu ; sys_platform != "darwin"
torch == 1.9.0 ; sys_platform == "darwin" torch == 1.9.0 ; sys_platform == "darwin"
torchvision == 0.10.0+cpu ; sys_platform != "darwin" torchvision == 0.10.0+cpu ; sys_platform != "darwin"
torchvision == 0.10.0 ; sys_platform == "darwin" torchvision == 0.10.0 ; sys_platform == "darwin"
pytorch-lightning >= 1.2.8, < 1.4.2 pytorch-lightning >= 1.4.2
onnx onnx
peewee peewee
graphviz graphviz
......
...@@ -4,6 +4,7 @@ import torch ...@@ -4,6 +4,7 @@ import torch
from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.plugins import Plugin from pytorch_lightning.plugins import Plugin
from pytorch_lightning.plugins.environments import ClusterEnvironment from pytorch_lightning.plugins.environments import ClusterEnvironment
...@@ -53,6 +54,13 @@ class BypassPlugin(TrainingTypePlugin): ...@@ -53,6 +54,13 @@ class BypassPlugin(TrainingTypePlugin):
"""Perform a all_gather on all processes """ """Perform a all_gather on all processes """
return tensor return tensor
def teardown(self):
"""
This method is called to teardown the training process.
It is the right place to release memory and free other resources.
"""
pass
@property @property
def root_device(self) -> torch.device: def root_device(self) -> torch.device:
return torch.device(self.device) return torch.device(self.device)
...@@ -78,10 +86,13 @@ class BypassPlugin(TrainingTypePlugin): ...@@ -78,10 +86,13 @@ class BypassPlugin(TrainingTypePlugin):
def get_accelerator_connector( def get_accelerator_connector(
num_processes: int = 1, num_processes: int = 1,
devices: Optional[Union[List[int], str, int]] = None,
tpu_cores: Optional[Union[List[int], str, int]] = None, tpu_cores: Optional[Union[List[int], str, int]] = None,
ipus: Optional[int] = None,
distributed_backend: Optional[str] = None, distributed_backend: Optional[str] = None,
auto_select_gpus: bool = False, accelerator: Optional[Union[str, Accelerator]] = None,
gpus: Optional[Union[List[int], str, int]] = None, gpus: Optional[Union[List[int], str, int]] = None,
auto_select_gpus: bool = False,
num_nodes: int = 1, num_nodes: int = 1,
sync_batchnorm: bool = False, sync_batchnorm: bool = False,
benchmark: bool = False, benchmark: bool = False,
...@@ -90,17 +101,35 @@ def get_accelerator_connector( ...@@ -90,17 +101,35 @@ def get_accelerator_connector(
precision: int = 32, precision: int = 32,
amp_backend: str = 'native', amp_backend: str = 'native',
amp_level: str = 'O2', amp_level: str = 'O2',
plugins: Optional[Union[List[Union[Plugin, ClusterEnvironment, str]], Plugin, ClusterEnvironment, str]] = None): plugins: Optional[Union[List[Union[Plugin, ClusterEnvironment, str]], Plugin, ClusterEnvironment, str]] = None,
**other_trainier_kwargs) -> AcceleratorConnector:
gpu_ids = Trainer()._parse_devices(gpus, auto_select_gpus, tpu_cores)
return AcceleratorConnector( return AcceleratorConnector(
num_processes, tpu_cores, distributed_backend, auto_select_gpus, gpus, num_nodes, sync_batchnorm, benchmark, num_processes,
replace_sampler_ddp, deterministic, precision, amp_backend, amp_level, plugins devices,
tpu_cores,
ipus,
distributed_backend,
accelerator,
gpus,
gpu_ids,
num_nodes,
sync_batchnorm,
benchmark,
replace_sampler_ddp,
deterministic,
precision,
amp_backend,
amp_level,
plugins,
) )
@serialize_cls @serialize_cls
class BypassAccelerator(Accelerator): class BypassAccelerator(Accelerator):
def __init__(self, precision_plugin=None, device="cpu"): def __init__(self, precision_plugin=None, device="cpu", **trainer_kwargs):
if precision_plugin is None: if precision_plugin is None:
precision_plugin = get_accelerator_connector().precision_plugin precision_plugin = get_accelerator_connector(**trainer_kwargs).select_precision_plugin()
# pylint: disable=abstract-class-instantiated # pylint: disable=abstract-class-instantiated
super().__init__(precision_plugin=precision_plugin, training_type_plugin=BypassPlugin(device)) super().__init__(precision_plugin=precision_plugin, training_type_plugin=BypassPlugin(device))
...@@ -26,6 +26,6 @@ class Trainer(pl.Trainer): ...@@ -26,6 +26,6 @@ class Trainer(pl.Trainer):
if use_cgo: if use_cgo:
if "accelerator" in trainer_kwargs: if "accelerator" in trainer_kwargs:
raise ValueError("accelerator should not be set when cross-graph optimization is enabled.") raise ValueError("accelerator should not be set when cross-graph optimization is enabled.")
trainer_kwargs['accelerator'] = BypassAccelerator(device='cpu') trainer_kwargs['accelerator'] = BypassAccelerator(device='cpu', **trainer_kwargs)
super().__init__(**trainer_kwargs) super().__init__(**trainer_kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment