Unverified Commit bcd4748d authored by Mehdi Mirzazadeh's avatar Mehdi Mirzazadeh Committed by GitHub
Browse files

Preparing pipeline for newer versions of pytorch (#726)

* Preparing pipeline for newer versions of pytorch

* updated error message
parent 63f289f2
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
from threading import Condition from threading import Condition
from types import TracebackType from types import TracebackType
from typing import List, Optional, Tuple, Type, Union, cast from typing import Dict, List, Optional, Tuple, Type, Union, cast
import torch import torch
from torch import Tensor from torch import Tensor
...@@ -68,6 +68,10 @@ class DistributedPipelineRecord: ...@@ -68,6 +68,10 @@ class DistributedPipelineRecord:
self.rank = rank self.rank = rank
self.device = device self.device = device
def __getstate__(self) -> Dict:
# avoid pickling failure.
return {}
def feed(self, chunk: int, input_idx: int, input: Tensor) -> Tensor: def feed(self, chunk: int, input_idx: int, input: Tensor) -> Tensor:
""" This function is called remotely to provide individual tensors of a given chunk.""" """ This function is called remotely to provide individual tensors of a given chunk."""
if input.device.type == "cpu": if input.device.type == "cpu":
...@@ -167,6 +171,10 @@ class PartitionHandler: ...@@ -167,6 +171,10 @@ class PartitionHandler:
self.num_outputs = num_outputs self.num_outputs = num_outputs
(self.in_queue,), (self.out_queue,) = create_workers([self.device]) (self.in_queue,), (self.out_queue,) = create_workers([self.device])
def __getstate__(self) -> Dict:
# avoid pickling failure.
return {}
def local_parameter_rrefs(self) -> List[rpc.RRef]: def local_parameter_rrefs(self) -> List[rpc.RRef]:
r""" r"""
Create one RRef for each parameter in the given local module, and return a Create one RRef for each parameter in the given local module, and return a
......
...@@ -20,7 +20,7 @@ Device = Union[torch.device, int, str] ...@@ -20,7 +20,7 @@ Device = Union[torch.device, int, str]
def check_pytorch_version() -> None: def check_pytorch_version() -> None:
if torch.__version__.split("+")[0].split(".")[:2] < ["1", "9"]: if list(map(int, torch.__version__.split("+")[0].split(".")[:2])) < [1, 9]:
raise Exception("DistributedPipeline requires PyTorch version 1.9 or higher") raise Exception("DistributedPipeline requires PyTorch version 1.9 or higher")
......
...@@ -31,7 +31,7 @@ else: ...@@ -31,7 +31,7 @@ else:
DEVICES = [CPU_DEVICES] DEVICES = [CPU_DEVICES]
pytestmark = pytest.mark.skipif(torch_version() < (1, 10, 0), reason="requires torch version >= 1.10.0") pytestmark = pytest.mark.skipif(torch_version() < (1, 9, 0), reason="requires torch version >= 1.9.0")
def rpc_worker(rank, world_size, init_file, func, *args): def rpc_worker(rank, world_size, init_file, func, *args):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment