Unverified Commit 5c9797a2 authored by Zhenhua Han's avatar Zhenhua Han Committed by GitHub
Browse files

Fix interface of CGO's accelerator to support pytorch-lightning 1.4.2 (#4075)

parent cae1e6d4
......@@ -8,7 +8,7 @@ torch == 1.9.0+cpu ; sys_platform != "darwin"
torch == 1.9.0 ; sys_platform == "darwin"
torchvision == 0.10.0+cpu ; sys_platform != "darwin"
torchvision == 0.10.0 ; sys_platform == "darwin"
pytorch-lightning >= 1.2.8, < 1.4.2
pytorch-lightning >= 1.4.2
onnx
peewee
graphviz
......
......@@ -4,6 +4,7 @@ import torch
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.plugins import Plugin
from pytorch_lightning.plugins.environments import ClusterEnvironment
......@@ -53,6 +54,13 @@ class BypassPlugin(TrainingTypePlugin):
"""Perform a all_gather on all processes """
return tensor
def teardown(self):
"""
This method is called to teardown the training process.
It is the right place to release memory and free other resources.
"""
pass
@property
def root_device(self) -> torch.device:
return torch.device(self.device)
......@@ -78,10 +86,13 @@ class BypassPlugin(TrainingTypePlugin):
def get_accelerator_connector(
num_processes: int = 1,
devices: Optional[Union[List[int], str, int]] = None,
tpu_cores: Optional[Union[List[int], str, int]] = None,
ipus: Optional[int] = None,
distributed_backend: Optional[str] = None,
auto_select_gpus: bool = False,
accelerator: Optional[Union[str, Accelerator]] = None,
gpus: Optional[Union[List[int], str, int]] = None,
auto_select_gpus: bool = False,
num_nodes: int = 1,
sync_batchnorm: bool = False,
benchmark: bool = False,
......@@ -90,17 +101,35 @@ def get_accelerator_connector(
precision: int = 32,
amp_backend: str = 'native',
amp_level: str = 'O2',
plugins: Optional[Union[List[Union[Plugin, ClusterEnvironment, str]], Plugin, ClusterEnvironment, str]] = None):
plugins: Optional[Union[List[Union[Plugin, ClusterEnvironment, str]], Plugin, ClusterEnvironment, str]] = None,
**other_trainier_kwargs) -> AcceleratorConnector:
gpu_ids = Trainer()._parse_devices(gpus, auto_select_gpus, tpu_cores)
return AcceleratorConnector(
num_processes, tpu_cores, distributed_backend, auto_select_gpus, gpus, num_nodes, sync_batchnorm, benchmark,
replace_sampler_ddp, deterministic, precision, amp_backend, amp_level, plugins
num_processes,
devices,
tpu_cores,
ipus,
distributed_backend,
accelerator,
gpus,
gpu_ids,
num_nodes,
sync_batchnorm,
benchmark,
replace_sampler_ddp,
deterministic,
precision,
amp_backend,
amp_level,
plugins,
)
@serialize_cls
class BypassAccelerator(Accelerator):
def __init__(self, precision_plugin=None, device="cpu"):
def __init__(self, precision_plugin=None, device="cpu", **trainer_kwargs):
if precision_plugin is None:
precision_plugin = get_accelerator_connector().precision_plugin
precision_plugin = get_accelerator_connector(**trainer_kwargs).select_precision_plugin()
# pylint: disable=abstract-class-instantiated
super().__init__(precision_plugin=precision_plugin, training_type_plugin=BypassPlugin(device))
......@@ -26,6 +26,6 @@ class Trainer(pl.Trainer):
if use_cgo:
if "accelerator" in trainer_kwargs:
raise ValueError("accelerator should not be set when cross-graph optimization is enabled.")
trainer_kwargs['accelerator'] = BypassAccelerator(device='cpu')
trainer_kwargs['accelerator'] = BypassAccelerator(device='cpu', **trainer_kwargs)
super().__init__(**trainer_kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment