Unverified Commit bcd4748d authored by Mehdi Mirzazadeh's avatar Mehdi Mirzazadeh Committed by GitHub
Browse files

Preparing pipeline for newer versions of pytorch (#726)

* Preparing pipeline for newer versions of pytorch

* updated error message
parent 63f289f2
......@@ -5,7 +5,7 @@
from threading import Condition
from types import TracebackType
from typing import List, Optional, Tuple, Type, Union, cast
from typing import Dict, List, Optional, Tuple, Type, Union, cast
import torch
from torch import Tensor
......@@ -68,6 +68,10 @@ class DistributedPipelineRecord:
self.rank = rank
self.device = device
def __getstate__(self) -> Dict:
# avoid pickling failure.
return {}
def feed(self, chunk: int, input_idx: int, input: Tensor) -> Tensor:
""" This function is called remotely to provide individual tensors of a given chunk."""
if input.device.type == "cpu":
......@@ -167,6 +171,10 @@ class PartitionHandler:
self.num_outputs = num_outputs
(self.in_queue,), (self.out_queue,) = create_workers([self.device])
def __getstate__(self) -> Dict:
# avoid pickling failure.
return {}
def local_parameter_rrefs(self) -> List[rpc.RRef]:
r"""
Create one RRef for each parameter in the given local module, and return a
......
......@@ -20,7 +20,7 @@ Device = Union[torch.device, int, str]
def check_pytorch_version() -> None:
if torch.__version__.split("+")[0].split(".")[:2] < ["1", "9"]:
if list(map(int, torch.__version__.split("+")[0].split(".")[:2])) < [1, 9]:
raise Exception("DistributedPipeline requires PyTorch version 1.9 or higher")
......
......@@ -31,7 +31,7 @@ else:
DEVICES = [CPU_DEVICES]
pytestmark = pytest.mark.skipif(torch_version() < (1, 10, 0), reason="requires torch version >= 1.10.0")
pytestmark = pytest.mark.skipif(torch_version() < (1, 9, 0), reason="requires torch version >= 1.9.0")
def rpc_worker(rank, world_size, init_file, func, *args):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment